mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Apply computed denormal modes to basic blocks
This commit is contained in:
parent
aaa31da026
commit
82ca92c5c3
14 changed files with 626 additions and 123 deletions
|
@ -95,16 +95,7 @@ fn run_method<'input>(
|
|||
Ok::<_, TranslateError>(body)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
|
|
|
@ -243,6 +243,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
}
|
||||
if !method.is_kernel {
|
||||
self.resolver.register(method.name, fn_);
|
||||
self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic");
|
||||
self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic");
|
||||
} else {
|
||||
self.emit_fn_attribute(
|
||||
fn_,
|
||||
"denormal-fp-math-f32",
|
||||
llvm_ftz(method.flush_to_zero_f32),
|
||||
);
|
||||
self.emit_fn_attribute(
|
||||
fn_,
|
||||
"denormal-fp-math",
|
||||
llvm_ftz(method.flush_to_zero_f16f64),
|
||||
);
|
||||
}
|
||||
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
|
@ -413,6 +426,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
}
|
||||
}
|
||||
|
||||
fn llvm_ftz(ftz: bool) -> &'static str {
|
||||
if ftz {
|
||||
"preserve-sign"
|
||||
} else {
|
||||
"ieee"
|
||||
}
|
||||
}
|
||||
|
||||
fn get_input_argument_type(
|
||||
context: LLVMContextRef,
|
||||
v_type: &ast::Type,
|
||||
|
@ -469,6 +490,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Statement::FunctionPointer(_) => todo!(),
|
||||
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
||||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1124,7 +1146,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let cos = self.emit_intrinsic(
|
||||
c"llvm.cos.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
Some(&ast::ScalarType::F32.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||
)?;
|
||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||
|
@ -1377,7 +1399,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let sin = self.emit_intrinsic(
|
||||
c"llvm.sin.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
Some(&ast::ScalarType::F32.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||
)?;
|
||||
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
||||
|
@ -1388,12 +1410,12 @@ impl<'a> MethodEmitContext<'a> {
|
|||
&mut self,
|
||||
name: &CStr,
|
||||
dst: Option<SpirvWord>,
|
||||
return_type: &ast::Type,
|
||||
return_type: Option<&ast::Type>,
|
||||
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
|
||||
) -> Result<LLVMValueRef, TranslateError> {
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
iter::once(return_type),
|
||||
return_type.into_iter(),
|
||||
arguments.iter().map(|(_, type_)| Ok(*type_)),
|
||||
)?;
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||
|
@ -1612,7 +1634,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let clamped = self.emit_intrinsic(
|
||||
c"llvm.umin",
|
||||
None,
|
||||
&from.into(),
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(max, from_llvm),
|
||||
|
@ -1642,7 +1664,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let zero_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
Some(&from.into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(zero, from_llvm),
|
||||
|
@ -1661,7 +1683,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let fully_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
Some(&from.into()),
|
||||
vec![(zero_clamped, from_llvm), (max, from_llvm)],
|
||||
)?;
|
||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||
|
@ -1701,7 +1723,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let rounded_float = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
Some(&from.into()),
|
||||
vec![(
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, from),
|
||||
|
@ -1770,7 +1792,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -1791,7 +1813,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -1813,7 +1835,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -1935,7 +1957,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
vec![(
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, data.type_),
|
||||
|
@ -1952,7 +1974,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
c"llvm.amdgcn.log.f32",
|
||||
Some(arguments.dst),
|
||||
&ast::ScalarType::F32.into(),
|
||||
Some(&ast::ScalarType::F32.into()),
|
||||
vec![(
|
||||
self.resolver.value(arguments.src)?,
|
||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
||||
|
@ -2007,7 +2029,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
Some(arguments.dst),
|
||||
&type_.into(),
|
||||
Some(&type_.into()),
|
||||
vec![(self.resolver.value(arguments.src)?, llvm_type)],
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -2031,7 +2053,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
Some(&data.type_().into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||
|
@ -2058,7 +2080,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_().into(),
|
||||
Some(&data.type_().into()),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||
|
@ -2076,7 +2098,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
vec![
|
||||
(
|
||||
self.resolver.value(arguments.src1)?,
|
||||
|
@ -2197,12 +2219,49 @@ impl<'a> MethodEmitContext<'a> {
|
|||
self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
|
||||
Some(arguments.dst),
|
||||
&data.type_.into(),
|
||||
Some(&data.type_.into()),
|
||||
intrinsic_arguments,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> {
|
||||
let intrinsic = c"llvm.amdgcn.s.setreg";
|
||||
let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32);
|
||||
let (hwreg, value) = match mode_reg {
|
||||
ModeRegister::DenormalF32(ftz) => {
|
||||
let (reg, offset, size) = (1, 4, 2u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
(hwreg, if ftz { 0u32 } else { 3 })
|
||||
}
|
||||
ModeRegister::DenormalF16F64(ftz) => {
|
||||
let (reg, offset, size) = (1, 6, 2u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
(hwreg, if ftz { 0 } else { 3 })
|
||||
}
|
||||
ModeRegister::DenormalBoth { f32, f16f64 } => {
|
||||
let (reg, offset, size) = (1, 4, 4u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
let f32 = if f32 { 0 } else { 3 };
|
||||
let f16f64 = if f16f64 { 0 } else { 3 };
|
||||
let value = f32 | f16f64 << 2;
|
||||
(hwreg, value)
|
||||
}
|
||||
ModeRegister::RoundingF32(rounding_mode) => todo!(),
|
||||
ModeRegister::RoundingF16F64(rounding_mode) => todo!(),
|
||||
ModeRegister::RoundingBoth { f32, f16f64 } => todo!(),
|
||||
};
|
||||
let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) };
|
||||
let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) };
|
||||
self.emit_intrinsic(
|
||||
intrinsic,
|
||||
None,
|
||||
None,
|
||||
vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
|
|
@ -41,14 +41,18 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
roundind_mode_f32: method.roundind_mode_f32,
|
||||
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,10 @@ pub(super) fn run<'a, 'input>(
|
|||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
},
|
||||
|
@ -60,16 +64,7 @@ fn run_method<'a, 'input>(
|
|||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
|
|
|
@ -64,16 +64,7 @@ fn run_method<'a, 'input>(
|
|||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'a, 'input>(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::BrachCondition;
|
||||
use super::Directive2;
|
||||
use super::Function2;
|
||||
use super::ModeRegister;
|
||||
use super::SpirvWord;
|
||||
use super::Statement;
|
||||
use super::TranslateError;
|
||||
|
@ -18,6 +19,7 @@ use rustc_hash::FxHashSet;
|
|||
use std::hash::Hash;
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use std::u32;
|
||||
use strum::EnumCount;
|
||||
use strum_macros::{EnumCount, VariantArray};
|
||||
|
||||
|
@ -36,6 +38,13 @@ impl DenormalMode {
|
|||
DenormalMode::Preserve
|
||||
}
|
||||
}
|
||||
|
||||
fn to_ftz(self) -> bool {
|
||||
match self {
|
||||
DenormalMode::FlushToZero => true,
|
||||
DenormalMode::Preserve => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<usize> for DenormalMode {
|
||||
|
@ -94,20 +103,19 @@ impl InstructionModes {
|
|||
_ => {}
|
||||
}
|
||||
}
|
||||
fn set_if_some<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
||||
match (source, value) {
|
||||
(Some(ref mut x), Some(y)) => *x = y,
|
||||
_ => {}
|
||||
fn set_if_any<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
||||
if let Some(x) = value {
|
||||
*source = Some(x);
|
||||
}
|
||||
}
|
||||
set_if_none(&mut entry.denormal_f32, self.denormal_f32);
|
||||
set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64);
|
||||
set_if_none(&mut entry.rounding_f32, self.rounding_f32);
|
||||
set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64);
|
||||
set_if_some(&mut exit.denormal_f32, self.denormal_f32);
|
||||
set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64);
|
||||
set_if_some(&mut exit.rounding_f32, self.rounding_f32);
|
||||
set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64);
|
||||
set_if_any(&mut exit.denormal_f32, self.denormal_f32);
|
||||
set_if_any(&mut exit.denormal_f16f64, self.denormal_f16f64);
|
||||
set_if_any(&mut exit.rounding_f32, self.rounding_f32);
|
||||
set_if_any(&mut exit.rounding_f16f64, self.rounding_f16f64);
|
||||
}
|
||||
|
||||
fn none() -> Self {
|
||||
|
@ -209,18 +217,12 @@ impl InstructionModes {
|
|||
flush_to_zero.map(DenormalMode::from_ftz),
|
||||
Some(RoundingMode::from_ast(rounding)),
|
||||
),
|
||||
ast::CvtMode::SignedFromFP {
|
||||
flush_to_zero,
|
||||
rounding,
|
||||
// float to int contains rounding field, but it's not a rounding
|
||||
// mode but rather round-to-int operation that will be applied
|
||||
ast::CvtMode::SignedFromFP { flush_to_zero, .. }
|
||||
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => {
|
||||
Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None)
|
||||
}
|
||||
| ast::CvtMode::UnsignedFromFP {
|
||||
flush_to_zero,
|
||||
rounding,
|
||||
} => Self::new(
|
||||
cvt.from,
|
||||
flush_to_zero.map(DenormalMode::from_ftz),
|
||||
Some(RoundingMode::from_ast(rounding)),
|
||||
),
|
||||
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
||||
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
||||
}
|
||||
|
@ -263,22 +265,15 @@ impl ControlFlowGraph {
|
|||
}
|
||||
|
||||
fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) {
|
||||
self.graph[node].denormal_f32 = Mode {
|
||||
entry: entry.denormal_f32.map(ExtendedMode::BasicBlock),
|
||||
exit: exit.denormal_f32.map(ExtendedMode::BasicBlock),
|
||||
};
|
||||
self.graph[node].denormal_f16f64 = Mode {
|
||||
entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
||||
exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
||||
};
|
||||
self.graph[node].rounding_f32 = Mode {
|
||||
entry: entry.rounding_f32.map(ExtendedMode::BasicBlock),
|
||||
exit: exit.rounding_f32.map(ExtendedMode::BasicBlock),
|
||||
};
|
||||
self.graph[node].rounding_f16f64 = Mode {
|
||||
entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
||||
exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
||||
};
|
||||
let node = &mut self.graph[node];
|
||||
node.denormal_f32.entry = entry.denormal_f32.map(ExtendedMode::BasicBlock);
|
||||
node.denormal_f16f64.entry = entry.denormal_f16f64.map(ExtendedMode::BasicBlock);
|
||||
node.rounding_f32.entry = entry.rounding_f32.map(ExtendedMode::BasicBlock);
|
||||
node.rounding_f16f64.entry = entry.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||
node.denormal_f32.exit = exit.denormal_f32.map(ExtendedMode::BasicBlock);
|
||||
node.denormal_f16f64.exit = exit.denormal_f16f64.map(ExtendedMode::BasicBlock);
|
||||
node.rounding_f32.exit = exit.rounding_f32.map(ExtendedMode::BasicBlock);
|
||||
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,7 +338,7 @@ trait EnumTuple {
|
|||
|
||||
pub(crate) fn run<'input>(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
mut directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut cfg = ControlFlowGraph::new();
|
||||
for directive in directives.iter() {
|
||||
|
@ -351,42 +346,39 @@ pub(crate) fn run<'input>(
|
|||
super::Directive2::Method(Function2 {
|
||||
name,
|
||||
body: Some(body),
|
||||
is_kernel,
|
||||
..
|
||||
}) => {
|
||||
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
||||
let mut entry = InstructionModes::none();
|
||||
let mut exit = InstructionModes::none();
|
||||
// TODO: implement for non-kernels
|
||||
if !*is_kernel {
|
||||
todo!()
|
||||
}
|
||||
let entry_index = cfg.add_entry_basic_block(*name);
|
||||
let mut bb_state = BasicBlockState::new(&mut cfg);
|
||||
let mut body_iter = body.iter();
|
||||
match body_iter.next() {
|
||||
Some(Statement::Label(label)) => {
|
||||
bb_state.cfg.add_jump(entry_index, *label);
|
||||
bb_state.start(*label);
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
for statement in body.iter() {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||
cfg.add_jump(bb_index, arguments.src);
|
||||
cfg.set_modes(
|
||||
bb_index,
|
||||
mem::replace(&mut entry, InstructionModes::none()),
|
||||
mem::replace(&mut exit, InstructionModes::none()),
|
||||
);
|
||||
basic_block = None;
|
||||
bb_state.end(&[arguments.src]);
|
||||
}
|
||||
Statement::Label(label) => {
|
||||
basic_block = Some(cfg.get_or_add_basic_block(*label));
|
||||
bb_state.start(*label);
|
||||
}
|
||||
Statement::Conditional(BrachCondition {
|
||||
if_true, if_false, ..
|
||||
}) => {
|
||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||
cfg.add_jump(bb_index, *if_true);
|
||||
cfg.add_jump(bb_index, *if_false);
|
||||
cfg.set_modes(
|
||||
bb_index,
|
||||
mem::replace(&mut entry, InstructionModes::none()),
|
||||
mem::replace(&mut exit, InstructionModes::none()),
|
||||
);
|
||||
basic_block = None;
|
||||
bb_state.end(&[*if_true, *if_false]);
|
||||
}
|
||||
Statement::Instruction(instruction) => {
|
||||
let modes = get_modes(instruction);
|
||||
modes.fold_into(&mut entry, &mut exit);
|
||||
bb_state.append(modes);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -395,7 +387,370 @@ pub(crate) fn run<'input>(
|
|||
_ => {}
|
||||
}
|
||||
}
|
||||
todo!()
|
||||
let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32);
|
||||
let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64);
|
||||
let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32);
|
||||
let rounding_f16f64 = compute_single_mode(&cfg, |node| node.rounding_f16f64);
|
||||
let denormal_f32 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f32);
|
||||
let denormal_f16f64 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f16f64);
|
||||
let rounding_f32 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f32);
|
||||
let rounding_f16f64 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f16f64);
|
||||
insert_mode_control(
|
||||
flat_resolver,
|
||||
&mut directives,
|
||||
&cfg,
|
||||
denormal_f32,
|
||||
denormal_f16f64,
|
||||
rounding_f32,
|
||||
rounding_f16f64,
|
||||
)?;
|
||||
Ok(directives)
|
||||
}
|
||||
|
||||
fn insert_mode_control<'input>(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||
directives: &mut [Directive2<ast::Instruction<SpirvWord>, SpirvWord>],
|
||||
cfg: &ControlFlowGraph,
|
||||
denormal_f32: ModeInsertions<DenormalMode>,
|
||||
denormal_f16f64: ModeInsertions<DenormalMode>,
|
||||
rounding_f32: ModeInsertions<RoundingMode>,
|
||||
rounding_f16f64: ModeInsertions<RoundingMode>,
|
||||
) -> Result<(), TranslateError> {
|
||||
for directive in directives.iter_mut() {
|
||||
let body_ptr = match directive {
|
||||
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => continue,
|
||||
Directive2::Method(Function2 {
|
||||
name,
|
||||
body: Some(body),
|
||||
flush_to_zero_f32,
|
||||
flush_to_zero_f16f64,
|
||||
roundind_mode_f32: rounding_mode_f32,
|
||||
roundind_mode_f16f64: rounding_mode_f16f64,
|
||||
..
|
||||
}) => {
|
||||
*flush_to_zero_f32 = denormal_f32
|
||||
.kernels
|
||||
.get(name)
|
||||
.copied()
|
||||
.unwrap_or(DenormalMode::default())
|
||||
.to_ftz();
|
||||
*flush_to_zero_f16f64 = denormal_f16f64
|
||||
.kernels
|
||||
.get(name)
|
||||
.copied()
|
||||
.unwrap_or(DenormalMode::default())
|
||||
.to_ftz();
|
||||
*rounding_mode_f32 = rounding_f32
|
||||
.kernels
|
||||
.get(name)
|
||||
.copied()
|
||||
.unwrap_or(RoundingMode::default())
|
||||
.to_ast();
|
||||
*rounding_mode_f16f64 = rounding_f16f64
|
||||
.kernels
|
||||
.get(name)
|
||||
.copied()
|
||||
.unwrap_or(RoundingMode::default())
|
||||
.to_ast();
|
||||
body
|
||||
}
|
||||
};
|
||||
let mut old_body = mem::replace(body_ptr, Vec::new());
|
||||
let mut result = Vec::with_capacity(old_body.len());
|
||||
let mut bb_state = BasicBlockControlState::new(
|
||||
&denormal_f32,
|
||||
&denormal_f16f64,
|
||||
&rounding_f32,
|
||||
&rounding_f16f64,
|
||||
);
|
||||
for statement in old_body.into_iter() {
|
||||
match &statement {
|
||||
Statement::Label(label) => {
|
||||
bb_state.start(*label);
|
||||
}
|
||||
Statement::Instruction(instruction) => {
|
||||
let modes = get_modes(&instruction);
|
||||
bb_state.insert(&mut result, modes)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
result.push(statement);
|
||||
}
|
||||
*body_ptr = result;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct BasicBlockControlState<'a> {
|
||||
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
|
||||
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
|
||||
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
|
||||
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
|
||||
basic_block: SpirvWord,
|
||||
denormal_f32: RegisterState<bool>,
|
||||
denormal_f16f64: RegisterState<bool>,
|
||||
foldable_rounding_f32: Option<usize>,
|
||||
foldable_rounding_f16f64: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum RegisterState<T> {
|
||||
Inherited,
|
||||
Unknown,
|
||||
Value(Option<usize>, T),
|
||||
}
|
||||
|
||||
impl<T> RegisterState<T> {
|
||||
fn empty() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
|
||||
fn new(must_insert: bool) -> Self {
|
||||
if must_insert {
|
||||
Self::Unknown
|
||||
} else {
|
||||
Self::Inherited
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> BasicBlockControlState<'a> {
|
||||
fn new(
|
||||
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
|
||||
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
|
||||
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
|
||||
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
|
||||
) -> Self {
|
||||
BasicBlockControlState {
|
||||
global_denormal_f32,
|
||||
global_denormal_f16f64,
|
||||
global_rounding_f32,
|
||||
global_rounding_f16f64,
|
||||
basic_block: SpirvWord(u32::MAX),
|
||||
denormal_f32: RegisterState::empty(),
|
||||
denormal_f16f64: RegisterState::empty(),
|
||||
foldable_rounding_f32: None,
|
||||
foldable_rounding_f16f64: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn start(&mut self, label: SpirvWord) {
|
||||
self.denormal_f32 =
|
||||
RegisterState::new(self.global_denormal_f32.basic_blocks.contains(&label));
|
||||
self.denormal_f32 =
|
||||
RegisterState::new(self.global_denormal_f16f64.basic_blocks.contains(&label));
|
||||
self.foldable_rounding_f32 = None;
|
||||
self.foldable_rounding_f16f64 = None;
|
||||
self.basic_block = label;
|
||||
}
|
||||
|
||||
fn add_or_fold_mode_set(
|
||||
&mut self,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
new_mode: bool,
|
||||
) -> Option<usize> {
|
||||
// try and fold into the other mode set
|
||||
if let RegisterState::Value(Some(other_index), other_value) = self.denormal_f16f64 {
|
||||
if let Some(Statement::SetMode(ModeRegister::DenormalF16F64(_))) =
|
||||
result.get_mut(other_index)
|
||||
{
|
||||
result[other_index] = Statement::SetMode(ModeRegister::DenormalBoth {
|
||||
f32: new_mode,
|
||||
f16f64: other_value,
|
||||
});
|
||||
return None;
|
||||
}
|
||||
}
|
||||
result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode)));
|
||||
Some(result.len() - 1)
|
||||
}
|
||||
|
||||
fn insert(
|
||||
&mut self,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
modes: InstructionModes,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.insert_one::<DenormalF32View>(result, modes.denormal_f32.map(DenormalMode::to_ftz))?;
|
||||
self.insert_one::<DenormalF16F64View>(
|
||||
result,
|
||||
modes.denormal_f16f64.map(DenormalMode::to_ftz),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_one<View: ModeView>(
|
||||
&mut self,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
mode: Option<View::Value>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if let Some(new_mode) = mode {
|
||||
let register_state = View::get_register(self);
|
||||
match register_state {
|
||||
RegisterState::Inherited => {
|
||||
View::set_register(self, RegisterState::Value(None, new_mode));
|
||||
}
|
||||
RegisterState::Unknown => {
|
||||
View::set_register(
|
||||
self,
|
||||
RegisterState::Value(
|
||||
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
|
||||
new_mode,
|
||||
),
|
||||
);
|
||||
}
|
||||
RegisterState::Value(_, old_value) => {
|
||||
if new_mode == old_value {
|
||||
return Ok(());
|
||||
}
|
||||
View::set_register(
|
||||
self,
|
||||
RegisterState::Value(
|
||||
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
|
||||
new_mode,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Return the index of the last insertion of SetMode with this mode
|
||||
fn add_or_fold_mode_set2<View: ModeView>(
|
||||
&self,
|
||||
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
new_mode: View::Value,
|
||||
) -> usize {
|
||||
// try and fold into the other mode set in struction
|
||||
if let RegisterState::Value(Some(twin_index), _) = View::TwinView::get_register(self) {
|
||||
if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) {
|
||||
if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) {
|
||||
*register_mode = View::new_mode(new_mode, Some(twin_mode));
|
||||
return twin_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(Statement::SetMode(View::new_mode(new_mode, None)));
|
||||
result.len() - 1
|
||||
}
|
||||
}
|
||||
|
||||
trait ModeView {
|
||||
type Value: PartialEq + Eq + Copy + Clone;
|
||||
type TwinView: ModeView<Value = Self::Value>;
|
||||
|
||||
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value>;
|
||||
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>);
|
||||
fn new_mode(t: Self::Value, other: Option<Self::Value>) -> ModeRegister;
|
||||
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value>;
|
||||
}
|
||||
|
||||
struct DenormalF32View;
|
||||
|
||||
impl ModeView for DenormalF32View {
|
||||
type Value = bool;
|
||||
type TwinView = DenormalF16F64View;
|
||||
|
||||
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
|
||||
bb.denormal_f32
|
||||
}
|
||||
|
||||
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
|
||||
bb.denormal_f32 = reg;
|
||||
}
|
||||
|
||||
fn new_mode(f32: Self::Value, f16f64: Option<Self::Value>) -> ModeRegister {
|
||||
match f16f64 {
|
||||
Some(f16f64) => ModeRegister::DenormalBoth { f32, f16f64 },
|
||||
None => ModeRegister::DenormalF32(f32),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
|
||||
match reg {
|
||||
ModeRegister::DenormalF32(value) => Some(*value),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DenormalF16F64View;
|
||||
|
||||
impl ModeView for DenormalF16F64View {
|
||||
type Value = bool;
|
||||
type TwinView = DenormalF32View;
|
||||
|
||||
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
|
||||
bb.denormal_f16f64
|
||||
}
|
||||
|
||||
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
|
||||
bb.denormal_f16f64 = reg;
|
||||
}
|
||||
|
||||
fn new_mode(f16f64: Self::Value, f32: Option<Self::Value>) -> ModeRegister {
|
||||
match f32 {
|
||||
Some(f32) => ModeRegister::DenormalBoth { f16f64, f32 },
|
||||
None => ModeRegister::DenormalF16F64(f16f64),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
|
||||
match reg {
|
||||
ModeRegister::DenormalF16F64(value) => Some(*value),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BasicBlockState<'a> {
|
||||
cfg: &'a mut ControlFlowGraph,
|
||||
node_index: Option<NodeIndex>,
|
||||
// If it's a kernel basic block then we don't track entry instruction mode
|
||||
entry: InstructionModes,
|
||||
exit: InstructionModes,
|
||||
}
|
||||
|
||||
impl<'a> BasicBlockState<'a> {
|
||||
fn new(cfg: &'a mut ControlFlowGraph) -> BasicBlockState<'a> {
|
||||
Self {
|
||||
cfg,
|
||||
node_index: None,
|
||||
entry: InstructionModes::none(),
|
||||
exit: InstructionModes::none(),
|
||||
}
|
||||
}
|
||||
|
||||
fn start(&mut self, label: SpirvWord) {
|
||||
self.end(&[]);
|
||||
self.node_index = Some(self.cfg.get_or_add_basic_block(label));
|
||||
}
|
||||
|
||||
fn end(&mut self, jumps: &[SpirvWord]) {
|
||||
let node_index = self.node_index.take();
|
||||
let node_index = match node_index {
|
||||
Some(x) => x,
|
||||
None => return,
|
||||
};
|
||||
for target in jumps {
|
||||
self.cfg.add_jump(node_index, *target);
|
||||
}
|
||||
self.cfg.set_modes(
|
||||
node_index,
|
||||
mem::replace(&mut self.entry, InstructionModes::none()),
|
||||
mem::replace(&mut self.exit, InstructionModes::none()),
|
||||
);
|
||||
}
|
||||
|
||||
fn append(&mut self, modes: InstructionModes) {
|
||||
modes.fold_into(&mut self.entry, &mut self.exit);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for BasicBlockState<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.end(&[]);
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_single_mode<T: Copy + Eq>(
|
||||
|
@ -424,10 +779,9 @@ fn compute_single_mode<T: Copy + Eq>(
|
|||
UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming));
|
||||
let mut visited = FxHashSet::default();
|
||||
while let Some(current) = to_visit.pop() {
|
||||
if visited.contains(¤t) {
|
||||
if !visited.insert(current) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(current);
|
||||
let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit;
|
||||
match exit_mode {
|
||||
None => {
|
||||
|
@ -462,6 +816,7 @@ fn compute_single_mode<T: Copy + Eq>(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PartialModeInsertion<T> {
|
||||
bb_must_insert_mode: FxHashSet<SpirvWord>,
|
||||
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
|
||||
|
@ -498,10 +853,11 @@ fn optimize<T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug, const
|
|||
}
|
||||
}
|
||||
let mut kernels = FxHashMap::default();
|
||||
for (kernel, modes) in kernel_modes {
|
||||
'iterate_kernels: for (kernel, modes) in kernel_modes {
|
||||
for (mode, var) in modes.into_iter().enumerate() {
|
||||
if solution[var] > 0.5 {
|
||||
kernels.insert(kernel, T::VARIANTS[mode]);
|
||||
continue 'iterate_kernels;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ mod hoist_globals;
|
|||
mod insert_explicit_load_store;
|
||||
mod insert_ftz_control;
|
||||
mod insert_implicit_conversions2;
|
||||
mod normalize_basic_blocks;
|
||||
mod normalize_identifiers2;
|
||||
mod normalize_predicates2;
|
||||
mod replace_instructions_with_function_calls;
|
||||
|
@ -52,6 +53,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
|
||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
|
@ -197,6 +199,22 @@ enum Statement<I, P: ast::Operand> {
|
|||
FunctionPointer(FunctionPointerDetails),
|
||||
VectorRead(VectorRead),
|
||||
VectorWrite(VectorWrite),
|
||||
SetMode(ModeRegister),
|
||||
}
|
||||
|
||||
enum ModeRegister {
|
||||
DenormalF32(bool),
|
||||
DenormalF16F64(bool),
|
||||
DenormalBoth {
|
||||
f32: bool,
|
||||
f16f64: bool,
|
||||
},
|
||||
RoundingF32(ast::RoundingMode),
|
||||
RoundingF16F64(ast::RoundingMode),
|
||||
RoundingBoth {
|
||||
f32: ast::RoundingMode,
|
||||
f16f64: ast::RoundingMode,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||
|
@ -469,6 +487,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||
let src = visitor.visit_ident(src, None, false, false)?;
|
||||
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
|
||||
}
|
||||
Statement::SetMode(mode_register) => Statement::SetMode(mode_register),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -573,6 +592,10 @@ struct Function2<Instruction, Operand: ast::Operand> {
|
|||
import_as: Option<String>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
linkage: ast::LinkingDirective,
|
||||
flush_to_zero_f32: bool,
|
||||
flush_to_zero_f16f64: bool,
|
||||
roundind_mode_f32: ast::RoundingMode,
|
||||
roundind_mode_f16f64: ast::RoundingMode,
|
||||
}
|
||||
|
||||
type NormalizedDirective2 = Directive2<
|
||||
|
|
52
ptx/src/pass/normalize_basic_blocks.rs
Normal file
52
ptx/src/pass/normalize_basic_blocks.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use super::*;
|
||||
|
||||
// This pass normalized ptx modules in two ways that makes mode computation pass
|
||||
// and code emissions passes much simpler:
|
||||
// * Inserts label at the start of every function
|
||||
// This makes control flow graph simpler in mode computation block: we can
|
||||
// represent kernels as separate nodes with its own separate entry/exit mode
|
||||
// * Inserts label at the start of every basic block
|
||||
|
||||
pub(crate) fn run(
|
||||
flat_resolver: &mut GlobalStringIdentResolver2<'_>,
|
||||
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>> {
|
||||
for directive in directives.iter_mut() {
|
||||
let body_ref = match directive {
|
||||
Directive2::Method(Function2 {
|
||||
body: Some(body), ..
|
||||
}) => body,
|
||||
_ => continue,
|
||||
};
|
||||
let body = std::mem::replace(body_ref, Vec::new());
|
||||
let mut result = Vec::with_capacity(body.len());
|
||||
let mut needs_label = false;
|
||||
let mut body_iterator = body.into_iter();
|
||||
match body_iterator.next() {
|
||||
Some(Statement::Label(_)) => {}
|
||||
Some(statement) => {
|
||||
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||
result.push(statement);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
for statement in body_iterator {
|
||||
if needs_label && !matches!(statement, Statement::Label(..)) {
|
||||
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||
}
|
||||
needs_label = is_block_terminator(&statement);
|
||||
result.push(statement);
|
||||
}
|
||||
*body_ref = result;
|
||||
}
|
||||
directives
|
||||
}
|
||||
|
||||
fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, SpirvWord>) -> bool {
|
||||
match instruction {
|
||||
Statement::Conditional(..)
|
||||
| Statement::Instruction(ast::Instruction::Bra { .. })
|
||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
|
@ -52,9 +52,13 @@ fn run_method<'input, 'b>(
|
|||
input_arguments,
|
||||
body,
|
||||
import_as: None,
|
||||
tuning: method.tuning,
|
||||
linkage,
|
||||
is_kernel,
|
||||
tuning: method.tuning,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -36,14 +36,18 @@ fn run_method<'input>(
|
|||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
body,
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
roundind_mode_f32: method.roundind_mode_f32,
|
||||
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,10 @@ pub(super) fn run<'input>(
|
|||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::EXTERN,
|
||||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
|
|
@ -40,16 +40,7 @@ fn run_method<'input>(
|
|||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
return_arguments: method.return_arguments,
|
||||
name: method.name,
|
||||
input_arguments: method.input_arguments,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
is_kernel: method.is_kernel,
|
||||
})
|
||||
Ok(Function2 { body, ..method })
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
|
|
27
ptx/src/test/spirv_run/malformed_label.ptx
Normal file
27
ptx/src/test/spirv_run/malformed_label.ptx
Normal file
|
@ -0,0 +1,27 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry malformed_label(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
bra BB0;
|
||||
// this basic block does not start with a label
|
||||
ld.u64 temp, [out_addr];
|
||||
|
||||
BB0:
|
||||
ld.u64 temp, [in_addr];
|
||||
add.u64 temp2, temp, 1;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
|
@ -186,6 +186,8 @@ test_ptx!(
|
|||
[0x800000u32, 0xFFFFFF]
|
||||
);
|
||||
|
||||
test_ptx!(malformed_label, [2u64], [3u64]);
|
||||
|
||||
test_ptx!(assertfail);
|
||||
test_ptx!(func_ptr);
|
||||
test_ptx!(lanemask_lt);
|
||||
|
|
Loading…
Add table
Reference in a new issue