Add support for fp8 to cvt (#468)

This implements specifically the fp8 conversion instructions needed by llm.c:

* `cvt.rn.satfinite{.relu}.f8x2type.f32`
* `cvt.rn{.relu}.f16x2.f8x2type`

It uses HIP's fp8 and fp16 headers: https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/low_fp_types.html#fp8-quarter-precision.
This commit is contained in:
Violet 2025-08-28 17:54:07 -07:00 committed by GitHub
commit 8f484d6a5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 507 additions and 63 deletions

Binary file not shown.

View file

@ -1,12 +1,13 @@
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
#include <cstddef>
#include <cstdint>
#include <bit>
#include <cmath>
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/hip_fp8.h>
#define CONSTANT_SPACE __attribute__((address_space(4)))
@ -476,4 +477,46 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
{
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
}
__device__ static __hip_fp8_storage_t cvt_float_to_fp8(float f, __hip_fp8_interpretation_t interp)
{
const uint32_t bits = reinterpret_cast<uint32_t &>(f);
const uint8_t sign = (bits & 0x80000000) ? 0x80 : 0x0;
const uint32_t abs = bits & 0x7fffffff;
const uint32_t min = interp == __HIP_E4M3 ? 0x3A800000 : 0x37000000;
if (abs < min)
{
return sign; // +/- 0
}
return __hip_cvt_float_to_fp8(f, __HIP_SATFINITE, interp);
}
struct Fp8x2
{
__hip_fp8_storage_t b : 8;
__hip_fp8_storage_t a : 8;
};
Fp8x2 FUNC(cvt_rn_satfinite_e4m3x2_f32)(float a, float b)
{
// If built-in support for fp8 formats is added to LLVM IR we should switch to use that.
return {cvt_float_to_fp8(b, __HIP_E4M3), cvt_float_to_fp8(a, __HIP_E4M3)};
}
Fp8x2 FUNC(cvt_rn_satfinite_e5m2x2_f32)(float a, float b)
{
return {cvt_float_to_fp8(b, __HIP_E5M2), cvt_float_to_fp8(a, __HIP_E5M2)};
}
__half2 FUNC(cvt_rn_f16x2_e4m3x2)(__hip_fp8x2_e4m3 in)
{
return in;
}
__half2 FUNC(cvt_rn_f16x2_e5m2x2)(__hip_fp8x2_e5m2 in)
{
return in;
}
}

View file

@ -299,7 +299,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
},
ptx_parser::ScalarType::S16
| ptx_parser::ScalarType::B16
| ptx_parser::ScalarType::U16 => unsafe {
| ptx_parser::ScalarType::U16
| ptx_parser::ScalarType::E4m3x2
| ptx_parser::ScalarType::E5m2x2 => unsafe {
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::S32
@ -1586,6 +1588,26 @@ impl<'a> MethodEmitContext<'a> {
data: ptx_parser::CvtDetails,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
// Truncating conversions to FP8 types should be replaced by a function call.
match data {
ptx_parser::CvtDetails {
to: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
mode: ast::CvtMode::FPTruncate { .. },
..
} => return Err(error_unreachable()),
_ => {}
}
// Extending conversions from FP8 types should be replaced by a function call.
match data {
ptx_parser::CvtDetails {
from: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
mode: ast::CvtMode::FPExtend { .. },
..
} => return Err(error_unreachable()),
_ => {}
}
let dst_type = get_scalar_type(self.context, data.to);
let llvm_fn = match data.mode {
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
@ -3096,7 +3118,11 @@ impl std::fmt::Display for LLVMTypeDisplay {
match self.0 {
ast::ScalarType::Pred => write!(f, "i1"),
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
ast::ScalarType::B16
| ast::ScalarType::U16
| ast::ScalarType::S16
| ast::ScalarType::E4m3x2
| ast::ScalarType::E5m2x2 => write!(f, "i16"),
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
ptx_parser::ScalarType::B128 => write!(f, "i128"),

View file

@ -153,9 +153,11 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
LLVMInt8TypeInContext(context)
},
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
LLVMInt16TypeInContext(context)
},
ast::ScalarType::B16
| ast::ScalarType::U16
| ast::ScalarType::S16
| ast::ScalarType::E4m3x2
| ast::ScalarType::E5m2x2 => unsafe { LLVMInt16TypeInContext(context) },
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
LLVMInt32TypeInContext(context)
},
@ -169,7 +171,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
ast::ScalarType::BF16x2 => todo!(),
}
}

View file

@ -995,6 +995,8 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
ast::ScalarType::BF16 => "bf16",
ast::ScalarType::BF16x2 => "bf16x2",
ast::ScalarType::Pred => "pred",
ast::ScalarType::E4m3x2 => "e4m3x2",
ast::ScalarType::E5m2x2 => "e5m2x2",
}
}

