mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Emit most of SPIR-V
This commit is contained in:
parent
144f8bd5ed
commit
790fe18579
5 changed files with 3535 additions and 78 deletions
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
|
@ -0,0 +1,299 @@
|
|||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use super::*;
|
||||
|
||||
/*
|
||||
PTX represents dynamically allocated shared local memory as
|
||||
.extern .shared .b32 shared_mem[];
|
||||
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
|
||||
And in AMD compilation
|
||||
This pass looks for all uses of .extern .shared and converts them to
|
||||
an additional method argument
|
||||
The question is how this artificial argument should be expressed. There are
|
||||
several options:
|
||||
* Straight conversion:
|
||||
.shared .b32 shared_mem[]
|
||||
* Introduce .param_shared statespace:
|
||||
.param_shared .b32 shared_mem
|
||||
or
|
||||
.param_shared .b32 shared_mem[]
|
||||
* Introduce .shared_ptr <SCALAR> type:
|
||||
.param .shared_ptr .b32 shared_mem
|
||||
* Reuse .ptr hint:
|
||||
.param .u64 .ptr shared_mem
|
||||
This is the most tempting, but also the most nonsensical, .ptr is just a
|
||||
hint, which has no semantical meaning (and the output of our
|
||||
transformation has a semantical meaning - we emit additional
|
||||
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
|
||||
*/
|
||||
pub(super) fn run<'input>(
|
||||
module: Vec<Directive<'input>>,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
) -> Result<Vec<Directive<'input>>, TranslateError> {
|
||||
let mut globals_shared = HashMap::new();
|
||||
for dir in module.iter() {
|
||||
match dir {
|
||||
Directive::Variable(
|
||||
_,
|
||||
ast::Variable {
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name,
|
||||
v_type,
|
||||
..
|
||||
},
|
||||
) => {
|
||||
globals_shared.insert(*name, v_type.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if globals_shared.len() == 0 {
|
||||
return Ok(module);
|
||||
}
|
||||
let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
|
||||
let module = module
|
||||
.into_iter()
|
||||
.map(|directive| match directive {
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}) => {
|
||||
let call_key = (*func_decl).borrow().name;
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
statement.visit_map(
|
||||
&mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
|
||||
if let Some(_) = globals_shared.get(&id) {
|
||||
methods_to_directly_used_shared_globals
|
||||
.entry(call_key)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(id);
|
||||
}
|
||||
Ok::<_, TranslateError>(id)
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok::<_, TranslateError>(Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}))
|
||||
}
|
||||
directive => Ok(directive),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
||||
// make sure it gets propagated to `fn1` and `kernel`
|
||||
let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
|
||||
methods_to_directly_used_shared_globals,
|
||||
kernels_methods_call_map,
|
||||
);
|
||||
// now visit every method declaration and inject those additional arguments
|
||||
let mut directives = Vec::with_capacity(module.len());
|
||||
for directive in module.into_iter() {
|
||||
match directive {
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}) => {
|
||||
let statements = {
|
||||
let func_decl_ref = &mut (*func_decl).borrow_mut();
|
||||
let method_name = func_decl_ref.name;
|
||||
insert_arguments_remap_statements(
|
||||
new_id,
|
||||
kernels_methods_call_map,
|
||||
&globals_shared,
|
||||
&methods_to_indirectly_used_shared_globals,
|
||||
method_name,
|
||||
&mut directives,
|
||||
func_decl_ref,
|
||||
statements,
|
||||
)?
|
||||
};
|
||||
directives.push(Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}));
|
||||
}
|
||||
directive => directives.push(directive),
|
||||
}
|
||||
}
|
||||
Ok(directives)
|
||||
}
|
||||
|
||||
// We need to compute two kinds of information:
|
||||
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
|
||||
// * If it's a function -> does it use .shared global (directly or indirectly)
|
||||
fn resolve_indirect_uses_of_globals_shared<'input>(
|
||||
methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
|
||||
let mut result = HashMap::new();
|
||||
for (method, callees) in kernels_methods_call_map.methods() {
|
||||
let mut indirect_globals = methods_use_of_globals_shared
|
||||
.get(&method)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.copied()
|
||||
.collect::<BTreeSet<_>>();
|
||||
for &callee in callees {
|
||||
indirect_globals.extend(
|
||||
methods_use_of_globals_shared
|
||||
.get(&ast::MethodName::Func(callee))
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.copied(),
|
||||
);
|
||||
}
|
||||
result.insert(method, indirect_globals);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn insert_arguments_remap_statements<'input>(
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
globals_shared: &HashMap<SpirvWord, ast::Type>,
|
||||
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||
ast::MethodName<'input, SpirvWord>,
|
||||
BTreeSet<SpirvWord>,
|
||||
>,
|
||||
method_name: ast::MethodName<SpirvWord>,
|
||||
result: &mut Vec<Directive>,
|
||||
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
|
||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let remapped_globals_in_method =
|
||||
if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
|
||||
match method_name {
|
||||
ast::MethodName::Func(..) => {
|
||||
let remapped_globals = method_globals
|
||||
.iter()
|
||||
.map(|global| {
|
||||
(
|
||||
*global,
|
||||
(
|
||||
new_id(),
|
||||
globals_shared
|
||||
.get(&global)
|
||||
.unwrap_or_else(|| todo!())
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
|
||||
func_decl_ref.input_arguments.push(ast::Variable {
|
||||
align: None,
|
||||
v_type: shared_global_type.clone(),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name: *new_shared_global_id,
|
||||
array_init: Vec::new(),
|
||||
});
|
||||
}
|
||||
remapped_globals
|
||||
}
|
||||
ast::MethodName::Kernel(..) => method_globals
|
||||
.iter()
|
||||
.map(|global| {
|
||||
(
|
||||
*global,
|
||||
(
|
||||
*global,
|
||||
globals_shared
|
||||
.get(&global)
|
||||
.unwrap_or_else(|| todo!())
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>(),
|
||||
}
|
||||
} else {
|
||||
return Ok(statements);
|
||||
};
|
||||
replace_uses_of_shared_memory(
|
||||
new_id,
|
||||
methods_to_indirectly_used_shared_globals,
|
||||
statements,
|
||||
remapped_globals_in_method,
|
||||
)
|
||||
}
|
||||
|
||||
fn replace_uses_of_shared_memory<'input>(
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||
ast::MethodName<'input, SpirvWord>,
|
||||
BTreeSet<SpirvWord>,
|
||||
>,
|
||||
statements: Vec<ExpandedStatement>,
|
||||
remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
mut data,
|
||||
mut arguments,
|
||||
}) => {
|
||||
// We can safely skip checking call arguments,
|
||||
// because there's simply no way to pass shared ptr
|
||||
// without converting it to .b64 first
|
||||
if let Some(shared_globals_used_by_callee) =
|
||||
methods_to_indirectly_used_shared_globals
|
||||
.get(&ast::MethodName::Func(arguments.func))
|
||||
{
|
||||
for &shared_global_used_by_callee in shared_globals_used_by_callee {
|
||||
let (remapped_shared_id, type_) = remapped_globals_in_method
|
||||
.get(&shared_global_used_by_callee)
|
||||
.unwrap_or_else(|| todo!());
|
||||
data.input_arguments
|
||||
.push((type_.clone(), ast::StateSpace::Shared));
|
||||
arguments.input_arguments.push(*remapped_shared_id);
|
||||
}
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}))
|
||||
}
|
||||
statement => {
|
||||
let new_statement =
|
||||
statement.visit_map(&mut |id,
|
||||
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_,
|
||||
_| {
|
||||
Ok::<_, TranslateError>(
|
||||
if let Some((remapped_shared_id, _)) =
|
||||
remapped_globals_in_method.get(&id)
|
||||
{
|
||||
*remapped_shared_id
|
||||
} else {
|
||||
id
|
||||
},
|
||||
)
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
|
@ -394,25 +394,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn multi_hash_map_append<
|
||||
K: Eq + std::hash::Hash,
|
||||
V,
|
||||
Collection: std::iter::Extend<V> + std::default::Default,
|
||||
>(
|
||||
m: &mut HashMap<K, Collection>,
|
||||
key: K,
|
||||
value: V,
|
||||
) {
|
||||
match m.entry(key) {
|
||||
hash_map::Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().extend(iter::once(value));
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(Default::default()).extend(iter::once(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_add_ptr_direct(
|
||||
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||
arg: &ast::AddArgs<TypedOperand>,
|
||||
|
|
2767
ptx/src/pass/emit_spirv.rs
Normal file
2767
ptx/src/pass/emit_spirv.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -3,22 +3,27 @@ use rspirv::{binary::Assemble, dr};
|
|||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
collections::{hash_map, HashMap},
|
||||
collections::{hash_map, HashMap, HashSet},
|
||||
ffi::CString,
|
||||
iter,
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
rc::Rc,
|
||||
};
|
||||
use std::hash::Hash;
|
||||
|
||||
mod convert_dynamic_shared_memory_usage;
|
||||
mod convert_to_stateful_memory_access;
|
||||
mod convert_to_typed;
|
||||
mod expand_arguments;
|
||||
mod extract_globals;
|
||||
mod fix_special_registers;
|
||||
mod insert_implicit_conversions;
|
||||
mod insert_mem_ssa_statements;
|
||||
mod normalize_identifiers;
|
||||
mod normalize_predicates;
|
||||
mod insert_implicit_conversions;
|
||||
mod normalize_labels;
|
||||
mod extract_globals;
|
||||
mod normalize_predicates;
|
||||
mod emit_spirv;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
|
@ -34,7 +39,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
|||
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
/*
|
||||
let directives = hoist_function_globals(directives);
|
||||
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
|
||||
let mut directives = ptx_impl_imports
|
||||
|
@ -43,21 +47,19 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
|||
.chain(directives.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
let mut builder = dr::Builder::new();
|
||||
builder.reserve_ids(id_defs.current_id());
|
||||
builder.reserve_ids(id_defs.current_id().0);
|
||||
let call_map = MethodsCallMap::new(&directives);
|
||||
let mut directives =
|
||||
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
|
||||
convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || {
|
||||
SpirvWord(builder.id())
|
||||
})?;
|
||||
normalize_variable_decls(&mut directives);
|
||||
let denorm_information = compute_denorm_information(&directives);
|
||||
emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives);
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
builder.set_version(1, 3);
|
||||
emit_capabilities(&mut builder);
|
||||
emit_extensions(&mut builder);
|
||||
let opencl_id = emit_opencl_import(&mut builder);
|
||||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
//emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
let mut kernel_info = HashMap::new();
|
||||
|
||||
todo!()
|
||||
/*
|
||||
let (build_options, should_flush_denorms) =
|
||||
emit_denorm_build_string(&call_map, &denorm_information);
|
||||
let (directives, globals_use_map) = get_globals_use_map(directives);
|
||||
|
@ -84,7 +86,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
|||
build_options,
|
||||
})
|
||||
*/
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn translate_directive<'input, 'a>(
|
||||
|
@ -1273,3 +1274,399 @@ fn fn_arguments_to_variables<'a>(
|
|||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
|
||||
let mut result = Vec::with_capacity(directives.len());
|
||||
for directive in directives {
|
||||
match directive {
|
||||
Directive::Method(method) => {
|
||||
for variable in method.globals {
|
||||
result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
|
||||
}
|
||||
result.push(Directive::Method(Function {
|
||||
globals: Vec::new(),
|
||||
..method
|
||||
}))
|
||||
}
|
||||
_ => result.push(directive),
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
struct MethodsCallMap<'input> {
|
||||
map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
|
||||
}
|
||||
|
||||
impl<'input> MethodsCallMap<'input> {
|
||||
fn new(module: &[Directive<'input>]) -> Self {
|
||||
let mut directly_called_by = HashMap::new();
|
||||
for directive in module {
|
||||
match directive {
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
body: Some(statements),
|
||||
..
|
||||
}) => {
|
||||
let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
|
||||
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
|
||||
entry.insert(Vec::new());
|
||||
}
|
||||
for statement in statements {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Call { data, arguments }) => {
|
||||
multi_hash_map_append(
|
||||
&mut directly_called_by,
|
||||
call_key,
|
||||
arguments.func,
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let mut result = HashMap::new();
|
||||
for (&method_key, children) in directly_called_by.iter() {
|
||||
let mut visited = HashSet::new();
|
||||
for child in children {
|
||||
Self::add_call_map_single(&directly_called_by, &mut visited, *child);
|
||||
}
|
||||
result.insert(method_key, visited);
|
||||
}
|
||||
MethodsCallMap { map: result }
|
||||
}
|
||||
|
||||
fn add_call_map_single(
|
||||
directly_called_by: &HashMap<ast::MethodName<'input, SpirvWord>, Vec<SpirvWord>>,
|
||||
visited: &mut HashSet<SpirvWord>,
|
||||
current: SpirvWord,
|
||||
) {
|
||||
if !visited.insert(current) {
|
||||
return;
|
||||
}
|
||||
if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
|
||||
for child in children {
|
||||
Self::add_call_map_single(directly_called_by, visited, *child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_kernel_children(&self, name: &'input str) -> impl Iterator<Item = &SpirvWord> {
|
||||
self.map
|
||||
.get(&ast::MethodName::Kernel(name))
|
||||
.into_iter()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
fn kernels(&self) -> impl Iterator<Item = (&'input str, &HashSet<SpirvWord>)> {
|
||||
self.map
|
||||
.iter()
|
||||
.filter_map(|(method, children)| match method {
|
||||
ast::MethodName::Kernel(kernel) => Some((*kernel, children)),
|
||||
ast::MethodName::Func(..) => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn methods(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (ast::MethodName<'input, SpirvWord>, &HashSet<SpirvWord>)> {
|
||||
self.map
|
||||
.iter()
|
||||
.map(|(method, children)| (*method, children))
|
||||
}
|
||||
|
||||
fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) {
|
||||
self.map
|
||||
.get(&method)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.copied()
|
||||
.for_each(f);
|
||||
}
|
||||
}
|
||||
|
||||
fn multi_hash_map_append<
|
||||
K: Eq + std::hash::Hash,
|
||||
V,
|
||||
Collection: std::iter::Extend<V> + std::default::Default,
|
||||
>(
|
||||
m: &mut HashMap<K, Collection>,
|
||||
key: K,
|
||||
value: V,
|
||||
) {
|
||||
match m.entry(key) {
|
||||
hash_map::Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().extend(iter::once(value));
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(Default::default()).extend(iter::once(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
|
||||
for directive in directives {
|
||||
match directive {
|
||||
Directive::Method(Function {
|
||||
body: Some(func), ..
|
||||
}) => {
|
||||
func[1..].sort_by_key(|s| match s {
|
||||
Statement::Variable(_) => 0,
|
||||
_ => 1,
|
||||
});
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HACK ALERT!
|
||||
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
|
||||
// in the kernel as flushing denorms to zero or preserving them
|
||||
// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
|
||||
// such capability, so instead we guesstimate which use is more common in the kernel
|
||||
// and emit suitable execution mode
|
||||
fn compute_denorm_information<'input>(
|
||||
module: &[Directive<'input>],
|
||||
) -> HashMap<ast::MethodName<'input, SpirvWord>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
|
||||
let mut denorm_methods = HashMap::new();
|
||||
for directive in module {
|
||||
match directive {
|
||||
Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
body: Some(statements),
|
||||
..
|
||||
}) => {
|
||||
let mut flush_counter = DenormCountMap::new();
|
||||
let method_key = (**func_decl).borrow().name;
|
||||
for statement in statements {
|
||||
match statement {
|
||||
Statement::Instruction(inst) => {
|
||||
if let Some((flush, width)) = flush_to_zero(inst) {
|
||||
denorm_count_map_update(&mut flush_counter, width, flush);
|
||||
}
|
||||
}
|
||||
Statement::LoadVar(..) => {}
|
||||
Statement::StoreVar(..) => {}
|
||||
Statement::Conditional(_) => {}
|
||||
Statement::Conversion(_) => {}
|
||||
Statement::Constant(_) => {}
|
||||
Statement::RetValue(_, _) => {}
|
||||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
Statement::RepackVector(_) => {}
|
||||
Statement::FunctionPointer(_) => {}
|
||||
}
|
||||
}
|
||||
denorm_methods.insert(method_key, flush_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
denorm_methods
|
||||
.into_iter()
|
||||
.map(|(name, v)| {
|
||||
let width_to_denorm = v
|
||||
.into_iter()
|
||||
.map(|(k, flush_over_preserve)| {
|
||||
let mode = if flush_over_preserve > 0 {
|
||||
spirv::FPDenormMode::FlushToZero
|
||||
} else {
|
||||
spirv::FPDenormMode::Preserve
|
||||
};
|
||||
(k, (mode, flush_over_preserve))
|
||||
})
|
||||
.collect();
|
||||
(name, width_to_denorm)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn flush_to_zero(this: &ast::Instruction<SpirvWord>) -> Option<(bool, u8)> {
|
||||
match this {
|
||||
ast::Instruction::Ld { .. } => None,
|
||||
ast::Instruction::St { .. } => None,
|
||||
ast::Instruction::Mov { .. } => None,
|
||||
ast::Instruction::Not { .. } => None,
|
||||
ast::Instruction::Bra { .. } => None,
|
||||
ast::Instruction::Shl { .. } => None,
|
||||
ast::Instruction::Shr { .. } => None,
|
||||
ast::Instruction::Ret { .. } => None,
|
||||
ast::Instruction::Call { .. } => None,
|
||||
ast::Instruction::Or { .. } => None,
|
||||
ast::Instruction::And { .. } => None,
|
||||
ast::Instruction::Cvta { .. } => None,
|
||||
ast::Instruction::Selp { .. } => None,
|
||||
ast::Instruction::Bar { .. } => None,
|
||||
ast::Instruction::Atom { .. } => None,
|
||||
ast::Instruction::AtomCas { .. } => None,
|
||||
ast::Instruction::Sub {
|
||||
data: ast::ArithDetails::Integer(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Integer(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Integer { .. },
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Mad {
|
||||
data: ast::MadDetails::Integer { .. },
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Min {
|
||||
data: ast::MinMaxDetails::Signed(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Min {
|
||||
data: ast::MinMaxDetails::Unsigned(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Max {
|
||||
data: ast::MinMaxDetails::Signed(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Max {
|
||||
data: ast::MinMaxDetails::Unsigned(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
mode:
|
||||
ast::CvtMode::ZeroExtend
|
||||
| ast::CvtMode::SignExtend
|
||||
| ast::CvtMode::Truncate
|
||||
| ast::CvtMode::Bitcast
|
||||
| ast::CvtMode::SaturateUnsignedToSigned
|
||||
| ast::CvtMode::SaturateSignedToUnsigned
|
||||
| ast::CvtMode::FPFromSigned(_)
|
||||
| ast::CvtMode::FPFromUnsigned(_),
|
||||
..
|
||||
},
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Div {
|
||||
data: ast::DivDetails::Unsigned(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Div {
|
||||
data: ast::DivDetails::Signed(_),
|
||||
..
|
||||
} => None,
|
||||
ast::Instruction::Clz { .. } => None,
|
||||
ast::Instruction::Brev { .. } => None,
|
||||
ast::Instruction::Popc { .. } => None,
|
||||
ast::Instruction::Xor { .. } => None,
|
||||
ast::Instruction::Bfe { .. } => None,
|
||||
ast::Instruction::Bfi { .. } => None,
|
||||
ast::Instruction::Rem { .. } => None,
|
||||
ast::Instruction::Prmt { .. } => None,
|
||||
ast::Instruction::Activemask { .. } => None,
|
||||
ast::Instruction::Membar { .. } => None,
|
||||
ast::Instruction::Sub {
|
||||
data: ast::ArithDetails::Float(float_control),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Float(float_control),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float(float_control),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mad {
|
||||
data: ast::MadDetails::Float(float_control),
|
||||
..
|
||||
} => float_control
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, float_control.type_.size_of())),
|
||||
ast::Instruction::Fma { data, .. } => data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())),
|
||||
ast::Instruction::Setp { data, .. } => {
|
||||
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
|
||||
}
|
||||
ast::Instruction::SetpBool { data, .. } => data
|
||||
.base
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, data.base.type_.size_of())),
|
||||
ast::Instruction::Abs { data, .. }
|
||||
| ast::Instruction::Rsqrt { data, .. }
|
||||
| ast::Instruction::Neg { data, .. }
|
||||
| ast::Instruction::Ex2 { data, .. } => {
|
||||
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
|
||||
}
|
||||
ast::Instruction::Min {
|
||||
data: ast::MinMaxDetails::Float(float_control),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Max {
|
||||
data: ast::MinMaxDetails::Float(float_control),
|
||||
..
|
||||
} => float_control
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())),
|
||||
ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => {
|
||||
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
|
||||
}
|
||||
// Modifier .ftz can only be specified when either .dtype or .atype
|
||||
// is .f32 and applies only to single precision (.f32) inputs and results.
|
||||
ast::Instruction::Cvt {
|
||||
data:
|
||||
ast::CvtDetails {
|
||||
mode:
|
||||
ast::CvtMode::FPExtend { flush_to_zero }
|
||||
| ast::CvtMode::FPTruncate { flush_to_zero, .. }
|
||||
| ast::CvtMode::FPRound { flush_to_zero, .. }
|
||||
| ast::CvtMode::SignedFromFP { flush_to_zero, .. }
|
||||
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
} => flush_to_zero.map(|ftz| (ftz, 4)),
|
||||
ast::Instruction::Div {
|
||||
data:
|
||||
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
}),
|
||||
..
|
||||
} => flush_to_zero.map(|ftz| (ftz, type_.size_of())),
|
||||
ast::Instruction::Sin { data, .. }
|
||||
| ast::Instruction::Cos { data, .. }
|
||||
| ast::Instruction::Lg2 { data, .. } => {
|
||||
Some((data.flush_to_zero, mem::size_of::<f32>() as u8))
|
||||
}
|
||||
ptx_parser::Instruction::PrmtSlow { .. } => None,
|
||||
ptx_parser::Instruction::Trap {} => None,
|
||||
}
|
||||
}
|
||||
|
||||
type DenormCountMap<T> = HashMap<T, isize>;
|
||||
|
||||
fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
|
||||
let num_value = if value { 1 } else { -1 };
|
||||
denorm_count_map_update_impl(map, key, num_value);
|
||||
}
|
||||
|
||||
fn denorm_count_map_update_impl<T: Eq + Hash>(
|
||||
map: &mut DenormCountMap<T>,
|
||||
key: T,
|
||||
num_value: isize,
|
||||
) {
|
||||
match map.entry(key) {
|
||||
hash_map::Entry::Occupied(mut counter) => {
|
||||
*(counter.get_mut()) += num_value;
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(num_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -514,8 +514,11 @@ pub trait Visitor<T: Operand, Err> {
|
|||
) -> Result<(), Err>;
|
||||
}
|
||||
|
||||
impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>>
|
||||
Visitor<T, Err> for Fn
|
||||
impl<
|
||||
T: Operand,
|
||||
Err,
|
||||
Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>,
|
||||
> Visitor<T, Err> for Fn
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
|
@ -760,7 +763,7 @@ pub enum Type {
|
|||
Vector(ScalarType, u8),
|
||||
// .param.b32 foo[4];
|
||||
Array(ScalarType, Vec<u32>),
|
||||
Pointer(ScalarType, StateSpace)
|
||||
Pointer(ScalarType, StateSpace),
|
||||
}
|
||||
|
||||
impl Type {
|
||||
|
@ -1097,7 +1100,7 @@ impl SetpData {
|
|||
let cmp_op = if type_kind == ScalarKind::Float {
|
||||
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
|
||||
} else {
|
||||
match SetpCompareInt::try_from(cmp_op) {
|
||||
match SetpCompareInt::try_from((cmp_op, type_kind)) {
|
||||
Ok(op) => SetpCompareOp::Integer(op),
|
||||
Err(err) => {
|
||||
state.errors.push(err);
|
||||
|
@ -1129,10 +1132,14 @@ pub enum SetpCompareOp {
|
|||
pub enum SetpCompareInt {
|
||||
Eq,
|
||||
NotEq,
|
||||
Less,
|
||||
LessOrEq,
|
||||
Greater,
|
||||
GreaterOrEq,
|
||||
UnsignedLess,
|
||||
UnsignedLessOrEq,
|
||||
UnsignedGreater,
|
||||
UnsignedGreaterOrEq,
|
||||
SignedLess,
|
||||
SignedLessOrEq,
|
||||
SignedGreater,
|
||||
SignedGreaterOrEq,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
|
@ -1153,29 +1160,41 @@ pub enum SetpCompareFloat {
|
|||
IsAnyNan,
|
||||
}
|
||||
|
||||
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
|
||||
impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt {
|
||||
type Error = PtxError;
|
||||
|
||||
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
|
||||
match value {
|
||||
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
|
||||
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
|
||||
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
|
||||
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
|
||||
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
|
||||
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
|
||||
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
|
||||
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
|
||||
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
|
||||
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
|
||||
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Num => Err(PtxError::WrongType),
|
||||
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
|
||||
fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result<Self, PtxError> {
|
||||
match (value, kind) {
|
||||
(RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq),
|
||||
(RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq),
|
||||
(RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => {
|
||||
Ok(SetpCompareInt::SignedLess)
|
||||
}
|
||||
(RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess),
|
||||
(RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => {
|
||||
Ok(SetpCompareInt::SignedLessOrEq)
|
||||
}
|
||||
(RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => {
|
||||
Ok(SetpCompareInt::UnsignedLessOrEq)
|
||||
}
|
||||
(RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => {
|
||||
Ok(SetpCompareInt::SignedGreater)
|
||||
}
|
||||
(RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater),
|
||||
(RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => {
|
||||
Ok(SetpCompareInt::SignedGreaterOrEq)
|
||||
}
|
||||
(RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => {
|
||||
Ok(SetpCompareInt::UnsignedGreaterOrEq)
|
||||
}
|
||||
(RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Num, _) => Err(PtxError::WrongType),
|
||||
(RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1276,7 +1295,9 @@ impl<T: Operand> CallArgs<T> {
|
|||
.return_arguments
|
||||
.into_iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false))
|
||||
.map(|(param, (type_, space))| {
|
||||
visitor.visit_ident(param, Some((type_, *space)), true, false)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let func = visitor.visit_ident(self.func, None, false, false)?;
|
||||
let input_arguments = self
|
||||
|
@ -1305,6 +1326,8 @@ pub enum CvtMode {
|
|||
SignExtend,
|
||||
Truncate,
|
||||
Bitcast,
|
||||
SaturateUnsignedToSigned,
|
||||
SaturateSignedToUnsigned,
|
||||
// float from float
|
||||
FPExtend {
|
||||
flush_to_zero: Option<bool>,
|
||||
|
@ -1389,21 +1412,11 @@ impl CvtDetails {
|
|||
},
|
||||
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
|
||||
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
|
||||
(
|
||||
ScalarKind::Unsigned | ScalarKind::Signed,
|
||||
ScalarKind::Unsigned | ScalarKind::Signed,
|
||||
) => match dst.size_of().cmp(&src.size_of()) {
|
||||
Ordering::Less => {
|
||||
if dst.kind() != src.kind() {
|
||||
errors.push(PtxError::Todo);
|
||||
}
|
||||
CvtMode::Truncate
|
||||
}
|
||||
(ScalarKind::Unsigned, ScalarKind::Unsigned)
|
||||
| (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
|
||||
Ordering::Less => CvtMode::Truncate,
|
||||
Ordering::Equal => CvtMode::Bitcast,
|
||||
Ordering::Greater => {
|
||||
if dst.kind() != src.kind() {
|
||||
errors.push(PtxError::Todo);
|
||||
}
|
||||
if src.kind() == ScalarKind::Signed {
|
||||
CvtMode::SignExtend
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue