From 27cfd50ddde5b5d74f8abaf417838d4f27c5b2b0 Mon Sep 17 00:00:00 2001 From: Violet Date: Mon, 21 Jul 2025 17:42:04 -0700 Subject: [PATCH] Implement `nanosleep.u32` (#421) --- comgr/src/lib.rs | 3 + ptx/lib/zluda_ptx_impl.bc | Bin 9132 -> 9812 bytes ptx/lib/zluda_ptx_impl.cpp | 25 +++ ptx/src/lib.rs | 1 + ptx/src/pass/insert_post_saturation.rs | 1 + .../instruction_mode_to_global_mode/mod.rs | 1 + ptx/src/pass/llvm/attributes.rs | 34 ++++ ptx/src/pass/{emit_llvm.rs => llvm/emit.rs} | 186 +----------------- ptx/src/pass/llvm/mod.rs | 173 ++++++++++++++++ ptx/src/pass/mod.rs | 21 +- ...eplace_instructions_with_function_calls.rs | 3 + ptx/src/test/ll/_attributes.ll | 1 + ptx/src/test/ll/nanosleep.ll | 15 ++ ptx/src/test/mod.rs | 5 +- ptx/src/test/spirv_run/mod.rs | 18 +- ptx/src/test/spirv_run/nanosleep.ptx | 13 ++ ptx_parser/src/ast.rs | 6 + ptx_parser/src/lib.rs | 7 + zluda/src/impl/module.rs | 4 +- 19 files changed, 330 insertions(+), 187 deletions(-) create mode 100644 ptx/src/pass/llvm/attributes.rs rename ptx/src/pass/{emit_llvm.rs => llvm/emit.rs} (92%) create mode 100644 ptx/src/pass/llvm/mod.rs create mode 100644 ptx/src/test/ll/_attributes.ll create mode 100644 ptx/src/test/ll/nanosleep.ll create mode 100644 ptx/src/test/spirv_run/nanosleep.ptx diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 4d4af11..776f76c 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -178,11 +178,14 @@ pub fn compile_bitcode( comgr: &Comgr, gcn_arch: &CStr, main_buffer: &[u8], + attributes_buffer: &[u8], ptx_impl: &[u8], ) -> Result, Error> { let bitcode_data_set = DataSet::new(comgr)?; let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?; bitcode_data_set.add(&main_bitcode_data)?; + let attributes_bitcode_data = Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?; + bitcode_data_set.add(&attributes_bitcode_data)?; let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?; bitcode_data_set.add(&stdlib_bitcode_data)?; let linking_info = ActionInfo::new(comgr)?; diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6bf56ca029a2372aa006cf2efae6debc4a2c3b3b..039e4d181b136715645d836b04f2e7b01edd6f47 100644 GIT binary patch delta 4367 zcmZ4Ee#K{k3ggF#s?+PAaU5tcS|GuwaWJrJu|`KyAA^V*cbAdOsUW7MB3(fmE|DUw z0V1w}BCQ&XhVB9)jU8PMEUp0^P68~BB2IxEi$qu)1zaRKnlxG$X|M#gG;6rHxC9Dv zD6w_~m9RJp9JLT>4VV%n!@ajdjc;*FoV+MEav+9!;k-h zgc#;9Jzzf|&meSxy-}AzdeM83Ao~Lm2788o3=<~%CN+f>d7~Ofx`B_v zxdeqSPDUvoiE{}GoByz_XRLpwpvl0%!N9=K#KTbL6DcXc&cM?E7I0)^_;Rtsk)5GK zHSsVX3*X8sijEK&1~4#6@G#}~ZDMF@%K|+Cn4a5cE1|Er^n@6rO zFgh?efDLtENI5V;OV5U#qk)Nm$$_Dsfq{WZfw?VY{X$y@1_=fM1}3m^k_Jn>`gxfe zUV=;k;}Z>vqMZ>!0$^jnA`WakD@#;nGbDI{xgY}MCUs#}mIfvUMg@=n*iFF~)fi?7 zfw&+Ha#N$XD?SGcJf;RsMvwv! zKGC3f^?rzm00S2TO9M!N1>~j-_SFmtQ6MfD8zp!gwew?TVPHmf)8B(S3^Sy_{s*Z7 zIqIZ97~_F0ARZW}9GKwMJBN#dfejQ;uyCqqiDWP^1uFm%k_JnfRNpW)oJ4k$sFkIJ z00Td=o8EAqW=LoP=>TJ+1dpddGuaq|>){@r!05bc^C|%a1{JUh1_lO@n?8n0dY^nfV#ixL`t{sL3*C6KC)dfeC?v zJ&K>#hGB*VObBG<)%y}LXjUFN_<#pVSv}a7e-jljj5}ob!GvLk7P4`FlMON*7$RUo zpum`Pa=~wggk5kUqXdty(->PB85=vmEAWE)X$hY3X56(-1W?+if=8WRsEkuD~Xy$?5_hQ_~q37=&3o zbwC*v6t*r?7&M={UgTiVWOX^h!Jw(iz`*d)fPn!dr_G`osifw=!z|i5nM+7n9VEu5qmdwx!pY{}24R@C#2LsDwCSle3t&STQ84?vB0U^xd$mY@`A(9}_l;E(@aRMWwG3SEebq*8^9?BQt2S&lhyHgd2iiijXgk^<|2SPn|1E?t)0s@H}7$*ygiA!S0M#E$s#l)puL|lGeIl$o{7+?wn!Vcx@Q9|^NKinXQ%W(c+R zNhrkeG&wF}Z%$z0VJK8!V31&7U@+nnP-6 zG6YvQh#%paadB;f)K9JxhC3T%S8@Mf+})tKhuflQPlM_s?u@4W4eFD4I+{*3XddEO z(R8*!`xDQJOE((yZa`eZz`&5o!@%Ip$iN`PJS72?Yd{`&$-)3`!1*vRfNN%OeFti% zffz;%3=ABs5Q+CJ3=E(m2vli+T34Xt1M;E*s1gA&K)!&|SquyeoD2*MFdEbX=Fnka zC}v<_0J#Q6gTxsO7#KjUDKStRgn@wpMAw5PKpLhoFfd3nFfhPqkT|Cb1H(oJ1_n6> z1Pv2E&%nR{@)?YViGv!JAah_eNL)aLfq{h)Vj+wMiF4>M)H6siGBALA2%|v~APvTh z3=E*c8%Be~K@M_bgjfutLE<3w35*O3q6`cSFd8HdQeO({yf82@z-W*-Nc{w``SqZd z5X2&o1W3bTMg|5M1_lNg4H5@w*blV;MuWsb>aQ?D5*CaGiG$R?VPs(7V_;x_(I9b< z`d?7>U>X$vAPJBLAts1HFd8PV#RM@3MuWsb23bJW!)TB=$eaMEIE;pgXF$z?(ID}9 zkU?co4KNxc0WzqM2@*9h8YaG+2@=#W8YB)%LwlGQ7!(;87+^F=9AwU2CI$vjg#e>L z;vjoJGBMOMfVvYOnHaz{NCKomkQovsFd8Hd(xAf3z@Wy!zyPB`;vn^2%nS^m>KR6Z z#6b>9WQI5hMuWsb>YJDu7(fl$Cgyqw117PGnSnu>fq?-=!^F=*HNa??_*bZU7!4Db zU}0cTXJBA}(J*m)76t|-1_lNg{U4P7KrTvTfn+ln4buSXM}Tqxj0TB=d^U-NfkBpm zfdNK?#6jk)XMqGYj0TB=)E|bHDKHu)&Uu9e5|>vdN6KlNf|{$K%m|`EO1M}U80?@q z9Y%x18(0{??R8M`6ft?BoV=tEE2O{yl}oXR($8LM@=ZB8X9Pp8lj*LY_ie`oLbAjc5b&0pnASQPRS^YV*xQd0}!OO1_ka>{b`5_411lk@b7 M^}xcDC6p%t04WmzfB*mh delta 3764 zcmccOv&Mab3S-_x)#>#b%pHOaLK4L+2?1Qq3Ia)e3?gdWT}Be8f|!^zu2}hH{Q3Dayz$rlz+zVWo8G{)Z0wow2ycii6q$fQ1|NsC0&DM-;EPReK+dv%o z22%zf#*WEBtPS=2-$5eG1zZjC4D}2i4)QG97;-vxfdu6nf*JlW{9txqdd2Xf`G8mh zZ!eN!p$qJdx;)Yo^cjQ_7+)kXGqEvEVBikORXEhZ@FRg)kqu;oJ;NooR34@VktPPN zOAmO>G*T28mI$x}Okm63a8RT9|9|2C|3R)X}0!}AM9^noniH;Hsla@4idMGXuA3~E1XaT@fzerO?n(^>1`!4h29QfR zKyKR08^(BG3rr0|%7F>5O6G8JFt8!JX=`&NgMlerk)*+rEcrD|4Vm>IW5B{E8Wcl0 zEky+wc#s1sjbk%ILKav7h%id<*y?tJm0>b7hzG);a0)CkP+*WlcGFRzX^aP0;AW>B zn4mS+h?ij*D562S7(j8S6e{ab&)@;(g9u53CA(G{++awU1>%9RQG!QnxClE-0}}%y zIMNwGVQ`J9nU`T1Cs+k2`k327-k26TFlaCcFn|I}fI-q=$uE854-5&5V2T)w5S!r6I{($j-3vx(!u`?r-#e(!3C3qZdcl^$HAOu;RmN+vPGt&W(hd_!z zDLK@dO^}V{04tc!zyMAV^*mYz3``902sy#f$;;HbmxqC2lf&d^JTB}Uj};rZc#<|7 z^71n?8clW-Q09;j&|%bIWMG(_BhVpk%+lDyqNX8gX13dGy2&OZ^Id^W3=9mDy#+to zR5LIz2(x(VfbtT^)|M#@nonIXaxiGJx}4x(&;+IOj|L13k_-$CO715n1oE-SfgE=} znUP`g7XcGjVU|?BFpl=kP7ZTKogS?b5J`7-$dPC}dc(k=XQQG9gRvk}6Z=wwq=Xa& z4z?aSiGv-Ciy9Rq8XQIbGzRua2t;!W&5@-nsfSMu&Ru8chlrmd-7^cW;K6Uq)JW*7Zk#X{RQBTH} z$v;Ka9R=AGIoM`M3LMp8Gz@9R}OGE2nLvgJiyQ{QPH??xr|j(Lc?T#adob_jtOVYEEp%S3QcYhml206 zXJB9`7pZVuI9&p)bhEfR%S^|B+{tp{VtCEummu2AVhJM6JTIY!?9TrZvN+voEU5-@ zCn3N0OQM*$UlO;OpC#4mL4FqyVg|)K_oTxb9gZ9fW}p~gfTVIzYzVVBinXLUW(c+P zNhrkdG&wf0Hz%<0Fcc~18q{ucR|xHI(CFrw zA#|!iYd6o1q_a?WFr@M@Fjz7&FbFYE*#RoNK#rX}nNdc%UY~)1;Q$K*xPG0en+eFi0>kFu-V#I7t0*Mg|5x1_lNg4H5^bzY0|kqe0^JAPsM!5-=Jh z0TO3rf*1s&LE<2Tgqa`~!)TB=NWC^x97e;$ouTHyXplI_oIoaqdaw@yp&CIFAcHcQ zKykys0Ha~zO-zuWhS4B#P+U%DVqj2YU|@jJAaRg6JD3<4Kn)uh4H5@ge1-|s3T9w9 z!&DDpfFwW~-Y`Mp9!7)2K^p!uF)*kxFfhRB|NsAk)T=T>GBAt=sRud8h8f}@7!48! zsgGr5U;s6%VD$fbs6nmF3=GN)3=A+DreOtC1B`}=Ux2EI(J=AP%nS_b3=9k~8YV8q z!oZ*e$~jOPCLU #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME +#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME +#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__ extern "C" { @@ -220,6 +222,29 @@ extern "C" SHFL_SYNC_IMPL(bfly, self ^ delta, >); SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >); + DECLARE_ATTR(uint32_t, CLOCK_RATE); + void FUNC(nanosleep_u32)(uint32_t nanoseconds) { + // clock_rate is in kHz + uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000; + uint64_t cycles = nanoseconds * cycles_per_ns; + // Avoid small sleep values resulting in s_sleep 0 + cycles += 63; + // s_sleep N sleeps for 64 * N cycles + uint64_t sleep_amount = cycles / 64; + + // The argument to s_sleep must be a constant + for (size_t i = 0; i < sleep_amount >> 4; i++) + __builtin_amdgcn_s_sleep(16); + if (sleep_amount & 8U) + __builtin_amdgcn_s_sleep(8); + if (sleep_amount & 4U) + __builtin_amdgcn_s_sleep(4); + if (sleep_amount & 2U) + __builtin_amdgcn_s_sleep(2); + if (sleep_amount & 1U) + __builtin_amdgcn_s_sleep(1); + } + void FUNC(__assertfail)(uint64_t message, uint64_t file, uint32_t line, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index da972f6..7aa9ee1 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -3,4 +3,5 @@ pub(crate) mod pass; mod test; pub use pass::to_llvm_module; +pub use pass::Attributes; diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 4ad5339..f2fead7 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -152,6 +152,7 @@ fn run_instruction<'input>( .. } | ast::Instruction::Mul24 { .. } + | ast::Instruction::Nanosleep { .. } | ast::Instruction::Neg { .. } | ast::Instruction::Not { .. } | ast::Instruction::Or { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 91bf1a8..5692365 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1809,6 +1809,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Cvta { .. } | ast::Instruction::Atom { .. } | ast::Instruction::Mul24 { .. } + | ast::Instruction::Nanosleep { .. } | ast::Instruction::AtomCas { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), diff --git a/ptx/src/pass/llvm/attributes.rs b/ptx/src/pass/llvm/attributes.rs new file mode 100644 index 0000000..4479ece --- /dev/null +++ b/ptx/src/pass/llvm/attributes.rs @@ -0,0 +1,34 @@ +use std::ffi::CStr; + +use super::*; +use super::super::*; +use llvm_zluda::{core::*}; + +pub(crate) fn run(context: &Context, attributes: Attributes) -> Result { + let module = llvm::Module::new(context, LLVM_UNNAMED); + + emit_attribute(context, &module, "clock_rate", attributes.clock_rate)?; + + if let Err(err) = module.verify() { + panic!("{:?}", err); + } + + Ok(module) +} + +fn emit_attribute(context: &Context, module: &llvm::Module, name: &str, attribute: u32) -> Result<(), TranslateError> { + let name = format!("{}attribute_{}\0", ZLUDA_PTX_PREFIX, name).to_ascii_uppercase(); + let name = unsafe { CStr::from_bytes_with_nul_unchecked(name.as_bytes()) }; + let attribute_type = get_scalar_type(context.get(), ast::ScalarType::U32); + let global = unsafe { + LLVMAddGlobalInAddressSpace( + module.get(), + attribute_type, + name.as_ptr(), + get_state_space(ast::StateSpace::Global)?, + ) + }; + unsafe { LLVMSetInitializer(global, LLVMConstInt(attribute_type, attribute as u64, 0)) }; + unsafe { LLVMSetGlobalConstant(global, 1) }; + Ok(()) +} \ No newline at end of file diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/llvm/emit.rs similarity index 92% rename from ptx/src/pass/emit_llvm.rs rename to ptx/src/pass/llvm/emit.rs index b888202..f7aafaa 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -27,98 +27,15 @@ use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; -use std::ops::Deref; use std::{i8, ptr}; use super::*; -use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; -use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; +use crate::pass::*; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; use ptx_parser::Mul24Control; -const LLVM_UNNAMED: &CStr = c""; -// https://llvm.org/docs/AMDGPUUsage.html#address-spaces -const GENERIC_ADDRESS_SPACE: u32 = 0; -const GLOBAL_ADDRESS_SPACE: u32 = 1; -const SHARED_ADDRESS_SPACE: u32 = 3; -const CONSTANT_ADDRESS_SPACE: u32 = 4; -const PRIVATE_ADDRESS_SPACE: u32 = 5; - -struct Context(LLVMContextRef); - -impl Context { - fn new() -> Self { - Self(unsafe { LLVMContextCreate() }) - } - - fn get(&self) -> LLVMContextRef { - self.0 - } -} - -impl Drop for Context { - fn drop(&mut self) { - unsafe { - LLVMContextDispose(self.0); - } - } -} - -pub struct Module(LLVMModuleRef, Context); - -impl Module { - fn new(ctx: Context, name: &CStr) -> Self { - Self( - unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }, - ctx, - ) - } - - fn get(&self) -> LLVMModuleRef { - self.0 - } - - fn context(&self) -> &Context { - &self.1 - } - - fn verify(&self) -> Result<(), Message> { - let mut err = ptr::null_mut(); - let error = unsafe { - LLVMVerifyModule( - self.get(), - LLVMVerifierFailureAction::LLVMReturnStatusAction, - &mut err, - ) - }; - if error == 1 && err != ptr::null_mut() { - Err(Message(unsafe { CStr::from_ptr(err) })) - } else { - Ok(()) - } - } - - pub fn write_bitcode_to_memory(&self) -> MemoryBuffer { - let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; - MemoryBuffer(memory_buffer) - } - - pub fn print_module_to_string(&self) -> Message { - let asm = unsafe { LLVMPrintModuleToString(self.get()) }; - Message(unsafe { CStr::from_ptr(asm) }) - } -} - -impl Drop for Module { - fn drop(&mut self) { - unsafe { - LLVMDisposeModule(self.0); - } - } -} - struct Builder(LLVMBuilderRef); impl Builder { @@ -143,55 +60,13 @@ impl Drop for Builder { } } -pub struct Message(&'static CStr); - -impl Drop for Message { - fn drop(&mut self) { - unsafe { - LLVMDisposeMessage(self.0.as_ptr().cast_mut()); - } - } -} - -impl std::fmt::Debug for Message { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Debug::fmt(&self.0, f) - } -} - -impl Message { - pub fn to_str(&self) -> &str { - self.0.to_str().unwrap().trim() - } -} - -pub struct MemoryBuffer(LLVMMemoryBufferRef); - -impl Drop for MemoryBuffer { - fn drop(&mut self) { - unsafe { - LLVMDisposeMemoryBuffer(self.0); - } - } -} - -impl Deref for MemoryBuffer { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - let data = unsafe { LLVMGetBufferStart(self.0) }; - let len = unsafe { LLVMGetBufferSize(self.0) }; - unsafe { std::slice::from_raw_parts(data.cast(), len) } - } -} - -pub(super) fn run<'input>( +pub(crate) fn run<'input>( + context: &Context, id_defs: GlobalStringIdentResolver2<'input>, directives: Vec, SpirvWord>>, -) -> Result { - let context = Context::new(); - let module = Module::new(context, LLVM_UNNAMED); - let mut emit_ctx = ModuleEmitContext::new(&module, &id_defs); +) -> Result { + let module = llvm::Module::new(context, LLVM_UNNAMED); + let mut emit_ctx = ModuleEmitContext::new(context, &module, &id_defs); for directive in directives { match directive { Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, @@ -213,8 +88,7 @@ struct ModuleEmitContext<'a, 'input> { } impl<'a, 'input> ModuleEmitContext<'a, 'input> { - fn new(module: &Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self { - let context = module.context(); + fn new(context: &Context, module: &llvm::Module, id_defs: &'a GlobalStringIdentResolver2<'input>) -> Self { ModuleEmitContext { context: context.get(), module: module.get(), @@ -642,7 +516,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::BarRed { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } - | ast::Instruction::ShflSync { .. } => return Err(error_unreachable()), + | ast::Instruction::ShflSync { .. } + | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), } } @@ -2729,33 +2604,6 @@ fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result LLVMTypeRef { - match type_ { - ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, - ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { - LLVMInt8TypeInContext(context) - }, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { - LLVMInt16TypeInContext(context) - }, - ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { - LLVMInt32TypeInContext(context) - }, - ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { - LLVMInt64TypeInContext(context) - }, - ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, - ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, - ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, - ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, - ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), - ast::ScalarType::BF16x2 => todo!(), - } -} - fn get_array_type<'a>( context: LLVMContextRef, elem_type: &'a ast::Type, @@ -2808,22 +2656,6 @@ fn get_function_type<'a>( }) } -fn get_state_space(space: ast::StateSpace) -> Result { - match space { - ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), - ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), - ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), - ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), - ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), - ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), - } -} - struct ResolveIdent { words: HashMap, values: HashMap, diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs new file mode 100644 index 0000000..daaa91f --- /dev/null +++ b/ptx/src/pass/llvm/mod.rs @@ -0,0 +1,173 @@ +pub(super) mod emit; +pub(super) mod attributes; + +use std::ffi::CStr; +use std::ops::Deref; +use std::ptr; + +use crate::pass::*; +use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; +use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; +use llvm_zluda::core::*; +use llvm_zluda::prelude::*; + +const LLVM_UNNAMED: &CStr = c""; + +// https://llvm.org/docs/AMDGPUUsage.html#address-spaces +const GENERIC_ADDRESS_SPACE: u32 = 0; +const GLOBAL_ADDRESS_SPACE: u32 = 1; +const SHARED_ADDRESS_SPACE: u32 = 3; +const CONSTANT_ADDRESS_SPACE: u32 = 4; +const PRIVATE_ADDRESS_SPACE: u32 = 5; + +pub(super) struct Context(LLVMContextRef); + +impl Context { + pub fn new() -> Self { + Self(unsafe { LLVMContextCreate() }) + } + + fn get(&self) -> LLVMContextRef { + self.0 + } +} + +impl Drop for Context { + fn drop(&mut self) { + unsafe { + LLVMContextDispose(self.0); + } + } +} + +pub struct Module(LLVMModuleRef); + +impl Module { + fn new(ctx: &Context, name: &CStr) -> Self { + Self( + unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }, + ) + } + + fn get(&self) -> LLVMModuleRef { + self.0 + } + + fn verify(&self) -> Result<(), Message> { + let mut err = ptr::null_mut(); + let error = unsafe { + LLVMVerifyModule( + self.get(), + LLVMVerifierFailureAction::LLVMReturnStatusAction, + &mut err, + ) + }; + if error == 1 && err != ptr::null_mut() { + Err(Message(unsafe { CStr::from_ptr(err) })) + } else { + Ok(()) + } + } + + pub fn write_bitcode_to_memory(&self) -> MemoryBuffer { + let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; + MemoryBuffer(memory_buffer) + } + + pub fn print_module_to_string(&self) -> Message { + let asm = unsafe { LLVMPrintModuleToString(self.get()) }; + Message(unsafe { CStr::from_ptr(asm) }) + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { + LLVMDisposeModule(self.0); + } + } +} + +pub struct Message(&'static CStr); + +impl Drop for Message { + fn drop(&mut self) { + unsafe { + LLVMDisposeMessage(self.0.as_ptr().cast_mut()); + } + } +} + +impl std::fmt::Debug for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&self.0, f) + } +} + +impl Message { + pub fn to_str(&self) -> &str { + self.0.to_str().unwrap().trim() + } +} +pub struct MemoryBuffer(LLVMMemoryBufferRef); + +impl Drop for MemoryBuffer { + fn drop(&mut self) { + unsafe { + LLVMDisposeMemoryBuffer(self.0); + } + } +} + +impl Deref for MemoryBuffer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + let data = unsafe { LLVMGetBufferStart(self.0) }; + let len = unsafe { LLVMGetBufferSize(self.0) }; + unsafe { std::slice::from_raw_parts(data.cast(), len) } + } +} + +fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef { + match type_ { + ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, + ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { + LLVMInt8TypeInContext(context) + }, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { + LLVMInt16TypeInContext(context) + }, + ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { + LLVMInt32TypeInContext(context) + }, + ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { + LLVMInt64TypeInContext(context) + }, + ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, + ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, + ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, + ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, + ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, + ast::ScalarType::U16x2 => todo!(), + ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::F16x2 => todo!(), + ast::ScalarType::BF16x2 => todo!(), + } +} + +fn get_state_space(space: ast::StateSpace) -> Result { + match space { + ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), + ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), + ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), + ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), + ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), + ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), + ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), + ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), + ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 385f759..ace910e 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -12,7 +12,6 @@ use strum::IntoEnumIterator; use strum_macros::EnumIter; mod deparamize_functions; -pub(crate) mod emit_llvm; mod expand_operands; mod fix_special_registers2; mod hoist_globals; @@ -20,6 +19,7 @@ mod insert_explicit_load_store; mod insert_implicit_conversions2; mod insert_post_saturation; mod instruction_mode_to_global_mode; +mod llvm; mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; @@ -46,7 +46,13 @@ quick_error! { } } -pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { +/// GPU attributes needed at compile time. +pub struct Attributes { + /// Clock frequency in kHz. + pub clock_rate: u32, +} + +pub fn to_llvm_module<'input>(ast: ast::Module<'input>, attributes: Attributes) -> Result { let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; @@ -65,16 +71,23 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, + _context: llvm::Context, } impl Module { diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 0480e5f..6420c79 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -137,6 +137,9 @@ fn run_instruction<'input>( ptx_parser::Instruction::ShflSync { data, arguments }, )? } + i @ ptx_parser::Instruction::Nanosleep { .. } => { + to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? + } i => i, }) } diff --git a/ptx/src/test/ll/_attributes.ll b/ptx/src/test/ll/_attributes.ll new file mode 100644 index 0000000..bd06a06 --- /dev/null +++ b/ptx/src/test/ll/_attributes.ll @@ -0,0 +1 @@ +@__ZLUDA_PTX_IMPL_ATTRIBUTE_CLOCK_RATE = addrspace(1) constant i32 2124000 \ No newline at end of file diff --git a/ptx/src/test/ll/nanosleep.ll b/ptx/src/test/ll/nanosleep.ll new file mode 100644 index 0000000..d567302 --- /dev/null +++ b/ptx/src/test/ll/nanosleep.ll @@ -0,0 +1,15 @@ +declare void @__zluda_ptx_impl_nanosleep_u32(i32) #0 + +define amdgpu_kernel void @nanosleep(ptr addrspace(4) byref(i64) %"28", ptr addrspace(4) byref(i64) %"29") #1 { + br label %1 + +1: ; preds = %0 + br label %"27" + +"27": ; preds = %1 + call void @__zluda_ptx_impl_nanosleep_u32(i32 1) + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index e9943f4..ba45ec0 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,4 +1,4 @@ -use crate::pass::TranslateError; +use crate::pass::{self, TranslateError}; use ptx_parser as ast; mod spirv_run; @@ -9,7 +9,8 @@ fn parse_and_assert(ptx_text: &str) { fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> { let ast = ast::parse_module_checked(ptx_text).unwrap(); - crate::to_llvm_module(ast)?; + let attributes = pass::Attributes { clock_rate: 2124000 }; + crate::to_llvm_module(ast, attributes)?; Ok(()) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index d861593..df85a24 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -297,6 +297,8 @@ test_ptx!( test_ptx!(multiple_return, [5u32], [6u32, 123u32]); test_ptx!(warp_sz, [0u8], [32u8]); +test_ptx!(nanosleep, [0u64], [0u64]); + test_ptx!(assertfail); // TODO: not yet supported //test_ptx!(func_ptr); @@ -375,7 +377,7 @@ fn test_hip_assert< block_dim_x: u32, ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast).unwrap(); + let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap(); let name = CString::new(name)?; let result = run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?; @@ -389,9 +391,19 @@ fn test_llvm_assert( expected_ll: &str, ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast).unwrap(); + let llvm_ir = pass::to_llvm_module(ast, pass::Attributes { clock_rate: 2124000 }).unwrap(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); let actual_ll = actual_ll.to_str(); + compare_llvm(name, actual_ll, expected_ll); + + let expected_attributes_ll = read_test_file!(concat!("../ll/_attributes.ll")); + let actual_attributes_ll = llvm_ir.attributes_ir.print_module_to_string(); + let actual_attributes_ll = actual_attributes_ll.to_str(); + compare_llvm("_attributes", actual_attributes_ll, &expected_attributes_ll); + Ok(()) +} + +fn compare_llvm(name: &str, actual_ll: &str, expected_ll: &str) { if actual_ll != expected_ll { let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); if let Ok(output_dir) = output_dir { @@ -404,7 +416,6 @@ fn test_llvm_assert( let comparison = pretty_assertions::StrComparison::new(&expected_ll, &actual_ll); panic!("assertion failed: `(left == right)`\n\n{}", comparison); } - Ok(()) } fn test_cuda_assert< @@ -567,6 +578,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def &comgr, unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, &*module.llvm_ir.write_bitcode_to_memory(), + &*module.attributes_ir.write_bitcode_to_memory(), module.linked_bitcode(), ) .unwrap(); diff --git a/ptx/src/test/spirv_run/nanosleep.ptx b/ptx/src/test/spirv_run/nanosleep.ptx new file mode 100644 index 0000000..c96d02a --- /dev/null +++ b/ptx/src/test/spirv_run/nanosleep.ptx @@ -0,0 +1,13 @@ +.version 6.5 +.target sm_70 +.address_size 64 + +.visible .entry nanosleep( + .param .u64 input, + .param .u64 output +) +{ + // TODO: check if there's some way of testing that it actually sleeps + nanosleep.u32 1; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 63155f4..53c7fa5 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -327,6 +327,12 @@ ptx_parser_macros::generate_instruction_type!( src2: T, } }, + Nanosleep { + type: Type::Scalar(ScalarType::U32), + arguments: { + src: T + } + }, Neg { type: Type::Scalar(data.type_), data: TypeFtz, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index bed2c55..e2c87fc 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3502,6 +3502,13 @@ derive_parser!( } } .mode: ShuffleMode = { .up, .down, .bfly, .idx }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-nanosleep + nanosleep.u32 t => { + Instruction::Nanosleep { + arguments: NanosleepArgs { src: t } + } + } ); #[cfg(test)] diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 6601628..da37a22 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -64,15 +64,17 @@ pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result