mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-07-28 03:48:38 +00:00
Port ssa conversion
This commit is contained in:
parent
3e0a15ac84
commit
cccd37f6ee
3 changed files with 286 additions and 9 deletions
|
@ -527,9 +527,3 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
|
||||||
this == other
|
|
||||||
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
|
||||||
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
|
||||||
}
|
|
||||||
|
|
276
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
276
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
/*
|
||||||
|
How do we handle arguments:
|
||||||
|
- input .params in kernels
|
||||||
|
.param .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
We do this for two reasons. One, common treatment for argument-declared
|
||||||
|
.param variables and .param variables inside function (we assume that
|
||||||
|
at SPIR-V level every .param is a pointer in Function storage class)
|
||||||
|
- input .params in functions
|
||||||
|
.param .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %_ptr_Function_ulong
|
||||||
|
- input .regs
|
||||||
|
.reg .b64 in_arg
|
||||||
|
get turned into the same SPIR-V as kernel .params:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
- output .regs
|
||||||
|
.reg .b64 out_arg
|
||||||
|
get just a variable declaration:
|
||||||
|
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||||
|
- output .params don't exist, they have been moved to input positions
|
||||||
|
by an earlier pass
|
||||||
|
Distinguishing betweem kernel .params and function .params is not the
|
||||||
|
cleanest solution. Alternatively, we could "deparamize" all kernel .param
|
||||||
|
arguments by turning them into .reg arguments like this:
|
||||||
|
.param .b64 arg -> .reg ptr<.b64,.param> arg
|
||||||
|
This has the massive downside that this transformation would have to run
|
||||||
|
very early and would muddy up already difficult code. It's simpler to just
|
||||||
|
have an if here
|
||||||
|
*/
|
||||||
|
pub(super) fn run<'a, 'b>(
|
||||||
|
func: Vec<TypedStatement>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for arg in fn_decl.input_arguments.iter_mut() {
|
||||||
|
insert_mem_ssa_argument(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
arg,
|
||||||
|
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for arg in fn_decl.return_arguments.iter() {
|
||||||
|
insert_mem_ssa_argument_reg_return(&mut result, arg);
|
||||||
|
}
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Ret { data } => {
|
||||||
|
// TODO: handle multiple output args
|
||||||
|
match &fn_decl.return_arguments[..] {
|
||||||
|
[return_reg] => {
|
||||||
|
let new_id = id_def.register_intermediate(Some((
|
||||||
|
return_reg.v_type.clone(),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
result.push(Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg: ast::LdArgs {
|
||||||
|
dst: new_id,
|
||||||
|
src: return_reg.name,
|
||||||
|
},
|
||||||
|
typ: return_reg.v_type.clone(),
|
||||||
|
member_index: None,
|
||||||
|
}));
|
||||||
|
result.push(Statement::RetValue(data, new_id));
|
||||||
|
}
|
||||||
|
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inst => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::Instruction(inst),
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
Statement::Conditional(bra) => {
|
||||||
|
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
|
||||||
|
}
|
||||||
|
Statement::Conversion(conv) => {
|
||||||
|
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
|
||||||
|
}
|
||||||
|
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::PtrAccess(ptr_access),
|
||||||
|
)?,
|
||||||
|
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::RepackVector(repack),
|
||||||
|
)?,
|
||||||
|
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::FunctionPointer(func_ptr),
|
||||||
|
)?,
|
||||||
|
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
|
||||||
|
result.push(s)
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument(
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &mut ast::Variable<SpirvWord>,
|
||||||
|
is_kernel: bool,
|
||||||
|
) {
|
||||||
|
if !is_kernel && arg.state_space == ast::StateSpace::Param {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
func.push(Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::StArgs {
|
||||||
|
src1: arg.name,
|
||||||
|
src2: new_id,
|
||||||
|
},
|
||||||
|
typ: arg.v_type.clone(),
|
||||||
|
member_index: None,
|
||||||
|
}));
|
||||||
|
arg.name = new_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument_reg_return(
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &ast::Variable<SpirvWord>,
|
||||||
|
) {
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: arg.state_space,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: arg.array_init.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_statement_default<'a, 'input>(
|
||||||
|
id_def: &'a mut NumericIdResolver<'input>,
|
||||||
|
func: &'a mut Vec<TypedStatement>,
|
||||||
|
stmt: TypedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let mut visitor = InsertMemSSAVisitor {
|
||||||
|
id_def,
|
||||||
|
func,
|
||||||
|
post_statements: Vec::new(),
|
||||||
|
};
|
||||||
|
let new_stmt = stmt.visit_map(&mut visitor)?;
|
||||||
|
visitor.func.push(new_stmt);
|
||||||
|
visitor.func.extend(visitor.post_statements);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
id_def: &'a mut NumericIdResolver<'input>,
|
||||||
|
func: &'a mut Vec<TypedStatement>,
|
||||||
|
post_statements: Vec<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
fn symbol(
|
||||||
|
&mut self,
|
||||||
|
symbol: SpirvWord,
|
||||||
|
member_index: Option<u8>,
|
||||||
|
expected: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if expected.is_none() {
|
||||||
|
return Ok(symbol);
|
||||||
|
};
|
||||||
|
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||||
|
if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
||||||
|
return Ok(symbol);
|
||||||
|
};
|
||||||
|
let member_index = match member_index {
|
||||||
|
Some(idx) => {
|
||||||
|
let vector_width = match var_type {
|
||||||
|
ast::Type::Vector(scalar_t, width) => {
|
||||||
|
var_type = ast::Type::Scalar(scalar_t);
|
||||||
|
width
|
||||||
|
}
|
||||||
|
_ => return Err(TranslateError::MismatchedType),
|
||||||
|
};
|
||||||
|
Some((
|
||||||
|
idx,
|
||||||
|
if self.id_def.special_registers.get(symbol).is_some() {
|
||||||
|
Some(vector_width)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let generated_id = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
|
||||||
|
if !is_dst {
|
||||||
|
self.func.push(Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg: ast::LdArgs {
|
||||||
|
dst: generated_id,
|
||||||
|
src: symbol,
|
||||||
|
},
|
||||||
|
typ: var_type,
|
||||||
|
member_index,
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
self.post_statements
|
||||||
|
.push(Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::StArgs {
|
||||||
|
src1: symbol,
|
||||||
|
src2: generated_id,
|
||||||
|
},
|
||||||
|
typ: var_type,
|
||||||
|
member_index: member_index.map(|(idx, _)| idx),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Ok(generated_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||||
|
for InsertMemSSAVisitor<'a, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
operand: TypedOperand,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
Ok(match operand {
|
||||||
|
TypedOperand::Reg(reg) => {
|
||||||
|
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
|
||||||
|
}
|
||||||
|
TypedOperand::RegOffset(reg, offset) => {
|
||||||
|
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
|
||||||
|
}
|
||||||
|
op @ TypedOperand::Imm(..) => op,
|
||||||
|
TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember(
|
||||||
|
self.symbol(symbol, Some(index), type_space, is_dst)?,
|
||||||
|
index,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.symbol(args, None, type_space, is_dst)
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ use std::{
|
||||||
mod convert_to_stateful_memory_access;
|
mod convert_to_stateful_memory_access;
|
||||||
mod convert_to_typed;
|
mod convert_to_typed;
|
||||||
mod fix_special_registers;
|
mod fix_special_registers;
|
||||||
|
mod insert_mem_ssa_statements;
|
||||||
mod normalize_identifiers;
|
mod normalize_identifiers;
|
||||||
mod normalize_predicates;
|
mod normalize_predicates;
|
||||||
|
|
||||||
|
@ -175,13 +176,13 @@ fn to_ssa<'input, 'b>(
|
||||||
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
|
||||||
let (func_decl, typed_statements) =
|
let (func_decl, typed_statements) =
|
||||||
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
|
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||||
todo!()
|
let ssa_statements = insert_mem_ssa_statements::run(
|
||||||
/*
|
|
||||||
let ssa_statements = insert_mem_ssa_statements(
|
|
||||||
typed_statements,
|
typed_statements,
|
||||||
&mut numeric_id_defs,
|
&mut numeric_id_defs,
|
||||||
&mut (*func_decl).borrow_mut(),
|
&mut (*func_decl).borrow_mut(),
|
||||||
)?;
|
)?;
|
||||||
|
todo!()
|
||||||
|
/*
|
||||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||||
let expanded_statements =
|
let expanded_statements =
|
||||||
|
@ -1206,3 +1207,9 @@ impl<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
|
||||||
|
this == other
|
||||||
|
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|
||||||
|
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue