mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add support for parsing instruction cvt
This commit is contained in:
parent
ff449289eb
commit
a10ee48e91
4 changed files with 302 additions and 32 deletions
113
ptx/src/ast.rs
113
ptx/src/ast.rs
|
@ -9,6 +9,8 @@ quick_error! {
|
|||
display("{}", err)
|
||||
cause(err)
|
||||
}
|
||||
SyntaxError {}
|
||||
NonF32Ftz {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,9 +103,11 @@ pub enum ScalarType {
|
|||
impl From<IntType> for ScalarType {
|
||||
fn from(t: IntType) -> Self {
|
||||
match t {
|
||||
IntType::S8 => ScalarType::S8,
|
||||
IntType::S16 => ScalarType::S16,
|
||||
IntType::S32 => ScalarType::S32,
|
||||
IntType::S64 => ScalarType::S64,
|
||||
IntType::U8 => ScalarType::U8,
|
||||
IntType::U16 => ScalarType::U16,
|
||||
IntType::U32 => ScalarType::U32,
|
||||
IntType::U64 => ScalarType::U64,
|
||||
|
@ -113,14 +117,38 @@ impl From<IntType> for ScalarType {
|
|||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum IntType {
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S8,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
}
|
||||
|
||||
impl IntType {
|
||||
pub fn is_signed(self) -> bool {
|
||||
match self {
|
||||
IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
|
||||
IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn width(self) -> u8 {
|
||||
match self {
|
||||
IntType::U8 => 1,
|
||||
IntType::U16 => 2,
|
||||
IntType::U32 => 4,
|
||||
IntType::U64 => 8,
|
||||
IntType::S8 => 1,
|
||||
IntType::S16 => 2,
|
||||
IntType::S32 => 4,
|
||||
IntType::S64 => 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum FloatType {
|
||||
F16,
|
||||
|
@ -178,7 +206,7 @@ pub enum Instruction<P: ArgParams> {
|
|||
SetpBool(SetpBoolData, Arg5<P>),
|
||||
Not(NotType, Arg2<P>),
|
||||
Bra(BraData, Arg1<P>),
|
||||
Cvt(CvtData, Arg2<P>),
|
||||
Cvt(CvtDetails, Arg2<P>),
|
||||
Shl(ShlType, Arg3<P>),
|
||||
St(StData, Arg2St<P>),
|
||||
Ret(RetData),
|
||||
|
@ -398,7 +426,88 @@ pub struct BraData {
|
|||
pub uniform: bool,
|
||||
}
|
||||
|
||||
pub struct CvtData {}
|
||||
pub enum CvtDetails {
|
||||
IntFromInt(CvtIntToIntDesc),
|
||||
FloatFromFloat(CvtDesc<FloatType, FloatType>),
|
||||
IntFromFloat(CvtDesc<IntType, FloatType>),
|
||||
FloatFromInt(CvtDesc<FloatType, IntType>),
|
||||
}
|
||||
|
||||
pub struct CvtIntToIntDesc {
|
||||
pub dst: IntType,
|
||||
pub src: IntType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub struct CvtDesc<Dst, Src> {
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: bool,
|
||||
pub saturate: bool,
|
||||
pub dst: Dst,
|
||||
pub src: Src,
|
||||
}
|
||||
|
||||
impl CvtDetails {
|
||||
pub fn new_int_from_int_checked(
|
||||
saturate: bool,
|
||||
dst: IntType,
|
||||
src: IntType,
|
||||
err: &mut Vec<PtxError>,
|
||||
) -> Self {
|
||||
if saturate {
|
||||
if src.is_signed() {
|
||||
if dst.is_signed() && dst.width() >= src.width() {
|
||||
err.push(PtxError::SyntaxError);
|
||||
}
|
||||
} else {
|
||||
if dst == src || dst.width() >= src.width() {
|
||||
err.push(PtxError::SyntaxError);
|
||||
}
|
||||
}
|
||||
}
|
||||
CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate })
|
||||
}
|
||||
|
||||
pub fn new_float_from_int_checked(
|
||||
rounding: RoundingMode,
|
||||
flush_to_zero: bool,
|
||||
saturate: bool,
|
||||
dst: FloatType,
|
||||
src: IntType,
|
||||
err: &mut Vec<PtxError>,
|
||||
) -> Self {
|
||||
if flush_to_zero && dst != FloatType::F32 {
|
||||
err.push(PtxError::NonF32Ftz);
|
||||
}
|
||||
CvtDetails::FloatFromInt(CvtDesc {
|
||||
dst,
|
||||
src,
|
||||
saturate,
|
||||
flush_to_zero,
|
||||
rounding: Some(rounding),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_int_from_float_checked(
|
||||
rounding: RoundingMode,
|
||||
flush_to_zero: bool,
|
||||
saturate: bool,
|
||||
dst: IntType,
|
||||
src: FloatType,
|
||||
err: &mut Vec<PtxError>,
|
||||
) -> Self {
|
||||
if flush_to_zero && src != FloatType::F32 {
|
||||
err.push(PtxError::NonF32Ftz);
|
||||
}
|
||||
CvtDetails::IntFromFloat(CvtDesc {
|
||||
dst,
|
||||
src,
|
||||
saturate,
|
||||
flush_to_zero,
|
||||
rounding: Some(rounding),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum ShlType {
|
||||
|
|
|
@ -403,13 +403,13 @@ InstMulMode: ast::MulDetails = {
|
|||
typ: t,
|
||||
control: ctr
|
||||
}),
|
||||
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
<r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F32,
|
||||
rounding: r,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
}),
|
||||
<r:RoundingMode?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
<r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F64,
|
||||
rounding: r,
|
||||
flush_to_zero: false,
|
||||
|
@ -436,13 +436,20 @@ MulIntControl: ast::MulIntControl = {
|
|||
};
|
||||
|
||||
#[inline]
|
||||
RoundingMode : ast::RoundingMode = {
|
||||
RoundingModeFloat : ast::RoundingMode = {
|
||||
".rn" => ast::RoundingMode::NearestEven,
|
||||
".rz" => ast::RoundingMode::Zero,
|
||||
".rm" => ast::RoundingMode::NegativeInf,
|
||||
".rp" => ast::RoundingMode::PositiveInf,
|
||||
};
|
||||
|
||||
RoundingModeInt : ast::RoundingMode = {
|
||||
".rni" => ast::RoundingMode::NearestEven,
|
||||
".rzi" => ast::RoundingMode::Zero,
|
||||
".rmi" => ast::RoundingMode::NegativeInf,
|
||||
".rpi" => ast::RoundingMode::PositiveInf,
|
||||
};
|
||||
|
||||
IntType : ast::IntType = {
|
||||
".u16" => ast::IntType::U16,
|
||||
".u32" => ast::IntType::U32,
|
||||
|
@ -468,13 +475,13 @@ InstAddMode: ast::AddDetails = {
|
|||
typ: ast::IntType::S32,
|
||||
saturate: true,
|
||||
}),
|
||||
<rn:RoundingMode?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
typ: ast::FloatType::F32,
|
||||
rounding: rn,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
}),
|
||||
<rn:RoundingMode?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
<rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
typ: ast::FloatType::F64,
|
||||
rounding: rn,
|
||||
flush_to_zero: false,
|
||||
|
@ -580,28 +587,153 @@ InstBra: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
|
||||
InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtData{}, a)
|
||||
}
|
||||
"cvt" <s:".sat"?> <dst_t:CvtTypeInt> <src_t:CvtTypeInt> <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::new_int_from_int_checked(
|
||||
s.is_some(),
|
||||
dst_t,
|
||||
src_t,
|
||||
errors
|
||||
),
|
||||
a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> <dst_t:CvtTypeFloat> <src_t:CvtTypeInt> <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::new_float_from_int_checked(
|
||||
r,
|
||||
f.is_some(),
|
||||
s.is_some(),
|
||||
dst_t,
|
||||
src_t,
|
||||
errors
|
||||
),
|
||||
a)
|
||||
},
|
||||
"cvt" <r:RoundingModeInt> <f:".ftz"?> <s:".sat"?> <dst_t:CvtTypeInt> <src_t:CvtTypeFloat> <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::new_int_from_float_checked(
|
||||
r,
|
||||
f.is_some(),
|
||||
s.is_some(),
|
||||
dst_t,
|
||||
src_t,
|
||||
errors
|
||||
),
|
||||
a)
|
||||
},
|
||||
"cvt" <r:RoundingModeInt?> <s:".sat"?> ".f16" ".f16" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: r,
|
||||
flush_to_zero: false,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <f:".ftz"?> <s:".sat"?> ".f32" ".f16" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: None,
|
||||
flush_to_zero: f.is_some(),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <s:".sat"?> ".f64" ".f16" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: None,
|
||||
flush_to_zero: false,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> ".f16" ".f32" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: Some(r),
|
||||
flush_to_zero: f.is_some(),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat?> <f:".ftz"?> <s:".sat"?> ".f32" ".f32" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: r,
|
||||
flush_to_zero: f.is_some(),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <s:".sat"?> ".f64" ".f32" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: None,
|
||||
flush_to_zero: false,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat> <s:".sat"?> ".f16" ".f64" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: Some(r),
|
||||
flush_to_zero: false,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> ".f32" ".f64" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: Some(r),
|
||||
flush_to_zero: s.is_some(),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
"cvt" <r:RoundingModeFloat?> <s:".sat"?> ".f64" ".f64" <a:Arg2> => {
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
|
||||
ast::CvtDesc {
|
||||
rounding: r,
|
||||
flush_to_zero: false,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
};
|
||||
|
||||
CvtRnd = {
|
||||
CvtIrnd,
|
||||
CvtFrnd
|
||||
}
|
||||
|
||||
CvtIrnd = {
|
||||
".rni", ".rzi", ".rmi", ".rpi"
|
||||
CvtTypeInt: ast::IntType = {
|
||||
".u8" => ast::IntType::U8,
|
||||
".u16" => ast::IntType::U16,
|
||||
".u32" => ast::IntType::U32,
|
||||
".u64" => ast::IntType::U64,
|
||||
".s8" => ast::IntType::S8,
|
||||
".s16" => ast::IntType::S16,
|
||||
".s32" => ast::IntType::S32,
|
||||
".s64" => ast::IntType::S64,
|
||||
};
|
||||
|
||||
CvtFrnd = {
|
||||
".rn", ".rz", ".rm", ".rp"
|
||||
};
|
||||
|
||||
CvtType = {
|
||||
".u8", ".u16", ".u32", ".u64",
|
||||
".s8", ".s16", ".s32", ".s64",
|
||||
".f16", ".f32", ".f64"
|
||||
CvtTypeFloat: ast::FloatType = {
|
||||
".f16" => ast::FloatType::F16,
|
||||
".f32" => ast::FloatType::F32,
|
||||
".f64" => ast::FloatType::F64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
|
||||
|
|
|
@ -1580,13 +1580,6 @@ impl ast::MulDetails {
|
|||
}
|
||||
|
||||
impl ast::IntType {
|
||||
fn is_signed(self) -> bool {
|
||||
match self {
|
||||
ast::IntType::S16 | ast::IntType::S32 | ast::IntType::S64 => true,
|
||||
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_new(t: ast::ScalarType) -> Option<Self> {
|
||||
match t {
|
||||
ast::ScalarType::U16 => Some(ast::IntType::U16),
|
||||
|
|
36
ptx/tools/cvt.py
Normal file
36
ptx/tools/cvt.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
types = ["u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f32", "f64"]
|
||||
rnd = ["", ".rn", ".rni"]
|
||||
ftz_all = ["", ".ftz"]
|
||||
sat = ["", ".sat"]
|
||||
|
||||
for in_type in types:
|
||||
for out_type in types:
|
||||
for r in rnd:
|
||||
for ftz in ftz_all:
|
||||
for s in sat:
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
f_name = os.path.join(dir, 'ptx')
|
||||
out_name = os.path.join(dir, 'out')
|
||||
with open(f_name, 'w') as f:
|
||||
f.write(
|
||||
f"""
|
||||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
.visible .entry VecAdd_kernel()
|
||||
{{
|
||||
.reg.{in_type} r1;
|
||||
.reg.{out_type} r2;
|
||||
cvt{r}{ftz}{s}.{out_type}.{in_type} r2, r1;
|
||||
ret;
|
||||
}}
|
||||
""")
|
||||
err = subprocess.run(f"ptxas {f_name} -o {out_name}", capture_output = True)
|
||||
if err.returncode == 0:
|
||||
print(f"cvt{r}{ftz}{s}.{out_type}.{in_type}")
|
||||
#else:
|
||||
# print(f"[INVALID] cvt{r}{ftz}{s}.{out_type}.{in_type}")
|
Loading…
Add table
Reference in a new issue