mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add or, mad, fma, min, max, selp, lg2, ex2, popc, rem
This commit is contained in:
parent
ae42eac925
commit
6f2944d9be
2 changed files with 298 additions and 21 deletions
|
@ -518,7 +518,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
|
||||
ast::Instruction::SetpBool { .. } => todo!(),
|
||||
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
|
||||
ast::Instruction::Or { .. } => todo!(),
|
||||
ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
|
||||
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
|
||||
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
|
||||
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
||||
|
@ -528,15 +528,15 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
||||
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
|
||||
ast::Instruction::Abs { .. } => todo!(),
|
||||
ast::Instruction::Mad { .. } => todo!(),
|
||||
ast::Instruction::Fma { .. } => todo!(),
|
||||
ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
|
||||
ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
|
||||
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
|
||||
ast::Instruction::Min { .. } => todo!(),
|
||||
ast::Instruction::Max { .. } => todo!(),
|
||||
ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
|
||||
ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
|
||||
ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
|
||||
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
|
||||
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
|
||||
ast::Instruction::Selp { .. } => todo!(),
|
||||
ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
|
||||
ast::Instruction::Bar { .. } => todo!(),
|
||||
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
||||
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
|
||||
|
@ -544,13 +544,13 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
|
||||
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
|
||||
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
|
||||
ast::Instruction::Lg2 { .. } => todo!(),
|
||||
ast::Instruction::Ex2 { .. } => todo!(),
|
||||
ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
|
||||
ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
|
||||
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
|
||||
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
|
||||
ast::Instruction::Popc { .. } => todo!(),
|
||||
ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
|
||||
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
||||
ast::Instruction::Rem { .. } => todo!(),
|
||||
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
|
||||
ast::Instruction::PrmtSlow { .. } => todo!(),
|
||||
ast::Instruction::Prmt { .. } => todo!(),
|
||||
ast::Instruction::Membar { .. } => todo!(),
|
||||
|
@ -664,7 +664,14 @@ impl<'a> MethodEmitContext<'a> {
|
|||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
ConversionKind::SignExtend => todo!(),
|
||||
ConversionKind::SignExtend => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let type_ = get_type(self.context, &conversion.to_type)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildSExt(builder, src, type_, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
ConversionKind::BitToPtr => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let type_ = get_pointer_type(self.context, conversion.to_space)?;
|
||||
|
@ -986,20 +993,82 @@ impl<'a> MethodEmitContext<'a> {
|
|||
data: ast::MulDetails,
|
||||
arguments: ast::MulArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul_impl(
|
||||
&mut self,
|
||||
data: ast::MulDetails,
|
||||
dst: Option<SpirvWord>,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let mul_fn = match data {
|
||||
ast::MulDetails::Integer { control, .. } => match control {
|
||||
ast::MulDetails::Integer { control, type_ } => match control {
|
||||
ast::MulIntControl::Low => LLVMBuildMul,
|
||||
ast::MulIntControl::High => todo!(),
|
||||
ast::MulIntControl::Wide => todo!(),
|
||||
ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
|
||||
ast::MulIntControl::Wide => {
|
||||
return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
|
||||
}
|
||||
},
|
||||
ast::MulDetails::Float(..) => LLVMBuildFMul,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
mul_fn(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
Ok(self
|
||||
.resolver
|
||||
.with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
|
||||
}
|
||||
|
||||
fn emit_mul_high(
|
||||
&mut self,
|
||||
type_: ptx_parser::ScalarType,
|
||||
dst: Option<SpirvWord>,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
|
||||
let shift_constant =
|
||||
unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
|
||||
let shifted = unsafe {
|
||||
LLVMBuildLShr(
|
||||
self.builder,
|
||||
wide_value,
|
||||
shift_constant,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let narrow_type = get_scalar_type(self.context, type_);
|
||||
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
|
||||
LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
|
||||
}))
|
||||
}
|
||||
|
||||
fn emit_mul_wide_impl(
|
||||
&mut self,
|
||||
type_: ptx_parser::ScalarType,
|
||||
dst: Option<SpirvWord>,
|
||||
src1: SpirvWord,
|
||||
src2: SpirvWord,
|
||||
) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
|
||||
let src1 = self.resolver.value(src1)?;
|
||||
let src2 = self.resolver.value(src2)?;
|
||||
let wide_type =
|
||||
unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
|
||||
let llvm_cast = match type_.kind() {
|
||||
ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
|
||||
ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
|
||||
let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
|
||||
Ok((
|
||||
wide_type,
|
||||
self.resolver.with_result_option(dst, |dst| unsafe {
|
||||
LLVMBuildMul(self.builder, src1, src2, dst)
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
fn emit_cos(
|
||||
|
@ -1018,6 +1087,19 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_or(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::OrArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildOr(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_xor(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
|
@ -1612,6 +1694,191 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_ex2(
|
||||
&mut self,
|
||||
data: ptx_parser::TypeFtz,
|
||||
arguments: ptx_parser::Ex2Args<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let intrinsic = match data.type_ {
|
||||
ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
|
||||
ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![(arguments.src, get_scalar_type(self.context, data.type_))],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_lg2(
|
||||
&mut self,
|
||||
_data: ptx_parser::FlushToZero,
|
||||
arguments: ptx_parser::Lg2Args<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.emit_intrinsic(
|
||||
c"llvm.amdgcn.log.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(
|
||||
arguments.src,
|
||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
||||
)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_selp(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::SelpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let src3 = self.resolver.value(arguments.src3)?;
|
||||
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
|
||||
LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_rem(
|
||||
&mut self,
|
||||
data: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::RemArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_fn = match data.kind() {
|
||||
ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
|
||||
ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst_name| unsafe {
|
||||
llvm_fn(self.builder, src1, src2, dst_name)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_popc(
|
||||
&mut self,
|
||||
type_: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::PopcArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let intrinsic = match type_ {
|
||||
ast::ScalarType::B32 => c"llvm.ctpop.i32",
|
||||
ast::ScalarType::B64 => c"llvm.ctpop.i64",
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let llvm_type = get_scalar_type(self.context, type_);
|
||||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&type_.into(),
|
||||
vec![(arguments.src, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_min(
|
||||
&mut self,
|
||||
data: ptx_parser::MinMaxDetails,
|
||||
arguments: ptx_parser::MinArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_prefix = match data {
|
||||
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
|
||||
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
|
||||
return Err(error_todo())
|
||||
}
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
|
||||
};
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_max(
|
||||
&mut self,
|
||||
data: ptx_parser::MinMaxDetails,
|
||||
arguments: ptx_parser::MaxArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_prefix = match data {
|
||||
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
|
||||
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
|
||||
return Err(error_todo())
|
||||
}
|
||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
|
||||
};
|
||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_fma(
|
||||
&mut self,
|
||||
data: ptx_parser::ArithFloat,
|
||||
arguments: ptx_parser::FmaArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
|
||||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
vec![
|
||||
(arguments.src1, get_scalar_type(self.context, data.type_)),
|
||||
(arguments.src2, get_scalar_type(self.context, data.type_)),
|
||||
(arguments.src3, get_scalar_type(self.context, data.type_)),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad(
|
||||
&mut self,
|
||||
data: ptx_parser::MadDetails,
|
||||
arguments: ptx_parser::MadArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mul_control = match data {
|
||||
ptx_parser::MadDetails::Float(mad_float) => {
|
||||
return self.emit_fma(
|
||||
mad_float,
|
||||
ast::FmaArgs {
|
||||
dst: arguments.dst,
|
||||
src1: arguments.src1,
|
||||
src2: arguments.src2,
|
||||
src3: arguments.src3,
|
||||
},
|
||||
)
|
||||
}
|
||||
ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
|
||||
ptx_parser::MadDetails::Integer { type_, control, .. } => {
|
||||
ast::MulDetails::Integer { control, type_ }
|
||||
}
|
||||
};
|
||||
let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
|
||||
let src3 = self.resolver.value(arguments.src3)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildAdd(self.builder, temp, src3, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
@ -1870,7 +2137,6 @@ impl ResolveIdent {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
struct ScalarTypeInLLVM(ast::ScalarType);
|
||||
|
||||
impl std::fmt::Display for ScalarTypeInLLVM {
|
||||
|
@ -1893,6 +2159,7 @@ impl std::fmt::Display for ScalarTypeInLLVM {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
|
||||
match this {
|
||||
ptx_parser::RoundingMode::Zero => 0,
|
||||
|
|
|
@ -149,6 +149,16 @@ fn error_unreachable() -> TranslateError {
|
|||
TranslateError::Unreachable
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_todo() -> TranslateError {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn error_todo() -> TranslateError {
|
||||
TranslateError::Todo
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_unknown_symbol() -> TranslateError {
|
||||
panic!()
|
||||
|
|
Loading…
Add table
Reference in a new issue