mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-05 00:57:24 +00:00
Add support for fp8 to cvt
(#468)
This implements specifically the fp8 conversion instructions needed by llm.c: * `cvt.rn.satfinite{.relu}.f8x2type.f32` * `cvt.rn{.relu}.f16x2.f8x2type` It uses HIP's fp8 and fp16 headers: https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/low_fp_types.html#fp8-quarter-precision.
This commit is contained in:
parent
3632f2bf03
commit
8f484d6a5f
17 changed files with 507 additions and 63 deletions
Binary file not shown.
|
@ -1,12 +1,13 @@
|
|||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <bit>
|
||||
#include <cmath>
|
||||
#include <hip/amd_detail/amd_device_functions.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
||||
|
||||
|
@ -476,4 +477,46 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
|||
{
|
||||
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
|
||||
}
|
||||
|
||||
__device__ static __hip_fp8_storage_t cvt_float_to_fp8(float f, __hip_fp8_interpretation_t interp)
|
||||
{
|
||||
const uint32_t bits = reinterpret_cast<uint32_t &>(f);
|
||||
const uint8_t sign = (bits & 0x80000000) ? 0x80 : 0x0;
|
||||
const uint32_t abs = bits & 0x7fffffff;
|
||||
|
||||
const uint32_t min = interp == __HIP_E4M3 ? 0x3A800000 : 0x37000000;
|
||||
if (abs < min)
|
||||
{
|
||||
return sign; // +/- 0
|
||||
}
|
||||
|
||||
return __hip_cvt_float_to_fp8(f, __HIP_SATFINITE, interp);
|
||||
}
|
||||
|
||||
struct Fp8x2
|
||||
{
|
||||
__hip_fp8_storage_t b : 8;
|
||||
__hip_fp8_storage_t a : 8;
|
||||
};
|
||||
|
||||
Fp8x2 FUNC(cvt_rn_satfinite_e4m3x2_f32)(float a, float b)
|
||||
{
|
||||
// If built-in support for fp8 formats is added to LLVM IR we should switch to use that.
|
||||
return {cvt_float_to_fp8(b, __HIP_E4M3), cvt_float_to_fp8(a, __HIP_E4M3)};
|
||||
}
|
||||
|
||||
Fp8x2 FUNC(cvt_rn_satfinite_e5m2x2_f32)(float a, float b)
|
||||
{
|
||||
return {cvt_float_to_fp8(b, __HIP_E5M2), cvt_float_to_fp8(a, __HIP_E5M2)};
|
||||
}
|
||||
|
||||
__half2 FUNC(cvt_rn_f16x2_e4m3x2)(__hip_fp8x2_e4m3 in)
|
||||
{
|
||||
return in;
|
||||
}
|
||||
|
||||
__half2 FUNC(cvt_rn_f16x2_e5m2x2)(__hip_fp8x2_e5m2 in)
|
||||
{
|
||||
return in;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -299,7 +299,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
},
|
||||
ptx_parser::ScalarType::S16
|
||||
| ptx_parser::ScalarType::B16
|
||||
| ptx_parser::ScalarType::U16 => unsafe {
|
||||
| ptx_parser::ScalarType::U16
|
||||
| ptx_parser::ScalarType::E4m3x2
|
||||
| ptx_parser::ScalarType::E5m2x2 => unsafe {
|
||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::S32
|
||||
|
@ -1586,6 +1588,26 @@ impl<'a> MethodEmitContext<'a> {
|
|||
data: ptx_parser::CvtDetails,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
// Truncating conversions to FP8 types should be replaced by a function call.
|
||||
match data {
|
||||
ptx_parser::CvtDetails {
|
||||
to: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
|
||||
mode: ast::CvtMode::FPTruncate { .. },
|
||||
..
|
||||
} => return Err(error_unreachable()),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Extending conversions from FP8 types should be replaced by a function call.
|
||||
match data {
|
||||
ptx_parser::CvtDetails {
|
||||
from: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
|
||||
mode: ast::CvtMode::FPExtend { .. },
|
||||
..
|
||||
} => return Err(error_unreachable()),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let dst_type = get_scalar_type(self.context, data.to);
|
||||
let llvm_fn = match data.mode {
|
||||
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
|
||||
|
@ -3096,7 +3118,11 @@ impl std::fmt::Display for LLVMTypeDisplay {
|
|||
match self.0 {
|
||||
ast::ScalarType::Pred => write!(f, "i1"),
|
||||
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
|
||||
ast::ScalarType::B16
|
||||
| ast::ScalarType::U16
|
||||
| ast::ScalarType::S16
|
||||
| ast::ScalarType::E4m3x2
|
||||
| ast::ScalarType::E5m2x2 => write!(f, "i16"),
|
||||
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
|
||||
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
|
||||
ptx_parser::ScalarType::B128 => write!(f, "i128"),
|
||||
|
|
|
@ -153,9 +153,11 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
|
||||
LLVMInt8TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
|
||||
LLVMInt16TypeInContext(context)
|
||||
},
|
||||
ast::ScalarType::B16
|
||||
| ast::ScalarType::U16
|
||||
| ast::ScalarType::S16
|
||||
| ast::ScalarType::E4m3x2
|
||||
| ast::ScalarType::E5m2x2 => unsafe { LLVMInt16TypeInContext(context) },
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
|
||||
LLVMInt32TypeInContext(context)
|
||||
},
|
||||
|
@ -169,7 +171,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -995,6 +995,8 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
|||
ast::ScalarType::BF16 => "bf16",
|
||||
ast::ScalarType::BF16x2 => "bf16x2",
|
||||
ast::ScalarType::Pred => "pred",
|
||||
ast::ScalarType::E4m3x2 => "e4m3x2",
|
||||
ast::ScalarType::E5m2x2 => "e5m2x2",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -57,6 +57,38 @@ fn run_directive<'input>(
|
|||
})
|
||||
}
|
||||
|
||||
fn get_or_declare_function<'input, S: Into<Cow<'input, str>>>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut HashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ptx_parser::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ptx_parser::Variable<SpirvWord>>,
|
||||
),
|
||||
rustc_hash::FxBuildHasher,
|
||||
>,
|
||||
name: S,
|
||||
return_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
input_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
) -> SpirvWord {
|
||||
let func = match fn_declarations.entry(name.into()) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, return_arguments),
|
||||
name,
|
||||
to_variables(resolver, input_arguments),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
func
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
|
@ -99,7 +131,7 @@ fn run_statements<'input>(
|
|||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)));
|
||||
let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat();
|
||||
let name = ["shfl_sync_", mode, "_b32_pred"].concat();
|
||||
let return_arguments = vec![(
|
||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
|
@ -122,45 +154,19 @@ fn run_statements<'input>(
|
|||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
];
|
||||
let func = match fn_declarations.entry(full_name.into()) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let name = resolver.register_named(name, None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &return_arguments),
|
||||
name,
|
||||
to_variables(resolver, &input_arguments),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
let func = get_or_declare_function(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
name,
|
||||
&return_arguments,
|
||||
&input_arguments,
|
||||
);
|
||||
smallvec![
|
||||
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: vec![(
|
||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
return_arguments,
|
||||
input_arguments
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: vec![packed_var],
|
||||
|
@ -184,6 +190,73 @@ fn run_statements<'input>(
|
|||
arguments: ast::CvtArgs {
|
||||
dst: dst_pred,
|
||||
src: dst_pred_wide,
|
||||
src2: None,
|
||||
},
|
||||
})
|
||||
]
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
from: from @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
|
||||
to: ast::ScalarType::F16x2,
|
||||
mode: _,
|
||||
},
|
||||
arguments:
|
||||
ast::CvtArgs {
|
||||
dst,
|
||||
src,
|
||||
src2: None,
|
||||
},
|
||||
}) => {
|
||||
let from_str = match from {
|
||||
ast::ScalarType::E4m3x2 => "e4m3x2",
|
||||
ast::ScalarType::E5m2x2 => "e5m2x2",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let packed_output = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::B32),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let name = format!("cvt_rn_f16x2_{}", from_str);
|
||||
let return_arguments = vec![(
|
||||
ast::Type::Scalar(ast::ScalarType::B32),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let input_arguments = vec![(
|
||||
ast::Type::Scalar(ast::ScalarType::B16),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let func = get_or_declare_function(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
name,
|
||||
&return_arguments,
|
||||
&input_arguments,
|
||||
);
|
||||
smallvec![
|
||||
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments,
|
||||
input_arguments,
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: vec![packed_output],
|
||||
func,
|
||||
input_arguments: vec![src],
|
||||
},
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Cvt {
|
||||
data: ast::CvtDetails {
|
||||
from: ast::ScalarType::B32,
|
||||
to: ast::ScalarType::F16x2,
|
||||
mode: ast::CvtMode::Bitcast
|
||||
},
|
||||
arguments: ast::CvtArgs {
|
||||
dst,
|
||||
src: packed_output,
|
||||
src2: None,
|
||||
},
|
||||
})
|
||||
]
|
||||
|
@ -335,6 +408,29 @@ fn run_instruction<'input>(
|
|||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
||||
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Cvt {
|
||||
data:
|
||||
ptx_parser::CvtDetails {
|
||||
from: ast::ScalarType::F32,
|
||||
to: to @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
|
||||
mode: _,
|
||||
},
|
||||
arguments: _,
|
||||
} => {
|
||||
let to = match to {
|
||||
ptx_parser::ScalarType::E4m3x2 => "e4m3x2",
|
||||
ptx_parser::ScalarType::E5m2x2 => "e5m2x2",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
// Conversions from f32 to f8 must have two source arguments.
|
||||
// satfinite is mandatory for conversions to e4m3x2.
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
format!("cvt_rn_satfinite_{}_f32", to).into(),
|
||||
i,
|
||||
)?
|
||||
}
|
||||
i => i,
|
||||
})
|
||||
}
|
||||
|
@ -373,20 +469,8 @@ fn to_call<'input>(
|
|||
};
|
||||
Ok::<_, TranslateError>(())
|
||||
})?;
|
||||
let fn_name = match fn_declarations.entry(name) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &data_return),
|
||||
name,
|
||||
to_variables(resolver, &data_input),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
let fn_name =
|
||||
get_or_declare_function(resolver, fn_declarations, name, &data_return, &data_input);
|
||||
Ok(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
|
|
35
ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll
Normal file
35
ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll
Normal file
|
@ -0,0 +1,35 @@
|
|||
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16) #0
|
||||
|
||||
define amdgpu_kernel void @cvt_rn_f16x2_e4m3x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca i64, align 8, addrspace(5)
|
||||
%"35" = alloca i16, align 2, addrspace(5)
|
||||
%"36" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"30"
|
||||
|
||||
"30": ; preds = %1
|
||||
%"37" = load i64, ptr addrspace(4) %"31", align 8
|
||||
store i64 %"37", ptr addrspace(5) %"33", align 8
|
||||
%"38" = load i64, ptr addrspace(4) %"32", align 8
|
||||
store i64 %"38", ptr addrspace(5) %"34", align 8
|
||||
%"40" = load i64, ptr addrspace(5) %"33", align 8
|
||||
%"45" = inttoptr i64 %"40" to ptr
|
||||
%"39" = load i16, ptr %"45", align 2
|
||||
store i16 %"39", ptr addrspace(5) %"35", align 2
|
||||
%"42" = load i16, ptr addrspace(5) %"35", align 2
|
||||
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16 %"42")
|
||||
%"46" = bitcast i32 %"49" to <2 x half>
|
||||
%"41" = bitcast <2 x half> %"46" to i32
|
||||
store i32 %"41", ptr addrspace(5) %"36", align 4
|
||||
%"43" = load i64, ptr addrspace(5) %"34", align 8
|
||||
%"44" = load i32, ptr addrspace(5) %"36", align 4
|
||||
%"48" = inttoptr i64 %"43" to ptr
|
||||
store i32 %"44", ptr %"48", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
35
ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll
Normal file
35
ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll
Normal file
|
@ -0,0 +1,35 @@
|
|||
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16) #0
|
||||
|
||||
define amdgpu_kernel void @cvt_rn_f16x2_e5m2x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca i64, align 8, addrspace(5)
|
||||
%"35" = alloca i16, align 2, addrspace(5)
|
||||
%"36" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"30"
|
||||
|
||||
"30": ; preds = %1
|
||||
%"37" = load i64, ptr addrspace(4) %"31", align 8
|
||||
store i64 %"37", ptr addrspace(5) %"33", align 8
|
||||
%"38" = load i64, ptr addrspace(4) %"32", align 8
|
||||
store i64 %"38", ptr addrspace(5) %"34", align 8
|
||||
%"40" = load i64, ptr addrspace(5) %"33", align 8
|
||||
%"45" = inttoptr i64 %"40" to ptr
|
||||
%"39" = load i16, ptr %"45", align 2
|
||||
store i16 %"39", ptr addrspace(5) %"35", align 2
|
||||
%"42" = load i16, ptr addrspace(5) %"35", align 2
|
||||
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16 %"42")
|
||||
%"46" = bitcast i32 %"49" to <2 x half>
|
||||
%"41" = bitcast <2 x half> %"46" to i32
|
||||
store i32 %"41", ptr addrspace(5) %"36", align 4
|
||||
%"43" = load i64, ptr addrspace(5) %"34", align 8
|
||||
%"44" = load i32, ptr addrspace(5) %"36", align 4
|
||||
%"48" = inttoptr i64 %"43" to ptr
|
||||
store i32 %"44", ptr %"48", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
40
ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll
Normal file
40
ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll
Normal file
|
@ -0,0 +1,40 @@
|
|||
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float, float) #0
|
||||
|
||||
define amdgpu_kernel void @cvt_rn_satfinite_e4m3x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
|
||||
%"36" = alloca i64, align 8, addrspace(5)
|
||||
%"37" = alloca i64, align 8, addrspace(5)
|
||||
%"38" = alloca float, align 4, addrspace(5)
|
||||
%"39" = alloca float, align 4, addrspace(5)
|
||||
%"40" = alloca i16, align 2, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"33"
|
||||
|
||||
"33": ; preds = %1
|
||||
%"41" = load i64, ptr addrspace(4) %"34", align 8
|
||||
store i64 %"41", ptr addrspace(5) %"36", align 8
|
||||
%"42" = load i64, ptr addrspace(4) %"35", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"37", align 8
|
||||
%"44" = load i64, ptr addrspace(5) %"36", align 8
|
||||
%"52" = inttoptr i64 %"44" to ptr
|
||||
%"43" = load float, ptr %"52", align 4
|
||||
store float %"43", ptr addrspace(5) %"38", align 4
|
||||
%"45" = load i64, ptr addrspace(5) %"36", align 8
|
||||
%"53" = inttoptr i64 %"45" to ptr
|
||||
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
|
||||
%"46" = load float, ptr %"32", align 4
|
||||
store float %"46", ptr addrspace(5) %"39", align 4
|
||||
%"48" = load float, ptr addrspace(5) %"38", align 4
|
||||
%"49" = load float, ptr addrspace(5) %"39", align 4
|
||||
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float %"48", float %"49")
|
||||
store i16 %"54", ptr addrspace(5) %"40", align 2
|
||||
%"50" = load i64, ptr addrspace(5) %"37", align 8
|
||||
%"51" = load i16, ptr addrspace(5) %"40", align 2
|
||||
%"55" = inttoptr i64 %"50" to ptr
|
||||
store i16 %"51", ptr %"55", align 2
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
40
ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll
Normal file
40
ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll
Normal file
|
@ -0,0 +1,40 @@
|
|||
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float, float) #0
|
||||
|
||||
define amdgpu_kernel void @cvt_rn_satfinite_e5m2x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
|
||||
%"36" = alloca i64, align 8, addrspace(5)
|
||||
%"37" = alloca i64, align 8, addrspace(5)
|
||||
%"38" = alloca float, align 4, addrspace(5)
|
||||
%"39" = alloca float, align 4, addrspace(5)
|
||||
%"40" = alloca i16, align 2, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"33"
|
||||
|
||||
"33": ; preds = %1
|
||||
%"41" = load i64, ptr addrspace(4) %"34", align 8
|
||||
store i64 %"41", ptr addrspace(5) %"36", align 8
|
||||
%"42" = load i64, ptr addrspace(4) %"35", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"37", align 8
|
||||
%"44" = load i64, ptr addrspace(5) %"36", align 8
|
||||
%"52" = inttoptr i64 %"44" to ptr
|
||||
%"43" = load float, ptr %"52", align 4
|
||||
store float %"43", ptr addrspace(5) %"38", align 4
|
||||
%"45" = load i64, ptr addrspace(5) %"36", align 8
|
||||
%"53" = inttoptr i64 %"45" to ptr
|
||||
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
|
||||
%"46" = load float, ptr %"32", align 4
|
||||
store float %"46", ptr addrspace(5) %"39", align 4
|
||||
%"48" = load float, ptr addrspace(5) %"38", align 4
|
||||
%"49" = load float, ptr addrspace(5) %"39", align 4
|
||||
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float %"48", float %"49")
|
||||
store i16 %"54", ptr addrspace(5) %"40", align 2
|
||||
%"50" = load i64, ptr addrspace(5) %"37", align 8
|
||||
%"51" = load i16, ptr addrspace(5) %"40", align 2
|
||||
%"55" = inttoptr i64 %"50" to ptr
|
||||
store i16 %"51", ptr %"55", align 2
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx
Normal file
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 7.8
|
||||
.target sm_90
|
||||
.address_size 64
|
||||
|
||||
.visible .entry cvt_rn_f16x2_e4m3x2(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b16 in;
|
||||
.reg .b32 result;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.b16 in, [in_addr];
|
||||
|
||||
cvt.rn.f16x2.e4m3x2 result, in;
|
||||
st.b32 [out_addr], result;
|
||||
ret;
|
||||
}
|
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx
Normal file
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 7.8
|
||||
.target sm_90
|
||||
.address_size 64
|
||||
|
||||
.visible .entry cvt_rn_f16x2_e5m2x2(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b16 in;
|
||||
.reg .b32 result;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.b16 in, [in_addr];
|
||||
|
||||
cvt.rn.f16x2.e5m2x2 result, in;
|
||||
st.b32 [out_addr], result;
|
||||
ret;
|
||||
}
|
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx
Normal file
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx
Normal file
|
@ -0,0 +1,25 @@
|
|||
.version 7.8
|
||||
.target sm_90
|
||||
.address_size 64
|
||||
|
||||
.visible .entry cvt_rn_satfinite_e4m3x2_f32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 in_a;
|
||||
.reg .f32 in_b;
|
||||
.reg .b16 result;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 in_a, [in_addr];
|
||||
ld.f32 in_b, [in_addr + 4];
|
||||
|
||||
cvt.rn.satfinite.e4m3x2.f32 result, in_a, in_b;
|
||||
st.b16 [out_addr], result;
|
||||
ret;
|
||||
}
|
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx
Normal file
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx
Normal file
|
@ -0,0 +1,25 @@
|
|||
.version 7.8
|
||||
.target sm_90
|
||||
.address_size 64
|
||||
|
||||
.visible .entry cvt_rn_satfinite_e5m2x2_f32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 in_a;
|
||||
.reg .f32 in_b;
|
||||
.reg .b16 result;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 in_a, [in_addr];
|
||||
ld.f32 in_b, [in_addr + 4];
|
||||
|
||||
cvt.rn.satfinite.e5m2x2.f32 result, in_a, in_b;
|
||||
st.b16 [out_addr], result;
|
||||
ret;
|
||||
}
|
|
@ -185,6 +185,14 @@ test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
|
|||
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
|
||||
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
|
||||
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
|
||||
test_ptx!(cvt_rn_satfinite_e4m3x2_f32, [0.40625, 12.9f32], [0x2D55u16]);
|
||||
test_ptx!(
|
||||
cvt_rn_satfinite_e5m2x2_f32,
|
||||
[0.375, -5256.6f32],
|
||||
[0x36EDu16]
|
||||
);
|
||||
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
|
||||
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
|
||||
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
||||
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
||||
test_ptx!(
|
||||
|
|
|
@ -233,6 +233,11 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
type: { Type::Scalar(data.from) },
|
||||
relaxed_type_check: true,
|
||||
},
|
||||
src2: {
|
||||
repr: Option<T>,
|
||||
type: { Type::Scalar(data.from) },
|
||||
relaxed_type_check: true,
|
||||
},
|
||||
}
|
||||
},
|
||||
Cvta {
|
||||
|
@ -1047,7 +1052,9 @@ impl ScalarType {
|
|||
| ScalarType::S16
|
||||
| ScalarType::B16
|
||||
| ScalarType::F16
|
||||
| ScalarType::BF16 => 2,
|
||||
| ScalarType::BF16
|
||||
| ScalarType::E4m3x2
|
||||
| ScalarType::E5m2x2 => 2,
|
||||
ScalarType::U32
|
||||
| ScalarType::S32
|
||||
| ScalarType::B32
|
||||
|
@ -1069,7 +1076,9 @@ impl ScalarType {
|
|||
| ScalarType::S16
|
||||
| ScalarType::B16
|
||||
| ScalarType::F16
|
||||
| ScalarType::BF16 => Layout::new::<u16>(),
|
||||
| ScalarType::BF16
|
||||
| ScalarType::E4m3x2
|
||||
| ScalarType::E5m2x2 => Layout::new::<u16>(),
|
||||
ScalarType::U32
|
||||
| ScalarType::S32
|
||||
| ScalarType::B32
|
||||
|
@ -1110,6 +1119,8 @@ impl ScalarType {
|
|||
ScalarType::F64 => ScalarKind::Float,
|
||||
ScalarType::BF16 => ScalarKind::Float,
|
||||
ScalarType::BF16x2 => ScalarKind::Float,
|
||||
ScalarType::E4m3x2 => ScalarKind::Float,
|
||||
ScalarType::E5m2x2 => ScalarKind::Float,
|
||||
ScalarType::Pred => ScalarKind::Pred,
|
||||
}
|
||||
}
|
||||
|
@ -1884,7 +1895,9 @@ impl CvtDetails {
|
|||
saturate,
|
||||
},
|
||||
Ordering::Greater => {
|
||||
if rounding.is_some() {
|
||||
if rounding.is_some()
|
||||
&& !(src == ScalarType::E4m3x2 || src == ScalarType::E5m2x2)
|
||||
{
|
||||
errors.push(PtxError::SyntaxError(
|
||||
"should not have rounding mode when dst is larger than src in cvt"
|
||||
.to_string(),
|
||||
|
|
|
@ -2370,7 +2370,7 @@ derive_parser!(
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
|
||||
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
|
||||
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
|
||||
let arguments = ast::CvtArgs { dst: d, src: a };
|
||||
let arguments = ast::CvtArgs { dst: d, src: a, src2: None };
|
||||
ast::Instruction::Cvt {
|
||||
data, arguments
|
||||
}
|
||||
|
@ -2381,18 +2381,38 @@ derive_parser!(
|
|||
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
|
||||
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
||||
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
||||
// cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
|
||||
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {
|
||||
if relu {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, f8x2type, ScalarType::F32);
|
||||
ast::Instruction::Cvt {
|
||||
data,
|
||||
arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) }
|
||||
}
|
||||
}
|
||||
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
|
||||
// cvt.rn.{.relu}.f16x2.f8x2type d, a;
|
||||
cvt.rn{.relu}.f16x2.f8x2type d, a => {
|
||||
if relu {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, ScalarType::F16x2, f8x2type);
|
||||
ast::Instruction::Cvt {
|
||||
data,
|
||||
arguments: ast::CvtArgs { dst: d, src: a, src2: None }
|
||||
}
|
||||
}
|
||||
|
||||
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
|
||||
.frnd2: RawRoundingMode = { .rn, .rz };
|
||||
RawRoundingMode = { .rn };
|
||||
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||
.s8, .s16, .s32, .s64,
|
||||
.bf16, .f16, .f32, .f64 };
|
||||
.atype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||
.s8, .s16, .s32, .s64,
|
||||
.bf16, .f16, .f32, .f64 };
|
||||
.f8x2type: ScalarType = { .e4m3x2, .e5m2x2 };
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
|
||||
shl.type d, a, b => {
|
||||
ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue