Fix more failing tests

This commit is contained in:
Andrzej Janik 2024-09-03 16:24:50 +02:00
parent 340ad86d56
commit 7a45b44854
3 changed files with 67 additions and 27 deletions

View file

@ -26,6 +26,15 @@ pub(crate) fn run(
src: src_reg,
}));
}
ast::Instruction::Call { data, arguments } => {
let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call =
Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);

View file

@ -668,57 +668,56 @@ impl<'input> FnSigMapper<'input> {
}
}
/*
fn resolve_in_spirv_repr(
&self,
call_inst: ast::CallInst<NormalizedArgParams>,
) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
data: ast::CallDetails,
arguments: ast::CallArgs<ast::ParsedOperand<SpirvWord>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
let func_decl = (*self.func_decl).borrow();
let mut return_arguments = Vec::new();
let mut input_arguments = call_inst
.param_list
.into_iter()
.zip(func_decl.input_arguments.iter())
.map(|(id, var)| (id, var.v_type.clone(), var.state_space))
.collect::<Vec<_>>();
let mut data_return = Vec::new();
let mut arguments_return = Vec::new();
let mut data_input = data.input_arguments;
let mut arguments_input = arguments.input_arguments;
let mut func_decl_return_iter = func_decl.return_arguments.iter();
let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
for (idx, id) in call_inst.ret_params.iter().enumerate() {
let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter();
for (idx, id) in arguments.return_arguments.iter().enumerate() {
let stays_as_return = match self.return_param_args.get(idx) {
Some(x) => *x,
None => return Err(TranslateError::MismatchedType),
};
if stays_as_return {
if let Some(var) = func_decl_return_iter.next() {
return_arguments.push((*id, var.v_type.clone(), var.state_space));
data_return.push((var.v_type.clone(), var.state_space));
arguments_return.push(*id);
} else {
return Err(TranslateError::MismatchedType);
}
} else {
if let Some(var) = func_decl_input_iter.next() {
input_arguments.push((
ast::Operand::Reg(*id),
var.v_type.clone(),
var.state_space,
));
data_input.push((var.v_type.clone(), var.state_space));
arguments_input.push(ast::ParsedOperand::Reg(*id));
} else {
return Err(TranslateError::MismatchedType);
}
}
}
if return_arguments.len() != func_decl.return_arguments.len()
|| input_arguments.len() != func_decl.input_arguments.len()
if arguments_return.len() != func_decl.return_arguments.len()
|| arguments_input.len() != func_decl.input_arguments.len()
{
return Err(TranslateError::MismatchedType);
}
Ok(ResolvedCall {
return_arguments,
input_arguments,
uniform: call_inst.uniform,
name: call_inst.func,
})
let data = ast::CallDetails {
uniform: data.uniform,
return_arguments: data_return,
input_arguments: data_input,
};
let arguments = ast::CallArgs {
func: arguments.func,
return_arguments: arguments_return,
input_arguments: arguments_input,
};
Ok(ast::Instruction::Call { data, arguments })
}
*/
}
enum Statement<I, P: ast::Operand> {

View file

@ -1663,6 +1663,38 @@ derive_parser!(
RawLdStQualifier = { .weak, .volatile };
StateSpace = { .global };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc
ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
if cop.is_some() && level_eviction_priority.is_some() {
state.errors.push(PtxError::SyntaxError);
}
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
state.errors.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: global,
caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(),
typ: Type::maybe_vector(vec, type_),
non_coherent: true
},
arguments: LdArgs { dst:d, src:a }
}
}
.cop: RawLdCacheOperator = { .ca, .cg, .cs };
.level::eviction_priority: EvictionPriority =
{ .L1::evict_normal, .L1::evict_unchanged,
.L1::evict_first, .L1::evict_last, .L1::no_allocate};
.level::cache_hint = { .L2::cache_hint };
.level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B };
.vec: VectorPrefix = { .v2, .v4 };
.type: ScalarType = { .b8, .b16, .b32, .b64, .b128,
.u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.f32, .f64 };
StateSpace = { .global };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add
add.type d, a, b => {
Instruction::Add {