diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 15125b0..a2a60dc 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -2,8 +2,8 @@ use super::*; pub(super) fn run<'a, 'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { directives .into_iter() .map(|directive| run_directive(resolver, directive)) @@ -12,8 +12,8 @@ pub(super) fn run<'a, 'input>( fn run_directive<'input>( resolver: &mut GlobalStringIdentResolver2, - directive: Directive2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), @@ -22,13 +22,13 @@ fn run_directive<'input>( fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2, - mut method: Function2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { + mut method: Function2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { let is_declaration = method.body.is_none(); let mut body = Vec::new(); let mut remap_returns = Vec::new(); - if !method.func_decl.name.is_kernel() { - for arg in method.func_decl.return_arguments.iter_mut() { + if !method.is_kernel { + for arg in method.return_arguments.iter_mut() { match arg.state_space { ptx_parser::StateSpace::Param => { arg.state_space = ptx_parser::StateSpace::Reg; @@ -51,7 +51,7 @@ fn run_method<'input>( _ => return Err(error_unreachable()), } } - for arg in method.func_decl.input_arguments.iter_mut() { + for arg in method.input_arguments.iter_mut() { match arg.state_space { ptx_parser::StateSpace::Param => { arg.state_space = ptx_parser::StateSpace::Reg; @@ -96,12 +96,14 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { - func_decl: method.func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 2d1269d..255aee0 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -168,7 +168,7 @@ impl Deref for MemoryBuffer { pub(super) fn run<'input>( id_defs: GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, + directives: Vec, SpirvWord>>, ) -> Result { let context = Context::new(); let module = Module::new(&context, LLVM_UNNAMED); @@ -218,24 +218,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn emit_method( &mut self, - method: Function2<'input, ast::Instruction, SpirvWord>, + method: Function2, SpirvWord>, ) -> Result<(), TranslateError> { - let func_decl = method.func_decl; let name = method .import_as .as_deref() - .or_else(|| match func_decl.name { - ast::MethodName::Kernel(name) => Some(name), - ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(), - }) + .or_else(|| self.id_defs.ident_map[&method.name].name.as_deref()) .ok_or_else(|| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; if fn_ == ptr::null_mut() { let fn_type = get_function_type( self.context, - func_decl.return_arguments.iter().map(|v| &v.v_type), - func_decl + method.return_arguments.iter().map(|v| &v.v_type), + method .input_arguments .iter() .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), @@ -245,15 +241,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { self.emit_fn_attribute(fn_, "uniform-work-group-size", "true"); self.emit_fn_attribute(fn_, "no-trapping-math", "true"); } - if let ast::MethodName::Func(name) = func_decl.name { - self.resolver.register(name, fn_); + if !method.is_kernel { + self.resolver.register(method.name, fn_); } - for (i, param) in func_decl.input_arguments.iter().enumerate() { + for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); - if func_decl.name.is_kernel() { + if method.is_kernel { let attr_kind = unsafe { LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) }; @@ -267,7 +263,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; } } - let call_conv = if func_decl.name.is_kernel() { + let call_conv = if method.is_kernel { Self::kernel_call_convention() } else { Self::func_call_convention() @@ -282,7 +278,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); - for var in func_decl.return_arguments { + for var in method.return_arguments { method_emitter.emit_variable(var)?; } for statement in statements.iter() { @@ -1558,7 +1554,7 @@ impl<'a> MethodEmitContext<'a> { return self.emit_cvt_float_to_int( data.from, data.to, - integer_rounding.unwrap_or(ast::RoundingMode::NearestEven), + integer_rounding, arguments, Some(LLVMBuildFPToSI), ) diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index f2de786..07806f9 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -2,8 +2,8 @@ use super::*; pub(super) fn run<'a, 'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, +) -> Result, SpirvWord>>, TranslateError> { directives .into_iter() .map(|directive| run_directive(resolver, directive)) @@ -13,11 +13,10 @@ pub(super) fn run<'a, 'input>( fn run_directive<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, directive: Directive2< - 'input, ast::Instruction>, ast::ParsedOperand, >, -) -> Result, SpirvWord>, TranslateError> { +) -> Result, SpirvWord>, TranslateError> { Ok(match directive { Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), @@ -27,11 +26,10 @@ fn run_directive<'input>( fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, method: Function2< - 'input, ast::Instruction>, ast::ParsedOperand, >, -) -> Result, SpirvWord>, TranslateError> { +) -> Result, SpirvWord>, TranslateError> { let body = method .body .map(|statements| { @@ -43,12 +41,14 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { - func_decl: method.func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 8c3b794..3323305 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -1,30 +1,29 @@ use super::*; pub(super) fn run<'a, 'input>( - resolver: &mut GlobalStringIdentResolver2<'input>, + resolver: &'a mut GlobalStringIdentResolver2<'input>, special_registers: &'a SpecialRegistersMap2, - directives: Vec>, -) -> Result>, TranslateError> { - let declarations = SpecialRegistersMap2::generate_declarations(resolver); - let mut result = Vec::with_capacity(declarations.len() + directives.len()); + directives: Vec, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len()); let mut sreg_to_function = - FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default()); - for (sreg, declaration) in declarations { - let name = if let ast::MethodName::Func(name) = declaration.name { - name - } else { - return Err(error_unreachable()); - }; - result.push(UnconditionalDirective::Method(UnconditionalFunction { - func_decl: declaration, - globals: Vec::new(), - body: None, - import_as: None, - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - })); - sreg_to_function.insert(sreg, name); - } + FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default()); + SpecialRegistersMap2::foreach_declaration( + resolver, + |sreg, (return_arguments, name, input_arguments)| { + result.push(UnconditionalDirective::Method(UnconditionalFunction { + return_arguments, + name, + input_arguments, + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, + })); + sreg_to_function.insert(sreg, name); + }, + ); let mut visitor = SpecialRegisterResolver { resolver, special_registers, @@ -39,8 +38,8 @@ pub(super) fn run<'a, 'input>( fn run_directive<'a, 'input>( visitor: &mut SpecialRegisterResolver<'a, 'input>, - directive: UnconditionalDirective<'input>, -) -> Result, TranslateError> { + directive: UnconditionalDirective, +) -> Result { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), @@ -49,8 +48,8 @@ fn run_directive<'a, 'input>( fn run_method<'a, 'input>( visitor: &mut SpecialRegisterResolver<'a, 'input>, - method: UnconditionalFunction<'input>, -) -> Result, TranslateError> { + method: UnconditionalFunction, +) -> Result { let body = method .body .map(|statements| { @@ -62,12 +61,14 @@ fn run_method<'a, 'input>( }) .transpose()?; Ok(Function2 { - func_decl: method.func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 718c052..654a7e9 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -1,8 +1,8 @@ use super::*; pub(super) fn run<'input>( - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { let mut result = Vec::with_capacity(directives.len()); for mut directive in directives.into_iter() { run_directive(&mut result, &mut directive)?; @@ -12,8 +12,8 @@ pub(super) fn run<'input>( } fn run_directive<'input>( - result: &mut Vec, SpirvWord>>, - directive: &mut Directive2<'input, ptx_parser::Instruction, SpirvWord>, + result: &mut Vec, SpirvWord>>, + directive: &mut Directive2, SpirvWord>, ) -> Result<(), TranslateError> { match directive { Directive2::Variable(..) => {} @@ -23,8 +23,8 @@ fn run_directive<'input>( } fn run_function<'input>( - result: &mut Vec, SpirvWord>>, - function: &mut Function2<'input, ptx_parser::Instruction, SpirvWord>, + result: &mut Vec, SpirvWord>>, + function: &mut Function2, SpirvWord>, ) { function.body = function.body.take().map(|statements| { statements diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 702f733..014c49b 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -11,8 +11,8 @@ use super::*; // pass, so we do nothing there pub(super) fn run<'a, 'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { directives .into_iter() .map(|directive| run_directive(resolver, directive)) @@ -21,8 +21,8 @@ pub(super) fn run<'a, 'input>( fn run_directive<'a, 'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directive: Directive2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(method) => { @@ -34,12 +34,11 @@ fn run_directive<'a, 'input>( fn run_method<'a, 'input>( mut visitor: InsertMemSSAVisitor<'a, 'input>, - method: Function2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { - let mut func_decl = method.func_decl; - let is_kernel = func_decl.name.is_kernel(); + mut method: Function2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let is_kernel = method.is_kernel; if is_kernel { - for arg in func_decl.input_arguments.iter_mut() { + for arg in method.input_arguments.iter_mut() { let old_name = arg.name; let old_space = arg.state_space; let new_space = ast::StateSpace::ParamEntry; @@ -51,10 +50,10 @@ fn run_method<'a, 'input>( arg.state_space = new_space; } }; - for arg in func_decl.return_arguments.iter_mut() { + for arg in method.return_arguments.iter_mut() { visitor.visit_variable(arg)?; } - let return_arguments = &func_decl.return_arguments[..]; + let return_arguments = &method.return_arguments[..]; let body = method .body .map(move |statements| { @@ -66,12 +65,14 @@ fn run_method<'a, 'input>( }) .transpose()?; Ok(Function2 { - func_decl: func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs index 4fc4136..2d16015 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control.rs @@ -1,3 +1,5 @@ +use crate::pass::error_unreachable; + use super::BrachCondition; use super::Directive2; use super::Function2; @@ -17,6 +19,178 @@ use rustc_hash::FxHashSet; use std::hash::Hash; use std::iter; +#[derive(Default)] +enum DenormalMode { + #[default] + FlushToZero, + Preserve, +} + +impl DenormalMode { + fn from_ftz(ftz: bool) -> Self { + if ftz { + DenormalMode::FlushToZero + } else { + DenormalMode::Preserve + } + } +} + +#[derive(Default)] +enum RoundingMode { + #[default] + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} + +impl RoundingMode { + fn to_ast(self) -> ast::RoundingMode { + match self { + RoundingMode::NearestEven => ast::RoundingMode::NearestEven, + RoundingMode::Zero => ast::RoundingMode::Zero, + RoundingMode::NegativeInf => ast::RoundingMode::NegativeInf, + RoundingMode::PositiveInf => ast::RoundingMode::PositiveInf, + } + } + + fn from_ast(rnd: ast::RoundingMode) -> Self { + match rnd { + ast::RoundingMode::NearestEven => RoundingMode::NearestEven, + ast::RoundingMode::Zero => RoundingMode::Zero, + ast::RoundingMode::NegativeInf => RoundingMode::NegativeInf, + ast::RoundingMode::PositiveInf => RoundingMode::PositiveInf, + } + } +} + +struct InstructionModes { + denormal_f32: Option, + denormal_f16_f64: Option, + rounding_f32: Option, + rounding_f16_f64: Option, +} + +impl InstructionModes { + fn none() -> Self { + Self { + denormal_f32: None, + denormal_f16_f64: None, + rounding_f32: None, + rounding_f16_f64: None, + } + } + + fn new( + type_: ast::ScalarType, + denormal: Option, + rounding: Option, + ) -> Self { + if type_ != ast::ScalarType::F32 { + Self { + denormal_f16_f64: denormal, + rounding_f16_f64: rounding, + ..Self::none() + } + } else { + Self { + denormal_f32: denormal, + rounding_f32: rounding, + ..Self::none() + } + } + } + + fn mixed_ftz_f32( + type_: ast::ScalarType, + denormal: Option, + rounding: Option, + ) -> Self { + if type_ != ast::ScalarType::F32 { + Self { + denormal_f16_f64: denormal, + rounding_f32: rounding, + ..Self::none() + } + } else { + Self { + denormal_f32: denormal, + rounding_f32: rounding, + ..Self::none() + } + } + } + + fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes { + let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz); + let rounding = Some(RoundingMode::from_ast(arith.rounding)); + InstructionModes::new(arith.type_, denormal, rounding) + } + + fn from_ftz(type_: ast::ScalarType, ftz: Option) -> Self { + Self::new(type_, ftz.map(DenormalMode::from_ftz), None) + } + + fn from_ftz_f32(ftz: bool) -> Self { + Self::new( + ast::ScalarType::F32, + Some(DenormalMode::from_ftz(ftz)), + None, + ) + } + + fn from_rcp(data: ast::RcpData) -> InstructionModes { + let rounding = match data.kind { + ast::RcpKind::Approx => None, + ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)), + }; + let denormal = data.flush_to_zero.map(DenormalMode::from_ftz); + InstructionModes::new(data.type_, denormal, rounding) + } + + fn from_cvt(cvt: &ast::CvtDetails) -> InstructionModes { + match cvt.mode { + ast::CvtMode::ZeroExtend + | ast::CvtMode::SignExtend + | ast::CvtMode::Truncate + | ast::CvtMode::Bitcast + | ast::CvtMode::SaturateUnsignedToSigned + | ast::CvtMode::SaturateSignedToUnsigned => Self::none(), + ast::CvtMode::FPExtend { flush_to_zero } => { + Self::from_ftz(ast::ScalarType::F32, flush_to_zero) + } + ast::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } + | ast::CvtMode::FPRound { + integer_rounding: rounding, + flush_to_zero, + } => Self::mixed_ftz_f32( + cvt.to, + flush_to_zero.map(DenormalMode::from_ftz), + Some(RoundingMode::from_ast(rounding)), + ), + ast::CvtMode::SignedFromFP { + flush_to_zero, + rounding, + } + | ast::CvtMode::UnsignedFromFP { + flush_to_zero, + rounding, + } => Self::new( + cvt.from, + flush_to_zero.map(DenormalMode::from_ftz), + Some(RoundingMode::from_ast(rounding)), + ), + ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => { + Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd))) + } + } + } +} + struct ControlFlowGraph { entry_points: FxHashMap, basic_blocks: FxHashMap, @@ -74,19 +248,40 @@ struct Node { pub(crate) fn run<'input>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, - directives: Vec, super::SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, super::SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { let mut cfg = ControlFlowGraph::::new(); - let mut node_idx_to_name = FxHashMap::, SpirvWord>::default(); for directive in directives.iter() { match directive { super::Directive2::Method(Function2 { - func_decl: ast::MethodDeclaration { name, .. }, - body, + name, + body: Some(body), .. }) => { + let mut basic_block = Some(cfg.add_entry_basic_block(*name)); for statement in body.iter() { - todo!() + match statement { + Statement::Instruction(ast::Instruction::Bra { arguments }) => { + let bb_index = basic_block.ok_or_else(error_unreachable)?; + cfg.add_jump(bb_index, arguments.src); + basic_block = None; + } + Statement::Label(label) => { + basic_block = Some(cfg.get_or_add_basic_block(*label)); + } + Statement::Conditional(BrachCondition { + if_true, if_false, .. + }) => { + let bb_index = basic_block.ok_or_else(error_unreachable)?; + cfg.add_jump(bb_index, *if_true); + cfg.add_jump(bb_index, *if_false); + basic_block = None; + } + Statement::Instruction(instruction) => { + let modes = get_modes(instruction); + } + _ => continue, + } } } _ => continue, @@ -280,6 +475,169 @@ impl UniqueVec { } } +fn get_modes(inst: &ast::Instruction) -> InstructionModes { + match inst { + // TODO: review it when implementing virtual calls + ast::Instruction::Call { .. } + | ast::Instruction::Mov { .. } + | ast::Instruction::Ld { .. } + | ast::Instruction::St { .. } + | ast::Instruction::PrmtSlow { .. } + | ast::Instruction::Prmt { .. } + | ast::Instruction::Activemask { .. } + | ast::Instruction::Membar { .. } + | ast::Instruction::Trap {} + | ast::Instruction::Not { .. } + | ast::Instruction::Or { .. } + | ast::Instruction::And { .. } + | ast::Instruction::Bra { .. } + | ast::Instruction::Clz { .. } + | ast::Instruction::Brev { .. } + | ast::Instruction::Popc { .. } + | ast::Instruction::Xor { .. } + | ast::Instruction::Rem { .. } + | ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Shr { .. } + | ast::Instruction::Shl { .. } + | ast::Instruction::Selp { .. } + | ast::Instruction::Ret { .. } + | ast::Instruction::Bar { .. } + | ast::Instruction::Cvta { .. } + | ast::Instruction::Atom { .. } + | ast::Instruction::AtomCas { .. } => InstructionModes::none(), + ast::Instruction::Add { + data: ast::ArithDetails::Integer(_), + .. + } + | ast::Instruction::Sub { + data: ast::ArithDetails::Integer(..), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Integer { .. }, + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Integer { .. }, + .. + } + | ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..), + .. + } + | ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..), + .. + } + | ast::Instruction::Div { + data: ast::DivDetails::Signed(..) | ast::DivDetails::Unsigned(..), + .. + } => InstructionModes::none(), + ast::Instruction::Fma { data, .. } + | ast::Instruction::Sub { + data: ast::ArithDetails::Float(data), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Float(data), + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Float(data), + .. + } + | ast::Instruction::Add { + data: ast::ArithDetails::Float(data), + .. + } => InstructionModes::from_arith_float(data), + ast::Instruction::Setp { + data: + ast::SetpData { + type_, + flush_to_zero, + .. + }, + .. + } + | ast::Instruction::SetpBool { + data: + ast::SetpBoolData { + base: + ast::SetpData { + type_, + flush_to_zero, + .. + }, + .. + }, + .. + } + | ast::Instruction::Neg { + data: ast::TypeFtz { + type_, + flush_to_zero, + }, + .. + } + | ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_, + flush_to_zero, + }, + .. + } + | ast::Instruction::Rsqrt { + data: ast::TypeFtz { + type_, + flush_to_zero, + }, + .. + } + | ast::Instruction::Abs { + data: ast::TypeFtz { + type_, + flush_to_zero, + }, + .. + } + | ast::Instruction::Min { + data: + ast::MinMaxDetails::Float(ast::MinMaxFloat { + type_, + flush_to_zero, + .. + }), + .. + } + | ast::Instruction::Max { + data: + ast::MinMaxDetails::Float(ast::MinMaxFloat { + type_, + flush_to_zero, + .. + }), + .. + } + | ast::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + type_, + flush_to_zero, + .. + }), + .. + } => InstructionModes::from_ftz(*type_, *flush_to_zero), + ast::Instruction::Sin { data, .. } + | ast::Instruction::Cos { data, .. } + | ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero), + ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => { + InstructionModes::from_rcp(*data) + } + ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index 4f738f5..9f8b01c 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -19,8 +19,8 @@ use ptx_parser as ast; */ pub(super) fn run<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { directives .into_iter() .map(|directive| run_directive(resolver, directive)) @@ -29,8 +29,8 @@ pub(super) fn run<'input>( fn run_directive<'a, 'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directive: Directive2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(mut method) => { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 8cc9926..442b1e7 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -44,7 +44,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, >; -enum Directive2<'input, Instruction, Operand: ast::Operand> { +enum Directive2 { Variable(ast::LinkingDirective, ast::Variable), - Method(Function2<'input, Instruction, Operand>), + Method(Function2), } -struct Function2<'input, Instruction, Operand: ast::Operand> { - pub func_decl: ast::MethodDeclaration<'input, SpirvWord>, - pub globals: Vec>, +struct Function2 { + pub return_arguments: Vec>, + pub name: Operand::Ident, + pub input_arguments: Vec>, pub body: Option>>, + is_kernel: bool, import_as: Option, tuning: Vec, linkage: ast::LinkingDirective, } -type NormalizedDirective2<'input> = Directive2< - 'input, +type NormalizedDirective2 = Directive2< ( Option>, ast::Instruction>, @@ -582,8 +583,7 @@ type NormalizedDirective2<'input> = Directive2< ast::ParsedOperand, >; -type NormalizedFunction2<'input> = Function2< - 'input, +type NormalizedFunction2 = Function2< ( Option>, ast::Instruction>, @@ -591,17 +591,11 @@ type NormalizedFunction2<'input> = Function2< ast::ParsedOperand, >; -type UnconditionalDirective<'input> = Directive2< - 'input, - ast::Instruction>, - ast::ParsedOperand, ->; +type UnconditionalDirective = + Directive2>, ast::ParsedOperand>; -type UnconditionalFunction<'input> = Function2< - 'input, - ast::Instruction>, - ast::ParsedOperand, ->; +type UnconditionalFunction = + Function2>, ast::ParsedOperand>; struct GlobalStringIdentResolver2<'input> { pub(crate) current_id: SpirvWord, @@ -807,47 +801,45 @@ impl SpecialRegistersMap2 { self.id_to_reg.get(&id).copied() } - fn generate_declarations<'a, 'input>( + fn len() -> usize { + PtxSpecialRegister::iter().len() + } + + fn foreach_declaration<'a, 'input>( resolver: &'a mut GlobalStringIdentResolver2<'input>, - ) -> impl ExactSizeIterator< - Item = ( + mut fn_: impl FnMut( PtxSpecialRegister, - ast::MethodDeclaration<'input, SpirvWord>, + ( + Vec>, + SpirvWord, + Vec>, + ), ), - > + 'a { - PtxSpecialRegister::iter().map(|sreg| { + ) { + for sreg in PtxSpecialRegister::iter() { let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); - let name = - ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); + let name = resolver.register_named(Cow::Owned(external_fn_name), None); let return_type = sreg.get_function_return_type(); let input_type = sreg.get_function_input_type(); - ( - sreg, - ast::MethodDeclaration { - return_arguments: vec![ast::Variable { - align: None, - v_type: return_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }], - name: name, - input_arguments: input_type - .into_iter() - .map(|type_| ast::Variable { - align: None, - v_type: type_.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }) - .collect::>(), - shared_mem: None, - }, - ) - }) + let return_arguments = vec![ast::Variable { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }]; + let input_arguments = input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(); + fn_(sreg, (return_arguments, name, input_arguments)); + } } } diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 5155886..4d94897 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -4,7 +4,7 @@ use ptx_parser as ast; pub(crate) fn run<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, directives: Vec>>, -) -> Result>, TranslateError> { +) -> Result, TranslateError> { resolver.start_scope(); let result = directives .into_iter() @@ -17,7 +17,7 @@ pub(crate) fn run<'input, 'b>( fn run_directive<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, -) -> Result, TranslateError> { +) -> Result { Ok(match directive { ast::Directive::Variable(linking, var) => { NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) @@ -32,15 +32,11 @@ fn run_method<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, linkage: ast::LinkingDirective, method: ast::Function<'input, &'input str, ast::Statement>>, -) -> Result, TranslateError> { - let name = match method.func_directive.name { - ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), - ast::MethodName::Func(text) => { - ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?) - } - }; +) -> Result { + let is_kernel = method.func_directive.name.is_kernel(); + let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?; resolver.start_scope(); - let func_decl = run_function_decl(resolver, method.func_directive, name)?; + let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?; let body = method .body .map(|statements| { @@ -51,20 +47,21 @@ fn run_method<'input, 'b>( .transpose()?; resolver.end_scope(); Ok(Function2 { - func_decl, - globals: Vec::new(), + return_arguments, + name, + input_arguments, body, import_as: None, tuning: method.tuning, linkage, + is_kernel, }) } fn run_function_decl<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, func_directive: ast::MethodDeclaration<'input, &'input str>, - name: ast::MethodName<'input, SpirvWord>, -) -> Result, TranslateError> { +) -> Result<(Vec>, Vec>), TranslateError> { assert!(func_directive.shared_mem.is_none()); let return_arguments = func_directive .return_arguments @@ -76,12 +73,7 @@ fn run_function_decl<'input, 'b>( .into_iter() .map(|var| run_variable(resolver, var)) .collect::, _>>()?; - Ok(ast::MethodDeclaration { - return_arguments, - name, - input_arguments, - shared_mem: None, - }) + Ok((return_arguments, input_arguments)) } fn run_variable<'input, 'b>( diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index d91e23c..f8505b6 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -3,8 +3,8 @@ use ptx_parser as ast; pub(crate) fn run<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec>, -) -> Result>, TranslateError> { + directives: Vec, +) -> Result, TranslateError> { directives .into_iter() .map(|directive| run_directive(resolver, directive)) @@ -13,8 +13,8 @@ pub(crate) fn run<'input>( fn run_directive<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directive: NormalizedDirective2<'input>, -) -> Result, TranslateError> { + directive: NormalizedDirective2, +) -> Result { Ok(match directive { Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), @@ -23,8 +23,8 @@ fn run_directive<'input>( fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - method: NormalizedFunction2<'input>, -) -> Result, TranslateError> { + method: NormalizedFunction2, +) -> Result { let body = method .body .map(|statements| { @@ -36,12 +36,14 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { - func_decl: method.func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 668cc21..089d276 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -2,8 +2,8 @@ use super::*; pub(super) fn run<'input>( resolver: &mut GlobalStringIdentResolver2<'input>, - directives: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { let mut fn_declarations = FxHashMap::default(); let remapped_directives = directives .into_iter() @@ -13,17 +13,14 @@ pub(super) fn run<'input>( .into_iter() .map(|(_, (return_arguments, name, input_arguments))| { Directive2::Method(Function2 { - func_decl: ast::MethodDeclaration { - return_arguments, - name: ast::MethodName::Func(name), - input_arguments, - shared_mem: None, - }, - globals: Vec::new(), + return_arguments, + name: name, + input_arguments, body: None, import_as: None, tuning: Vec::new(), linkage: ast::LinkingDirective::EXTERN, + is_kernel: false, }) }) .collect::>(); @@ -41,8 +38,8 @@ fn run_directive<'input>( Vec>, ), >, - directive: Directive2<'input, ast::Instruction, SpirvWord>, -) -> Result, SpirvWord>, TranslateError> { + directive: Directive2, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(mut method) => { diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs index 56bb7e6..48f2b45 100644 --- a/ptx/src/pass/replace_known_functions.rs +++ b/ptx/src/pass/replace_known_functions.rs @@ -1,14 +1,15 @@ +use std::borrow::Cow; + use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; pub(crate) fn run<'input>( - resolver: &GlobalStringIdentResolver2<'input>, - mut directives: Vec>, -) -> Vec> { + resolver: &mut GlobalStringIdentResolver2<'input>, + mut directives: Vec, +) -> Vec { for directive in directives.iter_mut() { match directive { NormalizedDirective2::Method(func) => { - func.import_as = - replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take()); + replace_with_ptx_impl(resolver, func.name); } _ => {} } @@ -17,22 +18,16 @@ pub(crate) fn run<'input>( } fn replace_with_ptx_impl<'input>( - resolver: &GlobalStringIdentResolver2<'input>, - fn_name: &ptx_parser::MethodName<'input, SpirvWord>, - name: Option, -) -> Option { + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_name: SpirvWord, +) { let known_names = ["__assertfail"]; - match name { - Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)), - Some(name) => Some(name), - None => match fn_name { - ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) { - Some(super::IdentEntry { - name: Some(name), .. - }) => Some(format!("__zluda_ptx_impl_{}", name)), - _ => None, - }, - ptx_parser::MethodName::Kernel(..) => None, - }, + if let Some(super::IdentEntry { + name: Some(name), .. + }) = resolver.ident_map.get_mut(&fn_name) + { + if known_names.contains(&&**name) { + *name = Cow::Owned(format!("__zluda_ptx_impl_{}", name)); + } } } diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs index eb7abb1..1721afd 100644 --- a/ptx/src/pass/resolve_function_pointers.rs +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -3,8 +3,8 @@ use ptx_parser as ast; use rustc_hash::FxHashSet; pub(crate) fn run<'input>( - directives: Vec>, -) -> Result>, TranslateError> { + directives: Vec, +) -> Result, TranslateError> { let mut functions = FxHashSet::default(); directives .into_iter() @@ -14,19 +14,13 @@ pub(crate) fn run<'input>( fn run_directive<'input>( functions: &mut FxHashSet, - directive: UnconditionalDirective<'input>, -) -> Result, TranslateError> { + directive: UnconditionalDirective, +) -> Result { Ok(match directive { var @ Directive2::Variable(..) => var, Directive2::Method(method) => { - { - let func_decl = &method.func_decl; - match func_decl.name { - ptx_parser::MethodName::Kernel(_) => {} - ptx_parser::MethodName::Func(name) => { - functions.insert(name); - } - } + if !method.is_kernel { + functions.insert(method.name); } Directive2::Method(run_method(functions, method)?) } @@ -35,8 +29,8 @@ fn run_directive<'input>( fn run_method<'input>( functions: &mut FxHashSet, - method: UnconditionalFunction<'input>, -) -> Result, TranslateError> { + method: UnconditionalFunction, +) -> Result { let body = method .body .map(|statements| { @@ -47,12 +41,14 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { - func_decl: method.func_decl, - globals: method.globals, + return_arguments: method.return_arguments, + name: method.name, + input_arguments: method.input_arguments, body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, + is_kernel: method.is_kernel, }) } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 19a2897..c2776c8 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1028,7 +1028,7 @@ pub struct ArithFloat { // round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, // mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add // instructions on the target device. - pub is_fusable: bool + pub is_fusable: bool, } #[derive(Copy, Clone, PartialEq, Eq)] @@ -1447,6 +1447,7 @@ pub struct CvtDetails { pub mode: CvtMode, } +#[derive(Clone, Copy)] pub enum CvtMode { // int from int ZeroExtend, @@ -1465,7 +1466,7 @@ pub enum CvtMode { flush_to_zero: Option, }, FPRound { - integer_rounding: Option, + integer_rounding: RoundingMode, flush_to_zero: Option, }, // int from float @@ -1519,7 +1520,7 @@ impl CvtDetails { flush_to_zero, }, Ordering::Equal => CvtMode::FPRound { - integer_rounding: rounding, + integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven), flush_to_zero, }, Ordering::Greater => {