mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Continue working on ftz modes
This commit is contained in:
parent
17529f951d
commit
5121bba285
15 changed files with 559 additions and 226 deletions
|
@ -2,8 +2,8 @@ use super::*;
|
|||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
|
@ -12,8 +12,8 @@ pub(super) fn run<'a, 'input>(
|
|||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
|
@ -22,13 +22,13 @@ fn run_directive<'input>(
|
|||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
if !method.func_decl.name.is_kernel() {
|
||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||
if !method.is_kernel {
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
|
@ -51,7 +51,7 @@ fn run_method<'input>(
|
|||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
|
@ -96,12 +96,14 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -168,7 +168,7 @@ impl Deref for MemoryBuffer {
|
|||
|
||||
pub(super) fn run<'input>(
|
||||
id_defs: GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<MemoryBuffer, TranslateError> {
|
||||
let context = Context::new();
|
||||
let module = Module::new(&context, LLVM_UNNAMED);
|
||||
|
@ -218,24 +218,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
|
||||
fn emit_method(
|
||||
&mut self,
|
||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let func_decl = method.func_decl;
|
||||
let name = method
|
||||
.import_as
|
||||
.as_deref()
|
||||
.or_else(|| match func_decl.name {
|
||||
ast::MethodName::Kernel(name) => Some(name),
|
||||
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
|
||||
})
|
||||
.or_else(|| self.id_defs.ident_map[&method.name].name.as_deref())
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||
if fn_ == ptr::null_mut() {
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl
|
||||
method.return_arguments.iter().map(|v| &v.v_type),
|
||||
method
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||
|
@ -245,15 +241,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
|
||||
self.emit_fn_attribute(fn_, "no-trapping-math", "true");
|
||||
}
|
||||
if let ast::MethodName::Func(name) = func_decl.name {
|
||||
self.resolver.register(name, fn_);
|
||||
if !method.is_kernel {
|
||||
self.resolver.register(method.name, fn_);
|
||||
}
|
||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
||||
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
let name = self.resolver.get_or_add(param.name);
|
||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||
self.resolver.register(param.name, value);
|
||||
if func_decl.name.is_kernel() {
|
||||
if method.is_kernel {
|
||||
let attr_kind = unsafe {
|
||||
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
|
||||
};
|
||||
|
@ -267,7 +263,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
|
||||
}
|
||||
}
|
||||
let call_conv = if func_decl.name.is_kernel() {
|
||||
let call_conv = if method.is_kernel {
|
||||
Self::kernel_call_convention()
|
||||
} else {
|
||||
Self::func_call_convention()
|
||||
|
@ -282,7 +278,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
||||
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
||||
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
||||
for var in func_decl.return_arguments {
|
||||
for var in method.return_arguments {
|
||||
method_emitter.emit_variable(var)?;
|
||||
}
|
||||
for statement in statements.iter() {
|
||||
|
@ -1558,7 +1554,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
|
||||
integer_rounding,
|
||||
arguments,
|
||||
Some(LLVMBuildFPToSI),
|
||||
)
|
||||
|
|
|
@ -2,8 +2,8 @@ use super::*;
|
|||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
|
@ -13,11 +13,10 @@ pub(super) fn run<'a, 'input>(
|
|||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
|
@ -27,11 +26,10 @@ fn run_directive<'input>(
|
|||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -43,12 +41,14 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,30 +1,29 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
special_registers: &'a SpecialRegistersMap2,
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
|
||||
let mut result = Vec::with_capacity(declarations.len() + directives.len());
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
||||
let mut sreg_to_function =
|
||||
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
|
||||
for (sreg, declaration) in declarations {
|
||||
let name = if let ast::MethodName::Func(name) = declaration.name {
|
||||
name
|
||||
} else {
|
||||
return Err(error_unreachable());
|
||||
};
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
func_decl: declaration,
|
||||
globals: Vec::new(),
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
}
|
||||
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
||||
SpecialRegistersMap2::foreach_declaration(
|
||||
resolver,
|
||||
|sreg, (return_arguments, name, input_arguments)| {
|
||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
},
|
||||
);
|
||||
let mut visitor = SpecialRegisterResolver {
|
||||
resolver,
|
||||
special_registers,
|
||||
|
@ -39,8 +38,8 @@ pub(super) fn run<'a, 'input>(
|
|||
|
||||
fn run_directive<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
directive: UnconditionalDirective<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||
|
@ -49,8 +48,8 @@ fn run_directive<'a, 'input>(
|
|||
|
||||
fn run_method<'a, 'input>(
|
||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||
method: UnconditionalFunction<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -62,12 +61,14 @@ fn run_method<'a, 'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(directives.len());
|
||||
for mut directive in directives.into_iter() {
|
||||
run_directive(&mut result, &mut directive)?;
|
||||
|
@ -12,8 +12,8 @@ pub(super) fn run<'input>(
|
|||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match directive {
|
||||
Directive2::Variable(..) => {}
|
||||
|
@ -23,8 +23,8 @@ fn run_directive<'input>(
|
|||
}
|
||||
|
||||
fn run_function<'input>(
|
||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
||||
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
||||
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) {
|
||||
function.body = function.body.take().map(|statements| {
|
||||
statements
|
||||
|
|
|
@ -11,8 +11,8 @@ use super::*;
|
|||
// pass, so we do nothing there
|
||||
pub(super) fn run<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
|
@ -21,8 +21,8 @@ pub(super) fn run<'a, 'input>(
|
|||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
|
@ -34,12 +34,11 @@ fn run_directive<'a, 'input>(
|
|||
|
||||
fn run_method<'a, 'input>(
|
||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let mut func_decl = method.func_decl;
|
||||
let is_kernel = func_decl.name.is_kernel();
|
||||
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let is_kernel = method.is_kernel;
|
||||
if is_kernel {
|
||||
for arg in func_decl.input_arguments.iter_mut() {
|
||||
for arg in method.input_arguments.iter_mut() {
|
||||
let old_name = arg.name;
|
||||
let old_space = arg.state_space;
|
||||
let new_space = ast::StateSpace::ParamEntry;
|
||||
|
@ -51,10 +50,10 @@ fn run_method<'a, 'input>(
|
|||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
for arg in func_decl.return_arguments.iter_mut() {
|
||||
for arg in method.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let return_arguments = &func_decl.return_arguments[..];
|
||||
let return_arguments = &method.return_arguments[..];
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
|
@ -66,12 +65,14 @@ fn run_method<'a, 'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use crate::pass::error_unreachable;
|
||||
|
||||
use super::BrachCondition;
|
||||
use super::Directive2;
|
||||
use super::Function2;
|
||||
|
@ -17,6 +19,178 @@ use rustc_hash::FxHashSet;
|
|||
use std::hash::Hash;
|
||||
use std::iter;
|
||||
|
||||
#[derive(Default)]
|
||||
enum DenormalMode {
|
||||
#[default]
|
||||
FlushToZero,
|
||||
Preserve,
|
||||
}
|
||||
|
||||
impl DenormalMode {
|
||||
fn from_ftz(ftz: bool) -> Self {
|
||||
if ftz {
|
||||
DenormalMode::FlushToZero
|
||||
} else {
|
||||
DenormalMode::Preserve
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum RoundingMode {
|
||||
#[default]
|
||||
NearestEven,
|
||||
Zero,
|
||||
NegativeInf,
|
||||
PositiveInf,
|
||||
}
|
||||
|
||||
impl RoundingMode {
|
||||
fn to_ast(self) -> ast::RoundingMode {
|
||||
match self {
|
||||
RoundingMode::NearestEven => ast::RoundingMode::NearestEven,
|
||||
RoundingMode::Zero => ast::RoundingMode::Zero,
|
||||
RoundingMode::NegativeInf => ast::RoundingMode::NegativeInf,
|
||||
RoundingMode::PositiveInf => ast::RoundingMode::PositiveInf,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_ast(rnd: ast::RoundingMode) -> Self {
|
||||
match rnd {
|
||||
ast::RoundingMode::NearestEven => RoundingMode::NearestEven,
|
||||
ast::RoundingMode::Zero => RoundingMode::Zero,
|
||||
ast::RoundingMode::NegativeInf => RoundingMode::NegativeInf,
|
||||
ast::RoundingMode::PositiveInf => RoundingMode::PositiveInf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InstructionModes {
|
||||
denormal_f32: Option<DenormalMode>,
|
||||
denormal_f16_f64: Option<DenormalMode>,
|
||||
rounding_f32: Option<RoundingMode>,
|
||||
rounding_f16_f64: Option<RoundingMode>,
|
||||
}
|
||||
|
||||
impl InstructionModes {
|
||||
fn none() -> Self {
|
||||
Self {
|
||||
denormal_f32: None,
|
||||
denormal_f16_f64: None,
|
||||
rounding_f32: None,
|
||||
rounding_f16_f64: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn new(
|
||||
type_: ast::ScalarType,
|
||||
denormal: Option<DenormalMode>,
|
||||
rounding: Option<RoundingMode>,
|
||||
) -> Self {
|
||||
if type_ != ast::ScalarType::F32 {
|
||||
Self {
|
||||
denormal_f16_f64: denormal,
|
||||
rounding_f16_f64: rounding,
|
||||
..Self::none()
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
denormal_f32: denormal,
|
||||
rounding_f32: rounding,
|
||||
..Self::none()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mixed_ftz_f32(
|
||||
type_: ast::ScalarType,
|
||||
denormal: Option<DenormalMode>,
|
||||
rounding: Option<RoundingMode>,
|
||||
) -> Self {
|
||||
if type_ != ast::ScalarType::F32 {
|
||||
Self {
|
||||
denormal_f16_f64: denormal,
|
||||
rounding_f32: rounding,
|
||||
..Self::none()
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
denormal_f32: denormal,
|
||||
rounding_f32: rounding,
|
||||
..Self::none()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
|
||||
let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz);
|
||||
let rounding = Some(RoundingMode::from_ast(arith.rounding));
|
||||
InstructionModes::new(arith.type_, denormal, rounding)
|
||||
}
|
||||
|
||||
fn from_ftz(type_: ast::ScalarType, ftz: Option<bool>) -> Self {
|
||||
Self::new(type_, ftz.map(DenormalMode::from_ftz), None)
|
||||
}
|
||||
|
||||
fn from_ftz_f32(ftz: bool) -> Self {
|
||||
Self::new(
|
||||
ast::ScalarType::F32,
|
||||
Some(DenormalMode::from_ftz(ftz)),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
fn from_rcp(data: ast::RcpData) -> InstructionModes {
|
||||
let rounding = match data.kind {
|
||||
ast::RcpKind::Approx => None,
|
||||
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
|
||||
};
|
||||
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
|
||||
InstructionModes::new(data.type_, denormal, rounding)
|
||||
}
|
||||
|
||||
fn from_cvt(cvt: &ast::CvtDetails) -> InstructionModes {
|
||||
match cvt.mode {
|
||||
ast::CvtMode::ZeroExtend
|
||||
| ast::CvtMode::SignExtend
|
||||
| ast::CvtMode::Truncate
|
||||
| ast::CvtMode::Bitcast
|
||||
| ast::CvtMode::SaturateUnsignedToSigned
|
||||
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
|
||||
ast::CvtMode::FPExtend { flush_to_zero } => {
|
||||
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
|
||||
}
|
||||
ast::CvtMode::FPTruncate {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
}
|
||||
| ast::CvtMode::FPRound {
|
||||
integer_rounding: rounding,
|
||||
flush_to_zero,
|
||||
} => Self::mixed_ftz_f32(
|
||||
cvt.to,
|
||||
flush_to_zero.map(DenormalMode::from_ftz),
|
||||
Some(RoundingMode::from_ast(rounding)),
|
||||
),
|
||||
ast::CvtMode::SignedFromFP {
|
||||
flush_to_zero,
|
||||
rounding,
|
||||
}
|
||||
| ast::CvtMode::UnsignedFromFP {
|
||||
flush_to_zero,
|
||||
rounding,
|
||||
} => Self::new(
|
||||
cvt.from,
|
||||
flush_to_zero.map(DenormalMode::from_ftz),
|
||||
Some(RoundingMode::from_ast(rounding)),
|
||||
),
|
||||
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
||||
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ControlFlowGraph<T: Eq + PartialEq> {
|
||||
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
||||
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
||||
|
@ -74,19 +248,40 @@ struct Node<T> {
|
|||
|
||||
pub(crate) fn run<'input>(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut cfg = ControlFlowGraph::<bool>::new();
|
||||
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
|
||||
for directive in directives.iter() {
|
||||
match directive {
|
||||
super::Directive2::Method(Function2 {
|
||||
func_decl: ast::MethodDeclaration { name, .. },
|
||||
body,
|
||||
name,
|
||||
body: Some(body),
|
||||
..
|
||||
}) => {
|
||||
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
||||
for statement in body.iter() {
|
||||
todo!()
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||
cfg.add_jump(bb_index, arguments.src);
|
||||
basic_block = None;
|
||||
}
|
||||
Statement::Label(label) => {
|
||||
basic_block = Some(cfg.get_or_add_basic_block(*label));
|
||||
}
|
||||
Statement::Conditional(BrachCondition {
|
||||
if_true, if_false, ..
|
||||
}) => {
|
||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||
cfg.add_jump(bb_index, *if_true);
|
||||
cfg.add_jump(bb_index, *if_false);
|
||||
basic_block = None;
|
||||
}
|
||||
Statement::Instruction(instruction) => {
|
||||
let modes = get_modes(instruction);
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => continue,
|
||||
|
@ -280,6 +475,169 @@ impl<T: Copy + Eq + Hash> UniqueVec<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||
match inst {
|
||||
// TODO: review it when implementing virtual calls
|
||||
ast::Instruction::Call { .. }
|
||||
| ast::Instruction::Mov { .. }
|
||||
| ast::Instruction::Ld { .. }
|
||||
| ast::Instruction::St { .. }
|
||||
| ast::Instruction::PrmtSlow { .. }
|
||||
| ast::Instruction::Prmt { .. }
|
||||
| ast::Instruction::Activemask { .. }
|
||||
| ast::Instruction::Membar { .. }
|
||||
| ast::Instruction::Trap {}
|
||||
| ast::Instruction::Not { .. }
|
||||
| ast::Instruction::Or { .. }
|
||||
| ast::Instruction::And { .. }
|
||||
| ast::Instruction::Bra { .. }
|
||||
| ast::Instruction::Clz { .. }
|
||||
| ast::Instruction::Brev { .. }
|
||||
| ast::Instruction::Popc { .. }
|
||||
| ast::Instruction::Xor { .. }
|
||||
| ast::Instruction::Rem { .. }
|
||||
| ast::Instruction::Bfe { .. }
|
||||
| ast::Instruction::Bfi { .. }
|
||||
| ast::Instruction::Shr { .. }
|
||||
| ast::Instruction::Shl { .. }
|
||||
| ast::Instruction::Selp { .. }
|
||||
| ast::Instruction::Ret { .. }
|
||||
| ast::Instruction::Bar { .. }
|
||||
| ast::Instruction::Cvta { .. }
|
||||
| ast::Instruction::Atom { .. }
|
||||
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
|
||||
ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Integer(_),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Sub {
|
||||
data: ast::ArithDetails::Integer(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Integer { .. },
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mad {
|
||||
data: ast::MadDetails::Integer { .. },
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Min {
|
||||
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Max {
|
||||
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Div {
|
||||
data: ast::DivDetails::Signed(..) | ast::DivDetails::Unsigned(..),
|
||||
..
|
||||
} => InstructionModes::none(),
|
||||
ast::Instruction::Fma { data, .. }
|
||||
| ast::Instruction::Sub {
|
||||
data: ast::ArithDetails::Float(data),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mul {
|
||||
data: ast::MulDetails::Float(data),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Mad {
|
||||
data: ast::MadDetails::Float(data),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Add {
|
||||
data: ast::ArithDetails::Float(data),
|
||||
..
|
||||
} => InstructionModes::from_arith_float(data),
|
||||
ast::Instruction::Setp {
|
||||
data:
|
||||
ast::SetpData {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::SetpBool {
|
||||
data:
|
||||
ast::SetpBoolData {
|
||||
base:
|
||||
ast::SetpData {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Neg {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Ex2 {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Rsqrt {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Abs {
|
||||
data: ast::TypeFtz {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
},
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Min {
|
||||
data:
|
||||
ast::MinMaxDetails::Float(ast::MinMaxFloat {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Max {
|
||||
data:
|
||||
ast::MinMaxDetails::Float(ast::MinMaxFloat {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
}),
|
||||
..
|
||||
}
|
||||
| ast::Instruction::Div {
|
||||
data:
|
||||
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||
type_,
|
||||
flush_to_zero,
|
||||
..
|
||||
}),
|
||||
..
|
||||
} => InstructionModes::from_ftz(*type_, *flush_to_zero),
|
||||
ast::Instruction::Sin { data, .. }
|
||||
| ast::Instruction::Cos { data, .. }
|
||||
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
|
||||
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
|
||||
InstructionModes::from_rcp(*data)
|
||||
}
|
||||
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -19,8 +19,8 @@ use ptx_parser as ast;
|
|||
*/
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
|
@ -29,8 +29,8 @@ pub(super) fn run<'input>(
|
|||
|
||||
fn run_directive<'a, 'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
|
|
|
@ -44,7 +44,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||
let directives = replace_known_functions::run(&flat_resolver, directives);
|
||||
let directives = replace_known_functions::run(&mut flat_resolver, directives);
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
|
@ -559,22 +559,23 @@ type NormalizedStatement = Statement<
|
|||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
enum Directive2<'input, Instruction, Operand: ast::Operand> {
|
||||
enum Directive2<Instruction, Operand: ast::Operand> {
|
||||
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||
Method(Function2<'input, Instruction, Operand>),
|
||||
Method(Function2<Instruction, Operand>),
|
||||
}
|
||||
|
||||
struct Function2<'input, Instruction, Operand: ast::Operand> {
|
||||
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
|
||||
pub globals: Vec<ast::Variable<SpirvWord>>,
|
||||
struct Function2<Instruction, Operand: ast::Operand> {
|
||||
pub return_arguments: Vec<ast::Variable<Operand::Ident>>,
|
||||
pub name: Operand::Ident,
|
||||
pub input_arguments: Vec<ast::Variable<Operand::Ident>>,
|
||||
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
||||
is_kernel: bool,
|
||||
import_as: Option<String>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
}
|
||||
|
||||
type NormalizedDirective2<'input> = Directive2<
|
||||
'input,
|
||||
type NormalizedDirective2 = Directive2<
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
|
@ -582,8 +583,7 @@ type NormalizedDirective2<'input> = Directive2<
|
|||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type NormalizedFunction2<'input> = Function2<
|
||||
'input,
|
||||
type NormalizedFunction2 = Function2<
|
||||
(
|
||||
Option<ast::PredAt<SpirvWord>>,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
|
@ -591,17 +591,11 @@ type NormalizedFunction2<'input> = Function2<
|
|||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalDirective<'input> = Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
type UnconditionalDirective =
|
||||
Directive2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||
|
||||
type UnconditionalFunction<'input> = Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
type UnconditionalFunction =
|
||||
Function2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||
|
||||
struct GlobalStringIdentResolver2<'input> {
|
||||
pub(crate) current_id: SpirvWord,
|
||||
|
@ -807,47 +801,45 @@ impl SpecialRegistersMap2 {
|
|||
self.id_to_reg.get(&id).copied()
|
||||
}
|
||||
|
||||
fn generate_declarations<'a, 'input>(
|
||||
fn len() -> usize {
|
||||
PtxSpecialRegister::iter().len()
|
||||
}
|
||||
|
||||
fn foreach_declaration<'a, 'input>(
|
||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||
) -> impl ExactSizeIterator<
|
||||
Item = (
|
||||
mut fn_: impl FnMut(
|
||||
PtxSpecialRegister,
|
||||
ast::MethodDeclaration<'input, SpirvWord>,
|
||||
(
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
SpirvWord,
|
||||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
),
|
||||
> + 'a {
|
||||
PtxSpecialRegister::iter().map(|sreg| {
|
||||
) {
|
||||
for sreg in PtxSpecialRegister::iter() {
|
||||
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||
let name =
|
||||
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
|
||||
let name = resolver.register_named(Cow::Owned(external_fn_name), None);
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let input_type = sreg.get_function_input_type();
|
||||
(
|
||||
sreg,
|
||||
ast::MethodDeclaration {
|
||||
return_arguments: vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: return_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
name: name,
|
||||
input_arguments: input_type
|
||||
.into_iter()
|
||||
.map(|type_| ast::Variable {
|
||||
align: None,
|
||||
v_type: type_.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver
|
||||
.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
shared_mem: None,
|
||||
},
|
||||
)
|
||||
})
|
||||
let return_arguments = vec![ast::Variable {
|
||||
align: None,
|
||||
v_type: return_type.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
}];
|
||||
let input_arguments = input_type
|
||||
.into_iter()
|
||||
.map(|type_| ast::Variable {
|
||||
align: None,
|
||||
v_type: type_.into(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
fn_(sreg, (return_arguments, name, input_arguments));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use ptx_parser as ast;
|
|||
pub(crate) fn run<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
||||
resolver.start_scope();
|
||||
let result = directives
|
||||
.into_iter()
|
||||
|
@ -17,7 +17,7 @@ pub(crate) fn run<'input, 'b>(
|
|||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||
) -> Result<NormalizedDirective2, TranslateError> {
|
||||
Ok(match directive {
|
||||
ast::Directive::Variable(linking, var) => {
|
||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||
|
@ -32,15 +32,11 @@ fn run_method<'input, 'b>(
|
|||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||
let name = match method.func_directive.name {
|
||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||
ast::MethodName::Func(text) => {
|
||||
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
|
||||
}
|
||||
};
|
||||
) -> Result<NormalizedFunction2, TranslateError> {
|
||||
let is_kernel = method.func_directive.name.is_kernel();
|
||||
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
||||
resolver.start_scope();
|
||||
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
|
||||
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -51,20 +47,21 @@ fn run_method<'input, 'b>(
|
|||
.transpose()?;
|
||||
resolver.end_scope();
|
||||
Ok(Function2 {
|
||||
func_decl,
|
||||
globals: Vec::new(),
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body,
|
||||
import_as: None,
|
||||
tuning: method.tuning,
|
||||
linkage,
|
||||
is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
name: ast::MethodName<'input, SpirvWord>,
|
||||
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
|
||||
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||
assert!(func_directive.shared_mem.is_none());
|
||||
let return_arguments = func_directive
|
||||
.return_arguments
|
||||
|
@ -76,12 +73,7 @@ fn run_function_decl<'input, 'b>(
|
|||
.into_iter()
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
})
|
||||
Ok((return_arguments, input_arguments))
|
||||
}
|
||||
|
||||
fn run_variable<'input, 'b>(
|
||||
|
|
|
@ -3,8 +3,8 @@ use ptx_parser as ast;
|
|||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<NormalizedDirective2<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
directives: Vec<NormalizedDirective2>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
|
@ -13,8 +13,8 @@ pub(crate) fn run<'input>(
|
|||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: NormalizedDirective2<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
directive: NormalizedDirective2,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
|
@ -23,8 +23,8 @@ fn run_directive<'input>(
|
|||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: NormalizedFunction2<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
method: NormalizedFunction2,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -36,12 +36,14 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@ use super::*;
|
|||
|
||||
pub(super) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut fn_declarations = FxHashMap::default();
|
||||
let remapped_directives = directives
|
||||
.into_iter()
|
||||
|
@ -13,17 +13,14 @@ pub(super) fn run<'input>(
|
|||
.into_iter()
|
||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||
Directive2::Method(Function2 {
|
||||
func_decl: ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
name: ast::MethodName::Func(name),
|
||||
input_arguments,
|
||||
shared_mem: None,
|
||||
},
|
||||
globals: Vec::new(),
|
||||
return_arguments,
|
||||
name: name,
|
||||
input_arguments,
|
||||
body: None,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -41,8 +38,8 @@ fn run_directive<'input>(
|
|||
Vec<ast::Variable<SpirvWord>>,
|
||||
),
|
||||
>,
|
||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(mut method) => {
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
use std::borrow::Cow;
|
||||
|
||||
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &GlobalStringIdentResolver2<'input>,
|
||||
mut directives: Vec<NormalizedDirective2<'input>>,
|
||||
) -> Vec<NormalizedDirective2<'input>> {
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
mut directives: Vec<NormalizedDirective2>,
|
||||
) -> Vec<NormalizedDirective2> {
|
||||
for directive in directives.iter_mut() {
|
||||
match directive {
|
||||
NormalizedDirective2::Method(func) => {
|
||||
func.import_as =
|
||||
replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
|
||||
replace_with_ptx_impl(resolver, func.name);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -17,22 +18,16 @@ pub(crate) fn run<'input>(
|
|||
}
|
||||
|
||||
fn replace_with_ptx_impl<'input>(
|
||||
resolver: &GlobalStringIdentResolver2<'input>,
|
||||
fn_name: &ptx_parser::MethodName<'input, SpirvWord>,
|
||||
name: Option<String>,
|
||||
) -> Option<String> {
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
fn_name: SpirvWord,
|
||||
) {
|
||||
let known_names = ["__assertfail"];
|
||||
match name {
|
||||
Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)),
|
||||
Some(name) => Some(name),
|
||||
None => match fn_name {
|
||||
ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) {
|
||||
Some(super::IdentEntry {
|
||||
name: Some(name), ..
|
||||
}) => Some(format!("__zluda_ptx_impl_{}", name)),
|
||||
_ => None,
|
||||
},
|
||||
ptx_parser::MethodName::Kernel(..) => None,
|
||||
},
|
||||
if let Some(super::IdentEntry {
|
||||
name: Some(name), ..
|
||||
}) = resolver.ident_map.get_mut(&fn_name)
|
||||
{
|
||||
if known_names.contains(&&**name) {
|
||||
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@ use ptx_parser as ast;
|
|||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
directives: Vec<UnconditionalDirective<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
directives: Vec<UnconditionalDirective>,
|
||||
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||
let mut functions = FxHashSet::default();
|
||||
directives
|
||||
.into_iter()
|
||||
|
@ -14,19 +14,13 @@ pub(crate) fn run<'input>(
|
|||
|
||||
fn run_directive<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
directive: UnconditionalDirective<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
directive: UnconditionalDirective,
|
||||
) -> Result<UnconditionalDirective, TranslateError> {
|
||||
Ok(match directive {
|
||||
var @ Directive2::Variable(..) => var,
|
||||
Directive2::Method(method) => {
|
||||
{
|
||||
let func_decl = &method.func_decl;
|
||||
match func_decl.name {
|
||||
ptx_parser::MethodName::Kernel(_) => {}
|
||||
ptx_parser::MethodName::Func(name) => {
|
||||
functions.insert(name);
|
||||
}
|
||||
}
|
||||
if !method.is_kernel {
|
||||
functions.insert(method.name);
|
||||
}
|
||||
Directive2::Method(run_method(functions, method)?)
|
||||
}
|
||||
|
@ -35,8 +29,8 @@ fn run_directive<'input>(
|
|||
|
||||
fn run_method<'input>(
|
||||
functions: &mut FxHashSet<SpirvWord>,
|
||||
method: UnconditionalFunction<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
method: UnconditionalFunction,
|
||||
) -> Result<UnconditionalFunction, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -47,12 +41,14 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1028,7 +1028,7 @@ pub struct ArithFloat {
|
|||
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
|
||||
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
|
||||
// instructions on the target device.
|
||||
pub is_fusable: bool
|
||||
pub is_fusable: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
|
@ -1447,6 +1447,7 @@ pub struct CvtDetails {
|
|||
pub mode: CvtMode,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum CvtMode {
|
||||
// int from int
|
||||
ZeroExtend,
|
||||
|
@ -1465,7 +1466,7 @@ pub enum CvtMode {
|
|||
flush_to_zero: Option<bool>,
|
||||
},
|
||||
FPRound {
|
||||
integer_rounding: Option<RoundingMode>,
|
||||
integer_rounding: RoundingMode,
|
||||
flush_to_zero: Option<bool>,
|
||||
},
|
||||
// int from float
|
||||
|
@ -1519,7 +1520,7 @@ impl CvtDetails {
|
|||
flush_to_zero,
|
||||
},
|
||||
Ordering::Equal => CvtMode::FPRound {
|
||||
integer_rounding: rounding,
|
||||
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
|
||||
flush_to_zero,
|
||||
},
|
||||
Ordering::Greater => {
|
||||
|
|
Loading…
Add table
Reference in a new issue