From 3d6991e0ca808f05025ee84574642efcdd7ed696 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 15 Jul 2020 23:56:00 +0200 Subject: [PATCH] Emit movs --- doc/NOTES.md | 7 +- level_zero/src/ze.rs | 29 ++-- ptx/src/ast.rs | 4 +- ptx/src/ptx.lalrpop | 212 ++++++++++++++++++++++++------ ptx/src/test/spirv_run/mod.rs | 94 +++++++++---- ptx/src/test/spirv_run/mov.ptx | 22 ++++ ptx/src/test/spirv_run/mov.spvtxt | 26 ++++ ptx/src/translate.rs | 32 ++++- 8 files changed, 348 insertions(+), 78 deletions(-) create mode 100644 ptx/src/test/spirv_run/mov.ptx create mode 100644 ptx/src/test/spirv_run/mov.spvtxt diff --git a/doc/NOTES.md b/doc/NOTES.md index be91c4c..f3a1a0b 100644 --- a/doc/NOTES.md +++ b/doc/NOTES.md @@ -48,4 +48,9 @@ Which is sensible, but completely untrue. In reality ptxas compiles silly code l ``` * Surprise, surprise, there's two kind of implicit conversions at play in the example above: * "Relaxed type-checking rules": this is the conversion of b16 operation type to s32 dst register - * Undocumented type coercion when dereferencing param_1. The PTX behaviour is to coerce **every** type. It's something like `[param_1] = *(b16*)param_1` \ No newline at end of file + * Undocumented type coercion when dereferencing param_1. The PTX behaviour is to coerce **every** type. It's something to the effect of `[param_1] = *(b16*)param_1` + +PTX grammar +----------- +* PTX grammar rules are atrocious, keywords can be freely reused as ids without escaping +* Modifiers can be applied to instructions in any arbitrary order. We don't support it and hope we will never have to diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index a4e1bcc..5df6323 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -17,6 +17,15 @@ macro_rules! check { }; } +macro_rules! check_panic { + ($expr:expr) => { + let err = unsafe { $expr }; + if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS { + panic!(err); + } + }; +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Error { NotReady = 1, @@ -295,7 +304,7 @@ impl CommandQueue { impl Drop for CommandQueue { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeCommandQueueDestroy(self.0) }; + check_panic! { sys::zeCommandQueueDestroy(self.0) }; } } @@ -344,7 +353,7 @@ impl Module { impl Drop for Module { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeModuleDestroy(self.0) }; + check_panic! { sys::zeModuleDestroy(self.0) }; } } @@ -408,7 +417,7 @@ impl DeviceBuffer { impl Drop for DeviceBuffer { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeDriverFreeMem(self.driver, self.ptr) }; + check_panic! { sys::zeDriverFreeMem(self.driver, self.ptr) }; } } @@ -506,7 +515,7 @@ impl<'a> CommandList<'a> { impl<'a> Drop for CommandList<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeCommandListDestroy(self.0) }; + check_panic! { sys::zeCommandListDestroy(self.0) }; } } @@ -531,9 +540,9 @@ impl<'a> FenceGuard<'a> { impl<'a> Drop for FenceGuard<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeFenceHostSynchronize(self.0, u32::max_value()) }; - unsafe { sys::zeFenceDestroy(self.0) }; - unsafe { sys::zeCommandListDestroy(self.1) }; + check_panic! { sys::zeFenceHostSynchronize(self.0, u32::max_value()) }; + check_panic! { sys::zeFenceDestroy(self.0) }; + check_panic! { sys::zeCommandListDestroy(self.1) }; } } @@ -653,7 +662,7 @@ impl<'a> EventPool<'a> { impl<'a> Drop for EventPool<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeEventPoolDestroy(self.0) }; + check_panic! { sys::zeEventPoolDestroy(self.0) }; } } @@ -684,7 +693,7 @@ impl<'a> Event<'a> { impl<'a> Drop for Event<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeEventDestroy(self.0) }; + check_panic! { sys::zeEventDestroy(self.0) }; } } @@ -744,6 +753,6 @@ impl<'a> Kernel<'a> { impl<'a> Drop for Kernel<'a> { #[allow(unused_must_use)] fn drop(&mut self) { - unsafe { sys::zeKernelDestroy(self.0) }; + check_panic! { sys::zeKernelDestroy(self.0) }; } } diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index bf8ea0d..c7cb7f7 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -239,7 +239,9 @@ pub enum LdCacheOperator { Uncached, } -pub struct MovData {} +pub struct MovData { + pub typ: Type, +} pub struct MulData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 22b91af..64d7725 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -8,12 +8,149 @@ match { r"\s+" => { }, r"//[^\n\r]*[\n\r]*" => { }, r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { }, + r"-?[?:0x]?[0-9]+" => Num, + r#""[^"]*""# => String, + r"[0-9]+\.[0-9]+" => VersionNumber, + "!", + "(", ")", + "+", + ",", + ".", + ":", + ";", + "@", + "[", "]", + "{", "}", + "|", + ".acquire", + ".address_size", + ".and", + ".b16", + ".b32", + ".b64", + ".b8", + ".ca", + ".cg", + ".const", + ".cs", + ".cta", + ".cv", + ".entry", + ".eq", + ".equ", + ".extern", + ".f16", + ".f16x2", + ".f32", + ".f64", + ".file", + ".ftz", + ".func", + ".ge", + ".geu", + ".global", + ".gpu", + ".gt", + ".gtu", + ".hi", + ".hs", + ".le", + ".leu", + ".lo", + ".loc", + ".local", + ".ls", + ".lt", + ".ltu", + ".lu", + ".nan", + ".ne", + ".neu", + ".num", + ".or", + ".param", + ".pred", + ".reg", + ".relaxed", + ".rm", + ".rmi", + ".rn", + ".rni", + ".rp", + ".rpi", + ".rz", + ".rzi", + ".s16", + ".s32", + ".s64", + ".s8" , + ".sat", + ".section", + ".shared", + ".sreg", + ".sys", + ".target", + ".u16", + ".u32", + ".u64", + ".u8" , + ".uni", + ".v2", + ".v4", + ".version", + ".visible", + ".volatile", + ".wb", + ".weak", + ".wide", + ".wt", + ".xor", +} else { + // IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID + "add", + "bra", + "cvt", + "debug", + "ld", + "map_f64_to_f32", + "mov", + "mul", + "not", + "ret", + "setp", + "shl", + "shr", r"sm_[0-9]+" => ShaderModel, - r"-?[?:0x]?[0-9]+" => Num + "st", + "texmode_independent", + "texmode_unified", +} else { + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers + r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID, + r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID, } else { r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID, -} else { - _ +} + +ExtendedID : &'input str = { + "add", + "bra", + "cvt", + "debug", + "ld", + "map_f64_to_f32", + "mov", + "mul", + "not", + "ret", + "setp", + "shl", + "shr", + ShaderModel, + "st", + "texmode_independent", + "texmode_unified", + ID } pub Module: ast::Module<'input> = { @@ -58,7 +195,7 @@ AddressSize = { Function: ast::Function<'input> = { LinkingDirective* - + "(" > ")" => ast::Function{<>} }; @@ -76,10 +213,10 @@ IsKernel: bool = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space FunctionInput: ast::Argument<'input> = { - ".param" <_type:ScalarType> => { + ".param" <_type:ScalarType> => { ast::Argument {a_type: _type, name: name, length: 1 } }, - ".param" "[" "]" => { + ".param" "[" "]" => { let length = length.parse::(); let length = length.unwrap_with(errors); ast::Argument { a_type: a_type, name: name, length: length } @@ -149,7 +286,7 @@ DebugLocation = { }; Label: &'input str = { - ":" => id + ":" => id }; Variable: ast::Variable<&'input str> = { @@ -160,7 +297,7 @@ Variable: ast::Variable<&'input str> = { }; VariableName: (&'input str, Option) = { - => (id, None), + => (id, None), // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names => { let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap(); @@ -189,7 +326,7 @@ Instruction: ast::Instruction<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<&'input str> = { - "ld" "," "[" "]" => { + "ld" "," "[" "]" => { ast::Instruction::Ld( ast::LdData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -234,17 +371,24 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction<&'input str> = { - "mov" MovType => { - ast::Instruction::Mov(ast::MovData{}, a) + "mov" => { + ast::Instruction::Mov(ast::MovData{ typ:t }, a) } }; -MovType = { - ".b16", ".b32", ".b64", - ".u16", ".u32", ".u64", - ".s16", ".s32", ".s64", - ".f32", ".f64", - ".pred" +MovType: ast::Type = { + ".b16" => ast::Type::Scalar(ast::ScalarType::B16), + ".b32" => ast::Type::Scalar(ast::ScalarType::B32), + ".b64" => ast::Type::Scalar(ast::ScalarType::B64), + ".u16" => ast::Type::Scalar(ast::ScalarType::U16), + ".u32" => ast::Type::Scalar(ast::ScalarType::U32), + ".u64" => ast::Type::Scalar(ast::ScalarType::U64), + ".s16" => ast::Type::Scalar(ast::ScalarType::S16), + ".s32" => ast::Type::Scalar(ast::ScalarType::S32), + ".s64" => ast::Type::Scalar(ast::ScalarType::S64), + ".f32" => ast::Type::Scalar(ast::ScalarType::F32), + ".f64" => ast::Type::Scalar(ast::ScalarType::F64), + ".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -263,7 +407,7 @@ InstMulMode: ast::MulData = { }; MulIntControl = { - "hi", ".lo", ".wide" + ".hi", ".lo", ".wide" }; #[inline] @@ -339,8 +483,8 @@ NotType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at PredAt: ast::PredAt<&'input str> = { - "@" => ast::PredAt { not: false, label:label }, - "@" "!" => ast::PredAt { not: true, label:label } + "@" => ast::PredAt { not: false, label:label }, + "@" "!" => ast::PredAt { not: true, label:label } }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra @@ -420,8 +564,8 @@ InstRet: ast::Instruction<&'input str> = { }; Operand: ast::Operand<&'input str> = { - => ast::Operand::Reg(r), - "+" => { + => ast::Operand::Reg(r), + "+" => { let offset = o.parse::(); let offset = offset.unwrap_with(errors); ast::Operand::RegOffset(r, offset) @@ -444,37 +588,37 @@ MovOperand: ast::MovOperand<&'input str> = { }; VectorOperand: (&'input str, &'input str) = { - "." => (pref, suf), - => (pref, &suf[1..]), + "." => (pref, suf), + => (pref, &suf[1..]), }; Arg1: ast::Arg1<&'input str> = { - => ast::Arg1{<>} + => ast::Arg1{<>} }; Arg2: ast::Arg2<&'input str> = { - "," => ast::Arg2{<>} + "," => ast::Arg2{<>} }; Arg2Mov: ast::Arg2Mov<&'input str> = { - "," => ast::Arg2Mov{<>} + "," => ast::Arg2Mov{<>} }; Arg3: ast::Arg3<&'input str> = { - "," "," => ast::Arg3{<>} + "," "," => ast::Arg3{<>} }; Arg4: ast::Arg4<&'input str> = { - "," "," => ast::Arg4{<>} + "," "," => ast::Arg4{<>} }; // TODO: pass src3 negation somewhere Arg5: ast::Arg5<&'input str> = { - "," "," "," "!"? => ast::Arg5{<>} + "," "," "," "!"? => ast::Arg5{<>} }; OptionalDst: &'input str = { - "|" => dst2 + "|" => dst2 } VectorPrefix: ast::VectorPrefix = { @@ -519,9 +663,3 @@ Comma: Vec = { } } }; - -String = r#""[^"]*""#; -VersionNumber = r"[0-9]+\.[0-9]+"; -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers -ID: &'input str = => s; -DotID: &'input str = => s; \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 83d0510..b573f2c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,6 +1,6 @@ use crate::ptx; use crate::translate; -use rspirv::dr::{Block, Function, Instruction, Loader, Operand}; +use rspirv::{binary::Disassemble, dr::{Block, Function, Instruction, Loader, Operand}}; use spirv_headers::Word; use spirv_tools_sys::{ spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env, @@ -37,6 +37,7 @@ macro_rules! test_ptx { } test_ptx!(ld_st, [1u64], [1u64]); +test_ptx!(mov, [1u64], [1u64]); struct DisplayError { err: T, @@ -148,73 +149,110 @@ fn test_spvtxt_assert<'a>( ptr::null_mut(), ) }; + assert!(result == spv_result_t::SPV_SUCCESS); let mut loader = Loader::new(); rspirv::binary::parse_words(&parsed_spirv, &mut loader)?; let spvtxt_mod = loader.module(); - assert_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]); - assert!(result == spv_result_t::SPV_SUCCESS); + if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) { + panic!(ptx_mod.disassemble()) + } Ok(()) } -fn assert_spirv_fn_equal(fn1: &Function, fn2: &Function) { +fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { let mut map = HashMap::new(); - assert_option_equal(&fn1.def, &fn2.def, &mut map, assert_instr_equal); - assert_option_equal(&fn1.end, &fn2.end, &mut map, assert_instr_equal); + if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) { + return false; + } + if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) { + return false; + } for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) { - assert_instr_equal(inst1, inst2, &mut map); + if !is_instr_equal(inst1, inst2, &mut map) { + return false; + } } for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) { - assert_block_equal(b1, b2, &mut map); + if !is_block_equal(b1, b2, &mut map) { + return false; + } } + true } -fn assert_block_equal(b1: &Block, b2: &Block, map: &mut HashMap) { - assert_option_equal(&b1.label, &b2.label, map, assert_instr_equal); +fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap) -> bool { + if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) { + return false; + } for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) { - assert_instr_equal(inst1, inst2, map); + if !is_instr_equal(inst1, inst2, map) { + return false; + } } + true } -fn assert_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut HashMap) { - assert_option_equal( - &instr1.result_type, - &instr2.result_type, - map, - assert_word_equal, - ); - assert_option_equal(&instr1.result_id, &instr2.result_id, map, assert_word_equal); +fn is_instr_equal( + instr1: &Instruction, + instr2: &Instruction, + map: &mut HashMap, +) -> bool { + if !is_option_equal(&instr1.result_type, &instr2.result_type, map, is_word_equal) { + return false; + } + if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) { + return false; + } for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) { match (o1, o2) { (Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => { - assert_word_equal(w1, w2, map) + if !is_word_equal(w1, w2, map) { + return false; + } + } + (Operand::IdScope(w1), Operand::IdScope(w2)) => { + if !is_word_equal(w1, w2, map) { + return false; + } + } + (Operand::IdRef(w1), Operand::IdRef(w2)) => { + if !is_word_equal(w1, w2, map) { + return false; + } + } + (o1, o2) => { + if o1 != o2 { + return false; + } } - (Operand::IdScope(w1), Operand::IdScope(w2)) => assert_word_equal(w1, w2, map), - (Operand::IdRef(w1), Operand::IdRef(w2)) => assert_word_equal(w1, w2, map), - (o1, o2) => assert_eq!(o1, o2), } } + true } -fn assert_word_equal(w1: &Word, w2: &Word, map: &mut HashMap) { +fn is_word_equal(w1: &Word, w2: &Word, map: &mut HashMap) -> bool { match map.entry(*w1) { std::collections::hash_map::Entry::Occupied(entry) => { - assert_eq!(entry.get(), w2); + if entry.get() != w2 { + return false; + } } std::collections::hash_map::Entry::Vacant(entry) => { entry.insert(*w2); } } + true } -fn assert_option_equal)>( +fn is_option_equal) -> bool>( o1: &Option, o2: &Option, map: &mut HashMap, f: F, -) { +) -> bool { match (o1, o2) { (Some(t1), Some(t2)) => f(t1, t2, map), - (None, None) => (), + (None, None) => true, _ => panic!(), } } diff --git a/ptx/src/test/spirv_run/mov.ptx b/ptx/src/test/spirv_run/mov.ptx new file mode 100644 index 0000000..5ca61f1 --- /dev/null +++ b/ptx/src/test/spirv_run/mov.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mov( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + mov.u64 temp2, temp; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/mov.spvtxt b/ptx/src/test/spirv_run/mov.spvtxt new file mode 100644 index 0000000..3f11b26 --- /dev/null +++ b/ptx/src/test/spirv_run/mov.spvtxt @@ -0,0 +1,26 @@ +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpCapability Int8 +%1 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %5 "mov" +%2 = OpTypeVoid +%3 = OpTypeInt 64 0 +%4 = OpTypeFunction %2 %3 %3 +%19 = OpTypePointer Generic %3 +%5 = OpFunction %2 None %4 +%6 = OpFunctionParameter %3 +%7 = OpFunctionParameter %3 +%18 = OpLabel +%13 = OpCopyObject %3 %6 +%14 = OpCopyObject %3 %7 +%15 = OpConvertUToPtr %19 %13 +%16 = OpLoad %3 %15 +%100 = OpCopyObject %3 %16 +%17 = OpConvertUToPtr %19 %14 +OpStore %17 %100 +OpReturn +OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1942586..ee28bb7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -10,9 +10,19 @@ use rspirv::binary::Assemble; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { Base(ast::ScalarType), + Extended(ast::ExtendedScalarType), Pointer(ast::ScalarType, spirv::StorageClass), } +impl From for SpirvType { + fn from(t: ast::Type) -> Self { + match t { + ast::Type::Scalar(t) => SpirvType::Base(t), + ast::Type::ExtendedScalar(t) => SpirvType::Extended(t) + } + } +} + struct TypeWordMap { void: spirv::Word, complex: HashMap, @@ -50,9 +60,20 @@ impl TypeWordMap { }) } + fn get_or_add_extended(&mut self, b: &mut dr::Builder, t: ast::ExtendedScalarType) -> spirv::Word { + *self + .complex + .entry(SpirvType::Extended(t)) + .or_insert_with(|| match t { + ast::ExtendedScalarType::Pred => b.type_bool(), + ast::ExtendedScalarType::F16x2 => todo!(), + }) + } + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { match t { SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), + SpirvType::Extended(t) => self.get_or_add_extended(b, t), SpirvType::Pointer(scalar, storage) => { let base = self.get_or_add_scalar(b, scalar); *self @@ -416,6 +437,14 @@ fn emit_function_body_ops( } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, + ast::Instruction::Mov(mov, arg) => { + let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); + let src = match arg.src { + ast::MovOperand::Op(ast::Operand::Reg(id)) => id, + _ => todo!(), + }; + builder.copy_object(result_type, Some(arg.dst), src)?; + } _ => todo!(), }, } @@ -1190,7 +1219,8 @@ impl ast::Instruction { ast::Instruction::Ret(_) => None, ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), - _ => todo!(), + ast::Instruction::Mov(mov, _) => Some(mov.typ), + _ => todo!() } } }