diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 853b9e1..440b232 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 9de6f61..42fa23e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,12 +1,13 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include #include +#include #define CONSTANT_SPACE __attribute__((address_space(4))) @@ -476,4 +477,46 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag}); } + + __device__ static __hip_fp8_storage_t cvt_float_to_fp8(float f, __hip_fp8_interpretation_t interp) + { + const uint32_t bits = reinterpret_cast(f); + const uint8_t sign = (bits & 0x80000000) ? 0x80 : 0x0; + const uint32_t abs = bits & 0x7fffffff; + + const uint32_t min = interp == __HIP_E4M3 ? 0x3A800000 : 0x37000000; + if (abs < min) + { + return sign; // +/- 0 + } + + return __hip_cvt_float_to_fp8(f, __HIP_SATFINITE, interp); + } + + struct Fp8x2 + { + __hip_fp8_storage_t b : 8; + __hip_fp8_storage_t a : 8; + }; + + Fp8x2 FUNC(cvt_rn_satfinite_e4m3x2_f32)(float a, float b) + { + // If built-in support for fp8 formats is added to LLVM IR we should switch to use that. + return {cvt_float_to_fp8(b, __HIP_E4M3), cvt_float_to_fp8(a, __HIP_E4M3)}; + } + + Fp8x2 FUNC(cvt_rn_satfinite_e5m2x2_f32)(float a, float b) + { + return {cvt_float_to_fp8(b, __HIP_E5M2), cvt_float_to_fp8(a, __HIP_E5M2)}; + } + + __half2 FUNC(cvt_rn_f16x2_e4m3x2)(__hip_fp8x2_e4m3 in) + { + return in; + } + + __half2 FUNC(cvt_rn_f16x2_e5m2x2)(__hip_fp8x2_e5m2 in) + { + return in; + } } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index f6b8ca0..177efa6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -299,7 +299,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }, ptx_parser::ScalarType::S16 | ptx_parser::ScalarType::B16 - | ptx_parser::ScalarType::U16 => unsafe { + | ptx_parser::ScalarType::U16 + | ptx_parser::ScalarType::E4m3x2 + | ptx_parser::ScalarType::E5m2x2 => unsafe { LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) }, ptx_parser::ScalarType::S32 @@ -1586,6 +1588,26 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::CvtDetails, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { + // Truncating conversions to FP8 types should be replaced by a function call. + match data { + ptx_parser::CvtDetails { + to: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2, + mode: ast::CvtMode::FPTruncate { .. }, + .. + } => return Err(error_unreachable()), + _ => {} + } + + // Extending conversions from FP8 types should be replaced by a function call. + match data { + ptx_parser::CvtDetails { + from: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2, + mode: ast::CvtMode::FPExtend { .. }, + .. + } => return Err(error_unreachable()), + _ => {} + } + let dst_type = get_scalar_type(self.context, data.to); let llvm_fn = match data.mode { ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, @@ -3096,7 +3118,11 @@ impl std::fmt::Display for LLVMTypeDisplay { match self.0 { ast::ScalarType::Pred => write!(f, "i1"), ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), + ast::ScalarType::B16 + | ast::ScalarType::U16 + | ast::ScalarType::S16 + | ast::ScalarType::E4m3x2 + | ast::ScalarType::E5m2x2 => write!(f, "i16"), ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), ptx_parser::ScalarType::B128 => write!(f, "i128"), diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 3513e88..24f790e 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -153,9 +153,11 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { LLVMInt8TypeInContext(context) }, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { - LLVMInt16TypeInContext(context) - }, + ast::ScalarType::B16 + | ast::ScalarType::U16 + | ast::ScalarType::S16 + | ast::ScalarType::E4m3x2 + | ast::ScalarType::E5m2x2 => unsafe { LLVMInt16TypeInContext(context) }, ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { LLVMInt32TypeInContext(context) }, @@ -169,7 +171,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), + ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => todo!(), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 79f5e99..d5edccf 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -995,6 +995,8 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { ast::ScalarType::BF16 => "bf16", ast::ScalarType::BF16x2 => "bf16x2", ast::ScalarType::Pred => "pred", + ast::ScalarType::E4m3x2 => "e4m3x2", + ast::ScalarType::E5m2x2 => "e5m2x2", } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index edcaaa1..4b9b2bb 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -57,6 +57,38 @@ fn run_directive<'input>( }) } +fn get_or_declare_function<'input, S: Into>>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut HashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + rustc_hash::FxBuildHasher, + >, + name: S, + return_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, + input_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, +) -> SpirvWord { + let func = match fn_declarations.entry(name.into()) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); + let name = resolver.register_named(Cow::Owned(full_name.clone()), None); + vacant_entry.insert(( + to_variables(resolver, return_arguments), + name, + to_variables(resolver, input_arguments), + )); + name + } + }; + func +} + fn run_statements<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, fn_declarations: &mut FxHashMap< @@ -99,7 +131,7 @@ fn run_statements<'input>( ast::Type::Scalar(ast::ScalarType::U32), ptx_parser::StateSpace::Reg, ))); - let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat(); + let name = ["shfl_sync_", mode, "_b32_pred"].concat(); let return_arguments = vec![( ast::Type::Vector(2, ast::ScalarType::U32), ptx_parser::StateSpace::Reg, @@ -122,45 +154,19 @@ fn run_statements<'input>( ptx_parser::StateSpace::Reg, ), ]; - let func = match fn_declarations.entry(full_name.into()) { - hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, - hash_map::Entry::Vacant(vacant_entry) => { - let name = vacant_entry.key().clone(); - let name = resolver.register_named(name, None); - vacant_entry.insert(( - to_variables(resolver, &return_arguments), - name, - to_variables(resolver, &input_arguments), - )); - name - } - }; + let func = get_or_declare_function( + resolver, + fn_declarations, + name, + &return_arguments, + &input_arguments, + ); smallvec![ Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call { data: ptx_parser::CallDetails { uniform: false, - return_arguments: vec![( - ast::Type::Vector(2, ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - )], - input_arguments: vec![ - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ( - ast::Type::Scalar(ast::ScalarType::U32), - ptx_parser::StateSpace::Reg, - ), - ], + return_arguments, + input_arguments }, arguments: ptx_parser::CallArgs { return_arguments: vec![packed_var], @@ -184,6 +190,73 @@ fn run_statements<'input>( arguments: ast::CvtArgs { dst: dst_pred, src: dst_pred_wide, + src2: None, + }, + }) + ] + } + Statement::Instruction(ast::Instruction::Cvt { + data: + ast::CvtDetails { + from: from @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2), + to: ast::ScalarType::F16x2, + mode: _, + }, + arguments: + ast::CvtArgs { + dst, + src, + src2: None, + }, + }) => { + let from_str = match from { + ast::ScalarType::E4m3x2 => "e4m3x2", + ast::ScalarType::E5m2x2 => "e5m2x2", + _ => unreachable!(), + }; + let packed_output = resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + ))); + let name = format!("cvt_rn_f16x2_{}", from_str); + let return_arguments = vec![( + ast::Type::Scalar(ast::ScalarType::B32), + ast::StateSpace::Reg, + )]; + let input_arguments = vec![( + ast::Type::Scalar(ast::ScalarType::B16), + ast::StateSpace::Reg, + )]; + let func = get_or_declare_function( + resolver, + fn_declarations, + name, + &return_arguments, + &input_arguments, + ); + smallvec![ + Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments, + input_arguments, + }, + arguments: ptx_parser::CallArgs { + return_arguments: vec![packed_output], + func, + input_arguments: vec![src], + }, + }), + Statement::Instruction(ast::Instruction::Cvt { + data: ast::CvtDetails { + from: ast::ScalarType::B32, + to: ast::ScalarType::F16x2, + mode: ast::CvtMode::Bitcast + }, + arguments: ast::CvtArgs { + dst, + src: packed_output, + src2: None, }, }) ] @@ -335,6 +408,29 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Nanosleep { .. } => { to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)? } + i @ ptx_parser::Instruction::Cvt { + data: + ptx_parser::CvtDetails { + from: ast::ScalarType::F32, + to: to @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2), + mode: _, + }, + arguments: _, + } => { + let to = match to { + ptx_parser::ScalarType::E4m3x2 => "e4m3x2", + ptx_parser::ScalarType::E5m2x2 => "e5m2x2", + _ => unreachable!(), + }; + // Conversions from f32 to f8 must have two source arguments. + // satfinite is mandatory for conversions to e4m3x2. + to_call( + resolver, + fn_declarations, + format!("cvt_rn_satfinite_{}_f32", to).into(), + i, + )? + } i => i, }) } @@ -373,20 +469,8 @@ fn to_call<'input>( }; Ok::<_, TranslateError>(()) })?; - let fn_name = match fn_declarations.entry(name) { - hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, - hash_map::Entry::Vacant(vacant_entry) => { - let name = vacant_entry.key().clone(); - let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); - let name = resolver.register_named(Cow::Owned(full_name.clone()), None); - vacant_entry.insert(( - to_variables(resolver, &data_return), - name, - to_variables(resolver, &data_input), - )); - name - } - }; + let fn_name = + get_or_declare_function(resolver, fn_declarations, name, &data_return, &data_input); Ok(ast::Instruction::Call { data: ptx_parser::CallDetails { uniform: false, diff --git a/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll b/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll new file mode 100644 index 0000000..ffa7ecf --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll @@ -0,0 +1,35 @@ +declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16) #0 + +define amdgpu_kernel void @cvt_rn_f16x2_e4m3x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i16, align 2, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 8 + store i64 %"37", ptr addrspace(5) %"33", align 8 + %"38" = load i64, ptr addrspace(4) %"32", align 8 + store i64 %"38", ptr addrspace(5) %"34", align 8 + %"40" = load i64, ptr addrspace(5) %"33", align 8 + %"45" = inttoptr i64 %"40" to ptr + %"39" = load i16, ptr %"45", align 2 + store i16 %"39", ptr addrspace(5) %"35", align 2 + %"42" = load i16, ptr addrspace(5) %"35", align 2 + %"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16 %"42") + %"46" = bitcast i32 %"49" to <2 x half> + %"41" = bitcast <2 x half> %"46" to i32 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"43" = load i64, ptr addrspace(5) %"34", align 8 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"48" = inttoptr i64 %"43" to ptr + store i32 %"44", ptr %"48", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll b/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll new file mode 100644 index 0000000..d63c684 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll @@ -0,0 +1,35 @@ +declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16) #0 + +define amdgpu_kernel void @cvt_rn_f16x2_e5m2x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 { + %"33" = alloca i64, align 8, addrspace(5) + %"34" = alloca i64, align 8, addrspace(5) + %"35" = alloca i16, align 2, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"30" + +"30": ; preds = %1 + %"37" = load i64, ptr addrspace(4) %"31", align 8 + store i64 %"37", ptr addrspace(5) %"33", align 8 + %"38" = load i64, ptr addrspace(4) %"32", align 8 + store i64 %"38", ptr addrspace(5) %"34", align 8 + %"40" = load i64, ptr addrspace(5) %"33", align 8 + %"45" = inttoptr i64 %"40" to ptr + %"39" = load i16, ptr %"45", align 2 + store i16 %"39", ptr addrspace(5) %"35", align 2 + %"42" = load i16, ptr addrspace(5) %"35", align 2 + %"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16 %"42") + %"46" = bitcast i32 %"49" to <2 x half> + %"41" = bitcast <2 x half> %"46" to i32 + store i32 %"41", ptr addrspace(5) %"36", align 4 + %"43" = load i64, ptr addrspace(5) %"34", align 8 + %"44" = load i32, ptr addrspace(5) %"36", align 4 + %"48" = inttoptr i64 %"43" to ptr + store i32 %"44", ptr %"48", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll b/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll new file mode 100644 index 0000000..eaa932a --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll @@ -0,0 +1,40 @@ +declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float, float) #0 + +define amdgpu_kernel void @cvt_rn_satfinite_e4m3x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 { + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + %"38" = alloca float, align 4, addrspace(5) + %"39" = alloca float, align 4, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"33" + +"33": ; preds = %1 + %"41" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"41", ptr addrspace(5) %"36", align 8 + %"42" = load i64, ptr addrspace(4) %"35", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"44" = load i64, ptr addrspace(5) %"36", align 8 + %"52" = inttoptr i64 %"44" to ptr + %"43" = load float, ptr %"52", align 4 + store float %"43", ptr addrspace(5) %"38", align 4 + %"45" = load i64, ptr addrspace(5) %"36", align 8 + %"53" = inttoptr i64 %"45" to ptr + %"32" = getelementptr inbounds i8, ptr %"53", i64 4 + %"46" = load float, ptr %"32", align 4 + store float %"46", ptr addrspace(5) %"39", align 4 + %"48" = load float, ptr addrspace(5) %"38", align 4 + %"49" = load float, ptr addrspace(5) %"39", align 4 + %"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float %"48", float %"49") + store i16 %"54", ptr addrspace(5) %"40", align 2 + %"50" = load i64, ptr addrspace(5) %"37", align 8 + %"51" = load i16, ptr addrspace(5) %"40", align 2 + %"55" = inttoptr i64 %"50" to ptr + store i16 %"51", ptr %"55", align 2 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll b/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll new file mode 100644 index 0000000..bec74d3 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll @@ -0,0 +1,40 @@ +declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float, float) #0 + +define amdgpu_kernel void @cvt_rn_satfinite_e5m2x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 { + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + %"38" = alloca float, align 4, addrspace(5) + %"39" = alloca float, align 4, addrspace(5) + %"40" = alloca i16, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"33" + +"33": ; preds = %1 + %"41" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"41", ptr addrspace(5) %"36", align 8 + %"42" = load i64, ptr addrspace(4) %"35", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"44" = load i64, ptr addrspace(5) %"36", align 8 + %"52" = inttoptr i64 %"44" to ptr + %"43" = load float, ptr %"52", align 4 + store float %"43", ptr addrspace(5) %"38", align 4 + %"45" = load i64, ptr addrspace(5) %"36", align 8 + %"53" = inttoptr i64 %"45" to ptr + %"32" = getelementptr inbounds i8, ptr %"53", i64 4 + %"46" = load float, ptr %"32", align 4 + store float %"46", ptr addrspace(5) %"39", align 4 + %"48" = load float, ptr addrspace(5) %"38", align 4 + %"49" = load float, ptr addrspace(5) %"39", align 4 + %"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float %"48", float %"49") + store i16 %"54", ptr addrspace(5) %"40", align 2 + %"50" = load i64, ptr addrspace(5) %"37", align 8 + %"51" = load i16, ptr addrspace(5) %"40", align 2 + %"55" = inttoptr i64 %"50" to ptr + store i16 %"51", ptr %"55", align 2 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx b/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx new file mode 100644 index 0000000..946c498 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx @@ -0,0 +1,23 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_f16x2_e4m3x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b16 in; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 in, [in_addr]; + + cvt.rn.f16x2.e4m3x2 result, in; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx b/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx new file mode 100644 index 0000000..7dcaee0 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx @@ -0,0 +1,23 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_f16x2_e5m2x2( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b16 in; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 in, [in_addr]; + + cvt.rn.f16x2.e5m2x2 result, in; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx new file mode 100644 index 0000000..8a470cf --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_satfinite_e4m3x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b16 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.satfinite.e4m3x2.f32 result, in_a, in_b; + st.b16 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx new file mode 100644 index 0000000..0e5dc8e --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_satfinite_e5m2x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b16 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.satfinite.e5m2x2.f32 result, in_a, in_b; + st.b16 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index ca412be..28a8112 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -185,6 +185,14 @@ test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]); test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]); test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]); +test_ptx!(cvt_rn_satfinite_e4m3x2_f32, [0.40625, 12.9f32], [0x2D55u16]); +test_ptx!( + cvt_rn_satfinite_e5m2x2_f32, + [0.375, -5256.6f32], + [0x36EDu16] +); +test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]); +test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f198795..bee81ac 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -233,6 +233,11 @@ ptx_parser_macros::generate_instruction_type!( type: { Type::Scalar(data.from) }, relaxed_type_check: true, }, + src2: { + repr: Option, + type: { Type::Scalar(data.from) }, + relaxed_type_check: true, + }, } }, Cvta { @@ -1047,7 +1052,9 @@ impl ScalarType { | ScalarType::S16 | ScalarType::B16 | ScalarType::F16 - | ScalarType::BF16 => 2, + | ScalarType::BF16 + | ScalarType::E4m3x2 + | ScalarType::E5m2x2 => 2, ScalarType::U32 | ScalarType::S32 | ScalarType::B32 @@ -1069,7 +1076,9 @@ impl ScalarType { | ScalarType::S16 | ScalarType::B16 | ScalarType::F16 - | ScalarType::BF16 => Layout::new::(), + | ScalarType::BF16 + | ScalarType::E4m3x2 + | ScalarType::E5m2x2 => Layout::new::(), ScalarType::U32 | ScalarType::S32 | ScalarType::B32 @@ -1110,6 +1119,8 @@ impl ScalarType { ScalarType::F64 => ScalarKind::Float, ScalarType::BF16 => ScalarKind::Float, ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::E4m3x2 => ScalarKind::Float, + ScalarType::E5m2x2 => ScalarKind::Float, ScalarType::Pred => ScalarKind::Pred, } } @@ -1884,7 +1895,9 @@ impl CvtDetails { saturate, }, Ordering::Greater => { - if rounding.is_some() { + if rounding.is_some() + && !(src == ScalarType::E4m3x2 || src == ScalarType::E5m2x2) + { errors.push(PtxError::SyntaxError( "should not have rounding mode when dst is larger than src in cvt" .to_string(), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 9c08f95..ea458f6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2370,7 +2370,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); - let arguments = ast::CvtArgs { dst: d, src: a }; + let arguments = ast::CvtArgs { dst: d, src: a, src2: None }; ast::Instruction::Cvt { data, arguments } @@ -2381,18 +2381,38 @@ derive_parser!( // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; // cvt.rna{.satfinite}.tf32.f32 d, a; // cvt.frnd2{.relu}.tf32.f32 d, a; - // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => { + if relu { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, f8x2type, ScalarType::F32); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) } + } + } // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; - // cvt.rn.{.relu}.f16x2.f8x2type d, a; + cvt.rn{.relu}.f16x2.f8x2type d, a => { + if relu { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, ScalarType::F16x2, f8x2type); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: None } + } + } .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; .frnd2: RawRoundingMode = { .rn, .rz }; + RawRoundingMode = { .rn }; .dtype: ScalarType = { .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; .atype: ScalarType = { .u8, .u16, .u32, .u64, .s8, .s16, .s32, .s64, .bf16, .f16, .f32, .f64 }; + .f8x2type: ScalarType = { .e4m3x2, .e5m2x2 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl shl.type d, a, b => { ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }