diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp index 313a1deb30..8978b75a31 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp @@ -4,6 +4,7 @@ // SPDX-FileCopyrightText: Copyright 2021 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include #include #include #include @@ -24,6 +25,120 @@ template struct FuncTraits {}; thread_local std::unique_ptr thread_optimizer; +static bool TryEmitCoalescedStorage(EmitContext& ctx, IR::Block::iterator& it, + IR::Block::iterator end) { + IR::Inst& inst = *it; + const auto opcode = inst.GetOpcode(); + + const bool is_u8 = opcode == IR::Opcode::WriteStorageU8 || opcode == IR::Opcode::WriteStorageS8; + const bool is_u16 = opcode == IR::Opcode::WriteStorageU16 || opcode == IR::Opcode::WriteStorageS16; + + if ((is_u8 && ctx.profile.support_int8) || (is_u16 && ctx.profile.support_int16)) { + return false; + } + + if (!is_u8 && !is_u16) { + return false; + } + + if (!inst.Arg(0).IsImmediate() || !inst.Arg(1).IsImmediate()) { + return false; + } + + const u32 binding = inst.Arg(0).U32(); + const u32 base_offset = inst.Arg(1).U32(); + const u32 base_word = base_offset / 4; + const u32 base_byte = base_offset % 4; + + const u32 step_bytes = is_u8 ? 1u : 2u; + const u32 max_bytes = is_u8 ? 4u : 4u; + + if (is_u16 && (base_offset % 2u) != 0) { + return false; + } + + std::array grouped{}; + grouped[0] = &inst; + u32 grouped_bytes = step_bytes; + auto look_ahead = it; + + while (grouped_bytes < max_bytes) { + auto next = std::next(look_ahead); + if (next == end) { + break; + } + const auto next_opcode = next->GetOpcode(); + const bool next_is_u8 = next_opcode == IR::Opcode::WriteStorageU8 || next_opcode == IR::Opcode::WriteStorageS8; + const bool next_is_u16 = next_opcode == IR::Opcode::WriteStorageU16 || next_opcode == IR::Opcode::WriteStorageS16; + if (next_is_u8 != is_u8 || next_is_u16 != is_u16) { + break; + } + if (!next->Arg(0).IsImmediate() || !next->Arg(1).IsImmediate()) { + break; + } + const u32 next_binding = next->Arg(0).U32(); + const u32 next_offset = next->Arg(1).U32(); + if (next_binding != binding) { + break; + } + if (next_offset / 4 != base_word) { + break; + } + + const u32 expected_offset = base_offset + grouped_bytes; + if (next_offset != expected_offset) { + break; + } + + grouped[grouped_bytes / step_bytes] = &*next; + grouped_bytes += step_bytes; + look_ahead = next; + } + + const u32 count = grouped_bytes / step_bytes; + if (count <= 1) { + return false; + } + + Id combined = ctx.u32_zero_value; + for (u32 i = 0; i < count; ++i) { + IR::Inst* current = grouped[i]; + Id raw = ctx.Def(current->Arg(2)); + if (is_u8) { + raw = ctx.OpBitwiseAnd(ctx.U32[1], raw, ctx.Const(0xFFu)); + } else { + raw = ctx.OpBitwiseAnd(ctx.U32[1], raw, ctx.Const(0xFFFFu)); + } + const u32 shift_bits = (is_u8 ? i * 8u : i * 16u); + Id shifted = raw; + if (shift_bits != 0) { + shifted = ctx.OpShiftLeftLogical(ctx.U32[1], raw, ctx.Const(shift_bits)); + } + if (i == 0) { + combined = shifted; + } else { + combined = ctx.OpBitwiseOr(ctx.U32[1], combined, shifted); + } + } + + const u32 bit_offset_val = is_u8 ? base_byte * 8u : ((base_offset / 2u) % 2u) * 16u; + const u32 bit_count_val = grouped_bytes * 8u; + + const Id ssbo = ctx.ssbos[binding].U32; + const Id index = ctx.Const(base_word); + const Id pointer = ctx.OpAccessChain(ctx.storage_types.U32.element, ssbo, ctx.u32_zero_value, index); + + if (bit_count_val == 32 && bit_offset_val == 0) { + ctx.OpStore(pointer, combined); + } else { + ctx.OpFunctionCall(ctx.TypeVoid(), ctx.write_storage_cas_loop_func, pointer, combined, + ctx.Const(bit_offset_val), ctx.Const(bit_count_val)); + } + + it = look_ahead; + return true; +} + spvtools::Optimizer& GetThreadOptimizer() { if (!thread_optimizer) { thread_optimizer = std::make_unique(SPV_ENV_VULKAN_1_3); @@ -140,8 +255,11 @@ void Traverse(EmitContext& ctx, IR::Program& program) { } current_block = node.data.block; ctx.AddLabel(label); - for (IR::Inst& inst : node.data.block->Instructions()) { - EmitInst(ctx, &inst); + for (auto it = node.data.block->begin(); it != node.data.block->end(); ++it) { + if (TryEmitCoalescedStorage(ctx, it, node.data.block->end())) { + continue; + } + EmitInst(ctx, &*it); } break; }