Fix st/ld offsets implement abs and fix remaining bugs from vectorAdd generation

This commit is contained in:
Andrzej Janik 2020-09-16 00:20:49 +02:00
parent fcf3aaeb16
commit 42bad8fcc2
5 changed files with 258 additions and 82 deletions

View file

@ -338,6 +338,7 @@ pub struct MovVectorDetails {
pub typ: MovVectorType,
pub length: u8,
}
#[derive(Copy, Clone)]
pub struct AbsDetails {
pub flush_to_zero: bool,
pub typ: ScalarType,

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry ld_st_offset(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 temp1;
.reg .u32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp1, [in_addr];
ld.u32 temp2, [in_addr+4];
st.u32 [out_addr], temp2;
st.u32 [out_addr+4], temp1;
ret;
}

View file

@ -0,0 +1,58 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%34 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "ld_st_offset"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%37 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%ulong_4_0 = OpConstant %ulong 4
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%32 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%26 = OpConvertUToPtr %_ptr_Generic_uint %15
%14 = OpLoad %uint %26
OpStore %6 %14
%17 = OpLoad %ulong %4
%27 = OpCopyObject %ulong %17
%23 = OpIAdd %ulong %27 %ulong_4
%28 = OpConvertUToPtr %_ptr_Generic_uint %23
%16 = OpLoad %uint %28
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %uint %7
%29 = OpConvertUToPtr %_ptr_Generic_uint %18
OpStore %29 %19
%20 = OpLoad %ulong %5
%21 = OpLoad %uint %6
%30 = OpCopyObject %ulong %20
%25 = OpIAdd %ulong %30 %ulong_4_0
%31 = OpConvertUToPtr %_ptr_Generic_uint %25
OpStore %31 %21
OpReturn
OpFunctionEnd

View file

@ -58,6 +58,7 @@ test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -606,9 +606,9 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
) {
let mut post_statements = Vec::new();
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
let id_type = match (id_def.get_type(desc.op), desc.is_pointer) {
(Some(t), false) => t,
(Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
let id_type = match (id_def.get_type(desc.op), desc.sema) {
(Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t,
(Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
};
let generated_id = id_def.new_id(Some(id_type));
@ -725,33 +725,71 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
id
}
ast::Operand::RegOffset(reg, offset) => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
}));
let result_id = self.id_def.new_id(Some(typ));
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
result_id
match desc.sema {
ArgumentSemantics::Default => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
let id_constant_stmt =
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
let result_id = self.id_def.new_id(Some(typ));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
}));
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
result_id
}
ArgumentSemantics::Ptr => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt =
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
let result_id = self.id_def.new_id(Some(typ));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
}));
let int_type = ast::IntType::U64;
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
result_id
}
ArgumentSemantics::ParamPtr => {
if offset == 0 {
return reg;
}
// Will be needed for arrays
todo!()
}
}
}
}
}
@ -977,8 +1015,9 @@ fn emit_function_body_ops(
let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => (
map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
*id,
Some(*id),
),
[] => (map.void(), None),
_ => todo!(),
};
let arg_list = call
@ -986,7 +1025,7 @@ fn emit_function_body_ops(
.iter()
.map(|(id, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
Statement::Variable(ast::Variable {
align,
@ -1047,7 +1086,7 @@ fn emit_function_body_ops(
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Abs(_, _) => todo!(),
ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?,
ast::Instruction::Call(_) => unreachable!(),
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
@ -1098,10 +1137,8 @@ fn emit_function_body_ops(
ast::MulDetails::Float(_) => todo!(),
},
ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => {
emit_add_int(builder, map, desc, arg)?;
}
ast::AddDetails::Float(_) => todo!(),
ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?,
ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
@ -1184,6 +1221,21 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::AddFloatDesc,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
Ok(())
}
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1415,6 +1467,30 @@ fn emit_mul_int(
Ok(())
}
fn emit_abs(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
d: &ast::AbsDetails,
arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
};
builder.ext_inst(
result_type,
Some(arg.dst),
opencl,
cl_abs as spirv::Word,
[arg.src],
)?;
Ok(())
}
fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1821,7 +1897,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
ArgumentDescriptor {
op: id,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(typ.into()),
);
@ -1832,7 +1908,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
ArgumentDescriptor {
op: self.func,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
None,
);
@ -1844,7 +1920,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
ArgumentDescriptor {
op: id,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
typ.into(),
);
@ -2024,7 +2100,14 @@ where
struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
is_pointer: bool,
sema: ArgumentSemantics,
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum ArgumentSemantics {
Default,
Ptr,
ParamPtr,
}
impl<T> ArgumentDescriptor<T> {
@ -2032,7 +2115,7 @@ impl<T> ArgumentDescriptor<T> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
is_pointer: self.is_pointer,
sema: self.sema,
}
}
}
@ -2046,13 +2129,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::MovVector(t, a) => {
ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length)))
}
ast::Instruction::Abs(_, _) => todo!(),
ast::Instruction::Abs(d, arg) => {
ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ)))
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let src_is_pointer = d.state_space != ast::LdStateSpace::Param;
ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer))
let is_param = d.state_space == ast::LdStateSpace::Param;
ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param))
}
ast::Instruction::Mov(mov_type, a) => {
ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into()))
@ -2100,8 +2185,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let param_space = d.state_space == ast::StStateSpace::Param;
ast::Instruction::St(d, a.map(visitor, inst_type, param_space))
let is_param = d.state_space == ast::StStateSpace::Param;
ast::Instruction::St(d, a.map(visitor, inst_type, is_param))
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
@ -2350,7 +2435,7 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2369,7 +2454,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(t),
),
@ -2377,7 +2462,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2388,14 +2473,14 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
self,
visitor: &mut V,
t: ast::Type,
is_src_pointer: bool,
is_param: bool,
) -> ast::Arg2<U> {
ast::Arg2 {
dst: visitor.variable(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(t),
),
@ -2403,7 +2488,11 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_pointer: is_src_pointer,
sema: if is_param {
ArgumentSemantics::ParamPtr
} else {
ArgumentSemantics::Ptr
},
},
t,
),
@ -2421,7 +2510,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(dst_t),
),
@ -2429,7 +2518,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
src_t,
),
@ -2442,14 +2531,18 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
self,
visitor: &mut V,
t: ast::Type,
param_space: bool,
is_param: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
src1: visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: param_space,
is_pointer: !param_space,
is_dst: is_param,
sema: if is_param {
ArgumentSemantics::ParamPtr
} else {
ArgumentSemantics::Ptr
},
},
t,
),
@ -2457,7 +2550,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2486,7 +2579,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2496,7 +2589,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: composite_src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2504,7 +2597,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2514,7 +2607,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2522,7 +2615,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
(scalar_type, vec_len),
),
@ -2533,7 +2626,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2543,7 +2636,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: composite_src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
@ -2551,7 +2644,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
ArgumentDescriptor {
op: src,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
(scalar_type, vec_len),
),
@ -2571,7 +2664,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(t),
),
@ -2579,7 +2672,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2587,7 +2680,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2604,7 +2697,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(t),
),
@ -2612,7 +2705,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2620,7 +2713,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
ast::Type::Scalar(ast::ScalarType::U32),
),
@ -2639,7 +2732,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
),
@ -2648,7 +2741,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
)
@ -2657,7 +2750,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2665,7 +2758,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2684,7 +2777,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
),
@ -2693,7 +2786,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
)
@ -2702,7 +2795,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2710,7 +2803,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
t,
),
@ -2718,7 +2811,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: self.src3,
is_dst: false,
is_pointer: false,
sema: ArgumentSemantics::Default,
},
ast::Type::Scalar(ast::ScalarType::Pred),
),