Add more tests

This commit is contained in:
Andrzej Janik 2020-09-20 15:44:52 +02:00
parent 17f2d09cc7
commit dcaea507ba
7 changed files with 175 additions and 139 deletions

View file

@ -13,8 +13,10 @@
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uchar = OpTypeInt 8 0
%_arr_uchar_8 = OpTypeArray %uchar %8
%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
%uint = OpTypeInt 32 0
%uint_8 = OpConstant %uint 8
%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8
%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %25
%8 = OpFunctionParameter %ulong
@ -22,7 +24,7 @@
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
%4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function

View file

@ -8,7 +8,7 @@ use spirv_headers::Word;
use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
};
use std::collections::hash_map::Entry;
use std::{collections::hash_map::Entry, cmp};
use std::error;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
@ -59,8 +59,9 @@ test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
//test_ptx!(ntid, [3u32], [4u32]);
//test_ptx!(reg_slm, [12u64], [12u64]);
test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [12u64]);
test_ptx!(mov_address, [0xDEADu64], [0u64]);
struct DisplayError<T: Debug> {
err: T,
@ -123,8 +124,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?;
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, input.len())?;
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, output.len())?;
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;

View file

@ -0,0 +1,15 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mov_address(
.param .u64 input,
.param .u64 output
)
{
.local .b8 __local_depot0[8];
.reg .u64 temp;
mov.u64 temp, __local_depot0;
ret;
}

View file

@ -2,12 +2,12 @@
.target sm_30
.address_size 64
.visible .entry reg_slm(
.visible .entry reg_local(
.param .u64 input,
.param .u64 output
)
{
.local .align 8 .b8 slm[8];
.local .align 8 .b8 local_x[8];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b64 temp;
@ -16,11 +16,9 @@
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
mov.s64 unused, slm;
ld.global.u64 temp, [in_addr];
st.u64 [slm], temp;
ld.u64 temp, [slm];
st.u64 [local_x], temp;
ld.u64 temp, [local_x];
st.global.u64 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,46 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpIAdd %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View file

@ -217,11 +217,13 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
for f in ssa_functions {
let f_body = match f.body {
Some(f) => f,
None => continue,
};
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?;
@ -229,6 +231,33 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
Ok(builder.module())
}
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver,
) {
for (reg, id) in id_defs.special_registers.iter() {
let result_type = map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(reg.get_type())),
spirv::StorageClass::UniformConstant,
),
);
builder.variable(
result_type,
Some(*id),
spirv::StorageClass::UniformConstant,
None,
);
builder.decorate(
*id,
spirv::Decoration::BuiltIn,
&[dr::Operand::BuiltIn(reg.get_builtin())],
);
}
}
fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -239,7 +268,12 @@ fn emit_function_header<'a>(
let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name)?;
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
let interface = global
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
fn_id
}
ast::MethodDecl::Func(_, name, _) => name,
@ -293,7 +327,7 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
) -> Result<ExpandedFunction<'a>, TranslateError> {
) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
}
@ -333,13 +367,14 @@ fn to_ssa<'input, 'b>(
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, ExpandedArgParams>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
) -> Result<ExpandedFunction<'input>, TranslateError> {
) -> Result<Function<'input>, TranslateError> {
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(ExpandedFunction {
return Ok(Function {
func_directive: f_args,
body: None,
globals: Vec::new(),
})
}
};
@ -357,12 +392,21 @@ fn to_ssa<'input, 'b>(
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let sorted_statements = normalize_variable_decls(labeled_statements);
Ok(ExpandedFunction {
let (f_body, globals) = extract_globals(sorted_statements);
Ok(Function {
func_directive: f_args,
body: Some(sorted_statements),
globals: globals,
body: Some(f_body),
})
}
fn extract_globals(
sorted_statements: Vec<ExpandedStatement>,
) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
// This fn will be used for SLM
(sorted_statements, Vec::new())
}
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0,
@ -477,7 +521,9 @@ fn add_types_to_statements(
},
_ => dets,
};
Ok(Statement::Instruction(ast::Instruction::MovVector(new_dets, args)))
Ok(Statement::Instruction(ast::Instruction::MovVector(
new_dets, args,
)))
}
s => Ok(s),
}
@ -724,7 +770,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
(_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
| (t, ArgumentSemantics::Ptr) => t,
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
let generated_id = id_def.new_id(id_type);
if !desc.is_dst {
@ -873,7 +919,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
));
Ok(result_id)
}
ArgumentSemantics::Ptr => {
ArgumentSemantics::PhysicalPointer => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
let result_id = self.id_def.new_id(typ);
@ -1137,7 +1183,7 @@ fn emit_function_body_ops(
builder.begin_block(Some(*id))?;
}
_ => {
if builder.block.is_none() {
if builder.block.is_none() && builder.function.is_some() {
builder.begin_block(None)?;
}
}
@ -1166,10 +1212,9 @@ fn emit_function_body_ops(
name,
}) => {
let st_class = match v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
spirv::StorageClass::Function
}
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
ast::VariableType::Reg(_)
| ast::VariableType::Param(_)
| ast::VariableType::Local(_) => spirv::StorageClass::Function,
};
let type_id = map.get_or_add(
builder,
@ -1234,7 +1279,7 @@ fn emit_function_body_ops(
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param => {
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
@ -1242,18 +1287,20 @@ fn emit_function_body_ops(
}
}
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| (data.state_space != ast::StStateSpace::Generic
&& data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
{
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
if data.state_space == ast::StStateSpace::Param {
if data.state_space == ast::StStateSpace::Param
|| data.state_space == ast::StStateSpace::Local
{
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
} else {
} else if data.state_space == ast::StStateSpace::Generic
|| data.state_space == ast::StStateSpace::Global
{
builder.store(arg.src1, arg.src2, None, &[])?;
} else {
todo!()
}
}
// SPIR-V does not support ret as guaranteed-converged
@ -1643,7 +1690,7 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::Ptr(space)) => {
(_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
@ -1699,14 +1746,11 @@ fn emit_implicit_conversion(
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => {
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.convert_ptr_to_u(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
}
Ok(())
@ -2181,7 +2225,7 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
decl.get_fn_decl_str(id)
}
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
fn get_src_semantics(_: &Self::MovOperand) -> ArgumentSemantics {
ArgumentSemantics::Default
}
}
@ -2230,7 +2274,12 @@ pub enum StateSpace {
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
struct Function<'input> {
pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
pub globals: Vec<ExpandedStatement>,
pub body: Option<Vec<ExpandedStatement>>,
}
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
@ -2248,7 +2297,7 @@ impl ArgParamsEx for ExpandedArgParams {
decl.get_fn_decl(*id)
}
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
fn get_src_semantics(_: &spirv::Word) -> ArgumentSemantics {
ArgumentSemantics::Default
}
}
@ -2398,12 +2447,12 @@ struct ArgumentDescriptor<Op> {
sema: ArgumentSemantics,
}
#[derive(Copy, Clone, PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
Default,
// st/ld global
Ptr,
PhysicalPointer,
// st/ld .param, .local
RegisterPointer,
// mov of .local/.global variables
@ -2720,7 +2769,8 @@ enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
Ptr(ast::LdStateSpace),
BitToPtr(ast::LdStateSpace),
PtrToBit,
}
impl<T> ast::PredAt<T> {
@ -2831,7 +2881,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: if is_param {
ArgumentSemantics::RegisterPointer
} else {
ArgumentSemantics::Ptr
ArgumentSemantics::PhysicalPointer
},
},
t,
@ -2919,7 +2969,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
sema: if is_param {
ArgumentSemantics::RegisterPointer
} else {
ArgumentSemantics::Ptr
ArgumentSemantics::PhysicalPointer
},
},
t,
@ -3518,7 +3568,7 @@ fn get_implicit_conversions_ld_src(
) -> Result<Vec<ImplicitConversion>, TranslateError> {
let src_type = id_def.get_typed(src)?;
match state_space {
ast::LdStateSpace::Param => {
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
if src_type != instr_type {
Ok(vec![
ImplicitConversion {
@ -3560,7 +3610,7 @@ fn get_implicit_conversions_ld_src(
dst: u32::max_value(),
from: src_type,
to: instr_type,
kind: ConversionKind::Ptr(state_space),
kind: ConversionKind::BitToPtr(state_space),
});
if result.len() == 2 {
let new_id = id_def.new_id(new_src_type);
@ -3570,92 +3620,9 @@ fn get_implicit_conversions_ld_src(
}
Ok(result)
}
_ => todo!(),
_ => Err(TranslateError::Todo),
}
}
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
id_def: &mut MutableNumericIdResolver,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> Result<spirv::Word, TranslateError> {
match state_space {
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
func,
id_def,
instr_type,
src,
should_convert_ld_param_src,
),
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
ScalarKind::Bit,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
id_def,
new_src_type,
src,
should_convert_ld_generic_src_to_bitcast,
)?;
Ok(insert_conversion_src(
func,
id_def,
new_src,
new_src_type,
instr_type,
ConversionKind::Ptr(state_space),
))
}
_ => todo!(),
}
}
fn insert_implicit_conversions_ld_src_impl<
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
) -> Result<spirv::Word, TranslateError> {
let src_type = id_def.get_typed(src)?;
if let Some(conv) = should_convert(src_type, instr_type) {
Ok(insert_conversion_src(
func, id_def, src, src_type, instr_type, conv,
))
} else {
Ok(src)
}
}
fn should_convert_ld_param_src(
src_type: ast::Type,
instr_type: ast::Type,
) -> Option<ConversionKind> {
if src_type != instr_type {
return Some(ConversionKind::Default);
}
None
}
// HACK ALERT
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
fn should_convert_ld_generic_src_to_bitcast(
src_type: ast::Type,
_instr_type: ast::Type,
) -> Option<ConversionKind> {
if let ast::Type::Scalar(src_type) = src_type {
if src_type.kind() == ScalarKind::Signed {
return Some(ConversionKind::Default);
}
}
None
}
#[must_use]
fn insert_conversion_src(
@ -3832,14 +3799,21 @@ fn insert_implicit_bitcasts(
None => return Ok(desc.op),
};
let id_actual_type = id_def.get_typed(desc.op)?;
if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
let conv_kind = if desc.sema == ArgumentSemantics::Address {
Some(ConversionKind::PtrToBit)
} else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
Some(ConversionKind::Default)
} else {
None
};
if let Some(conv_kind) = conv_kind {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
&mut desc.op,
id_type_from_instr,
id_actual_type,
ConversionKind::Default,
conv_kind,
));
Ok(desc.op)
} else {
@ -3849,7 +3823,7 @@ fn insert_implicit_bitcasts(
desc.op,
id_actual_type,
id_type_from_instr,
ConversionKind::Default,
conv_kind,
))
}
} else {