mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-02 14:19:27 +00:00
Merge commit '4306646739
' into parser_recovery
This commit is contained in:
commit
6f14025e9b
26 changed files with 593 additions and 163 deletions
BIN
elf.o
Normal file
BIN
elf.o
Normal file
Binary file not shown.
|
@ -230,15 +230,33 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
|
|||
|
||||
fn vec_pack(
|
||||
&mut self,
|
||||
vector_elements: Vec<SpirvWord>,
|
||||
vector_elements: Vec<ast::RegOrImmediate<SpirvWord>>,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (width, scalar_t, state_space) = match type_space {
|
||||
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
|
||||
Some((ast::Type::Scalar(scalar_t), space))
|
||||
if scalar_t.kind() == ast::ScalarKind::Bit =>
|
||||
{
|
||||
let type_ =
|
||||
ast::ScalarType::from_size(scalar_t.size_of() / (vector_elements.len() as u8))
|
||||
.ok_or_else(|| error_mismatched_type())?;
|
||||
(vector_elements.len() as u8, type_, space)
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let vector_elements = vector_elements
|
||||
.into_iter()
|
||||
.map(|element| match element {
|
||||
ast::RegOrImmediate::Reg(name) => self.reg(name),
|
||||
ast::RegOrImmediate::Imm(immediate_value) => self.immediate(
|
||||
immediate_value,
|
||||
Some((&ast::Type::Scalar(scalar_t), state_space)),
|
||||
),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let temporary_vector = self
|
||||
.resolver
|
||||
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
|
||||
|
|
|
@ -198,10 +198,15 @@ pub fn map_operand<T: Copy, Err>(
|
|||
Some(ident) => ast::ParsedOperand::Reg(ident),
|
||||
None => ast::ParsedOperand::VecMember(ident, member),
|
||||
},
|
||||
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
|
||||
idents
|
||||
ast::ParsedOperand::VecPack(elements) => ast::ParsedOperand::VecPack(
|
||||
elements
|
||||
.into_iter()
|
||||
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
|
||||
.map(|element| match element {
|
||||
ast::RegOrImmediate::Reg(ident) => {
|
||||
Ok(ast::RegOrImmediate::Reg(fn_(ident, None)?.unwrap_or(ident)))
|
||||
}
|
||||
ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(imm)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
|
|
|
@ -208,7 +208,7 @@ fn default_implicit_conversion_type(
|
|||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
Err(error_mismatched_type())
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
|
@ -264,14 +264,14 @@ pub(crate) fn should_convert_relaxed_dst_wrapper(
|
|||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if operand_space != instruction_space {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
None => Err(error_mismatched_type()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,8 +24,6 @@
|
|||
// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
|
||||
// but it will too fail similarly, but with "unable to legalize instruction"
|
||||
|
||||
use std::array::TryFromSliceError;
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::{CStr, NulError};
|
||||
use std::{i8, ptr, u64};
|
||||
|
||||
|
@ -249,79 +247,51 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
unsafe { LLVMSetAlignment(global, align) };
|
||||
}
|
||||
if !var.array_init.is_empty() {
|
||||
self.emit_array_init(&var.v_type, &*var.array_init, global)?;
|
||||
let initializer = self.get_array_init(&var.v_type, &*var.array_init)?;
|
||||
unsafe { LLVMSetInitializer(global, initializer) };
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: instead of Vec<u8> we should emit a typed initializer
|
||||
fn emit_array_init(
|
||||
&mut self,
|
||||
fn get_array_init(
|
||||
&self,
|
||||
type_: &ast::Type,
|
||||
array_init: &[u8],
|
||||
global: *mut llvm_zluda::LLVMValue,
|
||||
) -> Result<(), TranslateError> {
|
||||
match type_ {
|
||||
array_init: &[ast::RegOrImmediate<SpirvWord>],
|
||||
) -> Result<*mut LLVMValue, TranslateError> {
|
||||
let initializer = match type_ {
|
||||
ast::Type::Array(None, scalar, dimensions) => {
|
||||
if dimensions.len() != 1 {
|
||||
todo!()
|
||||
}
|
||||
if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
|
||||
if dimensions[0] as usize != array_init.len() {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
let type_ = get_scalar_type(self.context, *scalar);
|
||||
let mut elements = array_init
|
||||
.chunks(scalar.size_of() as usize)
|
||||
.map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|_| error_unreachable())?;
|
||||
let initializer =
|
||||
unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
|
||||
unsafe { LLVMSetInitializer(global, initializer) };
|
||||
.iter()
|
||||
.map(|elem| match elem {
|
||||
ast::RegOrImmediate::Reg(reg) => {
|
||||
Ok(unsafe { LLVMConstPtrToInt(self.resolver.value(*reg)?, type_) })
|
||||
}
|
||||
ast::RegOrImmediate::Imm(imm) => {
|
||||
Ok(get_immediate_value(self.context, scalar, imm))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn constant_from_bytes(
|
||||
&self,
|
||||
scalar: ast::ScalarType,
|
||||
bytes: &[u8],
|
||||
llvm_type: LLVMTypeRef,
|
||||
) -> Result<LLVMValueRef, TryFromSliceError> {
|
||||
Ok(match scalar {
|
||||
ptx_parser::ScalarType::Pred
|
||||
| ptx_parser::ScalarType::S8
|
||||
| ptx_parser::ScalarType::B8
|
||||
| ptx_parser::ScalarType::U8 => unsafe {
|
||||
LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::S16
|
||||
| ptx_parser::ScalarType::B16
|
||||
| ptx_parser::ScalarType::U16
|
||||
| ptx_parser::ScalarType::E4m3x2
|
||||
| ptx_parser::ScalarType::E5m2x2 => unsafe {
|
||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::S32
|
||||
| ptx_parser::ScalarType::B32
|
||||
| ptx_parser::ScalarType::U32 => unsafe {
|
||||
LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::F16 => todo!(),
|
||||
ptx_parser::ScalarType::BF16 => todo!(),
|
||||
ptx_parser::ScalarType::U64 => todo!(),
|
||||
ptx_parser::ScalarType::S64 => todo!(),
|
||||
ptx_parser::ScalarType::S16x2 => todo!(),
|
||||
ptx_parser::ScalarType::F32 => todo!(),
|
||||
ptx_parser::ScalarType::B64 => todo!(),
|
||||
ptx_parser::ScalarType::F64 => todo!(),
|
||||
ptx_parser::ScalarType::B128 => todo!(),
|
||||
ptx_parser::ScalarType::U16x2 => todo!(),
|
||||
ptx_parser::ScalarType::F16x2 => todo!(),
|
||||
ptx_parser::ScalarType::BF16x2 => todo!(),
|
||||
})
|
||||
ast::Type::Scalar(scalar) => {
|
||||
let initializer = match array_init {
|
||||
[ast::RegOrImmediate::Imm(init)] => init,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
get_immediate_value(self.context, scalar, initializer)
|
||||
}
|
||||
_ => {
|
||||
todo!()
|
||||
}
|
||||
};
|
||||
Ok(initializer)
|
||||
}
|
||||
|
||||
fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) {
|
||||
|
@ -338,6 +308,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_immediate_value(
|
||||
context: LLVMContextRef,
|
||||
scalar_type: &ast::ScalarType,
|
||||
imm: &ast::ImmediateValue,
|
||||
) -> *mut LLVMValue {
|
||||
let type_ = get_scalar_type(context, *scalar_type);
|
||||
match imm {
|
||||
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, *x, 0) },
|
||||
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, *x as u64, 0) },
|
||||
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, *x as f64) },
|
||||
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, *x) },
|
||||
}
|
||||
}
|
||||
|
||||
fn llvm_ftz(ftz: bool) -> &'static str {
|
||||
if ftz {
|
||||
"preserve-sign"
|
||||
|
@ -404,7 +388,6 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
|
||||
Statement::FpSaturate { dst, src, type_ } => self.emit_fp_saturate(type_, dst, src)?,
|
||||
// No-op
|
||||
Statement::FpModeRequired { .. } => {}
|
||||
})
|
||||
}
|
||||
|
@ -445,7 +428,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
unsafe { LLVMSetAlignment(alloca, align) };
|
||||
}
|
||||
if !var.array_init.is_empty() {
|
||||
todo!()
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -722,13 +705,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
}
|
||||
|
||||
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
||||
let type_ = get_scalar_type(self.context, constant.typ);
|
||||
let value = match constant.value {
|
||||
ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) },
|
||||
ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) },
|
||||
ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) },
|
||||
ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) },
|
||||
};
|
||||
let value = get_immediate_value(self.context, &constant.typ, &constant.value);
|
||||
self.resolver.register(constant.dst, value);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -3133,7 +3110,7 @@ impl std::fmt::Display for LLVMTypeDisplay {
|
|||
ast::ScalarType::F64 => write!(f, "f64"),
|
||||
ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
|
||||
ast::ScalarType::F16x2 => write!(f, "v2f16"),
|
||||
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
|
||||
ptx_parser::ScalarType::BF16x2 => write!(f, "v2bf16"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) },
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -224,7 +224,7 @@ fn error_unknown_symbol<T: Into<String>>(symbol: T) -> TranslateError {
|
|||
|
||||
#[cfg(debug_assertions)]
|
||||
fn error_mismatched_type() -> TranslateError {
|
||||
panic!()
|
||||
panic!("Mismatched type")
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
|
@ -613,6 +613,12 @@ struct ConstantDefinition {
|
|||
pub value: ast::ImmediateValue,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConstantDefinition {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "zluda.constant{} {}", self.typ, self.value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PtrAccess<T> {
|
||||
underlying_type: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
|
@ -629,6 +635,22 @@ struct RepackVectorDetails {
|
|||
relaxed_type_check: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RepackVectorDetails {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let extract = if self.is_extract {
|
||||
".extract"
|
||||
} else {
|
||||
".composite"
|
||||
};
|
||||
let relaxed = if self.relaxed_type_check {
|
||||
".relaxed"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
write!(f, "zluda.repack_vector{}{}{}", extract, relaxed, self.typ)
|
||||
}
|
||||
}
|
||||
|
||||
struct FunctionPointerDetails {
|
||||
dst: SpirvWord,
|
||||
src: SpirvWord,
|
||||
|
|
|
@ -92,7 +92,7 @@ fn run_variable<'input, 'b>(
|
|||
align: variable.align,
|
||||
v_type: variable.v_type,
|
||||
state_space: variable.state_space,
|
||||
array_init: variable.array_init,
|
||||
array_init: run_array_init(resolver, &variable.array_init)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -171,7 +171,7 @@ fn run_multivariable<'input, 'b>(
|
|||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
array_init: run_array_init(resolver, &variable.var.array_init)?,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
@ -186,9 +186,22 @@ fn run_multivariable<'input, 'b>(
|
|||
v_type: variable.var.v_type.clone(),
|
||||
state_space: variable.var.state_space,
|
||||
name: ident,
|
||||
array_init: variable.var.array_init.clone(),
|
||||
array_init: run_array_init(resolver, &variable.var.array_init)?,
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_array_init<'input, 'b>(
|
||||
resolver: &mut ScopedResolver<'input, 'b>,
|
||||
array_init: &[ast::RegOrImmediate<&'input str>],
|
||||
) -> Result<Vec<ast::RegOrImmediate<SpirvWord>>, TranslateError> {
|
||||
Ok(array_init
|
||||
.iter()
|
||||
.map(|elem| match elem {
|
||||
ast::RegOrImmediate::Reg(name) => Ok(ast::RegOrImmediate::Reg(resolver.get(name)?)),
|
||||
ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(*imm)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?)
|
||||
}
|
||||
|
|
22
ptx/src/pass/test/expand_operands/mod.rs
Normal file
22
ptx/src/pass/test/expand_operands/mod.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
use crate::pass::{test::directive2_vec_to_string, *};
|
||||
|
||||
use super::test_pass;
|
||||
|
||||
macro_rules! test_expand_operands {
|
||||
($test_name:ident) => {
|
||||
test_pass!(run_expand_operands, $test_name);
|
||||
};
|
||||
}
|
||||
|
||||
fn run_expand_operands(ptx: ptx_parser::Module) -> String {
|
||||
// We run the minimal number of passes required to produce the input expected by expand_operands
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap();
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
|
||||
directive2_vec_to_string(&flat_resolver, directives)
|
||||
}
|
||||
|
||||
test_expand_operands!(vector_operand);
|
||||
test_expand_operands!(vector_operand_convert);
|
23
ptx/src/pass/test/expand_operands/vector_operand.ptx
Normal file
23
ptx/src/pass/test/expand_operands/vector_operand.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_60
|
||||
.address_size 64
|
||||
|
||||
.func (.reg .v2.b16 output) default (
|
||||
.reg .b16 input
|
||||
)
|
||||
{
|
||||
mov.v2.b16 output, {0x5678, input};
|
||||
ret;
|
||||
}
|
||||
|
||||
// %%% output %%%
|
||||
|
||||
.func (.reg .v2 .b16 %2) %1 (
|
||||
.reg .b16 %3
|
||||
)
|
||||
{
|
||||
.b16.reg %4 = zluda.constant.b16 22136;
|
||||
.v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3;
|
||||
mov.v2.b16 %2, %5;
|
||||
ret;
|
||||
}
|
23
ptx/src/pass/test/expand_operands/vector_operand_convert.ptx
Normal file
23
ptx/src/pass/test/expand_operands/vector_operand_convert.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_60
|
||||
.address_size 64
|
||||
|
||||
.func (.reg .b32 output) default (
|
||||
.reg .b16 input
|
||||
)
|
||||
{
|
||||
mov.b32 output, {0x5678, input};
|
||||
ret;
|
||||
}
|
||||
|
||||
// %%% output %%%
|
||||
|
||||
.func (.reg .b32 %2) %1 (
|
||||
.reg .b16 %3
|
||||
)
|
||||
{
|
||||
.b16.reg %4 = zluda.constant.b16 22136;
|
||||
.v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3;
|
||||
mov.b32 %2, %5;
|
||||
ret;
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.func (.reg .b32 output) default (
|
||||
.reg .v2.b16 input
|
||||
)
|
||||
{
|
||||
mov.b32 output, input;
|
||||
ret;
|
||||
}
|
||||
|
||||
// %%% output %%%
|
||||
|
||||
.func (.reg .b32 %2) %1 (
|
||||
.reg .v2 .b16 %3
|
||||
)
|
||||
{
|
||||
.b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.v2.b16 %3;
|
||||
mov.b32 %2, %4;
|
||||
ret;
|
||||
}
|
|
@ -21,3 +21,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
|
|||
|
||||
test_insert_implicit_conversions!(default);
|
||||
test_insert_implicit_conversions!(default_reg_b32_reg_f16x2);
|
||||
test_insert_implicit_conversions!(default_reg_b32_reg_v2_b16);
|
||||
|
|
|
@ -6,6 +6,7 @@ use std::{
|
|||
path::Path,
|
||||
};
|
||||
|
||||
mod expand_operands;
|
||||
mod insert_implicit_conversions;
|
||||
|
||||
#[macro_export]
|
||||
|
@ -202,6 +203,8 @@ fn statement_to_string(
|
|||
Statement::Variable(var) => format!("{}", var),
|
||||
Statement::Instruction(instr) => format!("{}", instr),
|
||||
Statement::Conversion(conv) => format!("{}", conv),
|
||||
Statement::Constant(constant) => format!("{}", constant),
|
||||
Statement::RepackVector(repack) => format!("{}", repack),
|
||||
_ => todo!(),
|
||||
};
|
||||
let mut args_formatter = StatementFormatter::new(resolver);
|
||||
|
|
55
ptx/src/test/ll/const_ident.ll
Normal file
55
ptx/src/test/ll/const_ident.ll
Normal file
|
@ -0,0 +1,55 @@
|
|||
@x = addrspace(4) global i64 1
|
||||
@y = addrspace(4) global [4 x i64] [i64 4, i64 5, i64 6, i64 0]
|
||||
@constparams = addrspace(4) global [4 x i64] [i64 ptrtoint (ptr addrspace(4) @x to i64), i64 ptrtoint (ptr addrspace(4) @y to i64)]
|
||||
|
||||
define amdgpu_kernel void @const_ident(ptr addrspace(4) byref(i64) %"49", ptr addrspace(4) byref(i64) %"50") #0 {
|
||||
%"51" = alloca i64, align 8, addrspace(5)
|
||||
%"52" = alloca i64, align 8, addrspace(5)
|
||||
%"53" = alloca i64, align 8, addrspace(5)
|
||||
%"54" = alloca i64, align 8, addrspace(5)
|
||||
%"55" = alloca i64, align 8, addrspace(5)
|
||||
%"56" = alloca i64, align 8, addrspace(5)
|
||||
%"57" = alloca i64, align 8, addrspace(5)
|
||||
%"58" = alloca i64, align 8, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"48"
|
||||
|
||||
"48": ; preds = %1
|
||||
%"59" = load i64, ptr addrspace(4) %"49", align 8
|
||||
store i64 %"59", ptr addrspace(5) %"51", align 8
|
||||
%"60" = load i64, ptr addrspace(4) %"50", align 8
|
||||
store i64 %"60", ptr addrspace(5) %"52", align 8
|
||||
store i64 ptrtoint (ptr addrspace(4) @x to i64), ptr addrspace(5) %"53", align 8
|
||||
store i64 ptrtoint (ptr addrspace(4) @y to i64), ptr addrspace(5) %"54", align 8
|
||||
%"63" = load i64, ptr addrspace(4) @constparams, align 8
|
||||
store i64 %"63", ptr addrspace(5) %"55", align 8
|
||||
%"64" = load i64, ptr addrspace(4) getelementptr inbounds (i8, ptr addrspace(4) @constparams, i64 8), align 8
|
||||
store i64 %"64", ptr addrspace(5) %"56", align 8
|
||||
%"66" = load i64, ptr addrspace(5) %"53", align 8
|
||||
%"67" = load i64, ptr addrspace(5) %"55", align 8
|
||||
%"65" = xor i64 %"66", %"67"
|
||||
store i64 %"65", ptr addrspace(5) %"57", align 8
|
||||
%"69" = load i64, ptr addrspace(5) %"54", align 8
|
||||
%"70" = load i64, ptr addrspace(5) %"56", align 8
|
||||
%"68" = xor i64 %"69", %"70"
|
||||
store i64 %"68", ptr addrspace(5) %"58", align 8
|
||||
%"71" = load i64, ptr addrspace(5) %"52", align 8
|
||||
%"72" = load i64, ptr addrspace(5) %"57", align 8
|
||||
%"85" = inttoptr i64 %"71" to ptr
|
||||
store i64 %"72", ptr %"85", align 8
|
||||
%"73" = load i64, ptr addrspace(5) %"52", align 8
|
||||
%"87" = inttoptr i64 %"73" to ptr
|
||||
%"45" = getelementptr inbounds i8, ptr %"87", i64 8
|
||||
%"74" = load i64, ptr addrspace(5) %"58", align 8
|
||||
store i64 %"74", ptr %"45", align 8
|
||||
%"75" = load i64, ptr addrspace(5) %"52", align 8
|
||||
%"89" = inttoptr i64 %"75" to ptr
|
||||
%"47" = getelementptr inbounds i8, ptr %"89", i64 8
|
||||
%"76" = load i64, ptr addrspace(5) %"58", align 8
|
||||
store i64 %"76", ptr %"47", align 8
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
51
ptx/src/test/ll/fma_bf16x2.ll
Normal file
51
ptx/src/test/ll/fma_bf16x2.ll
Normal file
|
@ -0,0 +1,51 @@
|
|||
define amdgpu_kernel void @fma_bf16x2(ptr addrspace(4) byref(i64) %"39", ptr addrspace(4) byref(i64) %"40") #0 {
|
||||
%"41" = alloca i64, align 8, addrspace(5)
|
||||
%"42" = alloca i64, align 8, addrspace(5)
|
||||
%"43" = alloca i32, align 4, addrspace(5)
|
||||
%"44" = alloca i32, align 4, addrspace(5)
|
||||
%"45" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"38"
|
||||
|
||||
"38": ; preds = %1
|
||||
%"46" = load i64, ptr addrspace(4) %"39", align 8
|
||||
store i64 %"46", ptr addrspace(5) %"41", align 8
|
||||
%"47" = load i64, ptr addrspace(4) %"40", align 8
|
||||
store i64 %"47", ptr addrspace(5) %"42", align 8
|
||||
%"49" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"60" = inttoptr i64 %"49" to ptr
|
||||
%"48" = load i32, ptr %"60", align 4
|
||||
store i32 %"48", ptr addrspace(5) %"43", align 4
|
||||
%"50" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"61" = inttoptr i64 %"50" to ptr
|
||||
%"35" = getelementptr inbounds i8, ptr %"61", i64 4
|
||||
%"51" = load i32, ptr %"35", align 4
|
||||
store i32 %"51", ptr addrspace(5) %"44", align 4
|
||||
%"52" = load i64, ptr addrspace(5) %"41", align 8
|
||||
%"62" = inttoptr i64 %"52" to ptr
|
||||
%"37" = getelementptr inbounds i8, ptr %"62", i64 8
|
||||
%"53" = load i32, ptr %"37", align 4
|
||||
store i32 %"53", ptr addrspace(5) %"45", align 4
|
||||
%"55" = load i32, ptr addrspace(5) %"43", align 4
|
||||
%"56" = load i32, ptr addrspace(5) %"44", align 4
|
||||
%"57" = load i32, ptr addrspace(5) %"45", align 4
|
||||
%"64" = bitcast i32 %"55" to <2 x bfloat>
|
||||
%"65" = bitcast i32 %"56" to <2 x bfloat>
|
||||
%"66" = bitcast i32 %"57" to <2 x bfloat>
|
||||
%"63" = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %"64", <2 x bfloat> %"65", <2 x bfloat> %"66")
|
||||
%"54" = bitcast <2 x bfloat> %"63" to i32
|
||||
store i32 %"54", ptr addrspace(5) %"43", align 4
|
||||
%"58" = load i64, ptr addrspace(5) %"42", align 8
|
||||
%"59" = load i32, ptr addrspace(5) %"43", align 4
|
||||
%"67" = inttoptr i64 %"58" to ptr
|
||||
store i32 %"59", ptr %"67", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
|
||||
declare <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) #1
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
28
ptx/src/test/ll/global_array_f32.ll
Normal file
28
ptx/src/test/ll/global_array_f32.ll
Normal file
|
@ -0,0 +1,28 @@
|
|||
@foobar = addrspace(1) global [4 x float] [float 1.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00]
|
||||
|
||||
define amdgpu_kernel void @global_array_f32(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 {
|
||||
%"38" = alloca i64, align 8, addrspace(5)
|
||||
%"39" = alloca i64, align 8, addrspace(5)
|
||||
%"40" = alloca float, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"35"
|
||||
|
||||
"35": ; preds = %1
|
||||
store i64 ptrtoint (ptr addrspace(1) @foobar to i64), ptr addrspace(5) %"38", align 8
|
||||
%"42" = load i64, ptr addrspace(4) %"37", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"39", align 8
|
||||
%"43" = load i64, ptr addrspace(5) %"38", align 8
|
||||
%"48" = inttoptr i64 %"43" to ptr addrspace(1)
|
||||
%"34" = getelementptr inbounds i8, ptr addrspace(1) %"48", i64 4
|
||||
%"44" = load float, ptr addrspace(1) %"34", align 4
|
||||
store float %"44", ptr addrspace(5) %"40", align 4
|
||||
%"45" = load i64, ptr addrspace(5) %"39", align 8
|
||||
%"46" = load float, ptr addrspace(5) %"40", align 4
|
||||
%"49" = inttoptr i64 %"45" to ptr
|
||||
store float %"46", ptr %"49", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
31
ptx/src/test/ll/vector_operand.ll
Normal file
31
ptx/src/test/ll/vector_operand.ll
Normal file
|
@ -0,0 +1,31 @@
|
|||
define amdgpu_kernel void @vector_operand(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 {
|
||||
%"38" = alloca i64, align 8, addrspace(5)
|
||||
%"39" = alloca i64, align 8, addrspace(5)
|
||||
%"40" = alloca i16, align 2, addrspace(5)
|
||||
%"41" = alloca i32, align 4, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"35"
|
||||
|
||||
"35": ; preds = %1
|
||||
%"42" = load i64, ptr addrspace(4) %"36", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"38", align 8
|
||||
%"43" = load i64, ptr addrspace(4) %"37", align 8
|
||||
store i64 %"43", ptr addrspace(5) %"39", align 8
|
||||
%"45" = load i64, ptr addrspace(5) %"38", align 8
|
||||
%"50" = inttoptr i64 %"45" to ptr
|
||||
%"44" = load i16, ptr %"50", align 2
|
||||
store i16 %"44", ptr addrspace(5) %"40", align 2
|
||||
%"46" = load i16, ptr addrspace(5) %"40", align 2
|
||||
%"34" = insertelement <2 x i16> <i16 22136, i16 undef>, i16 %"46", i8 1
|
||||
%"51" = bitcast <2 x i16> %"34" to i32
|
||||
store i32 %"51", ptr addrspace(5) %"41", align 4
|
||||
%"48" = load i64, ptr addrspace(5) %"39", align 8
|
||||
%"49" = load i32, ptr addrspace(5) %"41", align 4
|
||||
%"52" = inttoptr i64 %"48" to ptr
|
||||
store i32 %"49", ptr %"52", align 4
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
|
@ -7,6 +7,7 @@
|
|||
)
|
||||
{
|
||||
.reg .u32 %reg<10>;
|
||||
.reg .b16 %reg_16;
|
||||
.reg .u64 %reg_64;
|
||||
.reg .pred p;
|
||||
.reg .pred q;
|
||||
|
@ -30,4 +31,7 @@
|
|||
|
||||
// vector index - only supported by mov (maybe: ld, st, tex)
|
||||
mov.u32 %reg0, %ntid.x;
|
||||
|
||||
// vector operand
|
||||
mov.u32 %reg0, {0, %reg_16};
|
||||
}
|
||||
|
|
39
ptx/src/test/spirv_run/const_ident.ptx
Normal file
39
ptx/src/test/spirv_run/const_ident.ptx
Normal file
|
@ -0,0 +1,39 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.const .u64 x = 1;
|
||||
.const .u64 y[4] = {4,5,6};
|
||||
.const .u64 constparams[2] = { x, y };
|
||||
|
||||
.visible .entry const_ident(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 x_addr;
|
||||
.reg .u64 y_addr;
|
||||
.reg .u64 constparams_0;
|
||||
.reg .u64 constparams_1;
|
||||
.reg .b64 x_equal;
|
||||
.reg .b64 y_equal;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
mov.u64 x_addr, x;
|
||||
mov.u64 y_addr, y;
|
||||
|
||||
ld.const.u64 constparams_0, [constparams+0];
|
||||
ld.const.u64 constparams_1, [constparams+8];
|
||||
|
||||
xor.b64 x_equal, x_addr, constparams_0;
|
||||
xor.b64 y_equal, y_addr, constparams_1;
|
||||
|
||||
st.u64 [out_addr], x_equal;
|
||||
st.u64 [out_addr+8], y_equal;
|
||||
st.u64 [out_addr+8], y_equal;
|
||||
ret;
|
||||
}
|
25
ptx/src/test/spirv_run/fma_bf16x2.ptx
Normal file
25
ptx/src/test/spirv_run/fma_bf16x2.ptx
Normal file
|
@ -0,0 +1,25 @@
|
|||
.version 7.0
|
||||
.target sm_80
|
||||
.address_size 64
|
||||
|
||||
.visible .entry fma_bf16x2(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b32 temp1;
|
||||
.reg .b32 temp2;
|
||||
.reg .b32 temp3;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.b32 temp1, [in_addr];
|
||||
ld.b32 temp2, [in_addr+4];
|
||||
ld.b32 temp3, [in_addr+8];
|
||||
fma.rn.bf16x2 temp1, temp1, temp2, temp3;
|
||||
st.b32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
22
ptx/src/test/spirv_run/global_array_f32.ptx
Normal file
22
ptx/src/test/spirv_run/global_array_f32.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.global .f32 foobar[4] = {0f3f800000};
|
||||
|
||||
.visible .entry global_array_f32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp;
|
||||
|
||||
mov.u64 in_addr, foobar;
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.global.f32 temp, [in_addr+4];
|
||||
st.f32 [out_addr], temp;
|
||||
ret;
|
||||
}
|
|
@ -137,6 +137,7 @@ test_ptx!(
|
|||
[0x1_00_00_00_00_00_00i64]
|
||||
);
|
||||
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
|
||||
test_ptx!(vector_operand, [0x1234u16], [0x12345678]);
|
||||
test_ptx!(shr, [-2i32], [-1i32]);
|
||||
test_ptx!(shr_oob, [-32768i16], [-1i16]);
|
||||
test_ptx!(or, [1u64, 2u64], [3u64]);
|
||||
|
@ -144,6 +145,7 @@ test_ptx!(sub, [2u64], [1u64]);
|
|||
test_ptx!(min, [555i32, 444i32], [444i32]);
|
||||
test_ptx!(max, [555i32, 444i32], [555i32]);
|
||||
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
||||
test_ptx!(global_array_f32, [0x0], [0f32]);
|
||||
test_ptx!(extern_shared, [127u64], [127u64]);
|
||||
test_ptx!(extern_shared_call, [121u64], [123u64]);
|
||||
test_ptx!(rcp, [2f32], [0.5f32]);
|
||||
|
@ -166,6 +168,11 @@ test_ptx!(and, [6u32, 3u32], [2u32]);
|
|||
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
||||
test_ptx!(selp_true, [100u16, 200u16], [100u16]);
|
||||
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
|
||||
test_ptx!(
|
||||
fma_bf16x2,
|
||||
[0x40004040, 0x40404080, 0x40A04040],
|
||||
[0x41304170]
|
||||
);
|
||||
test_ptx!(shared_variable, [513u64], [513u64]);
|
||||
test_ptx!(shared_ptr_32, [513u64], [513u64]);
|
||||
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
|
||||
|
@ -256,6 +263,7 @@ test_ptx!(
|
|||
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
|
||||
test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
|
||||
test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
|
||||
test_ptx!(const_ident, [0u16], [0u64, 0u64]);
|
||||
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
|
||||
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
|
||||
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
|
||||
|
|
25
ptx/src/test/spirv_run/vector_operand.ptx
Normal file
25
ptx/src/test/spirv_run/vector_operand.ptx
Normal file
|
@ -0,0 +1,25 @@
|
|||
.version 6.5
|
||||
.target sm_60
|
||||
.address_size 64
|
||||
|
||||
.visible .entry vector_operand(
|
||||
.param .u64 input_p,
|
||||
.param .u64 output_p
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
|
||||
.reg .b16 in;
|
||||
.reg .b32 out;
|
||||
|
||||
ld.param.u64 in_addr, [input_p];
|
||||
ld.param.u64 out_addr, [output_p];
|
||||
|
||||
ld.b16 in, [in_addr];
|
||||
|
||||
mov.b32 out, {0x5678, in};
|
||||
|
||||
st.b32 [out_addr], out;
|
||||
ret;
|
||||
}
|
|
@ -7,6 +7,7 @@ use crate::{
|
|||
ShuffleMode, VoteMode,
|
||||
};
|
||||
use bitflags::bitflags;
|
||||
use derive_more::Display;
|
||||
use std::{alloc::Layout, cmp::Ordering, fmt::Write, num::NonZeroU8};
|
||||
|
||||
pub enum Statement<P: Operand> {
|
||||
|
@ -807,7 +808,17 @@ where
|
|||
),
|
||||
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
|
||||
vec.into_iter()
|
||||
.map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
|
||||
.map(|reg_or_immediate| {
|
||||
Ok(match reg_or_immediate {
|
||||
RegOrImmediate::Reg(ident) => RegOrImmediate::Reg((self)(
|
||||
ident,
|
||||
type_space,
|
||||
is_dst,
|
||||
relaxed_type_check,
|
||||
)?),
|
||||
RegOrImmediate::Imm(imm) => RegOrImmediate::Imm(imm),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
),
|
||||
})
|
||||
|
@ -948,7 +959,7 @@ pub struct Variable<ID> {
|
|||
pub v_type: Type,
|
||||
pub state_space: StateSpace,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
pub array_init: Vec<RegOrImmediate<ID>>,
|
||||
}
|
||||
|
||||
impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> {
|
||||
|
@ -1004,6 +1015,7 @@ impl std::fmt::Display for Type {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Type::Scalar(scalar_type) => write!(f, "{}", scalar_type),
|
||||
Type::Vector(count, scalar_type) => write!(f, ".v{}{}", count, scalar_type),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
@ -1062,6 +1074,15 @@ impl Type {
|
|||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn from_size(size: u8) -> Option<Self> {
|
||||
Some(match size {
|
||||
1 => ScalarType::B8,
|
||||
2 => ScalarType::B16,
|
||||
4 => ScalarType::B32,
|
||||
16 => ScalarType::B128,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
pub fn size_of(self) -> u8 {
|
||||
match self {
|
||||
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
|
||||
|
@ -1212,13 +1233,19 @@ pub struct ShfDetails {
|
|||
pub mode: FunnelShiftMode,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Display)]
|
||||
pub enum RegOrImmediate<Ident> {
|
||||
Reg(Ident),
|
||||
Imm(ImmediateValue),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ParsedOperand<Ident> {
|
||||
Reg(Ident),
|
||||
RegOffset(Ident, i32),
|
||||
Imm(ImmediateValue),
|
||||
VecMember(Ident, u8),
|
||||
VecPack(Vec<Ident>),
|
||||
VecPack(Vec<RegOrImmediate<Ident>>),
|
||||
}
|
||||
|
||||
impl<Ident> ParsedOperand<Ident> {
|
||||
|
|
|
@ -332,6 +332,19 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
|
|||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn reg_or_immediate<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::RegOrImmediate<&'input str>> {
|
||||
trace(
|
||||
"reg_or_immediate",
|
||||
alt((
|
||||
immediate_value.map(|imm| ast::RegOrImmediate::Imm(imm)),
|
||||
ident.map(|id| ast::RegOrImmediate::Reg(id)),
|
||||
)),
|
||||
)
|
||||
.parse_next(stream)
|
||||
}
|
||||
|
||||
pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
|
||||
let (tokens, mut errors) = lex_with_span_unchecked(text);
|
||||
let parse_result = {
|
||||
|
@ -964,9 +977,9 @@ fn multi_variable<'a, 'input: 'a>(
|
|||
let initializer = match state_space {
|
||||
StateSpace::Global | StateSpace::Const => match array_dimensions {
|
||||
Some(ref mut dimensions) => {
|
||||
opt(array_initializer(vector, type_, dimensions)).parse_next(stream)?
|
||||
opt(array_initializer(type_, vector, dimensions)).parse_next(stream)?
|
||||
}
|
||||
None => opt(value_initializer(vector, type_)).parse_next(stream)?,
|
||||
None => opt(value_initializer(vector)).parse_next(stream)?,
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
@ -990,10 +1003,10 @@ fn multi_variable<'a, 'input: 'a>(
|
|||
}
|
||||
|
||||
fn array_initializer<'b, 'a: 'b, 'input: 'a>(
|
||||
vector: Option<NonZeroU8>,
|
||||
type_: ScalarType,
|
||||
vector: Option<NonZeroU8>,
|
||||
array_dimensions: &'b mut Vec<u32>,
|
||||
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + 'b {
|
||||
) -> impl Parser<PtxParser<'a, 'input>, Vec<RegOrImmediate<&'input str>>, ContextError> + 'b {
|
||||
trace(
|
||||
"array_initializer",
|
||||
move |stream: &mut PtxParser<'a, 'input>| {
|
||||
|
@ -1007,15 +1020,24 @@ fn array_initializer<'b, 'a: 'b, 'input: 'a>(
|
|||
Token::LBrace,
|
||||
separated::<_, (), (), _, _, _, _>(
|
||||
0..=array_dimensions[0] as usize,
|
||||
single_value_append(&mut result, type_),
|
||||
single_value_append(&mut result),
|
||||
Token::Comma,
|
||||
),
|
||||
Token::RBrace,
|
||||
)
|
||||
.parse_next(stream)?;
|
||||
// pad with zeros
|
||||
let result_size = type_.size_of() as usize * array_dimensions[0] as usize;
|
||||
result.extend(iter::repeat(0u8).take(result_size - result.len()));
|
||||
let result_size = array_dimensions[0] as usize;
|
||||
let default = match type_.kind() {
|
||||
ScalarKind::Bit | ScalarKind::Unsigned | ScalarKind::Pred => {
|
||||
ast::ImmediateValue::U64(0)
|
||||
}
|
||||
ScalarKind::Signed => ast::ImmediateValue::S64(0),
|
||||
ScalarKind::Float => ast::ImmediateValue::F64(0.0),
|
||||
};
|
||||
result.extend(
|
||||
iter::repeat(ast::RegOrImmediate::Imm(default)).take(result_size - result.len()),
|
||||
);
|
||||
Ok(result)
|
||||
},
|
||||
)
|
||||
|
@ -1023,8 +1045,7 @@ fn array_initializer<'b, 'a: 'b, 'input: 'a>(
|
|||
|
||||
fn value_initializer<'a, 'input: 'a>(
|
||||
vector: Option<NonZeroU8>,
|
||||
type_: ScalarType,
|
||||
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> {
|
||||
) -> impl Parser<PtxParser<'a, 'input>, Vec<RegOrImmediate<&'input str>>, ContextError> {
|
||||
trace(
|
||||
"value_initializer",
|
||||
move |stream: &mut PtxParser<'a, 'input>| {
|
||||
|
@ -1034,77 +1055,20 @@ fn value_initializer<'a, 'input: 'a>(
|
|||
if vector.is_some() {
|
||||
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
|
||||
}
|
||||
single_value_append(&mut result, type_).parse_next(stream)?;
|
||||
single_value_append(&mut result).parse_next(stream)?;
|
||||
Ok(result)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn single_value_append<'b, 'a: 'b, 'input: 'a>(
|
||||
accumulator: &'b mut Vec<u8>,
|
||||
type_: ScalarType,
|
||||
accumulator: &'b mut Vec<RegOrImmediate<&'input str>>,
|
||||
) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + 'b {
|
||||
trace(
|
||||
"single_value_append",
|
||||
move |stream: &mut PtxParser<'a, 'input>| {
|
||||
let value = immediate_value.parse_next(stream)?;
|
||||
match (type_, value) {
|
||||
(ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u8).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u8).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u16).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u16).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u32).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u32).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u64).to_le_bytes())
|
||||
}
|
||||
(ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as u64).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S8, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i8).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S8, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i8).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S16, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i16).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S16, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i16).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S32, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i32).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S32, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i32).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S64, ImmediateValue::U64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i64).to_le_bytes())
|
||||
}
|
||||
(ScalarType::S64, ImmediateValue::S64(x)) => {
|
||||
accumulator.extend_from_slice(&(x as i64).to_le_bytes())
|
||||
}
|
||||
(ScalarType::F32, ImmediateValue::F32(x)) => {
|
||||
accumulator.extend_from_slice(&x.to_le_bytes())
|
||||
}
|
||||
(ScalarType::F64, ImmediateValue::F64(x)) => {
|
||||
accumulator.extend_from_slice(&x.to_le_bytes())
|
||||
}
|
||||
_ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)),
|
||||
}
|
||||
let value = reg_or_immediate.parse_next(stream)?;
|
||||
accumulator.push(value);
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
|
@ -1379,12 +1343,18 @@ impl<Ident> ast::ParsedOperand<Ident> {
|
|||
}
|
||||
fn vector_operand<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<Vec<&'input str>> {
|
||||
let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?;
|
||||
) -> PResult<Vec<ast::RegOrImmediate<&'input str>>> {
|
||||
let (_, r1, _, r2) = (
|
||||
Token::LBrace,
|
||||
reg_or_immediate,
|
||||
Token::Comma,
|
||||
reg_or_immediate,
|
||||
)
|
||||
.parse_next(stream)?;
|
||||
// TODO: parse .v8 literals
|
||||
dispatch! {any;
|
||||
(Token::RBrace, _) => empty.map(|_| vec![r1, r2]),
|
||||
(Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
|
||||
(Token::Comma, _) => (reg_or_immediate, Token::Comma, reg_or_immediate, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
|
||||
_ => fail
|
||||
}
|
||||
.parse_next(stream)
|
||||
|
@ -1421,7 +1391,7 @@ pub enum PtxError<'input> {
|
|||
#[from]
|
||||
source: TokenError,
|
||||
},
|
||||
#[error("{0}")]
|
||||
#[error("Context error: {0}")]
|
||||
Parser(ContextError),
|
||||
#[error("")]
|
||||
Todo,
|
||||
|
@ -2772,14 +2742,30 @@ derive_parser!(
|
|||
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16 };
|
||||
//fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c;
|
||||
//fma.rnd{.ftz}.relu.f16 d, a, b, c;
|
||||
//fma.rnd{.ftz}.relu.f16x2 d, a, b, c;
|
||||
//fma.rnd{.relu}.bf16 d, a, b, c;
|
||||
//fma.rnd{.relu}.bf16x2 d, a, b, c;
|
||||
//fma.rnd.oob.{relu}.type d, a, b, c;
|
||||
fma.rnd{.relu}.bf16x2 d, a, b, c => {
|
||||
if relu {
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
ast::Instruction::Fma {
|
||||
data: ast::ArithFloat {
|
||||
type_: bf16x2,
|
||||
rounding: rnd.into(),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
is_fusable: false
|
||||
},
|
||||
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
|
||||
}
|
||||
}
|
||||
.rnd: RawRoundingMode = { .rn };
|
||||
ScalarType = { .f16 };
|
||||
ScalarType = { .bf16x2 };
|
||||
//fma.rnd.oob.{relu}.type d, a, b, c;
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue