mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Implement activemask
This commit is contained in:
parent
81baecf2c8
commit
3942afd8ff
4 changed files with 249 additions and 3 deletions
|
@ -436,9 +436,10 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
ast::Instruction::Bfi { data, arguments } => todo!(),
|
||||
ast::Instruction::PrmtSlow { arguments } => todo!(),
|
||||
ast::Instruction::Prmt { data, arguments } => todo!(),
|
||||
ast::Instruction::Activemask { arguments } => todo!(),
|
||||
ast::Instruction::Membar { data } => todo!(),
|
||||
ast::Instruction::Trap {} => todo!(),
|
||||
// replaced by a function call
|
||||
ast::Instruction::Activemask { arguments } => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -478,7 +479,20 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
match conversion.kind {
|
||||
ConversionKind::Default => todo!(),
|
||||
ConversionKind::Default => {
|
||||
let from_layout = conversion.from_type.layout();
|
||||
let to_layout = conversion.to_type.layout();
|
||||
if from_layout.size() == to_layout.size() {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let type_ = get_type(self.context, &conversion.to_type)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(builder, src, type_, dst)
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
},
|
||||
ConversionKind::SignExtend => todo!(),
|
||||
ConversionKind::BitToPtr => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
|
|
|
@ -37,6 +37,7 @@ mod normalize_identifiers2;
|
|||
mod normalize_labels;
|
||||
mod normalize_predicates;
|
||||
mod normalize_predicates2;
|
||||
mod replace_instructions_with_function_calls;
|
||||
mod resolve_function_pointers;
|
||||
|
||||
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
|
@ -90,6 +91,7 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
|||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||
Ok(Module {
|
||||
|
|
179
ptx/src/pass/replace_instructions_with_function_calls.rs
Normal file
179
ptx/src/pass/replace_instructions_with_function_calls.rs
Normal file
|
@ -0,0 +1,179 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut fn_declarations = FxHashMap::default();
|
||||
let remapped_directives = directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, &mut fn_declarations, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut result = fn_declarations
|
||||
.into_iter()
|
||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||
Directive2::Method(Function2 {
|
||||
func_decl: ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
name: ast::MethodName::Func(name),
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
},
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
result.extend(remapped_directives);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
method.body = method
|
||||
.body
|
||||
.map(|statements| run_statements(resolver, fn_declarations, statements))
|
||||
.transpose()?;
|
||||
Directive2::Method(method)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statements<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
Ok(match statement {
|
||||
Statement::Instruction(instruction) => {
|
||||
Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
|
||||
}
|
||||
s => s,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_instruction<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
instruction: ptx_parser::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
Ok(match instruction {
|
||||
i @ ptx_parser::Instruction::Activemask { .. } => {
|
||||
to_call(resolver, fn_declarations, "activemask".into(), i)?
|
||||
}
|
||||
i => i,
|
||||
})
|
||||
}
|
||||
|
||||
fn to_call<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_declarations: &mut FxHashMap<
|
||||
Cow<'input, str>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
name: Cow<'input, str>,
|
||||
i: ast::Instruction<SpirvWord>,
|
||||
) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
|
||||
let mut data_return = Vec::new();
|
||||
let mut data_input = Vec::new();
|
||||
let mut arguments_return = Vec::new();
|
||||
let mut arguments_input = Vec::new();
|
||||
ast::visit(&i, &mut |name: &SpirvWord,
|
||||
type_space: Option<(
|
||||
&ptx_parser::Type,
|
||||
ptx_parser::StateSpace,
|
||||
)>,
|
||||
is_dst: bool,
|
||||
_: bool| {
|
||||
let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
|
||||
if is_dst {
|
||||
data_return.push((type_.clone(), space));
|
||||
arguments_return.push(*name);
|
||||
} else {
|
||||
data_input.push((type_.clone(), space));
|
||||
arguments_input.push(*name);
|
||||
};
|
||||
Ok::<_, TranslateError>(())
|
||||
})?;
|
||||
let fn_name = match fn_declarations.entry(name) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let name = vacant_entry.key().clone();
|
||||
let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
|
||||
let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
|
||||
vacant_entry.insert((
|
||||
to_variables(resolver, &data_return),
|
||||
name,
|
||||
to_variables(resolver, &data_input),
|
||||
));
|
||||
name
|
||||
}
|
||||
};
|
||||
Ok(ast::Instruction::Call {
|
||||
data: ptx_parser::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: data_return,
|
||||
input_arguments: data_input,
|
||||
},
|
||||
arguments: ptx_parser::CallArgs {
|
||||
return_arguments: arguments_return,
|
||||
func: fn_name,
|
||||
input_arguments: arguments_input,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn to_variables<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
) -> Vec<ptx_parser::Variable<SpirvWord>> {
|
||||
arguments
|
||||
.iter()
|
||||
.map(|(type_, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: type_.clone(),
|
||||
state_space: *space,
|
||||
name: resolver.register_unnamed(Some((type_.clone(), *space))),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
|
@ -4,7 +4,7 @@ use super::{
|
|||
};
|
||||
use crate::{PtxError, PtxParserState};
|
||||
use bitflags::bitflags;
|
||||
use std::{cmp::Ordering, num::NonZeroU8};
|
||||
use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
|
||||
|
||||
pub enum Statement<P: Operand> {
|
||||
Label(P::Ident),
|
||||
|
@ -806,6 +806,32 @@ impl Type {
|
|||
None => Self::maybe_vector_parsed(prefix, scalar),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layout(&self) -> Layout {
|
||||
match self {
|
||||
Type::Scalar(type_) => type_.layout(),
|
||||
Type::Vector(elements, scalar_type) => {
|
||||
let scalar_layout = scalar_type.layout();
|
||||
unsafe {
|
||||
Layout::from_size_align_unchecked(
|
||||
scalar_layout.size() * *elements as usize,
|
||||
scalar_layout.align() * *elements as usize,
|
||||
)
|
||||
}
|
||||
}
|
||||
Type::Array(non_zero, scalar, vec) => {
|
||||
let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout();
|
||||
let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0);
|
||||
unsafe {
|
||||
Layout::from_size_align_unchecked(
|
||||
element_layout.size() * (len as usize),
|
||||
element_layout.align(),
|
||||
)
|
||||
}
|
||||
}
|
||||
Type::Pointer(..) => Layout::new::<usize>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
|
@ -831,6 +857,31 @@ impl ScalarType {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn layout(self) -> Layout {
|
||||
match self {
|
||||
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::<u8>(),
|
||||
ScalarType::U16
|
||||
| ScalarType::S16
|
||||
| ScalarType::B16
|
||||
| ScalarType::F16
|
||||
| ScalarType::BF16 => Layout::new::<u16>(),
|
||||
ScalarType::U32
|
||||
| ScalarType::S32
|
||||
| ScalarType::B32
|
||||
| ScalarType::F32
|
||||
| ScalarType::U16x2
|
||||
| ScalarType::S16x2
|
||||
| ScalarType::F16x2
|
||||
| ScalarType::BF16x2 => Layout::new::<u32>(),
|
||||
ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => {
|
||||
Layout::new::<u64>()
|
||||
}
|
||||
ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) },
|
||||
// Close enough
|
||||
ScalarType::Pred => Layout::new::<u8>(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ScalarType::U8 => ScalarKind::Unsigned,
|
||||
|
|
Loading…
Add table
Reference in a new issue