Fix how full-precision fp32 sqrt and div are handled (#467)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

Previously, when compiling full precision `sqrt`/`div` we'd leave it to the LLVM. LLVM looks at module's `denormal-fp-math-f32` mode, which is incompatible with how we handle denormals and could give wrong results in certain edge cases.
Instead handle it fully inside ZLUDA
This commit is contained in:
Andrzej Janik 2025-08-15 02:24:40 +02:00 committed by GitHub
commit 65367f04ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1092 additions and 139 deletions

5
Cargo.lock generated
View file

@ -2603,6 +2603,7 @@ dependencies = [
"quick-error",
"rustc-hash 2.0.0",
"serde",
"smallvec",
"strum 0.26.3",
"strum_macros 0.26.4",
"tempfile",
@ -2940,9 +2941,9 @@ checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "smallvec"
version = "1.13.2"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "sprs"

View file

@ -21,6 +21,7 @@ petgraph = "0.7.1"
microlp = "0.2.11"
int-enum = "1.1"
unwrap_or = "1.0.1"
smallvec = "1.15.1"
serde = { version = "1.0.219", features = ["derive"] }
[dev-dependencies]

Binary file not shown.

View file

@ -10,7 +10,9 @@
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__
#define DECLARE_ATTR(TYPE, NAME) \
extern const TYPE ATTR(NAME) \
__device__
extern "C"
{
@ -100,19 +102,6 @@ extern "C"
}
}
static __device__ uint32_t sub_sat(uint32_t x, uint32_t y)
{
uint32_t result;
if (__builtin_sub_overflow(x, y, &result))
{
return 0;
}
else
{
return result;
}
}
int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len)
{
// NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len`
@ -122,7 +111,7 @@ extern "C"
if (pos >= 64)
return (base >> 63U);
if (add_sat(pos, len) >= 64)
len = sub_sat(64, pos);
len = 64 - pos;
return (base << (64U - pos - len)) >> (64U - len);
}
@ -174,11 +163,8 @@ extern "C"
BAR_RED_IMPL(and);
BAR_RED_IMPL(or);
struct ShflSyncResult
{
uint32_t output;
bool in_bounds;
};
typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
// shfl.sync opts consists of two values, the warp end ID and the subsection mask.
//
@ -192,7 +178,6 @@ extern "C"
// The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
// bounds check.
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
{ \
@ -208,12 +193,12 @@ extern "C"
idx = self; \
} \
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
return {(uint32_t)output, !out_of_bounds}; \
return {(uint32_t)output, uint32_t(!out_of_bounds)}; \
} \
\
uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \
{ \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \
}
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
@ -226,7 +211,8 @@ extern "C"
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
DECLARE_ATTR(uint32_t, CLOCK_RATE);
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
void FUNC(nanosleep_u32)(uint32_t nanoseconds)
{
// clock_rate is in kHz
uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000;
uint64_t cycles = nanoseconds * cycles_per_ns;
@ -335,4 +321,157 @@ extern "C"
else
return value;
}
// Logic taken from legalizeFSQRTF32/lowerFSQRTF32 in LLVM AMDGPU target
__device__ static float precise_square_root(float x, bool needs_denorm_handling)
{
// Constants for denormal handling
const float scale_threshold = 0x1.0p-96f; // Very small value threshold
const float scale_up_factor = 0x1.0p+32f; // 2^32
const float scale_down_factor = 0x1.0p-16f; // 2^-16
// Check if input needs scaling (for very small values)
bool need_scale = scale_threshold > x;
auto scaled = scale_up_factor * x;
// Scale up input if needed
float sqrt_x = need_scale ? scaled : x;
float sqrt_s;
// Check if we need special denormal handling
if (needs_denorm_handling)
{
// Use hardware sqrt as initial approximation
sqrt_s = __builtin_sqrtf(sqrt_x); // Or equivalent hardware instruction
// Bit manipulations to get next values up/down
uint32_t sqrt_s_bits = std::bit_cast<uint32_t>(sqrt_s);
// Next value down (subtract 1 from bit pattern)
uint32_t sqrt_s_next_down_bits = sqrt_s_bits - 1;
float sqrt_s_next_down = std::bit_cast<float>(sqrt_s_next_down_bits);
// Calculate residual: x - sqrt_next_down * sqrt
float neg_sqrt_s_next_down = -sqrt_s_next_down;
float sqrt_vp = std::fma(neg_sqrt_s_next_down, sqrt_s, sqrt_x);
// Next value up (add 1 to bit pattern)
uint32_t sqrt_s_next_up_bits = sqrt_s_bits + 1;
float sqrt_s_next_up = std::bit_cast<float>(sqrt_s_next_up_bits);
// Calculate residual: x - sqrt_next_up * sqrt
float neg_sqrt_s_next_up = -sqrt_s_next_up;
float sqrt_vs = std::fma(neg_sqrt_s_next_up, sqrt_s, sqrt_x);
// Select correctly rounded result
if (sqrt_vp <= 0.0f)
{
sqrt_s = sqrt_s_next_down;
}
if (sqrt_vs > 0.0f)
{
sqrt_s = sqrt_s_next_up;
}
}
else
{
// Use Newton-Raphson method with reciprocal square root
// Initial approximation
float sqrt_r = __builtin_amdgcn_rsqf(sqrt_x); // Or equivalent hardware 1/sqrt instruction
sqrt_s = sqrt_x * sqrt_r;
// Refine approximation
float half = 0.5f;
float sqrt_h = sqrt_r * half;
float neg_sqrt_h = -sqrt_h;
// Calculate error term
float sqrt_e = std::fma(neg_sqrt_h, sqrt_s, half);
// First refinement
sqrt_h = std::fma(sqrt_h, sqrt_e, sqrt_h);
sqrt_s = std::fma(sqrt_s, sqrt_e, sqrt_s);
// Second refinement
float neg_sqrt_s = -sqrt_s;
float sqrt_d = std::fma(neg_sqrt_s, sqrt_s, sqrt_x);
sqrt_s = std::fma(sqrt_d, sqrt_h, sqrt_s);
}
// Scale back if input was scaled
if (need_scale)
{
sqrt_s *= scale_down_factor;
}
// Special case handling for zero and infinity
bool is_zero_or_inf = __builtin_isfpclass(sqrt_x, __FPCLASS_POSINF | __FPCLASS_POSZERO | __FPCLASS_NEGZERO);
return is_zero_or_inf ? sqrt_x : sqrt_s;
}
float FUNC(sqrt_rn_f32)(float x)
{
return precise_square_root(x, true);
}
float FUNC(sqrt_rn_ftz_f32)(float x)
{
return precise_square_root(x, false);
}
struct DivRnFtzF32Part1Result
{
float fma_4;
float fma_1;
float fma_3;
uint8_t numerator_scaled_flag;
};
DivRnFtzF32Part1Result FUNC(div_f32_part1)(float lhs, float rhs)
{
float one = 1.0f;
// Division scale operations
bool denominator_scaled_flag;
float denominator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, false, &denominator_scaled_flag);
bool numerator_scaled_flag;
float numerator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, true, &numerator_scaled_flag);
// Reciprocal approximation
float approx_rcp = __builtin_amdgcn_rcpf(denominator_scaled);
float neg_div_scale0 = -denominator_scaled;
// Perform division approximation steps
float fma_0 = fmaf(neg_div_scale0, approx_rcp, one);
float fma_1 = fmaf(fma_0, approx_rcp, approx_rcp);
float mul = numerator_scaled * fma_1;
float fma_2 = fmaf(neg_div_scale0, mul, numerator_scaled);
float fma_3 = fmaf(fma_2, fma_1, mul);
float fma_4 = fmaf(neg_div_scale0, fma_3, numerator_scaled);
return {fma_4, fma_1, fma_3, numerator_scaled_flag};
}
__device__ static float div_f32_part2(float x, float y, DivRnFtzF32Part1Result part1)
{
float fmas = __builtin_amdgcn_div_fmasf(part1.fma_4, part1.fma_1, part1.fma_3, part1.numerator_scaled_flag);
float result = __builtin_amdgcn_div_fixupf(fmas, y, x);
return result;
}
float FUNC(div_f32_part2)(float x, float y,
float fma_4,
float fma_1,
float fma_3,
uint8_t numerator_scaled_flag)
{
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
}
}

