mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 06:40:21 +00:00
Continue attempts at fixing code emission for method args
This commit is contained in:
parent
2e6f7e3fdc
commit
f70abd065b
3 changed files with 224 additions and 104 deletions
|
@ -96,6 +96,7 @@ pub struct MethodDeclaration<'input, ID> {
|
||||||
pub return_arguments: Vec<Variable<ID>>,
|
pub return_arguments: Vec<Variable<ID>>,
|
||||||
pub name: MethodName<'input, ID>,
|
pub name: MethodName<'input, ID>,
|
||||||
pub input_arguments: Vec<Variable<ID>>,
|
pub input_arguments: Vec<Variable<ID>>,
|
||||||
|
pub shared_mem: Option<Variable<ID>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Function<'a, ID, S> {
|
pub struct Function<'a, ID, S> {
|
||||||
|
|
|
@ -392,12 +392,12 @@ MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
|
||||||
".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
|
".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
|
||||||
let return_arguments = Vec::new();
|
let return_arguments = Vec::new();
|
||||||
let name = ast::MethodName::Kernel(name);
|
let name = ast::MethodName::Kernel(name);
|
||||||
ast::MethodDeclaration{ return_arguments, name, input_arguments }
|
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
|
||||||
},
|
},
|
||||||
".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
|
".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
|
||||||
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
|
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
|
||||||
let name = ast::MethodName::Func(name);
|
let name = ast::MethodName::Func(name);
|
||||||
ast::MethodDeclaration{ return_arguments, name, input_arguments }
|
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -562,7 +562,6 @@ fn emit_directives<'input>(
|
||||||
call_map,
|
call_map,
|
||||||
&directives,
|
&directives,
|
||||||
kernel_info,
|
kernel_info,
|
||||||
f.uses_shared_mem,
|
|
||||||
)?;
|
)?;
|
||||||
for t in f.tuning.iter() {
|
for t in f.tuning.iter() {
|
||||||
match *t {
|
match *t {
|
||||||
|
@ -1038,10 +1037,9 @@ fn emit_function_header<'a>(
|
||||||
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
||||||
direcitves: &[Directive],
|
direcitves: &[Directive],
|
||||||
kernel_info: &mut HashMap<String, KernelInfo>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
uses_shared_mem: bool,
|
|
||||||
) -> Result<spirv::Word, TranslateError> {
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
if let ast::MethodName::Kernel(name) = func_decl.name {
|
if let ast::MethodName::Kernel(name) = func_decl.name {
|
||||||
let input_args = if !uses_shared_mem {
|
let input_args = if func_decl.shared_mem.is_none() {
|
||||||
func_decl.input_arguments.as_slice()
|
func_decl.input_arguments.as_slice()
|
||||||
} else {
|
} else {
|
||||||
&func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
|
&func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
|
||||||
|
@ -1054,7 +1052,7 @@ fn emit_function_header<'a>(
|
||||||
name.to_string(),
|
name.to_string(),
|
||||||
KernelInfo {
|
KernelInfo {
|
||||||
arguments_sizes: args_lens,
|
arguments_sizes: args_lens,
|
||||||
uses_shared_mem: uses_shared_mem,
|
uses_shared_mem: func_decl.shared_mem.is_some(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1218,7 +1216,7 @@ fn rename_fn_params<'a, 'b>(
|
||||||
) -> Vec<ast::Variable<spirv::Word>> {
|
) -> Vec<ast::Variable<spirv::Word>> {
|
||||||
args.iter()
|
args.iter()
|
||||||
.map(|a| ast::Variable {
|
.map(|a| ast::Variable {
|
||||||
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false),
|
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
|
||||||
v_type: a.v_type.clone(),
|
v_type: a.v_type.clone(),
|
||||||
state_space: a.state_space,
|
state_space: a.state_space,
|
||||||
align: a.align,
|
align: a.align,
|
||||||
|
@ -1245,7 +1243,6 @@ fn to_ssa<'input, 'b>(
|
||||||
globals: Vec::new(),
|
globals: Vec::new(),
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning,
|
tuning,
|
||||||
uses_shared_mem: false,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1276,7 +1273,6 @@ fn to_ssa<'input, 'b>(
|
||||||
body: Some(f_body),
|
body: Some(f_body),
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning,
|
tuning,
|
||||||
uses_shared_mem: false,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1529,18 +1525,8 @@ fn convert_to_typed_statements(
|
||||||
match s {
|
match s {
|
||||||
Statement::Instruction(inst) => match inst {
|
Statement::Instruction(inst) => match inst {
|
||||||
ast::Instruction::Call(call) => {
|
ast::Instruction::Call(call) => {
|
||||||
// TODO: error out if lengths don't match
|
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
|
||||||
let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow();
|
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
|
||||||
let return_arguments =
|
|
||||||
to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments);
|
|
||||||
let input_arguments =
|
|
||||||
to_resolved_fn_args(call.param_list, &*fn_def.input_arguments);
|
|
||||||
let resolved_call = ResolvedCall {
|
|
||||||
uniform: call.uniform,
|
|
||||||
return_arguments,
|
|
||||||
name: call.func,
|
|
||||||
input_arguments,
|
|
||||||
};
|
|
||||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||||
let reresolved_call = resolved_call.visit(&mut visitor)?;
|
let reresolved_call = resolved_call.visit(&mut visitor)?;
|
||||||
visitor.func.push(reresolved_call);
|
visitor.func.push(reresolved_call);
|
||||||
|
@ -1683,6 +1669,7 @@ fn to_ptx_impl_atomic_call(
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
shared_mem: None,
|
||||||
};
|
};
|
||||||
let func = Function {
|
let func = Function {
|
||||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||||
|
@ -1690,7 +1677,6 @@ fn to_ptx_impl_atomic_call(
|
||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
uses_shared_mem: false,
|
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
|
@ -1772,6 +1758,7 @@ fn to_ptx_impl_bfe_call(
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
shared_mem: None,
|
||||||
};
|
};
|
||||||
let func = Function {
|
let func = Function {
|
||||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||||
|
@ -1779,7 +1766,6 @@ fn to_ptx_impl_bfe_call(
|
||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
uses_shared_mem: false,
|
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
|
@ -1871,6 +1857,7 @@ fn to_ptx_impl_bfi_call(
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
shared_mem: None,
|
||||||
};
|
};
|
||||||
let func = Function {
|
let func = Function {
|
||||||
func_decl: Rc::new(RefCell::new(func_decl)),
|
func_decl: Rc::new(RefCell::new(func_decl)),
|
||||||
|
@ -1878,7 +1865,6 @@ fn to_ptx_impl_bfi_call(
|
||||||
body: None,
|
body: None,
|
||||||
import_as: Some(entry.key().clone()),
|
import_as: Some(entry.key().clone()),
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
uses_shared_mem: false,
|
|
||||||
};
|
};
|
||||||
entry.insert(Directive::Method(func));
|
entry.insert(Directive::Method(func));
|
||||||
fn_id
|
fn_id
|
||||||
|
@ -2009,42 +1995,44 @@ fn normalize_predicates(
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
How do we handle arguments:
|
||||||
|
- input .params
|
||||||
|
.param .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
We do this for two reasons. One, common treatment for argument-declared
|
||||||
|
.param variables and .param variables inside function (we assume that
|
||||||
|
at SPIR-V level every .param is a pointer in Function storage class). Two,
|
||||||
|
PTX devs in their infinite wisdom decided that .reg arguments are writable
|
||||||
|
- input .regs
|
||||||
|
.reg .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
with the difference that %2 is defined as a variable and not temp
|
||||||
|
- output .regs
|
||||||
|
.reg .b64 out_arg
|
||||||
|
get just a variable declaration:
|
||||||
|
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||||
|
- output .params
|
||||||
|
.param .b64 out_arg
|
||||||
|
get treated the same as input .params, because there's no difference
|
||||||
|
*/
|
||||||
fn insert_mem_ssa_statements<'a, 'b>(
|
fn insert_mem_ssa_statements<'a, 'b>(
|
||||||
func: Vec<TypedStatement>,
|
func: Vec<TypedStatement>,
|
||||||
id_def: &mut NumericIdResolver,
|
id_def: &mut NumericIdResolver,
|
||||||
fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
|
fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
|
||||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
let mut result = Vec::with_capacity(func.len());
|
let mut result = Vec::with_capacity(func.len());
|
||||||
for arg in fn_decl.return_arguments.iter() {
|
|
||||||
result.push(Statement::Variable(ast::Variable {
|
|
||||||
align: arg.align,
|
|
||||||
v_type: arg.v_type.clone(),
|
|
||||||
state_space: arg.state_space,
|
|
||||||
name: arg.name,
|
|
||||||
array_init: arg.array_init.clone(),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
for arg in fn_decl.input_arguments.iter_mut() {
|
for arg in fn_decl.input_arguments.iter_mut() {
|
||||||
let typ = arg.v_type.clone();
|
insert_mem_ssa_argument(id_def, &mut result, arg);
|
||||||
let state_space = arg.state_space;
|
}
|
||||||
let new_id = id_def.register_intermediate(Some((typ.clone(), state_space)));
|
for arg in fn_decl.return_arguments.iter() {
|
||||||
result.push(Statement::Variable(ast::Variable {
|
insert_mem_ssa_argument_reg_return(&mut result, arg);
|
||||||
align: arg.align,
|
|
||||||
v_type: arg.v_type.clone(),
|
|
||||||
state_space: arg.state_space,
|
|
||||||
name: arg.name,
|
|
||||||
array_init: Vec::new(),
|
|
||||||
}));
|
|
||||||
result.push(Statement::StoreVar(StoreVarDetails {
|
|
||||||
arg: ast::Arg2St {
|
|
||||||
src1: arg.name,
|
|
||||||
src2: new_id,
|
|
||||||
},
|
|
||||||
state_space,
|
|
||||||
typ,
|
|
||||||
member_index: None,
|
|
||||||
}));
|
|
||||||
arg.name = new_id;
|
|
||||||
}
|
}
|
||||||
for s in func {
|
for s in func {
|
||||||
match s {
|
match s {
|
||||||
|
@ -2054,22 +2042,26 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
||||||
Statement::Instruction(inst) => match inst {
|
Statement::Instruction(inst) => match inst {
|
||||||
ast::Instruction::Ret(d) => {
|
ast::Instruction::Ret(d) => {
|
||||||
// TODO: handle multiple output args
|
// TODO: handle multiple output args
|
||||||
if let &[out_param] = &fn_decl.return_arguments.as_slice() {
|
match &fn_decl.return_arguments[..] {
|
||||||
let (typ, space, _) = id_def.get_typed(out_param.name)?;
|
[return_reg] => {
|
||||||
let new_id = id_def.register_intermediate(Some((typ.clone(), space)));
|
let new_id = id_def.register_intermediate(Some((
|
||||||
result.push(Statement::LoadVar(LoadVarDetails {
|
return_reg.v_type.clone(),
|
||||||
arg: ast::Arg2 {
|
ast::StateSpace::Reg,
|
||||||
dst: new_id,
|
)));
|
||||||
src: out_param.name,
|
result.push(Statement::LoadVar(LoadVarDetails {
|
||||||
},
|
arg: ast::Arg2 {
|
||||||
// TODO: ret with stateful conversion
|
dst: new_id,
|
||||||
state_space: new_todo!(),
|
src: return_reg.name,
|
||||||
typ: typ.clone(),
|
},
|
||||||
member_index: None,
|
// TODO: ret with stateful conversion
|
||||||
}));
|
state_space: ast::StateSpace::Reg,
|
||||||
result.push(Statement::RetValue(d, new_id));
|
typ: return_reg.v_type.clone(),
|
||||||
} else {
|
member_index: None,
|
||||||
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
|
}));
|
||||||
|
result.push(Statement::RetValue(d, new_id));
|
||||||
|
}
|
||||||
|
[] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
|
||||||
|
_ => unimplemented!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
|
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
|
||||||
|
@ -2107,6 +2099,43 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument(
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &mut ast::Variable<spirv::Word>,
|
||||||
|
) {
|
||||||
|
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
func.push(Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::Arg2St {
|
||||||
|
src1: arg.name,
|
||||||
|
src2: new_id,
|
||||||
|
},
|
||||||
|
typ: arg.v_type.clone(),
|
||||||
|
member_index: None,
|
||||||
|
}));
|
||||||
|
arg.name = new_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument_reg_return(
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &ast::Variable<spirv::Word>,
|
||||||
|
) {
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: arg.state_space,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: arg.array_init.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
|
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
|
||||||
fn visit(
|
fn visit(
|
||||||
self,
|
self,
|
||||||
|
@ -2202,7 +2231,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
src1: symbol,
|
src1: symbol,
|
||||||
src2: generated_id,
|
src2: generated_id,
|
||||||
},
|
},
|
||||||
state_space: ast::StateSpace::Reg,
|
|
||||||
typ: var_type,
|
typ: var_type,
|
||||||
member_index: member_index.map(|(idx, _)| idx),
|
member_index: member_index.map(|(idx, _)| idx),
|
||||||
}));
|
}));
|
||||||
|
@ -4162,10 +4190,10 @@ fn emit_load_var(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_identifiers<'a, 'b>(
|
fn normalize_identifiers<'input, 'b>(
|
||||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||||
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||||
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
|
func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
|
||||||
) -> Result<Vec<NormalizedStatement>, TranslateError> {
|
) -> Result<Vec<NormalizedStatement>, TranslateError> {
|
||||||
for s in func.iter() {
|
for s in func.iter() {
|
||||||
match s {
|
match s {
|
||||||
|
@ -4796,12 +4824,92 @@ impl SpecialRegistersMap {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct FnSigMapper<'input> {
|
||||||
|
// true - stays as return argument
|
||||||
|
// false - is moved to input argument
|
||||||
|
return_param_args: Vec<bool>,
|
||||||
|
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input> FnSigMapper<'input> {
|
||||||
|
fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
|
||||||
|
let return_param_args = method
|
||||||
|
.return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.state_space != ast::StateSpace::Param)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut new_return_arguments = Vec::new();
|
||||||
|
for arg in method.return_arguments.into_iter() {
|
||||||
|
if arg.state_space == ast::StateSpace::Param {
|
||||||
|
method.input_arguments.push(arg);
|
||||||
|
} else {
|
||||||
|
new_return_arguments.push(arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
method.return_arguments = new_return_arguments;
|
||||||
|
FnSigMapper {
|
||||||
|
return_param_args,
|
||||||
|
func_decl: Rc::new(RefCell::new(method)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_in_spirv_repr(
|
||||||
|
&self,
|
||||||
|
call_inst: ast::CallInst<NormalizedArgParams>,
|
||||||
|
) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
|
||||||
|
let func_decl = (*self.func_decl).borrow();
|
||||||
|
let mut return_arguments = Vec::new();
|
||||||
|
let mut input_arguments = call_inst
|
||||||
|
.param_list
|
||||||
|
.into_iter()
|
||||||
|
.zip(func_decl.input_arguments.iter())
|
||||||
|
.map(|(id, var)| (id, var.v_type.clone(), var.state_space))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut func_decl_return_iter = func_decl.return_arguments.iter();
|
||||||
|
let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
|
||||||
|
for (idx, id) in call_inst.ret_params.iter().enumerate() {
|
||||||
|
let stays_as_return = match self.return_param_args.get(idx) {
|
||||||
|
Some(x) => *x,
|
||||||
|
None => return Err(TranslateError::MismatchedType),
|
||||||
|
};
|
||||||
|
if stays_as_return {
|
||||||
|
if let Some(var) = func_decl_return_iter.next() {
|
||||||
|
return_arguments.push((*id, var.v_type.clone(), var.state_space));
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if let Some(var) = func_decl_input_iter.next() {
|
||||||
|
input_arguments.push((
|
||||||
|
ast::Operand::Reg(*id),
|
||||||
|
var.v_type.clone(),
|
||||||
|
var.state_space,
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if return_arguments.len() != func_decl.return_arguments.len()
|
||||||
|
|| input_arguments.len() != func_decl.input_arguments.len()
|
||||||
|
{
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
Ok(ResolvedCall {
|
||||||
|
return_arguments,
|
||||||
|
input_arguments,
|
||||||
|
uniform: call_inst.uniform,
|
||||||
|
name: call_inst.func,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct GlobalStringIdResolver<'input> {
|
struct GlobalStringIdResolver<'input> {
|
||||||
current_id: spirv::Word,
|
current_id: spirv::Word,
|
||||||
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
||||||
variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
|
||||||
special_registers: SpecialRegistersMap,
|
special_registers: SpecialRegistersMap,
|
||||||
fns: HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
|
fns: HashMap<spirv::Word, FnSigMapper<'input>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'input> GlobalStringIdResolver<'input> {
|
impl<'input> GlobalStringIdResolver<'input> {
|
||||||
|
@ -4885,45 +4993,36 @@ impl<'input> GlobalStringIdResolver<'input> {
|
||||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||||
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
|
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
|
||||||
};
|
};
|
||||||
let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration {
|
let fn_decl = ast::MethodDeclaration {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
name,
|
name,
|
||||||
input_arguments,
|
input_arguments,
|
||||||
}));
|
shared_mem: None,
|
||||||
self.fns.insert(name_id, Rc::clone(&new_fn_decl));
|
};
|
||||||
|
let new_fn_decl = if !fn_decl.name.is_kernel() {
|
||||||
|
let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
|
||||||
|
let new_fn_decl = resolver.func_decl.clone();
|
||||||
|
self.fns.insert(name_id, resolver);
|
||||||
|
new_fn_decl
|
||||||
|
} else {
|
||||||
|
Rc::new(RefCell::new(fn_decl))
|
||||||
|
};
|
||||||
Ok((
|
Ok((
|
||||||
fn_resolver,
|
fn_resolver,
|
||||||
GlobalFnDeclResolver {
|
GlobalFnDeclResolver { fns: &self.fns },
|
||||||
variables: &self.variables,
|
|
||||||
fns: &self.fns,
|
|
||||||
},
|
|
||||||
new_fn_decl,
|
new_fn_decl,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct GlobalFnDeclResolver<'input, 'a> {
|
pub struct GlobalFnDeclResolver<'input, 'a> {
|
||||||
variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
|
fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
|
||||||
fns: &'a HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
|
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
|
||||||
fn get_fn_decl(
|
fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
|
||||||
&self,
|
|
||||||
id: spirv::Word,
|
|
||||||
) -> Result<&Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
|
|
||||||
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
|
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_fn_decl_str(
|
|
||||||
&self,
|
|
||||||
id: &str,
|
|
||||||
) -> Result<&'a Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
|
|
||||||
match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
|
|
||||||
Some(Some(fn_d)) => Ok(fn_d),
|
|
||||||
_ => Err(TranslateError::UnknownSymbol),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FnStringIdResolver<'input, 'b> {
|
struct FnStringIdResolver<'input, 'b> {
|
||||||
|
@ -5209,7 +5308,6 @@ struct LoadVarDetails {
|
||||||
struct StoreVarDetails {
|
struct StoreVarDetails {
|
||||||
arg: ast::Arg2St<ExpandedArgParams>,
|
arg: ast::Arg2St<ExpandedArgParams>,
|
||||||
typ: ast::Type,
|
typ: ast::Type,
|
||||||
state_space: ast::StateSpace,
|
|
||||||
member_index: Option<u8>,
|
member_index: Option<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5300,7 +5398,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
|
||||||
self,
|
self,
|
||||||
visitor: &mut V,
|
visitor: &mut V,
|
||||||
) -> Result<ResolvedCall<To>, TranslateError> {
|
) -> Result<ResolvedCall<To>, TranslateError> {
|
||||||
let ret_params = self
|
let return_arguments = self
|
||||||
.return_arguments
|
.return_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
||||||
|
@ -5324,7 +5422,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
)?;
|
)?;
|
||||||
let param_list = self
|
let input_arguments = self
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
||||||
|
@ -5342,9 +5440,9 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(ResolvedCall {
|
Ok(ResolvedCall {
|
||||||
uniform: self.uniform,
|
uniform: self.uniform,
|
||||||
return_arguments: ret_params,
|
return_arguments,
|
||||||
name: func,
|
name: func,
|
||||||
input_arguments: param_list,
|
input_arguments,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5485,7 +5583,6 @@ struct Function<'input> {
|
||||||
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||||
pub globals: Vec<ast::Variable<spirv::Word>>,
|
pub globals: Vec<ast::Variable<spirv::Word>>,
|
||||||
pub body: Option<Vec<ExpandedStatement>>,
|
pub body: Option<Vec<ExpandedStatement>>,
|
||||||
pub uses_shared_mem: bool,
|
|
||||||
import_as: Option<String>,
|
import_as: Option<String>,
|
||||||
tuning: Vec<ast::TuningDirective>,
|
tuning: Vec<ast::TuningDirective>,
|
||||||
}
|
}
|
||||||
|
@ -7185,6 +7282,19 @@ fn default_implicit_conversion_space(
|
||||||
},
|
},
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(TranslateError::MismatchedType),
|
||||||
}
|
}
|
||||||
|
} else if instruction_space.is_compatible(ast::StateSpace::Reg) {
|
||||||
|
match instruction_type {
|
||||||
|
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||||
|
if operand_space == *instruction_ptr_space =>
|
||||||
|
{
|
||||||
|
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(TranslateError::MismatchedType),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(TranslateError::MismatchedType)
|
Err(TranslateError::MismatchedType)
|
||||||
}
|
}
|
||||||
|
@ -7432,6 +7542,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'input, ID> ast::MethodName<'input, ID> {
|
||||||
|
fn is_kernel(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ast::MethodName::Kernel(..) => true,
|
||||||
|
ast::MethodName::Func(..) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue