// Copyright 2008 Dolphin Emulator Project
// Licensed under GPLv2+
// Refer to the license.txt file included.

#include "Common/x64ABI.h"
#include "Common/CommonTypes.h"
#include "Common/x64Emitter.h"

using namespace Gen;

// Shared code between Win64 and Unix64

void XEmitter::ABI_CalculateFrameSize(BitSet32 mask, size_t rsp_alignment, size_t needed_frame_size,
                                      size_t* shadowp, size_t* subtractionp, size_t* xmm_offsetp)
{
  size_t shadow = 0;
#if defined(_WIN32)
  shadow = 0x20;
#endif

  int count = (mask & ABI_ALL_GPRS).Count();
  rsp_alignment -= count * 8;
  size_t subtraction = 0;
  int fpr_count = (mask & ABI_ALL_FPRS).Count();
  if (fpr_count)
  {
    // If we have any XMMs to save, we must align the stack here.
    subtraction = rsp_alignment & 0xf;
  }
  subtraction += 16 * fpr_count;
  size_t xmm_base_subtraction = subtraction;
  subtraction += needed_frame_size;
  subtraction += shadow;
  // Final alignment.
  rsp_alignment -= subtraction;
  subtraction += rsp_alignment & 0xf;

  *shadowp = shadow;
  *subtractionp = subtraction;
  *xmm_offsetp = subtraction - xmm_base_subtraction;
}

size_t XEmitter::ABI_PushRegistersAndAdjustStack(BitSet32 mask, size_t rsp_alignment,
                                                 size_t needed_frame_size)
{
  size_t shadow, subtraction, xmm_offset;
  ABI_CalculateFrameSize(mask, rsp_alignment, needed_frame_size, &shadow, &subtraction,
                         &xmm_offset);

  for (int r : mask& ABI_ALL_GPRS)
    PUSH((X64Reg)r);

  if (subtraction)
    SUB(64, R(RSP), subtraction >= 0x80 ? Imm32((u32)subtraction) : Imm8((u8)subtraction));

  for (int x : mask& ABI_ALL_FPRS)
  {
    MOVAPD(MDisp(RSP, (int)xmm_offset), (X64Reg)(x - 16));
    xmm_offset += 16;
  }

  return shadow;
}

void XEmitter::ABI_PopRegistersAndAdjustStack(BitSet32 mask, size_t rsp_alignment,
                                              size_t needed_frame_size)
{
  size_t shadow, subtraction, xmm_offset;
  ABI_CalculateFrameSize(mask, rsp_alignment, needed_frame_size, &shadow, &subtraction,
                         &xmm_offset);

  for (int x : mask& ABI_ALL_FPRS)
  {
    MOVAPD((X64Reg)(x - 16), MDisp(RSP, (int)xmm_offset));
    xmm_offset += 16;
  }

  if (subtraction)
    ADD(64, R(RSP), subtraction >= 0x80 ? Imm32((u32)subtraction) : Imm8((u8)subtraction));

  for (int r = 15; r >= 0; r--)
  {
    if (mask[r])
      POP((X64Reg)r);
  }
}

// Common functions
void XEmitter::ABI_CallFunction(const void* func)
{
  u64 distance = u64(func) - (u64(code) + 5);
  if (distance >= 0x0000000080000000ULL && distance < 0xFFFFFFFF80000000ULL)
  {
    // Far call
    MOV(64, R(RAX), Imm64((u64)func));
    CALLptr(R(RAX));
  }
  else
  {
    CALL(func);
  }
}

void XEmitter::ABI_CallFunctionC16(const void* func, u16 param1)
{
  MOV(32, R(ABI_PARAM1), Imm32((u32)param1));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCC16(const void* func, u32 param1, u16 param2)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(32, R(ABI_PARAM2), Imm32((u32)param2));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionC(const void* func, u32 param1)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCC(const void* func, u32 param1, u32 param2)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCP(const void* func, u32 param1, void* param2)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(64, R(ABI_PARAM2), Imm64((u64)param2));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCCC(const void* func, u32 param1, u32 param2, u32 param3)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  MOV(32, R(ABI_PARAM3), Imm32(param3));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCCP(const void* func, u32 param1, u32 param2, void* param3)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  MOV(64, R(ABI_PARAM3), Imm64((u64)param3));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionCCCP(const void* func, u32 param1, u32 param2, u32 param3,
                                    void* param4)
{
  MOV(32, R(ABI_PARAM1), Imm32(param1));
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  MOV(32, R(ABI_PARAM3), Imm32(param3));
  MOV(64, R(ABI_PARAM4), Imm64((u64)param4));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionPC(const void* func, void* param1, u32 param2)
{
  MOV(64, R(ABI_PARAM1), Imm64((u64)param1));
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionPPC(const void* func, void* param1, void* param2, u32 param3)
{
  MOV(64, R(ABI_PARAM1), Imm64((u64)param1));
  MOV(64, R(ABI_PARAM2), Imm64((u64)param2));
  MOV(32, R(ABI_PARAM3), Imm32(param3));
  ABI_CallFunction(func);
}

// Pass a register as a parameter.
void XEmitter::ABI_CallFunctionR(const void* func, X64Reg reg1)
{
  if (reg1 != ABI_PARAM1)
    MOV(32, R(ABI_PARAM1), R(reg1));
  ABI_CallFunction(func);
}

// Pass two registers as parameters.
void XEmitter::ABI_CallFunctionRR(const void* func, X64Reg reg1, X64Reg reg2)
{
  MOVTwo(64, ABI_PARAM1, reg1, 0, ABI_PARAM2, reg2);
  ABI_CallFunction(func);
}

void XEmitter::MOVTwo(int bits, Gen::X64Reg dst1, Gen::X64Reg src1, s32 offset1, Gen::X64Reg dst2,
                      Gen::X64Reg src2)
{
  if (dst1 == src2 && dst2 == src1)
  {
    XCHG(bits, R(src1), R(src2));
    if (offset1)
      ADD(bits, R(dst1), Imm32(offset1));
  }
  else if (src2 != dst1)
  {
    if (dst1 != src1 && offset1)
      LEA(bits, dst1, MDisp(src1, offset1));
    else if (dst1 != src1)
      MOV(bits, R(dst1), R(src1));
    else if (offset1)
      ADD(bits, R(dst1), Imm32(offset1));
    if (dst2 != src2)
      MOV(bits, R(dst2), R(src2));
  }
  else
  {
    if (dst2 != src2)
      MOV(bits, R(dst2), R(src2));
    if (dst1 != src1 && offset1)
      LEA(bits, dst1, MDisp(src1, offset1));
    else if (dst1 != src1)
      MOV(bits, R(dst1), R(src1));
    else if (offset1)
      ADD(bits, R(dst1), Imm32(offset1));
  }
}

void XEmitter::ABI_CallFunctionAC(int bits, const void* func, const Gen::OpArg& arg1, u32 param2)
{
  if (!arg1.IsSimpleReg(ABI_PARAM1))
    MOV(bits, R(ABI_PARAM1), arg1);
  MOV(32, R(ABI_PARAM2), Imm32(param2));
  ABI_CallFunction(func);
}

void XEmitter::ABI_CallFunctionA(int bits, const void* func, const Gen::OpArg& arg1)
{
  if (!arg1.IsSimpleReg(ABI_PARAM1))
    MOV(bits, R(ABI_PARAM1), arg1);
  ABI_CallFunction(func);
}