From ae42eac925201578d74ed3f49e380d07f6f7d0ed Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 14 Oct 2024 19:09:47 +0200 Subject: [PATCH] Add shifts, cvt, rsqrt, sqrt, rcp, more sregs --- ptx/lib/zluda_ptx_impl.bc | Bin 4136 -> 4624 bytes ptx/lib/zluda_ptx_impl.cpp | 18 ++ ptx/src/pass/emit_llvm.rs | 218 ++++++++++++++++++++++- ptx/src/test/spirv_run/mod.rs | 1 - ptx/src/test/spirv_run/shl_link_hack.ptx | 30 ---- 5 files changed, 230 insertions(+), 37 deletions(-) delete mode 100644 ptx/src/test/spirv_run/shl_link_hack.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 9533233ef8adad3c8cba1d5339c9bfaae496681d..6651430dc0372d68f8e8118c1b2f5ed333d7d0a9 100644 GIT binary patch delta 1349 zcmaKqUr19?9LIlm_s`Ka=gt4Qh-+y7ll^NhdRYikVPIiISnbYj>2jj3Sth}55>XJi zx=HpBjEI(DVE7UQF~f%lq!mOx)?-0E+T;42d4CUybjG>gb3ecD?{|K8=RDK0t=U!~ z@v?13%>2+#=D*sxwbrn5DRHIrsJk2q6eT1;P*{3)$z5ei5DuWYWc(y?=UI zLZIT$hExTY`$_-p6)+Oyfv-l(=QBMODo7w1oJs}7F}FJ91HGBQmPbH8k{Ne3*d!1@ z&^QtngwE%yg<(^H)B2Pcyg7bv<3|a?7woT(bc1o`$B{z+F#-_tad!Q?MCeV24I8q3P{$8qfF zWz*tamaml#D=l3Pdz-Wb52qZqB(jZSxBnBnT8=GF+RxjLZo+)8 zQ{8?8=H*?vK0{iP2aU5v#RvJBe$H_Gl)o0kM)izRv3ba-dB1iw;`rrvF|S6vKOv#= z|NWJ@rOM^0!&Q)90vvDfifl&xfLep<;t&VL`MS=GR~^c1!IcLm*Afo372p`IYjKit zUWxo1GJa)9$$teIKW$R--$BNACMAC!8SfS~Q(jQQYlK^Il2XDKWDPh;$^Q#k6HZd{ z%lO$&+f7n0+p17&!5|(KQ8v|T1OrK|CFYV9 zv|d-%LO_rRQtClSiHOIb9uz4=ym%0)Codu%M7+f}n|&AG@aCKO-uHfH?8LTYCw%Fh z@lDPt1d{clEvy@~X417T|3{?jODZ{Hy5v+h&BA=lMhub-?8M^ z&IoCQkzq4^y#MCEnU;>XcrwAEh|M$?RSYPK7MFJBn}uC?!xX%)J;oO}O`lO_Qj)4E z*_5Q~*()oVjFwHSt7$!}&B|I@mR0$(tmpNsN=nUV^0Gd=(tZ`1^l{n;-Z?Nh(#{Po z`sYK?LlwkBcgcg`;a}&iz*ex(ykODYU?X?Q3g6-1HLv3ss}VwqaEWpc3v68Glg_uH zxFW!5+k3bII1Kj%z$joGoEgU1S#YL>qxBe^6mW(;Ffl>MDLC&3M`U~mcoVo$a75Hs zfn(hfQNIBk(<7q(vBv|}T>%DOLVy)Q#6oYq1Vq$-@pwe!e|vQhv5p_UgAIh(>mBL8 jIOhEqB1;&U0wvf*`0;aq@Z&H1&bYYbJQ9=6#PIijbScv1 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 85823b4..f1b416d 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -13,12 +13,30 @@ extern "C" return __builtin_amdgcn_read_exec_lo(); } + size_t __ockl_get_local_id(uint32_t) __device__; + uint32_t FUNC(sreg_tid)(uint8_t member) + { + return (uint32_t)__ockl_get_local_id(member); + } + size_t __ockl_get_local_size(uint32_t) __device__; uint32_t FUNC(sreg_ntid)(uint8_t member) { return (uint32_t)__ockl_get_local_size(member); } + size_t __ockl_get_global_id(uint32_t) __device__; + uint32_t FUNC(sreg_ctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_id(member); + } + + size_t __ockl_get_global_size(uint32_t) __device__; + uint32_t FUNC(sreg_nctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_size(member); + } + uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device)); uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) { diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index ce1eb84..1784745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -522,9 +522,9 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), - ast::Instruction::Cvt { .. } => todo!(), - ast::Instruction::Shr { .. } => todo!(), - ast::Instruction::Shl { .. } => todo!(), + ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), + ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), + ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Abs { .. } => todo!(), @@ -533,9 +533,9 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), ast::Instruction::Min { .. } => todo!(), ast::Instruction::Max { .. } => todo!(), - ast::Instruction::Rcp { .. } => todo!(), - ast::Instruction::Sqrt { .. } => todo!(), - ast::Instruction::Rsqrt { .. } => todo!(), + ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), + ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), + ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), ast::Instruction::Selp { .. } => todo!(), ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), @@ -1406,6 +1406,212 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cvt( + &mut self, + data: ptx_parser::CvtDetails, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let dst_type = get_scalar_type(self.context, data.to); + let llvm_fn = match data.mode { + ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, + ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, + ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, + ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, + ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(), + ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(), + ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(), + ptx_parser::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::FPRound { + integer_rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::SignedFromFP { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::UnsignedFromFP { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(), + ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(), + }; + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }); + Ok(()) + } + + fn emit_rsqrt( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::RsqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match data.type_ { + ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", + ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_sqrt( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::SqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", + (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", + (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_rcp( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", + (_, ast::RcpKind::Compliant(rnd)) => { + return self.emit_rcp_compliant(data, arguments, rnd) + } + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_rcp_compliant( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + _rnd: ast::RoundingMode, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let one = unsafe { LLVMConstReal(type_, 1.0) }; + let src = self.resolver.value(arguments.src)?; + let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(self.builder, one, src, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; + Ok(()) + } + + fn emit_shr( + &mut self, + data: ptx_parser::ShrData, + arguments: ptx_parser::ShrArgs, + ) -> Result<(), TranslateError> { + let shift_fn = match data.kind { + ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr, + ptx_parser::RightShiftKind::Logical => LLVMBuildLShr, + }; + self.emit_shift( + data.type_, + arguments.dst, + arguments.src1, + arguments.src2, + shift_fn, + ) + } + + fn emit_shl( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::ShlArgs, + ) -> Result<(), TranslateError> { + self.emit_shift( + type_, + arguments.dst, + arguments.src1, + arguments.src2, + LLVMBuildShl, + ) + } + + fn emit_shift( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, + llvm_fn: unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let shift_size = self.resolver.value(src2)?; + let integer_bits = type_.layout().size() * 8; + let integer_bits_constant = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::U32), + integer_bits as u64, + 0, + ) + }; + let should_clamp = unsafe { + LLVMBuildICmp( + self.builder, + LLVMIntPredicate::LLVMIntUGE, + shift_size, + integer_bits_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let llvm_type = get_scalar_type(self.context, type_); + let zero = unsafe { LLVMConstNull(llvm_type) }; + let normalized_shift_size = if type_.layout().size() >= 4 { + unsafe { + LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) + } + } else { + unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } + }; + let shifted = unsafe { + llvm_fn( + self.builder, + src1, + normalized_shift_size, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) + }); + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 1b5afee..f4b7921 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -45,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); -test_ptx!(shl_link_hack, [11u64], [44u64]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); diff --git a/ptx/src/test/spirv_run/shl_link_hack.ptx b/ptx/src/test/spirv_run/shl_link_hack.ptx deleted file mode 100644 index a32555c..0000000 --- a/ptx/src/test/spirv_run/shl_link_hack.ptx +++ /dev/null @@ -1,30 +0,0 @@ -// HACK ALERT -// This test is for testing workaround for a bug in IGC where linking fails -// if there is shl/shr with different width of value and shift - -.version 6.5 -.target sm_30 -.address_size 64 - -.visible .entry shl_link_hack( - .param .u64 input, - .param .u64 output -) -{ - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - // Here only to trigger linking - .reg .u32 unused; - atom.inc.u32 unused, [out_addr], 2000000; - - ld.u64 temp, [in_addr]; - shl.b64 temp2, temp, 2; - st.u64 [out_addr], temp2; - ret; -}