mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Introduce vector repack statement
This commit is contained in:
parent
adf88bc1af
commit
eb841b3a88
1 changed files with 214 additions and 50 deletions
|
@ -975,8 +975,7 @@ fn compute_denorm_information<'input>(
|
|||
Statement::Label(_) => {}
|
||||
Statement::Variable(_) => {}
|
||||
Statement::PtrAccess { .. } => {}
|
||||
Statement::PackVector(_) => {}
|
||||
Statement::UnpackVector(_) => {}
|
||||
Statement::RepackVector(_) => {}
|
||||
}
|
||||
}
|
||||
denorm_methods.insert(method_key, flush_counter);
|
||||
|
@ -1477,7 +1476,7 @@ fn convert_to_typed_statements(
|
|||
};
|
||||
d.src_is_address = take_address;
|
||||
}
|
||||
let mut visitor = VectorPackingVisitor::new(&mut result, id_defs);
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let instruction = Statement::Instruction(
|
||||
ast::Instruction::Mov(
|
||||
d,
|
||||
|
@ -1488,12 +1487,14 @@ fn convert_to_typed_statements(
|
|||
)
|
||||
.map(&mut visitor)?,
|
||||
);
|
||||
result.push(instruction);
|
||||
visitor.func.push(instruction);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
inst => {
|
||||
let mut visitor = VectorPackingVisitor::new(&mut result, id_defs);
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
|
||||
result.push(instruction);
|
||||
visitor.func.push(instruction);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
|
@ -1505,24 +1506,52 @@ fn convert_to_typed_statements(
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
struct VectorPackingVisitor<'a, 'b> {
|
||||
struct VectorRepackVisitor<'a, 'b> {
|
||||
func: &'b mut Vec<TypedStatement>,
|
||||
id_def: &'b mut NumericIdResolver<'a>,
|
||||
post_stmts: Vec<TypedStatement>,
|
||||
post_stmts: Option<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> VectorPackingVisitor<'a, 'b> {
|
||||
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
|
||||
VectorPackingVisitor {
|
||||
VectorRepackVisitor {
|
||||
func,
|
||||
id_def,
|
||||
post_stmts: Vec::new(),
|
||||
post_stmts: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_vector(
|
||||
&mut self,
|
||||
is_dst: bool,
|
||||
vector_sema: ArgumentSemantics,
|
||||
typ: &ast::Type,
|
||||
idx: Vec<spirv::Word>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
// mov.u32 foobar, {a,b};
|
||||
let scalar_t = match typ {
|
||||
ast::Type::Vector(scalar_t, _) => *scalar_t,
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
|
||||
let statement = Statement::RepackVector(RepackVector {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temp_vec,
|
||||
unpacked: idx,
|
||||
vector_sema,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts = Some(statement);
|
||||
} else {
|
||||
self.func.push(statement);
|
||||
}
|
||||
Ok(temp_vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
|
||||
for VectorPackingVisitor<'a, 'b>
|
||||
for VectorRepackVisitor<'a, 'b>
|
||||
{
|
||||
fn id(
|
||||
&mut self,
|
||||
|
@ -1555,7 +1584,12 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
|
|||
) -> Result<ast::DstOperand<spirv::Word>, TranslateError> {
|
||||
match desc.op {
|
||||
ast::DstOperandVec::Normal(op) => self.dst_operand(desc.new_op(op), typ),
|
||||
ast::DstOperandVec::Vector(vec) => todo!(),
|
||||
ast::DstOperandVec::Vector(vec) => Ok(ast::DstOperand::Reg(self.convert_vector(
|
||||
desc.is_dst,
|
||||
desc.sema,
|
||||
typ,
|
||||
vec,
|
||||
)?)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1566,7 +1600,12 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
|
|||
) -> Result<ast::SrcOperand<spirv::Word>, TranslateError> {
|
||||
match desc.op {
|
||||
ast::SrcOperandVec::Normal(op) => self.src_operand(desc.new_op(op), typ),
|
||||
ast::SrcOperandVec::Vector(_) => todo!(),
|
||||
ast::SrcOperandVec::Vector(vec) => Ok(ast::SrcOperand::Reg(self.convert_vector(
|
||||
desc.is_dst,
|
||||
desc.sema,
|
||||
typ,
|
||||
vec,
|
||||
)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1799,8 +1838,7 @@ fn normalize_labels(
|
|||
| Statement::Label(..)
|
||||
| Statement::Undef(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::PackVector(..)
|
||||
| Statement::UnpackVector(..) => {}
|
||||
| Statement::RepackVector(..) => {}
|
||||
}
|
||||
}
|
||||
iter::once(Statement::Label(id_def.new_non_variable(None)))
|
||||
|
@ -1944,6 +1982,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
Statement::PtrAccess(ptr_access) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
|
||||
}
|
||||
Statement::RepackVector(repack) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, repack)?
|
||||
}
|
||||
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
|
||||
_ => return Err(TranslateError::Unreachable),
|
||||
}
|
||||
|
@ -2111,6 +2152,12 @@ fn expand_arguments<'a, 'b>(
|
|||
result.push(Statement::PtrAccess(new_inst));
|
||||
result.extend(post_stmts);
|
||||
}
|
||||
Statement::RepackVector(repack) => {
|
||||
let mut visitor = FlattenArguments::new(&mut result, id_def);
|
||||
let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
|
||||
result.push(Statement::RepackVector(new_inst));
|
||||
result.extend(post_stmts);
|
||||
}
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
||||
|
@ -2120,8 +2167,6 @@ fn expand_arguments<'a, 'b>(
|
|||
Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
|
||||
return Err(TranslateError::Unreachable)
|
||||
}
|
||||
Statement::PackVector(_) => todo!(),
|
||||
Statement::UnpackVector(_) => todo!(),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -2445,6 +2490,13 @@ fn insert_implicit_conversions(
|
|||
Some(state_space),
|
||||
)?;
|
||||
}
|
||||
Statement::RepackVector(repack) => insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
repack,
|
||||
should_bitcast_wrapper,
|
||||
None,
|
||||
)?,
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Conversion(_)
|
||||
| s @ Statement::Label(_)
|
||||
|
@ -2454,8 +2506,6 @@ fn insert_implicit_conversions(
|
|||
| s @ Statement::StoreVar(_, _)
|
||||
| s @ Statement::Undef(_, _)
|
||||
| s @ Statement::RetValue(_, _) => result.push(s),
|
||||
Statement::PackVector(_) => todo!(),
|
||||
Statement::UnpackVector(_) => todo!(),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -3081,8 +3131,38 @@ fn emit_function_body_ops(
|
|||
)?;
|
||||
builder.bitcast(result_type, Some(*dst), temp)?;
|
||||
}
|
||||
Statement::PackVector(_) => todo!(),
|
||||
Statement::UnpackVector(_) => todo!(),
|
||||
Statement::RepackVector(repack) => {
|
||||
if repack.is_extract {
|
||||
let scalar_type = map.get_or_add_scalar(builder, repack.typ);
|
||||
for (index, dst_id) in repack.unpacked.iter().enumerate() {
|
||||
builder.composite_extract(
|
||||
scalar_type,
|
||||
Some(*dst_id),
|
||||
repack.packed,
|
||||
&[index as u32],
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let vector_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Vector(
|
||||
SpirvScalarKey::from(repack.typ),
|
||||
repack.unpacked.len() as u8,
|
||||
),
|
||||
);
|
||||
let mut temp_vec = builder.undef(vector_type, None);
|
||||
for (index, src_id) in repack.unpacked.iter().enumerate() {
|
||||
temp_vec = builder.composite_insert(
|
||||
vector_type,
|
||||
None,
|
||||
*src_id,
|
||||
temp_vec,
|
||||
&[index as u32],
|
||||
)?;
|
||||
}
|
||||
builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -4334,9 +4414,7 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
},
|
||||
)?;
|
||||
result.push(new_statement);
|
||||
for s in post_statements {
|
||||
result.push(s);
|
||||
}
|
||||
result.extend(post_statements);
|
||||
}
|
||||
Statement::Call(call) => {
|
||||
let mut post_statements = Vec::new();
|
||||
|
@ -4354,9 +4432,25 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
},
|
||||
)?;
|
||||
result.push(new_statement);
|
||||
for s in post_statements {
|
||||
result.push(s);
|
||||
}
|
||||
result.extend(post_statements);
|
||||
}
|
||||
Statement::RepackVector(pack) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = pack.visit_variable(
|
||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&func_args_ptr,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
arg_desc,
|
||||
expected_type,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
_ => return Err(TranslateError::Unreachable),
|
||||
}
|
||||
|
@ -4810,8 +4904,7 @@ enum Statement<I, P: ast::ArgParams> {
|
|||
RetValue(ast::RetData, spirv::Word),
|
||||
Undef(ast::Type, spirv::Word),
|
||||
PtrAccess(PtrAccess<P>),
|
||||
PackVector(PackVector),
|
||||
UnpackVector(UnpackVector),
|
||||
RepackVector(RepackVector),
|
||||
}
|
||||
|
||||
impl ExpandedStatement {
|
||||
|
@ -4898,14 +4991,101 @@ impl ExpandedStatement {
|
|||
offset_src: constant_src,
|
||||
})
|
||||
}
|
||||
Statement::PackVector(_) => todo!(),
|
||||
Statement::UnpackVector(_) => todo!(),
|
||||
Statement::RepackVector(_) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PackVector {}
|
||||
struct UnpackVector {}
|
||||
struct RepackVector {
|
||||
is_extract: bool,
|
||||
typ: ast::ScalarType,
|
||||
packed: spirv::Word,
|
||||
unpacked: Vec<spirv::Word>,
|
||||
vector_sema: ArgumentSemantics,
|
||||
}
|
||||
|
||||
impl RepackVector {
|
||||
fn map<
|
||||
From: ArgParamsEx<Id = spirv::Word>,
|
||||
To: ArgParamsEx<Id = spirv::Word>,
|
||||
V: ArgumentMapVisitor<From, To>,
|
||||
>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
) -> Result<RepackVector, TranslateError> {
|
||||
let scalar = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: self.packed,
|
||||
is_dst: !self.is_extract,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
|
||||
)?;
|
||||
let scalar_type = self.typ;
|
||||
let is_extract = self.is_extract;
|
||||
let vector_sema = self.vector_sema;
|
||||
let vector = self
|
||||
.unpacked
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: id,
|
||||
is_dst: is_extract,
|
||||
sema: vector_sema,
|
||||
},
|
||||
Some(&ast::Type::Scalar(scalar_type)),
|
||||
)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(RepackVector {
|
||||
is_extract,
|
||||
typ: self.typ,
|
||||
packed: scalar,
|
||||
unpacked: vector,
|
||||
vector_sema,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VisitVariable for RepackVector {
|
||||
fn visit_variable<
|
||||
'a,
|
||||
F: FnMut(
|
||||
ArgumentDescriptor<spirv::Word>,
|
||||
Option<&ast::Type>,
|
||||
) -> Result<spirv::Word, TranslateError>,
|
||||
>(
|
||||
self,
|
||||
f: &mut F,
|
||||
) -> Result<TypedStatement, TranslateError> {
|
||||
Ok(TypedStatement::RepackVector(
|
||||
self.map::<TypedArgParams, _, _>(f)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl VisitVariableExpanded for RepackVector {
|
||||
fn visit_variable_extended<
|
||||
F: FnMut(
|
||||
ArgumentDescriptor<spirv::Word>,
|
||||
Option<&ast::Type>,
|
||||
) -> Result<spirv::Word, TranslateError>,
|
||||
>(
|
||||
self,
|
||||
f: &mut F,
|
||||
) -> Result<ExpandedStatement, TranslateError> {
|
||||
Ok(ExpandedStatement::RepackVector(
|
||||
self.map::<ExpandedArgParams, _, _>(f)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct UnpackVector {
|
||||
typ: ast::ScalarType,
|
||||
dst: Vec<spirv::Word>,
|
||||
src: spirv::Word,
|
||||
}
|
||||
|
||||
struct ResolvedCall<P: ast::ArgParams> {
|
||||
pub uniform: bool,
|
||||
|
@ -6737,22 +6917,6 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::Type {
|
||||
fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
|
||||
match self {
|
||||
ast::Type::Vector(t, len) => Ok((*t, *len)),
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
|
||||
match self {
|
||||
ast::Type::Scalar(t) => Ok(*t),
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::SrcOperand<T> {
|
||||
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
|
||||
self,
|
||||
|
|
Loading…
Add table
Reference in a new issue