View file

@ -803,6 +803,13 @@ fn create_control_flow_graph(
let modes = get_modes(instruction);
bb_state.append(modes);
}
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
bb_state.append(InstructionModes::new(
ast::ScalarType::F32,
ftz_f32.map(DenormalMode::from_ftz),
rnd_f32.map(RoundingMode::from_ast),
));
}
_ => {}
}
}
@ -1021,6 +1028,16 @@ fn apply_global_mode_controls(
let modes = get_modes(&instruction);
bb_state.insert(&mut result, modes)?;
}
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
bb_state.insert(
&mut result,
InstructionModes::new(
ast::ScalarType::F32,
ftz_f32.map(DenormalMode::from_ftz),
rnd_f32.map(RoundingMode::from_ast),
),
)?;
}
_ => {}
}
result.push(statement);

View file

@ -397,6 +397,8 @@ impl<'a> MethodEmitContext<'a> {
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?,
// No-op
Statement::FpModeRequired { .. } => {}
})
}
@ -825,25 +827,12 @@ impl<'a> MethodEmitContext<'a> {
match &*arguments.return_arguments {
[] => {}
[name] => self.resolver.register(*name, llvm_call),
[b32, pred] => {
self.resolver.with_result(*b32, |name| unsafe {
LLVMBuildExtractValue(self.builder, llvm_call, 0, name)
});
self.resolver.with_result(*pred, |name| unsafe {
let extracted =
LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr());
LLVMBuildTrunc(
self.builder,
extracted,
get_scalar_type(self.context, ast::ScalarType::Pred),
name,
)
});
}
_ => {
return Err(error_todo_msg(
"Only two return arguments (.b32, .pred) currently supported",
))
args => {
for (i, arg) in args.iter().copied().enumerate() {
self.resolver.with_result(arg, |name| unsafe {
LLVMBuildExtractValue(self.builder, llvm_call, i as u32, name)
});
}
}
}
Ok(())
@ -992,44 +981,28 @@ impl<'a> MethodEmitContext<'a> {
unsafe {
LLVMSetAlignment(load, type_.layout().align() as u32);
}
Ok(load)
Ok((load, type_))
})
.collect::<Result<Vec<_>, _>>()?;
match &*loads {
[] => unsafe { LLVMBuildRetVoid(self.builder) },
[value] => unsafe { LLVMBuildRet(self.builder, *value) },
_ => {
check_multiple_return_types(values.iter().map(|(_, type_)| type_))?;
let array_ty =
get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?;
let insert_b32 = unsafe {
LLVMBuildInsertValue(
self.builder,
LLVMGetPoison(array_ty),
loads[0],
0,
LLVM_UNNAMED.as_ptr(),
)
};
let zext_pred = unsafe {
LLVMBuildZExt(
self.builder,
loads[1],
get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?,
LLVM_UNNAMED.as_ptr(),
)
};
let insert_pred = unsafe {
LLVMBuildInsertValue(
self.builder,
insert_b32,
zext_pred,
1,
LLVM_UNNAMED.as_ptr(),
)
};
unsafe { LLVMBuildRet(self.builder, insert_pred) }
[(value, _)] => unsafe { LLVMBuildRet(self.builder, *value) },
loads => {
let struct_type =
get_or_create_struct_type(self.context, loads.iter().map(|(_, type_)| *type_))?;
let mut value = unsafe { LLVMGetUndef(struct_type) };
for (i, (load, _)) in loads.iter().enumerate() {
value = unsafe {
LLVMBuildInsertValue(
self.builder,
value,
*load,
i as u32,
LLVM_UNNAMED.as_ptr(),
)
};
}
unsafe { LLVMBuildRet(self.builder, value) }
}
};
Ok(())
@ -1898,10 +1871,10 @@ impl<'a> MethodEmitContext<'a> {
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_func: unsafe extern "C" fn(
arg1: LLVMBuilderRef,
Val: LLVMValueRef,
DestTy: LLVMTypeRef,
Name: *const i8,
LLVMBuilderRef,
LLVMValueRef,
LLVMTypeRef,
*const i8,
) -> LLVMValueRef,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, to);
@ -2928,46 +2901,60 @@ fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, T
})
}
fn get_array_type<'a>(
fn get_or_create_struct_type<'a>(
context: LLVMContextRef,
elem_type: &'a ast::Type,
count: u64,
mut elem_types: impl Iterator<Item = &'a ast::Type>,
) -> Result<LLVMTypeRef, TranslateError> {
let elem_type = get_type(context, elem_type)?;
Ok(unsafe { LLVMArrayType2(elem_type, count) })
use std::fmt::Write;
let (mut name, types) = elem_types.try_fold(
("struct".to_string(), Vec::new()),
|(mut name, mut types), t| {
name.push('.');
if let ast::Type::Scalar(scalar) = t {
write!(name, "{}", LLVMTypeDisplay(*scalar)).ok();
} else {
return Err(error_unreachable());
}
types.push(get_type(context, t)?);
Ok((name, types))
},
)?;
name.push('\0');
let mut struct_type = unsafe { LLVMGetTypeByName2(context, name.as_ptr().cast()) };
if struct_type.is_null() {
struct_type = create_struct_type(context, name, types);
}
Ok(struct_type)
}
fn check_multiple_return_types<'a>(
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
) -> Result<(), TranslateError> {
let err_msg = "Only (.b32, .pred) multiple return types are supported";
let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
match (first, second) {
(ast::Type::Scalar(first), ast::Type::Scalar(second)) => {
if first.size_of() != 4 || second.size_of() != 1 {
return Err(error_todo_msg(err_msg));
}
}
_ => return Err(error_todo_msg(err_msg)),
fn create_struct_type(
context: LLVMContextRef,
name: String,
mut elem_types: Vec<LLVMTypeRef>,
) -> LLVMTypeRef {
let llvm_type = unsafe { LLVMStructCreateNamed(context, name.as_ptr().cast()) };
unsafe {
LLVMStructSetBody(
llvm_type,
elem_types.as_mut_ptr(),
elem_types.len() as u32,
0,
)
}
Ok(())
llvm_type
}
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
mut return_args: impl DoubleEndedIterator<Item = &'a ast::Type>
+ ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, &return_args.next().unwrap())?,
_ => {
check_multiple_return_types(return_args)?;
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
}
_ => get_or_create_struct_type(context, return_args)?,
};
Ok(unsafe {

View file

@ -24,7 +24,8 @@ mod normalize_basic_blocks;
mod normalize_identifiers2;
mod normalize_predicates2;
mod remove_unreachable_basic_blocks;
mod replace_instructions_with_function_calls;
mod replace_instructions_with_functions;
mod replace_instructions_with_functions_fp_required;
mod replace_known_functions;
mod resolve_function_pointers;
@ -68,12 +69,14 @@ pub fn to_llvm_module<'input>(
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives =
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
let directives = remove_unreachable_basic_blocks::run(directives)?;
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let context = llvm::Context::new();
@ -235,6 +238,15 @@ enum Statement<I, P: ast::Operand> {
VectorRead(VectorRead),
VectorWrite(VectorWrite),
SetMode(ModeRegister),
// This instruction is a nop, it serves as a marker to indicate that the
// next instruction requires certain floating-point modes to be set.
// Some transcendentals compile to a sequence of instructions that
// require certain modes to be set _mid-function_.
// See replace_instructions_with_functions_fp_required pass for details
FpModeRequired {
ftz_f32: Option<bool>,
rnd_f32: Option<ast::RoundingMode>,
},
FpSaturate {
dst: SpirvWord,
src: SpirvWord,
@ -541,6 +553,9 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
)?;
Statement::FpSaturate { dst, src, type_ }
}
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
Statement::FpModeRequired { ftz_f32, rnd_f32 }
}
})
}
}

View file

@ -1,4 +1,5 @@
use super::*;
use smallvec::*;
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
@ -71,13 +72,136 @@ fn run_statements<'input>(
statements
.into_iter()
.map(|statement| {
Ok(match statement {
Statement::Instruction(instruction) => {
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
Ok::<SmallVec<[_; 3]>, _>(match statement {
Statement::Instruction(ast::Instruction::ShflSync {
data,
arguments:
ast::ShflSyncArgs {
dst_pred: Some(dst_pred),
dst,
src,
src_lane,
src_opts,
src_membermask,
},
}) => {
let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx",
};
let packed_var = resolver.register_unnamed(Some((
ast::Type::Vector(2, ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)));
let dst_pred_wide = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)));
let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat();
let return_arguments = vec![(
ast::Type::Vector(2, ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)];
let input_arguments = vec![
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
];
let func = match fn_declarations.entry(full_name.into()) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let name = resolver.register_named(name, None);
vacant_entry.insert((
to_variables(resolver, &return_arguments),
name,
to_variables(resolver, &input_arguments),
));
name
}
};
smallvec![
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments: vec![(
ast::Type::Vector(2, ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
)],
input_arguments: vec![
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::U32),
ptx_parser::StateSpace::Reg,
),
],
},
arguments: ptx_parser::CallArgs {
return_arguments: vec![packed_var],
func,
input_arguments: vec![src, src_lane, src_opts, src_membermask],
},
}),
Statement::RepackVector(RepackVectorDetails {
is_extract: true,
typ: ast::ScalarType::U32,
packed: packed_var,
unpacked: vec![dst, dst_pred_wide],
relaxed_type_check: false,
}),
Statement::Instruction(ast::Instruction::Cvt {
data: ast::CvtDetails {
from: ast::ScalarType::U32,
to: ast::ScalarType::Pred,
mode: ast::CvtMode::Truncate
},
arguments: ast::CvtArgs {
dst: dst_pred,
src: dst_pred_wide,
},
})
]
}
s => s,
Statement::<ast::Instruction<SpirvWord>, SpirvWord>::Instruction(instruction) => {
smallvec![
Statement::<ast::Instruction<SpirvWord>, SpirvWord>::Instruction(
run_instruction(resolver, fn_declarations, instruction)?
)
]
}
s => smallvec![s],
})
})
.flat_map(|result| match result {
Ok(vec) => vec.into_iter().map(|item| Ok(item)).collect(),
Err(er) => vec![Err(er)],
})
.collect::<Result<Vec<_>, _>>()
}
@ -141,6 +265,52 @@ fn run_instruction<'input>(
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {
kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven),
flush_to_zero: Some(true),
..
},
..
} => {
let name = "sqrt_rn_ftz_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Sqrt {
data:
ast::RcpData {
kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven),
..
},
..
} => {
let name = "sqrt_rn_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
kind: ast::DivFloatKind::Rounding(_),
flush_to_zero: Some(true),
..
}),
..
} => {
let name = "div_rn_ftz_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
kind: ast::DivFloatKind::Rounding(_),
..
}),
..
} => {
let name = "div_rn_f32";
to_call(resolver, fn_declarations, name.into(), i)?
}
i @ ptx_parser::Instruction::Bfi { data, .. } => {
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
@ -163,23 +333,24 @@ fn run_instruction<'input>(
ptx_parser::Instruction::BarRed { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync { data, arguments } => {
ptx_parser::Instruction::ShflSync {
data,
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },
} => {
let mode = match data.mode {
ptx_parser::ShuffleMode::Up => "up",
ptx_parser::ShuffleMode::Down => "down",
ptx_parser::ShuffleMode::BFly => "bfly",
ptx_parser::ShuffleMode::Idx => "idx",
};
let pred = if arguments.dst_pred.is_some() {
"_pred"
} else {
""
};
to_call(
resolver,
fn_declarations,
format!("shfl_sync_{}_b32{}", mode, pred).into(),
ptx_parser::Instruction::ShflSync { data, arguments },
format!("shfl_sync_{}_b32", mode).into(),
ptx_parser::Instruction::ShflSync {
data,
arguments: orig_arguments,
},
)?
}
i @ ptx_parser::Instruction::Nanosleep { .. } => {

View file

@ -0,0 +1,367 @@
// This pass exists specifically to replace the `div.rn.ftz.f32` instruction
// with a function call. One inherent weirdness of the replacement function is
// that it requires different rounding mode for the first part of the
// division and the second part. The first part is executed with FTZ disabled
// and the second part with FTZ enabled.
// For this reason we can't handle this past FTZ mode insertion without making
// the function read and restore the FTZ mode. For this reason we split the
// replacement function in two functions and prefix them with a noop
// (FpModeRequired) that carries the FTZ mode information.
use super::*;
use ptx_parser as ast;
use smallvec::smallvec;
use smallvec::SmallVec;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut imports = None;
let directives = directives
.into_iter()
.map(|directive| run_directive(resolver, directive, &mut imports))
.collect::<Result<Vec<_>, _>>()?;
Ok(match imports {
Some(imports) => {
let mut result = Vec::with_capacity(directives.len() + 2);
result.extend([
Directive2::Method(Function2 {
return_arguments: vec![
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U8),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
],
name: imports.part1,
input_arguments: vec![
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
],
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}),
Directive2::Method(Function2 {
return_arguments: vec![ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
}],
name: imports.part2,
input_arguments: vec![
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
ast::Variable {
name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U8),
ast::StateSpace::Reg,
))),
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
},
],
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}),
]);
result.extend(directives);
result
}
None => directives,
})
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
imports: &mut Option<FunctionImports>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method, imports)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
imports: &mut Option<FunctionImports>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
method.body = method.body.map(|body| {
body.into_iter()
.flat_map(|stmt| run_statement(resolver, stmt, imports))
.collect()
});
Ok(method)
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
stmt: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
imports: &mut Option<FunctionImports>,
) -> SmallVec<[Statement<ast::Instruction<SpirvWord>, SpirvWord>; 4]> {
match stmt {
Statement::Instruction(ast::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
flush_to_zero,
kind: ast::DivFloatKind::Rounding(rnd),
type_: ast::ScalarType::F32,
}),
arguments,
}) => {
let ftz = flush_to_zero.unwrap_or(false);
let FunctionImports { part1, part2, .. } = FunctionImports::init(imports, resolver);
let fma_4 = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
)));
let fma_1 = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
)));
let fma3_ = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
)));
let numerator_scaled_flag = resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U8),
ast::StateSpace::Reg,
)));
smallvec![
Statement::FpModeRequired {
ftz_f32: Some(false),
rnd_f32: Some(ast::RoundingMode::NearestEven),
},
Statement::Instruction(ast::Instruction::Call {
arguments: ast::CallArgs {
return_arguments: vec![fma_4, fma_1, fma3_, numerator_scaled_flag],
func: *part1,
input_arguments: vec![arguments.src1, arguments.src2],
},
data: ast::CallDetails {
uniform: false,
return_arguments: vec![
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,)
],
input_arguments: vec![
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
)
]
}
}),
Statement::FpModeRequired {
ftz_f32: Some(ftz),
rnd_f32: Some(rnd),
},
Statement::Instruction(ast::Instruction::Call {
arguments: ast::CallArgs {
return_arguments: vec![arguments.dst],
func: *part2,
input_arguments: vec![
arguments.src1,
arguments.src2,
fma_4,
fma_1,
fma3_,
numerator_scaled_flag
],
},
data: ast::CallDetails {
uniform: false,
return_arguments: vec![(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
)],
input_arguments: vec![
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(
ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg,
),
(ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,)
]
}
})
]
}
_ => smallvec![stmt],
}
}
#[derive(Clone)]
struct FunctionImports {
part1: SpirvWord,
part2: SpirvWord,
}
impl FunctionImports {
fn init<'a>(
this: &'a mut Option<FunctionImports>,
resolver: &mut GlobalStringIdentResolver2,
) -> &'a FunctionImports {
this.get_or_insert_with(|| {
let part1_name = [ZLUDA_PTX_PREFIX, "div_f32_part1"].concat();
let part1 = resolver.register_named(part1_name.into(), None);
let part2_name = [ZLUDA_PTX_PREFIX, "div_f32_part2"].concat();
let part2 = resolver.register_named(part2_name.into(), None);
FunctionImports { part1, part2 }
})
}
}

View file

@ -0,0 +1,74 @@
%struct.f32.f32.f32.i8 = type { float, float, float, i8 }
declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0
declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0
define amdgpu_kernel void @div_ftz(ptr addrspace(4) byref(i64) %"63", ptr addrspace(4) byref(i64) %"64") #1 {
%"65" = alloca i64, align 8, addrspace(5)
%"66" = alloca i64, align 8, addrspace(5)
%"67" = alloca float, align 4, addrspace(5)
%"68" = alloca float, align 4, addrspace(5)
%"69" = alloca float, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"54"
"54": ; preds = %1
%"70" = load i64, ptr addrspace(4) %"63", align 8
store i64 %"70", ptr addrspace(5) %"65", align 8
%"71" = load i64, ptr addrspace(4) %"64", align 8
store i64 %"71", ptr addrspace(5) %"66", align 8
%"73" = load i64, ptr addrspace(5) %"65", align 8
%"88" = inttoptr i64 %"73" to ptr
%"72" = load float, ptr %"88", align 4
store float %"72", ptr addrspace(5) %"67", align 4
%"74" = load i64, ptr addrspace(5) %"65", align 8
%"89" = inttoptr i64 %"74" to ptr
%"32" = getelementptr inbounds i8, ptr %"89", i64 4
%"75" = load float, ptr %"32", align 4
store float %"75", ptr addrspace(5) %"68", align 4
%"77" = load float, ptr addrspace(5) %"67", align 4
%"78" = load float, ptr addrspace(5) %"68", align 4
%"76" = fmul float %"77", %"78"
store float %"76", ptr addrspace(5) %"69", align 4
%"79" = load float, ptr addrspace(5) %"67", align 4
%"80" = load float, ptr addrspace(5) %"68", align 4
%2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"79", float %"80")
%"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0
%"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1
%"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2
%"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3
br label %"57"
"57": ; preds = %"54"
call void @llvm.amdgcn.s.setreg(i32 6401, i32 0)
br label %"55"
"55": ; preds = %"57"
%"82" = load float, ptr addrspace(5) %"67", align 4
%"83" = load float, ptr addrspace(5) %"68", align 4
%"81" = call float @__zluda_ptx_impl_div_f32_part2(float %"82", float %"83", float %"37", float %"38", float %"39", i8 %"40")
store float %"81", ptr addrspace(5) %"67", align 4
br label %"56"
"56": ; preds = %"55"
%"84" = load i64, ptr addrspace(5) %"66", align 8
%"85" = load float, ptr addrspace(5) %"67", align 4
%"90" = inttoptr i64 %"84" to ptr
store float %"85", ptr %"90", align 4
%"86" = load i64, ptr addrspace(5) %"66", align 8
%"91" = inttoptr i64 %"86" to ptr
%"34" = getelementptr inbounds i8, ptr %"91", i64 4
%"87" = load float, ptr addrspace(5) %"69", align 4
store float %"87", ptr %"34", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind willreturn
declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #2 = { nocallback nofree nosync nounwind willreturn }

View file

@ -0,0 +1,71 @@
%struct.f32.f32.f32.i8 = type { float, float, float, i8 }
declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0
declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0
define amdgpu_kernel void @div_noftz(ptr addrspace(4) byref(i64) %"62", ptr addrspace(4) byref(i64) %"63") #1 {
%"64" = alloca i64, align 8, addrspace(5)
%"65" = alloca i64, align 8, addrspace(5)
%"66" = alloca float, align 4, addrspace(5)
%"67" = alloca float, align 4, addrspace(5)
%"68" = alloca float, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"54"
"54": ; preds = %1
%"69" = load i64, ptr addrspace(4) %"62", align 8
store i64 %"69", ptr addrspace(5) %"64", align 8
%"70" = load i64, ptr addrspace(4) %"63", align 8
store i64 %"70", ptr addrspace(5) %"65", align 8
%"72" = load i64, ptr addrspace(5) %"64", align 8
%"87" = inttoptr i64 %"72" to ptr
%"71" = load float, ptr %"87", align 4
store float %"71", ptr addrspace(5) %"66", align 4
%"73" = load i64, ptr addrspace(5) %"64", align 8
%"88" = inttoptr i64 %"73" to ptr
%"32" = getelementptr inbounds i8, ptr %"88", i64 4
%"74" = load float, ptr %"32", align 4
store float %"74", ptr addrspace(5) %"67", align 4
%"76" = load float, ptr addrspace(5) %"66", align 4
%"77" = load float, ptr addrspace(5) %"67", align 4
%"75" = fmul float %"76", %"77"
store float %"75", ptr addrspace(5) %"68", align 4
call void @llvm.amdgcn.s.setreg(i32 6401, i32 3)
%"78" = load float, ptr addrspace(5) %"66", align 4
%"79" = load float, ptr addrspace(5) %"67", align 4
%2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"78", float %"79")
%"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0
%"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1
%"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2
%"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3
br label %"55"
"55": ; preds = %"54"
%"81" = load float, ptr addrspace(5) %"66", align 4
%"82" = load float, ptr addrspace(5) %"67", align 4
%"80" = call float @__zluda_ptx_impl_div_f32_part2(float %"81", float %"82", float %"37", float %"38", float %"39", i8 %"40")
store float %"80", ptr addrspace(5) %"66", align 4
br label %"56"
"56": ; preds = %"55"
%"83" = load i64, ptr addrspace(5) %"65", align 8
%"84" = load float, ptr addrspace(5) %"66", align 4
%"89" = inttoptr i64 %"83" to ptr
store float %"84", ptr %"89", align 4
%"85" = load i64, ptr addrspace(5) %"65", align 8
%"90" = inttoptr i64 %"85" to ptr
%"34" = getelementptr inbounds i8, ptr %"90", i64 4
%"86" = load float, ptr addrspace(5) %"68", align 4
store float %"86", ptr %"34", align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind willreturn
declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #2 = { nocallback nofree nosync nounwind willreturn }

View file

@ -1,4 +1,6 @@
define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
declare float @__zluda_ptx_impl_sqrt_approx_f32(float) #0
define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 {
%"32" = alloca i64, align 8, addrspace(5)
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca float, align 4, addrspace(5)
@ -17,7 +19,7 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace
%"37" = load float, ptr %"43", align 4
store float %"37", ptr addrspace(5) %"34", align 4
%"40" = load float, ptr addrspace(5) %"34", align 4
%"39" = call float @llvm.amdgcn.sqrt.f32(float %"40")
%"39" = call float @__zluda_ptx_impl_sqrt_approx_f32(float %"40")
store float %"39", ptr addrspace(5) %"34", align 4
%"41" = load i64, ptr addrspace(5) %"33", align 8
%"42" = load float, ptr addrspace(5) %"34", align 4
@ -26,8 +28,5 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.amdgcn.sqrt.f32(float) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,32 @@
declare float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float) #0
define amdgpu_kernel void @sqrt_rn_ftz(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 {
%"32" = alloca i64, align 8, addrspace(5)
%"33" = alloca i64, align 8, addrspace(5)
%"34" = alloca float, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"29"
"29": ; preds = %1
%"35" = load i64, ptr addrspace(4) %"30", align 8
store i64 %"35", ptr addrspace(5) %"32", align 8
%"36" = load i64, ptr addrspace(4) %"31", align 8
store i64 %"36", ptr addrspace(5) %"33", align 8
%"38" = load i64, ptr addrspace(5) %"32", align 8
%"43" = inttoptr i64 %"38" to ptr
%"37" = load float, ptr %"43", align 4
store float %"37", ptr addrspace(5) %"34", align 4
%"40" = load float, ptr addrspace(5) %"34", align 4
%"39" = call float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float %"40")
store float %"39", ptr addrspace(5) %"34", align 4
%"41" = load i64, ptr addrspace(5) %"33", align 8
%"42" = load float, ptr addrspace(5) %"34", align 4
%"44" = inttoptr i64 %"41" to ptr
store float %"42", ptr %"44", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,27 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry div_ftz(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
.reg .f32 temp2;
.reg .f32 force_ftz_mode;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
ld.f32 temp2, [in_addr+4];
// DO NOT REMOVE THIS MULTIPLICATION
mul.f32 force_ftz_mode, temp1, temp2;
div.ftz.rn.f32 temp1, temp1, temp2;
st.f32 [out_addr], temp1;
st.f32 [out_addr+4], force_ftz_mode;
ret;
}

View file

@ -0,0 +1,27 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry div_noftz(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
.reg .f32 temp2;
.reg .f32 force_ftz_mode;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
ld.f32 temp2, [in_addr+4];
// DO NOT REMOVE THIS MULTIPLICATION
mul.ftz.f32 force_ftz_mode, temp1, temp2;
div.rn.f32 temp1, temp1, temp2;
st.f32 [out_addr], temp1;
st.f32 [out_addr+4], force_ftz_mode;
ret;
}

View file

@ -173,6 +173,7 @@ test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(sqrt_rn_ftz, [0x1u32], [0x0u32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
test_ptx!(neg, [181i32], [-181i32]);
test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]);
@ -279,6 +280,19 @@ test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
test_ptx!(warp_sz, [0u8], [32u8]);
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]);
// Two test below test very important compiler feature, make sure that you
// understand fully what's going on before you touch it.
// The problem is that the full-precision division gets legalized by LLVM
// using __module attribute__.
// In the two tests below we deliberately force our compiler to emit
// different a module that has a different module-level denormal attribute
// from the denormal attribute of the instruction to catch cases like this
test_ptx!(div_ftz, [0x16A2028Du32, 0x5E89F6AE], [0x0, 900636404u32]);
test_ptx!(
div_noftz,
[0x16A2028Du32, 0x5E89F6AE],
[0x26u32, 900636404u32]
);
test_ptx!(nanosleep, [0u64], [0u64]);
test_ptx!(shf_l, [0x12345678u32, 0x9abcdef0u32, 12], [0xcdef0123u32]);

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry sqrt_rn_ftz(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
sqrt.rn.ftz.f32 temp1, temp1;
st.f32 [out_addr], temp1;
ret;
}

View file

@ -1058,22 +1058,12 @@ impl From<ScalarType> for Type {
#[derive(Clone)]
pub struct MovDetails {
pub typ: super::Type,
pub src_is_address: bool,
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
// This is in use by auto-generated movs
pub relaxed_src2_conv: bool,
}
impl MovDetails {
pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
MovDetails {
typ: Type::maybe_vector(vector, scalar),
src_is_address: false,
dst_width: 0,
src_width: 0,
relaxed_src2_conv: false,
}
}
}