diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 5df6323..4f120de 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -708,10 +708,10 @@ impl<'a> Kernel<'a> { Self(x, PhantomData) } - pub fn new(module: &'a Module, name: &CStr) -> Result { + pub fn new_resident(module: &'a Module, name: &CStr) -> Result { let desc = sys::ze_kernel_desc_t { version: sys::ze_kernel_desc_version_t::ZE_KERNEL_DESC_VERSION_CURRENT, - flags: sys::ze_kernel_flag_t::ZE_KERNEL_FLAG_NONE, + flags: sys::ze_kernel_flag_t::ZE_KERNEL_FLAG_FORCE_RESIDENCY, pKernelName: name.as_ptr() as *const _, }; let mut result = ptr::null_mut(); @@ -719,6 +719,21 @@ impl<'a> Kernel<'a> { Ok(Self(result, PhantomData)) } + pub fn set_attribute_bool( + &mut self, + attr: sys::ze_kernel_attribute_t, + value: bool, + ) -> Result<()> { + let ze_bool: sys::ze_bool_t = if value { 1 } else { 0 }; + check!(sys::zeKernelSetAttribute( + self.0, + attr, + mem::size_of::() as u32, + &ze_bool as *const _ as *const _ + )); + Ok(()) + } + pub fn set_arg_buffer>>( &self, index: u32, diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 32e46ce..3abcae7 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -100,7 +100,11 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let dev = devices.drain(0..1).next().unwrap(); let queue = ze::CommandQueue::new(&dev)?; let module = ze::Module::new_spirv(&dev, byte_il, None)?; - let kernel = ze::Kernel::new(&module, name)?; + let mut kernel = ze::Kernel::new_resident(&module, name)?; + kernel.set_attribute_bool( + ze::sys::ze_kernel_attribute_t::ZE_KERNEL_ATTR_INDIRECT_DEVICE_ACCESS, + true, + )?; let mut inp_b = ze::DeviceBuffer::::new(&drv, &dev, input.len())?; let mut out_b = ze::DeviceBuffer::::new(&drv, &dev, output.len())?; let inp_b_ptr_mut: ze::BufferPtrMut = (&mut inp_b).into(); diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 233c67f..c399f0d 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -326,7 +326,8 @@ fn expand_arguments( for s in func { match s { Statement::Instruction(inst) => { - let new_inst = normalize_insert_instruction(&mut result, id_def, inst); + let mut visitor = FlattenArguments::new(&mut result, id_def); + let new_inst = inst.map(&mut visitor); result.push(Statement::Instruction(new_inst)); } Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)), @@ -340,170 +341,56 @@ fn expand_arguments( result } -#[must_use] -fn normalize_insert_instruction( - func: &mut Vec, - id_def: &mut NumericIdResolver, - instr: ast::Instruction, -) -> ast::Instruction { - match instr { - ast::Instruction::Ld(d, a) => { - let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); - ast::Instruction::Ld(d, arg) - } - ast::Instruction::Mov(d, a) => { - let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); - ast::Instruction::Mov(d, arg) - } - ast::Instruction::Mul(d, a) => { - let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); - ast::Instruction::Mul(d, arg) - } - ast::Instruction::Add(d, a) => { - let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); - ast::Instruction::Add(d, arg) - } - ast::Instruction::Setp(d, a) => { - let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); - ast::Instruction::Setp(d, arg) - } - ast::Instruction::SetpBool(d, a) => { - let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); - ast::Instruction::SetpBool(d, arg) - } - ast::Instruction::Not(d, a) => { - let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); - ast::Instruction::Not(d, arg) - } - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, ast::Arg1 { src: a.src }), - ast::Instruction::Cvt(d, a) => { - let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); - ast::Instruction::Cvt(d, arg) - } - ast::Instruction::Shl(d, a) => { - let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); - ast::Instruction::Shl(d, arg) - } - ast::Instruction::St(d, a) => { - let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); - ast::Instruction::St(d, arg) - } - ast::Instruction::Ret(d) => ast::Instruction::Ret(d), +struct FlattenArguments<'a> { + func: &'a mut Vec, + id_def: &'a mut NumericIdResolver, +} + +impl<'a> FlattenArguments<'a> { + fn new(func: &'a mut Vec, id_def: &'a mut NumericIdResolver) -> Self { + FlattenArguments { func, id_def } } } -fn normalize_expand_arg2( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg2, -) -> ast::Arg2 { - ast::Arg2 { - dst: a.dst, - src: normalize_expand_operand(func, id_def, inst_type, a.src), +impl<'a> ArgumentMapVisitor for FlattenArguments<'a> { + fn dst_variable(&mut self, x: spirv::Word, _: Option) -> spirv::Word { + x } -} -fn normalize_expand_arg2mov( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg2Mov, -) -> ast::Arg2Mov { - ast::Arg2Mov { - dst: a.dst, - src: normalize_expand_mov_operand(func, id_def, inst_type, a.src), - } -} - -fn normalize_expand_arg2st( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg2St, -) -> ast::Arg2St { - ast::Arg2St { - src1: normalize_expand_operand(func, id_def, inst_type, a.src1), - src2: normalize_expand_operand(func, id_def, inst_type, a.src2), - } -} - -fn normalize_expand_arg3( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg3, -) -> ast::Arg3 { - ast::Arg3 { - dst: a.dst, - src1: normalize_expand_operand(func, id_def, inst_type, a.src1), - src2: normalize_expand_operand(func, id_def, inst_type, a.src2), - } -} - -fn normalize_expand_arg4( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg4, -) -> ast::Arg4 { - ast::Arg4 { - dst1: a.dst1, - dst2: a.dst2, - src1: normalize_expand_operand(func, id_def, inst_type, a.src1), - src2: normalize_expand_operand(func, id_def, inst_type, a.src2), - } -} - -fn normalize_expand_arg5( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - a: ast::Arg5, -) -> ast::Arg5 { - ast::Arg5 { - dst1: a.dst1, - dst2: a.dst2, - src1: normalize_expand_operand(func, id_def, inst_type, a.src1), - src2: normalize_expand_operand(func, id_def, inst_type, a.src2), - src3: normalize_expand_operand(func, id_def, inst_type, a.src3), - } -} - -fn normalize_expand_operand( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - opr: ast::Operand, -) -> spirv::Word { - match opr { - ast::Operand::Reg(r) => r, - ast::Operand::Imm(x) => { - if let Some(typ) = inst_type() { - let id = id_def.new_id(Some(ast::Type::Scalar(typ))); - func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: typ, - value: x, - })); - id - } else { - todo!() + fn src_operand(&mut self, op: ast::Operand, t: Option) -> spirv::Word { + match op { + ast::Operand::Reg(r) => r, + ast::Operand::Imm(x) => { + if let Some(typ) = t { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value: x, + })); + id + } else { + todo!() + } } + _ => todo!(), } - _ => todo!(), } -} -fn normalize_expand_mov_operand( - func: &mut Vec, - id_def: &mut NumericIdResolver, - inst_type: &impl Fn() -> Option, - opr: ast::MovOperand, -) -> spirv::Word { - match opr { - ast::MovOperand::Op(opr) => normalize_expand_operand(func, id_def, inst_type, opr), - _ => todo!(), + fn src_mov_operand( + &mut self, + op: ast::MovOperand, + t: Option, + ) -> spirv::Word { + match op { + ast::MovOperand::Op(opr) => self.src_operand(opr, t), + ast::MovOperand::Vec(_, _) => todo!(), + } } } @@ -1023,53 +910,6 @@ trait ArgumentMapVisitor { fn src_mov_operand(&mut self, o: T::MovOperand, typ: Option) -> U::MovOperand; } -struct FlattenArguments<'a> { - func: &'a mut Vec, - id_def: &'a mut NumericIdResolver, -} - -impl<'a> ArgumentMapVisitor for FlattenArguments<'a> { - fn dst_variable(&mut self, x: spirv::Word, _: Option) -> spirv::Word { - x - } - - fn src_operand(&mut self, op: ast::Operand, t: Option) -> spirv::Word { - match op { - ast::Operand::Reg(r) => r, - ast::Operand::Imm(x) => { - if let Some(typ) = t { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value: x, - })); - id - } else { - todo!() - } - } - _ => todo!(), - } - } - - fn src_mov_operand( - &mut self, - op: ast::MovOperand, - t: Option, - ) -> spirv::Word { - match op { - ast::MovOperand::Op(opr) => self.src_operand(opr, t), - ast::MovOperand::Vec(_, _) => todo!(), - } - } -} - impl ArgumentMapVisitor for T where T: FnMut(spirv::Word, bool, Option) -> spirv::Word, @@ -1118,7 +958,7 @@ where } impl ast::Instruction { - fn map_variable_new>( + fn map>( self, visitor: &mut V, ) -> ast::Instruction { @@ -1165,7 +1005,7 @@ impl ast::Instruction { self, f: &mut F, ) -> ast::Instruction { - self.map_variable_new(f) + self.map(f) } } @@ -1213,14 +1053,14 @@ fn reduced_visitor<'a>( impl ast::Instruction { fn visit_variable spirv::Word>(self, f: &mut F) -> Self { let mut visitor = reduced_visitor(f); - self.map_variable_new(&mut visitor) + self.map(&mut visitor) } fn visit_variable_extended) -> spirv::Word>( self, f: &mut F, ) -> Self { - self.map_variable_new(f) + self.map(f) } fn jump_target(&self) -> Option { @@ -1319,7 +1159,7 @@ impl<'a> ast::Instruction> { self, f: &mut F, ) -> ast::Instruction { - self.map_variable_new(f) + self.map(f) } } @@ -1458,15 +1298,6 @@ enum ScalarKind { Float, } -impl ast::Type { - fn try_as_scalar(self) -> Option { - match self { - ast::Type::Scalar(s) => Some(s), - ast::Type::ExtendedScalar(_) => None, - } - } -} - impl ast::ScalarType { fn width(self) -> u8 { match self {