Implement activemask

This commit is contained in:
Andrzej Janik 2024-09-25 15:54:32 +02:00
parent 81baecf2c8
commit 3942afd8ff
4 changed files with 249 additions and 3 deletions

View file

@ -436,9 +436,10 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Bfi { data, arguments } => todo!(),
ast::Instruction::PrmtSlow { arguments } => todo!(),
ast::Instruction::Prmt { data, arguments } => todo!(),
ast::Instruction::Activemask { arguments } => todo!(),
ast::Instruction::Membar { data } => todo!(),
ast::Instruction::Trap {} => todo!(),
// replaced by a function call
ast::Instruction::Activemask { arguments } => return Err(error_unreachable()),
}
}
@ -478,7 +479,20 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
let builder = self.builder;
match conversion.kind {
ConversionKind::Default => todo!(),
ConversionKind::Default => {
let from_layout = conversion.from_type.layout();
let to_layout = conversion.to_type.layout();
if from_layout.size() == to_layout.size() {
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildBitCast(builder, src, type_, dst)
});
Ok(())
} else {
todo!()
}
},
ConversionKind::SignExtend => todo!(),
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;

View file

@ -37,6 +37,7 @@ mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
mod normalize_predicates2;
mod replace_instructions_with_function_calls;
mod resolve_function_pointers;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@ -90,6 +91,7 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {

View file

@ -0,0 +1,179 @@
use super::*;
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives
.into_iter()
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
.collect::<Result<Vec<_>, _>>()?;
let mut result = fn_declarations
.into_iter()
.map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration {
return_arguments,
name: ast::MethodName::Func(name),
input_arguments,
shared_mem: None,
},
globals: Vec::new(),
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
})
})
.collect::<Vec<_>>();
result.extend(remapped_directives);
Ok(result)
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {
method.body = method
.body
.map(|statements| run_statements(resolver, fn_declarations, statements))
.transpose()?;
Directive2::Method(method)
}
})
}
fn run_statements<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
statements
.into_iter()
.map(|statement| {
Ok(match statement {
Statement::Instruction(instruction) => {
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
}
s => s,
})
})
.collect::<Result<Vec<_>, _>>()
}
fn run_instruction<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
instruction: ptx_parser::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
Ok(match instruction {
i @ ptx_parser::Instruction::Activemask { .. } => {
to_call(resolver, fn_declarations, "activemask".into(), i)?
}
i => i,
})
}
fn to_call<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_declarations: &mut FxHashMap<
Cow<'input, str>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
>,
name: Cow<'input, str>,
i: ast::Instruction<SpirvWord>,
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
let mut data_return = Vec::new();
let mut data_input = Vec::new();
let mut arguments_return = Vec::new();
let mut arguments_input = Vec::new();
ast::visit(&i, &mut |name: &SpirvWord,
type_space: Option<(
&ptx_parser::Type,
ptx_parser::StateSpace,
)>,
is_dst: bool,
_: bool| {
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
if is_dst {
data_return.push((type_.clone(), space));
arguments_return.push(*name);
} else {
data_input.push((type_.clone(), space));
arguments_input.push(*name);
};
Ok::<_, TranslateError>(())
})?;
let fn_name = match fn_declarations.entry(name) {
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
hash_map::Entry::Vacant(vacant_entry) => {
let name = vacant_entry.key().clone();
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
vacant_entry.insert((
to_variables(resolver, &data_return),
name,
to_variables(resolver, &data_input),
));
name
}
};
Ok(ast::Instruction::Call {
data: ptx_parser::CallDetails {
uniform: false,
return_arguments: data_return,
input_arguments: data_input,
},
arguments: ptx_parser::CallArgs {
return_arguments: arguments_return,
func: fn_name,
input_arguments: arguments_input,
},
})
}
fn to_variables<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
) -> Vec<ptx_parser::Variable<SpirvWord>> {
arguments
.iter()
.map(|(type_, space)| ast::Variable {
align: None,
v_type: type_.clone(),
state_space: *space,
name: resolver.register_unnamed(Some((type_.clone(), *space))),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}

View file

@ -4,7 +4,7 @@ use super::{
};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
use std::{cmp::Ordering, num::NonZeroU8};
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
pub enum Statement<P: Operand> {
Label(P::Ident),
@ -806,6 +806,32 @@ impl Type {
None => Self::maybe_vector_parsed(prefix, scalar),
}
}
pub fn layout(&self) -> Layout {
match self {
Type::Scalar(type_) => type_.layout(),
Type::Vector(elements, scalar_type) => {
let scalar_layout = scalar_type.layout();
unsafe {
Layout::from_size_align_unchecked(
scalar_layout.size() * *elements as usize,
scalar_layout.align() * *elements as usize,
)
}
}
Type::Array(non_zero, scalar, vec) => {
let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout();
let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0);
unsafe {
Layout::from_size_align_unchecked(
element_layout.size() * (len as usize),
element_layout.align(),
)
}
}
Type::Pointer(..) => Layout::new::<usize>(),
}
}
}
impl ScalarType {
@ -831,6 +857,31 @@ impl ScalarType {
}
}
pub fn layout(self) -> Layout {
match self {
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::<u8>(),
ScalarType::U16
| ScalarType::S16
| ScalarType::B16
| ScalarType::F16
| ScalarType::BF16 => Layout::new::<u16>(),
ScalarType::U32
| ScalarType::S32
| ScalarType::B32
| ScalarType::F32
| ScalarType::U16x2
| ScalarType::S16x2
| ScalarType::F16x2
| ScalarType::BF16x2 => Layout::new::<u32>(),
ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => {
Layout::new::<u64>()
}
ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) },
// Close enough
ScalarType::Pred => Layout::new::<u8>(),
}
}
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,