From de734305cfe8124c1a3a4a0adfee143e4ff5b680 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 3 Sep 2020 01:45:08 +0200 Subject: [PATCH] Start refactoring SPIRV module generation in preparation for support of functions --- notcuda/src/impl/module.rs | 5 +- ptx/src/ast.rs | 25 ++--- ptx/src/ptx.lalrpop | 19 ++-- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 193 +++++++++++++++++++++------------- 5 files changed, 141 insertions(+), 102 deletions(-) diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index feae40b..491778a 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -14,10 +14,7 @@ pub enum ModuleCompileError<'a> { } impl<'a> ModuleCompileError<'a> { - pub fn get_build_log(&self) { - - } - + pub fn get_build_log(&self) {} } impl<'a> From for ModuleCompileError<'a> { diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5de1db6..7550d55 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -48,24 +48,25 @@ impl< pub struct Module<'a> { pub version: (u8, u8), - pub functions: Vec>, + pub functions: Vec>, } -pub enum FunctionReturn<'a> { - Func(Vec>), - Kernel, +pub enum FunctionHeader<'a, P: ArgParams> { + Func(Vec>, P::ID), + Kernel(&'a str), } -pub struct Function<'a> { - pub func_directive: FunctionReturn<'a>, - pub name: &'a str, - pub args: Vec>, - pub body: Option>>>, +pub struct Function<'a, P: ArgParams, S> { + pub func_directive: FunctionHeader<'a, P>, + pub args: Vec>, + pub body: Option>, } +pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement>>; + #[derive(Default)] -pub struct Argument<'a> { - pub name: &'a str, +pub struct Argument { + pub name: P::ID, pub a_type: ScalarType, pub length: u32, } @@ -231,7 +232,7 @@ pub struct CallData { pub struct AbsDetails { pub flush_to_zero: bool, - pub typ: ScalarType + pub typ: ScalarType, } pub struct ArgCall { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 7438e97..7e38b78 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -189,7 +189,7 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option> = { +Directive: Option, ast::Statement>>> = { AddressSize => None, => Some(f), File => None, @@ -200,12 +200,11 @@ AddressSize = { ".address_size" Num }; -Function: ast::Function<'input> = { +Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement>> = { LinkingDirective* - - + - => ast::Function{<>} + => ast::Function{<>} }; LinkingDirective = { @@ -214,17 +213,17 @@ LinkingDirective = { ".weak" }; -FunctionReturn: ast::FunctionReturn<'input> = { - ".entry" => ast::FunctionReturn::Kernel, - ".func" => ast::FunctionReturn::Func(args.unwrap_or_else(|| Vec::new())) +FunctionHeader: ast::FunctionHeader<'input, ast::ParsedArgParams<'input>> = { + ".entry" => ast::FunctionHeader::Kernel(name), + ".func" => ast::FunctionHeader::Func(args.unwrap_or_else(|| Vec::new()), name) }; -Arguments: Vec> = { +Arguments: Vec>> = { "(" > ")" => args } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -FunctionInput: ast::Argument<'input> = { +FunctionInput: ast::Argument> = { ".param" <_type:ScalarType> => { ast::Argument {a_type: _type, name: name, length: 1 } }, diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 8883669..9ea0100 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -52,6 +52,7 @@ test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]); +test_ptx!(call, [1u64], [2u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 642e6ec..8cf3aca 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -154,7 +154,7 @@ impl TypeWordMap { } } -pub fn to_spirv_module(ast: ast::Module) -> Result { +pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { let mut builder = dr::Builder::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 3); @@ -163,13 +163,21 @@ pub fn to_spirv_module(ast: ast::Module) -> Result { let opencl_id = emit_opencl_import(&mut builder); emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); - for f in ast.functions { - emit_function(&mut builder, &mut map, opencl_id, f)?; + let mut id_defs = GlobalStringIdResolver::new(builder.id()); + let ssa_functions = ast + .functions + .into_iter() + .map(|f| to_ssa_function(&mut id_defs, opencl_id, f)) + .collect::>(); + for f in ssa_functions { + emit_function_args(&mut builder, &mut map, &*f.args); + emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?; + builder.end_function()?; } Ok(builder.module()) } -pub fn to_spirv(ast: ast::Module) -> Result, dr::Error> { +pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result, dr::Error> { let module = to_spirv_module(ast)?; Ok(module.assemble()) } @@ -196,28 +204,28 @@ fn emit_memory_model(builder: &mut dr::Builder) { ); } -fn emit_function<'a>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, +fn to_ssa_function<'a>( + id_defs: &mut GlobalStringIdResolver<'a>, opencl_id: spirv::Word, - f: ast::Function<'a>, -) -> Result { - let func_type = get_function_type(builder, map, &f.args); - let func_id = - builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?; - match f.func_directive { - ast::FunctionReturn::Kernel => { - builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]) - } - _ => todo!(), + f: ast::ParsedFunction<'a>, +) -> ExpandedFunction<'a> { + let ids_start = id_defs.current_id(); + let fn_resolver = FnStringIdResolver::new(id_defs); + let f_header = match f.func_directive { + ast::FunctionHeader::Kernel(name) => todo!(), + ast::FunctionHeader::Func(ret_params, name) => todo!(), + }; + let f_args = todo!(); + let f_body = Some(to_ssa( + fn_resolver, + &f.args, + f.body.unwrap_or_else(|| todo!()), + )); + ExpandedFunction { + func_directive: f_header, + args: f_args, + body: f_body, } - let (mut func_body, unique_ids) = to_ssa(&f.args, f.body.unwrap_or_else(|| todo!())); - let id_offset = builder.reserve_ids(unique_ids); - emit_function_args(builder, id_offset, map, &f.args); - func_body = apply_id_offset(func_body, id_offset); - emit_function_body_ops(builder, map, opencl_id, &func_body)?; - builder.end_function()?; - Ok(func_id) } fn apply_id_offset(func_body: Vec, id_offset: u32) -> Vec { @@ -228,16 +236,19 @@ fn apply_id_offset(func_body: Vec, id_offset: u32) -> Vec( - f_args: &'b [ast::Argument<'a>], + mut id_defs: FnStringIdResolver<'a, 'b>, + f_args: &'b [ast::Argument>], f_body: Vec>>, -) -> (Vec, spirv::Word) { - let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body); - let normalized_statements = normalize_predicates(normalized_ids, &mut id_def); - let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def); - let expanded_statements = expand_arguments(ssa_statements, &mut id_def); - let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def); - let labeled_statements = normalize_labels(expanded_statements, &mut id_def); - (labeled_statements, id_def.ids_count()) +) -> Vec { + let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body); + let mut numeric_id_defs = id_defs.finish(); + let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); + let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs); + let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); + let expanded_statements = + insert_implicit_conversions(expanded_statements, &mut numeric_id_defs); + let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + labeled_statements } fn normalize_labels( @@ -391,9 +402,9 @@ fn insert_mem_ssa_statements( result } -fn expand_arguments( +fn expand_arguments<'a, 'b, 'c>( func: Vec, - id_def: &mut NumericIdResolver, + id_def: &'c mut NumericIdResolver<'a, 'b>, ) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { @@ -416,18 +427,23 @@ fn expand_arguments( result } -struct FlattenArguments<'a> { - func: &'a mut Vec, - id_def: &'a mut NumericIdResolver, +struct FlattenArguments<'a, 'b, 'c> { + func: &'c mut Vec, + id_def: &'c mut NumericIdResolver<'a, 'b>, } -impl<'a> FlattenArguments<'a> { - fn new(func: &'a mut Vec, id_def: &'a mut NumericIdResolver) -> Self { +impl<'a, 'b, 'c> FlattenArguments<'a, 'b, 'c> { + fn new( + func: &'c mut Vec, + id_def: &'c mut NumericIdResolver<'a, 'b>, + ) -> Self { FlattenArguments { func, id_def } } } -impl<'a> ArgumentMapVisitor for FlattenArguments<'a> { +impl<'a, 'b, 'c> ArgumentMapVisitor + for FlattenArguments<'a, 'b, 'c> +{ fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { desc.op } @@ -577,18 +593,17 @@ fn insert_implicit_conversions( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - args: &[ast::Argument], + args: &[ast::Argument], ) -> spirv::Word { map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type))) } fn emit_function_args( builder: &mut dr::Builder, - id_offset: spirv::Word, map: &mut TypeWordMap, - args: &[ast::Argument], + args: &[ast::Argument], ) { - let mut id = id_offset; + let mut id = todo!(); for arg in args { let result_type = map.get_or_add_scalar(builder, arg.a_type); let inst = dr::Instruction::new( @@ -606,9 +621,9 @@ fn emit_function_body_ops( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, - func: &[ExpandedStatement], + func: &Option>, ) -> Result<(), dr::Error> { - for s in func { + for s in func.as_ref().unwrap() { match s { Statement::Label(id) => { if builder.block.is_some() { @@ -1079,10 +1094,10 @@ fn emit_implicit_conversion( // TODO: support scopes fn normalize_identifiers<'a, 'b>( - args: &'b [ast::Argument<'a>], + id_defs: &mut FnStringIdResolver<'a, 'b>, + args: &[ast::Argument>], func: Vec>>, -) -> (Vec>, NumericIdResolver) { - let mut id_defs = StringIdResolver::new(); +) -> Vec> { for arg in args { id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type))); } @@ -1096,13 +1111,13 @@ fn normalize_identifiers<'a, 'b>( } let mut result = Vec::new(); for s in func { - expand_map_variables(&mut id_defs, &mut result, s); + expand_map_variables(id_defs, &mut result, s); } - (result, id_defs.finish()) + result } -fn expand_map_variables<'a>( - id_defs: &mut StringIdResolver<'a>, +fn expand_map_variables<'a, 'b>( + id_defs: &mut FnStringIdResolver<'a, 'b>, result: &mut Vec>, s: ast::Statement>, ) { @@ -1145,24 +1160,53 @@ fn expand_map_variables<'a>( } } -struct StringIdResolver<'a> { +struct GlobalStringIdResolver<'a> { current_id: spirv::Word, + variables: HashMap, spirv::Word>, +} + +impl<'a> GlobalStringIdResolver<'a> { + fn new(start_id: spirv::Word) -> Self { + Self { + current_id: start_id, + variables: HashMap::new(), + } + } + + fn add_def(&mut self, id: &'a str) -> spirv::Word { + let numeric_id = self.current_id; + self.variables.insert(Cow::Borrowed(id), numeric_id); + self.current_id += 1; + numeric_id + } + + fn reserve_id(&mut self) { + self.current_id += 1; + } + + fn current_id(&self) -> spirv::Word { + self.current_id + } +} + +struct FnStringIdResolver<'a, 'b> { + global: &'b mut GlobalStringIdResolver<'a>, variables: Vec, spirv::Word>>, type_check: HashMap, } -impl<'a> StringIdResolver<'a> { - fn new() -> Self { - StringIdResolver { - current_id: 0u32, +impl<'a, 'b> FnStringIdResolver<'a, 'b> { + fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self { + Self { + global: global, variables: vec![HashMap::new(); 1], type_check: HashMap::new(), } } - fn finish(self) -> NumericIdResolver { + fn finish(self) -> NumericIdResolver<'a, 'b> { NumericIdResolver { - current_id: self.current_id, + global: self.global, type_check: self.type_check, } } @@ -1175,18 +1219,18 @@ impl<'a> StringIdResolver<'a> { self.variables.pop(); } - fn get_id(&self, id: &'a str) -> spirv::Word { + fn get_id(&self, id: &str) -> spirv::Word { for scope in self.variables.iter().rev() { match scope.get(id) { Some(id) => return *id, None => continue, } } - panic!() + self.global.variables[id] } fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { - let numeric_id = self.current_id; + let numeric_id = self.global.current_id; self.variables .last_mut() .unwrap() @@ -1194,7 +1238,7 @@ impl<'a> StringIdResolver<'a> { if let Some(typ) = typ { self.type_check.insert(numeric_id, typ); } - self.current_id += 1; + self.global.current_id += 1; numeric_id } @@ -1205,7 +1249,7 @@ impl<'a> StringIdResolver<'a> { count: u32, typ: ast::Type, ) -> impl Iterator { - let numeric_id = self.current_id; + let numeric_id = self.global.current_id; for i in 0..count { self.variables .last_mut() @@ -1213,33 +1257,29 @@ impl<'a> StringIdResolver<'a> { .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); self.type_check.insert(numeric_id + i, typ); } - self.current_id += count; + self.global.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) } } -struct NumericIdResolver { - current_id: spirv::Word, +struct NumericIdResolver<'a, 'b> { + global: &'b mut GlobalStringIdResolver<'a>, type_check: HashMap, } -impl NumericIdResolver { +impl<'a, 'b> NumericIdResolver<'a, 'b> { fn get_type(&self, id: spirv::Word) -> ast::Type { self.type_check[&id] } fn new_id(&mut self, typ: Option) -> spirv::Word { - let new_id = self.current_id; + let new_id = self.global.current_id; if let Some(typ) = typ { self.type_check.insert(new_id, typ); } - self.current_id += 1; + self.global.current_id += 1; new_id } - - fn ids_count(&self) -> spirv::Word { - self.current_id - } } enum Statement { @@ -1284,6 +1324,7 @@ impl ast::ArgParams for NormalizedArgParams { enum ExpandedArgParams {} type ExpandedStatement = Statement>; +type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>; impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word;