Emit most of SPIR-V

This commit is contained in:
Andrzej Janik 2024-08-30 03:12:33 +02:00
parent 144f8bd5ed
commit 790fe18579
5 changed files with 3535 additions and 78 deletions

View file

@ -0,0 +1,299 @@
use std::collections::{BTreeMap, BTreeSet};
use super::*;
/*
PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[];
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to
an additional method argument
The question is how this artificial argument should be expressed. There are
several options:
* Straight conversion:
.shared .b32 shared_mem[]
* Introduce .param_shared statespace:
.param_shared .b32 shared_mem
or
.param_shared .b32 shared_mem[]
* Introduce .shared_ptr <SCALAR> type:
.param .shared_ptr .b32 shared_mem
* Reuse .ptr hint:
.param .u64 .ptr shared_mem
This is the most tempting, but also the most nonsensical, .ptr is just a
hint, which has no semantical meaning (and the output of our
transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/
pub(super) fn run<'input>(
module: Vec<Directive<'input>>,
kernels_methods_call_map: &MethodsCallMap<'input>,
new_id: &mut impl FnMut() -> SpirvWord,
) -> Result<Vec<Directive<'input>>, TranslateError> {
let mut globals_shared = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(
_,
ast::Variable {
state_space: ast::StateSpace::Shared,
name,
v_type,
..
},
) => {
globals_shared.insert(*name, v_type.clone());
}
_ => {}
}
}
if globals_shared.len() == 0 {
return Ok(module);
}
let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
let module = module
.into_iter()
.map(|directive| match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}) => {
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| {
statement.visit_map(
&mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
if let Some(_) = globals_shared.get(&id) {
methods_to_directly_used_shared_globals
.entry(call_key)
.or_insert_with(HashSet::new)
.insert(id);
}
Ok::<_, TranslateError>(id)
},
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok::<_, TranslateError>(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}))
}
directive => Ok(directive),
})
.collect::<Result<Vec<_>, _>>()?;
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
methods_to_directly_used_shared_globals,
kernels_methods_call_map,
);
// now visit every method declaration and inject those additional arguments
let mut directives = Vec::with_capacity(module.len());
for directive in module.into_iter() {
match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}) => {
let statements = {
let func_decl_ref = &mut (*func_decl).borrow_mut();
let method_name = func_decl_ref.name;
insert_arguments_remap_statements(
new_id,
kernels_methods_call_map,
&globals_shared,
&methods_to_indirectly_used_shared_globals,
method_name,
&mut directives,
func_decl_ref,
statements,
)?
};
directives.push(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}));
}
directive => directives.push(directive),
}
}
Ok(directives)
}
// We need to compute two kinds of information:
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
// * If it's a function -> does it use .shared global (directly or indirectly)
fn resolve_indirect_uses_of_globals_shared<'input>(
methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
kernels_methods_call_map: &MethodsCallMap<'input>,
) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
let mut result = HashMap::new();
for (method, callees) in kernels_methods_call_map.methods() {
let mut indirect_globals = methods_use_of_globals_shared
.get(&method)
.into_iter()
.flatten()
.copied()
.collect::<BTreeSet<_>>();
for &callee in callees {
indirect_globals.extend(
methods_use_of_globals_shared
.get(&ast::MethodName::Func(callee))
.into_iter()
.flatten()
.copied(),
);
}
result.insert(method, indirect_globals);
}
result
}
fn insert_arguments_remap_statements<'input>(
new_id: &mut impl FnMut() -> SpirvWord,
kernels_methods_call_map: &MethodsCallMap<'input>,
globals_shared: &HashMap<SpirvWord, ast::Type>,
methods_to_indirectly_used_shared_globals: &HashMap<
ast::MethodName<'input, SpirvWord>,
BTreeSet<SpirvWord>,
>,
method_name: ast::MethodName<SpirvWord>,
result: &mut Vec<Directive>,
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let remapped_globals_in_method =
if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
match method_name {
ast::MethodName::Func(..) => {
let remapped_globals = method_globals
.iter()
.map(|global| {
(
*global,
(
new_id(),
globals_shared
.get(&global)
.unwrap_or_else(|| todo!())
.clone(),
),
)
})
.collect::<BTreeMap<_, _>>();
for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
func_decl_ref.input_arguments.push(ast::Variable {
align: None,
v_type: shared_global_type.clone(),
state_space: ast::StateSpace::Shared,
name: *new_shared_global_id,
array_init: Vec::new(),
});
}
remapped_globals
}
ast::MethodName::Kernel(..) => method_globals
.iter()
.map(|global| {
(
*global,
(
*global,
globals_shared
.get(&global)
.unwrap_or_else(|| todo!())
.clone(),
),
)
})
.collect::<BTreeMap<_, _>>(),
}
} else {
return Ok(statements);
};
replace_uses_of_shared_memory(
new_id,
methods_to_indirectly_used_shared_globals,
statements,
remapped_globals_in_method,
)
}
fn replace_uses_of_shared_memory<'input>(
new_id: &mut impl FnMut() -> SpirvWord,
methods_to_indirectly_used_shared_globals: &HashMap<
ast::MethodName<'input, SpirvWord>,
BTreeSet<SpirvWord>,
>,
statements: Vec<ExpandedStatement>,
remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
}) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if let Some(shared_globals_used_by_callee) =
methods_to_indirectly_used_shared_globals
.get(&ast::MethodName::Func(arguments.func))
{
for &shared_global_used_by_callee in shared_globals_used_by_callee {
let (remapped_shared_id, type_) = remapped_globals_in_method
.get(&shared_global_used_by_callee)
.unwrap_or_else(|| todo!());
data.input_arguments
.push((type_.clone(), ast::StateSpace::Shared));
arguments.input_arguments.push(*remapped_shared_id);
}
}
result.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}))
}
statement => {
let new_statement =
statement.visit_map(&mut |id,
_: Option<(&ast::Type, ast::StateSpace)>,
_,
_| {
Ok::<_, TranslateError>(
if let Some((remapped_shared_id, _)) =
remapped_globals_in_method.get(&id)
{
*remapped_shared_id
} else {
id
},
)
})?;
result.push(new_statement);
}
}
}
Ok(result)
}

View file

@ -394,25 +394,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
}
}
fn multi_hash_map_append<
K: Eq + std::hash::Hash,
V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(Default::default()).extend(iter::once(value));
}
}
}
fn is_add_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::AddArgs<TypedOperand>,

2767
ptx/src/pass/emit_spirv.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -3,22 +3,27 @@ use rspirv::{binary::Assemble, dr};
use std::{
borrow::Cow,
cell::RefCell,
collections::{hash_map, HashMap},
collections::{hash_map, HashMap, HashSet},
ffi::CString,
iter,
marker::PhantomData,
mem,
rc::Rc,
};
use std::hash::Hash;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod expand_arguments;
mod extract_globals;
mod fix_special_registers;
mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_predicates;
mod insert_implicit_conversions;
mod normalize_labels;
mod extract_globals;
mod normalize_predicates;
mod emit_spirv;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@ -34,7 +39,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
})
.collect::<Result<Vec<_>, _>>()?;
/*
let directives = hoist_function_globals(directives);
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let mut directives = ptx_impl_imports
@ -43,21 +47,19 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
.chain(directives.into_iter())
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
builder.reserve_ids(id_defs.current_id().0);
let call_map = MethodsCallMap::new(&directives);
let mut directives =
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || {
SpirvWord(builder.id())
})?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
emit_extensions(&mut builder);
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
//emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
todo!()
/*
let (build_options, should_flush_denorms) =
emit_denorm_build_string(&call_map, &denorm_information);
let (directives, globals_use_map) = get_globals_use_map(directives);
@ -84,7 +86,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
build_options,
})
*/
todo!()
}
fn translate_directive<'input, 'a>(
@ -1273,3 +1274,399 @@ fn fn_arguments_to_variables<'a>(
})
.collect::<Vec<_>>()
}
fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
let mut result = Vec::with_capacity(directives.len());
for directive in directives {
match directive {
Directive::Method(method) => {
for variable in method.globals {
result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
}
result.push(Directive::Method(Function {
globals: Vec::new(),
..method
}))
}
_ => result.push(directive),
}
}
result
}
struct MethodsCallMap<'input> {
map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
}
impl<'input> MethodsCallMap<'input> {
fn new(module: &[Directive<'input>]) -> Self {
let mut directly_called_by = HashMap::new();
for directive in module {
match directive {
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
for statement in statements {
match statement {
Statement::Instruction(ast::Instruction::Call { data, arguments }) => {
multi_hash_map_append(
&mut directly_called_by,
call_key,
arguments.func,
);
}
_ => {}
}
}
}
_ => {}
}
}
let mut result = HashMap::new();
for (&method_key, children) in directly_called_by.iter() {
let mut visited = HashSet::new();
for child in children {
Self::add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(method_key, visited);
}
MethodsCallMap { map: result }
}
fn add_call_map_single(
directly_called_by: &HashMap<ast::MethodName<'input, SpirvWord>, Vec<SpirvWord>>,
visited: &mut HashSet<SpirvWord>,
current: SpirvWord,
) {
if !visited.insert(current) {
return;
}
if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
Self::add_call_map_single(directly_called_by, visited, *child);
}
}
}
fn get_kernel_children(&self, name: &'input str) -> impl Iterator<Item = &SpirvWord> {
self.map
.get(&ast::MethodName::Kernel(name))
.into_iter()
.flatten()
}
fn kernels(&self) -> impl Iterator<Item = (&'input str, &HashSet<SpirvWord>)> {
self.map
.iter()
.filter_map(|(method, children)| match method {
ast::MethodName::Kernel(kernel) => Some((*kernel, children)),
ast::MethodName::Func(..) => None,
})
}
fn methods(
&self,
) -> impl Iterator<Item = (ast::MethodName<'input, SpirvWord>, &HashSet<SpirvWord>)> {
self.map
.iter()
.map(|(method, children)| (*method, children))
}
fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) {
self.map
.get(&method)
.into_iter()
.flatten()
.copied()
.for_each(f);
}
}
fn multi_hash_map_append<
K: Eq + std::hash::Hash,
V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(Default::default()).extend(iter::once(value));
}
}
}
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
for directive in directives {
match directive {
Directive::Method(Function {
body: Some(func), ..
}) => {
func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0,
_ => 1,
});
}
_ => (),
}
}
}
// HACK ALERT!
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
// in the kernel as flushing denorms to zero or preserving them
// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
// such capability, so instead we guesstimate which use is more common in the kernel
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
) -> HashMap<ast::MethodName<'input, SpirvWord>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let mut flush_counter = DenormCountMap::new();
let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
if let Some((flush, width)) = flush_to_zero(inst) {
denorm_count_map_update(&mut flush_counter, width, flush);
}
}
Statement::LoadVar(..) => {}
Statement::StoreVar(..) => {}
Statement::Conditional(_) => {}
Statement::Conversion(_) => {}
Statement::Constant(_) => {}
Statement::RetValue(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
}
denorm_methods.insert(method_key, flush_counter);
}
}
}
denorm_methods
.into_iter()
.map(|(name, v)| {
let width_to_denorm = v
.into_iter()
.map(|(k, flush_over_preserve)| {
let mode = if flush_over_preserve > 0 {
spirv::FPDenormMode::FlushToZero
} else {
spirv::FPDenormMode::Preserve
};
(k, (mode, flush_over_preserve))
})
.collect();
(name, width_to_denorm)
})
.collect()
}
fn flush_to_zero(this: &ast::Instruction<SpirvWord>) -> Option<(bool, u8)> {
match this {
ast::Instruction::Ld { .. } => None,
ast::Instruction::St { .. } => None,
ast::Instruction::Mov { .. } => None,
ast::Instruction::Not { .. } => None,
ast::Instruction::Bra { .. } => None,
ast::Instruction::Shl { .. } => None,
ast::Instruction::Shr { .. } => None,
ast::Instruction::Ret { .. } => None,
ast::Instruction::Call { .. } => None,
ast::Instruction::Or { .. } => None,
ast::Instruction::And { .. } => None,
ast::Instruction::Cvta { .. } => None,
ast::Instruction::Selp { .. } => None,
ast::Instruction::Bar { .. } => None,
ast::Instruction::Atom { .. } => None,
ast::Instruction::AtomCas { .. } => None,
ast::Instruction::Sub {
data: ast::ArithDetails::Integer(_),
..
} => None,
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..
} => None,
ast::Instruction::Mul {
data: ast::MulDetails::Integer { .. },
..
} => None,
ast::Instruction::Mad {
data: ast::MadDetails::Integer { .. },
..
} => None,
ast::Instruction::Min {
data: ast::MinMaxDetails::Signed(_),
..
} => None,
ast::Instruction::Min {
data: ast::MinMaxDetails::Unsigned(_),
..
} => None,
ast::Instruction::Max {
data: ast::MinMaxDetails::Signed(_),
..
} => None,
ast::Instruction::Max {
data: ast::MinMaxDetails::Unsigned(_),
..
} => None,
ast::Instruction::Cvt {
data:
ast::CvtDetails {
mode:
ast::CvtMode::ZeroExtend
| ast::CvtMode::SignExtend
| ast::CvtMode::Truncate
| ast::CvtMode::Bitcast
| ast::CvtMode::SaturateUnsignedToSigned
| ast::CvtMode::SaturateSignedToUnsigned
| ast::CvtMode::FPFromSigned(_)
| ast::CvtMode::FPFromUnsigned(_),
..
},
..
} => None,
ast::Instruction::Div {
data: ast::DivDetails::Unsigned(_),
..
} => None,
ast::Instruction::Div {
data: ast::DivDetails::Signed(_),
..
} => None,
ast::Instruction::Clz { .. } => None,
ast::Instruction::Brev { .. } => None,
ast::Instruction::Popc { .. } => None,
ast::Instruction::Xor { .. } => None,
ast::Instruction::Bfe { .. } => None,
ast::Instruction::Bfi { .. } => None,
ast::Instruction::Rem { .. } => None,
ast::Instruction::Prmt { .. } => None,
ast::Instruction::Activemask { .. } => None,
ast::Instruction::Membar { .. } => None,
ast::Instruction::Sub {
data: ast::ArithDetails::Float(float_control),
..
}
| ast::Instruction::Add {
data: ast::ArithDetails::Float(float_control),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Float(float_control),
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Float(float_control),
..
} => float_control
.flush_to_zero
.map(|ftz| (ftz, float_control.type_.size_of())),
ast::Instruction::Fma { data, .. } => data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())),
ast::Instruction::Setp { data, .. } => {
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
}
ast::Instruction::SetpBool { data, .. } => data
.base
.flush_to_zero
.map(|ftz| (ftz, data.base.type_.size_of())),
ast::Instruction::Abs { data, .. }
| ast::Instruction::Rsqrt { data, .. }
| ast::Instruction::Neg { data, .. }
| ast::Instruction::Ex2 { data, .. } => {
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
}
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(float_control),
..
}
| ast::Instruction::Max {
data: ast::MinMaxDetails::Float(float_control),
..
} => float_control
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())),
ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => {
data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
}
// Modifier .ftz can only be specified when either .dtype or .atype
// is .f32 and applies only to single precision (.f32) inputs and results.
ast::Instruction::Cvt {
data:
ast::CvtDetails {
mode:
ast::CvtMode::FPExtend { flush_to_zero }
| ast::CvtMode::FPTruncate { flush_to_zero, .. }
| ast::CvtMode::FPRound { flush_to_zero, .. }
| ast::CvtMode::SignedFromFP { flush_to_zero, .. }
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. },
..
},
..
} => flush_to_zero.map(|ftz| (ftz, 4)),
ast::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
type_,
flush_to_zero,
..
}),
..
} => flush_to_zero.map(|ftz| (ftz, type_.size_of())),
ast::Instruction::Sin { data, .. }
| ast::Instruction::Cos { data, .. }
| ast::Instruction::Lg2 { data, .. } => {
Some((data.flush_to_zero, mem::size_of::<f32>() as u8))
}
ptx_parser::Instruction::PrmtSlow { .. } => None,
ptx_parser::Instruction::Trap {} => None,
}
}
type DenormCountMap<T> = HashMap<T, isize>;
fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
let num_value = if value { 1 } else { -1 };
denorm_count_map_update_impl(map, key, num_value);
}
fn denorm_count_map_update_impl<T: Eq + Hash>(
map: &mut DenormCountMap<T>,
key: T,
num_value: isize,
) {
match map.entry(key) {
hash_map::Entry::Occupied(mut counter) => {
*(counter.get_mut()) += num_value;
}
hash_map::Entry::Vacant(entry) => {
entry.insert(num_value);
}
}
}

