mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Bring back support for dynamic shared memory
This commit is contained in:
parent
491e71e346
commit
e940b9400f
3 changed files with 87 additions and 76 deletions
|
@ -82,8 +82,8 @@ pub struct Module<'a> {
|
|||
}
|
||||
|
||||
pub enum Directive<'a, P: ArgParams> {
|
||||
Variable(Variable<P::Id>),
|
||||
Method(Function<'a, &'a str, Statement<P>>),
|
||||
Variable(LinkingDirective, Variable<P::Id>),
|
||||
Method(LinkingDirective, Function<'a, &'a str, Statement<P>>),
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||
|
@ -96,7 +96,7 @@ pub struct MethodDeclaration<'input, ID> {
|
|||
pub return_arguments: Vec<Variable<ID>>,
|
||||
pub name: MethodName<'input, ID>,
|
||||
pub input_arguments: Vec<Variable<ID>>,
|
||||
pub shared_mem: Option<Variable<ID>>,
|
||||
pub shared_mem: Option<ID>,
|
||||
}
|
||||
|
||||
pub struct Function<'a, ID, S> {
|
||||
|
|
|
@ -343,10 +343,16 @@ TargetSpecifier = {
|
|||
|
||||
Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
|
||||
AddressSize => None,
|
||||
<f:Function> => Some(ast::Directive::Method(f)),
|
||||
<f:Function> => {
|
||||
let (linking, func) = f;
|
||||
Some(ast::Directive::Method(linking, func))
|
||||
},
|
||||
File => None,
|
||||
Section => None,
|
||||
<v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)),
|
||||
<v:ModuleVariable> ";" => {
|
||||
let (linking, var) = v;
|
||||
Some(ast::Directive::Variable(linking, var))
|
||||
},
|
||||
! => {
|
||||
let err = <>;
|
||||
errors.push(err.error);
|
||||
|
@ -358,11 +364,13 @@ AddressSize = {
|
|||
".address_size" U8Num
|
||||
};
|
||||
|
||||
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||
LinkingDirectives
|
||||
Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>) = {
|
||||
<linking:LinkingDirectives>
|
||||
<func_directive:MethodDeclaration>
|
||||
<tuning:TuningDirective*>
|
||||
<body:FunctionBody> => ast::Function{<>}
|
||||
<body:FunctionBody> => {
|
||||
(linking, ast::Function{func_directive, tuning, body})
|
||||
}
|
||||
};
|
||||
|
||||
LinkingDirective: ast::LinkingDirective = {
|
||||
|
@ -598,18 +606,18 @@ SharedVariable: ast::Variable<&'input str> = {
|
|||
}
|
||||
}
|
||||
|
||||
ModuleVariable: ast::Variable<&'input str> = {
|
||||
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
|
||||
ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = {
|
||||
<linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => {
|
||||
let (align, v_type, name, array_init) = def;
|
||||
let state_space = ast::StateSpace::Global;
|
||||
ast::Variable { align, v_type, state_space, name, array_init }
|
||||
(linking, ast::Variable { align, v_type, state_space, name, array_init })
|
||||
},
|
||||
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
|
||||
<linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => {
|
||||
let (align, v_type, name, array_init) = def;
|
||||
let state_space = ast::StateSpace::Shared;
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
(linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() })
|
||||
},
|
||||
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||
<linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||
let (align, t, name, arr_or_ptr) = var;
|
||||
let (v_type, state_space, array_init) = match arr_or_ptr {
|
||||
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||
|
@ -620,17 +628,17 @@ ModuleVariable: ast::Variable<&'input str> = {
|
|||
}
|
||||
}
|
||||
ast::ArrayOrPointer::Pointer => {
|
||||
if !ldirs.contains(ast::LinkingDirective::EXTERN) {
|
||||
if !linking.contains(ast::LinkingDirective::EXTERN) {
|
||||
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
|
||||
}
|
||||
if space == ".global" {
|
||||
(ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new())
|
||||
(ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new())
|
||||
} else {
|
||||
(ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new())
|
||||
(ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new())
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(ast::Variable{ align, v_type, state_space, name, array_init })
|
||||
Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init }))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -172,14 +172,18 @@ impl TypeWordMap {
|
|||
.or_insert_with(|| b.type_vector(None, base, len as u32))
|
||||
}
|
||||
SpirvType::Array(typ, array_dimensions) => {
|
||||
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
|
||||
let (base_type, length) = match &*array_dimensions {
|
||||
&[] => {
|
||||
return self.get_or_add(b, SpirvType::Base(typ));
|
||||
}
|
||||
&[len] => {
|
||||
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
|
||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||
let len_const = b.constant_u32(u32_type, None, len);
|
||||
(base, len_const)
|
||||
}
|
||||
array_dimensions => {
|
||||
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
|
||||
let base = self
|
||||
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
|
||||
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
|
||||
|
@ -221,7 +225,7 @@ impl TypeWordMap {
|
|||
fn get_or_add_fn(
|
||||
&mut self,
|
||||
b: &mut dr::Builder,
|
||||
in_params: impl ExactSizeIterator<Item = SpirvType>,
|
||||
in_params: impl Iterator<Item = SpirvType>,
|
||||
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
|
||||
) -> (spirv::Word, spirv::Word) {
|
||||
let (out_args, out_spirv_type) = if out_params.len() == 0 {
|
||||
|
@ -233,6 +237,7 @@ impl TypeWordMap {
|
|||
self.get_or_add(b, arg_as_key),
|
||||
)
|
||||
} else {
|
||||
// TODO: support multiple return values
|
||||
todo!()
|
||||
};
|
||||
(
|
||||
|
@ -436,7 +441,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||
let mut builder = dr::Builder::new();
|
||||
builder.reserve_ids(id_defs.current_id());
|
||||
let call_map = get_kernels_call_map(&directives);
|
||||
//let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
|
||||
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
|
||||
normalize_variable_decls(&mut directives);
|
||||
let denorm_information = compute_denorm_information(&directives);
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
|
@ -528,7 +533,7 @@ fn emit_directives<'input>(
|
|||
let empty_body = Vec::new();
|
||||
for d in directives.iter() {
|
||||
match d {
|
||||
Directive::Variable(var) => {
|
||||
Directive::Variable(_, var) => {
|
||||
emit_variable(builder, map, &var)?;
|
||||
}
|
||||
Directive::Method(f) => {
|
||||
|
@ -699,7 +704,6 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
|
|||
transformation has a semantical meaning - we emit additional
|
||||
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
|
||||
*/
|
||||
/*
|
||||
fn convert_dynamic_shared_memory_usage<'input>(
|
||||
module: Vec<Directive<'input>>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
|
@ -707,13 +711,16 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
let mut extern_shared_decls = HashMap::new();
|
||||
for dir in module.iter() {
|
||||
match dir {
|
||||
Directive::Variable(ast::Variable {
|
||||
v_type: ast::Type::Pointer(p_type),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name,
|
||||
..
|
||||
}) => {
|
||||
extern_shared_decls.insert(*name, p_type.clone());
|
||||
Directive::Variable(
|
||||
linking,
|
||||
ast::Variable {
|
||||
v_type: ast::Type::Array(p_type, dims),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name,
|
||||
..
|
||||
},
|
||||
) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
|
||||
extern_shared_decls.insert(*name, *p_type);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -732,14 +739,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
uses_shared_mem,
|
||||
}) => {
|
||||
let call_key = (*func_decl).borrow().name;
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|statement| match statement {
|
||||
Statement::Call(call) => {
|
||||
multi_hash_map_append(&mut directly_called_by, call.func, call_key);
|
||||
multi_hash_map_append(&mut directly_called_by, call.name, call_key);
|
||||
Statement::Call(call)
|
||||
}
|
||||
statement => statement.map_id(&mut |id, _| {
|
||||
|
@ -756,7 +762,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
uses_shared_mem,
|
||||
})
|
||||
}
|
||||
directive => directive,
|
||||
|
@ -775,7 +780,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
uses_shared_mem,
|
||||
}) => {
|
||||
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
|
||||
return Directive::Method(Function {
|
||||
|
@ -784,21 +788,12 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
uses_shared_mem,
|
||||
});
|
||||
}
|
||||
let shared_id_param = new_id();
|
||||
{
|
||||
let mut func_decl = (*func_decl).borrow_mut();
|
||||
func_decl.input_arguments.push({
|
||||
ast::Variable {
|
||||
name: shared_id_param,
|
||||
align: None,
|
||||
v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
array_init: Vec::new(),
|
||||
}
|
||||
});
|
||||
func_decl.shared_mem = Some(shared_id_param);
|
||||
}
|
||||
let statements = replace_uses_of_shared_memory(
|
||||
new_id,
|
||||
|
@ -813,7 +808,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
uses_shared_mem: true,
|
||||
})
|
||||
}
|
||||
directive => directive,
|
||||
|
@ -835,8 +829,8 @@ fn replace_uses_of_shared_memory<'a>(
|
|||
// We can safely skip checking call arguments,
|
||||
// because there's simply no way to pass shared ptr
|
||||
// without converting it to .b64 first
|
||||
if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) {
|
||||
call.param_list.push((
|
||||
if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
|
||||
call.input_arguments.push((
|
||||
shared_id_param,
|
||||
ast::Type::Scalar(ast::ScalarType::B8),
|
||||
ast::StateSpace::Shared,
|
||||
|
@ -854,13 +848,11 @@ fn replace_uses_of_shared_memory<'a>(
|
|||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: shared_id_param,
|
||||
dst: replacement_id,
|
||||
from_type: ast::Type::Pointer(ast::ScalarType::B8),
|
||||
from_type: ast::Type::Scalar(ast::ScalarType::B8),
|
||||
from_space: ast::StateSpace::Shared,
|
||||
to_type: ast::Type::Pointer((*scalar_type).into()),
|
||||
to_type: ast::Type::Scalar(*scalar_type),
|
||||
to_space: ast::StateSpace::Shared,
|
||||
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
|
||||
src_
|
||||
dst_
|
||||
kind: ConversionKind::PtrToPtr,
|
||||
}));
|
||||
replacement_id
|
||||
} else {
|
||||
|
@ -912,7 +904,6 @@ fn get_callers_of_extern_shared_single<'a>(
|
|||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
type DenormCountMap<T> = HashMap<T, isize>;
|
||||
|
||||
|
@ -948,7 +939,7 @@ fn compute_denorm_information<'input>(
|
|||
let mut denorm_methods = HashMap::new();
|
||||
for directive in module {
|
||||
match directive {
|
||||
Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
|
||||
Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
body: Some(statements),
|
||||
|
@ -1158,14 +1149,17 @@ fn translate_directive<'input>(
|
|||
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
|
||||
) -> Result<Option<Directive<'input>>, TranslateError> {
|
||||
Ok(match d {
|
||||
ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable {
|
||||
align: var.align,
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
|
||||
array_init: var.array_init,
|
||||
})),
|
||||
ast::Directive::Method(f) => {
|
||||
ast::Directive::Variable(linking, var) => Some(Directive::Variable(
|
||||
linking,
|
||||
ast::Variable {
|
||||
align: var.align,
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
|
||||
array_init: var.array_init,
|
||||
},
|
||||
)),
|
||||
ast::Directive::Method(_, f) => {
|
||||
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
|
||||
}
|
||||
})
|
||||
|
@ -2576,7 +2570,7 @@ fn insert_implicit_conversions_impl(
|
|||
fn get_function_type(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
spirv_input: impl ExactSizeIterator<Item = SpirvType>,
|
||||
spirv_input: impl Iterator<Item = SpirvType>,
|
||||
spirv_output: &[ast::Variable<spirv::Word>],
|
||||
) -> (spirv::Word, spirv::Word) {
|
||||
map.get_or_add_fn(
|
||||
|
@ -5597,7 +5591,7 @@ impl ast::ArgParams for ExpandedArgParams {
|
|||
impl ArgParamsEx for ExpandedArgParams {}
|
||||
|
||||
enum Directive<'input> {
|
||||
Variable(ast::Variable<spirv::Word>),
|
||||
Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
|
||||
Method(Function<'input>),
|
||||
}
|
||||
|
||||
|
@ -7582,19 +7576,28 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
|
|||
}
|
||||
|
||||
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
|
||||
fn effective_input_arguments(
|
||||
&self,
|
||||
) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
|
||||
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
|
||||
let is_kernel = self.name.is_kernel();
|
||||
self.input_arguments.iter().map(move |arg| {
|
||||
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
|
||||
let spirv_type =
|
||||
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
|
||||
(arg.name, spirv_type)
|
||||
} else {
|
||||
(arg.name, SpirvType::new(arg.v_type.clone()))
|
||||
}
|
||||
})
|
||||
self.input_arguments
|
||||
.iter()
|
||||
.map(move |arg| {
|
||||
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
|
||||
let spirv_type =
|
||||
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
|
||||
(arg.name, spirv_type)
|
||||
} else {
|
||||
(arg.name, SpirvType::new(arg.v_type.clone()))
|
||||
}
|
||||
})
|
||||
.chain(self.shared_mem.iter().map(|id| {
|
||||
(
|
||||
*id,
|
||||
SpirvType::Pointer(
|
||||
Box::new(SpirvType::Base(SpirvScalarKey::B8)),
|
||||
spirv::StorageClass::Workgroup,
|
||||
),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue