Add or, mad, fma, min, max, selp, lg2, ex2, popc, rem

This commit is contained in:
Andrzej Janik 2024-10-15 04:42:44 +02:00
parent ae42eac925
commit 6f2944d9be
2 changed files with 298 additions and 21 deletions

View file

@ -518,7 +518,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
ast::Instruction::Or { .. } => todo!(),
ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
@ -528,15 +528,15 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(),
ast::Instruction::Mad { .. } => todo!(),
ast::Instruction::Fma { .. } => todo!(),
ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { .. } => todo!(),
ast::Instruction::Max { .. } => todo!(),
ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
ast::Instruction::Selp { .. } => todo!(),
ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
@ -544,13 +544,13 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
ast::Instruction::Lg2 { .. } => todo!(),
ast::Instruction::Ex2 { .. } => todo!(),
ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
ast::Instruction::Popc { .. } => todo!(),
ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
ast::Instruction::Rem { .. } => todo!(),
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
ast::Instruction::PrmtSlow { .. } => todo!(),
ast::Instruction::Prmt { .. } => todo!(),
ast::Instruction::Membar { .. } => todo!(),
@ -664,7 +664,14 @@ impl<'a> MethodEmitContext<'a> {
_ => todo!(),
}
}
ConversionKind::SignExtend => todo!(),
ConversionKind::SignExtend => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildSExt(builder, src, type_, dst)
});
Ok(())
}
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
@ -986,20 +993,82 @@ impl<'a> MethodEmitContext<'a> {
data: ast::MulDetails,
arguments: ast::MulArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
Ok(())
}
fn emit_mul_impl(
&mut self,
data: ast::MulDetails,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let mul_fn = match data {
ast::MulDetails::Integer { control, .. } => match control {
ast::MulDetails::Integer { control, type_ } => match control {
ast::MulIntControl::Low => LLVMBuildMul,
ast::MulIntControl::High => todo!(),
ast::MulIntControl::Wide => todo!(),
ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
ast::MulIntControl::Wide => {
return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
}
},
ast::MulDetails::Float(..) => LLVMBuildFMul,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
mul_fn(self.builder, src1, src2, dst)
});
Ok(())
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
Ok(self
.resolver
.with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
}
fn emit_mul_high(
&mut self,
type_: ptx_parser::ScalarType,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<LLVMValueRef, TranslateError> {
let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
let shift_constant =
unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
let shifted = unsafe {
LLVMBuildLShr(
self.builder,
wide_value,
shift_constant,
LLVM_UNNAMED.as_ptr(),
)
};
let narrow_type = get_scalar_type(self.context, type_);
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
}))
}
fn emit_mul_wide_impl(
&mut self,
type_: ptx_parser::ScalarType,
dst: Option<SpirvWord>,
src1: SpirvWord,
src2: SpirvWord,
) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
let src1 = self.resolver.value(src1)?;
let src2 = self.resolver.value(src2)?;
let wide_type =
unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
let llvm_cast = match type_.kind() {
ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
_ => return Err(error_unreachable()),
};
let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
Ok((
wide_type,
self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildMul(self.builder, src1, src2, dst)
}),
))
}
fn emit_cos(
@ -1018,6 +1087,19 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_or(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::OrArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildOr(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_xor(
&mut self,
_data: ptx_parser::ScalarType,
@ -1612,6 +1694,191 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_ex2(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::Ex2Args<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = match data.type_ {
ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
_ => return Err(error_unreachable()),
};
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
vec![(arguments.src, get_scalar_type(self.context, data.type_))],
)?;
Ok(())
}
fn emit_lg2(
&mut self,
_data: ptx_parser::FlushToZero,
arguments: ptx_parser::Lg2Args<SpirvWord>,
) -> Result<(), TranslateError> {
self.emit_intrinsic(
c"llvm.amdgcn.log.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(
arguments.src,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
)],
)?;
Ok(())
}
fn emit_selp(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::SelpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
});
Ok(())
}
fn emit_rem(
&mut self,
data: ptx_parser::ScalarType,
arguments: ptx_parser::RemArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = match data.kind() {
ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
_ => return Err(error_unreachable()),
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
llvm_fn(self.builder, src1, src2, dst_name)
});
Ok(())
}
fn emit_popc(
&mut self,
type_: ptx_parser::ScalarType,
arguments: ptx_parser::PopcArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = match type_ {
ast::ScalarType::B32 => c"llvm.ctpop.i32",
ast::ScalarType::B64 => c"llvm.ctpop.i64",
_ => return Err(error_unreachable()),
};
let llvm_type = get_scalar_type(self.context, type_);
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&type_.into(),
vec![(arguments.src, llvm_type)],
)?;
Ok(())
}
fn emit_min(
&mut self,
data: ptx_parser::MinMaxDetails,
arguments: ptx_parser::MinArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_prefix = match data {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
)?;
Ok(())
}
fn emit_max(
&mut self,
data: ptx_parser::MinMaxDetails,
arguments: ptx_parser::MaxArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_prefix = match data {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
)?;
Ok(())
}
fn emit_fma(
&mut self,
data: ptx_parser::ArithFloat,
arguments: ptx_parser::FmaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
vec![
(arguments.src1, get_scalar_type(self.context, data.type_)),
(arguments.src2, get_scalar_type(self.context, data.type_)),
(arguments.src3, get_scalar_type(self.context, data.type_)),
],
)?;
Ok(())
}
fn emit_mad(
&mut self,
data: ptx_parser::MadDetails,
arguments: ptx_parser::MadArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let mul_control = match data {
ptx_parser::MadDetails::Float(mad_float) => {
return self.emit_fma(
mad_float,
ast::FmaArgs {
dst: arguments.dst,
src1: arguments.src1,
src2: arguments.src2,
src3: arguments.src3,
},
)
}
ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
ptx_parser::MadDetails::Integer { type_, control, .. } => {
ast::MulDetails::Integer { control, type_ }
}
};
let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
let src3 = self.resolver.value(arguments.src3)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildAdd(self.builder, temp, src3, dst)
});
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@ -1870,7 +2137,6 @@ impl ResolveIdent {
}
}
/*
struct ScalarTypeInLLVM(ast::ScalarType);
impl std::fmt::Display for ScalarTypeInLLVM {
@ -1893,6 +2159,7 @@ impl std::fmt::Display for ScalarTypeInLLVM {
}
}
/*
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
match this {
ptx_parser::RoundingMode::Zero => 0,

View file

@ -149,6 +149,16 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
#[cfg(debug_assertions)]
fn error_todo() -> TranslateError {
unreachable!()
}
#[cfg(not(debug_assertions))]
fn error_todo() -> TranslateError {
TranslateError::Todo
}
#[cfg(debug_assertions)]
fn error_unknown_symbol() -> TranslateError {
panic!()