diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index ec562fe..ec49925 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -9,6 +9,8 @@ quick_error! { display("{}", err) cause(err) } + SyntaxError {} + NonF32Ftz {} } } @@ -101,9 +103,11 @@ pub enum ScalarType { impl From for ScalarType { fn from(t: IntType) -> Self { match t { + IntType::S8 => ScalarType::S8, IntType::S16 => ScalarType::S16, IntType::S32 => ScalarType::S32, IntType::S64 => ScalarType::S64, + IntType::U8 => ScalarType::U8, IntType::U16 => ScalarType::U16, IntType::U32 => ScalarType::U32, IntType::U64 => ScalarType::U64, @@ -113,14 +117,38 @@ impl From for ScalarType { #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum IntType { + U8, U16, U32, U64, + S8, S16, S32, S64, } +impl IntType { + pub fn is_signed(self) -> bool { + match self { + IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, + IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, + } + } + + pub fn width(self) -> u8 { + match self { + IntType::U8 => 1, + IntType::U16 => 2, + IntType::U32 => 4, + IntType::U64 => 8, + IntType::S8 => 1, + IntType::S16 => 2, + IntType::S32 => 4, + IntType::S64 => 8, + } + } +} + #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum FloatType { F16, @@ -178,7 +206,7 @@ pub enum Instruction { SetpBool(SetpBoolData, Arg5

), Not(NotType, Arg2

), Bra(BraData, Arg1

), - Cvt(CvtData, Arg2

), + Cvt(CvtDetails, Arg2

), Shl(ShlType, Arg3

), St(StData, Arg2St

), Ret(RetData), @@ -398,7 +426,88 @@ pub struct BraData { pub uniform: bool, } -pub struct CvtData {} +pub enum CvtDetails { + IntFromInt(CvtIntToIntDesc), + FloatFromFloat(CvtDesc), + IntFromFloat(CvtDesc), + FloatFromInt(CvtDesc), +} + +pub struct CvtIntToIntDesc { + pub dst: IntType, + pub src: IntType, + pub saturate: bool, +} + +pub struct CvtDesc { + pub rounding: Option, + pub flush_to_zero: bool, + pub saturate: bool, + pub dst: Dst, + pub src: Src, +} + +impl CvtDetails { + pub fn new_int_from_int_checked( + saturate: bool, + dst: IntType, + src: IntType, + err: &mut Vec, + ) -> Self { + if saturate { + if src.is_signed() { + if dst.is_signed() && dst.width() >= src.width() { + err.push(PtxError::SyntaxError); + } + } else { + if dst == src || dst.width() >= src.width() { + err.push(PtxError::SyntaxError); + } + } + } + CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate }) + } + + pub fn new_float_from_int_checked( + rounding: RoundingMode, + flush_to_zero: bool, + saturate: bool, + dst: FloatType, + src: IntType, + err: &mut Vec, + ) -> Self { + if flush_to_zero && dst != FloatType::F32 { + err.push(PtxError::NonF32Ftz); + } + CvtDetails::FloatFromInt(CvtDesc { + dst, + src, + saturate, + flush_to_zero, + rounding: Some(rounding), + }) + } + + pub fn new_int_from_float_checked( + rounding: RoundingMode, + flush_to_zero: bool, + saturate: bool, + dst: IntType, + src: FloatType, + err: &mut Vec, + ) -> Self { + if flush_to_zero && src != FloatType::F32 { + err.push(PtxError::NonF32Ftz); + } + CvtDetails::IntFromFloat(CvtDesc { + dst, + src, + saturate, + flush_to_zero, + rounding: Some(rounding), + }) + } +} #[derive(PartialEq, Eq, Copy, Clone)] pub enum ShlType { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index bd5678e..5f97e6c 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -403,13 +403,13 @@ InstMulMode: ast::MulDetails = { typ: t, control: ctr }), - ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { + ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { typ: ast::FloatType::F32, rounding: r, flush_to_zero: ftz.is_some(), saturate: s.is_some() }), - ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { + ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { typ: ast::FloatType::F64, rounding: r, flush_to_zero: false, @@ -436,13 +436,20 @@ MulIntControl: ast::MulIntControl = { }; #[inline] -RoundingMode : ast::RoundingMode = { +RoundingModeFloat : ast::RoundingMode = { ".rn" => ast::RoundingMode::NearestEven, ".rz" => ast::RoundingMode::Zero, ".rm" => ast::RoundingMode::NegativeInf, ".rp" => ast::RoundingMode::PositiveInf, }; +RoundingModeInt : ast::RoundingMode = { + ".rni" => ast::RoundingMode::NearestEven, + ".rzi" => ast::RoundingMode::Zero, + ".rmi" => ast::RoundingMode::NegativeInf, + ".rpi" => ast::RoundingMode::PositiveInf, +}; + IntType : ast::IntType = { ".u16" => ast::IntType::U16, ".u32" => ast::IntType::U32, @@ -468,13 +475,13 @@ InstAddMode: ast::AddDetails = { typ: ast::IntType::S32, saturate: true, }), - ".f32" => ast::AddDetails::Float(ast::AddFloatDesc { + ".f32" => ast::AddDetails::Float(ast::AddFloatDesc { typ: ast::FloatType::F32, rounding: rn, flush_to_zero: ftz.is_some(), saturate: sat.is_some(), }), - ".f64" => ast::AddDetails::Float(ast::AddFloatDesc { + ".f64" => ast::AddDetails::Float(ast::AddFloatDesc { typ: ast::FloatType::F64, rounding: rn, flush_to_zero: false, @@ -580,28 +587,153 @@ InstBra: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt InstCvt: ast::Instruction> = { - "cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType => { - ast::Instruction::Cvt(ast::CvtData{}, a) - } + "cvt" => { + ast::Instruction::Cvt(ast::CvtDetails::new_int_from_int_checked( + s.is_some(), + dst_t, + src_t, + errors + ), + a) + }, + "cvt" => { + ast::Instruction::Cvt(ast::CvtDetails::new_float_from_int_checked( + r, + f.is_some(), + s.is_some(), + dst_t, + src_t, + errors + ), + a) + }, + "cvt" => { + ast::Instruction::Cvt(ast::CvtDetails::new_int_from_float_checked( + r, + f.is_some(), + s.is_some(), + dst_t, + src_t, + errors + ), + a) + }, + "cvt" ".f16" ".f16" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: r, + flush_to_zero: false, + saturate: s.is_some(), + dst: ast::FloatType::F16, + src: ast::FloatType::F16 + } + ), a) + }, + "cvt" ".f32" ".f16" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: None, + flush_to_zero: f.is_some(), + saturate: s.is_some(), + dst: ast::FloatType::F32, + src: ast::FloatType::F16 + } + ), a) + }, + "cvt" ".f64" ".f16" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: None, + flush_to_zero: false, + saturate: s.is_some(), + dst: ast::FloatType::F64, + src: ast::FloatType::F16 + } + ), a) + }, + "cvt" ".f16" ".f32" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: Some(r), + flush_to_zero: f.is_some(), + saturate: s.is_some(), + dst: ast::FloatType::F16, + src: ast::FloatType::F32 + } + ), a) + }, + "cvt" ".f32" ".f32" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: r, + flush_to_zero: f.is_some(), + saturate: s.is_some(), + dst: ast::FloatType::F32, + src: ast::FloatType::F32 + } + ), a) + }, + "cvt" ".f64" ".f32" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: None, + flush_to_zero: false, + saturate: s.is_some(), + dst: ast::FloatType::F64, + src: ast::FloatType::F32 + } + ), a) + }, + "cvt" ".f16" ".f64" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: Some(r), + flush_to_zero: false, + saturate: s.is_some(), + dst: ast::FloatType::F16, + src: ast::FloatType::F64 + } + ), a) + }, + "cvt" ".f32" ".f64" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: Some(r), + flush_to_zero: s.is_some(), + saturate: s.is_some(), + dst: ast::FloatType::F32, + src: ast::FloatType::F64 + } + ), a) + }, + "cvt" ".f64" ".f64" => { + ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( + ast::CvtDesc { + rounding: r, + flush_to_zero: false, + saturate: s.is_some(), + dst: ast::FloatType::F64, + src: ast::FloatType::F64 + } + ), a) + }, }; -CvtRnd = { - CvtIrnd, - CvtFrnd -} - -CvtIrnd = { - ".rni", ".rzi", ".rmi", ".rpi" +CvtTypeInt: ast::IntType = { + ".u8" => ast::IntType::U8, + ".u16" => ast::IntType::U16, + ".u32" => ast::IntType::U32, + ".u64" => ast::IntType::U64, + ".s8" => ast::IntType::S8, + ".s16" => ast::IntType::S16, + ".s32" => ast::IntType::S32, + ".s64" => ast::IntType::S64, }; -CvtFrnd = { - ".rn", ".rz", ".rm", ".rp" -}; - -CvtType = { - ".u8", ".u16", ".u32", ".u64", - ".s8", ".s16", ".s32", ".s64", - ".f16", ".f32", ".f64" +CvtTypeFloat: ast::FloatType = { + ".f16" => ast::FloatType::F16, + ".f32" => ast::FloatType::F32, + ".f64" => ast::FloatType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7091fc9..9e51046 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1580,13 +1580,6 @@ impl ast::MulDetails { } impl ast::IntType { - fn is_signed(self) -> bool { - match self { - ast::IntType::S16 | ast::IntType::S32 | ast::IntType::S64 => true, - ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false, - } - } - fn try_new(t: ast::ScalarType) -> Option { match t { ast::ScalarType::U16 => Some(ast::IntType::U16), diff --git a/ptx/tools/cvt.py b/ptx/tools/cvt.py new file mode 100644 index 0000000..ab6e5ce --- /dev/null +++ b/ptx/tools/cvt.py @@ -0,0 +1,36 @@ +import os +import subprocess +import tempfile + +types = ["u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f32", "f64"] +rnd = ["", ".rn", ".rni"] +ftz_all = ["", ".ftz"] +sat = ["", ".sat"] + +for in_type in types: + for out_type in types: + for r in rnd: + for ftz in ftz_all: + for s in sat: + with tempfile.TemporaryDirectory() as dir: + f_name = os.path.join(dir, 'ptx') + out_name = os.path.join(dir, 'out') + with open(f_name, 'w') as f: + f.write( + f""" + .version 6.5 + .target sm_30 + .address_size 64 + .visible .entry VecAdd_kernel() + {{ + .reg.{in_type} r1; + .reg.{out_type} r2; + cvt{r}{ftz}{s}.{out_type}.{in_type} r2, r1; + ret; + }} + """) + err = subprocess.run(f"ptxas {f_name} -o {out_name}", capture_output = True) + if err.returncode == 0: + print(f"cvt{r}{ftz}{s}.{out_type}.{in_type}") + #else: + # print(f"[INVALID] cvt{r}{ftz}{s}.{out_type}.{in_type}") \ No newline at end of file