diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index b70019e..5e95dae 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,11 +24,10 @@ lalrpop_mod!( ); pub mod ast; -mod pass; +pub(crate) mod pass; #[cfg(test)] mod test; mod translate; -mod translate2; use std::fmt; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 7b794d6..934a472 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,12 +1,347 @@ use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; use std::{ borrow::Cow, cell::RefCell, collections::{hash_map, HashMap}, + ffi::CString, rc::Rc, }; -mod normalize; +pub(crate) mod normalize; + +static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); +static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; + +pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result { + let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); + let mut ptx_impl_imports = HashMap::new(); + let directives = ast + .directives + .into_iter() + .filter_map(|directive| { + translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() + }) + .collect::, _>>()?; + /* + let directives = hoist_function_globals(directives); + let must_link_ptx_impl = ptx_impl_imports.len() > 0; + let mut directives = ptx_impl_imports + .into_iter() + .map(|(_, v)| v) + .chain(directives.into_iter()) + .collect::>(); + let mut builder = dr::Builder::new(); + builder.reserve_ids(id_defs.current_id()); + let call_map = MethodsCallMap::new(&directives); + let mut directives = + convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); + normalize_variable_decls(&mut directives); + let denorm_information = compute_denorm_information(&directives); + // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module + builder.set_version(1, 3); + emit_capabilities(&mut builder); + emit_extensions(&mut builder); + let opencl_id = emit_opencl_import(&mut builder); + emit_memory_model(&mut builder); + let mut map = TypeWordMap::new(&mut builder); + //emit_builtins(&mut builder, &mut map, &id_defs); + let mut kernel_info = HashMap::new(); + let (build_options, should_flush_denorms) = + emit_denorm_build_string(&call_map, &denorm_information); + let (directives, globals_use_map) = get_globals_use_map(directives); + emit_directives( + &mut builder, + &mut map, + &id_defs, + opencl_id, + should_flush_denorms, + &call_map, + globals_use_map, + directives, + &mut kernel_info, + )?; + let spirv = builder.module(); + Ok(Module { + spirv, + kernel_info, + should_link_ptx_impl: if must_link_ptx_impl { + Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD)) + } else { + None + }, + build_options, + }) + */ + todo!() +} + +fn translate_directive<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + d: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result>, TranslateError> { + Ok(match d { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(linkage, f) => { + translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method) + } + }) +} + +type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement>>; + +fn translate_function<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + linkage: ast::LinkingDirective, + f: ParsedFunction<'input>, +) -> Result>, TranslateError> { + let import_as = match &f.func_directive { + ast::MethodDeclaration { + name: ast::MethodName::Func(func_name), + .. + } if *func_name == "__assertfail" || *func_name == "vprintf" => { + Some([ZLUDA_PTX_PREFIX, func_name].concat()) + } + _ => None, + }; + let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; + let mut func = to_ssa( + ptx_impl_imports, + str_resolver, + fn_resolver, + fn_decl, + f.body, + f.tuning, + linkage, + )?; + func.import_as = import_as; + if func.import_as.is_some() { + ptx_impl_imports.insert( + func.import_as.as_ref().unwrap().clone(), + Directive::Method(func), + ); + Ok(None) + } else { + Ok(Some(func)) + } +} + +fn to_ssa<'input, 'b>( + ptx_impl_imports: &'b mut HashMap>, + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + func_decl: Rc>>, + f_body: Option>>>, + tuning: Vec, + linkage: ast::LinkingDirective, +) -> Result, TranslateError> { + //deparamize_function_decl(&func_decl)?; + let f_body = match f_body { + Some(vec) => vec, + None => { + return Ok(Function { + func_decl: func_decl, + body: None, + globals: Vec::new(), + import_as: None, + tuning, + linkage, + }) + } + }; + let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?; + todo!() + /* + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &mut (*func_decl).borrow_mut(), + )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; + let expanded_statements = + insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.unmut(); + let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let (f_body, globals) = + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + Ok(Function { + func_decl: func_decl, + globals: globals, + body: Some(f_body), + import_as: None, + tuning, + linkage, + }) + */ +} + +pub struct Module { + pub spirv: dr::Module, + pub kernel_info: HashMap, + pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>, + pub build_options: CString, +} + +impl Module { + pub fn assemble(&self) -> Vec { + self.spirv.assemble() + } +} + +struct GlobalStringIdResolver<'input> { + current_id: SpirvWord, + variables: HashMap, SpirvWord>, + reverse_variables: HashMap, + variables_type_check: HashMap>, + special_registers: SpecialRegistersMap, + fns: HashMap>, +} + +impl<'input> GlobalStringIdResolver<'input> { + fn new(start_id: SpirvWord) -> Self { + Self { + current_id: start_id, + variables: HashMap::new(), + reverse_variables: HashMap::new(), + variables_type_check: HashMap::new(), + special_registers: SpecialRegistersMap::new(), + fns: HashMap::new(), + } + } + + fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord { + self.get_or_add_impl(id, None) + } + + fn get_or_add_def_typed( + &mut self, + id: &'input str, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> SpirvWord { + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) + } + + fn get_or_add_impl( + &mut self, + id: &'input str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> SpirvWord { + let id = match self.variables.entry(Cow::Borrowed(id)) { + hash_map::Entry::Occupied(e) => *(e.get()), + hash_map::Entry::Vacant(e) => { + let numeric_id = self.current_id; + e.insert(numeric_id); + self.reverse_variables.insert(numeric_id, id); + self.current_id.0 += 1; + numeric_id + } + }; + self.variables_type_check.insert(id, typ); + id + } + + fn get_id(&self, id: &str) -> Result { + self.variables + .get(id) + .copied() + .ok_or_else(error_unknown_symbol) + } + + fn current_id(&self) -> SpirvWord { + self.current_id + } + + fn start_fn<'b>( + &'b mut self, + header: &'b ast::MethodDeclaration<'input, &'input str>, + ) -> Result< + ( + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, + ), + TranslateError, + > { + // In case a function decl was inserted earlier we want to use its id + let name_id = self.get_or_add_def(header.name()); + let mut fn_resolver = FnStringIdResolver { + current_id: &mut self.current_id, + global_variables: &self.variables, + global_type_check: &self.variables_type_check, + special_registers: &mut self.special_registers, + variables: vec![HashMap::new(); 1], + type_check: HashMap::new(), + }; + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), + }; + let fn_decl = ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }; + let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) + }; + Ok(( + fn_resolver, + GlobalFnDeclResolver { fns: &self.fns }, + new_fn_decl, + )) + } +} + +fn rename_fn_params<'a, 'b>( + fn_resolver: &mut FnStringIdResolver<'a, 'b>, + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), + v_type: a.v_type.clone(), + state_space: a.state_space, + align: a.align, + array_init: a.array_init.clone(), + }) + .collect() +} + +pub struct KernelInfo { + pub arguments_sizes: Vec<(usize, bool)>, + pub uses_shared_mem: bool, +} #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { @@ -108,10 +443,10 @@ impl SpecialRegistersMap { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut SpirvWord, global_variables: &'b HashMap, SpirvWord>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, SpirvWord>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -160,7 +495,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .unwrap() .insert(Cow::Borrowed(id), numeric_id); self.type_check.insert( - numeric_id.0, + numeric_id, typ.map(|(typ, space)| (typ, space, is_variable)), ); self.current_id.0 += 1; @@ -183,7 +518,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { SpirvWord(numeric_id.0 + i), ); self.type_check.insert( - numeric_id.0 + i, + SpirvWord(numeric_id.0 + i), Some((typ.clone(), state_space, is_variable)), ); } @@ -196,8 +531,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut SpirvWord, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } @@ -210,12 +545,12 @@ impl<'b> NumericIdResolver<'b> { &self, id: SpirvWord, ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { - match self.type_check.get(&id.0) { + match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), - None => match self.global_type_check.get(&id.0) { + None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), }, @@ -228,7 +563,7 @@ impl<'b> NumericIdResolver<'b> { fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { let new_id = *self.current_id; self.type_check - .insert(new_id.0, Some((typ, state_space, true))); + .insert(new_id, Some((typ, state_space, true))); self.current_id.0 += 1; new_id } @@ -236,7 +571,7 @@ impl<'b> NumericIdResolver<'b> { fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { let new_id = *self.current_id; self.type_check - .insert(new_id.0, typ.map(|(t, space)| (t, space, false))); + .insert(new_id, typ.map(|(t, space)| (t, space, false))); self.current_id.0 += 1; new_id } @@ -490,6 +825,10 @@ impl From for spirv::Word { impl ast::Operand for SpirvWord { type Ident = Self; + + fn from_ident(ident: Self::Ident) -> Self { + ident + } } fn pred_map_variable Result>( @@ -503,29 +842,18 @@ fn pred_map_variable Result>( }) } -impl Result, Err> ast::VisitorMap for X { - fn visit( - &mut self, - args: T, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - ) -> U { - todo!() - } - - fn visit_ident( - &mut self, - args: ::Ident, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - ) -> ::Ident { - todo!() - } +pub(crate) enum Directive<'input> { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function<'input>), } -fn op_map_variable<'a, F: FnMut(&str) -> Result>( - this: ast::Instruction>, - f: &mut F, -) -> Result>, TranslateError> { - ast::visit_map(this , f) +pub(crate) struct Function<'input> { + pub func_decl: Rc>>, + pub globals: Vec>, + pub body: Option>, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, } + +type ExpandedStatement = Statement, SpirvWord>; \ No newline at end of file diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize.rs index 3832685..68ac26e 100644 --- a/ptx/src/pass/normalize.rs +++ b/ptx/src/pass/normalize.rs @@ -9,7 +9,7 @@ type NormalizedStatement = Statement< ast::ParsedOperand, >; -fn run<'input, 'b>( +pub(crate) fn run<'input, 'b>( id_defs: &mut FnStringIdResolver<'input, 'b>, fn_defs: &GlobalFnDeclResolver<'input, 'b>, func: Vec>>, @@ -47,7 +47,11 @@ fn expand_map_variables<'a, 'b>( ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) .transpose()?, - op_map_variable(i, &mut |id| id_defs.get_id(id))?, + ast::visit_map(i, &mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _: bool| { + id_defs.get_id(id) + })?, ))), ast::Statement::Variable(var) => { let var_type = var.var.v_type.clone(); diff --git a/ptx/src/translate2.rs b/ptx/src/translate2.rs deleted file mode 100644 index 4ac5dea..0000000 --- a/ptx/src/translate2.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::collections::HashMap; -use half::f16; -use ptx_parser as ast; - -fn to_ssa<'input, 'b>( - ptx_impl_imports: &'b mut HashMap>, - mut id_defs: FnStringIdResolver<'input, 'b>, - fn_defs: GlobalFnDeclResolver<'input, 'b>, - func_decl: Rc>>, - f_body: Option>>>, - tuning: Vec, - linkage: ast::LinkingDirective, -) -> Result, TranslateError> { - //deparamize_function_decl(&func_decl)?; - let f_body = match f_body { - Some(vec) => vec, - None => { - return Ok(Function { - func_decl: func_decl, - body: None, - globals: Vec::new(), - import_as: None, - tuning, - linkage, - }) - } - }; - let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; - /* - let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; - let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; - let (func_decl, typed_statements) = - convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = insert_mem_ssa_statements( - typed_statements, - &mut numeric_id_defs, - &mut (*func_decl).borrow_mut(), - )?; - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; - let expanded_statements = - insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; - let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; - Ok(Function { - func_decl: func_decl, - globals: globals, - body: Some(f_body), - import_as: None, - tuning, - linkage, - }) - */ -} \ No newline at end of file diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 87a2f6b..ee9f968 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -555,12 +555,46 @@ pub trait VisitorMap { ) -> Result; } -impl< - T: Operand, - U: Operand, - Err, - Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, - > VisitorMap for Fn +impl VisitorMap, ParsedOperand, Err> for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, +{ + fn visit( + &mut self, + args: ParsedOperand, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result, Err> { + Ok(match args { + ParsedOperand::Reg(ident) => ParsedOperand::Reg((self)(ident, type_space, is_dst)?), + ParsedOperand::RegOffset(ident, imm) => { + ParsedOperand::RegOffset((self)(ident, type_space, is_dst)?, imm) + } + ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), + ParsedOperand::VecMember(ident, index) => { + ParsedOperand::VecMember((self)(ident, type_space, is_dst)?, index) + } + ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( + vec.into_iter() + .map(|ident| (self)(ident, type_space, is_dst)) + .collect::, _>>()?, + ), + }) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + (self)(args, type_space, is_dst) + } +} + +impl, U: Operand, Err, Fn> VisitorMap for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, { fn visit( &mut self, @@ -573,12 +607,11 @@ impl< fn visit_ident( &mut self, - args: T::Ident, + args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ) -> Result { - let value: U = (self)(T::from_ident(args), type_space, is_dst)?; - Ok(value) + ) -> Result { + (self)(args, type_space, is_dst) } } @@ -925,6 +958,15 @@ pub struct MethodDeclaration<'input, ID> { pub shared_mem: Option, } +impl<'input> MethodDeclaration<'input, &'input str> { + pub fn name(&self) -> &'input str { + match self.name { + MethodName::Kernel(n) => n, + MethodName::Func(n) => n, + } + } +} + #[derive(Hash, PartialEq, Eq, Copy, Clone)] pub enum MethodName<'input, ID> { Kernel(&'input str),