mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Port ssa conversion
This commit is contained in:
parent
3e0a15ac84
commit
cccd37f6ee
3 changed files with 286 additions and 9 deletions
|
@ -527,9 +527,3 @@ fn convert_to_stateful_memory_access_postprocess(
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
||||
this == other
|
||||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
||||
}
|
||||
|
|
276
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
276
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
|
@ -0,0 +1,276 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
How do we handle arguments:
|
||||
- input .params in kernels
|
||||
.param .b64 in_arg
|
||||
get turned into this SPIR-V:
|
||||
%1 = OpFunctionParameter %ulong
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %1
|
||||
We do this for two reasons. One, common treatment for argument-declared
|
||||
.param variables and .param variables inside function (we assume that
|
||||
at SPIR-V level every .param is a pointer in Function storage class)
|
||||
- input .params in functions
|
||||
.param .b64 in_arg
|
||||
get turned into this SPIR-V:
|
||||
%1 = OpFunctionParameter %_ptr_Function_ulong
|
||||
- input .regs
|
||||
.reg .b64 in_arg
|
||||
get turned into the same SPIR-V as kernel .params:
|
||||
%1 = OpFunctionParameter %ulong
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %1
|
||||
- output .regs
|
||||
.reg .b64 out_arg
|
||||
get just a variable declaration:
|
||||
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||
- output .params don't exist, they have been moved to input positions
|
||||
by an earlier pass
|
||||
Distinguishing betweem kernel .params and function .params is not the
|
||||
cleanest solution. Alternatively, we could "deparamize" all kernel .param
|
||||
arguments by turning them into .reg arguments like this:
|
||||
.param .b64 arg -> .reg ptr<.b64,.param> arg
|
||||
This has the massive downside that this transformation would have to run
|
||||
very early and would muddy up already difficult code. It's simpler to just
|
||||
have an if here
|
||||
*/
|
||||
pub(super) fn run<'a, 'b>(
|
||||
func: Vec<TypedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for arg in fn_decl.input_arguments.iter_mut() {
|
||||
insert_mem_ssa_argument(
|
||||
id_def,
|
||||
&mut result,
|
||||
arg,
|
||||
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
|
||||
);
|
||||
}
|
||||
for arg in fn_decl.return_arguments.iter() {
|
||||
insert_mem_ssa_argument_reg_return(&mut result, arg);
|
||||
}
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Ret { data } => {
|
||||
// TODO: handle multiple output args
|
||||
match &fn_decl.return_arguments[..] {
|
||||
[return_reg] => {
|
||||
let new_id = id_def.register_intermediate(Some((
|
||||
return_reg.v_type.clone(),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
result.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: ast::LdArgs {
|
||||
dst: new_id,
|
||||
src: return_reg.name,
|
||||
},
|
||||
typ: return_reg.v_type.clone(),
|
||||
member_index: None,
|
||||
}));
|
||||
result.push(Statement::RetValue(data, new_id));
|
||||
}
|
||||
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
inst => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::Instruction(inst),
|
||||
)?,
|
||||
},
|
||||
Statement::Conditional(bra) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
|
||||
}
|
||||
Statement::Conversion(conv) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
|
||||
}
|
||||
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::PtrAccess(ptr_access),
|
||||
)?,
|
||||
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::RepackVector(repack),
|
||||
)?,
|
||||
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::FunctionPointer(func_ptr),
|
||||
)?,
|
||||
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
|
||||
result.push(s)
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_argument(
|
||||
id_def: &mut NumericIdResolver,
|
||||
func: &mut Vec<TypedStatement>,
|
||||
arg: &mut ast::Variable<SpirvWord>,
|
||||
is_kernel: bool,
|
||||
) {
|
||||
if !is_kernel && arg.state_space == ast::StateSpace::Param {
|
||||
return;
|
||||
}
|
||||
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
|
||||
func.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: arg.name,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
func.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::StArgs {
|
||||
src1: arg.name,
|
||||
src2: new_id,
|
||||
},
|
||||
typ: arg.v_type.clone(),
|
||||
member_index: None,
|
||||
}));
|
||||
arg.name = new_id;
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_argument_reg_return(
|
||||
func: &mut Vec<TypedStatement>,
|
||||
arg: &ast::Variable<SpirvWord>,
|
||||
) {
|
||||
func.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: arg.state_space,
|
||||
name: arg.name,
|
||||
array_init: arg.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_statement_default<'a, 'input>(
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
stmt: TypedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = InsertMemSSAVisitor {
|
||||
id_def,
|
||||
func,
|
||||
post_statements: Vec::new(),
|
||||
};
|
||||
let new_stmt = stmt.visit_map(&mut visitor)?;
|
||||
visitor.func.push(new_stmt);
|
||||
visitor.func.extend(visitor.post_statements);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
post_statements: Vec<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn symbol(
|
||||
&mut self,
|
||||
symbol: SpirvWord,
|
||||
member_index: Option<u8>,
|
||||
expected: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if expected.is_none() {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||
if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let member_index = match member_index {
|
||||
Some(idx) => {
|
||||
let vector_width = match var_type {
|
||||
ast::Type::Vector(scalar_t, width) => {
|
||||
var_type = ast::Type::Scalar(scalar_t);
|
||||
width
|
||||
}
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
Some((
|
||||
idx,
|
||||
if self.id_def.special_registers.get(symbol).is_some() {
|
||||
Some(vector_width)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let generated_id = self
|
||||
.id_def
|
||||
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
|
||||
if !is_dst {
|
||||
self.func.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: ast::LdArgs {
|
||||
dst: generated_id,
|
||||
src: symbol,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index,
|
||||
}));
|
||||
} else {
|
||||
self.post_statements
|
||||
.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::StArgs {
|
||||
src1: symbol,
|
||||
src2: generated_id,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index: member_index.map(|(idx, _)| idx),
|
||||
}));
|
||||
}
|
||||
Ok(generated_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: TypedOperand,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
Ok(match operand {
|
||||
TypedOperand::Reg(reg) => {
|
||||
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
|
||||
}
|
||||
TypedOperand::RegOffset(reg, offset) => {
|
||||
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
|
||||
}
|
||||
op @ TypedOperand::Imm(..) => op,
|
||||
TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember(
|
||||
self.symbol(symbol, Some(index), type_space, is_dst)?,
|
||||
index,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.symbol(args, None, type_space, is_dst)
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ use std::{
|
|||
mod convert_to_stateful_memory_access;
|
||||
mod convert_to_typed;
|
||||
mod fix_special_registers;
|
||||
mod insert_mem_ssa_statements;
|
||||
mod normalize_identifiers;
|
||||
mod normalize_predicates;
|
||||
|
||||
|
@ -175,13 +176,13 @@ fn to_ssa<'input, 'b>(
|
|||
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
todo!()
|
||||
/*
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
let ssa_statements = insert_mem_ssa_statements::run(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
&mut (*func_decl).borrow_mut(),
|
||||
)?;
|
||||
todo!()
|
||||
/*
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||
let expanded_statements =
|
||||
|
@ -1206,3 +1207,9 @@ impl<
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
||||
this == other
|
||||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue