diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 235ad7d..65ab918 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -436,9 +436,10 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Bfi { data, arguments } => todo!(), ast::Instruction::PrmtSlow { arguments } => todo!(), ast::Instruction::Prmt { data, arguments } => todo!(), - ast::Instruction::Activemask { arguments } => todo!(), ast::Instruction::Membar { data } => todo!(), ast::Instruction::Trap {} => todo!(), + // replaced by a function call + ast::Instruction::Activemask { arguments } => return Err(error_unreachable()), } } @@ -478,7 +479,20 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { let builder = self.builder; match conversion.kind { - ConversionKind::Default => todo!(), + ConversionKind::Default => { + let from_layout = conversion.from_type.layout(); + let to_layout = conversion.to_type.layout(); + if from_layout.size() == to_layout.size() { + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildBitCast(builder, src, type_, dst) + }); + Ok(()) + } else { + todo!() + } + }, ConversionKind::SignExtend => todo!(), ConversionKind::BitToPtr => { let src = self.resolver.value(conversion.src)?; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 7ba9ed0..ff7e2ad 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -37,6 +37,7 @@ mod normalize_identifiers2; mod normalize_labels; mod normalize_predicates; mod normalize_predicates2; +mod replace_instructions_with_function_calls; mod resolve_function_pointers; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -90,6 +91,7 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut fn_declarations = FxHashMap::default(); + let remapped_directives = directives + .into_iter() + .map(|directive| run_directive(resolver, &mut fn_declarations, directive)) + .collect::, _>>()?; + let mut result = fn_declarations + .into_iter() + .map(|(_, (return_arguments, name, input_arguments))| { + Directive2::Method(Function2 { + func_decl: ast::MethodDeclaration { + return_arguments, + name: ast::MethodName::Func(name), + input_arguments, + shared_mem: None, + }, + globals: Vec::new(), + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }) + }) + .collect::>(); + result.extend(remapped_directives); + Ok(result) +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, fn_declarations, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + statements + .into_iter() + .map(|statement| { + Ok(match statement { + Statement::Instruction(instruction) => { + Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) + } + s => s, + }) + }) + .collect::, _>>() +} + +fn run_instruction<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + instruction: ptx_parser::Instruction, +) -> Result, TranslateError> { + Ok(match instruction { + i @ ptx_parser::Instruction::Activemask { .. } => { + to_call(resolver, fn_declarations, "activemask".into(), i)? + } + i => i, + }) +} + +fn to_call<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + name: Cow<'input, str>, + i: ast::Instruction, +) -> Result, TranslateError> { + let mut data_return = Vec::new(); + let mut data_input = Vec::new(); + let mut arguments_return = Vec::new(); + let mut arguments_input = Vec::new(); + ast::visit(&i, &mut |name: &SpirvWord, + type_space: Option<( + &ptx_parser::Type, + ptx_parser::StateSpace, + )>, + is_dst: bool, + _: bool| { + let (type_, space) = type_space.ok_or_else(error_mismatched_type)?; + if is_dst { + data_return.push((type_.clone(), space)); + arguments_return.push(*name); + } else { + data_input.push((type_.clone(), space)); + arguments_input.push(*name); + }; + Ok::<_, TranslateError>(()) + })?; + let fn_name = match fn_declarations.entry(name) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); + let name = resolver.register_named(Cow::Owned(full_name.clone()), None); + vacant_entry.insert(( + to_variables(resolver, &data_return), + name, + to_variables(resolver, &data_input), + )); + name + } + }; + Ok(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments: data_return, + input_arguments: data_input, + }, + arguments: ptx_parser::CallArgs { + return_arguments: arguments_return, + func: fn_name, + input_arguments: arguments_input, + }, + }) +} + +fn to_variables<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, +) -> Vec> { + arguments + .iter() + .map(|(type_, space)| ast::Variable { + align: None, + v_type: type_.clone(), + state_space: *space, + name: resolver.register_unnamed(Some((type_.clone(), *space))), + array_init: Vec::new(), + }) + .collect::>() +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 65c624e..f0d3fbe 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; -use std::{cmp::Ordering, num::NonZeroU8}; +use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; pub enum Statement { Label(P::Ident), @@ -806,6 +806,32 @@ impl Type { None => Self::maybe_vector_parsed(prefix, scalar), } } + + pub fn layout(&self) -> Layout { + match self { + Type::Scalar(type_) => type_.layout(), + Type::Vector(elements, scalar_type) => { + let scalar_layout = scalar_type.layout(); + unsafe { + Layout::from_size_align_unchecked( + scalar_layout.size() * *elements as usize, + scalar_layout.align() * *elements as usize, + ) + } + } + Type::Array(non_zero, scalar, vec) => { + let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout(); + let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0); + unsafe { + Layout::from_size_align_unchecked( + element_layout.size() * (len as usize), + element_layout.align(), + ) + } + } + Type::Pointer(..) => Layout::new::(), + } + } } impl ScalarType { @@ -831,6 +857,31 @@ impl ScalarType { } } + pub fn layout(self) -> Layout { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::(), + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => Layout::new::(), + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => Layout::new::(), + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => { + Layout::new::() + } + ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) }, + // Close enough + ScalarType::Pred => Layout::new::(), + } + } + pub fn kind(self) -> ScalarKind { match self { ScalarType::U8 => ScalarKind::Unsigned,