Add shifts, cvt, rsqrt, sqrt, rcp, more sregs

This commit is contained in:
Andrzej Janik 2024-10-14 19:09:47 +02:00
commit ae42eac925
5 changed files with 230 additions and 37 deletions

Binary file not shown.

View file

@ -13,12 +13,30 @@ extern "C"
return __builtin_amdgcn_read_exec_lo(); return __builtin_amdgcn_read_exec_lo();
} }
size_t __ockl_get_local_id(uint32_t) __device__;
uint32_t FUNC(sreg_tid)(uint8_t member)
{
return (uint32_t)__ockl_get_local_id(member);
}
size_t __ockl_get_local_size(uint32_t) __device__; size_t __ockl_get_local_size(uint32_t) __device__;
uint32_t FUNC(sreg_ntid)(uint8_t member) uint32_t FUNC(sreg_ntid)(uint8_t member)
{ {
return (uint32_t)__ockl_get_local_size(member); return (uint32_t)__ockl_get_local_size(member);
} }
size_t __ockl_get_global_id(uint32_t) __device__;
uint32_t FUNC(sreg_ctaid)(uint8_t member)
{
return (uint32_t)__ockl_get_global_id(member);
}
size_t __ockl_get_global_size(uint32_t) __device__;
uint32_t FUNC(sreg_nctaid)(uint8_t member)
{
return (uint32_t)__ockl_get_global_size(member);
}
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device)); uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device));
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
{ {

View file

@ -522,9 +522,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { .. } => todo!(), ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
ast::Instruction::Shr { .. } => todo!(), ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
ast::Instruction::Shl { .. } => todo!(), ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(), ast::Instruction::Abs { .. } => todo!(),
@ -533,9 +533,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { .. } => todo!(), ast::Instruction::Min { .. } => todo!(),
ast::Instruction::Max { .. } => todo!(), ast::Instruction::Max { .. } => todo!(),
ast::Instruction::Rcp { .. } => todo!(), ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
ast::Instruction::Sqrt { .. } => todo!(), ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
ast::Instruction::Rsqrt { .. } => todo!(), ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
ast::Instruction::Selp { .. } => todo!(), ast::Instruction::Selp { .. } => todo!(),
ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
@ -1406,6 +1406,212 @@ impl<'a> MethodEmitContext<'a> {
Ok(()) Ok(())
} }
fn emit_cvt(
&mut self,
data: ptx_parser::CvtDetails,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let dst_type = get_scalar_type(self.context, data.to);
let llvm_fn = match data.mode {
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
ptx_parser::CvtMode::FPTruncate {
rounding,
flush_to_zero,
} => todo!(),
ptx_parser::CvtMode::FPRound {
integer_rounding,
flush_to_zero,
} => todo!(),
ptx_parser::CvtMode::SignedFromFP {
rounding,
flush_to_zero,
} => todo!(),
ptx_parser::CvtMode::UnsignedFromFP {
rounding,
flush_to_zero,
} => todo!(),
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
};
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst_type, dst)
});
Ok(())
}
fn emit_rsqrt(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::RsqrtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match data.type_ {
ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
)?;
Ok(())
}
fn emit_sqrt(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::SqrtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match (data.type_, data.kind) {
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
(ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
(ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
)?;
Ok(())
}
fn emit_rcp(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::RcpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let intrinsic = match (data.type_, data.kind) {
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
(_, ast::RcpKind::Compliant(rnd)) => {
return self.emit_rcp_compliant(data, arguments, rnd)
}
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, type_)],
)?;
Ok(())
}
fn emit_rcp_compliant(
&mut self,
data: ptx_parser::RcpData,
arguments: ptx_parser::RcpArgs<SpirvWord>,
_rnd: ast::RoundingMode,
) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, data.type_);
let one = unsafe { LLVMConstReal(type_, 1.0) };
let src = self.resolver.value(arguments.src)?;
let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFDiv(self.builder, one, src, dst)
});
unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
Ok(())
}
fn emit_shr(
&mut self,
data: ptx_parser::ShrData,
arguments: ptx_parser::ShrArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let shift_fn = match data.kind {
ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
};
self.emit_shift(
data.type_,
arguments.dst,
arguments.src1,
arguments.src2,
shift_fn,
)
}
fn emit_shl(
&mut self,
type_: ptx_parser::ScalarType,
arguments: ptx_parser::ShlArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_shift(
type_,
arguments.dst,
arguments.src1,
arguments.src2,
LLVMBuildShl,
)
}
fn emit_shift(
&mut self,
type_: ast::ScalarType,
dst: SpirvWord,
src1: SpirvWord,
src2: SpirvWord,
llvm_fn: unsafe extern "C" fn(
LLVMBuilderRef,
LLVMValueRef,
LLVMValueRef,
*const i8,
) -> LLVMValueRef,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(src1)?;
let shift_size = self.resolver.value(src2)?;
let integer_bits = type_.layout().size() * 8;
let integer_bits_constant = unsafe {
LLVMConstInt(
get_scalar_type(self.context, ast::ScalarType::U32),
integer_bits as u64,
0,
)
};
let should_clamp = unsafe {
LLVMBuildICmp(
self.builder,
LLVMIntPredicate::LLVMIntUGE,
shift_size,
integer_bits_constant,
LLVM_UNNAMED.as_ptr(),
)
};
let llvm_type = get_scalar_type(self.context, type_);
let zero = unsafe { LLVMConstNull(llvm_type) };
let normalized_shift_size = if type_.layout().size() >= 4 {
unsafe {
LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
}
} else {
unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
};
let shifted = unsafe {
llvm_fn(
self.builder,
src1,
normalized_shift_size,
LLVM_UNNAMED.as_ptr(),
)
};
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
});
Ok(())
}
/* /*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19 // Should be available in LLVM 19

View file

@ -45,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]);
test_ptx!(bra, [10u64], [11u64]); test_ptx!(bra, [10u64], [11u64]);
test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(not, [0u64], [u64::max_value()]);
test_ptx!(shl, [11u64], [44u64]); test_ptx!(shl, [11u64], [44u64]);
test_ptx!(shl_link_hack, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]); test_ptx!(block, [1u64], [2u64]);

View file

@ -1,30 +0,0 @@
// HACK ALERT
// This test is for testing workaround for a bug in IGC where linking fails
// if there is shl/shr with different width of value and shift
.version 6.5
.target sm_30
.address_size 64
.visible .entry shl_link_hack(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
// Here only to trigger linking
.reg .u32 unused;
atom.inc.u32 unused, [out_addr], 2000000;
ld.u64 temp, [in_addr];
shl.b64 temp2, temp, 2;
st.u64 [out_addr], temp2;
ret;
}