From 425edfcdd49a4fa49d480f1b078c55dba4709e29 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 7 May 2021 18:22:09 +0200 Subject: [PATCH] Simplify typing --- ptx/src/ast.rs | 21 +- ptx/src/ptx.lalrpop | 21 +- ptx/src/translate.rs | 524 +++++++++++++++++------------------------- zluda_dump/src/lib.rs | 14 +- 4 files changed, 247 insertions(+), 333 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 364ec01..e45a6fb 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -86,19 +86,20 @@ pub enum Directive<'a, P: ArgParams> { Method(Function<'a, &'a str, Statement

>), } -pub enum MethodDecl<'a, ID> { - Func(Vec>, ID, Vec>), - Kernel { - name: &'a str, - in_args: Vec>, - }, +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, +} pub struct Function<'a, ID, S> { - pub func_directive: MethodDecl<'a, ID>, + pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec, pub body: Option>, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 8fee7c2..78ebf1d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -360,7 +360,7 @@ AddressSize = { Function: ast::Function<'input, &'input str, ast::Statement>> = { LinkingDirectives - + => ast::Function{<>} }; @@ -388,19 +388,24 @@ LinkingDirectives: ast::LinkingDirective = { } } -MethodDecl: ast::MethodDecl<'input, &'input str> = { - ".entry" => - ast::MethodDecl::Kernel{ name, in_args }, - ".func" => { - ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) +MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { + ".entry" => { + let return_arguments = Vec::new(); + let name = ast::MethodName::Kernel(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments } + }, + ".func" => { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments } } }; -KernelArguments: Vec> = { +KernelArguments: Vec> = { "(" > ")" => args }; -FnArguments: Vec> = { +FnArguments: Vec> = { "(" > ")" => args }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1a2eda3..88ef51b 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,9 @@ use crate::ast; +use core::borrow; use half::f16; use rspirv::dr; -use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; +use std::{borrow::Borrow, cell::RefCell}; +use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -458,7 +460,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); - let call_map = get_call_map(&directives); + let call_map = get_kernels_call_map(&directives); let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); @@ -496,9 +498,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( call_map: &HashMap<&str, HashSet>, - denorm_information: &HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, ) -> CString { let denorm_counts = denorm_information .iter() @@ -516,10 +521,12 @@ fn emit_denorm_build_string( .collect::>(); let mut flush_over_preserve = 0; for (kernel, children) in call_map { - flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0); + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); for child_fn in children { flush_over_preserve += *denorm_counts - .get(&MethodName::Func(*child_fn)) + .get(&ast::MethodName::Func(*child_fn)) .unwrap_or(&0); } } @@ -535,9 +542,12 @@ fn emit_directives<'input>( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, - denorm_information: &HashMap, HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'input str, HashSet>, - directives: Vec, + directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { let empty_body = Vec::new(); @@ -560,16 +570,18 @@ fn emit_directives<'input>( for var in f.globals.iter() { emit_variable(builder, map, var)?; } + let func_decl = (*f.func_decl).borrow(); let fn_id = emit_function_header( builder, map, &id_defs, &f.globals, - &f.spirv_decl, + &*func_decl, &denorm_information, call_map, &directives, kernel_info, + f.uses_shared_mem, )?; for t in f.tuning.iter() { match *t { @@ -594,8 +606,13 @@ fn emit_directives<'input>( } emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; - if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = - (&f.func_decl, &f.import_as) + if let ( + ast::MethodDeclaration { + name: ast::MethodName::Func(fn_id), + .. + }, + Some(name), + ) = (&*func_decl, &f.import_as) { builder.decorate( *fn_id, @@ -614,7 +631,7 @@ fn emit_directives<'input>( Ok(()) } -fn get_call_map<'input>( +fn get_kernels_call_map<'input>( module: &[Directive<'input>], ) -> HashMap<&'input str, HashSet> { let mut directly_called_by = HashMap::new(); @@ -625,7 +642,7 @@ fn get_call_map<'input>( body: Some(statements), .. }) => { - let call_key = MethodName::new(&func_decl); + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { entry.insert(Vec::new()); } @@ -644,28 +661,28 @@ fn get_call_map<'input>( let mut result = HashMap::new(); for (method_key, children) in directly_called_by.iter() { match method_key { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let mut visited = HashSet::new(); for child in children { add_call_map_single(&directly_called_by, &mut visited, *child); } result.insert(*name, visited); } - MethodName::Func(_) => {} + ast::MethodName::Func(_) => {} } } result } fn add_call_map_single<'input>( - directly_called_by: &MultiHashMap, spirv::Word>, + directly_called_by: &MultiHashMap, spirv::Word>, visited: &mut HashSet, current: spirv::Word, ) { if !visited.insert(current) { return; } - if let Some(children) = directly_called_by.get(&MethodName::Func(current)) { + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { for child in children { add_call_map_single(directly_called_by, visited, *child); } @@ -739,10 +756,10 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }) => { - let call_key = MethodName::new(&func_decl); + let call_key = (*func_decl).borrow().name; let statements = statements .into_iter() .map(|statement| match statement { @@ -763,8 +780,8 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }) } directive => directive, @@ -782,30 +799,32 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - mut spirv_decl, tuning, + uses_shared_mem, }) => { - if !methods_using_extern_shared.contains(&spirv_decl.name) { + if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { return Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }); } let shared_id_param = new_id(); - spirv_decl.input.push({ - ast::Variable { - name: shared_id_param, - align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8), - state_space: ast::StateSpace::Shared, - array_init: Vec::new(), - } - }); - spirv_decl.uses_shared_mem = true; + { + let mut func_decl = (*func_decl).borrow_mut(); + func_decl.input_arguments.push({ + ast::Variable { + name: shared_id_param, + align: None, + v_type: ast::Type::Pointer(ast::ScalarType::B8), + state_space: ast::StateSpace::Shared, + array_init: Vec::new(), + } + }); + } let statements = replace_uses_of_shared_memory( new_id, &extern_shared_decls, @@ -818,8 +837,8 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem: true, }) } directive => directive, @@ -830,7 +849,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( new_id: &mut impl FnMut() -> spirv::Word, extern_shared_decls: &HashMap, - methods_using_extern_shared: &mut HashSet>, + methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, statements: Vec, ) -> Vec { @@ -841,7 +860,7 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { + if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) { call.param_list.push(( shared_id_param, ast::Type::Scalar(ast::ScalarType::B8), @@ -881,13 +900,13 @@ fn replace_uses_of_shared_memory<'a>( } fn get_callers_of_extern_shared<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, ) { let direct_uses_of_extern_shared = methods_using_extern_shared .iter() .filter_map(|method| { - if let MethodName::Func(f_id) = method { + if let ast::MethodName::Func(f_id) = method { Some(*f_id) } else { None @@ -900,14 +919,14 @@ fn get_callers_of_extern_shared<'a>( } fn get_callers_of_extern_shared_single<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, fn_id: spirv::Word, ) { if let Some(callers) = directly_called_by.get(&fn_id) { for caller in callers { if methods_using_extern_shared.insert(*caller) { - if let MethodName::Func(caller_fn) = caller { + if let ast::MethodName::Func(caller_fn) = caller { get_callers_of_extern_shared_single( methods_using_extern_shared, directly_called_by, @@ -949,7 +968,7 @@ fn denorm_count_map_update_impl( // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], -) -> HashMap, HashMap> { +) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { @@ -960,7 +979,7 @@ fn compute_denorm_information<'input>( .. }) => { let mut flush_counter = DenormCountMap::new(); - let method_key = MethodName::new(func_decl); + let method_key = (**func_decl).borrow().name; for statement in statements { match statement { Statement::Instruction(inst) => { @@ -1004,21 +1023,6 @@ fn compute_denorm_information<'input>( .collect() } -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -enum MethodName<'input> { - Kernel(&'input str), - Func(spirv::Word), -} - -impl<'input> MethodName<'input> { - fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - match decl { - ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name), - ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id), - } - } -} - fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1047,17 +1051,21 @@ fn emit_function_header<'a>( map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, synthetic_globals: &[ast::Variable], - func_decl: &SpirvMethodDecl<'a>, - _denorm_information: &HashMap, HashMap>, + func_decl: &ast::MethodDeclaration<'a, spirv::Word>, + _denorm_information: &HashMap< + ast::MethodName<'a, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, + uses_shared_mem: bool, ) -> Result { - if let MethodName::Kernel(name) = func_decl.name { - let input_args = if !func_decl.uses_shared_mem { - func_decl.input.as_slice() + if let ast::MethodName::Kernel(name) = func_decl.name { + let input_args = if !uses_shared_mem { + func_decl.input_arguments.as_slice() } else { - &func_decl.input[0..func_decl.input.len() - 1] + &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1] }; let args_lens = input_args .iter() @@ -1067,14 +1075,18 @@ fn emit_function_header<'a>( name.to_string(), KernelInfo { arguments_sizes: args_lens, - uses_shared_mem: func_decl.uses_shared_mem, + uses_shared_mem: uses_shared_mem, }, ); } - let (ret_type, func_type) = - get_function_type(builder, map, &func_decl.input, &func_decl.output); + let (ret_type, func_type) = get_function_type( + builder, + map, + &func_decl.input_arguments, + &func_decl.return_arguments, + ); let fn_id = match func_decl.name { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; let mut global_variables = defined_globals .variables_type_check @@ -1090,15 +1102,16 @@ fn emit_function_header<'a>( for directive in direcitves { match directive { Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - globals, - .. + func_decl, globals, .. }) => { - if child_fns.contains(name) { - for var in globals { - interface.push(var.name); + match (**func_decl).borrow().name { + ast::MethodName::Func(name) => { + for var in globals { + interface.push(var.name); + } } - } + ast::MethodName::Kernel(_) => {} + }; } _ => {} } @@ -1107,7 +1120,7 @@ fn emit_function_header<'a>( builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); fn_id } - MethodName::Func(name) => name, + ast::MethodName::Func(name) => name, }; builder.begin_function( ret_type, @@ -1130,7 +1143,7 @@ fn emit_function_header<'a>( } } */ - for input in &func_decl.input { + for input in &func_decl.input_arguments { let result_type = map.get_or_add( builder, SpirvType::new(input.v_type.clone(), input.state_space), @@ -1225,9 +1238,10 @@ fn translate_function<'a>( f: ast::ParsedFunction<'a>, ) -> Result>, TranslateError> { let import_as = match &f.func_directive { - ast::MethodDecl::Func(_, "__assertfail", _) => { - Some("__zluda_ptx_impl____assertfail".to_owned()) - } + ast::MethodDeclaration { + name: ast::MethodName::Func("__assertfail"), + .. + } => Some("__zluda_ptx_impl____assertfail".to_owned()), _ => None, }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; @@ -1253,10 +1267,10 @@ fn translate_function<'a>( fn expand_kernel_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { + args: impl Iterator>, +) -> Result>, TranslateError> { args.map(|a| { - Ok(ast::KernelArgument { + Ok(ast::Variable { name: fn_resolver.add_def( a.name, Some(( @@ -1274,42 +1288,39 @@ fn expand_kernel_params<'a, 'b>( .collect::>() } -fn expand_fn_params<'a, 'b>( +fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - let is_variable = a.state_space == ast::StateSpace::Reg; - Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable), + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, - array_init: Vec::new(), + array_init: a.array_init.clone(), }) - }) - .collect() + .collect() } fn to_ssa<'input, 'b>( ptx_impl_imports: &mut HashMap, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, - f_args: ast::MethodDecl<'input, spirv::Word>, + func_decl: Rc>>, f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { - let mut spirv_decl = SpirvMethodDecl::new(&f_args); let f_body = match f_body { Some(vec) => vec, None => { return Ok(Function { - func_decl: f_args, + func_decl: func_decl, body: None, globals: Vec::new(), import_as: None, - spirv_decl, tuning, + uses_shared_mem: false, }) } }; @@ -1323,8 +1334,7 @@ fn to_ssa<'input, 'b>( let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, - &f_args, - &mut spirv_decl, + &mut (*func_decl).borrow_mut(), )?; let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); @@ -1336,12 +1346,12 @@ fn to_ssa<'input, 'b>( let (f_body, globals) = extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); Ok(Function { - func_decl: f_args, + func_decl: func_decl, globals: globals, body: Some(f_body), import_as: None, - spirv_decl, tuning, + uses_shared_mem: false, }) } @@ -1573,9 +1583,9 @@ fn convert_to_typed_statements( Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { // TODO: error out if lengths don't match - let fn_def = fn_defs.get_fn_decl(call.func)?; - let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); - let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); + let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); + let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); + let in_args = to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args .into_iter() .partition(|(_, _, space)| *space == ast::StateSpace::Param); @@ -1731,24 +1741,24 @@ fn to_ptx_impl_atomic_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Pointer(typ), state_space: ptr_space, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, @@ -1756,24 +1766,23 @@ fn to_ptx_impl_atomic_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1810,31 +1819,31 @@ fn to_ptx_impl_bfe_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, @@ -1842,24 +1851,23 @@ fn to_ptx_impl_bfe_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1903,38 +1911,38 @@ fn to_ptx_impl_bfi_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, @@ -1942,24 +1950,23 @@ fn to_ptx_impl_bfi_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1994,12 +2001,12 @@ fn to_ptx_impl_bfi_call( fn to_resolved_fn_args( params: Vec, - params_decl: &[(ast::Type, ast::StateSpace)], + params_decl: &[ast::Variable], ) -> Vec<(T, ast::Type, ast::StateSpace)> { params .into_iter() .zip(params_decl.iter()) - .map(|(id, (typ, space))| (id, typ.clone(), *space)) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) .collect::>() } @@ -2084,11 +2091,10 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, - _: &'a ast::MethodDecl<'b, spirv::Word>, - fn_decl: &mut SpirvMethodDecl, + fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.output.iter() { + for arg in fn_decl.return_arguments.iter() { result.push(Statement::Variable(ast::Variable { align: arg.align, v_type: arg.v_type.clone(), @@ -2097,27 +2103,27 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(), })); } - for spirv_arg in fn_decl.input.iter_mut() { - let typ = spirv_arg.v_type.clone(); - let state_space = spirv_arg.state_space; + for arg in fn_decl.input_arguments.iter_mut() { + let typ = arg.v_type.clone(); + let state_space = arg.state_space; let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); result.push(Statement::Variable(ast::Variable { - align: spirv_arg.align, - v_type: spirv_arg.v_type.clone(), - state_space: spirv_arg.state_space, - name: spirv_arg.name, - array_init: spirv_arg.array_init.clone(), + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: Vec::new(), })); result.push(Statement::StoreVar(StoreVarDetails { arg: ast::Arg2St { - src1: spirv_arg.name, + src1: arg.name, src2: new_id, }, state_space, typ, member_index: None, })); - spirv_arg.name = new_id; + arg.name = new_id; } for s in func { match s { @@ -2127,7 +2133,7 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args - if let &[out_param] = &fn_decl.output.as_slice() { + if let &[out_param] = &fn_decl.return_arguments.as_slice() { let (typ, space, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.register_intermediate(Some((typ.clone(), space))); result.push(Statement::LoadVar(LoadVarDetails { @@ -5081,15 +5087,10 @@ struct GlobalStringIdResolver<'input> { variables: HashMap, spirv::Word>, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap, + fns: HashMap>>>, } -pub struct FnDecl { - ret_vals: Vec<(ast::Type, ast::StateSpace)>, - params: Vec<(ast::Type, ast::StateSpace)>, -} - -impl<'a> GlobalStringIdResolver<'a> { +impl<'input> GlobalStringIdResolver<'input> { fn new(start_id: spirv::Word) -> Self { Self { current_id: start_id, @@ -5100,13 +5101,13 @@ impl<'a> GlobalStringIdResolver<'a> { } } - fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { + fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word { self.get_or_add_impl(id, None) } fn get_or_add_def_typed( &mut self, - id: &'a str, + id: &'input str, typ: ast::Type, state_space: ast::StateSpace, is_variable: bool, @@ -5116,7 +5117,7 @@ impl<'a> GlobalStringIdResolver<'a> { fn get_or_add_impl( &mut self, - id: &'a str, + id: &'input str, typ: Option<(ast::Type, ast::StateSpace, bool)>, ) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { @@ -5145,12 +5146,12 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>( &'b mut self, - header: &'b ast::MethodDecl<'a, &'a str>, + header: &'b ast::MethodDeclaration<'input, &'input str>, ) -> Result< ( - FnStringIdResolver<'a, 'b>, - GlobalFnDeclResolver<'a, 'b>, - ast::MethodDecl<'a, spirv::Word>, + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, ), TranslateError, > { @@ -5164,30 +5165,18 @@ impl<'a> GlobalStringIdResolver<'a> { variables: vec![HashMap::new(); 1], type_check: HashMap::new(), }; - let new_fn_decl = match header { - ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel { - name, - in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?, - }, - ast::MethodDecl::Func(ret_params, _, params) => { - let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?; - let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?; - self.fns.insert( - name_id, - FnDecl { - ret_vals: ret_params_ids - .iter() - .map(|p| (p.v_type.clone(), p.state_space)) - .collect(), - params: params_ids - .iter() - .map(|p| (p.v_type.clone(), p.state_space)) - .collect(), - }, - ); - ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) - } + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), }; + let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + })); + self.fns.insert(name_id, Rc::clone(&new_fn_decl)); Ok(( fn_resolver, GlobalFnDeclResolver { @@ -5201,15 +5190,21 @@ impl<'a> GlobalStringIdResolver<'a> { pub struct GlobalFnDeclResolver<'input, 'a> { variables: &'a HashMap, spirv::Word>, - fns: &'a HashMap, + fns: &'a HashMap>>>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> { + fn get_fn_decl( + &self, + id: spirv::Word, + ) -> Result<&Rc>>, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> { + fn get_fn_decl_str( + &self, + id: &str, + ) -> Result<&'a Rc>>, TranslateError> { match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { Some(Some(fn_d)) => Ok(fn_d), _ => Err(TranslateError::UnknownSymbol), @@ -5713,21 +5708,9 @@ impl, U: ArgParamsEx> Visitab } } -pub trait ArgParamsEx: ast::ArgParams + Sized { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError>; -} +pub trait ArgParamsEx: ast::ArgParams + Sized {} -impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl_str(id) - } -} +impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {} enum NormalizedArgParams {} @@ -5736,14 +5719,7 @@ impl ast::ArgParams for NormalizedArgParams { type Operand = ast::Operand; } -impl ArgParamsEx for NormalizedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for NormalizedArgParams {} type NormalizedStatement = Statement< ( @@ -5762,14 +5738,7 @@ impl ast::ArgParams for TypedArgParams { type Operand = TypedOperand; } -impl ArgParamsEx for TypedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for TypedArgParams {} #[derive(Copy, Clone)] enum TypedOperand { @@ -5800,14 +5769,7 @@ impl ast::ArgParams for ExpandedArgParams { type Operand = spirv::Word; } -impl ArgParamsEx for ExpandedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for ExpandedArgParams {} enum Directive<'input> { Variable(ast::Variable), @@ -5815,10 +5777,10 @@ enum Directive<'input> { } struct Function<'input> { - pub func_decl: ast::MethodDecl<'input, spirv::Word>, - pub spirv_decl: SpirvMethodDecl<'input>, + pub func_decl: Rc>>, pub globals: Vec>, pub body: Option>, + pub uses_shared_mem: bool, import_as: Option, tuning: Vec, } @@ -7671,73 +7633,11 @@ fn should_convert_relaxed_dst( } } -impl<'a> ast::MethodDecl<'a, &'a str> { +impl<'a> ast::MethodDeclaration<'a, &'a str> { fn name(&self) -> &'a str { - match self { - ast::MethodDecl::Kernel { name, .. } => name, - ast::MethodDecl::Func(_, name, _) => name, - } - } -} - -struct SpirvMethodDecl<'input> { - input: Vec>, - output: Vec>, - name: MethodName<'input>, - uses_shared_mem: bool, -} - -impl<'input> SpirvMethodDecl<'input> { - fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - let (input, output) = match ast_decl { - ast::MethodDecl::Kernel { in_args, .. } => { - let spirv_input = in_args - .iter() - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - (spirv_input, Vec::new()) - } - ast::MethodDecl::Func(out_args, _, in_args) => { - let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args - .iter() - .partition(|var| var.state_space == ast::StateSpace::Param); - let spirv_output = non_param_output - .into_iter() - .cloned() - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - let spirv_input = param_output - .into_iter() - .cloned() - .chain(in_args.iter().cloned()) - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - (spirv_input, spirv_output) - } - }; - SpirvMethodDecl { - input, - output, - name: MethodName::new(ast_decl), - uses_shared_mem: false, + match self.name { + ast::MethodName::Kernel(name) => name, + ast::MethodName::Func(name) => name, } } } diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index 4ea449c..f168930 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) { unsafe fn try_dump_module_image(image: &str) -> Result<(), Box> { let mut dump_path = get_dump_dir()?; - dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1)); + dump_path.push(format!( + "module_{:04}.ptx", + MODULES.as_ref().unwrap().len() - 1 + )); let mut file = File::create(dump_path)?; file.write_all(image.as_bytes())?; Ok(()) @@ -217,10 +220,15 @@ unsafe fn to_str(image: *const T) -> Option<&'static str> { fn directive_to_kernel(dir: &ast::Directive) -> Option<(String, Vec)> { match dir { ast::Directive::Method(ast::Function { - func_directive: ast::MethodDecl::Kernel { name, in_args }, + func_directive: + ast::MethodDeclaration { + name: ast::MethodName::Kernel(name), + input_arguments, + .. + }, .. }) => { - let arg_sizes = in_args + let arg_sizes = input_arguments .iter() .map(|arg| ast::Type::from(arg.v_type.clone()).size_of()) .collect();