mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add brev
This commit is contained in:
parent
bf2aef9be0
commit
6456f0d1a1
1 changed files with 62 additions and 40 deletions
|
@ -293,7 +293,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
fn emit_global(
|
||||
&mut self,
|
||||
linking: ast::LinkingDirective,
|
||||
var: ptx_parser::Variable<SpirvWord>,
|
||||
var: ast::Variable<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let name = self
|
||||
.id_defs
|
||||
|
@ -330,14 +330,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
|
||||
fn get_input_argument_type(
|
||||
context: LLVMContextRef,
|
||||
v_type: &ptx_parser::Type,
|
||||
state_space: ptx_parser::StateSpace,
|
||||
v_type: &ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
match state_space {
|
||||
ptx_parser::StateSpace::ParamEntry => {
|
||||
ast::StateSpace::ParamEntry => {
|
||||
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => get_type(context, v_type),
|
||||
ast::StateSpace::Reg => get_type(context, v_type),
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
@ -481,19 +481,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
ast::Instruction::Lg2 { data, arguments } => todo!(),
|
||||
ast::Instruction::Ex2 { data, arguments } => todo!(),
|
||||
ast::Instruction::Clz { data, arguments } => todo!(),
|
||||
ast::Instruction::Brev { data, arguments } => todo!(),
|
||||
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
|
||||
ast::Instruction::Popc { data, arguments } => todo!(),
|
||||
ast::Instruction::Xor { data, arguments } => todo!(),
|
||||
ast::Instruction::Rem { data, arguments } => todo!(),
|
||||
ast::Instruction::Bfi { data, arguments } => todo!(),
|
||||
ast::Instruction::PrmtSlow { arguments } => todo!(),
|
||||
ast::Instruction::Prmt { data, arguments } => todo!(),
|
||||
ast::Instruction::Membar { data } => todo!(),
|
||||
ast::Instruction::Trap {} => todo!(),
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. } | ast::Instruction::Activemask { .. } => {
|
||||
return Err(error_unreachable())
|
||||
}
|
||||
ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bfi { .. }
|
||||
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -597,8 +596,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
|
||||
fn emit_st(
|
||||
&self,
|
||||
data: ptx_parser::StData,
|
||||
arguments: ptx_parser::StArgs<SpirvWord>,
|
||||
data: ast::StData,
|
||||
arguments: ast::StArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let ptr = self.resolver.value(arguments.src1)?;
|
||||
let value = self.resolver.value(arguments.src2)?;
|
||||
|
@ -609,14 +608,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
||||
fn emit_ret(&self, _data: ast::RetData) {
|
||||
unsafe { LLVMBuildRetVoid(self.builder) };
|
||||
}
|
||||
|
||||
fn emit_call(
|
||||
&mut self,
|
||||
data: ptx_parser::CallDetails,
|
||||
arguments: ptx_parser::CallArgs<SpirvWord>,
|
||||
data: ast::CallDetails,
|
||||
arguments: ast::CallArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if cfg!(debug_assertions) {
|
||||
for (_, space) in data.return_arguments.iter() {
|
||||
|
@ -669,8 +668,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
|
||||
fn emit_mov(
|
||||
&mut self,
|
||||
_data: ptx_parser::MovDetails,
|
||||
arguments: ptx_parser::MovArgs<SpirvWord>,
|
||||
_data: ast::MovDetails,
|
||||
arguments: ast::MovArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.resolver
|
||||
.register(arguments.dst, self.resolver.value(arguments.src)?);
|
||||
|
@ -699,35 +698,31 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
|
||||
fn emit_atom(
|
||||
&mut self,
|
||||
data: ptx_parser::AtomDetails,
|
||||
arguments: ptx_parser::AtomArgs<SpirvWord>,
|
||||
data: ast::AtomDetails,
|
||||
arguments: ast::AtomArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let op = match data.op {
|
||||
ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
|
||||
ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
|
||||
ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
|
||||
ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
|
||||
ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
|
||||
ptx_parser::AtomicOp::IncrementWrap => {
|
||||
ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
|
||||
ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
|
||||
ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
|
||||
ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
|
||||
ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
|
||||
ast::AtomicOp::IncrementWrap => {
|
||||
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
|
||||
}
|
||||
ptx_parser::AtomicOp::DecrementWrap => {
|
||||
ast::AtomicOp::DecrementWrap => {
|
||||
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
|
||||
}
|
||||
ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
|
||||
ptx_parser::AtomicOp::UnsignedMin => {
|
||||
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
|
||||
}
|
||||
ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
|
||||
ptx_parser::AtomicOp::UnsignedMax => {
|
||||
LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
|
||||
}
|
||||
ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
|
||||
ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
|
||||
ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
|
||||
ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
|
||||
ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin,
|
||||
ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
|
||||
ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax,
|
||||
ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
|
||||
ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
|
||||
ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
|
||||
};
|
||||
self.resolver.register(arguments.dst, unsafe {
|
||||
LLVMZludaBuildAtomicRMW(
|
||||
|
@ -744,8 +739,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
|
||||
fn emit_atom_cas(
|
||||
&mut self,
|
||||
data: ptx_parser::AtomCasDetails,
|
||||
arguments: ptx_parser::AtomCasArgs<SpirvWord>,
|
||||
data: ast::AtomCasDetails,
|
||||
arguments: ast::AtomCasArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
|
@ -769,12 +764,39 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_bra(&self, arguments: ptx_parser::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
|
||||
fn emit_bra(&self, arguments: ast::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
|
||||
let src = self.resolver.value(arguments.src)?;
|
||||
let src = unsafe { LLVMValueAsBasicBlock(src) };
|
||||
unsafe { LLVMBuildBr(self.builder, src) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_brev(
|
||||
&mut self,
|
||||
data: ast::ScalarType,
|
||||
arguments: ast::BrevArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let llvm_fn = match data.size_of() {
|
||||
4 => c"llvm.bitreverse.i32",
|
||||
8 => c"llvm.bitreverse.i64",
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
|
||||
let type_ = get_scalar_type(self.context, data);
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
iter::once(&data.into()),
|
||||
iter::once(Ok(type_)),
|
||||
)?;
|
||||
if fn_ == ptr::null_mut() {
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
|
||||
}
|
||||
let mut src = self.resolver.value(arguments.src)?;
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
|
|
Loading…
Add table
Reference in a new issue