mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-13 13:02:35 +00:00
Add support for fp8 to cvt
(#468)
This implements specifically the fp8 conversion instructions needed by llm.c: * `cvt.rn.satfinite{.relu}.f8x2type.f32` * `cvt.rn{.relu}.f16x2.f8x2type` It uses HIP's fp8 and fp16 headers: https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/low_fp_types.html#fp8-quarter-precision.
This commit is contained in:
parent
3632f2bf03
commit
8f484d6a5f
17 changed files with 507 additions and 63 deletions
Binary file not shown.
|
@ -1,12 +1,13 @@
|
||||||
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
// Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17`
|
||||||
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
// `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining
|
||||||
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <bit>
|
#include <bit>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <hip/amd_detail/amd_device_functions.h>
|
#include <hip/amd_detail/amd_device_functions.h>
|
||||||
|
#include <hip/hip_fp8.h>
|
||||||
|
|
||||||
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
#define CONSTANT_SPACE __attribute__((address_space(4)))
|
||||||
|
|
||||||
|
@ -476,4 +477,46 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
||||||
{
|
{
|
||||||
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
|
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ static __hip_fp8_storage_t cvt_float_to_fp8(float f, __hip_fp8_interpretation_t interp)
|
||||||
|
{
|
||||||
|
const uint32_t bits = reinterpret_cast<uint32_t &>(f);
|
||||||
|
const uint8_t sign = (bits & 0x80000000) ? 0x80 : 0x0;
|
||||||
|
const uint32_t abs = bits & 0x7fffffff;
|
||||||
|
|
||||||
|
const uint32_t min = interp == __HIP_E4M3 ? 0x3A800000 : 0x37000000;
|
||||||
|
if (abs < min)
|
||||||
|
{
|
||||||
|
return sign; // +/- 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return __hip_cvt_float_to_fp8(f, __HIP_SATFINITE, interp);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Fp8x2
|
||||||
|
{
|
||||||
|
__hip_fp8_storage_t b : 8;
|
||||||
|
__hip_fp8_storage_t a : 8;
|
||||||
|
};
|
||||||
|
|
||||||
|
Fp8x2 FUNC(cvt_rn_satfinite_e4m3x2_f32)(float a, float b)
|
||||||
|
{
|
||||||
|
// If built-in support for fp8 formats is added to LLVM IR we should switch to use that.
|
||||||
|
return {cvt_float_to_fp8(b, __HIP_E4M3), cvt_float_to_fp8(a, __HIP_E4M3)};
|
||||||
|
}
|
||||||
|
|
||||||
|
Fp8x2 FUNC(cvt_rn_satfinite_e5m2x2_f32)(float a, float b)
|
||||||
|
{
|
||||||
|
return {cvt_float_to_fp8(b, __HIP_E5M2), cvt_float_to_fp8(a, __HIP_E5M2)};
|
||||||
|
}
|
||||||
|
|
||||||
|
__half2 FUNC(cvt_rn_f16x2_e4m3x2)(__hip_fp8x2_e4m3 in)
|
||||||
|
{
|
||||||
|
return in;
|
||||||
|
}
|
||||||
|
|
||||||
|
__half2 FUNC(cvt_rn_f16x2_e5m2x2)(__hip_fp8x2_e5m2 in)
|
||||||
|
{
|
||||||
|
return in;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -299,7 +299,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
},
|
},
|
||||||
ptx_parser::ScalarType::S16
|
ptx_parser::ScalarType::S16
|
||||||
| ptx_parser::ScalarType::B16
|
| ptx_parser::ScalarType::B16
|
||||||
| ptx_parser::ScalarType::U16 => unsafe {
|
| ptx_parser::ScalarType::U16
|
||||||
|
| ptx_parser::ScalarType::E4m3x2
|
||||||
|
| ptx_parser::ScalarType::E5m2x2 => unsafe {
|
||||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||||
},
|
},
|
||||||
ptx_parser::ScalarType::S32
|
ptx_parser::ScalarType::S32
|
||||||
|
@ -1586,6 +1588,26 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
data: ptx_parser::CvtDetails,
|
data: ptx_parser::CvtDetails,
|
||||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
|
// Truncating conversions to FP8 types should be replaced by a function call.
|
||||||
|
match data {
|
||||||
|
ptx_parser::CvtDetails {
|
||||||
|
to: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
|
||||||
|
mode: ast::CvtMode::FPTruncate { .. },
|
||||||
|
..
|
||||||
|
} => return Err(error_unreachable()),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extending conversions from FP8 types should be replaced by a function call.
|
||||||
|
match data {
|
||||||
|
ptx_parser::CvtDetails {
|
||||||
|
from: ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2,
|
||||||
|
mode: ast::CvtMode::FPExtend { .. },
|
||||||
|
..
|
||||||
|
} => return Err(error_unreachable()),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
let dst_type = get_scalar_type(self.context, data.to);
|
let dst_type = get_scalar_type(self.context, data.to);
|
||||||
let llvm_fn = match data.mode {
|
let llvm_fn = match data.mode {
|
||||||
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
|
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
|
||||||
|
@ -3096,7 +3118,11 @@ impl std::fmt::Display for LLVMTypeDisplay {
|
||||||
match self.0 {
|
match self.0 {
|
||||||
ast::ScalarType::Pred => write!(f, "i1"),
|
ast::ScalarType::Pred => write!(f, "i1"),
|
||||||
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
|
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
|
||||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
|
ast::ScalarType::B16
|
||||||
|
| ast::ScalarType::U16
|
||||||
|
| ast::ScalarType::S16
|
||||||
|
| ast::ScalarType::E4m3x2
|
||||||
|
| ast::ScalarType::E5m2x2 => write!(f, "i16"),
|
||||||
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
|
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
|
||||||
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
|
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
|
||||||
ptx_parser::ScalarType::B128 => write!(f, "i128"),
|
ptx_parser::ScalarType::B128 => write!(f, "i128"),
|
||||||
|
|
|
@ -153,9 +153,11 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
||||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
|
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe {
|
||||||
LLVMInt8TypeInContext(context)
|
LLVMInt8TypeInContext(context)
|
||||||
},
|
},
|
||||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe {
|
ast::ScalarType::B16
|
||||||
LLVMInt16TypeInContext(context)
|
| ast::ScalarType::U16
|
||||||
},
|
| ast::ScalarType::S16
|
||||||
|
| ast::ScalarType::E4m3x2
|
||||||
|
| ast::ScalarType::E5m2x2 => unsafe { LLVMInt16TypeInContext(context) },
|
||||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
|
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe {
|
||||||
LLVMInt32TypeInContext(context)
|
LLVMInt32TypeInContext(context)
|
||||||
},
|
},
|
||||||
|
@ -169,7 +171,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
||||||
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
||||||
ast::ScalarType::U16x2 => todo!(),
|
ast::ScalarType::U16x2 => todo!(),
|
||||||
ast::ScalarType::S16x2 => todo!(),
|
ast::ScalarType::S16x2 => todo!(),
|
||||||
ast::ScalarType::F16x2 => todo!(),
|
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
|
||||||
ast::ScalarType::BF16x2 => todo!(),
|
ast::ScalarType::BF16x2 => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -995,6 +995,8 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
||||||
ast::ScalarType::BF16 => "bf16",
|
ast::ScalarType::BF16 => "bf16",
|
||||||
ast::ScalarType::BF16x2 => "bf16x2",
|
ast::ScalarType::BF16x2 => "bf16x2",
|
||||||
ast::ScalarType::Pred => "pred",
|
ast::ScalarType::Pred => "pred",
|
||||||
|
ast::ScalarType::E4m3x2 => "e4m3x2",
|
||||||
|
ast::ScalarType::E5m2x2 => "e5m2x2",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,38 @@ fn run_directive<'input>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_or_declare_function<'input, S: Into<Cow<'input, str>>>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
fn_declarations: &mut HashMap<
|
||||||
|
Cow<'input, str>,
|
||||||
|
(
|
||||||
|
Vec<ptx_parser::Variable<SpirvWord>>,
|
||||||
|
SpirvWord,
|
||||||
|
Vec<ptx_parser::Variable<SpirvWord>>,
|
||||||
|
),
|
||||||
|
rustc_hash::FxBuildHasher,
|
||||||
|
>,
|
||||||
|
name: S,
|
||||||
|
return_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
input_arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
) -> SpirvWord {
|
||||||
|
let func = match fn_declarations.entry(name.into()) {
|
||||||
|
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||||
|
hash_map::Entry::Vacant(vacant_entry) => {
|
||||||
|
let name = vacant_entry.key().clone();
|
||||||
|
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||||
|
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||||
|
vacant_entry.insert((
|
||||||
|
to_variables(resolver, return_arguments),
|
||||||
|
name,
|
||||||
|
to_variables(resolver, input_arguments),
|
||||||
|
));
|
||||||
|
name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
func
|
||||||
|
}
|
||||||
|
|
||||||
fn run_statements<'input>(
|
fn run_statements<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_declarations: &mut FxHashMap<
|
fn_declarations: &mut FxHashMap<
|
||||||
|
@ -99,7 +131,7 @@ fn run_statements<'input>(
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
ast::Type::Scalar(ast::ScalarType::U32),
|
||||||
ptx_parser::StateSpace::Reg,
|
ptx_parser::StateSpace::Reg,
|
||||||
)));
|
)));
|
||||||
let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat();
|
let name = ["shfl_sync_", mode, "_b32_pred"].concat();
|
||||||
let return_arguments = vec![(
|
let return_arguments = vec![(
|
||||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||||
ptx_parser::StateSpace::Reg,
|
ptx_parser::StateSpace::Reg,
|
||||||
|
@ -122,45 +154,19 @@ fn run_statements<'input>(
|
||||||
ptx_parser::StateSpace::Reg,
|
ptx_parser::StateSpace::Reg,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
let func = match fn_declarations.entry(full_name.into()) {
|
let func = get_or_declare_function(
|
||||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
resolver,
|
||||||
hash_map::Entry::Vacant(vacant_entry) => {
|
fn_declarations,
|
||||||
let name = vacant_entry.key().clone();
|
name,
|
||||||
let name = resolver.register_named(name, None);
|
&return_arguments,
|
||||||
vacant_entry.insert((
|
&input_arguments,
|
||||||
to_variables(resolver, &return_arguments),
|
);
|
||||||
name,
|
|
||||||
to_variables(resolver, &input_arguments),
|
|
||||||
));
|
|
||||||
name
|
|
||||||
}
|
|
||||||
};
|
|
||||||
smallvec![
|
smallvec![
|
||||||
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
||||||
data: ptx_parser::CallDetails {
|
data: ptx_parser::CallDetails {
|
||||||
uniform: false,
|
uniform: false,
|
||||||
return_arguments: vec![(
|
return_arguments,
|
||||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
input_arguments
|
||||||
ptx_parser::StateSpace::Reg,
|
|
||||||
)],
|
|
||||||
input_arguments: vec![
|
|
||||||
(
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ptx_parser::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ptx_parser::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ptx_parser::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
ast::Type::Scalar(ast::ScalarType::U32),
|
|
||||||
ptx_parser::StateSpace::Reg,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
arguments: ptx_parser::CallArgs {
|
arguments: ptx_parser::CallArgs {
|
||||||
return_arguments: vec![packed_var],
|
return_arguments: vec![packed_var],
|
||||||
|
@ -184,6 +190,73 @@ fn run_statements<'input>(
|
||||||
arguments: ast::CvtArgs {
|
arguments: ast::CvtArgs {
|
||||||
dst: dst_pred,
|
dst: dst_pred,
|
||||||
src: dst_pred_wide,
|
src: dst_pred_wide,
|
||||||
|
src2: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
]
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Cvt {
|
||||||
|
data:
|
||||||
|
ast::CvtDetails {
|
||||||
|
from: from @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
|
||||||
|
to: ast::ScalarType::F16x2,
|
||||||
|
mode: _,
|
||||||
|
},
|
||||||
|
arguments:
|
||||||
|
ast::CvtArgs {
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
src2: None,
|
||||||
|
},
|
||||||
|
}) => {
|
||||||
|
let from_str = match from {
|
||||||
|
ast::ScalarType::E4m3x2 => "e4m3x2",
|
||||||
|
ast::ScalarType::E5m2x2 => "e5m2x2",
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let packed_output = resolver.register_unnamed(Some((
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B32),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
let name = format!("cvt_rn_f16x2_{}", from_str);
|
||||||
|
let return_arguments = vec![(
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B32),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)];
|
||||||
|
let input_arguments = vec![(
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B16),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)];
|
||||||
|
let func = get_or_declare_function(
|
||||||
|
resolver,
|
||||||
|
fn_declarations,
|
||||||
|
name,
|
||||||
|
&return_arguments,
|
||||||
|
&input_arguments,
|
||||||
|
);
|
||||||
|
smallvec![
|
||||||
|
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
||||||
|
data: ptx_parser::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments,
|
||||||
|
input_arguments,
|
||||||
|
},
|
||||||
|
arguments: ptx_parser::CallArgs {
|
||||||
|
return_arguments: vec![packed_output],
|
||||||
|
func,
|
||||||
|
input_arguments: vec![src],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Statement::Instruction(ast::Instruction::Cvt {
|
||||||
|
data: ast::CvtDetails {
|
||||||
|
from: ast::ScalarType::B32,
|
||||||
|
to: ast::ScalarType::F16x2,
|
||||||
|
mode: ast::CvtMode::Bitcast
|
||||||
|
},
|
||||||
|
arguments: ast::CvtArgs {
|
||||||
|
dst,
|
||||||
|
src: packed_output,
|
||||||
|
src2: None,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
]
|
]
|
||||||
|
@ -335,6 +408,29 @@ fn run_instruction<'input>(
|
||||||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
||||||
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
to_call(resolver, fn_declarations, "nanosleep_u32".into(), i)?
|
||||||
}
|
}
|
||||||
|
i @ ptx_parser::Instruction::Cvt {
|
||||||
|
data:
|
||||||
|
ptx_parser::CvtDetails {
|
||||||
|
from: ast::ScalarType::F32,
|
||||||
|
to: to @ (ast::ScalarType::E4m3x2 | ast::ScalarType::E5m2x2),
|
||||||
|
mode: _,
|
||||||
|
},
|
||||||
|
arguments: _,
|
||||||
|
} => {
|
||||||
|
let to = match to {
|
||||||
|
ptx_parser::ScalarType::E4m3x2 => "e4m3x2",
|
||||||
|
ptx_parser::ScalarType::E5m2x2 => "e5m2x2",
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
// Conversions from f32 to f8 must have two source arguments.
|
||||||
|
// satfinite is mandatory for conversions to e4m3x2.
|
||||||
|
to_call(
|
||||||
|
resolver,
|
||||||
|
fn_declarations,
|
||||||
|
format!("cvt_rn_satfinite_{}_f32", to).into(),
|
||||||
|
i,
|
||||||
|
)?
|
||||||
|
}
|
||||||
i => i,
|
i => i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -373,20 +469,8 @@ fn to_call<'input>(
|
||||||
};
|
};
|
||||||
Ok::<_, TranslateError>(())
|
Ok::<_, TranslateError>(())
|
||||||
})?;
|
})?;
|
||||||
let fn_name = match fn_declarations.entry(name) {
|
let fn_name =
|
||||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
get_or_declare_function(resolver, fn_declarations, name, &data_return, &data_input);
|
||||||
hash_map::Entry::Vacant(vacant_entry) => {
|
|
||||||
let name = vacant_entry.key().clone();
|
|
||||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
|
||||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
|
||||||
vacant_entry.insert((
|
|
||||||
to_variables(resolver, &data_return),
|
|
||||||
name,
|
|
||||||
to_variables(resolver, &data_input),
|
|
||||||
));
|
|
||||||
name
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(ast::Instruction::Call {
|
Ok(ast::Instruction::Call {
|
||||||
data: ptx_parser::CallDetails {
|
data: ptx_parser::CallDetails {
|
||||||
uniform: false,
|
uniform: false,
|
||||||
|
|
35
ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll
Normal file
35
ptx/src/test/ll/cvt_rn_f16x2_e4m3x2.ll
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @cvt_rn_f16x2_e4m3x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
|
||||||
|
%"33" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"34" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"35" = alloca i16, align 2, addrspace(5)
|
||||||
|
%"36" = alloca i32, align 4, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"30"
|
||||||
|
|
||||||
|
"30": ; preds = %1
|
||||||
|
%"37" = load i64, ptr addrspace(4) %"31", align 8
|
||||||
|
store i64 %"37", ptr addrspace(5) %"33", align 8
|
||||||
|
%"38" = load i64, ptr addrspace(4) %"32", align 8
|
||||||
|
store i64 %"38", ptr addrspace(5) %"34", align 8
|
||||||
|
%"40" = load i64, ptr addrspace(5) %"33", align 8
|
||||||
|
%"45" = inttoptr i64 %"40" to ptr
|
||||||
|
%"39" = load i16, ptr %"45", align 2
|
||||||
|
store i16 %"39", ptr addrspace(5) %"35", align 2
|
||||||
|
%"42" = load i16, ptr addrspace(5) %"35", align 2
|
||||||
|
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e4m3x2(i16 %"42")
|
||||||
|
%"46" = bitcast i32 %"49" to <2 x half>
|
||||||
|
%"41" = bitcast <2 x half> %"46" to i32
|
||||||
|
store i32 %"41", ptr addrspace(5) %"36", align 4
|
||||||
|
%"43" = load i64, ptr addrspace(5) %"34", align 8
|
||||||
|
%"44" = load i32, ptr addrspace(5) %"36", align 4
|
||||||
|
%"48" = inttoptr i64 %"43" to ptr
|
||||||
|
store i32 %"44", ptr %"48", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
35
ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll
Normal file
35
ptx/src/test/ll/cvt_rn_f16x2_e5m2x2.ll
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @cvt_rn_f16x2_e5m2x2(ptr addrspace(4) byref(i64) %"31", ptr addrspace(4) byref(i64) %"32") #1 {
|
||||||
|
%"33" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"34" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"35" = alloca i16, align 2, addrspace(5)
|
||||||
|
%"36" = alloca i32, align 4, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"30"
|
||||||
|
|
||||||
|
"30": ; preds = %1
|
||||||
|
%"37" = load i64, ptr addrspace(4) %"31", align 8
|
||||||
|
store i64 %"37", ptr addrspace(5) %"33", align 8
|
||||||
|
%"38" = load i64, ptr addrspace(4) %"32", align 8
|
||||||
|
store i64 %"38", ptr addrspace(5) %"34", align 8
|
||||||
|
%"40" = load i64, ptr addrspace(5) %"33", align 8
|
||||||
|
%"45" = inttoptr i64 %"40" to ptr
|
||||||
|
%"39" = load i16, ptr %"45", align 2
|
||||||
|
store i16 %"39", ptr addrspace(5) %"35", align 2
|
||||||
|
%"42" = load i16, ptr addrspace(5) %"35", align 2
|
||||||
|
%"49" = call i32 @__zluda_ptx_impl_cvt_rn_f16x2_e5m2x2(i16 %"42")
|
||||||
|
%"46" = bitcast i32 %"49" to <2 x half>
|
||||||
|
%"41" = bitcast <2 x half> %"46" to i32
|
||||||
|
store i32 %"41", ptr addrspace(5) %"36", align 4
|
||||||
|
%"43" = load i64, ptr addrspace(5) %"34", align 8
|
||||||
|
%"44" = load i32, ptr addrspace(5) %"36", align 4
|
||||||
|
%"48" = inttoptr i64 %"43" to ptr
|
||||||
|
store i32 %"44", ptr %"48", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
40
ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll
Normal file
40
ptx/src/test/ll/cvt_rn_satfinite_e4m3x2_f32.ll
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float, float) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @cvt_rn_satfinite_e4m3x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
|
||||||
|
%"36" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"37" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"38" = alloca float, align 4, addrspace(5)
|
||||||
|
%"39" = alloca float, align 4, addrspace(5)
|
||||||
|
%"40" = alloca i16, align 2, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"33"
|
||||||
|
|
||||||
|
"33": ; preds = %1
|
||||||
|
%"41" = load i64, ptr addrspace(4) %"34", align 8
|
||||||
|
store i64 %"41", ptr addrspace(5) %"36", align 8
|
||||||
|
%"42" = load i64, ptr addrspace(4) %"35", align 8
|
||||||
|
store i64 %"42", ptr addrspace(5) %"37", align 8
|
||||||
|
%"44" = load i64, ptr addrspace(5) %"36", align 8
|
||||||
|
%"52" = inttoptr i64 %"44" to ptr
|
||||||
|
%"43" = load float, ptr %"52", align 4
|
||||||
|
store float %"43", ptr addrspace(5) %"38", align 4
|
||||||
|
%"45" = load i64, ptr addrspace(5) %"36", align 8
|
||||||
|
%"53" = inttoptr i64 %"45" to ptr
|
||||||
|
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
|
||||||
|
%"46" = load float, ptr %"32", align 4
|
||||||
|
store float %"46", ptr addrspace(5) %"39", align 4
|
||||||
|
%"48" = load float, ptr addrspace(5) %"38", align 4
|
||||||
|
%"49" = load float, ptr addrspace(5) %"39", align 4
|
||||||
|
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e4m3x2_f32(float %"48", float %"49")
|
||||||
|
store i16 %"54", ptr addrspace(5) %"40", align 2
|
||||||
|
%"50" = load i64, ptr addrspace(5) %"37", align 8
|
||||||
|
%"51" = load i16, ptr addrspace(5) %"40", align 2
|
||||||
|
%"55" = inttoptr i64 %"50" to ptr
|
||||||
|
store i16 %"51", ptr %"55", align 2
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
40
ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll
Normal file
40
ptx/src/test/ll/cvt_rn_satfinite_e5m2x2_f32.ll
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
declare hidden i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float, float) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @cvt_rn_satfinite_e5m2x2_f32(ptr addrspace(4) byref(i64) %"34", ptr addrspace(4) byref(i64) %"35") #1 {
|
||||||
|
%"36" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"37" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"38" = alloca float, align 4, addrspace(5)
|
||||||
|
%"39" = alloca float, align 4, addrspace(5)
|
||||||
|
%"40" = alloca i16, align 2, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"33"
|
||||||
|
|
||||||
|
"33": ; preds = %1
|
||||||
|
%"41" = load i64, ptr addrspace(4) %"34", align 8
|
||||||
|
store i64 %"41", ptr addrspace(5) %"36", align 8
|
||||||
|
%"42" = load i64, ptr addrspace(4) %"35", align 8
|
||||||
|
store i64 %"42", ptr addrspace(5) %"37", align 8
|
||||||
|
%"44" = load i64, ptr addrspace(5) %"36", align 8
|
||||||
|
%"52" = inttoptr i64 %"44" to ptr
|
||||||
|
%"43" = load float, ptr %"52", align 4
|
||||||
|
store float %"43", ptr addrspace(5) %"38", align 4
|
||||||
|
%"45" = load i64, ptr addrspace(5) %"36", align 8
|
||||||
|
%"53" = inttoptr i64 %"45" to ptr
|
||||||
|
%"32" = getelementptr inbounds i8, ptr %"53", i64 4
|
||||||
|
%"46" = load float, ptr %"32", align 4
|
||||||
|
store float %"46", ptr addrspace(5) %"39", align 4
|
||||||
|
%"48" = load float, ptr addrspace(5) %"38", align 4
|
||||||
|
%"49" = load float, ptr addrspace(5) %"39", align 4
|
||||||
|
%"54" = call i16 @__zluda_ptx_impl_cvt_rn_satfinite_e5m2x2_f32(float %"48", float %"49")
|
||||||
|
store i16 %"54", ptr addrspace(5) %"40", align 2
|
||||||
|
%"50" = load i64, ptr addrspace(5) %"37", align 8
|
||||||
|
%"51" = load i16, ptr addrspace(5) %"40", align 2
|
||||||
|
%"55" = inttoptr i64 %"50" to ptr
|
||||||
|
store i16 %"51", ptr %"55", align 2
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx
Normal file
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e4m3x2.ptx
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
.version 7.8
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry cvt_rn_f16x2_e4m3x2(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .b16 in;
|
||||||
|
.reg .b32 result;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.b16 in, [in_addr];
|
||||||
|
|
||||||
|
cvt.rn.f16x2.e4m3x2 result, in;
|
||||||
|
st.b32 [out_addr], result;
|
||||||
|
ret;
|
||||||
|
}
|
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx
Normal file
23
ptx/src/test/spirv_run/cvt_rn_f16x2_e5m2x2.ptx
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
.version 7.8
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry cvt_rn_f16x2_e5m2x2(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .b16 in;
|
||||||
|
.reg .b32 result;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.b16 in, [in_addr];
|
||||||
|
|
||||||
|
cvt.rn.f16x2.e5m2x2 result, in;
|
||||||
|
st.b32 [out_addr], result;
|
||||||
|
ret;
|
||||||
|
}
|
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx
Normal file
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e4m3x2_f32.ptx
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
.version 7.8
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry cvt_rn_satfinite_e4m3x2_f32(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 in_a;
|
||||||
|
.reg .f32 in_b;
|
||||||
|
.reg .b16 result;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 in_a, [in_addr];
|
||||||
|
ld.f32 in_b, [in_addr + 4];
|
||||||
|
|
||||||
|
cvt.rn.satfinite.e4m3x2.f32 result, in_a, in_b;
|
||||||
|
st.b16 [out_addr], result;
|
||||||
|
ret;
|
||||||
|
}
|
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx
Normal file
25
ptx/src/test/spirv_run/cvt_rn_satfinite_e5m2x2_f32.ptx
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
.version 7.8
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry cvt_rn_satfinite_e5m2x2_f32(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 in_a;
|
||||||
|
.reg .f32 in_b;
|
||||||
|
.reg .b16 result;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 in_a, [in_addr];
|
||||||
|
ld.f32 in_b, [in_addr + 4];
|
||||||
|
|
||||||
|
cvt.rn.satfinite.e5m2x2.f32 result, in_a, in_b;
|
||||||
|
st.b16 [out_addr], result;
|
||||||
|
ret;
|
||||||
|
}
|
|
@ -185,6 +185,14 @@ test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
|
||||||
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
|
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
|
||||||
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
|
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);
|
||||||
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
|
test_ptx!(cvt_rni_u16_f32, [0x477FFF80u32], [65535u16]);
|
||||||
|
test_ptx!(cvt_rn_satfinite_e4m3x2_f32, [0.40625, 12.9f32], [0x2D55u16]);
|
||||||
|
test_ptx!(
|
||||||
|
cvt_rn_satfinite_e5m2x2_f32,
|
||||||
|
[0.375, -5256.6f32],
|
||||||
|
[0x36EDu16]
|
||||||
|
);
|
||||||
|
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
|
||||||
|
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
|
||||||
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
||||||
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
||||||
test_ptx!(
|
test_ptx!(
|
||||||
|
|
|
@ -233,6 +233,11 @@ ptx_parser_macros::generate_instruction_type!(
|
||||||
type: { Type::Scalar(data.from) },
|
type: { Type::Scalar(data.from) },
|
||||||
relaxed_type_check: true,
|
relaxed_type_check: true,
|
||||||
},
|
},
|
||||||
|
src2: {
|
||||||
|
repr: Option<T>,
|
||||||
|
type: { Type::Scalar(data.from) },
|
||||||
|
relaxed_type_check: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Cvta {
|
Cvta {
|
||||||
|
@ -1047,7 +1052,9 @@ impl ScalarType {
|
||||||
| ScalarType::S16
|
| ScalarType::S16
|
||||||
| ScalarType::B16
|
| ScalarType::B16
|
||||||
| ScalarType::F16
|
| ScalarType::F16
|
||||||
| ScalarType::BF16 => 2,
|
| ScalarType::BF16
|
||||||
|
| ScalarType::E4m3x2
|
||||||
|
| ScalarType::E5m2x2 => 2,
|
||||||
ScalarType::U32
|
ScalarType::U32
|
||||||
| ScalarType::S32
|
| ScalarType::S32
|
||||||
| ScalarType::B32
|
| ScalarType::B32
|
||||||
|
@ -1069,7 +1076,9 @@ impl ScalarType {
|
||||||
| ScalarType::S16
|
| ScalarType::S16
|
||||||
| ScalarType::B16
|
| ScalarType::B16
|
||||||
| ScalarType::F16
|
| ScalarType::F16
|
||||||
| ScalarType::BF16 => Layout::new::<u16>(),
|
| ScalarType::BF16
|
||||||
|
| ScalarType::E4m3x2
|
||||||
|
| ScalarType::E5m2x2 => Layout::new::<u16>(),
|
||||||
ScalarType::U32
|
ScalarType::U32
|
||||||
| ScalarType::S32
|
| ScalarType::S32
|
||||||
| ScalarType::B32
|
| ScalarType::B32
|
||||||
|
@ -1110,6 +1119,8 @@ impl ScalarType {
|
||||||
ScalarType::F64 => ScalarKind::Float,
|
ScalarType::F64 => ScalarKind::Float,
|
||||||
ScalarType::BF16 => ScalarKind::Float,
|
ScalarType::BF16 => ScalarKind::Float,
|
||||||
ScalarType::BF16x2 => ScalarKind::Float,
|
ScalarType::BF16x2 => ScalarKind::Float,
|
||||||
|
ScalarType::E4m3x2 => ScalarKind::Float,
|
||||||
|
ScalarType::E5m2x2 => ScalarKind::Float,
|
||||||
ScalarType::Pred => ScalarKind::Pred,
|
ScalarType::Pred => ScalarKind::Pred,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1884,7 +1895,9 @@ impl CvtDetails {
|
||||||
saturate,
|
saturate,
|
||||||
},
|
},
|
||||||
Ordering::Greater => {
|
Ordering::Greater => {
|
||||||
if rounding.is_some() {
|
if rounding.is_some()
|
||||||
|
&& !(src == ScalarType::E4m3x2 || src == ScalarType::E5m2x2)
|
||||||
|
{
|
||||||
errors.push(PtxError::SyntaxError(
|
errors.push(PtxError::SyntaxError(
|
||||||
"should not have rounding mode when dst is larger than src in cvt"
|
"should not have rounding mode when dst is larger than src in cvt"
|
||||||
.to_string(),
|
.to_string(),
|
||||||
|
|
|
@ -2370,7 +2370,7 @@ derive_parser!(
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
|
||||||
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
|
cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
|
||||||
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
|
let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
|
||||||
let arguments = ast::CvtArgs { dst: d, src: a };
|
let arguments = ast::CvtArgs { dst: d, src: a, src2: None };
|
||||||
ast::Instruction::Cvt {
|
ast::Instruction::Cvt {
|
||||||
data, arguments
|
data, arguments
|
||||||
}
|
}
|
||||||
|
@ -2381,18 +2381,38 @@ derive_parser!(
|
||||||
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
|
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
|
||||||
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
||||||
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
||||||
// cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
|
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {
|
||||||
|
if relu {
|
||||||
|
state.errors.push(PtxError::Todo);
|
||||||
|
}
|
||||||
|
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, f8x2type, ScalarType::F32);
|
||||||
|
ast::Instruction::Cvt {
|
||||||
|
data,
|
||||||
|
arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) }
|
||||||
|
}
|
||||||
|
}
|
||||||
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
|
// cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
|
||||||
// cvt.rn.{.relu}.f16x2.f8x2type d, a;
|
cvt.rn{.relu}.f16x2.f8x2type d, a => {
|
||||||
|
if relu {
|
||||||
|
state.errors.push(PtxError::Todo);
|
||||||
|
}
|
||||||
|
let data = ast::CvtDetails::new(&mut state.errors, Some(rn), false, false, ScalarType::F16x2, f8x2type);
|
||||||
|
ast::Instruction::Cvt {
|
||||||
|
data,
|
||||||
|
arguments: ast::CvtArgs { dst: d, src: a, src2: None }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
|
.ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
|
||||||
.frnd2: RawRoundingMode = { .rn, .rz };
|
.frnd2: RawRoundingMode = { .rn, .rz };
|
||||||
|
RawRoundingMode = { .rn };
|
||||||
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
|
.dtype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||||
.s8, .s16, .s32, .s64,
|
.s8, .s16, .s32, .s64,
|
||||||
.bf16, .f16, .f32, .f64 };
|
.bf16, .f16, .f32, .f64 };
|
||||||
.atype: ScalarType = { .u8, .u16, .u32, .u64,
|
.atype: ScalarType = { .u8, .u16, .u32, .u64,
|
||||||
.s8, .s16, .s32, .s64,
|
.s8, .s16, .s32, .s64,
|
||||||
.bf16, .f16, .f32, .f64 };
|
.bf16, .f16, .f32, .f64 };
|
||||||
|
.f8x2type: ScalarType = { .e4m3x2, .e5m2x2 };
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
|
||||||
shl.type d, a, b => {
|
shl.type d, a, b => {
|
||||||
ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }
|
ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue