Fix gapless decode and combine split buffers
Some checks are pending
Build and Release / reuse (push) Waiting to run
Build and Release / clang-format (push) Waiting to run
Build and Release / get-info (push) Waiting to run
Build and Release / windows-sdl (push) Blocked by required conditions
Build and Release / windows-qt (push) Blocked by required conditions
Build and Release / macos-sdl (push) Blocked by required conditions
Build and Release / macos-qt (push) Blocked by required conditions
Build and Release / linux-sdl (push) Blocked by required conditions
Build and Release / linux-qt (push) Blocked by required conditions
Build and Release / pre-release (push) Blocked by required conditions

This commit is contained in:
Vladislav Mikhalin 2024-10-29 20:37:04 +03:00
parent da5f7f232a
commit d7f78e6720
6 changed files with 177 additions and 103 deletions

View file

@ -165,24 +165,8 @@ struct AjmDevice {
p_instance->gapless.skip_samples = params.skip_samples;
}
ASSERT_MSG(job.input.buffers.size() <= job.output.buffers.size(),
"Unsupported combination of input/output buffers.");
for (size_t i = 0; i < job.input.buffers.size(); ++i) {
// Decode as much of the input bitstream as possible.
const auto& in_buffer = job.input.buffers[i];
auto& out_buffer = job.output.buffers[i];
const u8* in_address = in_buffer.data();
u8* out_address = out_buffer.data();
const auto [in_remain, out_remain] = p_instance->Decode(
in_address, in_buffer.size(), out_address, out_buffer.size(), &job.output);
if (job.output.p_stream != nullptr) {
job.output.p_stream->input_consumed += in_buffer.size() - in_remain;
job.output.p_stream->output_written += out_buffer.size() - out_remain;
job.output.p_stream->total_decoded_samples += p_instance->decoded_samples;
}
if (!job.input.buffer.empty()) {
p_instance->Decode(&job.input, &job.output);
}
if (job.output.p_gapless_decode != nullptr) {
@ -439,7 +423,6 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size
const auto batch_info = std::make_shared<BatchInfo>();
auto batch_id = dev->batches.Create(batch_info);
if (!batch_id.has_value()) {
LOG_ERROR(Lib_Ajm, "Too many batches in job!");
return ORBIS_AJM_ERROR_OUT_OF_MEMORY;
}
batch_info->id = batch_id.value();
@ -471,7 +454,7 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size
case Identifier::AjmIdentInputRunBuf: {
auto& buffer = AjmBufferExtract<AjmChunkBuffer>(p_current);
u8* p_begin = reinterpret_cast<u8*>(buffer.p_address);
job.input.buffers.emplace_back(
job.input.buffer.append_range(
std::vector<u8>(p_begin, p_begin + buffer.header.size));
break;
}
@ -614,7 +597,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
std::lock_guard guard(dev->batches_mutex);
const auto opt_batch = dev->batches.Get(batch_id);
if (!opt_batch.has_value()) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_INVALID_BATCH");
return ORBIS_AJM_ERROR_INVALID_BATCH;
}
@ -623,7 +605,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
bool expected = false;
if (!batch->waiting.compare_exchange_strong(expected, true)) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_BUSY");
return ORBIS_AJM_ERROR_BUSY;
}
@ -631,7 +612,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
batch->finished.acquire();
} else if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) {
batch->waiting = false;
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_IN_PROGRESS");
return ORBIS_AJM_ERROR_IN_PROGRESS;
}
@ -641,11 +621,9 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
}
if (batch->canceled) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_CANCELLED");
return ORBIS_AJM_ERROR_CANCELLED;
}
LOG_INFO(Lib_Ajm, "ORBIS_OK");
return ORBIS_OK;
}
@ -656,7 +634,7 @@ int PS4_SYSV_ABI sceAjmDecAt9ParseConfigData() {
int PS4_SYSV_ABI sceAjmDecMp3ParseFrame(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame) {
LOG_INFO(Lib_Ajm, "called parse_ofl = {}", parse_ofl);
LOG_INFO(Lib_Ajm, "called stream_size = {} parse_ofl = {}", stream_size, parse_ofl);
if (buf == nullptr || stream_size < 4 || frame == nullptr) {
return ORBIS_AJM_ERROR_INVALID_PARAMETER;
}
@ -688,6 +666,9 @@ int PS4_SYSV_ABI sceAjmInstanceCodecType() {
int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmInstanceFlags flags,
u32* out_instance) {
LOG_INFO(Lib_Ajm, "called context = {}, codec_type = {}, flags = {:#x}", context,
magic_enum::enum_name(codec_type), flags.raw);
if (codec_type >= AjmCodecType::Max) {
return ORBIS_AJM_ERROR_INVALID_PARAMETER;
}
@ -720,8 +701,8 @@ int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmI
instance->flags = flags;
dev->instances[index] = std::move(instance);
*out_instance = index;
LOG_INFO(Lib_Ajm, "called codec_type = {}, flags = {:#x}, instance = {}",
magic_enum::enum_name(codec_type), flags.raw, index);
LOG_INFO(Lib_Ajm, "instance = {}", index);
return ORBIS_OK;
}

View file

@ -26,17 +26,23 @@ AjmAt9Decoder::~AjmAt9Decoder() {
}
void AjmAt9Decoder::Reset() {
num_frames = 0;
decoded_samples = 0;
total_decoded_samples = 0;
gapless = {};
ResetCodec();
}
void AjmAt9Decoder::ResetCodec() {
Atrac9ReleaseHandle(handle);
handle = Atrac9GetHandle();
Atrac9InitDecoder(handle, config_data);
Atrac9CodecInfo codec_info;
Atrac9GetCodecInfo(handle, &codec_info);
bytes_remain = codec_info.superframeSize;
num_frames = 0;
superframe_bytes_remain = codec_info.superframeSize;
gapless.skipped_samples = 0;
gapless_decoded_samples = 0;
}
void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) {
@ -58,72 +64,106 @@ void AjmAt9Decoder::GetCodecInfo(void* out_info) {
codec_info->uiSuperFrameSize = decoder_codec_info.superframeSize;
}
std::tuple<u32, u32> AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size_in, u8* out_buf,
u32 out_size_in, AjmJobOutput* output) {
const auto decoder_handle = static_cast<Atrac9Handle*>(handle);
void AjmAt9Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) {
LOG_TRACE(Lib_Ajm, "Decoding with instance {} in size = {}", index, input->buffer.size());
Atrac9CodecInfo codec_info;
Atrac9GetCodecInfo(handle, &codec_info);
int bytes_used = 0;
int num_superframes = 0;
u32 in_size = in_size_in;
u32 out_size = out_size_in;
const auto ShouldDecode = [&] {
if (in_size == 0 || out_size == 0) {
size_t out_buffer_index = 0;
std::span<const u8> in_buf(input->buffer);
std::span<u8> out_buf = output->buffers[out_buffer_index];
const auto should_decode = [&] {
if (in_buf.empty() || out_buf.empty()) {
return false;
}
if (gapless.total_samples != 0 && gapless.total_samples < decoded_samples) {
if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) {
return false;
}
return true;
};
const auto written_size = codec_info.channels * codec_info.frameSamples * sizeof(u16);
std::vector<s16> pcm_buffer(written_size >> 1);
while (ShouldDecode()) {
u32 ret = Atrac9Decode(decoder_handle, in_buf, pcm_buffer.data(), &bytes_used);
ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret);
in_buf += bytes_used;
in_size -= bytes_used;
if (output->p_mframe) {
++output->p_mframe->num_frames;
const auto write_output = [&](std::span<s16> pcm) {
while (!pcm.empty()) {
auto size = std::min(pcm.size() * sizeof(u16), out_buf.size());
std::memcpy(out_buf.data(), pcm.data(), size);
pcm = pcm.subspan(size >> 1);
out_buf = out_buf.subspan(size);
if (out_buf.empty()) {
out_buffer_index += 1;
if (out_buffer_index >= output->buffers.size()) {
return pcm.empty();
}
out_buf = output->buffers[out_buffer_index];
}
}
num_frames++;
bytes_remain -= bytes_used;
return true;
};
int num_superframes = 0;
const auto pcm_frame_size = codec_info.channels * codec_info.frameSamples * sizeof(u16);
std::vector<s16> pcm_buffer(pcm_frame_size >> 1);
while (should_decode()) {
int bytes_used = 0;
u32 ret = Atrac9Decode(handle, in_buf.data(), pcm_buffer.data(), &bytes_used);
ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret);
in_buf = in_buf.subspan(bytes_used);
superframe_bytes_remain -= bytes_used;
const size_t samples_remain = gapless.total_samples != 0
? gapless.total_samples - gapless_decoded_samples
: std::numeric_limits<size_t>::max();
bool written = false;
if (gapless.skipped_samples < gapless.skip_samples) {
gapless.skipped_samples += decoder_handle->Config.FrameSamples;
gapless.skipped_samples += codec_info.frameSamples;
if (gapless.skipped_samples > gapless.skip_samples) {
const auto size = gapless.skipped_samples - gapless.skip_samples;
const auto start = decoder_handle->Config.FrameSamples - size;
memcpy(out_buf, pcm_buffer.data() + start, size * sizeof(s16));
out_buf += size * sizeof(s16);
out_size -= size * sizeof(s16);
const u32 nsamples = gapless.skipped_samples - gapless.skip_samples;
const auto start = codec_info.frameSamples - nsamples;
written = write_output({pcm_buffer.data() + start, nsamples});
gapless.skipped_samples = gapless.skip_samples;
total_decoded_samples += nsamples;
gapless_decoded_samples += nsamples;
}
} else {
memcpy(out_buf, pcm_buffer.data(), written_size);
out_buf += written_size;
out_size -= written_size;
written =
write_output({pcm_buffer.data(), std::min(pcm_buffer.size(), samples_remain)});
total_decoded_samples += codec_info.frameSamples;
gapless_decoded_samples += codec_info.frameSamples;
}
decoded_samples += decoder_handle->Config.FrameSamples;
num_frames += 1;
if ((num_frames % codec_info.framesInSuperframe) == 0) {
in_buf += bytes_remain;
in_size -= bytes_remain;
bytes_remain = codec_info.superframeSize;
num_superframes++;
if (superframe_bytes_remain) {
if (output->p_stream) {
output->p_stream->input_consumed += superframe_bytes_remain;
}
in_buf = in_buf.subspan(superframe_bytes_remain);
}
superframe_bytes_remain = codec_info.superframeSize;
num_superframes += 1;
}
if (output->p_stream) {
output->p_stream->input_consumed += bytes_used;
if (written) {
output->p_stream->output_written +=
std::min(pcm_frame_size, samples_remain * sizeof(16));
}
}
if (output->p_mframe) {
output->p_mframe->num_frames += 1;
}
}
if (gapless.total_samples == decoded_samples) {
decoded_samples = 0;
if (gapless_decoded_samples >= gapless.total_samples) {
if (flags.gapless_loop) {
gapless.skipped_samples = 0;
ResetCodec();
}
}
LOG_TRACE(Lib_Ajm, "Decoded {} samples, frame count: {}", decoded_samples, num_frames);
return std::tuple(in_size, out_size);
if (output->p_stream) {
output->p_stream->total_decoded_samples = total_decoded_samples;
}
LOG_TRACE(Lib_Ajm, "Decoded buffer, in remain = {}, out remain = {}", in_buf.size(),
out_buf.size());
}
} // namespace Libraries::Ajm

View file

@ -30,6 +30,8 @@ struct AjmAt9Decoder final : AjmInstance {
std::fstream file;
int length;
u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE];
u32 superframe_bytes_remain{};
u32 num_frames{};
explicit AjmAt9Decoder();
~AjmAt9Decoder() override;
@ -43,8 +45,10 @@ struct AjmAt9Decoder final : AjmInstance {
return sizeof(AjmSidebandDecAt9CodecInfo);
}
std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size,
AjmJobOutput* output) override;
void Decode(const AjmJobInput* input, AjmJobOutput* output) override;
private:
void ResetCodec();
};
} // namespace Libraries::Ajm

View file

@ -101,7 +101,7 @@ struct AjmJobInput {
std::optional<AjmSidebandResampleParameters> resample_parameters;
std::optional<AjmSidebandFormat> format;
std::optional<AjmSidebandGaplessDecode> gapless_decode;
boost::container::small_vector<std::vector<u8>, 4> buffers;
std::vector<u8> buffer;
};
struct AjmJobOutput {
@ -132,9 +132,8 @@ struct AjmInstance {
AjmInstanceFlags flags{.raw = 0};
u32 num_channels{};
u32 index{};
u32 bytes_remain{};
u32 num_frames{};
u32 decoded_samples{};
u32 gapless_decoded_samples{};
u32 total_decoded_samples{};
AjmSidebandFormat format{};
AjmSidebandGaplessDecode gapless{};
AjmSidebandResampleParameters resample_parameters{};
@ -149,8 +148,7 @@ struct AjmInstance {
virtual void GetCodecInfo(void* out_info) = 0;
virtual u32 GetCodecInfoSize() = 0;
virtual std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size,
AjmJobOutput* output) = 0;
virtual void Decode(const AjmJobInput* input, AjmJobOutput* output) = 0;
};
} // namespace Libraries::Ajm

View file

@ -73,20 +73,51 @@ void AjmMp3Decoder::Reset() {
ASSERT_MSG(c, "Could not allocate audio codec context");
int ret = avcodec_open2(c, codec, nullptr);
ASSERT_MSG(ret >= 0, "Could not open codec");
decoded_samples = 0;
num_frames = 0;
total_decoded_samples = 0;
}
std::tuple<u32, u32> AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_buf, u32 out_size,
AjmJobOutput* output) {
void AjmMp3Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) {
AVPacket* pkt = av_packet_alloc();
while (in_size > 0 && out_size > 0) {
int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, buf, in_size, AV_NOPTS_VALUE,
AV_NOPTS_VALUE, 0);
ASSERT_MSG(ret >= 0, "Error while parsing {}", ret);
buf += ret;
in_size -= ret;
size_t out_buffer_index = 0;
std::span<const u8> in_buf(input->buffer);
std::span<u8> out_buf = output->buffers[out_buffer_index];
const auto should_decode = [&] {
if (in_buf.empty() || out_buf.empty()) {
return false;
}
if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) {
return false;
}
return true;
};
const auto write_output = [&](std::span<s16> pcm) {
while (!pcm.empty()) {
auto size = std::min(pcm.size() * sizeof(u16), out_buf.size());
std::memcpy(out_buf.data(), pcm.data(), size);
pcm = pcm.subspan(size >> 1);
out_buf = out_buf.subspan(size);
if (out_buf.empty()) {
out_buffer_index += 1;
if (out_buffer_index >= output->buffers.size()) {
return pcm.empty();
}
out_buf = output->buffers[out_buffer_index];
}
}
return true;
};
while (should_decode()) {
int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, in_buf.data(), in_buf.size(),
AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0);
ASSERT_MSG(ret >= 0, "Error while parsing {}", ret);
in_buf = in_buf.subspan(ret);
if (output->p_stream) {
output->p_stream->input_consumed += ret;
}
if (pkt->size) {
// Send the packet with the compressed data to the decoder
pkt->pts = parser->pts;
@ -107,22 +138,43 @@ std::tuple<u32, u32> AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_b
if (frame->format != AV_SAMPLE_FMT_S16) {
frame = ConvertAudioFrame(frame);
}
const auto size = frame->ch_layout.nb_channels * frame->nb_samples * sizeof(u16);
std::memcpy(out_buf, frame->data[0], size);
file.write((const char*)frame->data[0], size);
out_buf += size;
out_size -= size;
decoded_samples += frame->nb_samples;
num_frames++;
const auto frame_samples = frame->ch_layout.nb_channels * frame->nb_samples;
const auto size = frame_samples * sizeof(u16);
if (gapless.skipped_samples < gapless.skip_samples) {
gapless.skipped_samples += frame_samples;
if (gapless.skipped_samples > gapless.skip_samples) {
const u32 nsamples = gapless.skipped_samples - gapless.skip_samples;
const auto start = frame_samples - nsamples;
write_output({reinterpret_cast<s16*>(frame->data[0]), nsamples});
gapless.skipped_samples = gapless.skip_samples;
total_decoded_samples += nsamples;
gapless_decoded_samples += nsamples;
}
} else {
write_output({reinterpret_cast<s16*>(frame->data[0]), size >> 1});
total_decoded_samples += frame_samples;
gapless_decoded_samples += frame_samples;
}
av_frame_free(&frame);
if (output->p_stream) {
output->p_stream->output_written += size;
}
if (output->p_mframe) {
output->p_mframe->num_frames += 1;
}
}
}
}
av_packet_free(&pkt);
if (output->p_mframe) {
output->p_mframe->num_frames += num_frames;
if (gapless_decoded_samples >= gapless.total_samples) {
if (flags.gapless_loop) {
gapless.skipped_samples = 0;
gapless_decoded_samples = 0;
}
}
if (output->p_stream) {
output->p_stream->total_decoded_samples = total_decoded_samples;
}
return std::make_tuple(in_size, out_size);
}
int AjmMp3Decoder::ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,

View file

@ -74,8 +74,7 @@ struct AjmMp3Decoder : public AjmInstance {
return sizeof(AjmSidebandDecMp3CodecInfo);
}
std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size,
AjmJobOutput* output) override;
void Decode(const AjmJobInput* input, AjmJobOutput* output) override;
static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame);