Start implementing .shared unification

This commit is contained in:
Andrzej Janik 2021-09-24 01:31:50 +02:00
parent 9609f86033
commit 370c0bd09e
5 changed files with 305 additions and 119 deletions

View file

@ -1970,6 +1970,9 @@ ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = {
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
(ret_params, func, param_list)
},
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> => {
(ret_params, func, Vec::new())
},
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
<func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
};

View file

@ -221,6 +221,8 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
test_ptx!(func_ptr);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);

View file

@ -0,0 +1,34 @@
.version 6.5
.target sm_30
.address_size 64
.extern .shared .b32 shared_ex[];
.shared .b32 shared_mod[4];
.func (.reg .b64 out) load_from_shared()
{
ld.shared.u64 out, [shared_mod];
ret;
}
.visible .entry shared_unify_extern(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp1, [in_addr];
st.shared.u64 [shared_ex], temp1;
call (temp2), load_from_shared;
add.u64 temp2, temp2, temp1;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,62 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
OpExecutionMode %2 ContractionOff
OpDecorate %1 Alignment 4
OpDecorate %1 LinkageAttributes "shared_mem" Import
%void = OpTypeVoid
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
%1 = OpVariable %_ptr_Workgroup_uchar Workgroup
%ulong = OpTypeInt 64 0
%35 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%2 = OpFunction %void None %35
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
%28 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
OpStore %3 %10
OpStore %4 %11
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%13 = OpLoad %ulong %4 Aligned 8
OpStore %6 %13
%23 = OpConvertPtrToU %ulong %1
%14 = OpCopyObject %ulong %23
OpStore %7 %14
%16 = OpLoad %ulong %5
%24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
%15 = OpLoad %ulong %24 Aligned 8
OpStore %8 %15
%17 = OpLoad %ulong %7
%18 = OpLoad %ulong %8
%25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
OpStore %25 %18 Aligned 8
%20 = OpLoad %ulong %7
%26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
%19 = OpLoad %ulong %26 Aligned 8
OpStore %9 %19
%21 = OpLoad %ulong %6
%22 = OpLoad %ulong %9
%27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
OpStore %27 %22 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -443,7 +443,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
//let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
let mut directives =
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@ -607,7 +608,7 @@ fn emit_directives<'input>(
}
}
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
emit_function_linkage(builder, id_defs, f, fn_id);
emit_function_linkage(builder, id_defs, f, fn_id)?;
builder.select_block(None)?;
builder.end_function()?;
}
@ -683,7 +684,7 @@ fn get_kernels_call_map<'input>(
}
fn add_call_map_single<'input>(
directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
directly_called_by: &HashMap<ast::MethodName<'input, spirv::Word>, Vec<spirv::Word>>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
@ -697,15 +698,21 @@ fn add_call_map_single<'input>(
}
}
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
fn multi_hash_map_append<
K: Eq + std::hash::Hash,
V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(value);
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![value]);
entry.insert(Default::default());
}
}
}
@ -713,7 +720,8 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
/*
PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[];
In SPIRV/OpenCL world this is expressed as an additional argument
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to
an additional method argument
The question is how this artificial argument should be expressed. There are
@ -735,30 +743,35 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
let mut extern_shared_decls = HashMap::new();
let mut globals_shared = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(
linking,
ast::Variable {
v_type: ast::Type::Array(p_type, dims),
state_space: ast::StateSpace::Shared,
name,
v_type,
..
},
) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
extern_shared_decls.insert(*name, *p_type);
) => {
let size = if linking.contains(ast::LinkingDirective::EXTERN) {
GlobalSharedSize::ExternUnsized
} else {
GlobalSharedSize::Sized((*v_type).size_of())
};
globals_shared.insert(*name, (size, v_type.clone()));
}
_ => {}
}
}
if extern_shared_decls.len() == 0 {
if globals_shared.len() == 0 {
return module;
}
let mut methods_using_extern_shared = HashSet::new();
let mut directly_called_by = MultiHashMap::new();
let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new();
let module = module
.into_iter()
.map(|directive| match directive {
@ -773,17 +786,21 @@ fn convert_dynamic_shared_memory_usage<'input>(
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key);
.map(|statement| {
statement.map_id(&mut |id, _| {
if let Some((size, _)) = globals_shared.get(&id) {
match methods_to_globals_shared_direct_only_use.entry(call_key) {
hash_map::Entry::Occupied(mut e) => {
let original_size = *e.get();
e.insert(original_size.fold(*size));
}
hash_map::Entry::Vacant(mut e) => {
e.insert(*size);
}
}
}
id
}),
})
})
.collect();
Directive::Method(Function {
@ -800,11 +817,15 @@ fn convert_dynamic_shared_memory_usage<'input>(
.collect::<Vec<_>>();
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
let (kernels_to_global_shared, functions_to_global_shared) =
resolve_indirect_uses_of_globals_shared(
methods_to_globals_shared_direct_only_use,
kernels_methods_call_map,
);
// now visit every method declaration and inject those additional arguments
module
.into_iter()
.map(|directive| match directive {
let mut result = Vec::with_capacity(module.len());
for directive in module.into_iter() {
match directive {
Directive::Method(Function {
func_decl,
globals,
@ -813,46 +834,119 @@ fn convert_dynamic_shared_memory_usage<'input>(
tuning,
linkage,
}) => {
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
});
}
let shared_id_param = new_id();
{
let mut func_decl = (*func_decl).borrow_mut();
func_decl.shared_mem = Some(shared_id_param);
}
let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
statements,
);
Directive::Method(Function {
let statements = {
let func_decl_ref = &mut (*func_decl).borrow_mut();
let method_name = func_decl_ref.name;
insert_arguments_remap_statements(
method_name,
&kernels_to_global_shared,
new_id,
&mut result,
&functions_to_global_shared,
func_decl_ref,
&globals_shared,
statements,
)
};
result.push(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
})
}));
}
directive => directive,
})
.collect::<Vec<_>>()
directive => result.push(directive),
}
}
result
}
fn insert_arguments_remap_statements(
method_name: ast::MethodName<u32>,
kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>,
new_id: &mut impl FnMut() -> u32,
result: &mut Vec<Directive>,
functions_to_global_shared: &HashSet<u32>,
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<u32>>,
globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
let shared_id_param = match method_name {
ast::MethodName::Kernel(kernel_name) => {
let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
Some(s) => *s,
None => return statements,
};
let shared_id_param = new_id();
let (linkage, type_) = match globals_shared_size {
GlobalSharedSize::ExternUnsized => (
ast::LinkingDirective::EXTERN,
ast::Type::Array(ast::ScalarType::U8, Vec::new()),
),
GlobalSharedSize::Sized(size) => (
ast::LinkingDirective::NONE,
ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
),
};
result.push(Directive::Variable(
linkage,
ast::Variable {
align: None,
v_type: type_,
state_space: ast::StateSpace::Shared,
name: shared_id_param,
array_init: Vec::new(),
},
));
shared_id_param
}
ast::MethodName::Func(function_name) => {
if !functions_to_global_shared.contains(&function_name) {
return statements;
}
let shared_id_param = new_id();
func_decl_ref.input_arguments.push(ast::Variable {
align: None,
v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared),
state_space: ast::StateSpace::Reg,
name: shared_id_param,
array_init: Vec::new(),
});
shared_id_param
}
};
replace_uses_of_shared_memory(
new_id,
globals_shared,
functions_to_global_shared,
shared_id_param,
statements,
)
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
enum GlobalSharedSize {
ExternUnsized,
Sized(usize),
}
impl GlobalSharedSize {
fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize {
match (self, other) {
(GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => {
GlobalSharedSize::Sized(usize::max(s1, s2))
}
_ => GlobalSharedSize::ExternUnsized,
}
}
}
fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
methods_using_extern_shared: &HashSet<spirv::Word>,
shared_id_param: spirv::Word,
statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> {
@ -863,7 +957,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
if methods_using_extern_shared.contains(&call.name) {
call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@ -874,8 +968,8 @@ fn replace_uses_of_shared_memory<'a>(
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
if let Some(scalar_type) = extern_shared_decls.get(&id) {
if *scalar_type == ast::ScalarType::B8 {
if let Some((_, type_)) = extern_shared_decls.get(&id) {
if *type_ == ast::Type::Scalar(ast::ScalarType::B8) {
return shared_id_param;
}
let replacement_id = new_id();
@ -884,7 +978,7 @@ fn replace_uses_of_shared_memory<'a>(
dst: replacement_id,
from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared,
to_type: ast::Type::Scalar(*scalar_type),
to_type: type_.clone(),
to_space: ast::StateSpace::Shared,
kind: ConversionKind::PtrToPtr,
}));
@ -900,43 +994,40 @@ fn replace_uses_of_shared_memory<'a>(
result
}
fn get_callers_of_extern_shared<'a>(
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
}
})
.collect::<Vec<_>>();
for fn_id in direct_uses_of_extern_shared {
get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
}
}
fn get_callers_of_extern_shared_single<'a>(
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
*caller_fn,
);
// We need to compute two kinds of information:
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
// * If it's a function -> does it use .shared global (directly or indirectly)
fn resolve_indirect_uses_of_globals_shared<'input>(
mut methods_use_of_globals_shared: HashMap<
ast::MethodName<'input, spirv::Word>,
GlobalSharedSize,
>,
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
let mut kernel_use = HashMap::new();
let mut functions_using_global = HashSet::new();
let empty = HashSet::new();
for (method, globals) in methods_use_of_globals_shared.iter() {
match method {
ast::MethodName::Kernel(kernel_name) => {
let mut size = *globals;
for &called_subfunction in
kernels_methods_call_map.get(kernel_name).unwrap_or(&empty)
{
if let Some(new_size) = methods_use_of_globals_shared
.get(&ast::MethodName::Func(called_subfunction))
{
size = size.fold(*new_size);
}
}
kernel_use.insert(*kernel_name, size);
}
ast::MethodName::Func(fn_id) => {
functions_using_global.insert(*fn_id);
}
}
}
(kernel_use, functions_using_global)
}
type DenormCountMap<T> = HashMap<T, isize>;
@ -3480,7 +3571,10 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
emit_linking_decoration(builder, id_defs, None, var.name, linking);
if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
{
emit_linking_decoration(builder, id_defs, None, var.name, linking);
}
Ok(())
}
@ -3494,9 +3588,9 @@ fn emit_linking_decoration<'input>(
if linking == ast::LinkingDirective::NONE {
return;
}
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
if linking.contains(ast::LinkingDirective::VISIBLE) {
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@ -3508,6 +3602,8 @@ fn emit_linking_decoration<'input>(
.cloned(),
);
} else if linking.contains(ast::LinkingDirective::EXTERN) {
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@ -4454,7 +4550,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
let mut stateful_init_reg = MultiHashMap::new();
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta(
@ -7863,26 +7959,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
self.input_arguments
.iter()
.map(move |arg| {
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
let spirv_type =
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
(arg.name, spirv_type)
} else {
(arg.name, SpirvType::new(arg.v_type.clone()))
}
})
.chain(self.shared_mem.iter().map(|id| {
(
*id,
SpirvType::Pointer(
Box::new(SpirvType::Base(SpirvScalarKey::B8)),
spirv::StorageClass::Workgroup,
),
)
}))
self.input_arguments.iter().map(move |arg| {
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
let spirv_type =
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
(arg.name, spirv_type)
} else {
(arg.name, SpirvType::new(arg.v_type.clone()))
}
})
}
}