mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-05 17:16:10 +00:00
Fix how full-precision fp32 sqrt and div are handled (#467)
Previously, when compiling full precision `sqrt`/`div` we'd leave it to the LLVM. LLVM looks at module's `denormal-fp-math-f32` mode, which is incompatible with how we handle denormals and could give wrong results in certain edge cases. Instead handle it fully inside ZLUDA
This commit is contained in:
parent
a420601128
commit
65367f04ee
18 changed files with 1092 additions and 139 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
@ -2603,6 +2603,7 @@ dependencies = [
|
|||
"quick-error",
|
||||
"rustc-hash 2.0.0",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"strum 0.26.3",
|
||||
"strum_macros 0.26.4",
|
||||
"tempfile",
|
||||
|
@ -2940,9 +2941,9 @@ checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
|
|||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.13.2"
|
||||
version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "sprs"
|
||||
|
|
|
@ -21,6 +21,7 @@ petgraph = "0.7.1"
|
|||
microlp = "0.2.11"
|
||||
int-enum = "1.1"
|
||||
unwrap_or = "1.0.1"
|
||||
smallvec = "1.15.1"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
Binary file not shown.
|
@ -10,7 +10,9 @@
|
|||
|
||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
||||
#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__
|
||||
#define DECLARE_ATTR(TYPE, NAME) \
|
||||
extern const TYPE ATTR(NAME) \
|
||||
__device__
|
||||
|
||||
extern "C"
|
||||
{
|
||||
|
@ -100,19 +102,6 @@ extern "C"
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ uint32_t sub_sat(uint32_t x, uint32_t y)
|
||||
{
|
||||
uint32_t result;
|
||||
if (__builtin_sub_overflow(x, y, &result))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len)
|
||||
{
|
||||
// NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len`
|
||||
|
@ -122,7 +111,7 @@ extern "C"
|
|||
if (pos >= 64)
|
||||
return (base >> 63U);
|
||||
if (add_sat(pos, len) >= 64)
|
||||
len = sub_sat(64, pos);
|
||||
len = 64 - pos;
|
||||
return (base << (64U - pos - len)) >> (64U - len);
|
||||
}
|
||||
|
||||
|
@ -174,11 +163,8 @@ extern "C"
|
|||
BAR_RED_IMPL(and);
|
||||
BAR_RED_IMPL(or);
|
||||
|
||||
struct ShflSyncResult
|
||||
{
|
||||
uint32_t output;
|
||||
bool in_bounds;
|
||||
};
|
||||
|
||||
typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
||||
|
||||
// shfl.sync opts consists of two values, the warp end ID and the subsection mask.
|
||||
//
|
||||
|
@ -192,7 +178,6 @@ extern "C"
|
|||
// The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync
|
||||
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
|
||||
// bounds check.
|
||||
|
||||
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
|
||||
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
|
||||
{ \
|
||||
|
@ -208,12 +193,12 @@ extern "C"
|
|||
idx = self; \
|
||||
} \
|
||||
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
|
||||
return {(uint32_t)output, !out_of_bounds}; \
|
||||
return {(uint32_t)output, uint32_t(!out_of_bounds)}; \
|
||||
} \
|
||||
\
|
||||
uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \
|
||||
{ \
|
||||
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
|
||||
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \
|
||||
}
|
||||
|
||||
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
|
||||
|
@ -226,7 +211,8 @@ extern "C"
|
|||
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
|
||||
|
||||
DECLARE_ATTR(uint32_t, CLOCK_RATE);
|
||||
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
|
||||
void FUNC(nanosleep_u32)(uint32_t nanoseconds)
|
||||
{
|
||||
// clock_rate is in kHz
|
||||
uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000;
|
||||
uint64_t cycles = nanoseconds * cycles_per_ns;
|
||||
|
@ -335,4 +321,157 @@ extern "C"
|
|||
else
|
||||
return value;
|
||||
}
|
||||
|
||||
// Logic taken from legalizeFSQRTF32/lowerFSQRTF32 in LLVM AMDGPU target
|
||||
__device__ static float precise_square_root(float x, bool needs_denorm_handling)
|
||||
{
|
||||
|
||||
// Constants for denormal handling
|
||||
const float scale_threshold = 0x1.0p-96f; // Very small value threshold
|
||||
const float scale_up_factor = 0x1.0p+32f; // 2^32
|
||||
const float scale_down_factor = 0x1.0p-16f; // 2^-16
|
||||
|
||||
// Check if input needs scaling (for very small values)
|
||||
bool need_scale = scale_threshold > x;
|
||||
auto scaled = scale_up_factor * x;
|
||||
|
||||
// Scale up input if needed
|
||||
float sqrt_x = need_scale ? scaled : x;
|
||||
|
||||
float sqrt_s;
|
||||
|
||||
// Check if we need special denormal handling
|
||||
|
||||
if (needs_denorm_handling)
|
||||
{
|
||||
// Use hardware sqrt as initial approximation
|
||||
sqrt_s = __builtin_sqrtf(sqrt_x); // Or equivalent hardware instruction
|
||||
|
||||
// Bit manipulations to get next values up/down
|
||||
uint32_t sqrt_s_bits = std::bit_cast<uint32_t>(sqrt_s);
|
||||
|
||||
// Next value down (subtract 1 from bit pattern)
|
||||
uint32_t sqrt_s_next_down_bits = sqrt_s_bits - 1;
|
||||
float sqrt_s_next_down = std::bit_cast<float>(sqrt_s_next_down_bits);
|
||||
|
||||
// Calculate residual: x - sqrt_next_down * sqrt
|
||||
float neg_sqrt_s_next_down = -sqrt_s_next_down;
|
||||
float sqrt_vp = std::fma(neg_sqrt_s_next_down, sqrt_s, sqrt_x);
|
||||
|
||||
// Next value up (add 1 to bit pattern)
|
||||
uint32_t sqrt_s_next_up_bits = sqrt_s_bits + 1;
|
||||
float sqrt_s_next_up = std::bit_cast<float>(sqrt_s_next_up_bits);
|
||||
|
||||
// Calculate residual: x - sqrt_next_up * sqrt
|
||||
float neg_sqrt_s_next_up = -sqrt_s_next_up;
|
||||
float sqrt_vs = std::fma(neg_sqrt_s_next_up, sqrt_s, sqrt_x);
|
||||
|
||||
// Select correctly rounded result
|
||||
if (sqrt_vp <= 0.0f)
|
||||
{
|
||||
sqrt_s = sqrt_s_next_down;
|
||||
}
|
||||
|
||||
if (sqrt_vs > 0.0f)
|
||||
{
|
||||
sqrt_s = sqrt_s_next_up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use Newton-Raphson method with reciprocal square root
|
||||
|
||||
// Initial approximation
|
||||
float sqrt_r = __builtin_amdgcn_rsqf(sqrt_x); // Or equivalent hardware 1/sqrt instruction
|
||||
sqrt_s = sqrt_x * sqrt_r;
|
||||
|
||||
// Refine approximation
|
||||
float half = 0.5f;
|
||||
float sqrt_h = sqrt_r * half;
|
||||
float neg_sqrt_h = -sqrt_h;
|
||||
|
||||
// Calculate error term
|
||||
float sqrt_e = std::fma(neg_sqrt_h, sqrt_s, half);
|
||||
|
||||
// First refinement
|
||||
sqrt_h = std::fma(sqrt_h, sqrt_e, sqrt_h);
|
||||
sqrt_s = std::fma(sqrt_s, sqrt_e, sqrt_s);
|
||||
|
||||
// Second refinement
|
||||
float neg_sqrt_s = -sqrt_s;
|
||||
float sqrt_d = std::fma(neg_sqrt_s, sqrt_s, sqrt_x);
|
||||
sqrt_s = std::fma(sqrt_d, sqrt_h, sqrt_s);
|
||||
}
|
||||
|
||||
// Scale back if input was scaled
|
||||
if (need_scale)
|
||||
{
|
||||
sqrt_s *= scale_down_factor;
|
||||
}
|
||||
|
||||
// Special case handling for zero and infinity
|
||||
bool is_zero_or_inf = __builtin_isfpclass(sqrt_x, __FPCLASS_POSINF | __FPCLASS_POSZERO | __FPCLASS_NEGZERO);
|
||||
|
||||
return is_zero_or_inf ? sqrt_x : sqrt_s;
|
||||
}
|
||||
|
||||
float FUNC(sqrt_rn_f32)(float x)
|
||||
{
|
||||
return precise_square_root(x, true);
|
||||
}
|
||||
|
||||
float FUNC(sqrt_rn_ftz_f32)(float x)
|
||||
{
|
||||
return precise_square_root(x, false);
|
||||
}
|
||||
|
||||
struct DivRnFtzF32Part1Result
|
||||
{
|
||||
float fma_4;
|
||||
float fma_1;
|
||||
float fma_3;
|
||||
uint8_t numerator_scaled_flag;
|
||||
};
|
||||
|
||||
DivRnFtzF32Part1Result FUNC(div_f32_part1)(float lhs, float rhs)
|
||||
{
|
||||
float one = 1.0f;
|
||||
|
||||
// Division scale operations
|
||||
bool denominator_scaled_flag;
|
||||
float denominator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, false, &denominator_scaled_flag);
|
||||
|
||||
bool numerator_scaled_flag;
|
||||
float numerator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, true, &numerator_scaled_flag);
|
||||
|
||||
// Reciprocal approximation
|
||||
float approx_rcp = __builtin_amdgcn_rcpf(denominator_scaled);
|
||||
float neg_div_scale0 = -denominator_scaled;
|
||||
|
||||
// Perform division approximation steps
|
||||
float fma_0 = fmaf(neg_div_scale0, approx_rcp, one);
|
||||
float fma_1 = fmaf(fma_0, approx_rcp, approx_rcp);
|
||||
float mul = numerator_scaled * fma_1;
|
||||
float fma_2 = fmaf(neg_div_scale0, mul, numerator_scaled);
|
||||
float fma_3 = fmaf(fma_2, fma_1, mul);
|
||||
float fma_4 = fmaf(neg_div_scale0, fma_3, numerator_scaled);
|
||||
return {fma_4, fma_1, fma_3, numerator_scaled_flag};
|
||||
}
|
||||
|
||||
__device__ static float div_f32_part2(float x, float y, DivRnFtzF32Part1Result part1)
|
||||
{
|
||||
float fmas = __builtin_amdgcn_div_fmasf(part1.fma_4, part1.fma_1, part1.fma_3, part1.numerator_scaled_flag);
|
||||
float result = __builtin_amdgcn_div_fixupf(fmas, y, x);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
float FUNC(div_f32_part2)(float x, float y,
|
||||
float fma_4,
|
||||
float fma_1,
|
||||
float fma_3,
|
||||
uint8_t numerator_scaled_flag)
|
||||
{
|
||||
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -803,6 +803,13 @@ fn create_control_flow_graph(
|
|||
let modes = get_modes(instruction);
|
||||
bb_state.append(modes);
|
||||
}
|
||||
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
|
||||
bb_state.append(InstructionModes::new(
|
||||
ast::ScalarType::F32,
|
||||
ftz_f32.map(DenormalMode::from_ftz),
|
||||
rnd_f32.map(RoundingMode::from_ast),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -1021,6 +1028,16 @@ fn apply_global_mode_controls(
|
|||
let modes = get_modes(&instruction);
|
||||
bb_state.insert(&mut result, modes)?;
|
||||
}
|
||||
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
|
||||
bb_state.insert(
|
||||
&mut result,
|
||||
InstructionModes::new(
|
||||
ast::ScalarType::F32,
|
||||
ftz_f32.map(DenormalMode::from_ftz),
|
||||
rnd_f32.map(RoundingMode::from_ast),
|
||||
),
|
||||
)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
result.push(statement);
|
||||
|
|
|
@ -397,6 +397,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
|
||||
Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?,
|
||||
// No-op
|
||||
Statement::FpModeRequired { .. } => {}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -825,25 +827,12 @@ impl<'a> MethodEmitContext<'a> {
|
|||
match &*arguments.return_arguments {
|
||||
[] => {}
|
||||
[name] => self.resolver.register(*name, llvm_call),
|
||||
[b32, pred] => {
|
||||
self.resolver.with_result(*b32, |name| unsafe {
|
||||
LLVMBuildExtractValue(self.builder, llvm_call, 0, name)
|
||||
});
|
||||
self.resolver.with_result(*pred, |name| unsafe {
|
||||
let extracted =
|
||||
LLVMBuildExtractValue(self.builder, llvm_call, 1, LLVM_UNNAMED.as_ptr());
|
||||
LLVMBuildTrunc(
|
||||
self.builder,
|
||||
extracted,
|
||||
get_scalar_type(self.context, ast::ScalarType::Pred),
|
||||
name,
|
||||
)
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
return Err(error_todo_msg(
|
||||
"Only two return arguments (.b32, .pred) currently supported",
|
||||
))
|
||||
args => {
|
||||
for (i, arg) in args.iter().copied().enumerate() {
|
||||
self.resolver.with_result(arg, |name| unsafe {
|
||||
LLVMBuildExtractValue(self.builder, llvm_call, i as u32, name)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -992,44 +981,28 @@ impl<'a> MethodEmitContext<'a> {
|
|||
unsafe {
|
||||
LLVMSetAlignment(load, type_.layout().align() as u32);
|
||||
}
|
||||
Ok(load)
|
||||
Ok((load, type_))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
match &*loads {
|
||||
[] => unsafe { LLVMBuildRetVoid(self.builder) },
|
||||
[value] => unsafe { LLVMBuildRet(self.builder, *value) },
|
||||
_ => {
|
||||
check_multiple_return_types(values.iter().map(|(_, type_)| type_))?;
|
||||
let array_ty =
|
||||
get_array_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?;
|
||||
let insert_b32 = unsafe {
|
||||
LLVMBuildInsertValue(
|
||||
self.builder,
|
||||
LLVMGetPoison(array_ty),
|
||||
loads[0],
|
||||
0,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let zext_pred = unsafe {
|
||||
LLVMBuildZExt(
|
||||
self.builder,
|
||||
loads[1],
|
||||
get_type(self.context, &ast::Type::Scalar(ast::ScalarType::B32))?,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let insert_pred = unsafe {
|
||||
LLVMBuildInsertValue(
|
||||
self.builder,
|
||||
insert_b32,
|
||||
zext_pred,
|
||||
1,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
unsafe { LLVMBuildRet(self.builder, insert_pred) }
|
||||
[(value, _)] => unsafe { LLVMBuildRet(self.builder, *value) },
|
||||
loads => {
|
||||
let struct_type =
|
||||
get_or_create_struct_type(self.context, loads.iter().map(|(_, type_)| *type_))?;
|
||||
let mut value = unsafe { LLVMGetUndef(struct_type) };
|
||||
for (i, (load, _)) in loads.iter().enumerate() {
|
||||
value = unsafe {
|
||||
LLVMBuildInsertValue(
|
||||
self.builder,
|
||||
value,
|
||||
*load,
|
||||
i as u32,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
}
|
||||
unsafe { LLVMBuildRet(self.builder, value) }
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
|
@ -1898,10 +1871,10 @@ impl<'a> MethodEmitContext<'a> {
|
|||
to: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
llvm_func: unsafe extern "C" fn(
|
||||
arg1: LLVMBuilderRef,
|
||||
Val: LLVMValueRef,
|
||||
DestTy: LLVMTypeRef,
|
||||
Name: *const i8,
|
||||
LLVMBuilderRef,
|
||||
LLVMValueRef,
|
||||
LLVMTypeRef,
|
||||
*const i8,
|
||||
) -> LLVMValueRef,
|
||||
) -> Result<(), TranslateError> {
|
||||
let type_ = get_scalar_type(self.context, to);
|
||||
|
@ -2928,46 +2901,60 @@ fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, T
|
|||
})
|
||||
}
|
||||
|
||||
fn get_array_type<'a>(
|
||||
fn get_or_create_struct_type<'a>(
|
||||
context: LLVMContextRef,
|
||||
elem_type: &'a ast::Type,
|
||||
count: u64,
|
||||
mut elem_types: impl Iterator<Item = &'a ast::Type>,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
let elem_type = get_type(context, elem_type)?;
|
||||
Ok(unsafe { LLVMArrayType2(elem_type, count) })
|
||||
use std::fmt::Write;
|
||||
let (mut name, types) = elem_types.try_fold(
|
||||
("struct".to_string(), Vec::new()),
|
||||
|(mut name, mut types), t| {
|
||||
name.push('.');
|
||||
if let ast::Type::Scalar(scalar) = t {
|
||||
write!(name, "{}", LLVMTypeDisplay(*scalar)).ok();
|
||||
} else {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
types.push(get_type(context, t)?);
|
||||
Ok((name, types))
|
||||
},
|
||||
)?;
|
||||
name.push('\0');
|
||||
let mut struct_type = unsafe { LLVMGetTypeByName2(context, name.as_ptr().cast()) };
|
||||
if struct_type.is_null() {
|
||||
struct_type = create_struct_type(context, name, types);
|
||||
}
|
||||
Ok(struct_type)
|
||||
}
|
||||
|
||||
fn check_multiple_return_types<'a>(
|
||||
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let err_msg = "Only (.b32, .pred) multiple return types are supported";
|
||||
|
||||
let first = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
|
||||
let second = return_args.next().ok_or_else(|| error_todo_msg(err_msg))?;
|
||||
match (first, second) {
|
||||
(ast::Type::Scalar(first), ast::Type::Scalar(second)) => {
|
||||
if first.size_of() != 4 || second.size_of() != 1 {
|
||||
return Err(error_todo_msg(err_msg));
|
||||
}
|
||||
}
|
||||
_ => return Err(error_todo_msg(err_msg)),
|
||||
fn create_struct_type(
|
||||
context: LLVMContextRef,
|
||||
name: String,
|
||||
mut elem_types: Vec<LLVMTypeRef>,
|
||||
) -> LLVMTypeRef {
|
||||
let llvm_type = unsafe { LLVMStructCreateNamed(context, name.as_ptr().cast()) };
|
||||
unsafe {
|
||||
LLVMStructSetBody(
|
||||
llvm_type,
|
||||
elem_types.as_mut_ptr(),
|
||||
elem_types.len() as u32,
|
||||
0,
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
llvm_type
|
||||
}
|
||||
|
||||
fn get_function_type<'a>(
|
||||
context: LLVMContextRef,
|
||||
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
mut return_args: impl DoubleEndedIterator<Item = &'a ast::Type>
|
||||
+ ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
|
||||
let return_type = match return_args.len() {
|
||||
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||
1 => get_type(context, &return_args.next().unwrap())?,
|
||||
_ => {
|
||||
check_multiple_return_types(return_args)?;
|
||||
get_array_type(context, &ast::Type::Scalar(ast::ScalarType::B32), 2)?
|
||||
}
|
||||
_ => get_or_create_struct_type(context, return_args)?,
|
||||
};
|
||||
|
||||
Ok(unsafe {
|
||||
|
|
|
@ -24,7 +24,8 @@ mod normalize_basic_blocks;
|
|||
mod normalize_identifiers2;
|
||||
mod normalize_predicates2;
|
||||
mod remove_unreachable_basic_blocks;
|
||||
mod replace_instructions_with_function_calls;
|
||||
mod replace_instructions_with_functions;
|
||||
mod replace_instructions_with_functions_fp_required;
|
||||
mod replace_known_functions;
|
||||
mod resolve_function_pointers;
|
||||
|
||||
|
@ -68,12 +69,14 @@ pub fn to_llvm_module<'input>(
|
|||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives =
|
||||
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
||||
let directives = remove_unreachable_basic_blocks::run(directives)?;
|
||||
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
|
||||
let context = llvm::Context::new();
|
||||
|
@ -235,6 +238,15 @@ enum Statement<I, P: ast::Operand> {
|
|||
VectorRead(VectorRead),
|
||||
VectorWrite(VectorWrite),
|
||||
SetMode(ModeRegister),
|
||||
// This instruction is a nop, it serves as a marker to indicate that the
|
||||
// next instruction requires certain floating-point modes to be set.
|
||||
// Some transcendentals compile to a sequence of instructions that
|
||||
// require certain modes to be set _mid-function_.
|
||||
// See replace_instructions_with_functions_fp_required pass for details
|
||||
FpModeRequired {
|
||||
ftz_f32: Option<bool>,
|
||||
rnd_f32: Option<ast::RoundingMode>,
|
||||
},
|
||||
FpSaturate {
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
|
@ -541,6 +553,9 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||
)?;
|
||||
Statement::FpSaturate { dst, src, type_ }
|
||||
}
|
||||
Statement::FpModeRequired { ftz_f32, rnd_f32 } => {
|
||||
Statement::FpModeRequired { ftz_f32, rnd_f32 }
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::*;
|
||||
use smallvec::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
|
@ -71,13 +72,136 @@ fn run_statements<'input>(
|
|||
statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(instruction) => {
|
||||
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
||||
Ok::<SmallVec<[_; 3]>, _>(match statement {
|
||||
Statement::Instruction(ast::Instruction::ShflSync {
|
||||
data,
|
||||
arguments:
|
||||
ast::ShflSyncArgs {
|
||||
dst_pred: Some(dst_pred),
|
||||
dst,
|
||||
src,
|
||||
src_lane,
|
||||
src_opts,
|
||||
src_membermask,
|
||||
},
|
||||
}) => {
|
||||
let mode = match data.mode {
|
||||
ptx_parser::ShuffleMode::Up => "up",
|
||||
ptx_parser::ShuffleMode::Down => "down",
|
||||
ptx_parser::ShuffleMode::BFly => "bfly",
|
||||
ptx_parser::ShuffleMode::Idx => "idx",
|
||||
};
|
||||
let packed_var = resolver.register_unnamed(Some((
|
||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)));
|
||||
let dst_pred_wide = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)));
|
||||
let full_name = [ZLUDA_PTX_PREFIX, "shfl_sync_", mode, "_b32_pred"].concat();
|
||||
let return_arguments = vec![(
|
||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)];
|
||||
let input_arguments = vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
];
|
||||
let func = match fn_declarations.entry(full_name.into()) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let name = resolver.register_named(name, None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &return_arguments),
|
||||
name,
|
||||
to_variables(resolver, &input_arguments),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
smallvec![
|
||||
Statement::Instruction::<_, SpirvWord>(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: vec![(
|
||||
ast::Type::Vector(2, ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ptx_parser::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: vec![packed_var],
|
||||
func,
|
||||
input_arguments: vec![src, src_lane, src_opts, src_membermask],
|
||||
},
|
||||
}),
|
||||
Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: true,
|
||||
typ: ast::ScalarType::U32,
|
||||
packed: packed_var,
|
||||
unpacked: vec![dst, dst_pred_wide],
|
||||
relaxed_type_check: false,
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Cvt {
|
||||
data: ast::CvtDetails {
|
||||
from: ast::ScalarType::U32,
|
||||
to: ast::ScalarType::Pred,
|
||||
mode: ast::CvtMode::Truncate
|
||||
},
|
||||
arguments: ast::CvtArgs {
|
||||
dst: dst_pred,
|
||||
src: dst_pred_wide,
|
||||
},
|
||||
})
|
||||
]
|
||||
}
|
||||
s => s,
|
||||
Statement::<ast::Instruction<SpirvWord>, SpirvWord>::Instruction(instruction) => {
|
||||
smallvec![
|
||||
Statement::<ast::Instruction<SpirvWord>, SpirvWord>::Instruction(
|
||||
run_instruction(resolver, fn_declarations, instruction)?
|
||||
)
|
||||
]
|
||||
}
|
||||
s => smallvec![s],
|
||||
})
|
||||
})
|
||||
.flat_map(|result| match result {
|
||||
Ok(vec) => vec.into_iter().map(|item| Ok(item)).collect(),
|
||||
Err(er) => vec![Err(er)],
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
|
@ -141,6 +265,52 @@ fn run_instruction<'input>(
|
|||
let name = ["bfe_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: Some(true),
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {
|
||||
let name = "sqrt_rn_ftz_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Sqrt {
|
||||
data:
|
||||
ast::RcpData {
|
||||
kind: ast::RcpKind::Compliant(ast::RoundingMode::NearestEven),
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {
|
||||
let name = "sqrt_rn_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Div {
|
||||
data:
|
||||
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||
kind: ast::DivFloatKind::Rounding(_),
|
||||
flush_to_zero: Some(true),
|
||||
..
|
||||
}),
|
||||
..
|
||||
} => {
|
||||
let name = "div_rn_ftz_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Div {
|
||||
data:
|
||||
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||
kind: ast::DivFloatKind::Rounding(_),
|
||||
..
|
||||
}),
|
||||
..
|
||||
} => {
|
||||
let name = "div_rn_f32";
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Bfi { data, .. } => {
|
||||
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
|
||||
to_call(resolver, fn_declarations, name.into(), i)?
|
||||
|
@ -163,23 +333,24 @@ fn run_instruction<'input>(
|
|||
ptx_parser::Instruction::BarRed { data, arguments },
|
||||
)?
|
||||
}
|
||||
ptx_parser::Instruction::ShflSync { data, arguments } => {
|
||||
ptx_parser::Instruction::ShflSync {
|
||||
data,
|
||||
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },
|
||||
} => {
|
||||
let mode = match data.mode {
|
||||
ptx_parser::ShuffleMode::Up => "up",
|
||||
ptx_parser::ShuffleMode::Down => "down",
|
||||
ptx_parser::ShuffleMode::BFly => "bfly",
|
||||
ptx_parser::ShuffleMode::Idx => "idx",
|
||||
};
|
||||
let pred = if arguments.dst_pred.is_some() {
|
||||
"_pred"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
to_call(
|
||||
resolver,
|
||||
fn_declarations,
|
||||
format!("shfl_sync_{}_b32{}", mode, pred).into(),
|
||||
ptx_parser::Instruction::ShflSync { data, arguments },
|
||||
format!("shfl_sync_{}_b32", mode).into(),
|
||||
ptx_parser::Instruction::ShflSync {
|
||||
data,
|
||||
arguments: orig_arguments,
|
||||
},
|
||||
)?
|
||||
}
|
||||
i @ ptx_parser::Instruction::Nanosleep { .. } => {
|
367
ptx/src/pass/replace_instructions_with_functions_fp_required.rs
Normal file
367
ptx/src/pass/replace_instructions_with_functions_fp_required.rs
Normal file
|
@ -0,0 +1,367 @@
|
|||
// This pass exists specifically to replace the `div.rn.ftz.f32` instruction
|
||||
// with a function call. One inherent weirdness of the replacement function is
|
||||
// that it requires different rounding mode for the first part of the
|
||||
// division and the second part. The first part is executed with FTZ disabled
|
||||
// and the second part with FTZ enabled.
|
||||
// For this reason we can't handle this past FTZ mode insertion without making
|
||||
// the function read and restore the FTZ mode. For this reason we split the
|
||||
// replacement function in two functions and prefix them with a noop
|
||||
// (FpModeRequired) that carries the FTZ mode information.
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use smallvec::smallvec;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut imports = None;
|
||||
let directives = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive, &mut imports))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(match imports {
|
||||
Some(imports) => {
|
||||
let mut result = Vec::with_capacity(directives.len() + 2);
|
||||
result.extend([
|
||||
Directive2::Method(Function2 {
|
||||
return_arguments: vec![
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::U8),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
name: imports.part1,
|
||||
input_arguments: vec![
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}),
|
||||
Directive2::Method(Function2 {
|
||||
return_arguments: vec![ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: imports.part2,
|
||||
input_arguments: vec![
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::Variable {
|
||||
name: resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::U8),
|
||||
ast::StateSpace::Reg,
|
||||
))),
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
],
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}),
|
||||
]);
|
||||
result.extend(directives);
|
||||
result
|
||||
}
|
||||
None => directives,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
imports: &mut Option<FunctionImports>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method, imports)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
imports: &mut Option<FunctionImports>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
method.body = method.body.map(|body| {
|
||||
body.into_iter()
|
||||
.flat_map(|stmt| run_statement(resolver, stmt, imports))
|
||||
.collect()
|
||||
});
|
||||
Ok(method)
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
stmt: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
imports: &mut Option<FunctionImports>,
|
||||
) -> SmallVec<[Statement<ast::Instruction<SpirvWord>, SpirvWord>; 4]> {
|
||||
match stmt {
|
||||
Statement::Instruction(ast::Instruction::Div {
|
||||
data:
|
||||
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||
flush_to_zero,
|
||||
kind: ast::DivFloatKind::Rounding(rnd),
|
||||
type_: ast::ScalarType::F32,
|
||||
}),
|
||||
arguments,
|
||||
}) => {
|
||||
let ftz = flush_to_zero.unwrap_or(false);
|
||||
let FunctionImports { part1, part2, .. } = FunctionImports::init(imports, resolver);
|
||||
let fma_4 = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let fma_1 = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let fma3_ = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let numerator_scaled_flag = resolver.register_unnamed(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::U8),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
smallvec![
|
||||
Statement::FpModeRequired {
|
||||
ftz_f32: Some(false),
|
||||
rnd_f32: Some(ast::RoundingMode::NearestEven),
|
||||
},
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
arguments: ast::CallArgs {
|
||||
return_arguments: vec![fma_4, fma_1, fma3_, numerator_scaled_flag],
|
||||
func: *part1,
|
||||
input_arguments: vec![arguments.src1, arguments.src2],
|
||||
},
|
||||
data: ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,)
|
||||
],
|
||||
input_arguments: vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
)
|
||||
]
|
||||
}
|
||||
}),
|
||||
Statement::FpModeRequired {
|
||||
ftz_f32: Some(ftz),
|
||||
rnd_f32: Some(rnd),
|
||||
},
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
arguments: ast::CallArgs {
|
||||
return_arguments: vec![arguments.dst],
|
||||
func: *part2,
|
||||
input_arguments: vec![
|
||||
arguments.src1,
|
||||
arguments.src2,
|
||||
fma_4,
|
||||
fma_1,
|
||||
fma3_,
|
||||
numerator_scaled_flag
|
||||
],
|
||||
},
|
||||
data: ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: vec![(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
)],
|
||||
input_arguments: vec![
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
ast::Type::Scalar(ast::ScalarType::F32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg,)
|
||||
]
|
||||
}
|
||||
})
|
||||
]
|
||||
}
|
||||
_ => smallvec![stmt],
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FunctionImports {
|
||||
part1: SpirvWord,
|
||||
part2: SpirvWord,
|
||||
}
|
||||
|
||||
impl FunctionImports {
|
||||
fn init<'a>(
|
||||
this: &'a mut Option<FunctionImports>,
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
) -> &'a FunctionImports {
|
||||
this.get_or_insert_with(|| {
|
||||
let part1_name = [ZLUDA_PTX_PREFIX, "div_f32_part1"].concat();
|
||||
let part1 = resolver.register_named(part1_name.into(), None);
|
||||
let part2_name = [ZLUDA_PTX_PREFIX, "div_f32_part2"].concat();
|
||||
let part2 = resolver.register_named(part2_name.into(), None);
|
||||
FunctionImports { part1, part2 }
|
||||
})
|
||||
}
|
||||
}
|
74
ptx/src/test/ll/div_ftz.ll
Normal file
74
ptx/src/test/ll/div_ftz.ll
Normal file
|
@ -0,0 +1,74 @@
|
|||
%struct.f32.f32.f32.i8 = type { float, float, float, i8 }
|
||||
|
||||
declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0
|
||||
|
||||
declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0
|
||||
|
||||
define amdgpu_kernel void @div_ftz(ptr addrspace(4) byref(i64) %"63", ptr addrspace(4) byref(i64) %"64") #1 {
|
||||
%"65" = alloca i64, align 8, addrspace(5)
|
||||
%"66" = alloca i64, align 8, addrspace(5)
|
||||
%"67" = alloca float, align 4, addrspace(5)
|
||||
%"68" = alloca float, align 4, addrspace(5)
|
||||
%"69" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"54"
|
||||
|
||||
"54": ; preds = %1
|
||||
%"70" = load i64, ptr addrspace(4) %"63", align 8
|
||||
store i64 %"70", ptr addrspace(5) %"65", align 8
|
||||
%"71" = load i64, ptr addrspace(4) %"64", align 8
|
||||
store i64 %"71", ptr addrspace(5) %"66", align 8
|
||||
%"73" = load i64, ptr addrspace(5) %"65", align 8
|
||||
%"88" = inttoptr i64 %"73" to ptr
|
||||
%"72" = load float, ptr %"88", align 4
|
||||
store float %"72", ptr addrspace(5) %"67", align 4
|
||||
%"74" = load i64, ptr addrspace(5) %"65", align 8
|
||||
%"89" = inttoptr i64 %"74" to ptr
|
||||
%"32" = getelementptr inbounds i8, ptr %"89", i64 4
|
||||
%"75" = load float, ptr %"32", align 4
|
||||
store float %"75", ptr addrspace(5) %"68", align 4
|
||||
%"77" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"78" = load float, ptr addrspace(5) %"68", align 4
|
||||
%"76" = fmul float %"77", %"78"
|
||||
store float %"76", ptr addrspace(5) %"69", align 4
|
||||
%"79" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"80" = load float, ptr addrspace(5) %"68", align 4
|
||||
%2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"79", float %"80")
|
||||
%"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0
|
||||
%"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1
|
||||
%"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2
|
||||
%"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3
|
||||
br label %"57"
|
||||
|
||||
"57": ; preds = %"54"
|
||||
call void @llvm.amdgcn.s.setreg(i32 6401, i32 0)
|
||||
br label %"55"
|
||||
|
||||
"55": ; preds = %"57"
|
||||
%"82" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"83" = load float, ptr addrspace(5) %"68", align 4
|
||||
%"81" = call float @__zluda_ptx_impl_div_f32_part2(float %"82", float %"83", float %"37", float %"38", float %"39", i8 %"40")
|
||||
store float %"81", ptr addrspace(5) %"67", align 4
|
||||
br label %"56"
|
||||
|
||||
"56": ; preds = %"55"
|
||||
%"84" = load i64, ptr addrspace(5) %"66", align 8
|
||||
%"85" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"90" = inttoptr i64 %"84" to ptr
|
||||
store float %"85", ptr %"90", align 4
|
||||
%"86" = load i64, ptr addrspace(5) %"66", align 8
|
||||
%"91" = inttoptr i64 %"86" to ptr
|
||||
%"34" = getelementptr inbounds i8, ptr %"91", i64 4
|
||||
%"87" = load float, ptr addrspace(5) %"69", align 4
|
||||
store float %"87", ptr %"34", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind willreturn
|
||||
declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #2 = { nocallback nofree nosync nounwind willreturn }
|
71
ptx/src/test/ll/div_noftz.ll
Normal file
71
ptx/src/test/ll/div_noftz.ll
Normal file
|
@ -0,0 +1,71 @@
|
|||
%struct.f32.f32.f32.i8 = type { float, float, float, i8 }
|
||||
|
||||
declare %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float, float) #0
|
||||
|
||||
declare float @__zluda_ptx_impl_div_f32_part2(float, float, float, float, float, i8) #0
|
||||
|
||||
define amdgpu_kernel void @div_noftz(ptr addrspace(4) byref(i64) %"62", ptr addrspace(4) byref(i64) %"63") #1 {
|
||||
%"64" = alloca i64, align 8, addrspace(5)
|
||||
%"65" = alloca i64, align 8, addrspace(5)
|
||||
%"66" = alloca float, align 4, addrspace(5)
|
||||
%"67" = alloca float, align 4, addrspace(5)
|
||||
%"68" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"54"
|
||||
|
||||
"54": ; preds = %1
|
||||
%"69" = load i64, ptr addrspace(4) %"62", align 8
|
||||
store i64 %"69", ptr addrspace(5) %"64", align 8
|
||||
%"70" = load i64, ptr addrspace(4) %"63", align 8
|
||||
store i64 %"70", ptr addrspace(5) %"65", align 8
|
||||
%"72" = load i64, ptr addrspace(5) %"64", align 8
|
||||
%"87" = inttoptr i64 %"72" to ptr
|
||||
%"71" = load float, ptr %"87", align 4
|
||||
store float %"71", ptr addrspace(5) %"66", align 4
|
||||
%"73" = load i64, ptr addrspace(5) %"64", align 8
|
||||
%"88" = inttoptr i64 %"73" to ptr
|
||||
%"32" = getelementptr inbounds i8, ptr %"88", i64 4
|
||||
%"74" = load float, ptr %"32", align 4
|
||||
store float %"74", ptr addrspace(5) %"67", align 4
|
||||
%"76" = load float, ptr addrspace(5) %"66", align 4
|
||||
%"77" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"75" = fmul float %"76", %"77"
|
||||
store float %"75", ptr addrspace(5) %"68", align 4
|
||||
call void @llvm.amdgcn.s.setreg(i32 6401, i32 3)
|
||||
%"78" = load float, ptr addrspace(5) %"66", align 4
|
||||
%"79" = load float, ptr addrspace(5) %"67", align 4
|
||||
%2 = call %struct.f32.f32.f32.i8 @__zluda_ptx_impl_div_f32_part1(float %"78", float %"79")
|
||||
%"37" = extractvalue %struct.f32.f32.f32.i8 %2, 0
|
||||
%"38" = extractvalue %struct.f32.f32.f32.i8 %2, 1
|
||||
%"39" = extractvalue %struct.f32.f32.f32.i8 %2, 2
|
||||
%"40" = extractvalue %struct.f32.f32.f32.i8 %2, 3
|
||||
br label %"55"
|
||||
|
||||
"55": ; preds = %"54"
|
||||
%"81" = load float, ptr addrspace(5) %"66", align 4
|
||||
%"82" = load float, ptr addrspace(5) %"67", align 4
|
||||
%"80" = call float @__zluda_ptx_impl_div_f32_part2(float %"81", float %"82", float %"37", float %"38", float %"39", i8 %"40")
|
||||
store float %"80", ptr addrspace(5) %"66", align 4
|
||||
br label %"56"
|
||||
|
||||
"56": ; preds = %"55"
|
||||
%"83" = load i64, ptr addrspace(5) %"65", align 8
|
||||
%"84" = load float, ptr addrspace(5) %"66", align 4
|
||||
%"89" = inttoptr i64 %"83" to ptr
|
||||
store float %"84", ptr %"89", align 4
|
||||
%"85" = load i64, ptr addrspace(5) %"65", align 8
|
||||
%"90" = inttoptr i64 %"85" to ptr
|
||||
%"34" = getelementptr inbounds i8, ptr %"90", i64 4
|
||||
%"86" = load float, ptr addrspace(5) %"68", align 4
|
||||
store float %"86", ptr %"34", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind willreturn
|
||||
declare void @llvm.amdgcn.s.setreg(i32 immarg, i32) #2
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #2 = { nocallback nofree nosync nounwind willreturn }
|
|
@ -1,4 +1,6 @@
|
|||
define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #0 {
|
||||
declare float @__zluda_ptx_impl_sqrt_approx_f32(float) #0
|
||||
|
||||
define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 {
|
||||
%"32" = alloca i64, align 8, addrspace(5)
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca float, align 4, addrspace(5)
|
||||
|
@ -17,7 +19,7 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace
|
|||
%"37" = load float, ptr %"43", align 4
|
||||
store float %"37", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"39" = call float @llvm.amdgcn.sqrt.f32(float %"40")
|
||||
%"39" = call float @__zluda_ptx_impl_sqrt_approx_f32(float %"40")
|
||||
store float %"39", ptr addrspace(5) %"34", align 4
|
||||
%"41" = load i64, ptr addrspace(5) %"33", align 8
|
||||
%"42" = load float, ptr addrspace(5) %"34", align 4
|
||||
|
@ -26,8 +28,5 @@ define amdgpu_kernel void @sqrt(ptr addrspace(4) byref(i64) %"30", ptr addrspace
|
|||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare float @llvm.amdgcn.sqrt.f32(float) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
32
ptx/src/test/ll/sqrt_rn_ftz.ll
Normal file
32
ptx/src/test/ll/sqrt_rn_ftz.ll
Normal file
|
@ -0,0 +1,32 @@
|
|||
declare float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float) #0
|
||||
|
||||
define amdgpu_kernel void @sqrt_rn_ftz(ptr addrspace(4) byref(i64) %"30", ptr addrspace(4) byref(i64) %"31") #1 {
|
||||
%"32" = alloca i64, align 8, addrspace(5)
|
||||
%"33" = alloca i64, align 8, addrspace(5)
|
||||
%"34" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"29"
|
||||
|
||||
"29": ; preds = %1
|
||||
%"35" = load i64, ptr addrspace(4) %"30", align 8
|
||||
store i64 %"35", ptr addrspace(5) %"32", align 8
|
||||
%"36" = load i64, ptr addrspace(4) %"31", align 8
|
||||
store i64 %"36", ptr addrspace(5) %"33", align 8
|
||||
%"38" = load i64, ptr addrspace(5) %"32", align 8
|
||||
%"43" = inttoptr i64 %"38" to ptr
|
||||
%"37" = load float, ptr %"43", align 4
|
||||
store float %"37", ptr addrspace(5) %"34", align 4
|
||||
%"40" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"39" = call float @__zluda_ptx_impl_sqrt_rn_ftz_f32(float %"40")
|
||||
store float %"39", ptr addrspace(5) %"34", align 4
|
||||
%"41" = load i64, ptr addrspace(5) %"33", align 8
|
||||
%"42" = load float, ptr addrspace(5) %"34", align 4
|
||||
%"44" = inttoptr i64 %"41" to ptr
|
||||
store float %"42", ptr %"44", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
27
ptx/src/test/spirv_run/div_ftz.ptx
Normal file
27
ptx/src/test/spirv_run/div_ftz.ptx
Normal file
|
@ -0,0 +1,27 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry div_ftz(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
.reg .f32 temp2;
|
||||
.reg .f32 force_ftz_mode;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
ld.f32 temp2, [in_addr+4];
|
||||
// DO NOT REMOVE THIS MULTIPLICATION
|
||||
mul.f32 force_ftz_mode, temp1, temp2;
|
||||
div.ftz.rn.f32 temp1, temp1, temp2;
|
||||
st.f32 [out_addr], temp1;
|
||||
st.f32 [out_addr+4], force_ftz_mode;
|
||||
ret;
|
||||
}
|
27
ptx/src/test/spirv_run/div_noftz.ptx
Normal file
27
ptx/src/test/spirv_run/div_noftz.ptx
Normal file
|
@ -0,0 +1,27 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry div_noftz(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
.reg .f32 temp2;
|
||||
.reg .f32 force_ftz_mode;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
ld.f32 temp2, [in_addr+4];
|
||||
// DO NOT REMOVE THIS MULTIPLICATION
|
||||
mul.ftz.f32 force_ftz_mode, temp1, temp2;
|
||||
div.rn.f32 temp1, temp1, temp2;
|
||||
st.f32 [out_addr], temp1;
|
||||
st.f32 [out_addr+4], force_ftz_mode;
|
||||
ret;
|
||||
}
|
|
@ -173,6 +173,7 @@ test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
|
|||
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
|
||||
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
|
||||
test_ptx!(sqrt, [0.25f32], [0.5f32]);
|
||||
test_ptx!(sqrt_rn_ftz, [0x1u32], [0x0u32]);
|
||||
test_ptx!(rsqrt, [0.25f64], [2f64]);
|
||||
test_ptx!(neg, [181i32], [-181i32]);
|
||||
test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]);
|
||||
|
@ -279,6 +280,19 @@ test_ptx!(multiple_return, [5u32], [6u32, 123u32]);
|
|||
test_ptx!(warp_sz, [0u8], [32u8]);
|
||||
test_ptx!(tanh, [f32::INFINITY], [1.0f32]);
|
||||
test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]);
|
||||
// Two test below test very important compiler feature, make sure that you
|
||||
// understand fully what's going on before you touch it.
|
||||
// The problem is that the full-precision division gets legalized by LLVM
|
||||
// using __module attribute__.
|
||||
// In the two tests below we deliberately force our compiler to emit
|
||||
// different a module that has a different module-level denormal attribute
|
||||
// from the denormal attribute of the instruction to catch cases like this
|
||||
test_ptx!(div_ftz, [0x16A2028Du32, 0x5E89F6AE], [0x0, 900636404u32]);
|
||||
test_ptx!(
|
||||
div_noftz,
|
||||
[0x16A2028Du32, 0x5E89F6AE],
|
||||
[0x26u32, 900636404u32]
|
||||
);
|
||||
|
||||
test_ptx!(nanosleep, [0u64], [0u64]);
|
||||
test_ptx!(shf_l, [0x12345678u32, 0x9abcdef0u32, 12], [0xcdef0123u32]);
|
||||
|
|
21
ptx/src/test/spirv_run/sqrt_rn_ftz.ptx
Normal file
21
ptx/src/test/spirv_run/sqrt_rn_ftz.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry sqrt_rn_ftz(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
sqrt.rn.ftz.f32 temp1, temp1;
|
||||
st.f32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
|
@ -1058,22 +1058,12 @@ impl From<ScalarType> for Type {
|
|||
#[derive(Clone)]
|
||||
pub struct MovDetails {
|
||||
pub typ: super::Type,
|
||||
pub src_is_address: bool,
|
||||
// two fields below are in use by member moves
|
||||
pub dst_width: u8,
|
||||
pub src_width: u8,
|
||||
// This is in use by auto-generated movs
|
||||
pub relaxed_src2_conv: bool,
|
||||
}
|
||||
|
||||
impl MovDetails {
|
||||
pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
|
||||
MovDetails {
|
||||
typ: Type::maybe_vector(vector, scalar),
|
||||
src_is_address: false,
|
||||
dst_width: 0,
|
||||
src_width: 0,
|
||||
relaxed_src2_conv: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue