From 2b3ecc99e3b2a1c0a1989733da17b359e974951c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 18 Oct 2020 14:46:05 +0200 Subject: [PATCH] Implement pass to handle .extern .shared and add parsing code for it --- ptx/Cargo.toml | 1 + ptx/src/ast.rs | 84 +++- ptx/src/lib.rs | 3 + ptx/src/ptx.lalrpop | 270 +++++++----- .../test/spirv_build/global_extern_array.ptx | 5 + .../test/spirv_build/param_func_array_0.ptx | 10 + ptx/src/test/spirv_fail/const_ptr.ptx | 5 + ptx/src/test/spirv_fail/global_ptr.ptx | 5 + ptx/src/test/spirv_fail/local_ptr.txt | 12 + .../test/spirv_fail/param_entry_array_0.ptx | 10 + ptx/src/test/spirv_fail/param_vector.ptx | 10 + ptx/src/test/spirv_fail/shared_ptr.ptx | 5 + ptx/src/test/spirv_fail/shared_ptr2.ptx | 13 + ptx/src/test/spirv_run/extern_shared.ptx | 24 ++ ptx/src/test/spirv_run/extern_shared.spvtxt | 53 +++ ptx/src/test/spirv_run/extern_shared_call.ptx | 45 ++ .../test/spirv_run/extern_shared_call.spvtxt | 53 +++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 391 ++++++++++++++++-- 19 files changed, 877 insertions(+), 123 deletions(-) create mode 100644 ptx/src/test/spirv_build/global_extern_array.ptx create mode 100644 ptx/src/test/spirv_build/param_func_array_0.ptx create mode 100644 ptx/src/test/spirv_fail/const_ptr.ptx create mode 100644 ptx/src/test/spirv_fail/global_ptr.ptx create mode 100644 ptx/src/test/spirv_fail/local_ptr.txt create mode 100644 ptx/src/test/spirv_fail/param_entry_array_0.ptx create mode 100644 ptx/src/test/spirv_fail/param_vector.ptx create mode 100644 ptx/src/test/spirv_fail/shared_ptr.ptx create mode 100644 ptx/src/test/spirv_fail/shared_ptr2.ptx create mode 100644 ptx/src/test/spirv_run/extern_shared.ptx create mode 100644 ptx/src/test/spirv_run/extern_shared.spvtxt create mode 100644 ptx/src/test/spirv_run/extern_shared_call.ptx create mode 100644 ptx/src/test/spirv_run/extern_shared_call.spvtxt diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 96ab9d0..409cd1f 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -14,6 +14,7 @@ spirv_headers = "~1.4.2" quick-error = "1.2" bit-vec = "0.6" half ="1.6" +bitflags = "1.2" [build-dependencies.lalrpop] version = "0.19" diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index c6510da..1e90eba 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,4 +1,5 @@ -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; +use std::convert::TryInto; +use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; use half::f16; @@ -22,6 +23,8 @@ quick_error! { WrongVectorElement {} MultiArrayVariable {} ZeroDimensionArray {} + ArrayInitalizer {} + NonExternPointer {} } } @@ -78,6 +81,21 @@ macro_rules! sub_type { } } } + + impl std::convert::TryFrom for $type_name { + type Error = (); + + #[allow(non_snake_case)] + #[allow(unreachable_patterns)] + fn try_from(t: Type) -> Result { + match t { + $( + Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), + )+ + _ => Err(()), + } + } + } }; } @@ -98,14 +116,39 @@ sub_type! { } } +impl TryFrom for VariableLocalType { + type Error = PtxError; + + fn try_from(value: VariableGlobalType) -> Result { + match value { + VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), + VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), + VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), + VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), + } + } +} + +sub_type! { + VariableGlobalType { + Scalar(SizedScalarType), + Vector(SizedScalarType, u8), + Array(SizedScalarType, VecU32), + Pointer(SizedScalarType, PointerStateSpace), + } +} + // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; +// even more interestingly this is legal, but only in .func (not in .entry): +// .param .b32 foobar[] sub_type! { VariableParamType { Scalar(ParamScalarType), Array(SizedScalarType, VecU32), + Pointer(SizedScalarType, PointerStateSpace), } } @@ -193,7 +236,7 @@ pub enum MethodDecl<'a, ID> { } pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub type KernelArgument = Variable; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, @@ -206,6 +249,12 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement for Type { @@ -213,15 +262,25 @@ impl From for Type { match t { FnArgumentType::Reg(x) => x.into(), FnArgumentType::Param(x) => x.into(), + FnArgumentType::Shared => Type::Scalar(ScalarType::B64), } } } -#[derive(PartialEq, Eq, Hash, Clone)] +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum PointerStateSpace { + Global, + Const, + Shared, + Param, +} + +#[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec), + Pointer(ScalarType, PointerStateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -343,7 +402,8 @@ pub enum VariableType { Reg(VariableRegType), Local(VariableLocalType), Param(VariableParamType), - Global(VariableLocalType), + Global(VariableGlobalType), + Shared(VariableGlobalType), } impl VariableType { @@ -353,6 +413,7 @@ impl VariableType { VariableType::Local(t) => (StateSpace::Local, t.clone().into()), VariableType::Param(t) => (StateSpace::Param, t.clone().into()), VariableType::Global(t) => (StateSpace::Global, t.clone().into()), + VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), } } } @@ -364,6 +425,7 @@ impl From for Type { VariableType::Local(t) => t.into(), VariableType::Param(t) => t.into(), VariableType::Global(t) => t.into(), + VariableType::Shared(t) => t.into(), } } } @@ -1039,6 +1101,20 @@ impl<'a> NumsOrArrays<'a> { } } +pub enum ArrayOrPointer { + Array { dimensions: Vec, init: Vec }, + Pointer, +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 8ae1c6d..1aac8ab 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -17,6 +17,9 @@ extern crate spirv_headers as spirv; #[cfg(test)] extern crate spirv_tools_sys as spirv_tools; +#[macro_use] +extern crate bitflags; + lalrpop_mod!( #[allow(warnings)] ptx diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 0b6fa0f..4624580 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -3,6 +3,7 @@ use crate::ast::UnwrapWithVec; use crate::{without_none, vector_index}; use lalrpop_util::ParseError; +use std::convert::TryInto; grammar<'a>(errors: &mut Vec); @@ -210,7 +211,7 @@ Directive: Option>> = { => Some(ast::Directive::Method(f)), File => None, Section => None, - ";" => Some(ast::Directive::Variable(v)), + ";" => Some(ast::Directive::Variable(v)), }; AddressSize = { @@ -218,17 +219,23 @@ AddressSize = { }; Function: ast::Function<'input, &'input str, ast::Statement>> = { - LinkingDirective* + LinkingDirectives => ast::Function{<>} }; -LinkingDirective = { - ".extern", - ".visible", - ".weak" +LinkingDirective: ast::LinkingDirective = { + ".extern" => ast::LinkingDirective::EXTERN, + ".visible" => ast::LinkingDirective::VISIBLE, + ".weak" => ast::LinkingDirective::WEAK, }; +LinkingDirectives: ast::LinkingDirective = { + => { + ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y) + } +} + MethodDecl: ast::MethodDecl<'input, &'input str> = { ".entry" => ast::MethodDecl::Kernel(name, params), ".func" => { @@ -244,10 +251,15 @@ FnArguments: Vec> = { "(" > ")" => args }; -KernelInput: ast::Variable = { +KernelInput: ast::Variable = { => { let (align, v_type, name) = v; - ast::Variable{ align, v_type, name, array_init: Vec::new() } + ast::Variable { + align, + v_type: ast::KernelArgumentType::Normal(v_type), + name, + array_init: Vec::new() + } } } @@ -357,69 +369,120 @@ Variable: ast::Variable = { }; RegVariable: (Option, ast::VariableRegType, &'input str) = { - ".reg" => { + ".reg" > => { + let (align, t, name) = var; let v_type = ast::VariableRegType::Scalar(t); (align, v_type, name) }, - ".reg" => { + ".reg" > => { + let (align, v_len, t, name) = var; let v_type = ast::VariableRegType::Vector(t, v_len); (align, v_type, name) } } LocalVariable: ast::Variable = { - ".local" => { - let (align, array_init, v_type, name) = def; - ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init } + ".local" > => { + let (align, t, name) = var; + let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); + ast::Variable { align, v_type, name, array_init: Vec::new() } + }, + ".local" > => { + let (align, v_len, t, name) = var; + let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); + ast::Variable { align, v_type, name, array_init: Vec::new() } + }, + ".local" > =>? { + let (align, t, name, arr_or_ptr) = var; + let (v_type, array_init) = match arr_or_ptr { + ast::ArrayOrPointer::Array { dimensions, init } => { + (ast::VariableLocalType::Array(t, dimensions), init) + } + ast::ArrayOrPointer::Pointer => { + return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); + } + }; + Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) } } -GlobalVariable: ast::Variable = { - ".global" => { - let (align, array_init, v_type, name) = def; +ModuleVariable: ast::Variable = { + LinkingDirectives ".global" => { + let (align, v_type, name, array_init) = def; ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + }, + LinkingDirectives ".shared" => { + let (align, v_type, name, array_init) = def; + ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + }, + > > =>? { + let (align, t, name, arr_or_ptr) = var; + let (v_type, array_init) = match arr_or_ptr { + ast::ArrayOrPointer::Array { dimensions, init } => { + if space == ".global" { + (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + } else { + (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + } + } + ast::ArrayOrPointer::Pointer => { + if !ldirs.contains(ast::LinkingDirective::EXTERN) { + return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); + } + if space == ".global" { + (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new()) + } else { + (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new()) + } + } + }; + Ok(ast::Variable{ align, array_init, v_type, name }) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { - ".param" => { + ".param" > => { + let (align, t, name) = var; let v_type = ast::VariableParamType::Scalar(t); (align, Vec::new(), v_type, name) }, - ".param" => { - let (array_init, name, (t, dimensions)) = arr; - let v_type = ast::VariableParamType::Array(t, dimensions); + ".param" > => { + let (align, t, name, arr_or_ptr) = var; + let (v_type, array_init) = match arr_or_ptr { + ast::ArrayOrPointer::Array { dimensions, init } => { + (ast::VariableParamType::Array(t, dimensions), init) + } + ast::ArrayOrPointer::Pointer => { + (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new()) + } + }; (align, array_init, v_type, name) } } ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { - ".param" => { - let v_type = ast::VariableParamType::Scalar(t); - (align, v_type, name) - }, - ".param" => { - let (name, (t, dimensions)) = arr; - let v_type = ast::VariableParamType::Array(t, dimensions); - (align, v_type, name) + =>? { + let (align, array_init, v_type, name) = var; + if array_init.len() > 0 { + Err(ParseError::User { error: ast::PtxError::ArrayInitalizer }) + } else { + Ok((align, v_type, name)) + } } } -LocalVariableDefinition: (Option, Vec, ast::VariableLocalType, &'input str) = { - => { - let v_type = ast::VariableLocalType::Scalar(t); - (align, Vec::new(), v_type, name) +GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input str, Vec) = { + > => { + let (align, t, name) = scalar; + let v_type = ast::VariableGlobalType::Scalar(t); + (align, v_type, name, Vec::new()) }, - => { - let v_type = ast::VariableLocalType::Vector(t, v_len); - (align, Vec::new(), v_type, name) + > => { + let (align, v_len, t, name) = var; + let v_type = ast::VariableGlobalType::Vector(t, v_len); + (align, v_type, name, Vec::new()) }, - => { - let (array_init, name, (t, dimensions)) = arr; - let v_type = ast::VariableLocalType::Array(t, dimensions); - (align, array_init, v_type, name) - } } #[inline] @@ -461,60 +524,6 @@ ParamScalarType: ast::ParamScalarType = { ".f64" => ast::ParamScalarType::F64, } -ArrayDefinition: (Vec, &'input str, (ast::SizedScalarType, Vec)) = { - =>? { - let mut dims = dims; - let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?; - Ok(( - array_init, - name, - (typ, dims) - )) - } -} - -ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec)) = { - =>? { - let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::>()?; - Ok((name, (typ, dims))) - } -} - -// [0] and [] are treated the same -ArrayDimensions: Vec = { - ArrayEmptyDimension => vec![0u32], - ArrayEmptyDimension => { - let mut dims = dims; - let mut result = vec![0u32]; - result.append(&mut dims); - result - }, - => dims -} - -ArrayEmptyDimension = { - "[" "]" -} - -ArrayDimension: u32 = { - "[" "]" =>? { - str::parse::(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) - } -} - -ArrayInitializer: ast::NumsOrArrays<'input> = { - "=" => nums -} - -NumsOrArraysBracket: ast::NumsOrArrays<'input> = { - "{" "}" => nums -} - -NumsOrArrays: ast::NumsOrArrays<'input> = { - > => ast::NumsOrArrays::Arrays(n), - > => ast::NumsOrArrays::Nums(n), -} - Instruction: ast::Instruction> = { InstLd, InstMov, @@ -1311,6 +1320,73 @@ BitType = { ".b8", ".b16", ".b32", ".b64" }; +VariableScalar: (Option, T, &'input str) = { + => { + (align, v_type, name) + } +} + +VariableVector: (Option, u8, T, &'input str) = { + => { + (align, v_len, v_type, name) + } +} + +// empty dimensions [0] means it's a pointer +VariableArrayOrPointer: (Option, T, &'input str, ast::ArrayOrPointer) = { + =>? { + let mut dims = dims; + let array_init = match init { + Some(init) => { + let init_vec = init.to_vec(typ, &mut dims)?; + ast::ArrayOrPointer::Array { dimensions: dims, init: init_vec } + } + None => { + if dims.len() > 1 && dims.contains(&0) { + return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) + } + ast::ArrayOrPointer::Pointer + } + }; + Ok((align, typ, name, array_init)) + } +} + +// [0] and [] are treated the same +ArrayDimensions: Vec = { + ArrayEmptyDimension => vec![0u32], + ArrayEmptyDimension => { + let mut dims = dims; + let mut result = vec![0u32]; + result.append(&mut dims); + result + }, + => dims +} + +ArrayEmptyDimension = { + "[" "]" +} + +ArrayDimension: u32 = { + "[" "]" =>? { + str::parse::(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) + } +} + +ArrayInitializer: ast::NumsOrArrays<'input> = { + "=" => nums +} + +NumsOrArraysBracket: ast::NumsOrArrays<'input> = { + "{" "}" => nums +} + +NumsOrArrays: ast::NumsOrArrays<'input> = { + > => ast::NumsOrArrays::Arrays(n), + > => ast::NumsOrArrays::Nums(n), +} + Comma: Vec = { ",")*> => match e { None => v, @@ -1329,3 +1405,9 @@ CommaNonEmpty: Vec = { v } }; + +#[inline] +Or: T1 = { + T1, + T2 +} \ No newline at end of file diff --git a/ptx/src/test/spirv_build/global_extern_array.ptx b/ptx/src/test/spirv_build/global_extern_array.ptx new file mode 100644 index 0000000..fe0f19f --- /dev/null +++ b/ptx/src/test/spirv_build/global_extern_array.ptx @@ -0,0 +1,5 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .global .b32 foobar [1]; \ No newline at end of file diff --git a/ptx/src/test/spirv_build/param_func_array_0.ptx b/ptx/src/test/spirv_build/param_func_array_0.ptx new file mode 100644 index 0000000..005af52 --- /dev/null +++ b/ptx/src/test/spirv_build/param_func_array_0.ptx @@ -0,0 +1,10 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .func foobar( + .param .b32 foobar[] +) +{ +ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/const_ptr.ptx b/ptx/src/test/spirv_fail/const_ptr.ptx new file mode 100644 index 0000000..0efd729 --- /dev/null +++ b/ptx/src/test/spirv_fail/const_ptr.ptx @@ -0,0 +1,5 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.const .b32 foobar []; \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/global_ptr.ptx b/ptx/src/test/spirv_fail/global_ptr.ptx new file mode 100644 index 0000000..7ce4c83 --- /dev/null +++ b/ptx/src/test/spirv_fail/global_ptr.ptx @@ -0,0 +1,5 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.global .b32 foobar []; \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/local_ptr.txt b/ptx/src/test/spirv_fail/local_ptr.txt new file mode 100644 index 0000000..9375011 --- /dev/null +++ b/ptx/src/test/spirv_fail/local_ptr.txt @@ -0,0 +1,12 @@ +.version 6.5 +.target sm_30 +.address_size 64 + + +.visible .entry func() +{ + + .local .b32 foobar [1]; + + ret; +} diff --git a/ptx/src/test/spirv_fail/param_entry_array_0.ptx b/ptx/src/test/spirv_fail/param_entry_array_0.ptx new file mode 100644 index 0000000..86dd5eb --- /dev/null +++ b/ptx/src/test/spirv_fail/param_entry_array_0.ptx @@ -0,0 +1,10 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry foobar( + .param .b32 foobar[] +) +{ +ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/param_vector.ptx b/ptx/src/test/spirv_fail/param_vector.ptx new file mode 100644 index 0000000..28895e2 --- /dev/null +++ b/ptx/src/test/spirv_fail/param_vector.ptx @@ -0,0 +1,10 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .func foobar( + .param .b32 .v2 foobar +) +{ +ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/shared_ptr.ptx b/ptx/src/test/spirv_fail/shared_ptr.ptx new file mode 100644 index 0000000..b1b815a --- /dev/null +++ b/ptx/src/test/spirv_fail/shared_ptr.ptx @@ -0,0 +1,5 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +extern .shared .b32 foobar []; \ No newline at end of file diff --git a/ptx/src/test/spirv_fail/shared_ptr2.ptx b/ptx/src/test/spirv_fail/shared_ptr2.ptx new file mode 100644 index 0000000..fb2472a --- /dev/null +++ b/ptx/src/test/spirv_fail/shared_ptr2.ptx @@ -0,0 +1,13 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .shared .b32 foobar1 []; + +.visible .func _Z4dupaPf( + .param .b64 _Z4dupaPf_param_0 +) +{ +.shared .b32 foobar2 []; +ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/extern_shared.ptx b/ptx/src/test/spirv_run/extern_shared.ptx new file mode 100644 index 0000000..ac5c256 --- /dev/null +++ b/ptx/src/test/spirv_run/extern_shared.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .shared .b32 shared_mem []; + +.visible .entry extern_shared( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp, [in_addr]; + st.shared.u64 [shared_mem], temp; + ld.shared.u64 temp, [shared_mem]; + st.global.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt new file mode 100644 index 0000000..84e7eac --- /dev/null +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -0,0 +1,53 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %29 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "cvta" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %32 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %1 = OpFunction %void None %32 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %27 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %22 = OpCopyObject %ulong %14 + %21 = OpCopyObject %ulong %22 + %13 = OpCopyObject %ulong %21 + OpStore %4 %13 + %16 = OpLoad %ulong %5 + %24 = OpCopyObject %ulong %16 + %23 = OpCopyObject %ulong %24 + %15 = OpCopyObject %ulong %23 + OpStore %5 %15 + %18 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + %17 = OpLoad %float %25 + OpStore %6 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %float %6 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 + OpStore %26 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.ptx b/ptx/src/test/spirv_run/extern_shared_call.ptx new file mode 100644 index 0000000..6626783 --- /dev/null +++ b/ptx/src/test/spirv_run/extern_shared_call.ptx @@ -0,0 +1,45 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .shared .align 4 .b32 shared_mem[]; + +.func (.param .u64 output) incr_shared_2_param( + .param .u64 .ptr .shared shared_mem_addr +) +{ + .reg .u64 temp; + ld.shared.u64 temp, [shared_mem_addr]; + add.u64 temp, temp, 2; + st.param.u64 [output], temp; + ret; +} + +.func (.param .u64 output) incr_shared_2_global() +{ + .reg .u64 temp; + ld.shared.u64 temp, [shared_mem]; + add.u64 temp, temp, 2; + st.param.u64 [output], temp; + ret; +} + + +.visible .entry extern_shared( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp, [in_addr]; + st.shared.u64 [shared_mem], temp; + ld.shared.u64 temp, [shared_mem]; + st.global.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt new file mode 100644 index 0000000..84e7eac --- /dev/null +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -0,0 +1,53 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %29 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "cvta" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %32 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %1 = OpFunction %void None %32 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %27 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %22 = OpCopyObject %ulong %14 + %21 = OpCopyObject %ulong %22 + %13 = OpCopyObject %ulong %21 + OpStore %4 %13 + %16 = OpLoad %ulong %5 + %24 = OpCopyObject %ulong %16 + %23 = OpCopyObject %ulong %24 + %15 = OpCopyObject %ulong %23 + OpStore %5 %15 + %18 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + %17 = OpLoad %float %25 + OpStore %6 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %float %6 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 + OpStore %26 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 0c881d9..14c3bc9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -78,6 +78,7 @@ test_ptx!(sub, [2u64], [1u64]); test_ptx!(min, [555i32, 444i32], [444i32]); test_ptx!(max, [555i32, 444i32], [555i32]); test_ptx!(global_array, [0xDEADu32], [1u32]); +test_ptx!(extern_shared, [127u64], [127u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a86ab3c..09dd0bb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -34,11 +34,7 @@ enum SpirvType { impl SpirvType { fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { - let key = match t { - ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), - ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len), - ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len), - }; + let key = t.into(); SpirvType::Pointer(Box::new(key), sc) } } @@ -49,6 +45,20 @@ impl From for SpirvType { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), + ast::Type::Pointer(typ, state_space) => { + SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into()) + } + } + } +} + +impl Into for ast::PointerStateSpace { + fn into(self) -> spirv::StorageClass { + match self { + ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant, + ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, + ast::PointerStateSpace::Param => spirv::StorageClass::Function, } } } @@ -354,6 +364,14 @@ impl TypeWordMap { b.constant_composite(result_type, None, &components) } }, + ast::Type::Pointer(typ, state_space) => { + let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?; + let result_type = self.get_or_add( + b, + SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()), + ); + b.variable(result_type, None, (*state_space).into(), Some(base)) + } }) } @@ -415,13 +433,7 @@ pub fn to_spirv_module<'a>( None => continue, }; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; - emit_function_header( - &mut builder, - &mut map, - &id_defs, - f.func_directive, - &mut args_len, - )?; + emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; builder.end_function()?; } @@ -430,6 +442,202 @@ pub fn to_spirv_module<'a>( Ok((builder.module(), args_len)) } +type MultiHashMap = HashMap>; + +fn multi_hash_map_append(m: &mut MultiHashMap, key: K, value: V) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().push(value); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![value]); + } + } +} + +// PTX represents dynamically allocated shared local memory as +// .extern .shared .align 4 .b8 shared_mem[]; +// In SPIRV/OpenCL world this is expressed as an additional argument +// This pass looks for all uses of .extern .shared and converts them to +// an additional method argument +fn convert_dynamic_shared_memory_usage<'input>( + new_id: &mut impl FnMut() -> spirv::Word, + id_defs: &mut GlobalStringIdResolver<'input>, + module: Vec>, +) -> Vec> { + let mut extern_shared_decls = HashSet::new(); + for dir in module.iter() { + match dir { + Directive::Variable(var) => { + if let ast::VariableType::Shared(_) = var.v_type { + extern_shared_decls.insert(var.name); + } + } + _ => {} + } + } + if extern_shared_decls.len() == 0 { + return module; + } + let mut methods_using_extern_shared = HashSet::new(); + let mut directly_called_by = MultiHashMap::new(); + let module = module + .into_iter() + .map(|directive| match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + }) => { + let call_key = match func_decl { + ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), + ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), + }; + let statements = statements + .into_iter() + .map(|statement| match statement { + Statement::Call(call) => { + multi_hash_map_append(&mut directly_called_by, call.func, call_key); + Statement::Call(call) + } + statement => statement.map_id(&mut |id| { + if extern_shared_decls.contains(&id) { + methods_using_extern_shared.insert(call_key); + } + id + }), + }) + .collect(); + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + }) + } + directive => directive, + }) + .collect::>(); + // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, + // make sure it gets propagated to `fn1` and `kernel` + get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by); + // now visit every method declaration and inject those additional arguments + module + .into_iter() + .map(|directive| match directive { + Directive::Method(Function { + mut func_decl, + globals, + body: Some(statements), + }) => { + let call_key = match func_decl { + ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), + ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), + }; + if !methods_using_extern_shared.contains(&call_key) { + return Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + }); + } + let shared_id_param = new_id(); + match &mut func_decl { + ast::MethodDecl::Func(_, _, input_args) => { + input_args.push(ast::Variable { + align: None, + v_type: ast::FnArgumentType::Shared, + array_init: Vec::new(), + name: shared_id_param, + }); + } + ast::MethodDecl::Kernel(_, input_args) => { + input_args.push(ast::Variable { + align: None, + v_type: ast::KernelArgumentType::Shared, + array_init: Vec::new(), + name: shared_id_param, + }); + } + } + let statements = statements + .into_iter() + .map(|statement| match statement { + Statement::Call(mut call) => { + // We can safely skip checking call arguments, + // because there's simply no way to pass shared ptr + // without converting it to .b64 first + if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) + { + call.param_list + .push((shared_id_param, ast::FnArgumentType::Shared)); + } + Statement::Call(call) + } + statement => statement.map_id(&mut |id| { + if extern_shared_decls.contains(&id) { + shared_id_param + } else { + id + } + }), + }) + .collect(); + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + }) + } + directive => directive, + }) + .collect::>() +} + +fn get_callers_of_extern_shared<'a>( + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, +) { + let direct_uses_of_extern_shared = methods_using_extern_shared + .iter() + .filter_map(|method| { + if let CallgraphKey::Func(f_id) = method { + Some(*f_id) + } else { + None + } + }) + .collect::>(); + for fn_id in direct_uses_of_extern_shared { + get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id); + } +} + +fn get_callers_of_extern_shared_single<'a>( + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, + fn_id: spirv::Word, +) { + if let Some(callers) = directly_called_by.get(&fn_id) { + for caller in callers { + if methods_using_extern_shared.insert(*caller) { + if let CallgraphKey::Func(caller_fn) = caller { + get_callers_of_extern_shared_single( + methods_using_extern_shared, + directly_called_by, + *caller_fn, + ); + } + } + } + } +} + +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +enum CallgraphKey<'input> { + Kernel(&'input str), + Func(spirv::Word), +} + fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -594,6 +802,7 @@ fn expand_fn_params<'a, 'b>( let ss = match a.v_type { ast::FnArgumentType::Reg(_) => StateSpace::Reg, ast::FnArgumentType::Param(_) => StateSpace::Param, + ast::FnArgumentType::Shared => StateSpace::Shared, }; ast::FnArgument { name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))), @@ -615,7 +824,7 @@ fn to_ssa<'input, 'b>( Some(vec) => vec, None => { return Ok(Function { - func_directive: f_args, + func_decl: f_args, body: None, globals: Vec::new(), }) @@ -637,7 +846,7 @@ fn to_ssa<'input, 'b>( let sorted_statements = normalize_variable_decls(labeled_statements); let (f_body, globals) = extract_globals(sorted_statements); Ok(Function { - func_directive: f_args, + func_decl: f_args, globals: globals, body: Some(f_body), }) @@ -935,7 +1144,7 @@ fn insert_mem_ssa_statements<'a, 'b>( let new_id = id_def.new_id(typ.clone()); result.push(Statement::Variable(ast::Variable { align: p.align, - v_type: ast::VariableType::Param(p.v_type.clone()), + v_type: ast::VariableType::Param(p.v_type.clone().to_param()), name: p.name, array_init: p.array_init.clone(), })); @@ -1878,26 +2087,33 @@ fn emit_variable( map: &mut TypeWordMap, var: &ast::Variable, ) -> Result<(), TranslateError> { - let (should_init, st_class) = match var.v_type { + let (must_init, st_class) = match var.v_type { ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { (false, spirv::StorageClass::Function) } ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), + ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), }; - let type_id = map.get_or_add( - builder, - SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), - ); - let initalizer = if should_init { + let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( builder, &ast::Type::from(var.v_type.clone()), &*var.array_init, )?) + } else if must_init { + let type_id = map.get_or_add( + builder, + SpirvType::from(ast::Type::from(var.v_type.clone())), + ); + Some(builder.constant_null(type_id, None)) } else { None }; - builder.variable(type_id, Some(var.name), st_class, initalizer); + let ptr_type_id = map.get_or_add( + builder, + SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), + ); + builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { builder.decorate( var.name, @@ -2537,7 +2753,8 @@ fn expand_map_variables<'a, 'b>( ast::VariableType::Reg(_) => StateSpace::Reg, ast::VariableType::Local(_) => StateSpace::Local, ast::VariableType::Param(_) => StateSpace::ParamReg, - ast::VariableType::Global(_) => todo!(), + ast::VariableType::Global(_) => StateSpace::Global, + ast::VariableType::Shared(_) => StateSpace::Shared, }; match var.count { Some(count) => { @@ -2888,6 +3105,69 @@ enum Statement { Undef(ast::Type, spirv::Word), } +impl ExpandedStatement { + fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement { + match self { + Statement::Label(id) => Statement::Label(f(id)), + Statement::Variable(mut var) => { + var.name = f(var.name); + Statement::Variable(var) + } + Statement::Instruction(inst) => inst + .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op))) + .unwrap(), + Statement::LoadVar(mut arg, typ) => { + arg.dst = f(arg.dst); + arg.src = f(arg.src); + Statement::LoadVar(arg, typ) + } + Statement::StoreVar(mut arg, typ) => { + arg.src1 = f(arg.src1); + arg.src2 = f(arg.src2); + Statement::StoreVar(arg, typ) + } + Statement::Call(mut call) => { + for (id, _) in call.ret_params.iter_mut() { + *id = f(*id); + } + call.func = f(call.func); + for (id, _) in call.param_list.iter_mut() { + *id = f(*id); + } + Statement::Call(call) + } + Statement::Composite(mut composite) => { + composite.dst = f(composite.dst); + composite.src_composite = f(composite.src_composite); + Statement::Composite(composite) + } + Statement::Conditional(mut conditional) => { + conditional.predicate = f(conditional.predicate); + conditional.if_true = f(conditional.if_true); + conditional.if_false = f(conditional.if_false); + Statement::Conditional(conditional) + } + Statement::Conversion(mut conv) => { + conv.dst = f(conv.dst); + conv.src = f(conv.src); + Statement::Conversion(conv) + } + Statement::Constant(mut constant) => { + constant.dst = f(constant.dst); + Statement::Constant(constant) + } + Statement::RetValue(data, id) => { + let id = f(id); + Statement::RetValue(data, id) + } + Statement::Undef(typ, id) => { + let id = f(id); + Statement::Undef(typ, id) + } + } + } +} + struct ResolvedCall { pub uniform: bool, pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>, @@ -3106,7 +3386,7 @@ enum Directive<'input> { } struct Function<'input> { - pub func_directive: ast::MethodDecl<'input, spirv::Word>, + pub func_decl: ast::MethodDecl<'input, spirv::Word>, pub globals: Vec, pub body: Option>, } @@ -3546,18 +3826,28 @@ impl ast::Type { scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), + state_space: ast::PointerStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], + state_space: ast::PointerStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), + state_space: ast::PointerStateSpace::Global, + }, + ast::Type::Pointer(scalar, state_space) => TypeParts { + kind: TypeKind::Pointer, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + state_space: *state_space, }, } } @@ -3575,6 +3865,10 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), + TypeKind::Pointer => ast::Type::Pointer( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.state_space, + ), } } } @@ -3585,6 +3879,7 @@ struct TypeParts { scalar_kind: ScalarKind, width: u8, components: Vec, + state_space: ast::PointerStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -3592,6 +3887,7 @@ enum TypeKind { Scalar, Vector, Array, + Pointer, } impl ast::Instruction { @@ -3762,6 +4058,36 @@ impl ast::VariableParamType { (ast::ScalarType::from(*t).size_of() as usize) * (len.iter().fold(1, |x, y| x * (*y)) as usize) } + ast::VariableParamType::Pointer(_, _) => mem::size_of::() + } + } +} + +impl ast::KernelArgumentType { + fn width(&self) -> usize { + match self { + ast::KernelArgumentType::Normal(t) => t.width(), + ast::KernelArgumentType::Shared => mem::size_of::(), + } + } +} + +impl From for ast::Type { + fn from(this: ast::KernelArgumentType) -> Self { + match this { + ast::KernelArgumentType::Normal(typ) => typ.into(), + ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64), + } + } +} + +impl ast::KernelArgumentType { + fn to_param(self) -> ast::VariableParamType { + match self { + ast::KernelArgumentType::Normal(p) => p, + ast::KernelArgumentType::Shared => { + ast::VariableParamType::Scalar(ast::ParamScalarType::B64) + } } } } @@ -4598,6 +4924,7 @@ impl From for ast::VariableType { match t { ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), + ast::FnArgumentType::Shared => todo!(), } } } @@ -4648,6 +4975,17 @@ fn bitcast_physical_pointer( ss: Option, ) -> Result, TranslateError> { match operand_type { + // array decays to a pointer + ast::Type::Array(_, vec) => { + if vec.len() != 0 { + return Err(TranslateError::MismatchedType); + } + if let Some(space) = ss { + Ok(Some(ConversionKind::BitToPtr(space))) + } else { + Err(TranslateError::Unreachable) + } + } ast::Type::Scalar(ast::ScalarType::B64) | ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::S64) => { @@ -4882,7 +5220,10 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> { f(&ast::FnArgument { align: arg.align, name: arg.name, - v_type: ast::FnArgumentType::Param(arg.v_type.clone()), + v_type: match arg.v_type.clone() { + ast::KernelArgumentType::Normal(typ) => ast::FnArgumentType::Param(typ), + ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared, + }, array_init: arg.array_init.clone(), }) }),