mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Emit movs
This commit is contained in:
parent
f2f3eeb303
commit
3d6991e0ca
8 changed files with 348 additions and 78 deletions
|
@ -48,4 +48,9 @@ Which is sensible, but completely untrue. In reality ptxas compiles silly code l
|
|||
```
|
||||
* Surprise, surprise, there's two kind of implicit conversions at play in the example above:
|
||||
* "Relaxed type-checking rules": this is the conversion of b16 operation type to s32 dst register
|
||||
* Undocumented type coercion when dereferencing param_1. The PTX behaviour is to coerce **every** type. It's something like `[param_1] = *(b16*)param_1`
|
||||
* Undocumented type coercion when dereferencing param_1. The PTX behaviour is to coerce **every** type. It's something to the effect of `[param_1] = *(b16*)param_1`
|
||||
|
||||
PTX grammar
|
||||
-----------
|
||||
* PTX grammar rules are atrocious, keywords can be freely reused as ids without escaping
|
||||
* Modifiers can be applied to instructions in any arbitrary order. We don't support it and hope we will never have to
|
||||
|
|
|
@ -17,6 +17,15 @@ macro_rules! check {
|
|||
};
|
||||
}
|
||||
|
||||
macro_rules! check_panic {
|
||||
($expr:expr) => {
|
||||
let err = unsafe { $expr };
|
||||
if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
panic!(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Error {
|
||||
NotReady = 1,
|
||||
|
@ -295,7 +304,7 @@ impl CommandQueue {
|
|||
impl Drop for CommandQueue {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeCommandQueueDestroy(self.0) };
|
||||
check_panic! { sys::zeCommandQueueDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -344,7 +353,7 @@ impl Module {
|
|||
impl Drop for Module {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeModuleDestroy(self.0) };
|
||||
check_panic! { sys::zeModuleDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,7 +417,7 @@ impl<T: SafeRepr> DeviceBuffer<T> {
|
|||
impl<T: SafeRepr> Drop for DeviceBuffer<T> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeDriverFreeMem(self.driver, self.ptr) };
|
||||
check_panic! { sys::zeDriverFreeMem(self.driver, self.ptr) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -506,7 +515,7 @@ impl<'a> CommandList<'a> {
|
|||
impl<'a> Drop for CommandList<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeCommandListDestroy(self.0) };
|
||||
check_panic! { sys::zeCommandListDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -531,9 +540,9 @@ impl<'a> FenceGuard<'a> {
|
|||
impl<'a> Drop for FenceGuard<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeFenceHostSynchronize(self.0, u32::max_value()) };
|
||||
unsafe { sys::zeFenceDestroy(self.0) };
|
||||
unsafe { sys::zeCommandListDestroy(self.1) };
|
||||
check_panic! { sys::zeFenceHostSynchronize(self.0, u32::max_value()) };
|
||||
check_panic! { sys::zeFenceDestroy(self.0) };
|
||||
check_panic! { sys::zeCommandListDestroy(self.1) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -653,7 +662,7 @@ impl<'a> EventPool<'a> {
|
|||
impl<'a> Drop for EventPool<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeEventPoolDestroy(self.0) };
|
||||
check_panic! { sys::zeEventPoolDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -684,7 +693,7 @@ impl<'a> Event<'a> {
|
|||
impl<'a> Drop for Event<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeEventDestroy(self.0) };
|
||||
check_panic! { sys::zeEventDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -744,6 +753,6 @@ impl<'a> Kernel<'a> {
|
|||
impl<'a> Drop for Kernel<'a> {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
unsafe { sys::zeKernelDestroy(self.0) };
|
||||
check_panic! { sys::zeKernelDestroy(self.0) };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -239,7 +239,9 @@ pub enum LdCacheOperator {
|
|||
Uncached,
|
||||
}
|
||||
|
||||
pub struct MovData {}
|
||||
pub struct MovData {
|
||||
pub typ: Type,
|
||||
}
|
||||
|
||||
pub struct MulData {}
|
||||
|
||||
|
|
|
@ -8,12 +8,149 @@ match {
|
|||
r"\s+" => { },
|
||||
r"//[^\n\r]*[\n\r]*" => { },
|
||||
r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { },
|
||||
r"-?[?:0x]?[0-9]+" => Num,
|
||||
r#""[^"]*""# => String,
|
||||
r"[0-9]+\.[0-9]+" => VersionNumber,
|
||||
"!",
|
||||
"(", ")",
|
||||
"+",
|
||||
",",
|
||||
".",
|
||||
":",
|
||||
";",
|
||||
"@",
|
||||
"[", "]",
|
||||
"{", "}",
|
||||
"|",
|
||||
".acquire",
|
||||
".address_size",
|
||||
".and",
|
||||
".b16",
|
||||
".b32",
|
||||
".b64",
|
||||
".b8",
|
||||
".ca",
|
||||
".cg",
|
||||
".const",
|
||||
".cs",
|
||||
".cta",
|
||||
".cv",
|
||||
".entry",
|
||||
".eq",
|
||||
".equ",
|
||||
".extern",
|
||||
".f16",
|
||||
".f16x2",
|
||||
".f32",
|
||||
".f64",
|
||||
".file",
|
||||
".ftz",
|
||||
".func",
|
||||
".ge",
|
||||
".geu",
|
||||
".global",
|
||||
".gpu",
|
||||
".gt",
|
||||
".gtu",
|
||||
".hi",
|
||||
".hs",
|
||||
".le",
|
||||
".leu",
|
||||
".lo",
|
||||
".loc",
|
||||
".local",
|
||||
".ls",
|
||||
".lt",
|
||||
".ltu",
|
||||
".lu",
|
||||
".nan",
|
||||
".ne",
|
||||
".neu",
|
||||
".num",
|
||||
".or",
|
||||
".param",
|
||||
".pred",
|
||||
".reg",
|
||||
".relaxed",
|
||||
".rm",
|
||||
".rmi",
|
||||
".rn",
|
||||
".rni",
|
||||
".rp",
|
||||
".rpi",
|
||||
".rz",
|
||||
".rzi",
|
||||
".s16",
|
||||
".s32",
|
||||
".s64",
|
||||
".s8" ,
|
||||
".sat",
|
||||
".section",
|
||||
".shared",
|
||||
".sreg",
|
||||
".sys",
|
||||
".target",
|
||||
".u16",
|
||||
".u32",
|
||||
".u64",
|
||||
".u8" ,
|
||||
".uni",
|
||||
".v2",
|
||||
".v4",
|
||||
".version",
|
||||
".visible",
|
||||
".volatile",
|
||||
".wb",
|
||||
".weak",
|
||||
".wide",
|
||||
".wt",
|
||||
".xor",
|
||||
} else {
|
||||
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
|
||||
"add",
|
||||
"bra",
|
||||
"cvt",
|
||||
"debug",
|
||||
"ld",
|
||||
"map_f64_to_f32",
|
||||
"mov",
|
||||
"mul",
|
||||
"not",
|
||||
"ret",
|
||||
"setp",
|
||||
"shl",
|
||||
"shr",
|
||||
r"sm_[0-9]+" => ShaderModel,
|
||||
r"-?[?:0x]?[0-9]+" => Num
|
||||
"st",
|
||||
"texmode_independent",
|
||||
"texmode_unified",
|
||||
} else {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
|
||||
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
|
||||
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
|
||||
} else {
|
||||
r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID,
|
||||
} else {
|
||||
_
|
||||
}
|
||||
|
||||
ExtendedID : &'input str = {
|
||||
"add",
|
||||
"bra",
|
||||
"cvt",
|
||||
"debug",
|
||||
"ld",
|
||||
"map_f64_to_f32",
|
||||
"mov",
|
||||
"mul",
|
||||
"not",
|
||||
"ret",
|
||||
"setp",
|
||||
"shl",
|
||||
"shr",
|
||||
ShaderModel,
|
||||
"st",
|
||||
"texmode_independent",
|
||||
"texmode_unified",
|
||||
ID
|
||||
}
|
||||
|
||||
pub Module: ast::Module<'input> = {
|
||||
|
@ -58,7 +195,7 @@ AddressSize = {
|
|||
Function: ast::Function<'input> = {
|
||||
LinkingDirective*
|
||||
<kernel:IsKernel>
|
||||
<name:ID>
|
||||
<name:ExtendedID>
|
||||
"(" <args:Comma<FunctionInput>> ")"
|
||||
<body:FunctionBody> => ast::Function{<>}
|
||||
};
|
||||
|
@ -76,10 +213,10 @@ IsKernel: bool = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
FunctionInput: ast::Argument<'input> = {
|
||||
".param" <_type:ScalarType> <name:ID> => {
|
||||
".param" <_type:ScalarType> <name:ExtendedID> => {
|
||||
ast::Argument {a_type: _type, name: name, length: 1 }
|
||||
},
|
||||
".param" <a_type:ScalarType> <name:ID> "[" <length:Num> "]" => {
|
||||
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => {
|
||||
let length = length.parse::<u32>();
|
||||
let length = length.unwrap_with(errors);
|
||||
ast::Argument { a_type: a_type, name: name, length: length }
|
||||
|
@ -149,7 +286,7 @@ DebugLocation = {
|
|||
};
|
||||
|
||||
Label: &'input str = {
|
||||
<id:ID> ":" => id
|
||||
<id:ExtendedID> ":" => id
|
||||
};
|
||||
|
||||
Variable: ast::Variable<&'input str> = {
|
||||
|
@ -160,7 +297,7 @@ Variable: ast::Variable<&'input str> = {
|
|||
};
|
||||
|
||||
VariableName: (&'input str, Option<u32>) = {
|
||||
<id:ID> => (id, None),
|
||||
<id:ExtendedID> => (id, None),
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||
<id:ParametrizedID> => {
|
||||
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
||||
|
@ -189,7 +326,7 @@ Instruction: ast::Instruction<&'input str> = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
InstLd: ast::Instruction<&'input str> = {
|
||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
|
||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
|
||||
ast::Instruction::Ld(
|
||||
ast::LdData {
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
|
@ -234,17 +371,24 @@ LdCacheOperator: ast::LdCacheOperator = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
InstMov: ast::Instruction<&'input str> = {
|
||||
"mov" MovType <a:Arg2Mov> => {
|
||||
ast::Instruction::Mov(ast::MovData{}, a)
|
||||
"mov" <t:MovType> <a:Arg2Mov> => {
|
||||
ast::Instruction::Mov(ast::MovData{ typ:t }, a)
|
||||
}
|
||||
};
|
||||
|
||||
MovType = {
|
||||
".b16", ".b32", ".b64",
|
||||
".u16", ".u32", ".u64",
|
||||
".s16", ".s32", ".s64",
|
||||
".f32", ".f64",
|
||||
".pred"
|
||||
MovType: ast::Type = {
|
||||
".b16" => ast::Type::Scalar(ast::ScalarType::B16),
|
||||
".b32" => ast::Type::Scalar(ast::ScalarType::B32),
|
||||
".b64" => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
".u16" => ast::Type::Scalar(ast::ScalarType::U16),
|
||||
".u32" => ast::Type::Scalar(ast::ScalarType::U32),
|
||||
".u64" => ast::Type::Scalar(ast::ScalarType::U64),
|
||||
".s16" => ast::Type::Scalar(ast::ScalarType::S16),
|
||||
".s32" => ast::Type::Scalar(ast::ScalarType::S32),
|
||||
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
|
||||
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
|
||||
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
|
||||
".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||
|
@ -263,7 +407,7 @@ InstMulMode: ast::MulData = {
|
|||
};
|
||||
|
||||
MulIntControl = {
|
||||
"hi", ".lo", ".wide"
|
||||
".hi", ".lo", ".wide"
|
||||
};
|
||||
|
||||
#[inline]
|
||||
|
@ -339,8 +483,8 @@ NotType = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
|
||||
PredAt: ast::PredAt<&'input str> = {
|
||||
"@" <label:ID> => ast::PredAt { not: false, label:label },
|
||||
"@" "!" <label:ID> => ast::PredAt { not: true, label:label }
|
||||
"@" <label:ExtendedID> => ast::PredAt { not: false, label:label },
|
||||
"@" "!" <label:ExtendedID> => ast::PredAt { not: true, label:label }
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
|
||||
|
@ -420,8 +564,8 @@ InstRet: ast::Instruction<&'input str> = {
|
|||
};
|
||||
|
||||
Operand: ast::Operand<&'input str> = {
|
||||
<r:ID> => ast::Operand::Reg(r),
|
||||
<r:ID> "+" <o:Num> => {
|
||||
<r:ExtendedID> => ast::Operand::Reg(r),
|
||||
<r:ExtendedID> "+" <o:Num> => {
|
||||
let offset = o.parse::<i32>();
|
||||
let offset = offset.unwrap_with(errors);
|
||||
ast::Operand::RegOffset(r, offset)
|
||||
|
@ -444,37 +588,37 @@ MovOperand: ast::MovOperand<&'input str> = {
|
|||
};
|
||||
|
||||
VectorOperand: (&'input str, &'input str) = {
|
||||
<pref:ID> "." <suf:ID> => (pref, suf),
|
||||
<pref:ID> <suf:DotID> => (pref, &suf[1..]),
|
||||
<pref:ExtendedID> "." <suf:ExtendedID> => (pref, suf),
|
||||
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
|
||||
};
|
||||
|
||||
Arg1: ast::Arg1<&'input str> = {
|
||||
<src:ID> => ast::Arg1{<>}
|
||||
<src:ExtendedID> => ast::Arg1{<>}
|
||||
};
|
||||
|
||||
Arg2: ast::Arg2<&'input str> = {
|
||||
<dst:ID> "," <src:Operand> => ast::Arg2{<>}
|
||||
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
|
||||
};
|
||||
|
||||
Arg2Mov: ast::Arg2Mov<&'input str> = {
|
||||
<dst:ID> "," <src:MovOperand> => ast::Arg2Mov{<>}
|
||||
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>}
|
||||
};
|
||||
|
||||
Arg3: ast::Arg3<&'input str> = {
|
||||
<dst:ID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
|
||||
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
|
||||
};
|
||||
|
||||
Arg4: ast::Arg4<&'input str> = {
|
||||
<dst1:ID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
|
||||
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
|
||||
};
|
||||
|
||||
// TODO: pass src3 negation somewhere
|
||||
Arg5: ast::Arg5<&'input str> = {
|
||||
<dst1:ID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
|
||||
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
|
||||
};
|
||||
|
||||
OptionalDst: &'input str = {
|
||||
"|" <dst2:ID> => dst2
|
||||
"|" <dst2:ExtendedID> => dst2
|
||||
}
|
||||
|
||||
VectorPrefix: ast::VectorPrefix = {
|
||||
|
@ -519,9 +663,3 @@ Comma<T>: Vec<T> = {
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
String = r#""[^"]*""#;
|
||||
VersionNumber = r"[0-9]+\.[0-9]+";
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
|
||||
ID: &'input str = <s:r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+"> => s;
|
||||
DotID: &'input str = <s:r"\.[a-zA-Z][a-zA-Z0-9_$]*"> => s;
|
|
@ -1,6 +1,6 @@
|
|||
use crate::ptx;
|
||||
use crate::translate;
|
||||
use rspirv::dr::{Block, Function, Instruction, Loader, Operand};
|
||||
use rspirv::{binary::Disassemble, dr::{Block, Function, Instruction, Loader, Operand}};
|
||||
use spirv_headers::Word;
|
||||
use spirv_tools_sys::{
|
||||
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
|
||||
|
@ -37,6 +37,7 @@ macro_rules! test_ptx {
|
|||
}
|
||||
|
||||
test_ptx!(ld_st, [1u64], [1u64]);
|
||||
test_ptx!(mov, [1u64], [1u64]);
|
||||
|
||||
struct DisplayError<T: Display + Debug> {
|
||||
err: T,
|
||||
|
@ -148,73 +149,110 @@ fn test_spvtxt_assert<'a>(
|
|||
ptr::null_mut(),
|
||||
)
|
||||
};
|
||||
assert!(result == spv_result_t::SPV_SUCCESS);
|
||||
let mut loader = Loader::new();
|
||||
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
|
||||
let spvtxt_mod = loader.module();
|
||||
assert_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]);
|
||||
assert!(result == spv_result_t::SPV_SUCCESS);
|
||||
if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) {
|
||||
panic!(ptx_mod.disassemble())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_spirv_fn_equal(fn1: &Function, fn2: &Function) {
|
||||
fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
|
||||
let mut map = HashMap::new();
|
||||
assert_option_equal(&fn1.def, &fn2.def, &mut map, assert_instr_equal);
|
||||
assert_option_equal(&fn1.end, &fn2.end, &mut map, assert_instr_equal);
|
||||
if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) {
|
||||
return false;
|
||||
}
|
||||
if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) {
|
||||
return false;
|
||||
}
|
||||
for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) {
|
||||
assert_instr_equal(inst1, inst2, &mut map);
|
||||
if !is_instr_equal(inst1, inst2, &mut map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) {
|
||||
assert_block_equal(b1, b2, &mut map);
|
||||
if !is_block_equal(b1, b2, &mut map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn assert_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) {
|
||||
assert_option_equal(&b1.label, &b2.label, map, assert_instr_equal);
|
||||
fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool {
|
||||
if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) {
|
||||
return false;
|
||||
}
|
||||
for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) {
|
||||
assert_instr_equal(inst1, inst2, map);
|
||||
if !is_instr_equal(inst1, inst2, map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn assert_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut HashMap<Word, Word>) {
|
||||
assert_option_equal(
|
||||
&instr1.result_type,
|
||||
&instr2.result_type,
|
||||
map,
|
||||
assert_word_equal,
|
||||
);
|
||||
assert_option_equal(&instr1.result_id, &instr2.result_id, map, assert_word_equal);
|
||||
fn is_instr_equal(
|
||||
instr1: &Instruction,
|
||||
instr2: &Instruction,
|
||||
map: &mut HashMap<Word, Word>,
|
||||
) -> bool {
|
||||
if !is_option_equal(&instr1.result_type, &instr2.result_type, map, is_word_equal) {
|
||||
return false;
|
||||
}
|
||||
if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) {
|
||||
return false;
|
||||
}
|
||||
for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) {
|
||||
match (o1, o2) {
|
||||
(Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => {
|
||||
assert_word_equal(w1, w2, map)
|
||||
if !is_word_equal(w1, w2, map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
(Operand::IdScope(w1), Operand::IdScope(w2)) => {
|
||||
if !is_word_equal(w1, w2, map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
(Operand::IdRef(w1), Operand::IdRef(w2)) => {
|
||||
if !is_word_equal(w1, w2, map) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
(o1, o2) => {
|
||||
if o1 != o2 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
(Operand::IdScope(w1), Operand::IdScope(w2)) => assert_word_equal(w1, w2, map),
|
||||
(Operand::IdRef(w1), Operand::IdRef(w2)) => assert_word_equal(w1, w2, map),
|
||||
(o1, o2) => assert_eq!(o1, o2),
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn assert_word_equal(w1: &Word, w2: &Word, map: &mut HashMap<Word, Word>) {
|
||||
fn is_word_equal(w1: &Word, w2: &Word, map: &mut HashMap<Word, Word>) -> bool {
|
||||
match map.entry(*w1) {
|
||||
std::collections::hash_map::Entry::Occupied(entry) => {
|
||||
assert_eq!(entry.get(), w2);
|
||||
if entry.get() != w2 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(*w2);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn assert_option_equal<T, F: FnOnce(&T, &T, &mut HashMap<Word, Word>)>(
|
||||
fn is_option_equal<T, F: FnOnce(&T, &T, &mut HashMap<Word, Word>) -> bool>(
|
||||
o1: &Option<T>,
|
||||
o2: &Option<T>,
|
||||
map: &mut HashMap<Word, Word>,
|
||||
f: F,
|
||||
) {
|
||||
) -> bool {
|
||||
match (o1, o2) {
|
||||
(Some(t1), Some(t2)) => f(t1, t2, map),
|
||||
(None, None) => (),
|
||||
(None, None) => true,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
|
22
ptx/src/test/spirv_run/mov.ptx
Normal file
22
ptx/src/test/spirv_run/mov.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mov(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
mov.u64 temp2, temp;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
26
ptx/src/test/spirv_run/mov.spvtxt
Normal file
26
ptx/src/test/spirv_run/mov.spvtxt
Normal file
|
@ -0,0 +1,26 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %5 "mov"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 64 0
|
||||
%4 = OpTypeFunction %2 %3 %3
|
||||
%19 = OpTypePointer Generic %3
|
||||
%5 = OpFunction %2 None %4
|
||||
%6 = OpFunctionParameter %3
|
||||
%7 = OpFunctionParameter %3
|
||||
%18 = OpLabel
|
||||
%13 = OpCopyObject %3 %6
|
||||
%14 = OpCopyObject %3 %7
|
||||
%15 = OpConvertUToPtr %19 %13
|
||||
%16 = OpLoad %3 %15
|
||||
%100 = OpCopyObject %3 %16
|
||||
%17 = OpConvertUToPtr %19 %14
|
||||
OpStore %17 %100
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -10,9 +10,19 @@ use rspirv::binary::Assemble;
|
|||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
enum SpirvType {
|
||||
Base(ast::ScalarType),
|
||||
Extended(ast::ExtendedScalarType),
|
||||
Pointer(ast::ScalarType, spirv::StorageClass),
|
||||
}
|
||||
|
||||
impl From<ast::Type> for SpirvType {
|
||||
fn from(t: ast::Type) -> Self {
|
||||
match t {
|
||||
ast::Type::Scalar(t) => SpirvType::Base(t),
|
||||
ast::Type::ExtendedScalar(t) => SpirvType::Extended(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TypeWordMap {
|
||||
void: spirv::Word,
|
||||
complex: HashMap<SpirvType, spirv::Word>,
|
||||
|
@ -50,9 +60,20 @@ impl TypeWordMap {
|
|||
})
|
||||
}
|
||||
|
||||
fn get_or_add_extended(&mut self, b: &mut dr::Builder, t: ast::ExtendedScalarType) -> spirv::Word {
|
||||
*self
|
||||
.complex
|
||||
.entry(SpirvType::Extended(t))
|
||||
.or_insert_with(|| match t {
|
||||
ast::ExtendedScalarType::Pred => b.type_bool(),
|
||||
ast::ExtendedScalarType::F16x2 => todo!(),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
|
||||
match t {
|
||||
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
|
||||
SpirvType::Extended(t) => self.get_or_add_extended(b, t),
|
||||
SpirvType::Pointer(scalar, storage) => {
|
||||
let base = self.get_or_add_scalar(b, scalar);
|
||||
*self
|
||||
|
@ -416,6 +437,14 @@ fn emit_function_body_ops(
|
|||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
ast::Instruction::Ret(_) => builder.ret()?,
|
||||
ast::Instruction::Mov(mov, arg) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
|
||||
let src = match arg.src {
|
||||
ast::MovOperand::Op(ast::Operand::Reg(id)) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.copy_object(result_type, Some(arg.dst), src)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
},
|
||||
}
|
||||
|
@ -1190,7 +1219,8 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Ret(_) => None,
|
||||
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
|
||||
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
|
||||
_ => todo!(),
|
||||
ast::Instruction::Mov(mov, _) => Some(mov.typ),
|
||||
_ => todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue