Add br, setp, not, cvta, sub, neg, sin

This commit is contained in:
Andrzej Janik 2024-10-11 22:55:10 +02:00
parent c8b88f4483
commit d9c33ca505

View file

@ -451,7 +451,7 @@ impl<'a> MethodEmitContext<'a> {
Statement::Variable(var) => self.emit_variable(var)?,
Statement::Label(label) => self.emit_label_delayed(label)?,
Statement::Instruction(inst) => self.emit_instruction(inst)?,
Statement::Conditional(_) => todo!(),
Statement::Conditional(cond) => self.emit_conditional(cond)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
@ -515,9 +515,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
ast::Instruction::Setp { .. } => todo!(),
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
ast::Instruction::Not { .. } => todo!(),
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
ast::Instruction::Or { .. } => todo!(),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
@ -526,11 +526,11 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Shr { .. } => todo!(),
ast::Instruction::Shl { .. } => todo!(),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { .. } => todo!(),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(),
ast::Instruction::Mad { .. } => todo!(),
ast::Instruction::Fma { .. } => todo!(),
ast::Instruction::Sub { .. } => todo!(),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { .. } => todo!(),
ast::Instruction::Max { .. } => todo!(),
ast::Instruction::Rcp { .. } => todo!(),
@ -541,8 +541,8 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
ast::Instruction::Neg { .. } => todo!(),
ast::Instruction::Sin { .. } => todo!(),
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
ast::Instruction::Lg2 { .. } => todo!(),
ast::Instruction::Ex2 { .. } => todo!(),
@ -651,6 +651,16 @@ impl<'a> MethodEmitContext<'a> {
}
}
}
(ast::Type::Vector(..), ast::Type::Scalar(..))
| (ast::Type::Scalar(..), ast::Type::Array(..))
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildBitCast(builder, src, dst_type, dst)
});
Ok(())
}
_ => todo!(),
}
}
@ -997,20 +1007,13 @@ impl<'a> MethodEmitContext<'a> {
_data: ast::FlushToZero,
arguments: ast::CosArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = c"llvm.cos.f32";
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
let fn_type = get_function_type(
self.context,
iter::once(&ast::ScalarType::F32.into()),
iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))),
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
let cos = self.emit_intrinsic(
c"llvm.cos.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(arguments.src, llvm_f32)],
)?;
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
}
let mut src = self.resolver.value(arguments.src)?;
let cos = self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
});
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
}
@ -1168,6 +1171,241 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_cvta(
&mut self,
data: ptx_parser::CvtaDetails,
arguments: ptx_parser::CvtaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let (from_space, to_space) = match data.direction {
ptx_parser::CvtaDirection::GenericToExplicit => {
(ast::StateSpace::Generic, data.state_space)
}
ptx_parser::CvtaDirection::ExplicitToGeneric => {
(data.state_space, ast::StateSpace::Generic)
}
};
let from_type = get_pointer_type(self.context, from_space)?;
let dest_type = get_pointer_type(self.context, to_space)?;
let src = self.resolver.value(arguments.src)?;
let temp_ptr =
unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
});
Ok(())
}
fn emit_sub(
&mut self,
data: ptx_parser::ArithDetails,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
match data {
ptx_parser::ArithDetails::Integer(arith_integer) => {
self.emit_sub_integer(arith_integer, arguments)
}
ptx_parser::ArithDetails::Float(arith_float) => {
self.emit_sub_float(arith_float, arguments)
}
}
}
fn emit_sub_integer(
&mut self,
arith_integer: ptx_parser::ArithInteger,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_integer.saturate {
todo!()
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildSub(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_sub_float(
&mut self,
arith_float: ptx_parser::ArithFloat,
arguments: ptx_parser::SubArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arith_float.saturate {
todo!()
}
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFSub(self.builder, src1, src2, dst)
});
Ok(())
}
fn emit_sin(
&mut self,
_data: ptx_parser::FlushToZero,
arguments: ptx_parser::SinArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
let sin = self.emit_intrinsic(
c"llvm.sin.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(arguments.src, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
Ok(())
}
fn emit_intrinsic(
&mut self,
name: &CStr,
dst: Option<SpirvWord>,
return_type: &ast::Type,
arguments: Vec<(SpirvWord, LLVMTypeRef)>,
) -> Result<LLVMValueRef, TranslateError> {
let fn_type = get_function_type(
self.context,
iter::once(return_type),
arguments.iter().map(|(_, type_)| Ok(*type_)),
)?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
}
let mut arguments = arguments
.iter()
.map(|(arg, _)| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?;
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildCall2(
self.builder,
fn_type,
fn_,
arguments.as_mut_ptr(),
arguments.len() as u32,
dst,
)
}))
}
fn emit_neg(
&mut self,
data: ptx_parser::TypeFtz,
arguments: ptx_parser::NegArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
LLVMBuildFNeg
} else {
LLVMBuildNeg
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
llvm_fn(self.builder, src, dst)
});
Ok(())
}
fn emit_not(
&mut self,
_data: ptx_parser::ScalarType,
arguments: ptx_parser::NotArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildNot(self.builder, src, dst)
});
Ok(())
}
fn emit_setp(
&mut self,
data: ptx_parser::SetpData,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if arguments.dst2.is_some() {
todo!()
}
match data.cmp_op {
ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
self.emit_setp_int(setp_compare_int, arguments)
}
ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
self.emit_setp_float(setp_compare_float, arguments)
}
}
}
fn emit_setp_int(
&mut self,
setp: ptx_parser::SetpCompareInt,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let op = match setp {
ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildICmp(self.builder, op, src1, src2, dst1)
});
Ok(())
}
fn emit_setp_float(
&mut self,
setp: ptx_parser::SetpCompareFloat,
arguments: ptx_parser::SetpArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let op = match setp {
ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
};
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
});
Ok(())
}
fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
let predicate = self.resolver.value(cond.predicate)?;
let if_true = self.resolver.value(cond.if_true)?;
let if_false = self.resolver.value(cond.if_false)?;
unsafe {
LLVMBuildCondBr(
self.builder,
predicate,
LLVMValueAsBasicBlock(if_true),
LLVMValueAsBasicBlock(if_false),
)
};
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@ -1328,8 +1566,7 @@ fn get_function_type<'a>(
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
input_args.collect::<Result<Vec<_>, _>>()?;
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,