Introduce vector repack statement

This commit is contained in:
Andrzej Janik 2020-12-06 00:47:56 +01:00
parent adf88bc1af
commit eb841b3a88

View file

@ -975,8 +975,7 @@ fn compute_denorm_information<'input>(
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
Statement::PackVector(_) => {}
Statement::UnpackVector(_) => {}
Statement::RepackVector(_) => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@ -1477,7 +1476,7 @@ fn convert_to_typed_statements(
};
d.src_is_address = take_address;
}
let mut visitor = VectorPackingVisitor::new(&mut result, id_defs);
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(
ast::Instruction::Mov(
d,
@ -1488,12 +1487,14 @@ fn convert_to_typed_statements(
)
.map(&mut visitor)?,
);
result.push(instruction);
visitor.func.push(instruction);
visitor.func.extend(visitor.post_stmts);
}
inst => {
let mut visitor = VectorPackingVisitor::new(&mut result, id_defs);
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
result.push(instruction);
visitor.func.push(instruction);
visitor.func.extend(visitor.post_stmts);
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
@ -1505,24 +1506,52 @@ fn convert_to_typed_statements(
Ok(result)
}
struct VectorPackingVisitor<'a, 'b> {
struct VectorRepackVisitor<'a, 'b> {
func: &'b mut Vec<TypedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
post_stmts: Vec<TypedStatement>,
post_stmts: Option<TypedStatement>,
}
impl<'a, 'b> VectorPackingVisitor<'a, 'b> {
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
VectorPackingVisitor {
VectorRepackVisitor {
func,
id_def,
post_stmts: Vec::new(),
post_stmts: None,
}
}
fn convert_vector(
&mut self,
is_dst: bool,
vector_sema: ArgumentSemantics,
typ: &ast::Type,
idx: Vec<spirv::Word>,
) -> Result<spirv::Word, TranslateError> {
// mov.u32 foobar, {a,b};
let scalar_t = match typ {
ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
let statement = Statement::RepackVector(RepackVector {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
vector_sema,
});
if is_dst {
self.post_stmts = Some(statement);
} else {
self.func.push(statement);
}
Ok(temp_vec)
}
}
impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
for VectorPackingVisitor<'a, 'b>
for VectorRepackVisitor<'a, 'b>
{
fn id(
&mut self,
@ -1555,7 +1584,12 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::DstOperandVec::Normal(op) => self.dst_operand(desc.new_op(op), typ),
ast::DstOperandVec::Vector(vec) => todo!(),
ast::DstOperandVec::Vector(vec) => Ok(ast::DstOperand::Reg(self.convert_vector(
desc.is_dst,
desc.sema,
typ,
vec,
)?)),
}
}
@ -1566,7 +1600,12 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::SrcOperandVec::Normal(op) => self.src_operand(desc.new_op(op), typ),
ast::SrcOperandVec::Vector(_) => todo!(),
ast::SrcOperandVec::Vector(vec) => Ok(ast::SrcOperand::Reg(self.convert_vector(
desc.is_dst,
desc.sema,
typ,
vec,
)?)),
}
}
}
@ -1799,8 +1838,7 @@ fn normalize_labels(
| Statement::Label(..)
| Statement::Undef(..)
| Statement::PtrAccess { .. }
| Statement::PackVector(..)
| Statement::UnpackVector(..) => {}
| Statement::RepackVector(..) => {}
}
}
iter::once(Statement::Label(id_def.new_non_variable(None)))
@ -1944,6 +1982,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::PtrAccess(ptr_access) => {
insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
}
Statement::RepackVector(repack) => {
insert_mem_ssa_statement_default(id_def, &mut result, repack)?
}
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(TranslateError::Unreachable),
}
@ -2111,6 +2152,12 @@ fn expand_arguments<'a, 'b>(
result.push(Statement::PtrAccess(new_inst));
result.extend(post_stmts);
}
Statement::RepackVector(repack) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
result.push(Statement::RepackVector(new_inst));
result.extend(post_stmts);
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@ -2120,8 +2167,6 @@ fn expand_arguments<'a, 'b>(
Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
return Err(TranslateError::Unreachable)
}
Statement::PackVector(_) => todo!(),
Statement::UnpackVector(_) => todo!(),
}
}
Ok(result)
@ -2445,6 +2490,13 @@ fn insert_implicit_conversions(
Some(state_space),
)?;
}
Statement::RepackVector(repack) => insert_implicit_conversions_impl(
&mut result,
id_def,
repack,
should_bitcast_wrapper,
None,
)?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
@ -2454,8 +2506,6 @@ fn insert_implicit_conversions(
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
Statement::PackVector(_) => todo!(),
Statement::UnpackVector(_) => todo!(),
}
}
Ok(result)
@ -3081,8 +3131,38 @@ fn emit_function_body_ops(
)?;
builder.bitcast(result_type, Some(*dst), temp)?;
}
Statement::PackVector(_) => todo!(),
Statement::UnpackVector(_) => todo!(),
Statement::RepackVector(repack) => {
if repack.is_extract {
let scalar_type = map.get_or_add_scalar(builder, repack.typ);
for (index, dst_id) in repack.unpacked.iter().enumerate() {
builder.composite_extract(
scalar_type,
Some(*dst_id),
repack.packed,
&[index as u32],
)?;
}
} else {
let vector_type = map.get_or_add(
builder,
SpirvType::Vector(
SpirvScalarKey::from(repack.typ),
repack.unpacked.len() as u8,
),
);
let mut temp_vec = builder.undef(vector_type, None);
for (index, src_id) in repack.unpacked.iter().enumerate() {
temp_vec = builder.composite_insert(
vector_type,
None,
*src_id,
temp_vec,
&[index as u32],
)?;
}
builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
}
}
}
}
Ok(())
@ -4334,9 +4414,7 @@ fn convert_to_stateful_memory_access<'a>(
},
)?;
result.push(new_statement);
for s in post_statements {
result.push(s);
}
result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
@ -4354,9 +4432,25 @@ fn convert_to_stateful_memory_access<'a>(
},
)?;
result.push(new_statement);
for s in post_statements {
result.push(s);
}
result.extend(post_statements);
}
Statement::RepackVector(pack) => {
let mut post_statements = Vec::new();
let new_statement = pack.visit_variable(
&mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
},
)?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(TranslateError::Unreachable),
}
@ -4810,8 +4904,7 @@ enum Statement<I, P: ast::ArgParams> {
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
PtrAccess(PtrAccess<P>),
PackVector(PackVector),
UnpackVector(UnpackVector),
RepackVector(RepackVector),
}
impl ExpandedStatement {
@ -4898,14 +4991,101 @@ impl ExpandedStatement {
offset_src: constant_src,
})
}
Statement::PackVector(_) => todo!(),
Statement::UnpackVector(_) => todo!(),
Statement::RepackVector(_) => todo!(),
}
}
}
struct PackVector {}
struct UnpackVector {}
struct RepackVector {
is_extract: bool,
typ: ast::ScalarType,
packed: spirv::Word,
unpacked: Vec<spirv::Word>,
vector_sema: ArgumentSemantics,
}
impl RepackVector {
fn map<
From: ArgParamsEx<Id = spirv::Word>,
To: ArgParamsEx<Id = spirv::Word>,
V: ArgumentMapVisitor<From, To>,
>(
self,
visitor: &mut V,
) -> Result<RepackVector, TranslateError> {
let scalar = visitor.id(
ArgumentDescriptor {
op: self.packed,
is_dst: !self.is_extract,
sema: ArgumentSemantics::Default,
},
Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
let vector_sema = self.vector_sema;
let vector = self
.unpacked
.into_iter()
.map(|id| {
visitor.id(
ArgumentDescriptor {
op: id,
is_dst: is_extract,
sema: vector_sema,
},
Some(&ast::Type::Scalar(scalar_type)),
)
})
.collect::<Result<_, _>>()?;
Ok(RepackVector {
is_extract,
typ: self.typ,
packed: scalar,
unpacked: vector,
vector_sema,
})
}
}
impl VisitVariable for RepackVector {
fn visit_variable<
'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<TypedStatement, TranslateError> {
Ok(TypedStatement::RepackVector(
self.map::<TypedArgParams, _, _>(f)?,
))
}
}
impl VisitVariableExpanded for RepackVector {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
Ok(ExpandedStatement::RepackVector(
self.map::<ExpandedArgParams, _, _>(f)?,
))
}
}
struct UnpackVector {
typ: ast::ScalarType,
dst: Vec<spirv::Word>,
src: spirv::Word,
}
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
@ -6737,22 +6917,6 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
}
}
impl ast::Type {
fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
match self {
ast::Type::Vector(t, len) => Ok((*t, *len)),
_ => Err(TranslateError::MismatchedType),
}
}
fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
match self {
ast::Type::Scalar(t) => Ok(*t),
_ => Err(TranslateError::MismatchedType),
}
}
}
impl<T> ast::SrcOperand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,