diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 0281961..93793e6 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -380,6 +380,7 @@ pub enum Instruction { }, MadCC { type_: ScalarType, + is_hi: bool, arg: Arg4

, }, Fma(ArithFloat, Arg4

), diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index 94cc973..5a68bb8 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -621,8 +621,8 @@ fn emit_statement( crate::translate::Statement::MadC(MadCDetails { type_, is_hi, arg }) => { emit_inst_madc(ctx, type_, is_hi, &arg)? } - crate::translate::Statement::MadCC(MadCCDetails { type_, arg }) => { - emit_inst_madcc(ctx, type_, &arg)? + crate::translate::Statement::MadCC(MadCCDetails { type_, is_hi, arg }) => { + emit_inst_madcc(ctx, type_, is_hi, &arg)? } crate::translate::Statement::AddC(type_, arg) => emit_inst_add_c(ctx, type_, &arg)?, crate::translate::Statement::AddCC(type_, arg) => { @@ -2083,6 +2083,7 @@ fn emit_inst_mad_lo( fn emit_inst_madcc( ctx: &mut EmitContext, type_: ast::ScalarType, + is_hi: bool, arg: &Arg4CarryOut, ) -> Result<(), TranslateError> { let builder = ctx.builder.get(); diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ae57575..d5c9b61 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1516,7 +1516,12 @@ InstMad: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-mad-cc InstMadCC: ast::Instruction> = { - "mad" ".lo" ".cc" => ast::Instruction::MadCC{<>}, + "mad" ".lo" ".cc" => { + ast::Instruction::MadCC { type_, arg, is_hi: false } + }, + "mad" ".hi" ".cc" => { + ast::Instruction::MadCC { type_, arg, is_hi: true } + }, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-madc diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 041c690..1a203bd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1999,9 +1999,10 @@ fn insert_hardware_registers_impl<'input>( is_hi, arg: Arg4CarryIn::new(arg, carry_out, TypedOperand::Reg(overflow_flag)), })), - Statement::Instruction(ast::Instruction::MadCC { type_, arg }) => { + Statement::Instruction(ast::Instruction::MadCC { type_, is_hi, arg }) => { result.push(Statement::MadCC(MadCCDetails { type_, + is_hi, arg: Arg4CarryOut::new(arg, TypedOperand::Reg(overflow_flag)), })) } @@ -5568,6 +5569,7 @@ impl, U: ArgParamsEx> Visitable for MadCD pub(crate) struct MadCCDetails { pub(crate) type_: ast::ScalarType, + pub(crate) is_hi: bool, pub(crate) arg: Arg4CarryOut

, } @@ -5578,6 +5580,7 @@ impl, U: ArgParamsEx> Visitable for MadCC ) -> Result, U>, TranslateError> { Ok(Statement::MadCC(MadCCDetails { type_: self.type_, + is_hi: self.is_hi, arg: self.arg.map(visitor, self.type_)?, })) } @@ -6486,8 +6489,9 @@ impl ast::Instruction { carry_out, arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?, }, - ast::Instruction::MadCC { type_, arg } => ast::Instruction::MadCC { + ast::Instruction::MadCC { type_, arg, is_hi } => ast::Instruction::MadCC { type_, + is_hi, arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?, }, ast::Instruction::Tex(details, arg) => {