diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 4d4142c..dd0c901 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -293,7 +293,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn emit_global( &mut self, linking: ast::LinkingDirective, - var: ptx_parser::Variable, + var: ast::Variable, ) -> Result<(), TranslateError> { let name = self .id_defs @@ -330,14 +330,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn get_input_argument_type( context: LLVMContextRef, - v_type: &ptx_parser::Type, - state_space: ptx_parser::StateSpace, + v_type: &ast::Type, + state_space: ast::StateSpace, ) -> Result { match state_space { - ptx_parser::StateSpace::ParamEntry => { + ast::StateSpace::ParamEntry => { Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) } - ptx_parser::StateSpace::Reg => get_type(context, v_type), + ast::StateSpace::Reg => get_type(context, v_type), _ => return Err(error_unreachable()), } } @@ -481,19 +481,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Lg2 { data, arguments } => todo!(), ast::Instruction::Ex2 { data, arguments } => todo!(), ast::Instruction::Clz { data, arguments } => todo!(), - ast::Instruction::Brev { data, arguments } => todo!(), + ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), ast::Instruction::Popc { data, arguments } => todo!(), ast::Instruction::Xor { data, arguments } => todo!(), ast::Instruction::Rem { data, arguments } => todo!(), - ast::Instruction::Bfi { data, arguments } => todo!(), ast::Instruction::PrmtSlow { arguments } => todo!(), ast::Instruction::Prmt { data, arguments } => todo!(), ast::Instruction::Membar { data } => todo!(), ast::Instruction::Trap {} => todo!(), // replaced by a function call - ast::Instruction::Bfe { .. } | ast::Instruction::Activemask { .. } => { - return Err(error_unreachable()) - } + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } } @@ -597,8 +596,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_st( &self, - data: ptx_parser::StData, - arguments: ptx_parser::StArgs, + data: ast::StData, + arguments: ast::StArgs, ) -> Result<(), TranslateError> { let ptr = self.resolver.value(arguments.src1)?; let value = self.resolver.value(arguments.src2)?; @@ -609,14 +608,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_ret(&self, _data: ptx_parser::RetData) { + fn emit_ret(&self, _data: ast::RetData) { unsafe { LLVMBuildRetVoid(self.builder) }; } fn emit_call( &mut self, - data: ptx_parser::CallDetails, - arguments: ptx_parser::CallArgs, + data: ast::CallDetails, + arguments: ast::CallArgs, ) -> Result<(), TranslateError> { if cfg!(debug_assertions) { for (_, space) in data.return_arguments.iter() { @@ -669,8 +668,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_mov( &mut self, - _data: ptx_parser::MovDetails, - arguments: ptx_parser::MovArgs, + _data: ast::MovDetails, + arguments: ast::MovArgs, ) -> Result<(), TranslateError> { self.resolver .register(arguments.dst, self.resolver.value(arguments.src)?); @@ -699,35 +698,31 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_atom( &mut self, - data: ptx_parser::AtomDetails, - arguments: ptx_parser::AtomArgs, + data: ast::AtomDetails, + arguments: ast::AtomArgs, ) -> Result<(), TranslateError> { let builder = self.builder; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let op = match data.op { - ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, - ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, - ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, - ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, - ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, - ptx_parser::AtomicOp::IncrementWrap => { + ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, + ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, + ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, + ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, + ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, + ast::AtomicOp::IncrementWrap => { LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap } - ptx_parser::AtomicOp::DecrementWrap => { + ast::AtomicOp::DecrementWrap => { LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap } - ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, - ptx_parser::AtomicOp::UnsignedMin => { - LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin - } - ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, - ptx_parser::AtomicOp::UnsignedMax => { - LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax - } - ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, - ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, - ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, + ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, + ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin, + ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, + ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax, + ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, + ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, + ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, }; self.resolver.register(arguments.dst, unsafe { LLVMZludaBuildAtomicRMW( @@ -744,8 +739,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_atom_cas( &mut self, - data: ptx_parser::AtomCasDetails, - arguments: ptx_parser::AtomCasArgs, + data: ast::AtomCasDetails, + arguments: ast::AtomCasArgs, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; @@ -769,12 +764,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_bra(&self, arguments: ptx_parser::BraArgs) -> Result<(), TranslateError> { + fn emit_bra(&self, arguments: ast::BraArgs) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; let src = unsafe { LLVMValueAsBasicBlock(src) }; unsafe { LLVMBuildBr(self.builder, src) }; Ok(()) } + + fn emit_brev( + &mut self, + data: ast::ScalarType, + arguments: ast::BrevArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.bitreverse.i32", + 8 => c"llvm.bitreverse.i64", + _ => return Err(error_unreachable()), + }; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + let type_ = get_scalar_type(self.context, data); + let fn_type = get_function_type( + self.context, + iter::once(&data.into()), + iter::once(Ok(type_)), + )?; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let mut src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) + }); + Ok(()) + } } fn get_pointer_type<'ctx>(