View file

@ -514,8 +514,11 @@ pub trait Visitor<T: Operand, Err> {
) -> Result<(), Err>;
}
impl<T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>>
Visitor<T, Err> for Fn
impl<
T: Operand,
Err,
Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>,
> Visitor<T, Err> for Fn
{
fn visit(
&mut self,
@ -760,7 +763,7 @@ pub enum Type {
Vector(ScalarType, u8),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
Pointer(ScalarType, StateSpace)
Pointer(ScalarType, StateSpace),
}
impl Type {
@ -1097,7 +1100,7 @@ impl SetpData {
let cmp_op = if type_kind == ScalarKind::Float {
SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
} else {
match SetpCompareInt::try_from(cmp_op) {
match SetpCompareInt::try_from((cmp_op, type_kind)) {
Ok(op) => SetpCompareOp::Integer(op),
Err(err) => {
state.errors.push(err);
@ -1129,10 +1132,14 @@ pub enum SetpCompareOp {
pub enum SetpCompareInt {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
UnsignedLess,
UnsignedLessOrEq,
UnsignedGreater,
UnsignedGreaterOrEq,
SignedLess,
SignedLessOrEq,
SignedGreater,
SignedGreaterOrEq,
}
#[derive(PartialEq, Eq, Copy, Clone)]
@ -1153,29 +1160,41 @@ pub enum SetpCompareFloat {
IsAnyNan,
}
impl TryFrom<RawSetpCompareOp> for SetpCompareInt {
impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt {
type Error = PtxError;
fn try_from(value: RawSetpCompareOp) -> Result<Self, PtxError> {
match value {
RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq),
RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq),
RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less),
RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq),
RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater),
RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq),
RawSetpCompareOp::Equ => Err(PtxError::WrongType),
RawSetpCompareOp::Neu => Err(PtxError::WrongType),
RawSetpCompareOp::Ltu => Err(PtxError::WrongType),
RawSetpCompareOp::Leu => Err(PtxError::WrongType),
RawSetpCompareOp::Gtu => Err(PtxError::WrongType),
RawSetpCompareOp::Geu => Err(PtxError::WrongType),
RawSetpCompareOp::Num => Err(PtxError::WrongType),
RawSetpCompareOp::Nan => Err(PtxError::WrongType),
fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result<Self, PtxError> {
match (value, kind) {
(RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq),
(RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq),
(RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => {
Ok(SetpCompareInt::SignedLess)
}
(RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess),
(RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => {
Ok(SetpCompareInt::SignedLessOrEq)
}
(RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => {
Ok(SetpCompareInt::UnsignedLessOrEq)
}
(RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => {
Ok(SetpCompareInt::SignedGreater)
}
(RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater),
(RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => {
Ok(SetpCompareInt::SignedGreaterOrEq)
}
(RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => {
Ok(SetpCompareInt::UnsignedGreaterOrEq)
}
(RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Num, _) => Err(PtxError::WrongType),
(RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType),
}
}
}
@ -1276,7 +1295,9 @@ impl<T: Operand> CallArgs<T> {
.return_arguments
.into_iter()
.zip(details.return_arguments.iter())
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false))
.map(|(param, (type_, space))| {
visitor.visit_ident(param, Some((type_, *space)), true, false)
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.visit_ident(self.func, None, false, false)?;
let input_arguments = self
@ -1305,6 +1326,8 @@ pub enum CvtMode {
SignExtend,
Truncate,
Bitcast,
SaturateUnsignedToSigned,
SaturateSignedToUnsigned,
// float from float
FPExtend {
flush_to_zero: Option<bool>,
@ -1389,21 +1412,11 @@ impl CvtDetails {
},
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
(
ScalarKind::Unsigned | ScalarKind::Signed,
ScalarKind::Unsigned | ScalarKind::Signed,
) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => {
if dst.kind() != src.kind() {
errors.push(PtxError::Todo);
}
CvtMode::Truncate
}
(ScalarKind::Unsigned, ScalarKind::Unsigned)
| (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => CvtMode::Truncate,
Ordering::Equal => CvtMode::Bitcast,
Ordering::Greater => {
if dst.kind() != src.kind() {
errors.push(PtxError::Todo);
}
if src.kind() == ScalarKind::Signed {
CvtMode::SignExtend
} else {