Continue working on ftz modes

This commit is contained in:
Andrzej Janik 2025-02-18 02:42:17 +00:00
parent 17529f951d
commit 5121bba285
15 changed files with 559 additions and 226 deletions

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
@ -12,8 +12,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -22,13 +22,13 @@ fn run_directive<'input>(
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
if !method.func_decl.name.is_kernel() {
for arg in method.func_decl.return_arguments.iter_mut() {
if !method.is_kernel {
for arg in method.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
@ -51,7 +51,7 @@ fn run_method<'input>(
_ => return Err(error_unreachable()),
}
}
for arg in method.func_decl.input_arguments.iter_mut() {
for arg in method.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
@ -96,12 +96,14 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -168,7 +168,7 @@ impl Deref for MemoryBuffer {
pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
@ -218,24 +218,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn emit_method(
&mut self,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
.or_else(|| match func_decl.name {
ast::MethodName::Kernel(name) => Some(name),
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
})
.or_else(|| self.id_defs.ident_map[&method.name].name.as_deref())
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() {
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl
method.return_arguments.iter().map(|v| &v.v_type),
method
.input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
@ -245,15 +241,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
self.emit_fn_attribute(fn_, "no-trapping-math", "true");
}
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
if !method.is_kernel {
self.resolver.register(method.name, fn_);
}
for (i, param) in func_decl.input_arguments.iter().enumerate() {
for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
if func_decl.name.is_kernel() {
if method.is_kernel {
let attr_kind = unsafe {
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
};
@ -267,7 +263,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
}
}
let call_conv = if func_decl.name.is_kernel() {
let call_conv = if method.is_kernel {
Self::kernel_call_convention()
} else {
Self::func_call_convention()
@ -282,7 +278,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
for var in func_decl.return_arguments {
for var in method.return_arguments {
method_emitter.emit_variable(var)?;
}
for statement in statements.iter() {
@ -1558,7 +1554,7 @@ impl<'a> MethodEmitContext<'a> {
return self.emit_cvt_float_to_int(
data.from,
data.to,
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
integer_rounding,
arguments,
Some(LLVMBuildFPToSI),
)

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
@ -13,11 +13,10 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -27,11 +26,10 @@ fn run_directive<'input>(
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method
.body
.map(|statements| {
@ -43,12 +41,14 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -1,30 +1,29 @@
use super::*;
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
let mut result = Vec::with_capacity(declarations.len() + directives.len());
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
for (sreg, declaration) in declarations {
let name = if let ast::MethodName::Func(name) = declaration.name {
name
} else {
return Err(error_unreachable());
};
result.push(UnconditionalDirective::Method(UnconditionalFunction {
func_decl: declaration,
globals: Vec::new(),
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
}));
sreg_to_function.insert(sreg, name);
}
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
SpecialRegistersMap2::foreach_declaration(
resolver,
|sreg, (return_arguments, name, input_arguments)| {
result.push(UnconditionalDirective::Method(UnconditionalFunction {
return_arguments,
name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
}));
sreg_to_function.insert(sreg, name);
},
);
let mut visitor = SpecialRegisterResolver {
resolver,
special_registers,
@ -39,8 +38,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
@ -49,8 +48,8 @@ fn run_directive<'a, 'input>(
fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
@ -62,12 +61,14 @@ fn run_method<'a, 'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -1,8 +1,8 @@
use super::*;
pub(super) fn run<'input>(
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive)?;
@ -12,8 +12,8 @@ pub(super) fn run<'input>(
}
fn run_directive<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match directive {
Directive2::Variable(..) => {}
@ -23,8 +23,8 @@ fn run_directive<'input>(
}
fn run_function<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) {
function.body = function.body.take().map(|statements| {
statements

View file

@ -11,8 +11,8 @@ use super::*;
// pass, so we do nothing there
pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
@ -21,8 +21,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
@ -34,12 +34,11 @@ fn run_directive<'a, 'input>(
fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
let is_kernel = func_decl.name.is_kernel();
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_kernel = method.is_kernel;
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
for arg in method.input_arguments.iter_mut() {
let old_name = arg.name;
let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry;
@ -51,10 +50,10 @@ fn run_method<'a, 'input>(
arg.state_space = new_space;
}
};
for arg in func_decl.return_arguments.iter_mut() {
for arg in method.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let return_arguments = &func_decl.return_arguments[..];
let return_arguments = &method.return_arguments[..];
let body = method
.body
.map(move |statements| {
@ -66,12 +65,14 @@ fn run_method<'a, 'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -1,3 +1,5 @@
use crate::pass::error_unreachable;
use super::BrachCondition;
use super::Directive2;
use super::Function2;
@ -17,6 +19,178 @@ use rustc_hash::FxHashSet;
use std::hash::Hash;
use std::iter;
#[derive(Default)]
enum DenormalMode {
#[default]
FlushToZero,
Preserve,
}
impl DenormalMode {
fn from_ftz(ftz: bool) -> Self {
if ftz {
DenormalMode::FlushToZero
} else {
DenormalMode::Preserve
}
}
}
#[derive(Default)]
enum RoundingMode {
#[default]
NearestEven,
Zero,
NegativeInf,
PositiveInf,
}
impl RoundingMode {
fn to_ast(self) -> ast::RoundingMode {
match self {
RoundingMode::NearestEven => ast::RoundingMode::NearestEven,
RoundingMode::Zero => ast::RoundingMode::Zero,
RoundingMode::NegativeInf => ast::RoundingMode::NegativeInf,
RoundingMode::PositiveInf => ast::RoundingMode::PositiveInf,
}
}
fn from_ast(rnd: ast::RoundingMode) -> Self {
match rnd {
ast::RoundingMode::NearestEven => RoundingMode::NearestEven,
ast::RoundingMode::Zero => RoundingMode::Zero,
ast::RoundingMode::NegativeInf => RoundingMode::NegativeInf,
ast::RoundingMode::PositiveInf => RoundingMode::PositiveInf,
}
}
}
struct InstructionModes {
denormal_f32: Option<DenormalMode>,
denormal_f16_f64: Option<DenormalMode>,
rounding_f32: Option<RoundingMode>,
rounding_f16_f64: Option<RoundingMode>,
}
impl InstructionModes {
fn none() -> Self {
Self {
denormal_f32: None,
denormal_f16_f64: None,
rounding_f32: None,
rounding_f16_f64: None,
}
}
fn new(
type_: ast::ScalarType,
denormal: Option<DenormalMode>,
rounding: Option<RoundingMode>,
) -> Self {
if type_ != ast::ScalarType::F32 {
Self {
denormal_f16_f64: denormal,
rounding_f16_f64: rounding,
..Self::none()
}
} else {
Self {
denormal_f32: denormal,
rounding_f32: rounding,
..Self::none()
}
}
}
fn mixed_ftz_f32(
type_: ast::ScalarType,
denormal: Option<DenormalMode>,
rounding: Option<RoundingMode>,
) -> Self {
if type_ != ast::ScalarType::F32 {
Self {
denormal_f16_f64: denormal,
rounding_f32: rounding,
..Self::none()
}
} else {
Self {
denormal_f32: denormal,
rounding_f32: rounding,
..Self::none()
}
}
}
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz);
let rounding = Some(RoundingMode::from_ast(arith.rounding));
InstructionModes::new(arith.type_, denormal, rounding)
}
fn from_ftz(type_: ast::ScalarType, ftz: Option<bool>) -> Self {
Self::new(type_, ftz.map(DenormalMode::from_ftz), None)
}
fn from_ftz_f32(ftz: bool) -> Self {
Self::new(
ast::ScalarType::F32,
Some(DenormalMode::from_ftz(ftz)),
None,
)
}
fn from_rcp(data: ast::RcpData) -> InstructionModes {
let rounding = match data.kind {
ast::RcpKind::Approx => None,
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
};
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
InstructionModes::new(data.type_, denormal, rounding)
}
fn from_cvt(cvt: &ast::CvtDetails) -> InstructionModes {
match cvt.mode {
ast::CvtMode::ZeroExtend
| ast::CvtMode::SignExtend
| ast::CvtMode::Truncate
| ast::CvtMode::Bitcast
| ast::CvtMode::SaturateUnsignedToSigned
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
ast::CvtMode::FPExtend { flush_to_zero } => {
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
}
ast::CvtMode::FPTruncate {
rounding,
flush_to_zero,
}
| ast::CvtMode::FPRound {
integer_rounding: rounding,
flush_to_zero,
} => Self::mixed_ftz_f32(
cvt.to,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::SignedFromFP {
flush_to_zero,
rounding,
}
| ast::CvtMode::UnsignedFromFP {
flush_to_zero,
rounding,
} => Self::new(
cvt.from,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
}
}
}
}
struct ControlFlowGraph<T: Eq + PartialEq> {
entry_points: FxHashMap<SpirvWord, NodeIndex>,
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
@ -74,19 +248,40 @@ struct Node<T> {
pub(crate) fn run<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut cfg = ControlFlowGraph::<bool>::new();
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
for directive in directives.iter() {
match directive {
super::Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration { name, .. },
body,
name,
body: Some(body),
..
}) => {
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
for statement in body.iter() {
todo!()
match statement {
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, arguments.src);
basic_block = None;
}
Statement::Label(label) => {
basic_block = Some(cfg.get_or_add_basic_block(*label));
}
Statement::Conditional(BrachCondition {
if_true, if_false, ..
}) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, *if_true);
cfg.add_jump(bb_index, *if_false);
basic_block = None;
}
Statement::Instruction(instruction) => {
let modes = get_modes(instruction);
}
_ => continue,
}
}
}
_ => continue,
@ -280,6 +475,169 @@ impl<T: Copy + Eq + Hash> UniqueVec<T> {
}
}
fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
match inst {
// TODO: review it when implementing virtual calls
ast::Instruction::Call { .. }
| ast::Instruction::Mov { .. }
| ast::Instruction::Ld { .. }
| ast::Instruction::St { .. }
| ast::Instruction::PrmtSlow { .. }
| ast::Instruction::Prmt { .. }
| ast::Instruction::Activemask { .. }
| ast::Instruction::Membar { .. }
| ast::Instruction::Trap {}
| ast::Instruction::Not { .. }
| ast::Instruction::Or { .. }
| ast::Instruction::And { .. }
| ast::Instruction::Bra { .. }
| ast::Instruction::Clz { .. }
| ast::Instruction::Brev { .. }
| ast::Instruction::Popc { .. }
| ast::Instruction::Xor { .. }
| ast::Instruction::Rem { .. }
| ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Shr { .. }
| ast::Instruction::Shl { .. }
| ast::Instruction::Selp { .. }
| ast::Instruction::Ret { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::Cvta { .. }
| ast::Instruction::Atom { .. }
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..
}
| ast::Instruction::Sub {
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Integer { .. },
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Integer { .. },
..
}
| ast::Instruction::Min {
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
..
}
| ast::Instruction::Max {
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
..
}
| ast::Instruction::Div {
data: ast::DivDetails::Signed(..) | ast::DivDetails::Unsigned(..),
..
} => InstructionModes::none(),
ast::Instruction::Fma { data, .. }
| ast::Instruction::Sub {
data: ast::ArithDetails::Float(data),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Float(data),
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Float(data),
..
}
| ast::Instruction::Add {
data: ast::ArithDetails::Float(data),
..
} => InstructionModes::from_arith_float(data),
ast::Instruction::Setp {
data:
ast::SetpData {
type_,
flush_to_zero,
..
},
..
}
| ast::Instruction::SetpBool {
data:
ast::SetpBoolData {
base:
ast::SetpData {
type_,
flush_to_zero,
..
},
..
},
..
}
| ast::Instruction::Neg {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Ex2 {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Rsqrt {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Abs {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Min {
data:
ast::MinMaxDetails::Float(ast::MinMaxFloat {
type_,
flush_to_zero,
..
}),
..
}
| ast::Instruction::Max {
data:
ast::MinMaxDetails::Float(ast::MinMaxFloat {
type_,
flush_to_zero,
..
}),
..
}
| ast::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
type_,
flush_to_zero,
..
}),
..
} => InstructionModes::from_ftz(*type_, *flush_to_zero),
ast::Instruction::Sin { data, .. }
| ast::Instruction::Cos { data, .. }
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
InstructionModes::from_rcp(*data)
}
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -19,8 +19,8 @@ use ptx_parser as ast;
*/
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
@ -29,8 +29,8 @@ pub(super) fn run<'input>(
fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {

View file

@ -44,7 +44,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
let directives = replace_known_functions::run(&flat_resolver, directives);
let directives = replace_known_functions::run(&mut flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
@ -559,22 +559,23 @@ type NormalizedStatement = Statement<
ast::ParsedOperand<SpirvWord>,
>;
enum Directive2<'input, Instruction, Operand: ast::Operand> {
enum Directive2<Instruction, Operand: ast::Operand> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function2<'input, Instruction, Operand>),
Method(Function2<Instruction, Operand>),
}
struct Function2<'input, Instruction, Operand: ast::Operand> {
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
struct Function2<Instruction, Operand: ast::Operand> {
pub return_arguments: Vec<ast::Variable<Operand::Ident>>,
pub name: Operand::Ident,
pub input_arguments: Vec<ast::Variable<Operand::Ident>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
is_kernel: bool,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
}
type NormalizedDirective2<'input> = Directive2<
'input,
type NormalizedDirective2 = Directive2<
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
@ -582,8 +583,7 @@ type NormalizedDirective2<'input> = Directive2<
ast::ParsedOperand<SpirvWord>,
>;
type NormalizedFunction2<'input> = Function2<
'input,
type NormalizedFunction2 = Function2<
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
@ -591,17 +591,11 @@ type NormalizedFunction2<'input> = Function2<
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalDirective<'input> = Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalDirective =
Directive2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
type UnconditionalFunction<'input> = Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalFunction =
Function2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
struct GlobalStringIdentResolver2<'input> {
pub(crate) current_id: SpirvWord,
@ -807,47 +801,45 @@ impl SpecialRegistersMap2 {
self.id_to_reg.get(&id).copied()
}
fn generate_declarations<'a, 'input>(
fn len() -> usize {
PtxSpecialRegister::iter().len()
}
fn foreach_declaration<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
) -> impl ExactSizeIterator<
Item = (
mut fn_: impl FnMut(
PtxSpecialRegister,
ast::MethodDeclaration<'input, SpirvWord>,
(
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
),
> + 'a {
PtxSpecialRegister::iter().map(|sreg| {
) {
for sreg in PtxSpecialRegister::iter() {
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let name =
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let name = resolver.register_named(Cow::Owned(external_fn_name), None);
let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_input_type();
(
sreg,
ast::MethodDeclaration {
return_arguments: vec![ast::Variable {
align: None,
v_type: return_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}],
name: name,
input_arguments: input_type
.into_iter()
.map(|type_| ast::Variable {
align: None,
v_type: type_.into(),
state_space: ast::StateSpace::Reg,
name: resolver
.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
})
.collect::<Vec<_>>(),
shared_mem: None,
},
)
})
let return_arguments = vec![ast::Variable {
align: None,
v_type: return_type.into(),
state_space: ast::StateSpace::Reg,
name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
}];
let input_arguments = input_type
.into_iter()
.map(|type_| ast::Variable {
align: None,
v_type: type_.into(),
state_space: ast::StateSpace::Reg,
name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
})
.collect::<Vec<_>>();
fn_(sreg, (return_arguments, name, input_arguments));
}
}
}

View file

@ -4,7 +4,7 @@ use ptx_parser as ast;
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
resolver.start_scope();
let result = directives
.into_iter()
@ -17,7 +17,7 @@ pub(crate) fn run<'input, 'b>(
fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
) -> Result<NormalizedDirective2, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
@ -32,15 +32,11 @@ fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => {
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
}
};
) -> Result<NormalizedFunction2, TranslateError> {
let is_kernel = method.func_directive.name.is_kernel();
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
resolver.start_scope();
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
let body = method
.body
.map(|statements| {
@ -51,20 +47,21 @@ fn run_method<'input, 'b>(
.transpose()?;
resolver.end_scope();
Ok(Function2 {
func_decl,
globals: Vec::new(),
return_arguments,
name,
input_arguments,
body,
import_as: None,
tuning: method.tuning,
linkage,
is_kernel,
})
}
fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive
.return_arguments
@ -76,12 +73,7 @@ fn run_function_decl<'input, 'b>(
.into_iter()
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
shared_mem: None,
})
Ok((return_arguments, input_arguments))
}
fn run_variable<'input, 'b>(

View file

@ -3,8 +3,8 @@ use ptx_parser as ast;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
directives: Vec<NormalizedDirective2>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
@ -13,8 +13,8 @@ pub(crate) fn run<'input>(
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
directive: NormalizedDirective2,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -23,8 +23,8 @@ fn run_directive<'input>(
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
method: NormalizedFunction2,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
@ -36,12 +36,14 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives
.into_iter()
@ -13,17 +13,14 @@ pub(super) fn run<'input>(
.into_iter()
.map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration {
return_arguments,
name: ast::MethodName::Func(name),
input_arguments,
shared_mem: None,
},
globals: Vec::new(),
return_arguments,
name: name,
input_arguments,
body: None,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
})
})
.collect::<Vec<_>>();
@ -41,8 +38,8 @@ fn run_directive<'input>(
Vec<ast::Variable<SpirvWord>>,
),
>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => {

View file

@ -1,14 +1,15 @@
use std::borrow::Cow;
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
pub(crate) fn run<'input>(
resolver: &GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2<'input>>,
) -> Vec<NormalizedDirective2<'input>> {
resolver: &mut GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2>,
) -> Vec<NormalizedDirective2> {
for directive in directives.iter_mut() {
match directive {
NormalizedDirective2::Method(func) => {
func.import_as =
replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
replace_with_ptx_impl(resolver, func.name);
}
_ => {}
}
@ -17,22 +18,16 @@ pub(crate) fn run<'input>(
}
fn replace_with_ptx_impl<'input>(
resolver: &GlobalStringIdentResolver2<'input>,
fn_name: &ptx_parser::MethodName<'input, SpirvWord>,
name: Option<String>,
) -> Option<String> {
resolver: &mut GlobalStringIdentResolver2<'input>,
fn_name: SpirvWord,
) {
let known_names = ["__assertfail"];
match name {
Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)),
Some(name) => Some(name),
None => match fn_name {
ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) {
Some(super::IdentEntry {
name: Some(name), ..
}) => Some(format!("__zluda_ptx_impl_{}", name)),
_ => None,
},
ptx_parser::MethodName::Kernel(..) => None,
},
if let Some(super::IdentEntry {
name: Some(name), ..
}) = resolver.ident_map.get_mut(&fn_name)
{
if known_names.contains(&&**name) {
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
}
}
}

View file

@ -3,8 +3,8 @@ use ptx_parser as ast;
use rustc_hash::FxHashSet;
pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut functions = FxHashSet::default();
directives
.into_iter()
@ -14,19 +14,13 @@ pub(crate) fn run<'input>(
fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
directive: UnconditionalDirective,
) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
{
let func_decl = &method.func_decl;
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {
functions.insert(name);
}
}
if !method.is_kernel {
functions.insert(method.name);
}
Directive2::Method(run_method(functions, method)?)
}
@ -35,8 +29,8 @@ fn run_directive<'input>(
fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
method: UnconditionalFunction,
) -> Result<UnconditionalFunction, TranslateError> {
let body = method
.body
.map(|statements| {
@ -47,12 +41,14 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
}

View file

@ -1028,7 +1028,7 @@ pub struct ArithFloat {
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
// instructions on the target device.
pub is_fusable: bool
pub is_fusable: bool,
}
#[derive(Copy, Clone, PartialEq, Eq)]
@ -1447,6 +1447,7 @@ pub struct CvtDetails {
pub mode: CvtMode,
}
#[derive(Clone, Copy)]
pub enum CvtMode {
// int from int
ZeroExtend,
@ -1465,7 +1466,7 @@ pub enum CvtMode {
flush_to_zero: Option<bool>,
},
FPRound {
integer_rounding: Option<RoundingMode>,
integer_rounding: RoundingMode,
flush_to_zero: Option<bool>,
},
// int from float
@ -1519,7 +1520,7 @@ impl CvtDetails {
flush_to_zero,
},
Ordering::Equal => CvtMode::FPRound {
integer_rounding: rounding,
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
flush_to_zero,
},
Ordering::Greater => {