Reimplement vector member access

This commit is contained in:
Andrzej Janik 2020-12-08 00:46:10 +01:00
commit 18d5aa85b5

View file

@ -981,7 +981,6 @@ fn compute_denorm_information<'input>(
Statement::Conversion(_) => {} Statement::Conversion(_) => {}
Statement::Constant(_) => {} Statement::Constant(_) => {}
Statement::RetValue(_, _) => {} Statement::RetValue(_, _) => {}
Statement::Undef(_, _) => {}
Statement::Label(_) => {} Statement::Label(_) => {}
Statement::Variable(_) => {} Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {} Statement::PtrAccess { .. } => {}
@ -1845,7 +1844,6 @@ fn normalize_labels(
| Statement::Conversion(..) | Statement::Conversion(..)
| Statement::Constant(..) | Statement::Constant(..)
| Statement::Label(..) | Statement::Label(..)
| Statement::Undef(..)
| Statement::PtrAccess { .. } | Statement::PtrAccess { .. }
| Statement::RepackVector(..) => {} | Statement::RepackVector(..) => {}
} }
@ -2071,46 +2069,158 @@ impl<
} }
} }
fn insert_mem_ssa_statement_default<'a, S: Visitable<TypedArgParams, TypedArgParams>>( struct InsertMemSSAVisitor<'a, 'input> {
id_def: &mut NumericIdResolver, id_def: &'a mut NumericIdResolver<'input>,
result: &mut Vec<TypedStatement>, func: &'a mut Vec<TypedStatement>,
stmt: S, post_statements: Vec<TypedStatement>,
) -> Result<(), TranslateError> { }
let mut post_statements = Vec::new();
let new_statement = impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>, fn symbol(
expected_type: Option<&ast::Type>| { &mut self,
if expected_type.is_none() { desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
return Ok(desc.op); expected_type: Option<&ast::Type>,
}; ) -> Result<spirv::Word, TranslateError> {
let (var_type, is_variable) = id_def.get_typed(desc.op)?; let symbol = desc.op.0;
if !is_variable { if expected_type.is_none() {
return Ok(desc.op); return Ok(symbol);
};
let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
if !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
Some(idx) => {
match var_type {
ast::Type::Vector(scalar_t, _) => {
var_type = ast::Type::Scalar(scalar_t);
}
_ => return Err(TranslateError::MismatchedType),
}
Some((idx, self.id_def.special_registers.contains_key(&symbol)))
} }
let generated_id = id_def.new_non_variable(Some(var_type.clone())); None => None,
if !desc.is_dst { };
result.push(Statement::LoadVar(LoadVarDetails { let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
arg: Arg2 { if !desc.is_dst {
dst: generated_id, self.func.push(Statement::LoadVar(LoadVarDetails {
src: desc.op, arg: Arg2 {
}, dst: generated_id,
typ: var_type, src: symbol,
member_index: None, },
})); typ: var_type,
} else { member_index,
post_statements.push(Statement::StoreVar(StoreVarDetails { }));
} else {
self.post_statements
.push(Statement::StoreVar(StoreVarDetails {
arg: Arg2St { arg: Arg2St {
src1: desc.op, src1: symbol,
src2: generated_id, src2: generated_id,
}, },
typ: var_type, typ: var_type,
member_index: None, member_index,
})); }));
}
Ok(generated_id)
}
}
impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
for InsertMemSSAVisitor<'a, 'input>
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
fn dst_operand(
&mut self,
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::DstOperand::Reg(reg) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
} }
Ok(generated_id) ast::DstOperand::VecMember(symbol, index) => {
})?; ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
result.push(new_statement); }
result.append(&mut post_statements); })
}
fn src_operand(
&mut self,
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::SrcOperand::Reg(reg) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
self.symbol(desc.new_op((reg, None)), Some(typ))?,
offset,
),
op @ ast::SrcOperand::Imm(..) => op,
ast::SrcOperand::VecMember(symbol, index) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
fn dst_operand_vec(
&mut self,
desc: ArgumentDescriptor<ast::DstOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::DstOperand::Reg(reg) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::DstOperand::VecMember(symbol, index) => {
ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
fn src_operand_vec(
&mut self,
desc: ArgumentDescriptor<ast::SrcOperand<spirv::Word>>,
typ: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::SrcOperand::Reg(reg) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset(
self.symbol(desc.new_op((reg, None)), Some(typ))?,
offset,
),
op @ ast::SrcOperand::Imm(..) => op,
ast::SrcOperand::VecMember(symbol, index) => {
ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
}
})
}
}
fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
stmt: S,
) -> Result<(), TranslateError> {
let mut visitor = InsertMemSSAVisitor {
id_def,
func,
post_statements: Vec::new(),
};
let new_stmt = stmt.visit(&mut visitor)?;
visitor.func.push(new_stmt);
visitor.func.extend(visitor.post_statements);
Ok(()) Ok(())
} }
@ -2162,7 +2272,7 @@ fn expand_arguments<'a, 'b>(
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
Statement::Constant(_) | Statement::Undef(_, _) => return Err(error_unreachable()), Statement::Constant(_) => return Err(error_unreachable()),
} }
} }
Ok(result) Ok(result)
@ -2472,7 +2582,6 @@ fn insert_implicit_conversions(
| s @ Statement::Variable(_) | s @ Statement::Variable(_)
| s @ Statement::LoadVar(..) | s @ Statement::LoadVar(..)
| s @ Statement::StoreVar(..) | s @ Statement::StoreVar(..)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s), | s @ Statement::RetValue(_, _) => result.push(s),
} }
} }
@ -3049,26 +3158,66 @@ fn emit_function_body_ops(
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
} }
}, },
Statement::LoadVar(LoadVarDetails { Statement::LoadVar(details) => {
arg, let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
typ, let src = match details.member_index {
member_index, Some((index, is_sreg)) => {
}) => { let storage_class = if is_sreg {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); spirv::StorageClass::Input
builder.load(type_id, Some(arg.dst), arg.src, None, [])?; } else {
spirv::StorageClass::Function
};
let result_ptr_type = map.get_or_add(
builder,
SpirvType::new_pointer(details.typ.clone(), storage_class),
);
let index_spirv = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(index as u32),
)?;
builder.in_bounds_access_chain(
result_ptr_type,
None,
details.arg.src,
&[index_spirv],
)?
}
None => details.arg.src,
};
builder.load(result_type, Some(details.arg.dst), src, None, [])?;
} }
Statement::StoreVar(StoreVarDetails { Statement::StoreVar(details) => {
arg, member_index, .. let dst_ptr = match details.member_index {
}) => { Some((index, is_sreg)) => {
builder.store(arg.src1, arg.src2, None, [])?; let storage_class = if is_sreg {
spirv::StorageClass::Input
} else {
spirv::StorageClass::Function
};
let result_ptr_type = map.get_or_add(
builder,
SpirvType::new_pointer(details.typ.clone(), storage_class),
);
let index_spirv = map.get_or_add_constant(
builder,
&ast::Type::Scalar(ast::ScalarType::U32),
&vec_repr(index as u32),
)?;
builder.in_bounds_access_chain(
result_ptr_type,
None,
details.arg.src1,
&[index_spirv],
)?
}
None => details.arg.src1,
};
builder.store(dst_ptr, details.arg.src2, None, [])?;
} }
Statement::RetValue(_, id) => { Statement::RetValue(_, id) => {
builder.ret_value(*id)?; builder.ret_value(*id)?;
} }
Statement::Undef(t, id) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
Statement::PtrAccess(PtrAccess { Statement::PtrAccess(PtrAccess {
underlying_type, underlying_type,
state_space, state_space,
@ -4870,7 +5019,6 @@ enum Statement<I, P: ast::ArgParams> {
Conversion(ImplicitConversion), Conversion(ImplicitConversion),
Constant(ConstantDefinition), Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word), RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
PtrAccess(PtrAccess<P>), PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails), RepackVector(RepackVectorDetails),
} }
@ -4932,10 +5080,6 @@ impl ExpandedStatement {
let id = f(id, false); let id = f(id, false);
Statement::RetValue(data, id) Statement::RetValue(data, id)
} }
Statement::Undef(typ, id) => {
let id = f(id, true);
Statement::Undef(typ, id)
}
Statement::PtrAccess(PtrAccess { Statement::PtrAccess(PtrAccess {
underlying_type, underlying_type,
state_space, state_space,
@ -4962,13 +5106,15 @@ impl ExpandedStatement {
struct LoadVarDetails { struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>, arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type, typ: ast::Type,
member_index: Option<u8>, // (index, is_sreg)
member_index: Option<(u8, bool)>,
} }
struct StoreVarDetails { struct StoreVarDetails {
arg: ast::Arg2St<ExpandedArgParams>, arg: ast::Arg2St<ExpandedArgParams>,
typ: ast::Type, typ: ast::Type,
member_index: Option<u8>, // (index, is_sreg)
member_index: Option<(u8, bool)>,
} }
struct RepackVectorDetails { struct RepackVectorDetails {
@ -5261,29 +5407,6 @@ impl ArgParamsEx for ExpandedArgParams {
} }
} }
#[derive(Copy, Clone)]
pub enum StateSpace {
Reg,
Const,
Global,
Local,
Shared,
Param,
}
impl From<ast::StateSpace> for StateSpace {
fn from(ss: ast::StateSpace) -> Self {
match ss {
ast::StateSpace::Reg => StateSpace::Reg,
ast::StateSpace::Const => StateSpace::Const,
ast::StateSpace::Global => StateSpace::Global,
ast::StateSpace::Local => StateSpace::Local,
ast::StateSpace::Shared => StateSpace::Shared,
ast::StateSpace::Param => StateSpace::Param,
}
}
}
enum Directive<'input> { enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>), Variable(ast::Variable<ast::VariableType, spirv::Word>),
Method(Function<'input>), Method(Function<'input>),
@ -5388,7 +5511,7 @@ where
fn dst_operand( fn dst_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<ast::DstOperand<&str>>, desc: ArgumentDescriptor<ast::DstOperand<&str>>,
typ: &ast::Type, _: &ast::Type,
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> { ) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
Ok(match desc.op { Ok(match desc.op {
ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?), ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?),
@ -5399,7 +5522,7 @@ where
fn src_operand( fn src_operand(
&mut self, &mut self,
desc: ArgumentDescriptor<ast::SrcOperand<&str>>, desc: ArgumentDescriptor<ast::SrcOperand<&str>>,
typ: &ast::Type, _: &ast::Type,
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> { ) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
Ok(match desc.op { Ok(match desc.op {
ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?), ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?),
@ -6810,6 +6933,7 @@ impl<T> ast::DstOperand<T> {
} }
} }
} }
impl ast::DstOperand<spirv::Word> { impl ast::DstOperand<spirv::Word> {
fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> { fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
match self { match self {