This commit is contained in:
Andrzej Janik 2024-10-05 00:52:16 +02:00
parent bf2aef9be0
commit 6456f0d1a1

View file

@ -293,7 +293,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn emit_global(
&mut self,
linking: ast::LinkingDirective,
var: ptx_parser::Variable<SpirvWord>,
var: ast::Variable<SpirvWord>,
) -> Result<(), TranslateError> {
let name = self
.id_defs
@ -330,14 +330,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn get_input_argument_type(
context: LLVMContextRef,
v_type: &ptx_parser::Type,
state_space: ptx_parser::StateSpace,
v_type: &ast::Type,
state_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
ptx_parser::StateSpace::ParamEntry => {
ast::StateSpace::ParamEntry => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
ptx_parser::StateSpace::Reg => get_type(context, v_type),
ast::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
}
}
@ -481,19 +481,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Lg2 { data, arguments } => todo!(),
ast::Instruction::Ex2 { data, arguments } => todo!(),
ast::Instruction::Clz { data, arguments } => todo!(),
ast::Instruction::Brev { data, arguments } => todo!(),
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
ast::Instruction::Popc { data, arguments } => todo!(),
ast::Instruction::Xor { data, arguments } => todo!(),
ast::Instruction::Rem { data, arguments } => todo!(),
ast::Instruction::Bfi { data, arguments } => todo!(),
ast::Instruction::PrmtSlow { arguments } => todo!(),
ast::Instruction::Prmt { data, arguments } => todo!(),
ast::Instruction::Membar { data } => todo!(),
ast::Instruction::Trap {} => todo!(),
// replaced by a function call
ast::Instruction::Bfe { .. } | ast::Instruction::Activemask { .. } => {
return Err(error_unreachable())
}
ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}
}
@ -597,8 +596,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_st(
&self,
data: ptx_parser::StData,
arguments: ptx_parser::StArgs<SpirvWord>,
data: ast::StData,
arguments: ast::StArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?;
@ -609,14 +608,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Ok(())
}
fn emit_ret(&self, _data: ptx_parser::RetData) {
fn emit_ret(&self, _data: ast::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
fn emit_call(
&mut self,
data: ptx_parser::CallDetails,
arguments: ptx_parser::CallArgs<SpirvWord>,
data: ast::CallDetails,
arguments: ast::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
@ -669,8 +668,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_mov(
&mut self,
_data: ptx_parser::MovDetails,
arguments: ptx_parser::MovArgs<SpirvWord>,
_data: ast::MovDetails,
arguments: ast::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
@ -699,35 +698,31 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_atom(
&mut self,
data: ptx_parser::AtomDetails,
arguments: ptx_parser::AtomArgs<SpirvWord>,
data: ast::AtomDetails,
arguments: ast::AtomArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let builder = self.builder;
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let op = match data.op {
ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
ptx_parser::AtomicOp::IncrementWrap => {
ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
ast::AtomicOp::IncrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
}
ptx_parser::AtomicOp::DecrementWrap => {
ast::AtomicOp::DecrementWrap => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
}
ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
ptx_parser::AtomicOp::UnsignedMin => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
}
ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
ptx_parser::AtomicOp::UnsignedMax => {
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
}
ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin,
ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax,
ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
};
self.resolver.register(arguments.dst, unsafe {
LLVMZludaBuildAtomicRMW(
@ -744,8 +739,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_atom_cas(
&mut self,
data: ptx_parser::AtomCasDetails,
arguments: ptx_parser::AtomCasArgs<SpirvWord>,
data: ast::AtomCasDetails,
arguments: ast::AtomCasArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
@ -769,12 +764,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Ok(())
}
fn emit_bra(&self, arguments: ptx_parser::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
fn emit_bra(&self, arguments: ast::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
let src = self.resolver.value(arguments.src)?;
let src = unsafe { LLVMValueAsBasicBlock(src) };
unsafe { LLVMBuildBr(self.builder, src) };
Ok(())
}
fn emit_brev(
&mut self,
data: ast::ScalarType,
arguments: ast::BrevArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let llvm_fn = match data.size_of() {
4 => c"llvm.bitreverse.i32",
8 => c"llvm.bitreverse.i64",
_ => return Err(error_unreachable()),
};
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
let type_ = get_scalar_type(self.context, data);
let fn_type = get_function_type(
self.context,
iter::once(&data.into()),
iter::once(Ok(type_)),
)?;
if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
}
let mut src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
});
Ok(())
}
}
fn get_pointer_type<'ctx>(