mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Reimplement vector member access
This commit is contained in:
parent
100831daaf
commit
18d5aa85b5
1 changed files with 208 additions and 84 deletions
|
@ -981,7 +981,6 @@ fn compute_denorm_information<'input>(
|
|||
Statement::Conversion(_) => {}
|
||||
Statement::Constant(_) => {}
|
||||
Statement::RetValue(_, _) => {}
|
||||
Statement::Undef(_, _) => {}
|
||||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
|
@ -1845,7 +1844,6 @@ fn normalize_labels(
|
|||
| Statement::Conversion(..)
|
||||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::Undef(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::RepackVector(..) => {}
|
||||
}
|
||||
|
@ -2071,46 +2069,158 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_statement_default<'a, S: Visitable<TypedArgParams, TypedArgParams>>(
|
||||
id_def: &mut NumericIdResolver,
|
||||
result: &mut Vec<TypedStatement>,
|
||||
stmt: S,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement =
|
||||
stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<&ast::Type>| {
|
||||
if expected_type.is_none() {
|
||||
return Ok(desc.op);
|
||||
};
|
||||
let (var_type, is_variable) = id_def.get_typed(desc.op)?;
|
||||
if !is_variable {
|
||||
return Ok(desc.op);
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
post_statements: Vec<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn symbol(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
|
||||
expected_type: Option<&ast::Type>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
let symbol = desc.op.0;
|
||||
if expected_type.is_none() {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
|
||||
if !is_variable {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let member_index = match desc.op.1 {
|
||||
Some(idx) => {
|
||||
match var_type {
|
||||
ast::Type::Vector(scalar_t, _) => {
|
||||
var_type = ast::Type::Scalar(scalar_t);
|
||||
}
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
}
|
||||
Some((idx, self.id_def.special_registers.contains_key(&symbol)))
|
||||
}
|
||||
let generated_id = id_def.new_non_variable(Some(var_type.clone()));
|
||||
if !desc.is_dst {
|
||||
result.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: Arg2 {
|
||||
dst: generated_id,
|
||||
src: desc.op,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index: None,
|
||||
}));
|
||||
} else {
|
||||
post_statements.push(Statement::StoreVar(StoreVarDetails {
|
||||
None => None,
|
||||
};
|
||||
let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
|
||||
if !desc.is_dst {
|
||||
self.func.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: Arg2 {
|
||||
dst: generated_id,
|
||||
src: symbol,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index,
|
||||
}));
|
||||
} else {
|
||||
self.post_statements
|
||||
.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: Arg2St {
|
||||
src1: desc.op,
|
||||
src1: symbol,
|
||||
src2: generated_id,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index: None,
|
||||
member_index,
|
||||
}));
|
||||
}
|
||||
Ok(generated_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn id(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<spirv::Word>,
|
||||
typ: Option<&ast::Type>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
self.symbol(desc.new_op((desc.op, None)), typ)
|
||||
}
|
||||
|
||||
fn dst_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::DstOperand::Reg(reg) => {
|
||||
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
|
||||
}
|
||||
Ok(generated_id)
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
result.append(&mut post_statements);
|
||||
ast::DstOperand::VecMember(symbol, index) => {
|
||||
ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn src_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::SrcOperand::Reg(reg) => {
|
||||
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
|
||||
}
|
||||
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
|
||||
self.symbol(desc.new_op((reg, None)), Some(typ))?,
|
||||
offset,
|
||||
),
|
||||
op @ ast::SrcOperand::Imm(..) => op,
|
||||
ast::SrcOperand::VecMember(symbol, index) => {
|
||||
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn dst_operand_vec(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::DstOperand::Reg(reg) => {
|
||||
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
|
||||
}
|
||||
ast::DstOperand::VecMember(symbol, index) => {
|
||||
ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn src_operand_vec(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
|
||||
typ: &ast::Type,
|
||||
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::SrcOperand::Reg(reg) => {
|
||||
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
|
||||
}
|
||||
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
|
||||
self.symbol(desc.new_op((reg, None)), Some(typ))?,
|
||||
offset,
|
||||
),
|
||||
op @ ast::SrcOperand::Imm(..) => op,
|
||||
ast::SrcOperand::VecMember(symbol, index) => {
|
||||
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
stmt: S,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = InsertMemSSAVisitor {
|
||||
id_def,
|
||||
func,
|
||||
post_statements: Vec::new(),
|
||||
};
|
||||
let new_stmt = stmt.visit(&mut visitor)?;
|
||||
visitor.func.push(new_stmt);
|
||||
visitor.func.extend(visitor.post_statements);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -2162,7 +2272,7 @@ fn expand_arguments<'a, 'b>(
|
|||
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||
Statement::Constant(_) | Statement::Undef(_, _) => return Err(error_unreachable()),
|
||||
Statement::Constant(_) => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -2472,7 +2582,6 @@ fn insert_implicit_conversions(
|
|||
| s @ Statement::Variable(_)
|
||||
| s @ Statement::LoadVar(..)
|
||||
| s @ Statement::StoreVar(..)
|
||||
| s @ Statement::Undef(_, _)
|
||||
| s @ Statement::RetValue(_, _) => result.push(s),
|
||||
}
|
||||
}
|
||||
|
@ -3049,26 +3158,66 @@ fn emit_function_body_ops(
|
|||
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
}
|
||||
},
|
||||
Statement::LoadVar(LoadVarDetails {
|
||||
arg,
|
||||
typ,
|
||||
member_index,
|
||||
}) => {
|
||||
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
|
||||
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
|
||||
Statement::LoadVar(details) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
|
||||
let src = match details.member_index {
|
||||
Some((index, is_sreg)) => {
|
||||
let storage_class = if is_sreg {
|
||||
spirv::StorageClass::Input
|
||||
} else {
|
||||
spirv::StorageClass::Function
|
||||
};
|
||||
let result_ptr_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::new_pointer(details.typ.clone(), storage_class),
|
||||
);
|
||||
let index_spirv = map.get_or_add_constant(
|
||||
builder,
|
||||
&ast::Type::Scalar(ast::ScalarType::U32),
|
||||
&vec_repr(index as u32),
|
||||
)?;
|
||||
builder.in_bounds_access_chain(
|
||||
result_ptr_type,
|
||||
None,
|
||||
details.arg.src,
|
||||
&[index_spirv],
|
||||
)?
|
||||
}
|
||||
None => details.arg.src,
|
||||
};
|
||||
builder.load(result_type, Some(details.arg.dst), src, None, [])?;
|
||||
}
|
||||
Statement::StoreVar(StoreVarDetails {
|
||||
arg, member_index, ..
|
||||
}) => {
|
||||
builder.store(arg.src1, arg.src2, None, [])?;
|
||||
Statement::StoreVar(details) => {
|
||||
let dst_ptr = match details.member_index {
|
||||
Some((index, is_sreg)) => {
|
||||
let storage_class = if is_sreg {
|
||||
spirv::StorageClass::Input
|
||||
} else {
|
||||
spirv::StorageClass::Function
|
||||
};
|
||||
let result_ptr_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::new_pointer(details.typ.clone(), storage_class),
|
||||
);
|
||||
let index_spirv = map.get_or_add_constant(
|
||||
builder,
|
||||
&ast::Type::Scalar(ast::ScalarType::U32),
|
||||
&vec_repr(index as u32),
|
||||
)?;
|
||||
builder.in_bounds_access_chain(
|
||||
result_ptr_type,
|
||||
None,
|
||||
details.arg.src1,
|
||||
&[index_spirv],
|
||||
)?
|
||||
}
|
||||
None => details.arg.src1,
|
||||
};
|
||||
builder.store(dst_ptr, details.arg.src2, None, [])?;
|
||||
}
|
||||
Statement::RetValue(_, id) => {
|
||||
builder.ret_value(*id)?;
|
||||
}
|
||||
Statement::Undef(t, id) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
|
||||
builder.undef(result_type, Some(*id));
|
||||
}
|
||||
Statement::PtrAccess(PtrAccess {
|
||||
underlying_type,
|
||||
state_space,
|
||||
|
@ -4870,7 +5019,6 @@ enum Statement<I, P: ast::ArgParams> {
|
|||
Conversion(ImplicitConversion),
|
||||
Constant(ConstantDefinition),
|
||||
RetValue(ast::RetData, spirv::Word),
|
||||
Undef(ast::Type, spirv::Word),
|
||||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
}
|
||||
|
@ -4932,10 +5080,6 @@ impl ExpandedStatement {
|
|||
let id = f(id, false);
|
||||
Statement::RetValue(data, id)
|
||||
}
|
||||
Statement::Undef(typ, id) => {
|
||||
let id = f(id, true);
|
||||
Statement::Undef(typ, id)
|
||||
}
|
||||
Statement::PtrAccess(PtrAccess {
|
||||
underlying_type,
|
||||
state_space,
|
||||
|
@ -4962,13 +5106,15 @@ impl ExpandedStatement {
|
|||
struct LoadVarDetails {
|
||||
arg: ast::Arg2<ExpandedArgParams>,
|
||||
typ: ast::Type,
|
||||
member_index: Option<u8>,
|
||||
// (index, is_sreg)
|
||||
member_index: Option<(u8, bool)>,
|
||||
}
|
||||
|
||||
struct StoreVarDetails {
|
||||
arg: ast::Arg2St<ExpandedArgParams>,
|
||||
typ: ast::Type,
|
||||
member_index: Option<u8>,
|
||||
// (index, is_sreg)
|
||||
member_index: Option<(u8, bool)>,
|
||||
}
|
||||
|
||||
struct RepackVectorDetails {
|
||||
|
@ -5261,29 +5407,6 @@ impl ArgParamsEx for ExpandedArgParams {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
Const,
|
||||
Global,
|
||||
Local,
|
||||
Shared,
|
||||
Param,
|
||||
}
|
||||
|
||||
impl From<ast::StateSpace> for StateSpace {
|
||||
fn from(ss: ast::StateSpace) -> Self {
|
||||
match ss {
|
||||
ast::StateSpace::Reg => StateSpace::Reg,
|
||||
ast::StateSpace::Const => StateSpace::Const,
|
||||
ast::StateSpace::Global => StateSpace::Global,
|
||||
ast::StateSpace::Local => StateSpace::Local,
|
||||
ast::StateSpace::Shared => StateSpace::Shared,
|
||||
ast::StateSpace::Param => StateSpace::Param,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Directive<'input> {
|
||||
Variable(ast::Variable<ast::VariableType, spirv::Word>),
|
||||
Method(Function<'input>),
|
||||
|
@ -5388,7 +5511,7 @@ where
|
|||
fn dst_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::DstOperand<&str>>,
|
||||
typ: &ast::Type,
|
||||
_: &ast::Type,
|
||||
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?),
|
||||
|
@ -5399,7 +5522,7 @@ where
|
|||
fn src_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::SrcOperand<&str>>,
|
||||
typ: &ast::Type,
|
||||
_: &ast::Type,
|
||||
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
|
||||
Ok(match desc.op {
|
||||
ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?),
|
||||
|
@ -6810,6 +6933,7 @@ impl<T> ast::DstOperand<T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::DstOperand<spirv::Word> {
|
||||
fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
|
||||
match self {
|
||||
|
|
Loading…
Add table
Reference in a new issue