mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add br, setp, not, cvta, sub, neg, sin
This commit is contained in:
parent
c8b88f4483
commit
d9c33ca505
1 changed files with 259 additions and 22 deletions
|
@ -451,7 +451,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Statement::Variable(var) => self.emit_variable(var)?,
|
||||
Statement::Label(label) => self.emit_label_delayed(label)?,
|
||||
Statement::Instruction(inst) => self.emit_instruction(inst)?,
|
||||
Statement::Conditional(_) => todo!(),
|
||||
Statement::Conditional(cond) => self.emit_conditional(cond)?,
|
||||
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
|
||||
Statement::Constant(constant) => self.emit_constant(constant)?,
|
||||
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
|
||||
|
@ -515,9 +515,9 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
|
||||
ast::Instruction::Setp { .. } => todo!(),
|
||||
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
|
||||
ast::Instruction::SetpBool { .. } => todo!(),
|
||||
ast::Instruction::Not { .. } => todo!(),
|
||||
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
|
||||
ast::Instruction::Or { .. } => todo!(),
|
||||
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
|
||||
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
|
||||
|
@ -526,11 +526,11 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Shr { .. } => todo!(),
|
||||
ast::Instruction::Shl { .. } => todo!(),
|
||||
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
||||
ast::Instruction::Cvta { .. } => todo!(),
|
||||
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
|
||||
ast::Instruction::Abs { .. } => todo!(),
|
||||
ast::Instruction::Mad { .. } => todo!(),
|
||||
ast::Instruction::Fma { .. } => todo!(),
|
||||
ast::Instruction::Sub { .. } => todo!(),
|
||||
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
|
||||
ast::Instruction::Min { .. } => todo!(),
|
||||
ast::Instruction::Max { .. } => todo!(),
|
||||
ast::Instruction::Rcp { .. } => todo!(),
|
||||
|
@ -541,8 +541,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
||||
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
|
||||
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
|
||||
ast::Instruction::Neg { .. } => todo!(),
|
||||
ast::Instruction::Sin { .. } => todo!(),
|
||||
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
|
||||
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
|
||||
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
|
||||
ast::Instruction::Lg2 { .. } => todo!(),
|
||||
ast::Instruction::Ex2 { .. } => todo!(),
|
||||
|
@ -651,6 +651,16 @@ impl<'a> MethodEmitContext<'a> {
|
|||
}
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(..), ast::Type::Scalar(..))
|
||||
| (ast::Type::Scalar(..), ast::Type::Array(..))
|
||||
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(builder, src, dst_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
@ -997,20 +1007,13 @@ impl<'a> MethodEmitContext<'a> {
|
|||
_data: ast::FlushToZero,
|
||||
arguments: ast::CosArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_fn = c"llvm.cos.f32";
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
iter::once(&ast::ScalarType::F32.into()),
|
||||
iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))),
|
||||
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
|
||||
let cos = self.emit_intrinsic(
|
||||
c"llvm.cos.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(arguments.src, llvm_f32)],
|
||||
)?;
|
||||
if fn_ == ptr::null_mut() {
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
|
||||
}
|
||||
let mut src = self.resolver.value(arguments.src)?;
|
||||
let cos = self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
|
||||
});
|
||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1168,6 +1171,241 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvta(
|
||||
&mut self,
|
||||
data: ptx_parser::CvtaDetails,
|
||||
arguments: ptx_parser::CvtaArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let (from_space, to_space) = match data.direction {
|
||||
ptx_parser::CvtaDirection::GenericToExplicit => {
|
||||
(ast::StateSpace::Generic, data.state_space)
|
||||
}
|
||||
ptx_parser::CvtaDirection::ExplicitToGeneric => {
|
||||
(data.state_space, ast::StateSpace::Generic)
|
||||
}
|
||||
};
|
||||
let from_type = get_pointer_type(self.context, from_space)?;
|
||||
let dest_type = get_pointer_type(self.context, to_space)?;
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let temp_ptr =
|
||||
unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sub(
|
||||
&mut self,
|
||||
data: ptx_parser::ArithDetails,
|
||||
arguments: ptx_parser::SubArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match data {
|
||||
ptx_parser::ArithDetails::Integer(arith_integer) => {
|
||||
self.emit_sub_integer(arith_integer, arguments)
|
||||
}
|
||||
ptx_parser::ArithDetails::Float(arith_float) => {
|
||||
self.emit_sub_float(arith_float, arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_sub_integer(
|
||||
&mut self,
|
||||
arith_integer: ptx_parser::ArithInteger,
|
||||
arguments: ptx_parser::SubArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arith_integer.saturate {
|
||||
todo!()
|
||||
}
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildSub(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sub_float(
|
||||
&mut self,
|
||||
arith_float: ptx_parser::ArithFloat,
|
||||
arguments: ptx_parser::SubArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arith_float.saturate {
|
||||
todo!()
|
||||
}
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildFSub(self.builder, src1, src2, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sin(
|
||||
&mut self,
|
||||
_data: ptx_parser::FlushToZero,
|
||||
arguments: ptx_parser::SinArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
|
||||
let sin = self.emit_intrinsic(
|
||||
c"llvm.sin.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
vec![(arguments.src, llvm_f32)],
|
||||
)?;
|
||||
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_intrinsic(
|
||||
&mut self,
|
||||
name: &CStr,
|
||||
dst: Option<SpirvWord>,
|
||||
return_type: &ast::Type,
|
||||
arguments: Vec<(SpirvWord, LLVMTypeRef)>,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
iter::once(return_type),
|
||||
arguments.iter().map(|(_, type_)| Ok(*type_)),
|
||||
)?;
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||
if fn_ == ptr::null_mut() {
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
}
|
||||
let mut arguments = arguments
|
||||
.iter()
|
||||
.map(|(arg, _)| self.resolver.value(*arg))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
fn_type,
|
||||
fn_,
|
||||
arguments.as_mut_ptr(),
|
||||
arguments.len() as u32,
|
||||
dst,
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
fn emit_neg(
|
||||
&mut self,
|
||||
data: ptx_parser::TypeFtz,
|
||||
arguments: ptx_parser::NegArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
|
||||
LLVMBuildFNeg
|
||||
} else {
|
||||
LLVMBuildNeg
|
||||
};
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
llvm_fn(self.builder, src, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_not(
|
||||
&mut self,
|
||||
_data: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::NotArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildNot(self.builder, src, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp(
|
||||
&mut self,
|
||||
data: ptx_parser::SetpData,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if arguments.dst2.is_some() {
|
||||
todo!()
|
||||
}
|
||||
match data.cmp_op {
|
||||
ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
|
||||
self.emit_setp_int(setp_compare_int, arguments)
|
||||
}
|
||||
ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
|
||||
self.emit_setp_float(setp_compare_float, arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_setp_int(
|
||||
&mut self,
|
||||
setp: ptx_parser::SetpCompareInt,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let op = match setp {
|
||||
ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
|
||||
ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
|
||||
ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
|
||||
ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
|
||||
ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
|
||||
ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
|
||||
ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
|
||||
ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
|
||||
ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
|
||||
ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
|
||||
LLVMBuildICmp(self.builder, op, src1, src2, dst1)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp_float(
|
||||
&mut self,
|
||||
setp: ptx_parser::SetpCompareFloat,
|
||||
arguments: ptx_parser::SetpArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let op = match setp {
|
||||
ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
|
||||
ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
|
||||
ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
|
||||
ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
|
||||
ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
|
||||
ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
|
||||
ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
|
||||
ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
|
||||
ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
|
||||
ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
|
||||
ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
|
||||
ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
|
||||
ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
|
||||
ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
|
||||
};
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
self.resolver.with_result(arguments.dst1, |dst1| unsafe {
|
||||
LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
|
||||
let predicate = self.resolver.value(cond.predicate)?;
|
||||
let if_true = self.resolver.value(cond.if_true)?;
|
||||
let if_false = self.resolver.value(cond.if_false)?;
|
||||
unsafe {
|
||||
LLVMBuildCondBr(
|
||||
self.builder,
|
||||
predicate,
|
||||
LLVMValueAsBasicBlock(if_true),
|
||||
LLVMValueAsBasicBlock(if_false),
|
||||
)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
@ -1328,8 +1566,7 @@ fn get_function_type<'a>(
|
|||
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
|
||||
input_args.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
|
||||
let return_type = match return_args.len() {
|
||||
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||
1 => get_type(context, return_args.next().unwrap())?,
|
||||
|
|
Loading…
Add table
Reference in a new issue