diff --git a/Cargo.lock b/Cargo.lock index 7480124..04f0fea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2603,6 +2603,7 @@ dependencies = [ "quick-error", "rustc-hash 2.0.0", "serde", + "smallvec", "strum 0.26.3", "strum_macros 0.26.4", "tempfile", @@ -2940,9 +2941,9 @@ checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] name = "smallvec" -version = "1.13.2" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "sprs" diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index b2c5d99..2f9b174 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -21,6 +21,7 @@ petgraph = "0.7.1" microlp = "0.2.11" int-enum = "1.1" unwrap_or = "1.0.1" +smallvec = "1.15.1" serde = { version = "1.0.219", features = ["derive"] } [dev-dependencies] diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 64593d3..9cabbcc 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 0f0f1d3..38a26bb 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -10,7 +10,9 @@ #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME -#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__ +#define DECLARE_ATTR(TYPE, NAME) \ + extern const TYPE ATTR(NAME) \ + __device__ extern "C" { @@ -100,19 +102,6 @@ extern "C" } } - static __device__ uint32_t sub_sat(uint32_t x, uint32_t y) - { - uint32_t result; - if (__builtin_sub_overflow(x, y, &result)) - { - return 0; - } - else - { - return result; - } - } - int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len) { // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` @@ -122,7 +111,7 @@ extern "C" if (pos >= 64) return (base >> 63U); if (add_sat(pos, len) >= 64) - len = sub_sat(64, pos); + len = 64 - pos; return (base << (64U - pos - len)) >> (64U - len); } @@ -174,11 +163,8 @@ extern "C" BAR_RED_IMPL(and); BAR_RED_IMPL(or); - struct ShflSyncResult - { - uint32_t output; - bool in_bounds; - }; + +typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); // shfl.sync opts consists of two values, the warp end ID and the subsection mask. // @@ -192,7 +178,6 @@ extern "C" // The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync // intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the // bounds check. - #define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \ ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \ { \ @@ -208,12 +193,12 @@ extern "C" idx = self; \ } \ int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \ - return {(uint32_t)output, !out_of_bounds}; \ + return {(uint32_t)output, uint32_t(!out_of_bounds)}; \ } \ \ uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \ { \ - return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \ + return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \ } // We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync @@ -226,7 +211,8 @@ extern "C" SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >); DECLARE_ATTR(uint32_t, CLOCK_RATE); - void FUNC(nanosleep_u32)(uint32_t nanoseconds) { + void FUNC(nanosleep_u32)(uint32_t nanoseconds) + { // clock_rate is in kHz uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000; uint64_t cycles = nanoseconds * cycles_per_ns; @@ -335,4 +321,157 @@ extern "C" else return value; } + + // Logic taken from legalizeFSQRTF32/lowerFSQRTF32 in LLVM AMDGPU target + __device__ static float precise_square_root(float x, bool needs_denorm_handling) + { + + // Constants for denormal handling + const float scale_threshold = 0x1.0p-96f; // Very small value threshold + const float scale_up_factor = 0x1.0p+32f; // 2^32 + const float scale_down_factor = 0x1.0p-16f; // 2^-16 + + // Check if input needs scaling (for very small values) + bool need_scale = scale_threshold > x; + auto scaled = scale_up_factor * x; + + // Scale up input if needed + float sqrt_x = need_scale ? scaled : x; + + float sqrt_s; + + // Check if we need special denormal handling + + if (needs_denorm_handling) + { + // Use hardware sqrt as initial approximation + sqrt_s = __builtin_sqrtf(sqrt_x); // Or equivalent hardware instruction + + // Bit manipulations to get next values up/down + uint32_t sqrt_s_bits = std::bit_cast(sqrt_s); + + // Next value down (subtract 1 from bit pattern) + uint32_t sqrt_s_next_down_bits = sqrt_s_bits - 1; + float sqrt_s_next_down = std::bit_cast(sqrt_s_next_down_bits); + + // Calculate residual: x - sqrt_next_down * sqrt + float neg_sqrt_s_next_down = -sqrt_s_next_down; + float sqrt_vp = std::fma(neg_sqrt_s_next_down, sqrt_s, sqrt_x); + + // Next value up (add 1 to bit pattern) + uint32_t sqrt_s_next_up_bits = sqrt_s_bits + 1; + float sqrt_s_next_up = std::bit_cast(sqrt_s_next_up_bits); + + // Calculate residual: x - sqrt_next_up * sqrt + float neg_sqrt_s_next_up = -sqrt_s_next_up; + float sqrt_vs = std::fma(neg_sqrt_s_next_up, sqrt_s, sqrt_x); + + // Select correctly rounded result + if (sqrt_vp <= 0.0f) + { + sqrt_s = sqrt_s_next_down; + } + + if (sqrt_vs > 0.0f) + { + sqrt_s = sqrt_s_next_up; + } + } + else + { + // Use Newton-Raphson method with reciprocal square root + + // Initial approximation + float sqrt_r = __builtin_amdgcn_rsqf(sqrt_x); // Or equivalent hardware 1/sqrt instruction + sqrt_s = sqrt_x * sqrt_r; + + // Refine approximation + float half = 0.5f; + float sqrt_h = sqrt_r * half; + float neg_sqrt_h = -sqrt_h; + + // Calculate error term + float sqrt_e = std::fma(neg_sqrt_h, sqrt_s, half); + + // First refinement + sqrt_h = std::fma(sqrt_h, sqrt_e, sqrt_h); + sqrt_s = std::fma(sqrt_s, sqrt_e, sqrt_s); + + // Second refinement + float neg_sqrt_s = -sqrt_s; + float sqrt_d = std::fma(neg_sqrt_s, sqrt_s, sqrt_x); + sqrt_s = std::fma(sqrt_d, sqrt_h, sqrt_s); + } + + // Scale back if input was scaled + if (need_scale) + { + sqrt_s *= scale_down_factor; + } + + // Special case handling for zero and infinity + bool is_zero_or_inf = __builtin_isfpclass(sqrt_x, __FPCLASS_POSINF | __FPCLASS_POSZERO | __FPCLASS_NEGZERO); + + return is_zero_or_inf ? sqrt_x : sqrt_s; + } + + float FUNC(sqrt_rn_f32)(float x) + { + return precise_square_root(x, true); + } + + float FUNC(sqrt_rn_ftz_f32)(float x) + { + return precise_square_root(x, false); + } + + struct DivRnFtzF32Part1Result + { + float fma_4; + float fma_1; + float fma_3; + uint8_t numerator_scaled_flag; + }; + + DivRnFtzF32Part1Result FUNC(div_f32_part1)(float lhs, float rhs) + { + float one = 1.0f; + + // Division scale operations + bool denominator_scaled_flag; + float denominator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, false, &denominator_scaled_flag); + + bool numerator_scaled_flag; + float numerator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, true, &numerator_scaled_flag); + + // Reciprocal approximation + float approx_rcp = __builtin_amdgcn_rcpf(denominator_scaled); + float neg_div_scale0 = -denominator_scaled; + + // Perform division approximation steps + float fma_0 = fmaf(neg_div_scale0, approx_rcp, one); + float fma_1 = fmaf(fma_0, approx_rcp, approx_rcp); + float mul = numerator_scaled * fma_1; + float fma_2 = fmaf(neg_div_scale0, mul, numerator_scaled); + float fma_3 = fmaf(fma_2, fma_1, mul); + float fma_4 = fmaf(neg_div_scale0, fma_3, numerator_scaled); + return {fma_4, fma_1, fma_3, numerator_scaled_flag}; + } + + __device__ static float div_f32_part2(float x, float y, DivRnFtzF32Part1Result part1) + { + float fmas = __builtin_amdgcn_div_fmasf(part1.fma_4, part1.fma_1, part1.fma_3, part1.numerator_scaled_flag); + float result = __builtin_amdgcn_div_fixupf(fmas, y, x); + + return result; + } + + float FUNC(div_f32_part2)(float x, float y, + float fma_4, + float fma_1, + float fma_3, + uint8_t numerator_scaled_flag) + { + return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag}); + } } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 4e1ca5c..ba98a23 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -803,6 +803,13 @@ fn create_control_flow_graph( let modes = get_modes(instruction); bb_state.append(modes); } + Statement::FpModeRequired { ftz_f32, rnd_f32 } => { + bb_state.append(InstructionModes::new( + ast::ScalarType::F32, + ftz_f32.map(DenormalMode::from_ftz), + rnd_f32.map(RoundingMode::from_ast), + )); + } _ => {} } } @@ -1021,6 +1028,16 @@ fn apply_global_mode_controls( let modes = get_modes(&instruction); bb_state.insert(&mut result, modes)?; } + Statement::FpModeRequired { ftz_f32, rnd_f32 } => { + bb_state.insert( + &mut result, + InstructionModes::new( + ast::ScalarType::F32, + ftz_f32.map(DenormalMode::from_ftz), + rnd_f32.map(RoundingMode::from_ast), + ), + )?; + } _ => {} } result.push(statement); diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 5d9516f..90e12e6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -397,6 +397,8 @@ impl<'a> MethodEmitContext<'a> { Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?, + // No-op + Statement::FpModeRequired { .. } => {} }) } @@ -825,25 +827,12 @@ impl<'a> MethodEmitContext<'a> { match &*arguments.return_arguments { [] => {} [name] => self.resolver.register(*name, llvm_call), - [b32, pred] => { - self.resolver.with_result(*b32, |name| unsafe { - LLVMBuildExtractValue(self.builder, llvm_call, 0, name) - }); - self.resolver.with_result(*pred, |name| unsafe { - let extracted = - LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr()); - LLVMBuildTrunc( - self.builder, - extracted, - get_scalar_type(self.context, ast::ScalarType::Pred), - name, - ) - }); - } - _ => { - return Err(error_todo_msg( - "Only two return arguments (.b32, .pred) currently supported", - )) + args => { + for (i, arg) in args.iter().copied().enumerate() { + self.resolver.with_result(arg, |name| unsafe { + LLVMBuildExtractValue(self.builder, llvm_call, i as u32, name) + }); + } } } Ok(()) @@ -992,44 +981,28 @@ impl<'a> MethodEmitContext<'a> { unsafe { LLVMSetAlignment(load, type_.layout().align() as u32); } - Ok(load) + Ok((load, type_)) }) .collect::, _>>()?; - match &*loads { [] => unsafe { LLVMBuildRetVoid(self.builder) }, - [value] => unsafe { LLVMBuildRet(self.builder, *value) }, - _ => { - check_multiple_return_types(values.iter().map(|(_, type_)| type_))?; - let array_ty = - get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?; - let insert_b32 = unsafe { - LLVMBuildInsertValue( - self.builder, - LLVMGetPoison(array_ty), - loads[0], - 0, - LLVM_UNNAMED.as_ptr(), - ) - }; - let zext_pred = unsafe { - LLVMBuildZExt( - self.builder, - loads[1], - get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?, - LLVM_UNNAMED.as_ptr(), - ) - }; - let insert_pred = unsafe { - LLVMBuildInsertValue( - self.builder, - insert_b32, - zext_pred, - 1, - LLVM_UNNAMED.as_ptr(), - ) - }; - unsafe { LLVMBuildRet(self.builder, insert_pred) } + [(value, _)] => unsafe { LLVMBuildRet(self.builder, *value) }, + loads => { + let struct_type = + get_or_create_struct_type(self.context, loads.iter().map(|(_, type_)| *type_))?; + let mut value = unsafe { LLVMGetUndef(struct_type) }; + for (i, (load, _)) in loads.iter().enumerate() { + value = unsafe { + LLVMBuildInsertValue( + self.builder, + value, + *load, + i as u32, + LLVM_UNNAMED.as_ptr(), + ) + }; + } + unsafe { LLVMBuildRet(self.builder, value) } } }; Ok(()) @@ -1898,10 +1871,10 @@ impl<'a> MethodEmitContext<'a> { to: ptx_parser::ScalarType, arguments: ptx_parser::CvtArgs, llvm_func: unsafe extern "C" fn( - arg1: LLVMBuilderRef, - Val: LLVMValueRef, - DestTy: LLVMTypeRef, - Name: *const i8, + LLVMBuilderRef, + LLVMValueRef, + LLVMTypeRef, + *const i8, ) -> LLVMValueRef, ) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, to); @@ -2928,46 +2901,60 @@ fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result( +fn get_or_create_struct_type<'a>( context: LLVMContextRef, - elem_type: &'a ast::Type, - count: u64, + mut elem_types: impl Iterator, ) -> Result { - let elem_type = get_type(context, elem_type)?; - Ok(unsafe { LLVMArrayType2(elem_type, count) }) + use std::fmt::Write; + let (mut name, types) = elem_types.try_fold( + ("struct".to_string(), Vec::new()), + |(mut name, mut types), t| { + name.push('.'); + if let ast::Type::Scalar(scalar) = t { + write!(name, "{}", LLVMTypeDisplay(*scalar)).ok(); + } else { + return Err(error_unreachable()); + } + types.push(get_type(context, t)?); + Ok((name, types)) + }, + )?; + name.push('\0'); + let mut struct_type = unsafe { LLVMGetTypeByName2(context, name.as_ptr().cast()) }; + if struct_type.is_null() { + struct_type = create_struct_type(context, name, types); + } + Ok(struct_type) } -fn check_multiple_return_types<'a>( - mut return_args: impl ExactSizeIterator, -) -> Result<(), TranslateError> { - let err_msg = "Only (.b32, .pred) multiple return types are supported"; - - let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; - let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?; - match (first, second) { - (ast::Type::Scalar(first), ast::Type::Scalar(second)) => { - if first.size_of() != 4 || second.size_of() != 1 { - return Err(error_todo_msg(err_msg)); - } - } - _ => return Err(error_todo_msg(err_msg)), +fn create_struct_type( + context: LLVMContextRef, + name: String, + mut elem_types: Vec, +) -> LLVMTypeRef { + let llvm_type = unsafe { LLVMStructCreateNamed(context, name.as_ptr().cast()) }; + unsafe { + LLVMStructSetBody( + llvm_type, + elem_types.as_mut_ptr(), + elem_types.len() as u32, + 0, + ) } - Ok(()) + llvm_type } fn get_function_type<'a>( context: LLVMContextRef, - mut return_args: impl ExactSizeIterator, + mut return_args: impl DoubleEndedIterator + + ExactSizeIterator, input_args: impl ExactSizeIterator>, ) -> Result { let mut input_args = input_args.collect::, _>>()?; let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, 1 => get_type(context, &return_args.next().unwrap())?, - _ => { - check_multiple_return_types(return_args)?; - get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)? - } + _ => get_or_create_struct_type(context, return_args)?, }; Ok(unsafe { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index eeb2c7f..d31c0ec 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -24,7 +24,8 @@ mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; mod remove_unreachable_basic_blocks; -mod replace_instructions_with_function_calls; +mod replace_instructions_with_functions; +mod replace_instructions_with_functions_fp_required; mod replace_known_functions; mod resolve_function_pointers; @@ -68,12 +69,14 @@ pub fn to_llvm_module<'input>( let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + let directives = + replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; let directives = remove_unreachable_basic_blocks::run(directives)?; let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; - let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; + let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; let directives = hoist_globals::run(directives)?; let context = llvm::Context::new(); @@ -235,6 +238,15 @@ enum Statement { VectorRead(VectorRead), VectorWrite(VectorWrite), SetMode(ModeRegister), + // This instruction is a nop, it serves as a marker to indicate that the + // next instruction requires certain floating-point modes to be set. + // Some transcendentals compile to a sequence of instructions that + // require certain modes to be set _mid-function_. + // See replace_instructions_with_functions_fp_required pass for details + FpModeRequired { + ftz_f32: Option, + rnd_f32: Option, + }, FpSaturate { dst: SpirvWord, src: SpirvWord, @@ -541,6 +553,9 @@ impl> Statement, T> { )?; Statement::FpSaturate { dst, src, type_ } } + Statement::FpModeRequired { ftz_f32, rnd_f32 } => { + Statement::FpModeRequired { ftz_f32, rnd_f32 } + } }) } } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_functions.rs similarity index 52% rename from ptx/src/pass/replace_instructions_with_function_calls.rs rename to ptx/src/pass/replace_instructions_with_functions.rs index 8123e41..0f6a36c 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -1,4 +1,5 @@ use super::*; +use smallvec::*; pub(super) fn run<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, @@ -71,13 +72,136 @@ fn run_statements<'input>( statements .into_iter() .map(|statement| { - Ok(match statement { - Statement::Instruction(instruction) => { - Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) + Ok::, _>(match statement { + Statement::Instruction(ast::Instruction::ShflSync { + data, + arguments: + ast::ShflSyncArgs { + dst_pred: Some(dst_pred), + dst, + src, + src_lane, + src_opts, + src_membermask, + }, + }) => { + let mode = match data.mode { + ptx_parser::ShuffleMode::Up => "up", + ptx_parser::ShuffleMode::Down => "down", + ptx_parser::ShuffleMode::BFly => "bfly", + ptx_parser::ShuffleMode::Idx => "idx", + }; + let packed_var = resolver.register_unnamed(Some(( + ast::Type::Vector(2, ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ))); + let dst_pred_wide = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ))); + let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat(); + let return_arguments = vec![( + ast::Type::Vector(2, ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + )]; + let input_arguments = vec![ + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ]; + let func = match fn_declarations.entry(full_name.into()) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let name = resolver.register_named(name, None); + vacant_entry.insert(( + to_variables(resolver, &return_arguments), + name, + to_variables(resolver, &input_arguments), + )); + name + } + }; + smallvec![ + Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments: vec![( + ast::Type::Vector(2, ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + )], + input_arguments: vec![ + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::U32), + ptx_parser::StateSpace::Reg, + ), + ], + }, + arguments: ptx_parser::CallArgs { + return_arguments: vec![packed_var], + func, + input_arguments: vec![src, src_lane, src_opts, src_membermask], + }, + }), + Statement::RepackVector(RepackVectorDetails { + is_extract: true, + typ: ast::ScalarType::U32, + packed: packed_var, + unpacked: vec![dst, dst_pred_wide], + relaxed_type_check: false, + }), + Statement::Instruction(ast::Instruction::Cvt { + data: ast::CvtDetails { + from: ast::ScalarType::U32, + to: ast::ScalarType::Pred, + mode: ast::CvtMode::Truncate + }, + arguments: ast::CvtArgs { + dst: dst_pred, + src: dst_pred_wide, + }, + }) + ] } - s => s, + Statement::, SpirvWord>::Instruction(instruction) => { + smallvec![ + Statement::, SpirvWord>::Instruction( + run_instruction(resolver, fn_declarations, instruction)? + ) + ] + } + s => smallvec![s], }) }) + .flat_map(|result| match result { + Ok(vec) => vec.into_iter().map(|item| Ok(item)).collect(), + Err(er) => vec![Err(er)], + }) .collect::, _>>() } @@ -141,6 +265,52 @@ fn run_instruction<'input>( let name = ["bfe_", scalar_to_ptx_name(data)].concat(); to_call(resolver, fn_declarations, name.into(), i)? } + i @ ptx_parser::Instruction::Sqrt { + data: + ast::RcpData { + kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven), + flush_to_zero: Some(true), + .. + }, + .. + } => { + let name = "sqrt_rn_ftz_f32"; + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Sqrt { + data: + ast::RcpData { + kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven), + .. + }, + .. + } => { + let name = "sqrt_rn_f32"; + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + kind: ast::DivFloatKind::Rounding(_), + flush_to_zero: Some(true), + .. + }), + .. + } => { + let name = "div_rn_ftz_f32"; + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + kind: ast::DivFloatKind::Rounding(_), + .. + }), + .. + } => { + let name = "div_rn_f32"; + to_call(resolver, fn_declarations, name.into(), i)? + } i @ ptx_parser::Instruction::Bfi { data, .. } => { let name = ["bfi_", scalar_to_ptx_name(data)].concat(); to_call(resolver, fn_declarations, name.into(), i)? @@ -163,23 +333,24 @@ fn run_instruction<'input>( ptx_parser::Instruction::BarRed { data, arguments }, )? } - ptx_parser::Instruction::ShflSync { data, arguments } => { + ptx_parser::Instruction::ShflSync { + data, + arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, + } => { let mode = match data.mode { ptx_parser::ShuffleMode::Up => "up", ptx_parser::ShuffleMode::Down => "down", ptx_parser::ShuffleMode::BFly => "bfly", ptx_parser::ShuffleMode::Idx => "idx", }; - let pred = if arguments.dst_pred.is_some() { - "_pred" - } else { - "" - }; to_call( resolver, fn_declarations, - format!("shfl_sync_{}_b32{}", mode, pred).into(), - ptx_parser::Instruction::ShflSync { data, arguments }, + format!("shfl_sync_{}_b32", mode).into(), + ptx_parser::Instruction::ShflSync { + data, + arguments: orig_arguments, + }, )? } i @ ptx_parser::Instruction::Nanosleep { .. } => { diff --git a/ptx/src/pass/replace_instructions_with_functions_fp_required.rs b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs new file mode 100644 index 0000000..bf2d690 --- /dev/null +++ b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs @@ -0,0 +1,367 @@ +// This pass exists specifically to replace the `div.rn.ftz.f32` instruction +// with a function call. One inherent weirdness of the replacement function is +// that it requires different rounding mode for the first part of the +// division and the second part. The first part is executed with FTZ disabled +// and the second part with FTZ enabled. +// For this reason we can't handle this past FTZ mode insertion without making +// the function read and restore the FTZ mode. For this reason we split the +// replacement function in two functions and prefix them with a noop +// (FpModeRequired) that carries the FTZ mode information. + +use super::*; +use ptx_parser as ast; +use smallvec::smallvec; +use smallvec::SmallVec; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut imports = None; + let directives = directives + .into_iter() + .map(|directive| run_directive(resolver, directive, &mut imports)) + .collect::, _>>()?; + Ok(match imports { + Some(imports) => { + let mut result = Vec::with_capacity(directives.len() + 2); + result.extend([ + Directive2::Method(Function2 { + return_arguments: vec![ + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::U8), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ], + name: imports.part1, + input_arguments: vec![ + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ], + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + }), + Directive2::Method(Function2 { + return_arguments: vec![ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }], + name: imports.part2, + input_arguments: vec![ + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ast::Variable { + name: resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::U8), + ast::StateSpace::Reg, + ))), + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, + ], + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + }), + ]); + result.extend(directives); + result + } + None => directives, + }) +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2, SpirvWord>, + imports: &mut Option, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method, imports)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + mut method: Function2, SpirvWord>, + imports: &mut Option, +) -> Result, SpirvWord>, TranslateError> { + method.body = method.body.map(|body| { + body.into_iter() + .flat_map(|stmt| run_statement(resolver, stmt, imports)) + .collect() + }); + Ok(method) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + stmt: Statement, SpirvWord>, + imports: &mut Option, +) -> SmallVec<[Statement, SpirvWord>; 4]> { + match stmt { + Statement::Instruction(ast::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + flush_to_zero, + kind: ast::DivFloatKind::Rounding(rnd), + type_: ast::ScalarType::F32, + }), + arguments, + }) => { + let ftz = flush_to_zero.unwrap_or(false); + let FunctionImports { part1, part2, .. } = FunctionImports::init(imports, resolver); + let fma_4 = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))); + let fma_1 = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))); + let fma3_ = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ))); + let numerator_scaled_flag = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::U8), + ast::StateSpace::Reg, + ))); + smallvec![ + Statement::FpModeRequired { + ftz_f32: Some(false), + rnd_f32: Some(ast::RoundingMode::NearestEven), + }, + Statement::Instruction(ast::Instruction::Call { + arguments: ast::CallArgs { + return_arguments: vec![fma_4, fma_1, fma3_, numerator_scaled_flag], + func: *part1, + input_arguments: vec![arguments.src1, arguments.src2], + }, + data: ast::CallDetails { + uniform: false, + return_arguments: vec![ + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + (ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,) + ], + input_arguments: vec![ + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ) + ] + } + }), + Statement::FpModeRequired { + ftz_f32: Some(ftz), + rnd_f32: Some(rnd), + }, + Statement::Instruction(ast::Instruction::Call { + arguments: ast::CallArgs { + return_arguments: vec![arguments.dst], + func: *part2, + input_arguments: vec![ + arguments.src1, + arguments.src2, + fma_4, + fma_1, + fma3_, + numerator_scaled_flag + ], + }, + data: ast::CallDetails { + uniform: false, + return_arguments: vec![( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + )], + input_arguments: vec![ + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + ( + ast::Type::Scalar(ast::ScalarType::F32), + ast::StateSpace::Reg, + ), + (ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,) + ] + } + }) + ] + } + _ => smallvec![stmt], + } +} + +#[derive(Clone)] +struct FunctionImports { + part1: SpirvWord, + part2: SpirvWord, +} + +impl FunctionImports { + fn init<'a>( + this: &'a mut Option, + resolver: &mut GlobalStringIdentResolver2, + ) -> &'a FunctionImports { + this.get_or_insert_with(|| { + let part1_name = [ZLUDA_PTX_PREFIX, "div_f32_part1"].concat(); + let part1 = resolver.register_named(part1_name.into(), None); + let part2_name = [ZLUDA_PTX_PREFIX, "div_f32_part2"].concat(); + let part2 = resolver.register_named(part2_name.into(), None); + FunctionImports { part1, part2 } + }) + } +} diff --git a/ptx/src/test/ll/div_ftz.ll b/ptx/src/test/ll/div_ftz.ll new file mode 100644 index 0000000..6898edb --- /dev/null +++ b/ptx/src/test/ll/div_ftz.ll @@ -0,0 +1,74 @@ +%struct.f32.f32.f32.i8 = type { float, float, float, i8 } + +declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0 + +declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0 + +define amdgpu_kernel void @div_ftz(ptr addrspace(4) byref(i64) %"63", ptr addrspace(4) byref(i64) %"64") #1 { + %"65" = alloca i64, align 8, addrspace(5) + %"66" = alloca i64, align 8, addrspace(5) + %"67" = alloca float, align 4, addrspace(5) + %"68" = alloca float, align 4, addrspace(5) + %"69" = alloca float, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"54" + +"54": ; preds = %1 + %"70" = load i64, ptr addrspace(4) %"63", align 8 + store i64 %"70", ptr addrspace(5) %"65", align 8 + %"71" = load i64, ptr addrspace(4) %"64", align 8 + store i64 %"71", ptr addrspace(5) %"66", align 8 + %"73" = load i64, ptr addrspace(5) %"65", align 8 + %"88" = inttoptr i64 %"73" to ptr + %"72" = load float, ptr %"88", align 4 + store float %"72", ptr addrspace(5) %"67", align 4 + %"74" = load i64, ptr addrspace(5) %"65", align 8 + %"89" = inttoptr i64 %"74" to ptr + %"32" = getelementptr inbounds i8, ptr %"89", i64 4 + %"75" = load float, ptr %"32", align 4 + store float %"75", ptr addrspace(5) %"68", align 4 + %"77" = load float, ptr addrspace(5) %"67", align 4 + %"78" = load float, ptr addrspace(5) %"68", align 4 + %"76" = fmul float %"77", %"78" + store float %"76", ptr addrspace(5) %"69", align 4 + %"79" = load float, ptr addrspace(5) %"67", align 4 + %"80" = load float, ptr addrspace(5) %"68", align 4 + %2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"79", float %"80") + %"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0 + %"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1 + %"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2 + %"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3 + br label %"57" + +"57": ; preds = %"54" + call void @llvm.amdgcn.s.setreg(i32 6401, i32 0) + br label %"55" + +"55": ; preds = %"57" + %"82" = load float, ptr addrspace(5) %"67", align 4 + %"83" = load float, ptr addrspace(5) %"68", align 4 + %"81" = call float @__zluda_ptx_impl_div_f32_part2(float %"82", float %"83", float %"37", float %"38", float %"39", i8 %"40") + store float %"81", ptr addrspace(5) %"67", align 4 + br label %"56" + +"56": ; preds = %"55" + %"84" = load i64, ptr addrspace(5) %"66", align 8 + %"85" = load float, ptr addrspace(5) %"67", align 4 + %"90" = inttoptr i64 %"84" to ptr + store float %"85", ptr %"90", align 4 + %"86" = load i64, ptr addrspace(5) %"66", align 8 + %"91" = inttoptr i64 %"86" to ptr + %"34" = getelementptr inbounds i8, ptr %"91", i64 4 + %"87" = load float, ptr addrspace(5) %"69", align 4 + store float %"87", ptr %"34", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind willreturn +declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #2 = { nocallback nofree nosync nounwind willreturn } \ No newline at end of file diff --git a/ptx/src/test/ll/div_noftz.ll b/ptx/src/test/ll/div_noftz.ll new file mode 100644 index 0000000..46be55f --- /dev/null +++ b/ptx/src/test/ll/div_noftz.ll @@ -0,0 +1,71 @@ +%struct.f32.f32.f32.i8 = type { float, float, float, i8 } + +declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0 + +declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0 + +define amdgpu_kernel void @div_noftz(ptr addrspace(4) byref(i64) %"62", ptr addrspace(4) byref(i64) %"63") #1 { + %"64" = alloca i64, align 8, addrspace(5) + %"65" = alloca i64, align 8, addrspace(5) + %"66" = alloca float, align 4, addrspace(5) + %"67" = alloca float, align 4, addrspace(5) + %"68" = alloca float, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"54" + +"54": ; preds = %1 + %"69" = load i64, ptr addrspace(4) %"62", align 8 + store i64 %"69", ptr addrspace(5) %"64", align 8 + %"70" = load i64, ptr addrspace(4) %"63", align 8 + store i64 %"70", ptr addrspace(5) %"65", align 8 + %"72" = load i64, ptr addrspace(5) %"64", align 8 + %"87" = inttoptr i64 %"72" to ptr + %"71" = load float, ptr %"87", align 4 + store float %"71", ptr addrspace(5) %"66", align 4 + %"73" = load i64, ptr addrspace(5) %"64", align 8 + %"88" = inttoptr i64 %"73" to ptr + %"32" = getelementptr inbounds i8, ptr %"88", i64 4 + %"74" = load float, ptr %"32", align 4 + store float %"74", ptr addrspace(5) %"67", align 4 + %"76" = load float, ptr addrspace(5) %"66", align 4 + %"77" = load float, ptr addrspace(5) %"67", align 4 + %"75" = fmul float %"76", %"77" + store float %"75", ptr addrspace(5) %"68", align 4 + call void @llvm.amdgcn.s.setreg(i32 6401, i32 3) + %"78" = load float, ptr addrspace(5) %"66", align 4 + %"79" = load float, ptr addrspace(5) %"67", align 4 + %2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"78", float %"79") + %"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0 + %"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1 + %"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2 + %"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3 + br label %"55" + +"55": ; preds = %"54" + %"81" = load float, ptr addrspace(5) %"66", align 4 + %"82" = load float, ptr addrspace(5) %"67", align 4 + %"80" = call float @__zluda_ptx_impl_div_f32_part2(float %"81", float %"82", float %"37", float %"38", float %"39", i8 %"40") + store float %"80", ptr addrspace(5) %"66", align 4 + br label %"56" + +"56": ; preds = %"55" + %"83" = load i64, ptr addrspace(5) %"65", align 8 + %"84" = load float, ptr addrspace(5) %"66", align 4 + %"89" = inttoptr i64 %"83" to ptr + store float %"84", ptr %"89", align 4 + %"85" = load i64, ptr addrspace(5) %"65", align 8 + %"90" = inttoptr i64 %"85" to ptr + %"34" = getelementptr inbounds i8, ptr %"90", i64 4 + %"86" = load float, ptr addrspace(5) %"68", align 4 + store float %"86", ptr %"34", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind willreturn +declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #2 = { nocallback nofree nosync nounwind willreturn } \ No newline at end of file diff --git a/ptx/src/test/ll/sqrt.ll b/ptx/src/test/ll/sqrt.ll index 4c7ce98..e8ec284 100644 --- a/ptx/src/test/ll/sqrt.ll +++ b/ptx/src/test/ll/sqrt.ll @@ -1,4 +1,6 @@ -define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 { +declare float @__zluda_ptx_impl_sqrt_approx_f32(float) #0 + +define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 { %"32" = alloca i64, align 8, addrspace(5) %"33" = alloca i64, align 8, addrspace(5) %"34" = alloca float, align 4, addrspace(5) @@ -17,7 +19,7 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace %"37" = load float, ptr %"43", align 4 store float %"37", ptr addrspace(5) %"34", align 4 %"40" = load float, ptr addrspace(5) %"34", align 4 - %"39" = call float @llvm.amdgcn.sqrt.f32(float %"40") + %"39" = call float @__zluda_ptx_impl_sqrt_approx_f32(float %"40") store float %"39", ptr addrspace(5) %"34", align 4 %"41" = load i64, ptr addrspace(5) %"33", align 8 %"42" = load float, ptr addrspace(5) %"34", align 4 @@ -26,8 +28,5 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace ret void } -; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) -declare float @llvm.amdgcn.sqrt.f32(float) #1 - -attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } -attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/sqrt_rn_ftz.ll b/ptx/src/test/ll/sqrt_rn_ftz.ll new file mode 100644 index 0000000..5881807 --- /dev/null +++ b/ptx/src/test/ll/sqrt_rn_ftz.ll @@ -0,0 +1,32 @@ +declare float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float) #0 + +define amdgpu_kernel void @sqrt_rn_ftz(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 { + %"32" = alloca i64, align 8, addrspace(5) + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca float, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"29" + +"29": ; preds = %1 + %"35" = load i64, ptr addrspace(4) %"30", align 8 + store i64 %"35", ptr addrspace(5) %"32", align 8 + %"36" = load i64, ptr addrspace(4) %"31", align 8 + store i64 %"36", ptr addrspace(5) %"33", align 8 + %"38" = load i64, ptr addrspace(5) %"32", align 8 + %"43" = inttoptr i64 %"38" to ptr + %"37" = load float, ptr %"43", align 4 + store float %"37", ptr addrspace(5) %"34", align 4 + %"40" = load float, ptr addrspace(5) %"34", align 4 + %"39" = call float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float %"40") + store float %"39", ptr addrspace(5) %"34", align 4 + %"41" = load i64, ptr addrspace(5) %"33", align 8 + %"42" = load float, ptr addrspace(5) %"34", align 4 + %"44" = inttoptr i64 %"41" to ptr + store float %"42", ptr %"44", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/div_ftz.ptx b/ptx/src/test/spirv_run/div_ftz.ptx new file mode 100644 index 0000000..3f547c5 --- /dev/null +++ b/ptx/src/test/spirv_run/div_ftz.ptx @@ -0,0 +1,27 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry div_ftz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + .reg .f32 temp2; + .reg .f32 force_ftz_mode; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + ld.f32 temp2, [in_addr+4]; + // DO NOT REMOVE THIS MULTIPLICATION + mul.f32 force_ftz_mode, temp1, temp2; + div.ftz.rn.f32 temp1, temp1, temp2; + st.f32 [out_addr], temp1; + st.f32 [out_addr+4], force_ftz_mode; + ret; +} diff --git a/ptx/src/test/spirv_run/div_noftz.ptx b/ptx/src/test/spirv_run/div_noftz.ptx new file mode 100644 index 0000000..fa935fb --- /dev/null +++ b/ptx/src/test/spirv_run/div_noftz.ptx @@ -0,0 +1,27 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry div_noftz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + .reg .f32 temp2; + .reg .f32 force_ftz_mode; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + ld.f32 temp2, [in_addr+4]; + // DO NOT REMOVE THIS MULTIPLICATION + mul.ftz.f32 force_ftz_mode, temp1, temp2; + div.rn.f32 temp1, temp1, temp2; + st.f32 [out_addr], temp1; + st.f32 [out_addr+4], force_ftz_mode; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e508a5c..ed760ed 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -173,6 +173,7 @@ test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]); test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); test_ptx!(sqrt, [0.25f32], [0.5f32]); +test_ptx!(sqrt_rn_ftz, [0x1u32], [0x0u32]); test_ptx!(rsqrt, [0.25f64], [2f64]); test_ptx!(neg, [181i32], [-181i32]); test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]); @@ -279,6 +280,19 @@ test_ptx!(multiple_return, [5u32], [6u32, 123u32]); test_ptx!(warp_sz, [0u8], [32u8]); test_ptx!(tanh, [f32::INFINITY], [1.0f32]); test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]); +// Two test below test very important compiler feature, make sure that you +// understand fully what's going on before you touch it. +// The problem is that the full-precision division gets legalized by LLVM +// using __module attribute__. +// In the two tests below we deliberately force our compiler to emit +// different a module that has a different module-level denormal attribute +// from the denormal attribute of the instruction to catch cases like this +test_ptx!(div_ftz, [0x16A2028Du32, 0x5E89F6AE], [0x0, 900636404u32]); +test_ptx!( + div_noftz, + [0x16A2028Du32, 0x5E89F6AE], + [0x26u32, 900636404u32] +); test_ptx!(nanosleep, [0u64], [0u64]); test_ptx!(shf_l, [0x12345678u32, 0x9abcdef0u32, 12], [0xcdef0123u32]); diff --git a/ptx/src/test/spirv_run/sqrt_rn_ftz.ptx b/ptx/src/test/spirv_run/sqrt_rn_ftz.ptx new file mode 100644 index 0000000..423879e --- /dev/null +++ b/ptx/src/test/spirv_run/sqrt_rn_ftz.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry sqrt_rn_ftz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + sqrt.rn.ftz.f32 temp1, temp1; + st.f32 [out_addr], temp1; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index b2e15e5..ed5eb9d 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1058,22 +1058,12 @@ impl From for Type { #[derive(Clone)] pub struct MovDetails { pub typ: super::Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, } impl MovDetails { pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { MovDetails { typ: Type::maybe_vector(vector, scalar), - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, } } }