diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs new file mode 100644 index 0000000..1dac7fd --- /dev/null +++ b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs @@ -0,0 +1,299 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use super::*; + +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument to the kernel + And in AMD compilation + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ +pub(super) fn run<'input>( + module: Vec>, + kernels_methods_call_map: &MethodsCallMap<'input>, + new_id: &mut impl FnMut() -> SpirvWord, +) -> Result>, TranslateError> { + let mut globals_shared = HashMap::new(); + for dir in module.iter() { + match dir { + Directive::Variable( + _, + ast::Variable { + state_space: ast::StateSpace::Shared, + name, + v_type, + .. + }, + ) => { + globals_shared.insert(*name, v_type.clone()); + } + _ => {} + } + } + if globals_shared.len() == 0 { + return Ok(module); + } + let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); + let module = module + .into_iter() + .map(|directive| match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let call_key = (*func_decl).borrow().name; + let statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if let Some(_) = globals_shared.get(&id) { + methods_to_directly_used_shared_globals + .entry(call_key) + .or_insert_with(HashSet::new) + .insert(id); + } + Ok::<_, TranslateError>(id) + }, + ) + }) + .collect::, _>>()?; + Ok::<_, TranslateError>(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })) + } + directive => Ok(directive), + }) + .collect::, _>>()?; + // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, + // make sure it gets propagated to `fn1` and `kernel` + let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( + methods_to_directly_used_shared_globals, + kernels_methods_call_map, + ); + // now visit every method declaration and inject those additional arguments + let mut directives = Vec::with_capacity(module.len()); + for directive in module.into_iter() { + match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let statements = { + let func_decl_ref = &mut (*func_decl).borrow_mut(); + let method_name = func_decl_ref.name; + insert_arguments_remap_statements( + new_id, + kernels_methods_call_map, + &globals_shared, + &methods_to_indirectly_used_shared_globals, + method_name, + &mut directives, + func_decl_ref, + statements, + )? + }; + directives.push(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })); + } + directive => directives.push(directive), + } + } + Ok(directives) +} + +// We need to compute two kinds of information: +// * If it's a kernel -> size of .shared globals in use (direct or indirect) +// * If it's a function -> does it use .shared global (directly or indirectly) +fn resolve_indirect_uses_of_globals_shared<'input>( + methods_use_of_globals_shared: HashMap, HashSet>, + kernels_methods_call_map: &MethodsCallMap<'input>, +) -> HashMap, BTreeSet> { + let mut result = HashMap::new(); + for (method, callees) in kernels_methods_call_map.methods() { + let mut indirect_globals = methods_use_of_globals_shared + .get(&method) + .into_iter() + .flatten() + .copied() + .collect::>(); + for &callee in callees { + indirect_globals.extend( + methods_use_of_globals_shared + .get(&ast::MethodName::Func(callee)) + .into_iter() + .flatten() + .copied(), + ); + } + result.insert(method, indirect_globals); + } + result +} + +fn insert_arguments_remap_statements<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + kernels_methods_call_map: &MethodsCallMap<'input>, + globals_shared: &HashMap, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + method_name: ast::MethodName, + result: &mut Vec, + func_decl_ref: &mut std::cell::RefMut>, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let remapped_globals_in_method = + if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { + match method_name { + ast::MethodName::Func(..) => { + let remapped_globals = method_globals + .iter() + .map(|global| { + ( + *global, + ( + new_id(), + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(); + for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { + func_decl_ref.input_arguments.push(ast::Variable { + align: None, + v_type: shared_global_type.clone(), + state_space: ast::StateSpace::Shared, + name: *new_shared_global_id, + array_init: Vec::new(), + }); + } + remapped_globals + } + ast::MethodName::Kernel(..) => method_globals + .iter() + .map(|global| { + ( + *global, + ( + *global, + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(), + } + } else { + return Ok(statements); + }; + replace_uses_of_shared_memory( + new_id, + methods_to_indirectly_used_shared_globals, + statements, + remapped_globals_in_method, + ) +} + +fn replace_uses_of_shared_memory<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + statements: Vec, + remapped_globals_in_method: BTreeMap, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + // We can safely skip checking call arguments, + // because there's simply no way to pass shared ptr + // without converting it to .b64 first + if let Some(shared_globals_used_by_callee) = + methods_to_indirectly_used_shared_globals + .get(&ast::MethodName::Func(arguments.func)) + { + for &shared_global_used_by_callee in shared_globals_used_by_callee { + let (remapped_shared_id, type_) = remapped_globals_in_method + .get(&shared_global_used_by_callee) + .unwrap_or_else(|| todo!()); + data.input_arguments + .push((type_.clone(), ast::StateSpace::Shared)); + arguments.input_arguments.push(*remapped_shared_id); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })) + } + statement => { + let new_statement = + statement.visit_map(&mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _, + _| { + Ok::<_, TranslateError>( + if let Some((remapped_shared_id, _)) = + remapped_globals_in_method.get(&id) + { + *remapped_shared_id + } else { + id + }, + ) + })?; + result.push(new_statement); + } + } + } + Ok(result) +} diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 829e1e6..61b31ad 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -394,25 +394,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { } } -fn multi_hash_map_append< - K: Eq + std::hash::Hash, - V, - Collection: std::iter::Extend + std::default::Default, ->( - m: &mut HashMap, - key: K, - value: V, -) { - match m.entry(key) { - hash_map::Entry::Occupied(mut entry) => { - entry.get_mut().extend(iter::once(value)); - } - hash_map::Entry::Vacant(entry) => { - entry.insert(Default::default()).extend(iter::once(value)); - } - } -} - fn is_add_ptr_direct( remapped_ids: &HashMap, arg: &ast::AddArgs, diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs new file mode 100644 index 0000000..9dff12e --- /dev/null +++ b/ptx/src/pass/emit_spirv.rs @@ -0,0 +1,2767 @@ +use super::*; +use half::f16; +use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + mem, +}; + +pub(super) fn run<'input>( + mut builder: dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + call_map: MethodsCallMap<'input>, + denorm_information: HashMap< + ptx_parser::MethodName, + HashMap, + >, + directives: Vec>, +) -> Result<(), TranslateError> { + builder.set_version(1, 3); + emit_capabilities(&mut builder); + emit_extensions(&mut builder); + let opencl_id = emit_opencl_import(&mut builder); + emit_memory_model(&mut builder); + let mut map = TypeWordMap::new(&mut builder); + //emit_builtins(&mut builder, &mut map, &id_defs); + let mut kernel_info = HashMap::new(); + let (build_options, should_flush_denorms) = + emit_denorm_build_string(&call_map, &denorm_information); + let (directives, globals_use_map) = get_globals_use_map(directives); + emit_directives( + &mut builder, + &mut map, + &id_defs, + opencl_id, + should_flush_denorms, + &call_map, + globals_use_map, + directives, + &mut kernel_info, + ) +} + +fn emit_capabilities(builder: &mut dr::Builder) { + builder.capability(spirv::Capability::GenericPointer); + builder.capability(spirv::Capability::Linkage); + builder.capability(spirv::Capability::Addresses); + builder.capability(spirv::Capability::Kernel); + builder.capability(spirv::Capability::Int8); + builder.capability(spirv::Capability::Int16); + builder.capability(spirv::Capability::Int64); + builder.capability(spirv::Capability::Float16); + builder.capability(spirv::Capability::Float64); + builder.capability(spirv::Capability::DenormFlushToZero); + // TODO: re-enable when Intel float control extension works + //builder.capability(spirv::Capability::FunctionFloatControlINTEL); +} + +// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html +fn emit_extensions(builder: &mut dr::Builder) { + // TODO: re-enable when Intel float control extension works + //builder.extension("SPV_INTEL_float_controls2"); + builder.extension("SPV_KHR_float_controls"); + builder.extension("SPV_KHR_no_integer_wrap_decoration"); +} + +fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { + builder.ext_inst_import("OpenCL.std") +} + +fn emit_memory_model(builder: &mut dr::Builder) { + builder.memory_model( + spirv::AddressingModel::Physical64, + spirv::MemoryModel::OpenCL, + ); +} + +struct TypeWordMap { + void: spirv::Word, + complex: HashMap, + constants: HashMap<(SpirvType, u64), SpirvWord>, +} + +impl TypeWordMap { + fn new(b: &mut dr::Builder) -> TypeWordMap { + let void = b.type_void(None); + TypeWordMap { + void: void, + complex: HashMap::::new(), + constants: HashMap::new(), + } + } + + fn void(&self) -> spirv::Word { + self.void + } + + fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { + let key: SpirvScalarKey = t.into(); + self.get_or_add_spirv_scalar(b, key) + } + + fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { + *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { + SpirvWord(match key { + SpirvScalarKey::B8 => b.type_int(None, 8, 0), + SpirvScalarKey::B16 => b.type_int(None, 16, 0), + SpirvScalarKey::B32 => b.type_int(None, 32, 0), + SpirvScalarKey::B64 => b.type_int(None, 64, 0), + SpirvScalarKey::F16 => b.type_float(None, 16), + SpirvScalarKey::F32 => b.type_float(None, 32), + SpirvScalarKey::F64 => b.type_float(None, 64), + SpirvScalarKey::Pred => b.type_bool(None), + SpirvScalarKey::F16x2 => todo!(), + }) + }) + } + + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { + match t { + SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), + SpirvType::Pointer(ref typ, storage) => { + let base = self.get_or_add(b, *typ.clone()); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) + } + SpirvType::Vector(typ, len) => { + let base = self.get_or_add_spirv_scalar(b, typ); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) + } + SpirvType::Array(typ, array_dimensions) => { + let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } + &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self.get_or_add_spirv_scalar(b, typ); + let len_const = b.constant_u32(u32_type.0, None, len); + (base, len_const) + } + array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self + .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); + let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); + (base, len_const) + } + }; + *self + .complex + .entry(SpirvType::Array(typ, array_dimensions)) + .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) + } + SpirvType::Func(ref out_params, ref in_params) => { + let out_t = match out_params { + Some(p) => self.get_or_add(b, *p.clone()), + None => SpirvWord(self.void()), + }; + let in_t = in_params + .iter() + .map(|t| self.get_or_add(b, t.clone()).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) + } + SpirvType::Struct(ref underlying) => { + let underlying_ids = underlying + .iter() + .map(|t| self.get_or_add_spirv_scalar(b, *t).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) + } + } + } + + fn get_or_add_fn( + &mut self, + b: &mut dr::Builder, + in_params: impl Iterator, + mut out_params: impl ExactSizeIterator, + ) -> (SpirvWord, SpirvWord) { + let (out_args, out_spirv_type) = if out_params.len() == 0 { + (None, SpirvWord(self.void())) + } else if out_params.len() == 1 { + let arg_as_key = out_params.next().unwrap(); + ( + Some(Box::new(arg_as_key.clone())), + self.get_or_add(b, arg_as_key), + ) + } else { + // TODO: support multiple return values + todo!() + }; + ( + out_spirv_type, + self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), + ) + } + + fn get_or_add_constant( + &mut self, + b: &mut dr::Builder, + typ: &ast::Type, + init: &[u8], + ) -> Result { + Ok(match typ { + ast::Type::Scalar(t) => match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v), + ), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v, + |b, result_type, v| b.constant_u64(result_type, None, v), + ), + ast::ScalarType::F16 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u16>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), + ), + ast::ScalarType::F32 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u32>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v), + ), + ast::ScalarType::F64 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u64>(v) }, + |b, result_type, v| b.constant_f64(result_type, None, v), + ), + ast::ScalarType::F16x2 => return Err(TranslateError::Todo), + ast::ScalarType::Pred => self.get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| { + if v == 0 { + b.constant_false(result_type, None) + } else { + b.constant_true(result_type, None) + } + }, + ), + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + }, + ast::Type::Vector(typ, len) => { + let result_type = + self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); + let size_of_t = typ.size_of(); + let components = (0..*len) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + ast::Type::Array(typ, dims) => match dims.as_slice() { + [] => return Err(error_unreachable()), + [dim] => { + let result_type = self + .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); + let size_of_t = typ.size_of(); + let components = (0..*dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + [first_dim, rest @ ..] => { + let result_type = self.get_or_add( + b, + SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), + ); + let size_of_t = rest + .iter() + .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); + let components = (0..*first_dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Array(*typ, rest.to_vec()), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + }, + ast::Type::Pointer(..) => return Err(error_unreachable()), + }) + } + + fn get_or_add_constant_single< + T: Copy, + CastAsU64: FnOnce(T) -> u64, + InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, + >( + &mut self, + b: &mut dr::Builder, + key: ast::ScalarType, + init: &[u8], + cast: CastAsU64, + f: InsertConstant, + ) -> SpirvWord { + let value = unsafe { *(init.as_ptr() as *const T) }; + let value_64 = cast(value); + let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); + match self.constants.get(&ht_key) { + Some(value) => *value, + None => { + let spirv_type = self.get_or_add_scalar(b, key); + let result = SpirvWord(f(b, spirv_type.0, value)); + self.constants.insert(ht_key, result); + result + } + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum SpirvType { + Base(SpirvScalarKey), + Vector(SpirvScalarKey, u8), + Array(SpirvScalarKey, Vec), + Pointer(Box, spirv::StorageClass), + Func(Option>, Vec), + Struct(Vec), +} + +impl SpirvType { + fn new(t: ast::Type) -> Self { + match t { + ast::Type::Scalar(t) => SpirvType::Base(t.into()), + ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), + ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space_to_spirv(space), + ), + } + } + + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); + SpirvType::Pointer(Box::new(key), outer_space) + } +} + +impl From for SpirvType { + fn from(t: ast::ScalarType) -> Self { + SpirvType::Base(t.into()) + } +} +// SPIR-V integer type definitions are signless, more below: +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +enum SpirvScalarKey { + B8, + B16, + B32, + B64, + F16, + F32, + F64, + Pred, + F16x2, +} + +impl From for SpirvScalarKey { + fn from(t: ast::ScalarType) -> Self { + match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { + SpirvScalarKey::B16 + } + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { + SpirvScalarKey::B32 + } + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { + SpirvScalarKey::B64 + } + ast::ScalarType::F16 => SpirvScalarKey::F16, + ast::ScalarType::F32 => SpirvScalarKey::F32, + ast::ScalarType::F64 => SpirvScalarKey::F64, + ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, + ast::ScalarType::Pred => SpirvScalarKey::Pred, + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + } + } +} + +fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { + match this { + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + } +} + +// TODO: remove this once we have pef-function support for denorms +fn emit_denorm_build_string<'input>( + call_map: &MethodsCallMap, + denorm_information: &HashMap< + ast::MethodName<'input, SpirvWord>, + HashMap, + >, +) -> (CString, bool) { + let denorm_counts = denorm_information + .iter() + .map(|(method, meth_denorm)| { + let f16_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + let f32_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + (method, (f16_count + f32_count)) + }) + .collect::>(); + let mut flush_over_preserve = 0; + for (kernel, children) in call_map.kernels() { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); + for child_fn in children { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Func(*child_fn)) + .unwrap_or(&0); + } + } + if flush_over_preserve > 0 { + ( + CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), + true, + ) + } else { + (CString::new("-ze-take-global-address").unwrap(), false) + } +} + +fn get_globals_use_map<'input>( + directives: Vec>, +) -> ( + Vec>, + HashMap, HashSet>, +) { + let mut known_globals = HashSet::new(); + for directive in directives.iter() { + match directive { + Directive::Variable(_, ast::Variable { name, .. }) => { + known_globals.insert(*name); + } + Directive::Method(..) => {} + } + } + let mut symbol_uses_map = HashMap::new(); + let directives = directives + .into_iter() + .map(|directive| match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, + Directive::Method(Function { + func_decl, + body: Some(mut statements), + globals, + import_as, + tuning, + linkage, + }) => { + let method_name = func_decl.borrow().name; + statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if known_globals.contains(&symbol) { + multi_hash_map_append( + &mut symbol_uses_map, + method_name, + symbol, + ); + } + Ok::<_, TranslateError>(symbol) + }, + ) + }) + .collect::, _>>() + .unwrap(); + Directive::Method(Function { + func_decl, + body: Some(statements), + globals, + import_as, + tuning, + linkage, + }) + } + }) + .collect::>(); + (directives, symbol_uses_map) +} + +fn emit_directives<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl_id: spirv::Word, + should_flush_denorms: bool, + call_map: &MethodsCallMap<'input>, + globals_use_map: HashMap, HashSet>, + directives: Vec>, + kernel_info: &mut HashMap, +) -> Result<(), TranslateError> { + let empty_body = Vec::new(); + for d in directives.iter() { + match d { + Directive::Variable(linking, var) => { + emit_variable(builder, map, id_defs, *linking, &var)?; + } + Directive::Method(f) => { + let f_body = match &f.body { + Some(f) => f, + None => { + if f.linkage.contains(ast::LinkingDirective::EXTERN) { + &empty_body + } else { + continue; + } + } + }; + for var in f.globals.iter() { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + let func_decl = (*f.func_decl).borrow(); + let fn_id = emit_function_header( + builder, + map, + &id_defs, + &*func_decl, + call_map, + &globals_use_map, + kernel_info, + )?; + if matches!(func_decl.name, ast::MethodName::Kernel(_)) { + if should_flush_denorms { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [16], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [32], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [64], + ); + } + // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::ContractionOff, + [], + ); + for t in f.tuning.iter() { + match *t { + ast::TuningDirective::MaxNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, + [nx, ny, nz], + ); + } + ast::TuningDirective::ReqNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::LocalSize, + [nx, ny, nz], + ); + } + // Too architecture specific + ast::TuningDirective::MaxNReg(..) + | ast::TuningDirective::MinNCtaPerSm(..) => {} + } + } + } + emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; + emit_function_linkage(builder, id_defs, f, fn_id)?; + builder.select_block(None)?; + builder.end_function()?; + } + } + } + Ok(()) +} + +fn emit_variable<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + linking: ast::LinkingDirective, + var: &ast::Variable, +) -> Result<(), TranslateError> { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { + (false, spirv::StorageClass::Function) + } + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), + ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + }; + let initalizer = if var.array_init.len() > 0 { + Some( + map.get_or_add_constant( + builder, + &ast::Type::from(var.v_type.clone()), + &*var.array_init, + )? + .0, + ) + } else if must_init { + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); + Some(builder.constant_null(type_id.0, None)) + } else { + None + }; + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); + builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); + if let Some(align) = var.align { + builder.decorate( + var.name.0, + spirv::Decoration::Alignment, + [dr::Operand::LiteralInt32(align)].iter().cloned(), + ); + } + if var.state_space != ast::StateSpace::Shared + || !linking.contains(ast::LinkingDirective::EXTERN) + { + emit_linking_decoration(builder, id_defs, None, var.name, linking); + } + Ok(()) +} + +fn emit_function_header<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + defined_globals: &GlobalStringIdResolver<'input>, + func_decl: &ast::MethodDeclaration<'input, SpirvWord>, + call_map: &MethodsCallMap<'input>, + globals_use_map: &HashMap, HashSet>, + kernel_info: &mut HashMap, +) -> Result { + if let ast::MethodName::Kernel(name) = func_decl.name { + let args_lens = func_decl + .input_arguments + .iter() + .map(|param| { + ( + type_size_of(¶m.v_type), + matches!(param.v_type, ast::Type::Pointer(..)), + ) + }) + .collect(); + kernel_info.insert( + name.to_string(), + KernelInfo { + arguments_sizes: args_lens, + uses_shared_mem: func_decl.shared_mem.is_some(), + }, + ); + } + let (ret_type, func_type) = get_function_type( + builder, + map, + effective_input_arguments(func_decl).map(|(_, typ)| typ), + &func_decl.return_arguments, + ); + let fn_id = match func_decl.name { + ast::MethodName::Kernel(name) => { + let fn_id = defined_globals.get_id(name)?; + let interface = globals_use_map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + .copied() + .chain({ + call_map + .get_kernel_children(name) + .copied() + .flat_map(|subfunction| { + globals_use_map + .get(&ast::MethodName::Func(subfunction)) + .into_iter() + .flatten() + .copied() + }) + .into_iter() + }) + .map(|word| word.0) + .collect::>(); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); + fn_id + } + ast::MethodName::Func(name) => name, + }; + builder.begin_function( + ret_type.0, + Some(fn_id.0), + spirv::FunctionControl::NONE, + func_type.0, + )?; + for (name, typ) in effective_input_arguments(func_decl) { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name.0), result_type.0)?; + } + Ok(fn_id) +} + +pub fn type_size_of(this: &ast::Type) -> usize { + match this { + ast::Type::Scalar(typ) => typ.size_of() as usize, + ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(typ, len) => len + .iter() + .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), + ast::Type::Pointer(..) => mem::size_of::(), + } +} +fn emit_function_body_ops<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl: spirv::Word, + func: &[ExpandedStatement], +) -> Result<(), TranslateError> { + for s in func { + match s { + Statement::Label(id) => { + if builder.selected_block().is_some() { + builder.branch(id.0)?; + } + builder.begin_block(Some(id.0))?; + } + _ => { + if builder.selected_block().is_none() && builder.selected_function().is_some() { + builder.begin_block(None)?; + } + } + } + match s { + Statement::Label(_) => (), + Statement::Variable(var) => { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + Statement::Constant(cnst) => { + let typ_id = map.get_or_add_scalar(builder, cnst.typ); + match (cnst.typ, cnst.value) { + (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); + } + (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f32(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f64(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); + } + (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + _ => return Err(TranslateError::MismatchedType), + } + } + Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, + Statement::Conditional(bra) => { + builder.branch_conditional( + bra.predicate.0, + bra.if_true.0, + bra.if_false.0, + iter::empty(), + )?; + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + // TODO: implement properly + let zero = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U64), + &vec_repr(0u64), + )?; + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); + builder.copy_object(result_type.0, Some(dst.0), zero.0)?; + } + Statement::Instruction(inst) => match inst { + ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), + ast::Instruction::Call { data, arguments } => { + let (result_type, result_id) = + match (&*data.return_arguments, &*arguments.return_arguments) { + ([(type_, space)], [id]) => { + if *space != ast::StateSpace::Reg { + return Err(error_unreachable()); + } + ( + map.get_or_add(builder, SpirvType::new(type_.clone())).0, + Some(id.0), + ) + } + ([], []) => (map.void(), None), + _ => todo!(), + }; + let arg_list = arguments + .input_arguments + .iter() + .map(|id| id.0) + .collect::>(); + builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; + } + ast::Instruction::Abs { data, arguments } => { + emit_abs(builder, map, opencl, data, arguments)? + } + // SPIR-V does not support marking jumps as guaranteed-converged + ast::Instruction::Bra { arguments, .. } => { + builder.branch(arguments.src.0)?; + } + ast::Instruction::Ld { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.load( + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + ast::Instruction::St { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + builder.store( + arguments.src1.0, + arguments.src2.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + // SPIR-V does not support ret as guaranteed-converged + ast::Instruction::Ret { .. } => builder.ret()?, + ast::Instruction::Mov { data, arguments } => { + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Mul { data, arguments } => match data { + ast::MulDetails::Integer { type_, control } => { + emit_mul_int(builder, map, opencl, *type_, *control, arguments)? + } + ast::MulDetails::Float(ref ctr) => { + emit_mul_float(builder, map, ctr, arguments)? + } + }, + ast::Instruction::Add { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? + } + ast::ArithDetails::Float(desc) => { + emit_add_float(builder, map, desc, arguments)? + } + }, + ast::Instruction::Setp { data, arguments } => { + if arguments.dst2.is_some() { + todo!() + } + emit_setp(builder, map, data, arguments)?; + } + ast::Instruction::Not { data, arguments } => { + let result_type = map.get_or_add(builder, SpirvType::from(*data)); + let result_id = Some(arguments.dst.0); + let operand = arguments.src; + match data { + ast::ScalarType::Pred => { + logical_not(builder, result_type.0, result_id, operand.0) + } + _ => builder.not(result_type.0, result_id, operand.0), + }?; + } + ast::Instruction::Shl { data, arguments } => { + let full_type = ast::Type::Scalar(*data); + let size_of = type_size_of(&full_type); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); + let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; + builder.shift_left_logical( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ast::Instruction::Shr { data, arguments } => { + let full_type = ast::ScalarType::from(data.type_); + let size_of = full_type.size_of(); + let result_type = map.get_or_add_scalar(builder, full_type).0; + let offset_src = + insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; + match data.kind { + ptx_parser::RightShiftKind::Arithmetic => { + builder.shift_right_arithmetic( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ptx_parser::RightShiftKind::Logical => { + builder.shift_right_logical( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + } + } + ast::Instruction::Cvt { data, arguments } => { + emit_cvt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Cvta { data, arguments } => { + // This would be only meaningful if const/slm/global pointers + // had a different format than generic pointers, but they don't pretty much by ptx definition + // Honestly, I have no idea why this instruction exists and is emitted by the compiler + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::SetpBool { .. } => todo!(), + ast::Instruction::Mad { data, arguments } => match data { + ast::MadDetails::Integer { + type_, + control, + saturate, + } => { + if *saturate { + todo!() + } + if type_.kind() == ast::ScalarKind::Signed { + emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? + } else { + emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? + } + } + ast::MadDetails::Float(desc) => { + emit_mad_float(builder, map, opencl, desc, arguments)? + } + }, + ast::Instruction::Fma { data, arguments } => { + emit_fma_float(builder, map, opencl, data, arguments)? + } + ast::Instruction::Or { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data).0; + if *data == ast::ScalarType::Pred { + builder.logical_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Sub { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; + } + ast::ArithDetails::Float(desc) => { + emit_sub_float(builder, map, desc, arguments)?; + } + }, + ast::Instruction::Min { data, arguments } => { + emit_min(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Max { data, arguments } => { + emit_max(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rcp { data, arguments } => { + emit_rcp(builder, map, opencl, data, arguments)?; + } + ast::Instruction::And { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + if *data == ast::ScalarType::Pred { + builder.logical_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Selp { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + builder.select( + result_type.0, + Some(arguments.dst.0), + arguments.src3.0, + arguments.src1.0, + arguments.src2.0, + )?; + } + // TODO: implement named barriers + ast::Instruction::Bar { data, arguments } => { + let workgroup_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(spirv::Scope::Workgroup as u32), + )?; + let barrier_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr( + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + )?; + builder.control_barrier( + workgroup_scope.0, + workgroup_scope.0, + barrier_semantics.0, + )?; + } + ast::Instruction::Atom { data, arguments } => { + emit_atom(builder, map, data, arguments)?; + } + ast::Instruction::AtomCas { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(data.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(data.semantics).bits()), + )?; + builder.atomic_compare_exchange( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + memory_const.0, + semantics_const.0, + semantics_const.0, + arguments.src3.0, + arguments.src2.0, + )?; + } + ast::Instruction::Div { data, arguments } => match data { + ast::DivDetails::Unsigned(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.u_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Signed(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.s_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Float(t) => { + let result_type = map.get_or_add_scalar(builder, t.type_.into()); + builder.f_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + emit_float_div_decoration(builder, arguments.dst, t.kind); + } + }, + ast::Instruction::Sqrt { data, arguments } => { + emit_sqrt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rsqrt { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_.into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::rsqrt as spirv::Word, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Neg { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let negate_func = if data.type_.kind() == ast::ScalarKind::Float { + dr::Builder::f_negate + } else { + dr::Builder::s_negate + }; + negate_func( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + )?; + } + ast::Instruction::Sin { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::sin as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Cos { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::cos as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Lg2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::log2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Ex2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::exp2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Clz { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::clz as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Brev { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Popc { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Xor { data, arguments } => { + let builder_fn: fn( + &mut dr::Builder, + u32, + Option, + u32, + u32, + ) -> Result = match data { + ast::ScalarType::Pred => emit_logical_xor_spirv, + _ => dr::Builder::bitwise_xor, + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => { + // Should have beeen replaced with a funciton call earlier + return Err(error_unreachable()); + } + + ast::Instruction::Rem { data, arguments } => { + let builder_fn = if data.kind() == ast::ScalarKind::Signed { + dr::Builder::s_mod + } else { + dr::Builder::u_mod + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Prmt { data, arguments } => { + let control = *data as u32; + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let vec4_b8_type = + map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); + let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); + let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; + let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; + let dst_vector = builder.vector_shuffle( + vec4_b8_type.0, + None, + src1_vector, + src2_vector, + components, + )?; + builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; + } + ast::Instruction::Membar { data } => { + let (scope, semantics) = match data { + ast::MemScope::Cta => ( + spirv::Scope::Workgroup, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Gpu => ( + spirv::Scope::Device, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Sys => ( + spirv::Scope::CrossDevice, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + + ast::MemScope::Cluster => todo!(), + }; + let spirv_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope as u32), + )?; + let spirv_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics), + )?; + builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; + } + }, + Statement::LoadVar(details) => { + emit_load_var(builder, map, details)?; + } + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some(index) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to( + details.typ.clone(), + spirv::StorageClass::Function, + ), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src1.0, + [index_spirv.0].iter().copied(), + )? + } + None => details.arg.src1.0, + }; + builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; + } + Statement::RetValue(_, id) => { + builder.ret_value(id.0)?; + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let u8_pointer = map.get_or_add( + builder, + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), + ); + let result_type = map.get_or_add( + builder, + SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), + ); + let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; + let temp = builder.in_bounds_ptr_access_chain( + u8_pointer.0, + None, + ptr_src_u8, + offset_src.0, + iter::empty(), + )?; + builder.bitcast(result_type.0, Some(dst.0), temp)?; + } + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type.0, + Some(dst_id.0), + repack.packed.0, + [index as u32].iter().copied(), + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type.0, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type.0, + None, + src_id.0, + temp_vec, + [index as u32].iter().copied(), + )?; + } + builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; + } + } + } + } + Ok(()) +} + +fn emit_function_linkage<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + f: &Function, + fn_name: SpirvWord, +) -> Result<(), TranslateError> { + if f.linkage == ast::LinkingDirective::NONE { + return Ok(()); + }; + let linking_name = match f.func_decl.borrow().name { + // According to SPIR-V rules linkage attributes are invalid on kernels + ast::MethodName::Kernel(..) => return Ok(()), + ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( + || match id_defs.reverse_variables.get(&fn_id) { + Some(fn_name) => Ok(fn_name), + None => Err(error_unknown_symbol()), + }, + Result::Ok, + )?, + }; + emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); + Ok(()) +} + +fn get_function_type( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + spirv_input: impl Iterator, + spirv_output: &[ast::Variable], +) -> (SpirvWord, SpirvWord) { + map.get_or_add_fn( + builder, + spirv_input, + spirv_output + .iter() + .map(|var| SpirvType::new(var.v_type.clone())), + ) +} + +fn emit_linking_decoration<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + name_override: Option<&str>, + name: SpirvWord, + linking: ast::LinkingDirective, +) { + if linking == ast::LinkingDirective::NONE { + return; + } + if linking.contains(ast::LinkingDirective::VISIBLE) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Export), + ] + .iter() + .cloned(), + ); + } else if linking.contains(ast::LinkingDirective::EXTERN) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Import), + ] + .iter() + .cloned(), + ); + } + // TODO: handle LinkingDirective::WEAK +} + +fn effective_input_arguments<'a>( + this: &'a ast::MethodDeclaration<'a, SpirvWord>, +) -> impl Iterator + 'a { + let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); + this.input_arguments.iter().map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) +} + +fn emit_implicit_conversion( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + cv: &ImplicitConversion, +) -> Result<(), TranslateError> { + let from_parts = to_parts(&cv.from_type); + let to_parts = to_parts(&cv.to_type); + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), + ); + builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { + if from_parts.width == to_parts.width { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } else { + builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = map.get_or_add( + builder, + SpirvType::new(type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + })), + ); + let same_width_bit_value = + builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; + let wide_bit_type = type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..to_parts + }); + let wide_bit_type_spirv = + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit + { + builder.u_convert( + wide_bit_type_spirv.0, + Some(cv.dst.0), + same_width_bit_value, + )?; + } else { + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed + { + dr::Builder::s_convert + } else { + dr::Builder::u_convert + }; + let wide_bit_value = + conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; + emit_implicit_conversion( + builder, + map, + &ImplicitConversion { + src: SpirvWord(wide_bit_value), + dst: cv.dst, + from_type: wide_bit_type, + from_space: cv.from_space, + to_type: cv.to_type.clone(), + to_space: cv.to_space, + kind: ConversionKind::Default, + }, + )?; + } + } + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; + } + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.to_space), + ), + ); + if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; + } else if cv.from_space == ast::StateSpace::Generic + && cv.to_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; + } else { + builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + } + (_, _, &ConversionKind::AddressOf) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + _ => unreachable!(), + } + Ok(()) +} + +fn vec_repr(t: T) -> Vec { + let mut result = vec![0; mem::size_of::()]; + unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; + result +} + +fn emit_abs( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + d: &ast::TypeFtz, + arg: &ast::AbsArgs, +) -> Result<(), dr::Error> { + let scalar_t = ast::ScalarType::from(d.type_); + let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_abs + } else { + spirv::CLOp::fabs + }; + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + cl_abs as spirv::Word, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + Ok(()) +} + +fn emit_mul_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)); + match control { + ast::MulIntControl::Low => { + builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::s_mul_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => { + let instr_width = type_.size_of(); + let instr_kind = type_.kind(); + let dst_type = scalar_from_parts(instr_width * 2, instr_kind); + let dst_type_id = map.get_or_add_scalar(builder, dst_type); + let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { + let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + } else { + let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + }; + builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; + builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); + } + } + Ok(()) +} + +fn emit_mul_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ctr: &ast::ArithFloat, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + if ctr.saturate { + todo!() + } + let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); + builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, ctr.rounding); + Ok(()) +} + +fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { + match kind { + ast::ScalarKind::Float => match width { + 2 => ast::ScalarType::F16, + 4 => ast::ScalarType::F32, + 8 => ast::ScalarType::F64, + _ => unreachable!(), + }, + ast::ScalarKind::Bit => match width { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => unreachable!(), + }, + ast::ScalarKind::Signed => match width { + 1 => ast::ScalarType::S8, + 2 => ast::ScalarType::S16, + 4 => ast::ScalarType::S32, + 8 => ast::ScalarType::S64, + _ => unreachable!(), + }, + ast::ScalarKind::Unsigned => match width { + 1 => ast::ScalarType::U8, + 2 => ast::ScalarType::U16, + 4 => ast::ScalarType::U32, + 8 => ast::ScalarType::U64, + _ => unreachable!(), + }, + ast::ScalarKind::Pred => ast::ScalarType::Pred, + } +} + +fn emit_rounding_decoration( + builder: &mut dr::Builder, + dst: SpirvWord, + rounding: Option, +) { + if let Some(rounding) = rounding { + builder.decorate( + dst.0, + spirv::Decoration::FPRoundingMode, + [rounding_to_spirv(rounding)].iter().cloned(), + ); + } +} + +fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { + let mode = match this { + ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, + ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, + ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, + ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, + }; + rspirv::dr::Operand::FPRoundingMode(mode) +} + +fn emit_add_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); + builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_add_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); + builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_setp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + setp: &ast::SetpData, + arg: &ast::SetpArgs, +) -> Result<(), dr::Error> { + let result_type = map + .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) + .0; + let result_id = Some(arg.dst1.0); + let operand_1 = arg.src1.0; + let operand_2 = arg.src2.0; + match setp.cmp_op { + ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { + builder.i_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { + builder.f_ord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { + builder.i_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { + builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { + builder.u_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { + builder.s_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { + builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { + builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { + builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { + builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { + builder.u_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { + builder.s_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { + builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { + builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { + builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { + builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { + builder.f_unord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { + builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { + builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { + builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { + builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { + builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + builder.logical_or(result_type, result_id, temp1, temp2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; + logical_not(builder, result_type, result_id, any_nan) + } + _ => todo!(), + }?; + Ok(()) +} + +// HACK ALERT +// Temporary workaround until IGC gets its shit together +// Currently IGC carries two copies of SPIRV-LLVM translator +// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. +// Obviously, old and buggy one is used for compiling L0 SPIRV +// https://github.com/intel/intel-graphics-compiler/issues/148 +fn logical_not( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + operand: spirv::Word, +) -> Result { + let const_true = builder.constant_true(result_type, None); + let const_false = builder.constant_false(result_type, None); + builder.select(result_type, result_id, operand, const_false, const_true) +} + +// HACK ALERT +// For some reason IGC fails linking if the value and shift size are of different type +fn insert_shift_hack( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + offset_var: spirv::Word, + size_of: usize, +) -> Result { + let result_type = match size_of { + 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), + 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), + 4 => return Ok(offset_var), + _ => return Err(error_unreachable()), + }; + Ok(builder.u_convert(result_type.0, None, offset_var)?) +} + +fn emit_cvt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + dets: &ast::CvtDetails, + arg: &ast::CvtArgs, +) -> Result<(), TranslateError> { + match dets.mode { + ptx_parser::CvtMode::SignExtend => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::SignExtend, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::ZeroExtend + | ptx_parser::CvtMode::Truncate + | ptx_parser::CvtMode::Bitcast => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::Default, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::SaturateUnsignedToSigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::SaturateSignedToUnsigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPExtend { flush_to_zero } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPRound { + integer_rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + match integer_rounding { + Some(ast::RoundingMode::NearestEven) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::rint as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::Zero) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::trunc as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::NegativeInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::floor as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::PositiveInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::ceil as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + None => { + builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + } + } + ptx_parser::CvtMode::SignedFromFP { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::UnsignedFromFP { + rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromSigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromUnsigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + } + Ok(()) +} + +fn emit_mad_uint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) + .0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::u_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_sint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::s_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::mad as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_fma_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::FmaArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::fma as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_sub_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) + .0; + builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_sub_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_min( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MinArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_max( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MaxArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_rcp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::RcpData, + arg: &ast::RcpArgs, +) -> Result<(), TranslateError> { + let is_f64 = desc.type_ == ast::ScalarType::F64; + let (instr_type, constant) = if is_f64 { + (ast::ScalarType::F64, vec_repr(1.0f64)) + } else { + (ast::ScalarType::F32, vec_repr(1.0f32)) + }; + let result_type = map.get_or_add_scalar(builder, instr_type); + let rounding = match desc.kind { + ptx_parser::RcpKind::Approx => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::native_recip as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + return Ok(()); + } + ptx_parser::RcpKind::Compliant(rounding) => rounding, + }; + let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; + builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + builder.decorate( + arg.dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + Ok(()) +} + +fn emit_atom( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &ast::AtomDetails, + arg: &ast::AtomArgs, +) -> Result<(), TranslateError> { + let spirv_op = match details.op { + ptx_parser::AtomicOp::And => dr::Builder::atomic_and, + ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, + ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, + ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, + ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, + ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { + return Err(error_unreachable()) + } + ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, + ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, + ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, + ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, + ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, + ptx_parser::AtomicOp::FloatMin => todo!(), + ptx_parser::AtomicOp::FloatMax => todo!(), + }; + let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(details.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(details.semantics).bits()), + )?; + spirv_op( + builder, + result_type.0, + Some(arg.dst.0), + arg.src1.0, + memory_const.0, + semantics_const.0, + arg.src2.0, + )?; + Ok(()) +} + +fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { + match this { + ast::MemScope::Cta => spirv::Scope::Workgroup, + ast::MemScope::Gpu => spirv::Scope::Device, + ast::MemScope::Sys => spirv::Scope::CrossDevice, + ptx_parser::MemScope::Cluster => todo!(), + } +} + +fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { + match this { + ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, + ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, + ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, + ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, + } +} + +fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { + match kind { + ast::DivFloatKind::Approx => { + builder.decorate( + dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + } + ast::DivFloatKind::Rounding(rnd) => { + emit_rounding_decoration(builder, dst, Some(rnd)); + } + ast::DivFloatKind::ApproxFull => {} + } +} + +fn emit_sqrt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + details: &ast::RcpData, + a: &ast::SqrtArgs, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add_scalar(builder, details.type_.into()); + let (ocl_op, rounding) = match details.kind { + ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), + ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), + }; + builder.ext_inst( + result_type.0, + Some(a.dst.0), + opencl, + ocl_op as spirv::Word, + [dr::Operand::IdRef(a.src.0)].iter().cloned(), + )?; + emit_rounding_decoration(builder, a.dst, rounding); + Ok(()) +} + +// TODO: check what kind of assembly do we emit +fn emit_logical_xor_spirv( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + op1: spirv::Word, + op2: spirv::Word, +) -> Result { + let temp_or = builder.logical_or(result_type, None, op1, op2)?; + let temp_and = builder.logical_and(result_type, None, op1, op2)?; + let temp_neg = logical_not(builder, result_type, None, temp_and)?; + builder.logical_and(result_type, result_id, temp_or, temp_neg) +} + +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + _ => return Err(TranslateError::MismatchedType), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); + let vector_temp = builder.load( + vector_type_spirv.0, + None, + details.arg.src.0, + None, + iter::empty(), + )?; + builder.composite_extract( + result_type.0, + Some(details.arg.dst.0), + vector_temp, + [index as u32].iter().copied(), + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src.0, + [index_spirv.0].iter().copied(), + )?; + builder.load( + result_type.0, + Some(details.arg.dst.0), + src, + None, + iter::empty(), + )?; + } + None => { + builder.load( + result_type.0, + Some(details.arg.dst.0), + details.arg.src.0, + None, + iter::empty(), + )?; + } + }; + Ok(()) +} + +fn to_parts(this: &ast::Type) -> TypeParts { + match this { + ast::Type::Scalar(scalar) => TypeParts { + kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + ast::Type::Vector(scalar, components) => TypeParts { + kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: vec![*components as u32], + }, + ast::Type::Array(scalar, components) => TypeParts { + kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: components.clone(), + }, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + } +} + +fn type_from_parts(t: TypeParts) -> ast::Type { + match t.kind { + TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), + TypeKind::Vector => ast::Type::Vector( + scalar_from_parts(t.width, t.scalar_kind), + t.components[0] as u8, + ), + TypeKind::Array => { + ast::Type::Array(scalar_from_parts(t.width, t.scalar_kind), t.components) + } + TypeKind::Pointer => { + ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) + } + } +} + +#[derive(Eq, PartialEq, Clone)] +struct TypeParts { + kind: TypeKind, + scalar_kind: ast::ScalarKind, + width: u8, + state_space: ast::StateSpace, + components: Vec, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +enum TypeKind { + Scalar, + Vector, + Array, + Pointer, +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 1fdf3a6..8923718 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -3,22 +3,27 @@ use rspirv::{binary::Assemble, dr}; use std::{ borrow::Cow, cell::RefCell, - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, ffi::CString, + iter, marker::PhantomData, + mem, rc::Rc, }; +use std::hash::Hash; +mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; mod expand_arguments; +mod extract_globals; mod fix_special_registers; +mod insert_implicit_conversions; mod insert_mem_ssa_statements; mod normalize_identifiers; -mod normalize_predicates; -mod insert_implicit_conversions; mod normalize_labels; -mod extract_globals; +mod normalize_predicates; +mod emit_spirv; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -34,7 +39,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result, _>>()?; - /* let directives = hoist_function_globals(directives); let must_link_ptx_impl = ptx_impl_imports.len() > 0; let mut directives = ptx_impl_imports @@ -43,21 +47,19 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result>(); let mut builder = dr::Builder::new(); - builder.reserve_ids(id_defs.current_id()); + builder.reserve_ids(id_defs.current_id().0); let call_map = MethodsCallMap::new(&directives); let mut directives = - convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); + convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || { + SpirvWord(builder.id()) + })?; normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); + emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module - builder.set_version(1, 3); - emit_capabilities(&mut builder); - emit_extensions(&mut builder); - let opencl_id = emit_opencl_import(&mut builder); - emit_memory_model(&mut builder); - let mut map = TypeWordMap::new(&mut builder); - //emit_builtins(&mut builder, &mut map, &id_defs); - let mut kernel_info = HashMap::new(); + + todo!() + /* let (build_options, should_flush_denorms) = emit_denorm_build_string(&call_map, &denorm_information); let (directives, globals_use_map) = get_globals_use_map(directives); @@ -84,7 +86,6 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result( @@ -1273,3 +1274,399 @@ fn fn_arguments_to_variables<'a>( }) .collect::>() } + +fn hoist_function_globals(directives: Vec) -> Vec { + let mut result = Vec::with_capacity(directives.len()); + for directive in directives { + match directive { + Directive::Method(method) => { + for variable in method.globals { + result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); + } + result.push(Directive::Method(Function { + globals: Vec::new(), + ..method + })) + } + _ => result.push(directive), + } + } + result +} + +struct MethodsCallMap<'input> { + map: HashMap, HashSet>, +} + +impl<'input> MethodsCallMap<'input> { + fn new(module: &[Directive<'input>]) -> Self { + let mut directly_called_by = HashMap::new(); + for directive in module { + match directive { + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; + if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { + entry.insert(Vec::new()); + } + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { data, arguments }) => { + multi_hash_map_append( + &mut directly_called_by, + call_key, + arguments.func, + ); + } + _ => {} + } + } + } + _ => {} + } + } + let mut result = HashMap::new(); + for (&method_key, children) in directly_called_by.iter() { + let mut visited = HashSet::new(); + for child in children { + Self::add_call_map_single(&directly_called_by, &mut visited, *child); + } + result.insert(method_key, visited); + } + MethodsCallMap { map: result } + } + + fn add_call_map_single( + directly_called_by: &HashMap, Vec>, + visited: &mut HashSet, + current: SpirvWord, + ) { + if !visited.insert(current) { + return; + } + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { + for child in children { + Self::add_call_map_single(directly_called_by, visited, *child); + } + } + } + + fn get_kernel_children(&self, name: &'input str) -> impl Iterator { + self.map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + } + + fn kernels(&self) -> impl Iterator)> { + self.map + .iter() + .filter_map(|(method, children)| match method { + ast::MethodName::Kernel(kernel) => Some((*kernel, children)), + ast::MethodName::Func(..) => None, + }) + } + + fn methods( + &self, + ) -> impl Iterator, &HashSet)> { + self.map + .iter() + .map(|(method, children)| (*method, children)) + } + + fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) { + self.map + .get(&method) + .into_iter() + .flatten() + .copied() + .for_each(f); + } +} + +fn multi_hash_map_append< + K: Eq + std::hash::Hash, + V, + Collection: std::iter::Extend + std::default::Default, +>( + m: &mut HashMap, + key: K, + value: V, +) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().extend(iter::once(value)); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Default::default()).extend(iter::once(value)); + } + } +} + +fn normalize_variable_decls(directives: &mut Vec) { + for directive in directives { + match directive { + Directive::Method(Function { + body: Some(func), .. + }) => { + func[1..].sort_by_key(|s| match s { + Statement::Variable(_) => 0, + _ => 1, + }); + } + _ => (), + } + } +} + +// HACK ALERT! +// This function is a "good enough" heuristic of whetever to mark f16/f32 operations +// in the kernel as flushing denorms to zero or preserving them +// PTX support per-instruction ftz information. Unfortunately SPIR-V has no +// such capability, so instead we guesstimate which use is more common in the kernel +// and emit suitable execution mode +fn compute_denorm_information<'input>( + module: &[Directive<'input>], +) -> HashMap, HashMap> { + let mut denorm_methods = HashMap::new(); + for directive in module { + match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let mut flush_counter = DenormCountMap::new(); + let method_key = (**func_decl).borrow().name; + for statement in statements { + match statement { + Statement::Instruction(inst) => { + if let Some((flush, width)) = flush_to_zero(inst) { + denorm_count_map_update(&mut flush_counter, width, flush); + } + } + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} + Statement::Conditional(_) => {} + Statement::Conversion(_) => {} + Statement::Constant(_) => {} + Statement::RetValue(_, _) => {} + Statement::Label(_) => {} + Statement::Variable(_) => {} + Statement::PtrAccess { .. } => {} + Statement::RepackVector(_) => {} + Statement::FunctionPointer(_) => {} + } + } + denorm_methods.insert(method_key, flush_counter); + } + } + } + denorm_methods + .into_iter() + .map(|(name, v)| { + let width_to_denorm = v + .into_iter() + .map(|(k, flush_over_preserve)| { + let mode = if flush_over_preserve > 0 { + spirv::FPDenormMode::FlushToZero + } else { + spirv::FPDenormMode::Preserve + }; + (k, (mode, flush_over_preserve)) + }) + .collect(); + (name, width_to_denorm) + }) + .collect() +} + +fn flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { + match this { + ast::Instruction::Ld { .. } => None, + ast::Instruction::St { .. } => None, + ast::Instruction::Mov { .. } => None, + ast::Instruction::Not { .. } => None, + ast::Instruction::Bra { .. } => None, + ast::Instruction::Shl { .. } => None, + ast::Instruction::Shr { .. } => None, + ast::Instruction::Ret { .. } => None, + ast::Instruction::Call { .. } => None, + ast::Instruction::Or { .. } => None, + ast::Instruction::And { .. } => None, + ast::Instruction::Cvta { .. } => None, + ast::Instruction::Selp { .. } => None, + ast::Instruction::Bar { .. } => None, + ast::Instruction::Atom { .. } => None, + ast::Instruction::AtomCas { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Add { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Mul { + data: ast::MulDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Mad { + data: ast::MadDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::ZeroExtend + | ast::CvtMode::SignExtend + | ast::CvtMode::Truncate + | ast::CvtMode::Bitcast + | ast::CvtMode::SaturateUnsignedToSigned + | ast::CvtMode::SaturateSignedToUnsigned + | ast::CvtMode::FPFromSigned(_) + | ast::CvtMode::FPFromUnsigned(_), + .. + }, + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Signed(_), + .. + } => None, + ast::Instruction::Clz { .. } => None, + ast::Instruction::Brev { .. } => None, + ast::Instruction::Popc { .. } => None, + ast::Instruction::Xor { .. } => None, + ast::Instruction::Bfe { .. } => None, + ast::Instruction::Bfi { .. } => None, + ast::Instruction::Rem { .. } => None, + ast::Instruction::Prmt { .. } => None, + ast::Instruction::Activemask { .. } => None, + ast::Instruction::Membar { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Add { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Float(float_control), + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, float_control.type_.size_of())), + ast::Instruction::Fma { data, .. } => data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())), + ast::Instruction::Setp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::SetpBool { data, .. } => data + .base + .flush_to_zero + .map(|ftz| (ftz, data.base.type_.size_of())), + ast::Instruction::Abs { data, .. } + | ast::Instruction::Rsqrt { data, .. } + | ast::Instruction::Neg { data, .. } + | ast::Instruction::Ex2 { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::Min { + data: ast::MinMaxDetails::Float(float_control), + .. + } + | ast::Instruction::Max { + data: ast::MinMaxDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())), + ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + // Modifier .ftz can only be specified when either .dtype or .atype + // is .f32 and applies only to single precision (.f32) inputs and results. + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::FPExtend { flush_to_zero } + | ast::CvtMode::FPTruncate { flush_to_zero, .. } + | ast::CvtMode::FPRound { flush_to_zero, .. } + | ast::CvtMode::SignedFromFP { flush_to_zero, .. } + | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. }, + .. + }, + .. + } => flush_to_zero.map(|ftz| (ftz, 4)), + ast::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + type_, + flush_to_zero, + .. + }), + .. + } => flush_to_zero.map(|ftz| (ftz, type_.size_of())), + ast::Instruction::Sin { data, .. } + | ast::Instruction::Cos { data, .. } + | ast::Instruction::Lg2 { data, .. } => { + Some((data.flush_to_zero, mem::size_of::() as u8)) + } + ptx_parser::Instruction::PrmtSlow { .. } => None, + ptx_parser::Instruction::Trap {} => None, + } +} + +type DenormCountMap = HashMap; + +fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { + let num_value = if value { 1 } else { -1 }; + denorm_count_map_update_impl(map, key, num_value); +} + +fn denorm_count_map_update_impl( + map: &mut DenormCountMap, + key: T, + num_value: isize, +) { + match map.entry(key) { + hash_map::Entry::Occupied(mut counter) => { + *(counter.get_mut()) += num_value; + } + hash_map::Entry::Vacant(entry) => { + entry.insert(num_value); + } + } +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 59815f2..39b464e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -514,8 +514,11 @@ pub trait Visitor { ) -> Result<(), Err>; } -impl, bool, bool) -> Result<(), Err>> - Visitor for Fn +impl< + T: Operand, + Err, + Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, + > Visitor for Fn { fn visit( &mut self, @@ -760,7 +763,7 @@ pub enum Type { Vector(ScalarType, u8), // .param.b32 foo[4]; Array(ScalarType, Vec), - Pointer(ScalarType, StateSpace) + Pointer(ScalarType, StateSpace), } impl Type { @@ -1097,7 +1100,7 @@ impl SetpData { let cmp_op = if type_kind == ScalarKind::Float { SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) } else { - match SetpCompareInt::try_from(cmp_op) { + match SetpCompareInt::try_from((cmp_op, type_kind)) { Ok(op) => SetpCompareOp::Integer(op), Err(err) => { state.errors.push(err); @@ -1129,10 +1132,14 @@ pub enum SetpCompareOp { pub enum SetpCompareInt { Eq, NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, + UnsignedLess, + UnsignedLessOrEq, + UnsignedGreater, + UnsignedGreaterOrEq, + SignedLess, + SignedLessOrEq, + SignedGreater, + SignedGreaterOrEq, } #[derive(PartialEq, Eq, Copy, Clone)] @@ -1153,29 +1160,41 @@ pub enum SetpCompareFloat { IsAnyNan, } -impl TryFrom for SetpCompareInt { +impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { type Error = PtxError; - fn try_from(value: RawSetpCompareOp) -> Result { - match value { - RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq), - RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq), - RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less), - RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq), - RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater), - RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq), - RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less), - RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq), - RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater), - RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq), - RawSetpCompareOp::Equ => Err(PtxError::WrongType), - RawSetpCompareOp::Neu => Err(PtxError::WrongType), - RawSetpCompareOp::Ltu => Err(PtxError::WrongType), - RawSetpCompareOp::Leu => Err(PtxError::WrongType), - RawSetpCompareOp::Gtu => Err(PtxError::WrongType), - RawSetpCompareOp::Geu => Err(PtxError::WrongType), - RawSetpCompareOp::Num => Err(PtxError::WrongType), - RawSetpCompareOp::Nan => Err(PtxError::WrongType), + fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result { + match (value, kind) { + (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), + (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLess) + } + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLessOrEq) + } + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { + Ok(SetpCompareInt::UnsignedLessOrEq) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreater) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreaterOrEq) + } + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { + Ok(SetpCompareInt::UnsignedGreaterOrEq) + } + (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), } } } @@ -1276,7 +1295,9 @@ impl CallArgs { .return_arguments .into_iter() .zip(details.return_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true, false)) + .map(|(param, (type_, space))| { + visitor.visit_ident(param, Some((type_, *space)), true, false) + }) .collect::, _>>()?; let func = visitor.visit_ident(self.func, None, false, false)?; let input_arguments = self @@ -1305,6 +1326,8 @@ pub enum CvtMode { SignExtend, Truncate, Bitcast, + SaturateUnsignedToSigned, + SaturateSignedToUnsigned, // float from float FPExtend { flush_to_zero: Option, @@ -1389,21 +1412,11 @@ impl CvtDetails { }, (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), - ( - ScalarKind::Unsigned | ScalarKind::Signed, - ScalarKind::Unsigned | ScalarKind::Signed, - ) => match dst.size_of().cmp(&src.size_of()) { - Ordering::Less => { - if dst.kind() != src.kind() { - errors.push(PtxError::Todo); - } - CvtMode::Truncate - } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::Truncate, Ordering::Equal => CvtMode::Bitcast, Ordering::Greater => { - if dst.kind() != src.kind() { - errors.push(PtxError::Todo); - } if src.kind() == ScalarKind::Signed { CvtMode::SignExtend } else {