diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 08532e3..d0e826c 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -1646,9 +1646,39 @@ impl<'a> MethodEmitContext<'a> { } }; let src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst_type, dst) - }); + if let Some(src2) = arguments.src2 { + let packed_type = get_scalar_type( + self.context, + data.to + .packed_type() + .ok_or_else(|| error_mismatched_type())?, + ); + let src2 = self.resolver.value(src2)?; + self.resolver.with_result(arguments.dst, |dst| { + let vec = unsafe { + LLVMBuildInsertElement( + self.builder, + LLVMGetPoison(dst_type), + llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32), + LLVM_UNNAMED.as_ptr(), + ) + }; + unsafe { + LLVMBuildInsertElement( + self.builder, + vec, + llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32), + dst, + ) + } + }) + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }) + }; Ok(()) } diff --git a/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll new file mode 100644 index 0000000..1e19037 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll @@ -0,0 +1,41 @@ +define amdgpu_kernel void @cvt_rn_bf16x2_f32(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca float, align 4, addrspace(5) + %"42" = alloca float, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"44" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"44", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"45", ptr addrspace(5) %"40", align 8 + %"47" = load i64, ptr addrspace(5) %"39", align 8 + %"55" = inttoptr i64 %"47" to ptr + %"46" = load float, ptr %"55", align 4 + store float %"46", ptr addrspace(5) %"41", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"56" = inttoptr i64 %"48" to ptr + %"35" = getelementptr inbounds i8, ptr %"56", i64 4 + %"49" = load float, ptr %"35", align 4 + store float %"49", ptr addrspace(5) %"42", align 4 + %"51" = load float, ptr addrspace(5) %"41", align 4 + %"52" = load float, ptr addrspace(5) %"42", align 4 + %2 = fptrunc float %"51" to bfloat + %3 = insertelement <2 x bfloat> poison, bfloat %2, i32 1 + %4 = fptrunc float %"52" to bfloat + %"57" = insertelement <2 x bfloat> %3, bfloat %4, i32 0 + %"50" = bitcast <2 x bfloat> %"57" to i32 + store i32 %"50", ptr addrspace(5) %"43", align 4 + %"53" = load i64, ptr addrspace(5) %"40", align 8 + %"54" = load i32, ptr addrspace(5) %"43", align 4 + %"58" = inttoptr i64 %"53" to ptr + store i32 %"54", ptr %"58", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx new file mode 100644 index 0000000..2bad276 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_bf16x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.bf16x2.f32 result, in_a, in_b; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e6d9a58..6e1b27e 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -200,6 +200,7 @@ test_ptx!( ); test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]); test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]); +test_ptx!(cvt_rn_bf16x2_f32, [0.40625, 12.9f32], [0x3ED0414Eu32]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 86719ef..37b5f6b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1174,6 +1174,35 @@ impl ScalarType { ScalarType::Pred => ScalarKind::Pred, } } + + pub fn packed_type(&self) -> Option { + match self { + ScalarType::E4m3x2 => Some(ScalarType::B8), + ScalarType::E5m2x2 => Some(ScalarType::B8), + ScalarType::F16x2 => Some(ScalarType::F16), + ScalarType::BF16x2 => Some(ScalarType::BF16), + ScalarType::U16x2 => Some(ScalarType::U16), + ScalarType::S16x2 => Some(ScalarType::S16), + ScalarType::S16 + | ScalarType::BF16 + | ScalarType::U32 + | ScalarType::S8 + | ScalarType::S32 + | ScalarType::Pred + | ScalarType::B8 + | ScalarType::U64 + | ScalarType::B16 + | ScalarType::S64 + | ScalarType::B32 + | ScalarType::U8 + | ScalarType::F32 + | ScalarType::B64 + | ScalarType::B128 + | ScalarType::U16 + | ScalarType::F64 + | ScalarType::F16 => None, + } + } } #[derive(Clone, Copy, PartialEq, Eq)] @@ -1945,8 +1974,13 @@ impl CvtDetails { (RoundingMode::NearestEven, false) } }; + let dst_size = if dst.packed_type().is_some() { + dst.size_of() / 2 + } else { + dst.size_of() + }; let mode = match (dst.kind(), src.kind()) { - (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + (ScalarKind::Float, ScalarKind::Float) => match dst_size.cmp(&src.size_of()) { Ordering::Less => { let (rounding, is_integer_rounding) = unwrap_rounding(); CvtMode::FPTruncate { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 2701127..3dec840 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2442,7 +2442,16 @@ derive_parser!( // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; - // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b => { + if relu || satfinite { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(frnd2), false, false, ScalarType::BF16x2, ScalarType::F32); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) } + } + } // cvt.rna{.satfinite}.tf32.f32 d, a; // cvt.frnd2{.relu}.tf32.f32 d, a; cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {