diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 15177bc..ce1eb84 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -451,7 +451,7 @@ impl<'a> MethodEmitContext<'a> { Statement::Variable(var) => self.emit_variable(var)?, Statement::Label(label) => self.emit_label_delayed(label)?, Statement::Instruction(inst) => self.emit_instruction(inst)?, - Statement::Conditional(_) => todo!(), + Statement::Conditional(cond) => self.emit_conditional(cond)?, Statement::Conversion(conversion) => self.emit_conversion(conversion)?, Statement::Constant(constant) => self.emit_constant(constant)?, Statement::RetValue(_, values) => self.emit_ret_value(values)?, @@ -515,9 +515,9 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), - ast::Instruction::Setp { .. } => todo!(), + ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), ast::Instruction::SetpBool { .. } => todo!(), - ast::Instruction::Not { .. } => todo!(), + ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), ast::Instruction::Or { .. } => todo!(), ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments), @@ -526,11 +526,11 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Shr { .. } => todo!(), ast::Instruction::Shl { .. } => todo!(), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), - ast::Instruction::Cvta { .. } => todo!(), + ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Abs { .. } => todo!(), ast::Instruction::Mad { .. } => todo!(), ast::Instruction::Fma { .. } => todo!(), - ast::Instruction::Sub { .. } => todo!(), + ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), ast::Instruction::Min { .. } => todo!(), ast::Instruction::Max { .. } => todo!(), ast::Instruction::Rcp { .. } => todo!(), @@ -541,8 +541,8 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), - ast::Instruction::Neg { .. } => todo!(), - ast::Instruction::Sin { .. } => todo!(), + ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), + ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), ast::Instruction::Lg2 { .. } => todo!(), ast::Instruction::Ex2 { .. } => todo!(), @@ -651,6 +651,16 @@ impl<'a> MethodEmitContext<'a> { } } } + (ast::Type::Vector(..), ast::Type::Scalar(..)) + | (ast::Type::Scalar(..), ast::Type::Array(..)) + | (ast::Type::Array(..), ast::Type::Scalar(..)) => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildBitCast(builder, src, dst_type, dst) + }); + Ok(()) + } _ => todo!(), } } @@ -997,20 +1007,13 @@ impl<'a> MethodEmitContext<'a> { _data: ast::FlushToZero, arguments: ast::CosArgs, ) -> Result<(), TranslateError> { - let llvm_fn = c"llvm.cos.f32"; - let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; - let fn_type = get_function_type( - self.context, - iter::once(&ast::ScalarType::F32.into()), - iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))), + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let cos = self.emit_intrinsic( + c"llvm.cos.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![(arguments.src, llvm_f32)], )?; - if fn_ == ptr::null_mut() { - fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; - } - let mut src = self.resolver.value(arguments.src)?; - let cos = self.resolver.with_result(arguments.dst, |dst| unsafe { - LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) - }); unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } Ok(()) } @@ -1168,6 +1171,241 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cvta( + &mut self, + data: ptx_parser::CvtaDetails, + arguments: ptx_parser::CvtaArgs, + ) -> Result<(), TranslateError> { + let (from_space, to_space) = match data.direction { + ptx_parser::CvtaDirection::GenericToExplicit => { + (ast::StateSpace::Generic, data.state_space) + } + ptx_parser::CvtaDirection::ExplicitToGeneric => { + (data.state_space, ast::StateSpace::Generic) + } + }; + let from_type = get_pointer_type(self.context, from_space)?; + let dest_type = get_pointer_type(self.context, to_space)?; + let src = self.resolver.value(arguments.src)?; + let temp_ptr = + unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst) + }); + Ok(()) + } + + fn emit_sub( + &mut self, + data: ptx_parser::ArithDetails, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + match data { + ptx_parser::ArithDetails::Integer(arith_integer) => { + self.emit_sub_integer(arith_integer, arguments) + } + ptx_parser::ArithDetails::Float(arith_float) => { + self.emit_sub_float(arith_float, arguments) + } + } + } + + fn emit_sub_integer( + &mut self, + arith_integer: ptx_parser::ArithInteger, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + if arith_integer.saturate { + todo!() + } + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sub_float( + &mut self, + arith_float: ptx_parser::ArithFloat, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + if arith_float.saturate { + todo!() + } + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sin( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::SinArgs, + ) -> Result<(), TranslateError> { + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let sin = self.emit_intrinsic( + c"llvm.sin.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![(arguments.src, llvm_f32)], + )?; + unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_intrinsic( + &mut self, + name: &CStr, + dst: Option, + return_type: &ast::Type, + arguments: Vec<(SpirvWord, LLVMTypeRef)>, + ) -> Result { + let fn_type = get_function_type( + self.context, + iter::once(return_type), + arguments.iter().map(|(_, type_)| Ok(*type_)), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + } + let mut arguments = arguments + .iter() + .map(|(arg, _)| self.resolver.value(*arg)) + .collect::, _>>()?; + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + arguments.as_mut_ptr(), + arguments.len() as u32, + dst, + ) + })) + } + + fn emit_neg( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::NegArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float { + LLVMBuildFNeg + } else { + LLVMBuildNeg + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst) + }); + Ok(()) + } + + fn emit_not( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::NotArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildNot(self.builder, src, dst) + }); + Ok(()) + } + + fn emit_setp( + &mut self, + data: ptx_parser::SetpData, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + if arguments.dst2.is_some() { + todo!() + } + match data.cmp_op { + ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { + self.emit_setp_int(setp_compare_int, arguments) + } + ptx_parser::SetpCompareOp::Float(setp_compare_float) => { + self.emit_setp_float(setp_compare_float, arguments) + } + } + } + + fn emit_setp_int( + &mut self, + setp: ptx_parser::SetpCompareInt, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + let op = match setp { + ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, + ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, + ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT, + ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE, + ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT, + ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE, + ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT, + ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE, + ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, + ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst1, |dst1| unsafe { + LLVMBuildICmp(self.builder, op, src1, src2, dst1) + }); + Ok(()) + } + + fn emit_setp_float( + &mut self, + setp: ptx_parser::SetpCompareFloat, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + let op = match setp { + ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, + ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, + ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT, + ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE, + ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT, + ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE, + ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ, + ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE, + ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT, + ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE, + ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT, + ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE, + ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, + ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst1, |dst1| unsafe { + LLVMBuildFCmp(self.builder, op, src1, src2, dst1) + }); + Ok(()) + } + + fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { + let predicate = self.resolver.value(cond.predicate)?; + let if_true = self.resolver.value(cond.if_true)?; + let if_false = self.resolver.value(cond.if_false)?; + unsafe { + LLVMBuildCondBr( + self.builder, + predicate, + LLVMValueAsBasicBlock(if_true), + LLVMValueAsBasicBlock(if_false), + ) + }; + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 @@ -1328,8 +1566,7 @@ fn get_function_type<'a>( mut return_args: impl ExactSizeIterator, input_args: impl ExactSizeIterator>, ) -> Result { - let mut input_args: Vec<*mut llvm_zluda::LLVMType> = - input_args.collect::, _>>()?; + let mut input_args = input_args.collect::, _>>()?; let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, 1 => get_type(context, return_args.next().unwrap())?,