Reimplement vector member access

This commit is contained in:
Andrzej Janik 2020-12-08 00:46:10 +01:00
parent 100831daaf
commit 18d5aa85b5

View file

@ -981,7 +981,6 @@ fn compute_denorm_information<'input>(
Statement::Conversion(_) => {}
Statement::Constant(_) => {}
Statement::RetValue(_, _) => {}
Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
@ -1845,7 +1844,6 @@ fn normalize_labels(
| Statement::Conversion(..)
| Statement::Constant(..)
| Statement::Label(..)
| Statement::Undef(..)
| Statement::PtrAccess { .. }
| Statement::RepackVector(..) => {}
}
@ -2071,46 +2069,158 @@ impl<
}
}
fn insert_mem_ssa_statement_default<'a, S: Visitable<TypedArgParams, TypedArgParams>>(
id_def: &mut NumericIdResolver,
result: &mut Vec<TypedStatement>,
stmt: S,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
let new_statement =
stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>| {
if expected_type.is_none() {
return Ok(desc.op);
};
let (var_type, is_variable) = id_def.get_typed(desc.op)?;
if !is_variable {
return Ok(desc.op);
struct InsertMemSSAVisitor<'a, 'input> {
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
post_statements: Vec<TypedStatement>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
expected_type: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
let symbol = desc.op.0;
if expected_type.is_none() {
return Ok(symbol);
};
let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
if !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
Some(idx) => {
match var_type {
ast::Type::Vector(scalar_t, _) => {
var_type = ast::Type::Scalar(scalar_t);
}
_ => return Err(TranslateError::MismatchedType),
}
Some((idx, self.id_def.special_registers.contains_key(&symbol)))
}
let generated_id = id_def.new_non_variable(Some(var_type.clone()));
if !desc.is_dst {
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: desc.op,
},
typ: var_type,
member_index: None,
}));
} else {
post_statements.push(Statement::StoreVar(StoreVarDetails {
None => None,
};
let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
typ: var_type,
member_index,
}));
} else {
self.post_statements
.push(Statement::StoreVar(StoreVarDetails {
arg: Arg2St {
src1: desc.op,
src1: symbol,
src2: generated_id,
},
typ: var_type,
member_index: None,
member_index,
}));
}
Ok(generated_id)
}
}
impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
for InsertMemSSAVisitor<'a, 'input>
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
fn dst_operand(
&mut self,
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::DstOperand::Reg(reg) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
Ok(generated_id)
})?;
result.push(new_statement);
result.append(&mut post_statements);
ast::DstOperand::VecMember(symbol, index) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
fn src_operand(
&mut self,
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::SrcOperand::Reg(reg) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
self.symbol(desc.new_op((reg, None)), Some(typ))?,
offset,
),
op @ ast::SrcOperand::Imm(..) => op,
ast::SrcOperand::VecMember(symbol, index) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
fn dst_operand_vec(
&mut self,
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::DstOperand::Reg(reg) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::DstOperand::VecMember(symbol, index) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
fn src_operand_vec(
&mut self,
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::SrcOperand::Reg(reg) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
self.symbol(desc.new_op((reg, None)), Some(typ))?,
offset,
),
op @ ast::SrcOperand::Imm(..) => op,
ast::SrcOperand::VecMember(symbol, index) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
}
fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
stmt: S,
) -> Result<(), TranslateError> {
let mut visitor = InsertMemSSAVisitor {
id_def,
func,
post_statements: Vec::new(),
};
let new_stmt = stmt.visit(&mut visitor)?;
visitor.func.push(new_stmt);
visitor.func.extend(visitor.post_statements);
Ok(())
}
@ -2162,7 +2272,7 @@ fn expand_arguments<'a, 'b>(
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
Statement::Constant(_) | Statement::Undef(_, _) => return Err(error_unreachable()),
Statement::Constant(_) => return Err(error_unreachable()),
}
}
Ok(result)
@ -2472,7 +2582,6 @@ fn insert_implicit_conversions(
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(..)
| s @ Statement::StoreVar(..)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
}
}
@ -3049,26 +3158,66 @@ fn emit_function_body_ops(
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
},
Statement::LoadVar(LoadVarDetails {
arg,
typ,
member_index,
}) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
Statement::LoadVar(details) => {
let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
let src = match details.member_index {
Some((index, is_sreg)) => {
let storage_class = if is_sreg {
spirv::StorageClass::Input
} else {
spirv::StorageClass::Function
};
let result_ptr_type = map.get_or_add(
builder,
SpirvType::new_pointer(details.typ.clone(), storage_class),
);
let index_spirv = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(index as u32),
)?;
builder.in_bounds_access_chain(
result_ptr_type,
None,
details.arg.src,
&[index_spirv],
)?
}
None => details.arg.src,
};
builder.load(result_type, Some(details.arg.dst), src, None, [])?;
}
Statement::StoreVar(StoreVarDetails {
arg, member_index, ..
}) => {
builder.store(arg.src1, arg.src2, None, [])?;
Statement::StoreVar(details) => {
let dst_ptr = match details.member_index {
Some((index, is_sreg)) => {
let storage_class = if is_sreg {
spirv::StorageClass::Input
} else {
spirv::StorageClass::Function
};
let result_ptr_type = map.get_or_add(
builder,
SpirvType::new_pointer(details.typ.clone(), storage_class),
);
let index_spirv = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(index as u32),
)?;
builder.in_bounds_access_chain(
result_ptr_type,
None,
details.arg.src1,
&[index_spirv],
)?
}
None => details.arg.src1,
};
builder.store(dst_ptr, details.arg.src2, None, [])?;
}
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
Statement::Undef(t, id) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@ -4870,7 +5019,6 @@ enum Statement<I, P: ast::ArgParams> {
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
}
@ -4932,10 +5080,6 @@ impl ExpandedStatement {
let id = f(id, false);
Statement::RetValue(data, id)
}
Statement::Undef(typ, id) => {
let id = f(id, true);
Statement::Undef(typ, id)
}
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@ -4962,13 +5106,15 @@ impl ExpandedStatement {
struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type,
member_index: Option<u8>,
// (index, is_sreg)
member_index: Option<(u8, bool)>,
}
struct StoreVarDetails {
arg: ast::Arg2St<ExpandedArgParams>,
typ: ast::Type,
member_index: Option<u8>,
// (index, is_sreg)
member_index: Option<(u8, bool)>,
}
struct RepackVectorDetails {
@ -5261,29 +5407,6 @@ impl ArgParamsEx for ExpandedArgParams {
}
}
#[derive(Copy, Clone)]
pub enum StateSpace {
Reg,
Const,
Global,
Local,
Shared,
Param,
}
impl From<ast::StateSpace> for StateSpace {
fn from(ss: ast::StateSpace) -> Self {
match ss {
ast::StateSpace::Reg => StateSpace::Reg,
ast::StateSpace::Const => StateSpace::Const,
ast::StateSpace::Global => StateSpace::Global,
ast::StateSpace::Local => StateSpace::Local,
ast::StateSpace::Shared => StateSpace::Shared,
ast::StateSpace::Param => StateSpace::Param,
}
}
}
enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>),
Method(Function<'input>),
@ -5388,7 +5511,7 @@ where
fn dst_operand(
&mut self,
desc: ArgumentDescriptor<ast::DstOperand<&str>>,
typ: &ast::Type,
_: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?),
@ -5399,7 +5522,7 @@ where
fn src_operand(
&mut self,
desc: ArgumentDescriptor<ast::SrcOperand<&str>>,
typ: &ast::Type,
_: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?),
@ -6810,6 +6933,7 @@ impl<T> ast::DstOperand<T> {
}
}
}
impl ast::DstOperand<spirv::Word> {
fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
match self {