Support immediates in vector operands (#488)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

This commit is contained in:
Violet 2025-09-08 10:26:58 -07:00 committed by GitHub
commit 4306646739
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 236 additions and 10 deletions

BIN
elf.o Normal file

Binary file not shown.

View file

@ -230,15 +230,33 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
fn vec_pack(
&mut self,
vector_elements: Vec<SpirvWord>,
vector_elements: Vec<ast::RegOrImmediate<SpirvWord>>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
Some((ast::Type::Scalar(scalar_t), space))
if scalar_t.kind() == ast::ScalarKind::Bit =>
{
let type_ =
ast::ScalarType::from_size(scalar_t.size_of() / (vector_elements.len() as u8))
.ok_or_else(|| error_mismatched_type())?;
(vector_elements.len() as u8, type_, space)
}
_ => return Err(error_mismatched_type()),
};
let vector_elements = vector_elements
.into_iter()
.map(|element| match element {
ast::RegOrImmediate::Reg(name) => self.reg(name),
ast::RegOrImmediate::Imm(immediate_value) => self.immediate(
immediate_value,
Some((&ast::Type::Scalar(scalar_t), state_space)),
),
})
.collect::<Result<Vec<_>, _>>()?;
let temporary_vector = self
.resolver
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));

View file

@ -198,10 +198,15 @@ pub fn map_operand<T: Copy, Err>(
Some(ident) => ast::ParsedOperand::Reg(ident),
None => ast::ParsedOperand::VecMember(ident, member),
},
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
ast::ParsedOperand::VecPack(elements) => ast::ParsedOperand::VecPack(
elements
.into_iter()
.map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.map(|element| match element {
ast::RegOrImmediate::Reg(ident) => {
Ok(ast::RegOrImmediate::Reg(fn_(ident, None)?.unwrap_or(ident)))
}
ast::RegOrImmediate::Imm(imm) => Ok(ast::RegOrImmediate::Imm(imm)),
})
.collect::<Result<Vec<_>, _>>()?,
),
})

View file

@ -613,6 +613,12 @@ struct ConstantDefinition {
pub value: ast::ImmediateValue,
}
impl std::fmt::Display for ConstantDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "zluda.constant{} {}", self.typ, self.value)
}
}
pub struct PtrAccess<T> {
underlying_type: ast::Type,
state_space: ast::StateSpace,
@ -629,6 +635,22 @@ struct RepackVectorDetails {
relaxed_type_check: bool,
}
impl std::fmt::Display for RepackVectorDetails {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let extract = if self.is_extract {
".extract"
} else {
".composite"
};
let relaxed = if self.relaxed_type_check {
".relaxed"
} else {
""
};
write!(f, "zluda.repack_vector{}{}{}", extract, relaxed, self.typ)
}
}
struct FunctionPointerDetails {
dst: SpirvWord,
src: SpirvWord,

View file

@ -0,0 +1,22 @@
use crate::pass::{test::directive2_vec_to_string, *};
use super::test_pass;
macro_rules! test_expand_operands {
($test_name:ident) => {
test_pass!(run_expand_operands, $test_name);
};
}
fn run_expand_operands(ptx: ptx_parser::Module) -> String {
// We run the minimal number of passes required to produce the input expected by expand_operands
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
directive2_vec_to_string(&flat_resolver, directives)
}
test_expand_operands!(vector_operand);
test_expand_operands!(vector_operand_convert);

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_60
.address_size 64
.func (.reg .v2.b16 output) default (
.reg .b16 input
)
{
mov.v2.b16 output, {0x5678, input};
ret;
}
// %%% output %%%
.func (.reg .v2 .b16 %2) %1 (
.reg .b16 %3
)
{
.b16.reg %4 = zluda.constant.b16 22136;
.v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3;
mov.v2.b16 %2, %5;
ret;
}

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_60
.address_size 64
.func (.reg .b32 output) default (
.reg .b16 input
)
{
mov.b32 output, {0x5678, input};
ret;
}
// %%% output %%%
.func (.reg .b32 %2) %1 (
.reg .b16 %3
)
{
.b16.reg %4 = zluda.constant.b16 22136;
.v2.b16.reg %5 = zluda.repack_vector.composite.b16 %4, %3;
mov.b32 %2, %5;
ret;
}

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.func (.reg .b32 output) default (
.reg .v2.b16 input
)
{
mov.b32 output, input;
ret;
}
// %%% output %%%
.func (.reg .b32 %2) %1 (
.reg .v2 .b16 %3
)
{
.b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.v2.b16 %3;
mov.b32 %2, %4;
ret;
}

View file

@ -21,3 +21,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
test_insert_implicit_conversions!(default);
test_insert_implicit_conversions!(default_reg_b32_reg_f16x2);
test_insert_implicit_conversions!(default_reg_b32_reg_v2_b16);

View file

@ -6,6 +6,7 @@ use std::{
path::Path,
};
mod expand_operands;
mod insert_implicit_conversions;
#[macro_export]
@ -202,6 +203,8 @@ fn statement_to_string(
Statement::Variable(var) => format!("{}", var),
Statement::Instruction(instr) => format!("{}", instr),
Statement::Conversion(conv) => format!("{}", conv),
Statement::Constant(constant) => format!("{}", constant),
Statement::RepackVector(repack) => format!("{}", repack),
_ => todo!(),
};
let mut args_formatter = StatementFormatter::new(resolver);

View file

@ -0,0 +1,31 @@
define amdgpu_kernel void @vector_operand(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 {
%"38" = alloca i64, align 8, addrspace(5)
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca i16, align 2, addrspace(5)
%"41" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"35"
"35": ; preds = %1
%"42" = load i64, ptr addrspace(4) %"36", align 8
store i64 %"42", ptr addrspace(5) %"38", align 8
%"43" = load i64, ptr addrspace(4) %"37", align 8
store i64 %"43", ptr addrspace(5) %"39", align 8
%"45" = load i64, ptr addrspace(5) %"38", align 8
%"50" = inttoptr i64 %"45" to ptr
%"44" = load i16, ptr %"50", align 2
store i16 %"44", ptr addrspace(5) %"40", align 2
%"46" = load i16, ptr addrspace(5) %"40", align 2
%"34" = insertelement <2 x i16> <i16 22136, i16 undef>, i16 %"46", i8 1
%"51" = bitcast <2 x i16> %"34" to i32
store i32 %"51", ptr addrspace(5) %"41", align 4
%"48" = load i64, ptr addrspace(5) %"39", align 8
%"49" = load i32, ptr addrspace(5) %"41", align 4
%"52" = inttoptr i64 %"48" to ptr
store i32 %"49", ptr %"52", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -7,6 +7,7 @@
)
{
.reg .u32 %reg<10>;
.reg .b16 %reg_16;
.reg .u64 %reg_64;
.reg .pred p;
.reg .pred q;
@ -30,4 +31,7 @@
// vector index - only supported by mov (maybe: ld, st, tex)
mov.u32 %reg0, %ntid.x;
// vector operand
mov.u32 %reg0, {0, %reg_16};
}

View file

@ -137,6 +137,7 @@ test_ptx!(
[0x1_00_00_00_00_00_00i64]
);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(vector_operand, [0x1234u16], [0x12345678]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(shr_oob, [-32768i16], [-1i16]);
test_ptx!(or, [1u64, 2u64], [3u64]);

View file

@ -0,0 +1,25 @@
.version 6.5
.target sm_60
.address_size 64
.visible .entry vector_operand(
.param .u64 input_p,
.param .u64 output_p
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b16 in;
.reg .b32 out;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.b16 in, [in_addr];
mov.b32 out, {0x5678, in};
st.b32 [out_addr], out;
ret;
}

View file

@ -808,7 +808,17 @@ where
),
ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
vec.into_iter()
.map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
.map(|reg_or_immediate| {
Ok(match reg_or_immediate {
RegOrImmediate::Reg(ident) => RegOrImmediate::Reg((self)(
ident,
type_space,
is_dst,
relaxed_type_check,
)?),
RegOrImmediate::Imm(imm) => RegOrImmediate::Imm(imm),
})
})
.collect::<Result<Vec<_>, _>>()?,
),
})
@ -1005,6 +1015,7 @@ impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Scalar(scalar_type) => write!(f, "{}", scalar_type),
Type::Vector(count, scalar_type) => write!(f, ".v{}{}", count, scalar_type),
_ => todo!(),
}
}
@ -1063,6 +1074,15 @@ impl Type {
}
impl ScalarType {
pub fn from_size(size: u8) -> Option<Self> {
Some(match size {
1 => ScalarType::B8,
2 => ScalarType::B16,
4 => ScalarType::B32,
16 => ScalarType::B128,
_ => return None,
})
}
pub fn size_of(self) -> u8 {
match self {
ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
@ -1225,7 +1245,7 @@ pub enum ParsedOperand<Ident> {
RegOffset(Ident, i32),
Imm(ImmediateValue),
VecMember(Ident, u8),
VecPack(Vec<Ident>),
VecPack(Vec<RegOrImmediate<Ident>>),
}
impl<Ident> ParsedOperand<Ident> {

View file

@ -1283,12 +1283,18 @@ impl<Ident> ast::ParsedOperand<Ident> {
}
fn vector_operand<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Vec<&'input str>> {
let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?;
) -> PResult<Vec<ast::RegOrImmediate<&'input str>>> {
let (_, r1, _, r2) = (
Token::LBrace,
reg_or_immediate,
Token::Comma,
reg_or_immediate,
)
.parse_next(stream)?;
// TODO: parse .v8 literals
dispatch! {any;
(Token::RBrace, _) => empty.map(|_| vec![r1, r2]),
(Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
(Token::Comma, _) => (reg_or_immediate, Token::Comma, reg_or_immediate, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
_ => fail
}
.parse_next(stream)
@ -1325,7 +1331,7 @@ pub enum PtxError<'input> {
#[from]
source: TokenError,
},
#[error("{0}")]
#[error("Context error: {0}")]
Parser(ContextError),
#[error("")]
Todo,