Carry state space with pointer

This commit is contained in:
Andrzej Janik 2021-05-15 15:58:11 +02:00
commit 82b5cef0bd
3 changed files with 108 additions and 106 deletions

View file

@ -108,10 +108,49 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub enum Type { pub enum Type {
// .param.b32 foo;
// -> OpTypeInt
Scalar(ScalarType), Scalar(ScalarType),
// .param.v2.b32 foo;
// -> OpTypeVector
Vector(ScalarType, u8), Vector(ScalarType, u8),
// .param.b32 foo[4];
// -> OpTypeArray
Array(ScalarType, Vec<u32>), Array(ScalarType, Vec<u32>),
Pointer(ScalarType), /*
Variables of this type almost never exist in the original .ptx and are
usually artificially created. Some examples below:
- extern pointers to the .shared memory in the form:
.extern .shared .b32 shared_mem[];
which we first parse as
.extern .shared .b32 shared_mem;
and then convert to an additional function parameter:
.param .ptr<.b32.shared> shared_mem;
and do a load at the start of the function (and renames inside fn):
.reg .ptr<.b32.shared> temp;
ld.param.ptr<.b32.shared> temp, [shared_mem];
note, we don't support non-.shared extern pointers, because there's
zero use for them in the ptxas
- artifical pointers created by stateful conversion, which work
similiarly to the above
- function parameters:
foobar(.param .align 4 .b8 numbers[])
which get parsed to
foobar(.param .align 4 .b8 numbers)
and then converted to
foobar(.reg .align 4 .ptr<.b8.param> numbers)
- ld/st with offset:
.reg.b32 x;
.param.b64 arg0;
st.param.b32 [arg0+4], x;
Yes, this code is legal and actually emitted by the NV compiler!
We convert the st to:
.reg ptr<.b64.param> temp = ptr_offset(arg0, 4);
st.param.b32 [temp], x;
*/
// .reg ptr<.b64.param>
// -> OpTypePointer Function
Pointer(ScalarType, StateSpace),
} }
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]

View file

@ -624,9 +624,9 @@ ModuleVariable: ast::Variable<&'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
} }
if space == ".global" { if space == ".global" {
(ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new()) (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new())
} else { } else {
(ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new()) (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new())
} }
} }
}; };
@ -648,7 +648,7 @@ ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
(ast::Type::Array(t, dimensions), init) (ast::Type::Array(t, dimensions), init)
} }
ast::ArrayOrPointer::Pointer => { ast::ArrayOrPointer::Pointer => {
(ast::Type::Pointer(t), Vec::new()) (ast::Type::Scalar(t), Vec::new())
} }
}; };
(align, array_init, v_type, name) (align, array_init, v_type, name)

View file

@ -56,33 +56,20 @@ enum SpirvType {
} }
impl SpirvType { impl SpirvType {
fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self { fn new(t: ast::Type) -> Self {
match t { match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
ast::Type::Pointer(pointer_t) => { ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
let spirv_space = match decl_space { Box::new(SpirvType::Base(pointer_t.into())),
ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { space.to_spirv(),
spirv::StorageClass::Private ),
}
ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
ast::StateSpace::Generic => spirv::StorageClass::Generic,
ast::StateSpace::Sreg => spirv::StorageClass::Input,
};
SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space)
}
} }
} }
fn pointer_to( fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
t: ast::Type, let key = Self::new(t);
inner_space: ast::StateSpace,
outer_space: spirv::StorageClass,
) -> Self {
let key = Self::new(t, inner_space);
SpirvType::Pointer(Box::new(key), outer_space) SpirvType::Pointer(Box::new(key), outer_space)
} }
} }
@ -394,7 +381,7 @@ impl TypeWordMap {
b.constant_composite(result_type, None, components.into_iter()) b.constant_composite(result_type, None, components.into_iter())
} }
}, },
ast::Type::Pointer(typ) => return Err(error_unreachable()), ast::Type::Pointer(..) => return Err(error_unreachable()),
}) })
} }
@ -453,7 +440,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let must_link_ptx_impl = ptx_impl_imports.len() > 0; let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let directives = ptx_impl_imports let mut directives = ptx_impl_imports
.into_iter() .into_iter()
.map(|(_, v)| v) .map(|(_, v)| v)
.chain(directives.into_iter()) .chain(directives.into_iter())
@ -461,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut builder = dr::Builder::new(); let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id()); builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives); let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); //let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives); normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives); let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@ -725,6 +712,7 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
transformation has a semantical meaning - we emit additional transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/ */
/*
fn convert_dynamic_shared_memory_usage<'input>( fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>, module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut() -> spirv::Word,
@ -819,7 +807,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
ast::Variable { ast::Variable {
name: shared_id_param, name: shared_id_param,
align: None, align: None,
v_type: ast::Type::Pointer(ast::ScalarType::B8), v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()),
state_space: ast::StateSpace::Shared, state_space: ast::StateSpace::Shared,
array_init: Vec::new(), array_init: Vec::new(),
} }
@ -937,6 +925,7 @@ fn get_callers_of_extern_shared_single<'a>(
} }
} }
} }
*/
type DenormCountMap<T> = HashMap<T, isize>; type DenormCountMap<T> = HashMap<T, isize>;
@ -1031,11 +1020,7 @@ fn emit_builtins(
for (reg, id) in id_defs.special_registers.builtins() { for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add( let result_type = map.get_or_add(
builder, builder,
SpirvType::pointer_to( SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input),
reg.get_type(),
ast::StateSpace::Reg,
spirv::StorageClass::Input,
),
); );
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
builder.decorate( builder.decorate(
@ -1144,10 +1129,7 @@ fn emit_function_header<'a>(
} }
*/ */
for input in &func_decl.input_arguments { for input in &func_decl.input_arguments {
let result_type = map.get_or_add( let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone()));
builder,
SpirvType::new(input.v_type.clone(), input.state_space),
);
builder.function_parameter(Some(input.name), result_type)?; builder.function_parameter(Some(input.name), result_type)?;
} }
Ok(fn_id) Ok(fn_id)
@ -1753,8 +1735,8 @@ fn to_ptx_impl_atomic_call(
input_arguments: vec![ input_arguments: vec![
ast::Variable { ast::Variable {
align: None, align: None,
v_type: ast::Type::Pointer(typ), v_type: ast::Type::Pointer(typ, ptr_space),
state_space: ptr_space, state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None), name: id_defs.register_intermediate(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
@ -1791,7 +1773,11 @@ fn to_ptx_impl_atomic_call(
func: fn_id, func: fn_id,
ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![ param_list: vec![
(arg.src1, ast::Type::Pointer(typ), ptr_space), (
arg.src1,
ast::Type::Pointer(typ, ptr_space),
ast::StateSpace::Reg,
),
( (
arg.src2, arg.src2,
ast::Type::Scalar(scalar_typ), ast::Type::Scalar(scalar_typ),
@ -2629,8 +2615,8 @@ fn insert_implicit_conversions(
is_dst: false, is_dst: false,
sema: ArgumentSemantics::PhysicalPointer, sema: ArgumentSemantics::PhysicalPointer,
}, },
typ: &ast::Type::Pointer(underlying_type), typ: &ast::Type::Pointer(underlying_type, state_space),
state_space, state_space: new_todo!(),
stmt_ctor: |new_ptr_src| { stmt_ctor: |new_ptr_src| {
Statement::PtrAccess(PtrAccess { Statement::PtrAccess(PtrAccess {
underlying_type, underlying_type,
@ -2758,10 +2744,10 @@ fn get_function_type(
builder, builder,
spirv_input spirv_input
.iter() .iter()
.map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), .map(|var| SpirvType::new(var.v_type.clone())),
spirv_output spirv_output
.iter() .iter()
.map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), .map(|var| SpirvType::new(var.v_type.clone())),
) )
} }
@ -2790,7 +2776,7 @@ fn emit_function_body_ops(
Statement::Call(call) => { Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params { let (result_type, result_id) = match &*call.ret_params {
[(id, typ, space)] => ( [(id, typ, space)] => (
map.get_or_add(builder, SpirvType::new(typ.clone(), *space)), map.get_or_add(builder, SpirvType::new(typ.clone())),
Some(*id), Some(*id),
), ),
[] => (map.void(), None), [] => (map.void(), None),
@ -2922,10 +2908,8 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak { if data.qualifier != ast::LdStQualifier::Weak {
todo!() todo!()
} }
let result_type = map.get_or_add( let result_type =
builder, map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space),
);
builder.load( builder.load(
result_type, result_type,
Some(arg.dst), Some(arg.dst),
@ -2956,10 +2940,8 @@ fn emit_function_body_ops(
// SPIR-V does not support ret as guaranteed-converged // SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => { ast::Instruction::Mov(d, arg) => {
let result_type = map.get_or_add( let result_type =
builder, map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone())));
SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg),
);
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
ast::Instruction::Mul(mul, arg) => match mul { ast::Instruction::Mul(mul, arg) => match mul {
@ -3000,8 +2982,7 @@ fn emit_function_body_ops(
ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a) => {
let full_type = ast::Type::Scalar(*t); let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of(); let size_of = full_type.size_of();
let result_type = let result_type = map.get_or_add(builder, SpirvType::new(full_type));
map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
} }
@ -3265,7 +3246,6 @@ fn emit_function_body_ops(
builder, builder,
SpirvType::pointer_to( SpirvType::pointer_to(
details.typ.clone(), details.typ.clone(),
details.state_space,
spirv::StorageClass::Function, spirv::StorageClass::Function,
), ),
); );
@ -3297,11 +3277,11 @@ fn emit_function_body_ops(
}) => { }) => {
let u8_pointer = map.get_or_add( let u8_pointer = map.get_or_add(
builder, builder,
SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space), SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
); );
let result_type = map.get_or_add( let result_type = map.get_or_add(
builder, builder,
SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space), SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
); );
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain( let temp = builder.in_bounds_ptr_access_chain(
@ -3596,15 +3576,12 @@ fn emit_variable(
&*var.array_init, &*var.array_init,
)?) )?)
} else if must_init { } else if must_init {
let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space)); let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
Some(builder.constant_null(type_id, None)) Some(builder.constant_null(type_id, None))
} else { } else {
None None
}; };
let ptr_type_id = map.get_or_add( let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
builder,
SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class),
);
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align { if let Some(align) = var.align {
builder.decorate( builder.decorate(
@ -3742,10 +3719,7 @@ fn emit_min(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
}; };
let inst_type = map.get_or_add( let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder,
SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
);
builder.ext_inst( builder.ext_inst(
inst_type, inst_type,
Some(arg.dst), Some(arg.dst),
@ -3770,10 +3744,7 @@ fn emit_max(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
}; };
let inst_type = map.get_or_add( let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder,
SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
);
builder.ext_inst( builder.ext_inst(
inst_type, inst_type,
Some(arg.dst), Some(arg.dst),
@ -4255,14 +4226,13 @@ fn emit_implicit_conversion(
(_, _, ConversionKind::BitToPtr) => { (_, _, ConversionKind::BitToPtr) => {
let dst_type = map.get_or_add( let dst_type = map.get_or_add(
builder, builder,
SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()), SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()),
); );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
} }
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width { if from_parts.width == to_parts.width {
let dst_type = let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
if from_parts.scalar_kind != ast::ScalarKind::Float if from_parts.scalar_kind != ast::ScalarKind::Float
&& to_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float
{ {
@ -4275,13 +4245,10 @@ fn emit_implicit_conversion(
// This block is safe because it's illegal to implictly convert between floating point values // This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add( let same_width_bit_type = map.get_or_add(
builder, builder,
SpirvType::new( SpirvType::new(ast::Type::from_parts(TypeParts {
ast::Type::from_parts(TypeParts { scalar_kind: ast::ScalarKind::Bit,
scalar_kind: ast::ScalarKind::Bit, ..from_parts
..from_parts })),
}),
cv.from_space,
),
); );
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts { let wide_bit_type = ast::Type::from_parts(TypeParts {
@ -4289,7 +4256,7 @@ fn emit_implicit_conversion(
..to_parts ..to_parts
}); });
let wide_bit_type_spirv = let wide_bit_type_spirv =
map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space)); map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
if to_parts.scalar_kind == ast::ScalarKind::Unsigned if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|| to_parts.scalar_kind == ast::ScalarKind::Bit || to_parts.scalar_kind == ast::ScalarKind::Bit
{ {
@ -4323,15 +4290,13 @@ fn emit_implicit_conversion(
} }
} }
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
let result_type = let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
builder.s_convert(result_type, Some(cv.dst), cv.src)?; builder.s_convert(result_type, Some(cv.dst), cv.src)?;
} }
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
builder.bitcast(into_type, Some(cv.dst), cv.src)?; builder.bitcast(into_type, Some(cv.dst), cv.src)?;
} }
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
@ -4339,12 +4304,12 @@ fn emit_implicit_conversion(
map.get_or_add( map.get_or_add(
builder, builder,
SpirvType::Pointer( SpirvType::Pointer(
Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)), Box::new(SpirvType::new(cv.to_type.clone())),
spirv::StorageClass::Function, spirv::StorageClass::Function,
), ),
) )
} else { } else {
map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)) map.get_or_add(builder, SpirvType::new(cv.to_type.clone()))
}; };
builder.bitcast(result_type, Some(cv.dst), cv.src)?; builder.bitcast(result_type, Some(cv.dst), cv.src)?;
} }
@ -4358,18 +4323,14 @@ fn emit_load_var(
map: &mut TypeWordMap, map: &mut TypeWordMap,
details: &LoadVarDetails, details: &LoadVarDetails,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let result_type = map.get_or_add( let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
builder,
SpirvType::new(details.typ.clone(), details.state_space),
);
match details.member_index { match details.member_index {
Some((index, Some(width))) => { Some((index, Some(width))) => {
let vector_type = match details.typ { let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType), _ => return Err(TranslateError::MismatchedType),
}; };
let vector_type_spirv = let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
map.get_or_add(builder, SpirvType::new(vector_type, details.state_space));
let vector_temp = builder.load( let vector_temp = builder.load(
vector_type_spirv, vector_type_spirv,
None, None,
@ -4387,11 +4348,7 @@ fn emit_load_var(
Some((index, None)) => { Some((index, None)) => {
let result_ptr_type = map.get_or_add( let result_ptr_type = map.get_or_add(
builder, builder,
SpirvType::pointer_to( SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
details.typ.clone(),
details.state_space,
spirv::StorageClass::Function,
),
); );
let index_spirv = map.get_or_add_constant( let index_spirv = map.get_or_add_constant(
builder, builder,
@ -5661,7 +5618,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
ast::StateSpace::Reg => new_todo!(), ast::StateSpace::Reg => new_todo!(),
ast::StateSpace::Sreg => new_todo!(), ast::StateSpace::Sreg => new_todo!(),
}; };
let ptr_type = ast::Type::Pointer(self.underlying_type.clone()); let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!());
let new_dst = visitor.id( let new_dst = visitor.id(
ArgumentDescriptor { ArgumentDescriptor {
op: self.dst, op: self.dst,
@ -6231,24 +6188,28 @@ impl ast::Type {
match self { match self {
ast::Type::Scalar(scalar) => TypeParts { ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar, kind: TypeKind::Scalar,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: Vec::new(), components: Vec::new(),
}, },
ast::Type::Vector(scalar, components) => TypeParts { ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector, kind: TypeKind::Vector,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: vec![*components as u32], components: vec![*components as u32],
}, },
ast::Type::Array(scalar, components) => TypeParts { ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array, kind: TypeKind::Array,
state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: components.clone(), components: components.clone(),
}, },
ast::Type::Pointer(scalar) => TypeParts { ast::Type::Pointer(scalar, space) => TypeParts {
kind: TypeKind::PointerScalar, kind: TypeKind::Pointer,
state_space: *space,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: Vec::new(), components: Vec::new(),
@ -6269,9 +6230,10 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind), ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components, t.components,
), ),
TypeKind::PointerScalar => { TypeKind::Pointer => ast::Type::Pointer(
ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind)) ast::ScalarType::from_parts(t.width, t.scalar_kind),
} t.state_space,
),
} }
} }
@ -6292,6 +6254,7 @@ struct TypeParts {
kind: TypeKind, kind: TypeKind,
scalar_kind: ast::ScalarKind, scalar_kind: ast::ScalarKind,
width: u8, width: u8,
state_space: ast::StateSpace,
components: Vec<u32>, components: Vec<u32>,
} }
@ -6300,7 +6263,7 @@ enum TypeKind {
Scalar, Scalar,
Vector, Vector,
Array, Array,
PointerScalar, Pointer,
} }
impl ast::Instruction<ExpandedArgParams> { impl ast::Instruction<ExpandedArgParams> {