View file

@ -57,6 +57,38 @@ fn run_directive<'input>(
})
}
fn get_or_declare_function<'input, S: Into<Cow<'input, str>>>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut HashMap<
Cow<'input, str>,
(
Vec<ptx_parser::Variable<SpirvWord>>,
SpirvWord,
Vec<ptx_parser::Variable<SpirvWord>>,
),
rustc_hash::FxBuildHasher,
>,
name: S,
return_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
input_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
) -> SpirvWord {
let func = match fn_declarations.entry(name.into()) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert((
to_variables(resolver, return_arguments),
name,
to_variables(resolver, input_arguments),
));
name
}
};
func
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
@ -99,7 +131,7 @@ fn run_statements<'input>(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)));
let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat();
let name = ["shfl_sync_", mode, "_b32_pred"].concat();
let return_arguments = vec![(
ast::Type::Vector(2, ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
@ -122,45 +154,19 @@ fn run_statements<'input>(
ptx_parser::StateSpace::Reg,
),
];
let func = match fn_declarations.entry(full_name.into()) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let name = resolver.register_named(name, None);
vacant_entry.insert((
to_variables(resolver, &return_arguments),
name,
to_variables(resolver, &input_arguments),
));
name
}
};
let func = get_or_declare_function(
resolver,
fn_declarations,
name,
&return_arguments,
&input_arguments,
);
smallvec![
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments: vec![(
ast::Type::Vector(2, ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)],
input_arguments: vec![
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
],
return_arguments,
input_arguments
},
arguments: ptx_parser::CallArgs {
return_arguments: vec![packed_var],
@ -184,6 +190,73 @@ fn run_statements<'input>(
arguments: ast::CvtArgs {
dst: dst_pred,
src: dst_pred_wide,
src2: None,
},
})
]
}
Statement::Instruction(ast::Instruction::Cvt {
data:
ast::CvtDetails {
from: from @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
to: ast::ScalarType::F16x2,
mode: _,
},
arguments:
ast::CvtArgs {
dst,
src,
src2: None,
},
}) => {
let from_str = match from {
ast::ScalarType::E4m3x2 => "e4m3x2",
ast::ScalarType::E5m2x2 => "e5m2x2",
_ => unreachable!(),
};
let packed_output = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)));
let name = format!("cvt_rn_f16x2_{}", from_str);
let return_arguments = vec![(
ast::Type::Scalar(ast::ScalarType::B32),
ast::StateSpace::Reg,
)];
let input_arguments = vec![(
ast::Type::Scalar(ast::ScalarType::B16),
ast::StateSpace::Reg,
)];
let func = get_or_declare_function(
resolver,
fn_declarations,
name,
&return_arguments,
&input_arguments,
);
smallvec![
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments,
input_arguments,
},
arguments: ptx_parser::CallArgs {
return_arguments: vec![packed_output],
func,
input_arguments: vec![src],
},
}),
Statement::Instruction(ast::Instruction::Cvt {
data: ast::CvtDetails {
from: ast::ScalarType::B32,
to: ast::ScalarType::F16x2,
mode: ast::CvtMode::Bitcast
},
arguments: ast::CvtArgs {
dst,
src: packed_output,
src2: None,
},
})
]
@ -335,6 +408,29 @@ fn run_instruction<'input>(
i @ ptx_parser::Instruction::Nanosleep { .. } => {
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
}
i @ ptx_parser::Instruction::Cvt {
data:
ptx_parser::CvtDetails {
from: ast::ScalarType::F32,
to: to @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
mode: _,
},
arguments: _,
} => {
let to = match to {
ptx_parser::ScalarType::E4m3x2 => "e4m3x2",
ptx_parser::ScalarType::E5m2x2 => "e5m2x2",
_ => unreachable!(),
};
// Conversions from f32 to f8 must have two source arguments.
// satfinite is mandatory for conversions to e4m3x2.
to_call(
resolver,
fn_declarations,
format!("cvt_rn_satfinite_{}_f32", to).into(),
i,
)?
}
i => i,
})
}
@ -373,20 +469,8 @@ fn to_call<'input>(
};
Ok::<_, TranslateError>(())
})?;
let fn_name = match fn_declarations.entry(name) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert((
to_variables(resolver, &data_return),
name,
to_variables(resolver, &data_input),
));
name
}
};
let fn_name =
get_or_declare_function(resolver, fn_declarations, name, &data_return, &data_input);
Ok(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,

View file

@ -0,0 +1,35 @@
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16) #0
define amdgpu_kernel void @cvt_rn_f16x2_e4m3x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i16, align 2, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 8
store i64 %"37", ptr addrspace(5) %"33", align 8
%"38" = load i64, ptr addrspace(4) %"32", align 8
store i64 %"38", ptr addrspace(5) %"34", align 8
%"40" = load i64, ptr addrspace(5) %"33", align 8
%"45" = inttoptr i64 %"40" to ptr
%"39" = load i16, ptr %"45", align 2
store i16 %"39", ptr addrspace(5) %"35", align 2
%"42" = load i16, ptr addrspace(5) %"35", align 2
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16 %"42")
%"46" = bitcast i32 %"49" to <2 x half>
%"41" = bitcast <2 x half> %"46" to i32
store i32 %"41", ptr addrspace(5) %"36", align 4
%"43" = load i64, ptr addrspace(5) %"34", align 8
%"44" = load i32, ptr addrspace(5) %"36", align 4
%"48" = inttoptr i64 %"43" to ptr
store i32 %"44", ptr %"48", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,35 @@
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16) #0
define amdgpu_kernel void @cvt_rn_f16x2_e5m2x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca i64, align 8, addrspace(5)
%"35" = alloca i16, align 2, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"30"
"30": ; preds = %1
%"37" = load i64, ptr addrspace(4) %"31", align 8
store i64 %"37", ptr addrspace(5) %"33", align 8
%"38" = load i64, ptr addrspace(4) %"32", align 8
store i64 %"38", ptr addrspace(5) %"34", align 8
%"40" = load i64, ptr addrspace(5) %"33", align 8
%"45" = inttoptr i64 %"40" to ptr
%"39" = load i16, ptr %"45", align 2
store i16 %"39", ptr addrspace(5) %"35", align 2
%"42" = load i16, ptr addrspace(5) %"35", align 2
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16 %"42")
%"46" = bitcast i32 %"49" to <2 x half>
%"41" = bitcast <2 x half> %"46" to i32
store i32 %"41", ptr addrspace(5) %"36", align 4
%"43" = load i64, ptr addrspace(5) %"34", align 8
%"44" = load i32, ptr addrspace(5) %"36", align 4
%"48" = inttoptr i64 %"43" to ptr
store i32 %"44", ptr %"48", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,40 @@
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float, float) #0
define amdgpu_kernel void @cvt_rn_satfinite_e4m3x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
%"36" = alloca i64, align 8, addrspace(5)
%"37" = alloca i64, align 8, addrspace(5)
%"38" = alloca float, align 4, addrspace(5)
%"39" = alloca float, align 4, addrspace(5)
%"40" = alloca i16, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"33"
"33": ; preds = %1
%"41" = load i64, ptr addrspace(4) %"34", align 8
store i64 %"41", ptr addrspace(5) %"36", align 8
%"42" = load i64, ptr addrspace(4) %"35", align 8
store i64 %"42", ptr addrspace(5) %"37", align 8
%"44" = load i64, ptr addrspace(5) %"36", align 8
%"52" = inttoptr i64 %"44" to ptr
%"43" = load float, ptr %"52", align 4
store float %"43", ptr addrspace(5) %"38", align 4
%"45" = load i64, ptr addrspace(5) %"36", align 8
%"53" = inttoptr i64 %"45" to ptr
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
%"46" = load float, ptr %"32", align 4
store float %"46", ptr addrspace(5) %"39", align 4
%"48" = load float, ptr addrspace(5) %"38", align 4
%"49" = load float, ptr addrspace(5) %"39", align 4
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float %"48", float %"49")
store i16 %"54", ptr addrspace(5) %"40", align 2
%"50" = load i64, ptr addrspace(5) %"37", align 8
%"51" = load i16, ptr addrspace(5) %"40", align 2
%"55" = inttoptr i64 %"50" to ptr
store i16 %"51", ptr %"55", align 2
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,40 @@
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float, float) #0
define amdgpu_kernel void @cvt_rn_satfinite_e5m2x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
%"36" = alloca i64, align 8, addrspace(5)
%"37" = alloca i64, align 8, addrspace(5)
%"38" = alloca float, align 4, addrspace(5)
%"39" = alloca float, align 4, addrspace(5)
%"40" = alloca i16, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"33"
"33": ; preds = %1
%"41" = load i64, ptr addrspace(4) %"34", align 8
store i64 %"41", ptr addrspace(5) %"36", align 8
%"42" = load i64, ptr addrspace(4) %"35", align 8
store i64 %"42", ptr addrspace(5) %"37", align 8
%"44" = load i64, ptr addrspace(5) %"36", align 8
%"52" = inttoptr i64 %"44" to ptr
%"43" = load float, ptr %"52", align 4
store float %"43", ptr addrspace(5) %"38", align 4
%"45" = load i64, ptr addrspace(5) %"36", align 8
%"53" = inttoptr i64 %"45" to ptr
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
%"46" = load float, ptr %"32", align 4
store float %"46", ptr addrspace(5) %"39", align 4
%"48" = load float, ptr addrspace(5) %"38", align 4
%"49" = load float, ptr addrspace(5) %"39", align 4
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float %"48", float %"49")
store i16 %"54", ptr addrspace(5) %"40", align 2
%"50" = load i64, ptr addrspace(5) %"37", align 8
%"51" = load i16, ptr addrspace(5) %"40", align 2
%"55" = inttoptr i64 %"50" to ptr
store i16 %"51", ptr %"55", align 2
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,23 @@
.version 7.8
.target sm_90
.address_size 64
.visible .entry cvt_rn_f16x2_e4m3x2(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b16 in;
.reg .b32 result;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.b16 in, [in_addr];
cvt.rn.f16x2.e4m3x2 result, in;
st.b32 [out_addr], result;
ret;
}

View file

@ -0,0 +1,23 @@
.version 7.8
.target sm_90
.address_size 64
.visible .entry cvt_rn_f16x2_e5m2x2(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b16 in;
.reg .b32 result;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.b16 in, [in_addr];
cvt.rn.f16x2.e5m2x2 result, in;
st.b32 [out_addr], result;
ret;
}

View file

@ -0,0 +1,25 @@
.version 7.8
.target sm_90
.address_size 64
.visible .entry cvt_rn_satfinite_e4m3x2_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 in_a;
.reg .f32 in_b;
.reg .b16 result;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 in_a, [in_addr];
ld.f32 in_b, [in_addr + 4];
cvt.rn.satfinite.e4m3x2.f32 result, in_a, in_b;
st.b16 [out_addr], result;
ret;
}

View file

@ -0,0 +1,25 @@
.version 7.8
.target sm_90
.address_size 64
.visible .entry cvt_rn_satfinite_e5m2x2_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 in_a;
.reg .f32 in_b;
.reg .b16 result;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 in_a, [in_addr];
ld.f32 in_b, [in_addr + 4];
cvt.rn.satfinite.e5m2x2.f32 result, in_a, in_b;
st.b16 [out_addr], result;
ret;
}

View file

@ -185,6 +185,14 @@ test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
test_ptx!(cvt_rn_satfinite_e4m3x2_f32, [0.40625, 12.9f32], [0x2D55u16]);
test_ptx!(
cvt_rn_satfinite_e5m2x2_f32,
[0.375, -5256.6f32],
[0x36EDu16]
);
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
test_ptx!(

View file

@ -233,6 +233,11 @@ ptx_parser_macros::generate_instruction_type!(
type: { Type::Scalar(data.from) },
relaxed_type_check: true,
},
src2: {
repr: Option<T>,
type: { Type::Scalar(data.from) },
relaxed_type_check: true,
},
}
},
Cvta {
@ -1047,7 +1052,9 @@ impl ScalarType {
| ScalarType::S16
| ScalarType::B16
| ScalarType::F16
| ScalarType::BF16 => 2,
| ScalarType::BF16
| ScalarType::E4m3x2
| ScalarType::E5m2x2 => 2,
ScalarType::U32
| ScalarType::S32
| ScalarType::B32
@ -1069,7 +1076,9 @@ impl ScalarType {
| ScalarType::S16
| ScalarType::B16
| ScalarType::F16
| ScalarType::BF16 => Layout::new::<u16>(),
| ScalarType::BF16
| ScalarType::E4m3x2
| ScalarType::E5m2x2 => Layout::new::<u16>(),
ScalarType::U32
| ScalarType::S32
| ScalarType::B32
@ -1110,6 +1119,8 @@ impl ScalarType {
ScalarType::F64 => ScalarKind::Float,
ScalarType::BF16 => ScalarKind::Float,
ScalarType::BF16x2 => ScalarKind::Float,
ScalarType::E4m3x2 => ScalarKind::Float,
ScalarType::E5m2x2 => ScalarKind::Float,
ScalarType::Pred => ScalarKind::Pred,
}
}
@ -1884,7 +1895,9 @@ impl CvtDetails {
saturate,
},
Ordering::Greater => {
if rounding.is_some() {
if rounding.is_some()
&& !(src == ScalarType::E4m3x2 || src == ScalarType::E5m2x2)
{
errors.push(PtxError::SyntaxError(
"should not have rounding mode when dst is larger than src in cvt"
.to_string(),

View file

@ -2370,7 +2370,7 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
let arguments = ast::CvtArgs { dst: d, src: a };
let arguments = ast::CvtArgs { dst: d, src: a, src2: None };
ast::Instruction::Cvt {
data, arguments
}
@ -2381,18 +2381,38 @@ derive_parser!(
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
// cvt.rna{.satfinite}.tf32.f32 d, a;
// cvt.frnd2{.relu}.tf32.f32 d, a;
// cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {
if relu {
state.errors.push(PtxError::Todo);
}
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, f8x2type, ScalarType::F32);
ast::Instruction::Cvt {
data,
arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) }
}
}
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
// cvt.rn.{.relu}.f16x2.f8x2type d, a;
cvt.rn{.relu}.f16x2.f8x2type d, a => {
if relu {
state.errors.push(PtxError::Todo);
}
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, ScalarType::F16x2, f8x2type);
ast::Instruction::Cvt {
data,
arguments: ast::CvtArgs { dst: d, src: a, src2: None }
}
}
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
.frnd2: RawRoundingMode = { .rn, .rz };
RawRoundingMode = { .rn };
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.bf16, .f16, .f32, .f64 };
.atype: ScalarType = { .u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.bf16, .f16, .f32, .f64 };
.f8x2type: ScalarType = { .e4m3x2, .e5m2x2 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
shl.type d, a, b => {
ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }