diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index c2af204..ab5b246 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -26,6 +26,15 @@ pub(crate) fn run( src: src_reg, })); } + ast::Instruction::Call { data, arguments } => { + let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?; + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = + Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?); + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 92d1bf4..2be6297 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -668,57 +668,56 @@ impl<'input> FnSigMapper<'input> { } } - /* fn resolve_in_spirv_repr( &self, - call_inst: ast::CallInst, - ) -> Result, TranslateError> { + data: ast::CallDetails, + arguments: ast::CallArgs>, + ) -> Result>, TranslateError> { let func_decl = (*self.func_decl).borrow(); - let mut return_arguments = Vec::new(); - let mut input_arguments = call_inst - .param_list - .into_iter() - .zip(func_decl.input_arguments.iter()) - .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) - .collect::>(); + let mut data_return = Vec::new(); + let mut arguments_return = Vec::new(); + let mut data_input = data.input_arguments; + let mut arguments_input = arguments.input_arguments; let mut func_decl_return_iter = func_decl.return_arguments.iter(); - let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); - for (idx, id) in call_inst.ret_params.iter().enumerate() { + let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter(); + for (idx, id) in arguments.return_arguments.iter().enumerate() { let stays_as_return = match self.return_param_args.get(idx) { Some(x) => *x, None => return Err(TranslateError::MismatchedType), }; if stays_as_return { if let Some(var) = func_decl_return_iter.next() { - return_arguments.push((*id, var.v_type.clone(), var.state_space)); + data_return.push((var.v_type.clone(), var.state_space)); + arguments_return.push(*id); } else { return Err(TranslateError::MismatchedType); } } else { if let Some(var) = func_decl_input_iter.next() { - input_arguments.push(( - ast::Operand::Reg(*id), - var.v_type.clone(), - var.state_space, - )); + data_input.push((var.v_type.clone(), var.state_space)); + arguments_input.push(ast::ParsedOperand::Reg(*id)); } else { return Err(TranslateError::MismatchedType); } } } - if return_arguments.len() != func_decl.return_arguments.len() - || input_arguments.len() != func_decl.input_arguments.len() + if arguments_return.len() != func_decl.return_arguments.len() + || arguments_input.len() != func_decl.input_arguments.len() { return Err(TranslateError::MismatchedType); } - Ok(ResolvedCall { - return_arguments, - input_arguments, - uniform: call_inst.uniform, - name: call_inst.func, - }) + let data = ast::CallDetails { + uniform: data.uniform, + return_arguments: data_return, + input_arguments: data_input, + }; + let arguments = ast::CallArgs { + func: arguments.func, + return_arguments: arguments_return, + input_arguments: arguments_input, + }; + Ok(ast::Instruction::Call { data, arguments }) } - */ } enum Statement { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index dfe78ee..3d09511 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1663,6 +1663,38 @@ derive_parser!( RawLdStQualifier = { .weak, .volatile }; StateSpace = { .global }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc + ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if cop.is_some() && level_eviction_priority.is_some() { + state.errors.push(PtxError::SyntaxError); + } + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: global, + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: true + }, + arguments: LdArgs { dst:d, src:a } + } + } + .cop: RawLdCacheOperator = { .ca, .cg, .cs }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, + .L1::evict_first, .L1::evict_last, .L1::no_allocate}; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + StateSpace = { .global }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add add.type d, a, b => { Instruction::Add {