diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 9533233..6651430 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 85823b4..f1b416d 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -13,12 +13,30 @@ extern "C" return __builtin_amdgcn_read_exec_lo(); } + size_t __ockl_get_local_id(uint32_t) __device__; + uint32_t FUNC(sreg_tid)(uint8_t member) + { + return (uint32_t)__ockl_get_local_id(member); + } + size_t __ockl_get_local_size(uint32_t) __device__; uint32_t FUNC(sreg_ntid)(uint8_t member) { return (uint32_t)__ockl_get_local_size(member); } + size_t __ockl_get_global_id(uint32_t) __device__; + uint32_t FUNC(sreg_ctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_id(member); + } + + size_t __ockl_get_global_size(uint32_t) __device__; + uint32_t FUNC(sreg_nctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_size(member); + } + uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device)); uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) { diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index ce1eb84..1784745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -522,9 +522,9 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), - ast::Instruction::Cvt { .. } => todo!(), - ast::Instruction::Shr { .. } => todo!(), - ast::Instruction::Shl { .. } => todo!(), + ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), + ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), + ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Abs { .. } => todo!(), @@ -533,9 +533,9 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), ast::Instruction::Min { .. } => todo!(), ast::Instruction::Max { .. } => todo!(), - ast::Instruction::Rcp { .. } => todo!(), - ast::Instruction::Sqrt { .. } => todo!(), - ast::Instruction::Rsqrt { .. } => todo!(), + ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), + ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), + ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), ast::Instruction::Selp { .. } => todo!(), ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), @@ -1406,6 +1406,212 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cvt( + &mut self, + data: ptx_parser::CvtDetails, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let dst_type = get_scalar_type(self.context, data.to); + let llvm_fn = match data.mode { + ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, + ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, + ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, + ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, + ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(), + ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(), + ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(), + ptx_parser::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::FPRound { + integer_rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::SignedFromFP { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::UnsignedFromFP { + rounding, + flush_to_zero, + } => todo!(), + ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(), + ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(), + }; + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }); + Ok(()) + } + + fn emit_rsqrt( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::RsqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match data.type_ { + ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", + ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_sqrt( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::SqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", + (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", + (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_rcp( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", + (_, ast::RcpKind::Compliant(rnd)) => { + return self.emit_rcp_compliant(data, arguments, rnd) + } + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(arguments.src, type_)], + )?; + Ok(()) + } + + fn emit_rcp_compliant( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + _rnd: ast::RoundingMode, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let one = unsafe { LLVMConstReal(type_, 1.0) }; + let src = self.resolver.value(arguments.src)?; + let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(self.builder, one, src, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; + Ok(()) + } + + fn emit_shr( + &mut self, + data: ptx_parser::ShrData, + arguments: ptx_parser::ShrArgs, + ) -> Result<(), TranslateError> { + let shift_fn = match data.kind { + ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr, + ptx_parser::RightShiftKind::Logical => LLVMBuildLShr, + }; + self.emit_shift( + data.type_, + arguments.dst, + arguments.src1, + arguments.src2, + shift_fn, + ) + } + + fn emit_shl( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::ShlArgs, + ) -> Result<(), TranslateError> { + self.emit_shift( + type_, + arguments.dst, + arguments.src1, + arguments.src2, + LLVMBuildShl, + ) + } + + fn emit_shift( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, + llvm_fn: unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let shift_size = self.resolver.value(src2)?; + let integer_bits = type_.layout().size() * 8; + let integer_bits_constant = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::U32), + integer_bits as u64, + 0, + ) + }; + let should_clamp = unsafe { + LLVMBuildICmp( + self.builder, + LLVMIntPredicate::LLVMIntUGE, + shift_size, + integer_bits_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let llvm_type = get_scalar_type(self.context, type_); + let zero = unsafe { LLVMConstNull(llvm_type) }; + let normalized_shift_size = if type_.layout().size() >= 4 { + unsafe { + LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) + } + } else { + unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } + }; + let shifted = unsafe { + llvm_fn( + self.builder, + src1, + normalized_shift_size, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) + }); + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 1b5afee..f4b7921 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -45,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); -test_ptx!(shl_link_hack, [11u64], [44u64]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); diff --git a/ptx/src/test/spirv_run/shl_link_hack.ptx b/ptx/src/test/spirv_run/shl_link_hack.ptx deleted file mode 100644 index a32555c..0000000 --- a/ptx/src/test/spirv_run/shl_link_hack.ptx +++ /dev/null @@ -1,30 +0,0 @@ -// HACK ALERT -// This test is for testing workaround for a bug in IGC where linking fails -// if there is shl/shr with different width of value and shift - -.version 6.5 -.target sm_30 -.address_size 64 - -.visible .entry shl_link_hack( - .param .u64 input, - .param .u64 output -) -{ - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - // Here only to trigger linking - .reg .u32 unused; - atom.inc.u32 unused, [out_addr], 2000000; - - ld.u64 temp, [in_addr]; - shl.b64 temp2, temp, 2; - st.u64 [out_addr], temp2; - ret; -}