From 033531509bc3465f743527af3a8cd59f3b3b0c70 Mon Sep 17 00:00:00 2001 From: CamilleLaVey Date: Sun, 9 Nov 2025 23:14:51 -0400 Subject: [PATCH] [shader_recompiler, spir-v] Adding INT64 emulation path --- .../backend/spirv/emit_spirv_memory.cpp | 12 +- .../backend/spirv/spirv_emit_context.cpp | 195 ++++++++++++++++-- .../backend/spirv/spirv_emit_context.h | 9 + 3 files changed, 192 insertions(+), 24 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp index bdcbccfde9..0ac7086995 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp @@ -92,7 +92,7 @@ void EmitLoadGlobalS16(EmitContext&) { } Id EmitLoadGlobal32(EmitContext& ctx, Id address) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { return ctx.OpFunctionCall(ctx.U32[1], ctx.load_global_func_u32, address); } LOG_WARNING(Shader_SPIRV, "Int64 not supported, ignoring memory operation"); @@ -100,7 +100,7 @@ Id EmitLoadGlobal32(EmitContext& ctx, Id address) { } Id EmitLoadGlobal64(EmitContext& ctx, Id address) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { return ctx.OpFunctionCall(ctx.U32[2], ctx.load_global_func_u32x2, address); } LOG_WARNING(Shader_SPIRV, "Int64 not supported, ignoring memory operation"); @@ -108,7 +108,7 @@ Id EmitLoadGlobal64(EmitContext& ctx, Id address) { } Id EmitLoadGlobal128(EmitContext& ctx, Id address) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { return ctx.OpFunctionCall(ctx.U32[4], ctx.load_global_func_u32x4, address); } LOG_WARNING(Shader_SPIRV, "Int64 not supported, ignoring memory operation"); @@ -132,7 +132,7 @@ void EmitWriteGlobalS16(EmitContext&) { } void EmitWriteGlobal32(EmitContext& ctx, Id address, Id value) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { ctx.OpFunctionCall(ctx.void_id, ctx.write_global_func_u32, address, value); return; } @@ -140,7 +140,7 @@ void EmitWriteGlobal32(EmitContext& ctx, Id address, Id value) { } void EmitWriteGlobal64(EmitContext& ctx, Id address, Id value) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { ctx.OpFunctionCall(ctx.void_id, ctx.write_global_func_u32x2, address, value); return; } @@ -148,7 +148,7 @@ void EmitWriteGlobal64(EmitContext& ctx, Id address, Id value) { } void EmitWriteGlobal128(EmitContext& ctx, Id address, Id value) { - if (ctx.profile.support_int64) { + if (ctx.SupportsNativeInt64() || ctx.UsesInt64Emulation()) { ctx.OpFunctionCall(ctx.void_id, ctx.write_global_func_u32x4, address, value); return; } diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp index 4c3e101433..c4b72b5888 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp @@ -460,9 +460,10 @@ void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_vie EmitContext::EmitContext(const Profile& profile_, const RuntimeInfo& runtime_info_, IR::Program& program, Bindings& bindings) - : Sirit::Module(profile_.supported_spirv), profile{profile_}, runtime_info{runtime_info_}, - stage{program.stage}, texture_rescaling_index{bindings.texture_scaling_index}, - image_rescaling_index{bindings.image_scaling_index} { + : Sirit::Module(profile_.supported_spirv), profile{profile_}, runtime_info{runtime_info_}, + stage{program.stage}, emulate_int64{program.info.uses_int64 && !profile.support_int64}, + texture_rescaling_index{bindings.texture_scaling_index}, + image_rescaling_index{bindings.image_scaling_index} { const bool is_unified{profile.unified_descriptor_binding}; u32& uniform_binding{is_unified ? bindings.unified : bindings.uniform_buffer}; u32& storage_binding{is_unified ? bindings.unified : bindings.storage_buffer}; @@ -932,11 +933,163 @@ void EmitContext::DefineWriteStorageCasLoopFunction(const Info& info) { } void EmitContext::DefineGlobalMemoryFunctions(const Info& info) { - if (!info.uses_global_memory || !profile.support_int64) { + if (!info.uses_global_memory) { return; } using DefPtr = Id StorageDefinitions::*; const Id zero{u32_zero_value}; + + if (SupportsNativeInt64()) { + const auto define_body{[&](DefPtr ssbo_member, Id addr, Id element_pointer, u32 shift, + auto&& callback) { + AddLabel(); + const size_t num_buffers{info.storage_buffers_descriptors.size()}; + for (size_t index = 0; index < num_buffers; ++index) { + if (!info.nvn_buffer_used[index]) { + continue; + } + const auto& ssbo{info.storage_buffers_descriptors[index]}; + const Id ssbo_addr_cbuf_offset{Const(ssbo.cbuf_offset / 8)}; + const Id ssbo_size_cbuf_offset{Const(ssbo.cbuf_offset / 4 + 2)}; + const Id ssbo_addr_pointer{OpAccessChain( + uniform_types.U32x2, cbufs[ssbo.cbuf_index].U32x2, zero, + ssbo_addr_cbuf_offset)}; + const Id ssbo_size_pointer{OpAccessChain( + uniform_types.U32, cbufs[ssbo.cbuf_index].U32, zero, ssbo_size_cbuf_offset)}; + + const u64 ssbo_align_mask{~(profile.min_ssbo_alignment - 1U)}; + const Id unaligned_addr{OpBitcast(U64, OpLoad(U32[2], ssbo_addr_pointer))}; + const Id ssbo_addr{OpBitwiseAnd(U64, unaligned_addr, Constant(U64, ssbo_align_mask))}; + const Id ssbo_size{OpUConvert(U64, OpLoad(U32[1], ssbo_size_pointer))}; + const Id ssbo_end{OpIAdd(U64, ssbo_addr, ssbo_size)}; + const Id cond{OpLogicalAnd(U1, OpUGreaterThanEqual(U1, addr, ssbo_addr), + OpULessThan(U1, addr, ssbo_end))}; + const Id then_label{OpLabel()}; + const Id else_label{OpLabel()}; + OpSelectionMerge(else_label, spv::SelectionControlMask::MaskNone); + OpBranchConditional(cond, then_label, else_label); + AddLabel(then_label); + const Id ssbo_id{ssbos[index].*ssbo_member}; + const Id ssbo_offset{OpUConvert(U32[1], OpISub(U64, addr, ssbo_addr))}; + const Id ssbo_index{OpShiftRightLogical(U32[1], ssbo_offset, Const(shift))}; + const Id ssbo_pointer{OpAccessChain(element_pointer, ssbo_id, zero, ssbo_index)}; + callback(ssbo_pointer); + AddLabel(else_label); + } + }}; + const auto define_load{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) { + const Id function_type{TypeFunction(type, U64)}; + const Id func_id{OpFunction(type, spv::FunctionControlMask::MaskNone, function_type)}; + const Id addr{OpFunctionParameter(U64)}; + define_body(ssbo_member, addr, element_pointer, shift, + [&](Id ssbo_pointer) { OpReturnValue(OpLoad(type, ssbo_pointer)); }); + OpReturnValue(ConstantNull(type)); + OpFunctionEnd(); + return func_id; + }}; + const auto define_write{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) { + const Id function_type{TypeFunction(void_id, U64, type)}; + const Id func_id{ + OpFunction(void_id, spv::FunctionControlMask::MaskNone, function_type)}; + const Id addr{OpFunctionParameter(U64)}; + const Id data{OpFunctionParameter(type)}; + define_body(ssbo_member, addr, element_pointer, shift, [&](Id ssbo_pointer) { + OpStore(ssbo_pointer, data); + OpReturn(); + }); + OpReturn(); + OpFunctionEnd(); + return func_id; + }}; + const auto define{ + [&](DefPtr ssbo_member, const StorageTypeDefinition& type_def, Id type, size_t size) { + const Id element_type{type_def.element}; + const u32 shift{static_cast(std::countr_zero(size))}; + const Id load_func{define_load(ssbo_member, element_type, type, shift)}; + const Id write_func{define_write(ssbo_member, element_type, type, shift)}; + return std::make_pair(load_func, write_func); + }}; + std::tie(load_global_func_u32, write_global_func_u32) = + define(&StorageDefinitions::U32, storage_types.U32, U32[1], sizeof(u32)); + std::tie(load_global_func_u32x2, write_global_func_u32x2) = + define(&StorageDefinitions::U32x2, storage_types.U32x2, U32[2], sizeof(u32[2])); + std::tie(load_global_func_u32x4, write_global_func_u32x4) = + define(&StorageDefinitions::U32x4, storage_types.U32x4, U32[4], sizeof(u32[4])); + return; + } + + if (!UsesInt64Emulation()) { + return; + } + + const auto make_pair = [&](Id lo, Id hi) { + return OpCompositeConstruct(U32[2], lo, hi); + }; + const auto split_pair = [&](Id value) { + return std::array{OpCompositeExtract(U32[1], value, 0U), + OpCompositeExtract(U32[1], value, 1U)}; + }; + const auto bool_to_u32 = [&](Id predicate) { + return OpSelect(U32[1], predicate, Const(1u), zero); + }; + const auto and_pair = [&](Id value, Id mask) { + const auto value_parts{split_pair(value)}; + const auto mask_parts{split_pair(mask)}; + return make_pair(OpBitwiseAnd(U32[1], value_parts[0], mask_parts[0]), + OpBitwiseAnd(U32[1], value_parts[1], mask_parts[1])); + }; + const auto add_pair = [&](Id lhs, Id rhs) { + const auto lhs_parts{split_pair(lhs)}; + const auto rhs_parts{split_pair(rhs)}; + const Id sum_lo{OpIAdd(U32[1], lhs_parts[0], rhs_parts[0])}; + const Id carry{OpULessThan(U1, sum_lo, lhs_parts[0])}; + Id sum_hi{OpIAdd(U32[1], lhs_parts[1], rhs_parts[1])}; + sum_hi = OpIAdd(U32[1], sum_hi, bool_to_u32(carry)); + return make_pair(sum_lo, sum_hi); + }; + const auto sub_pair = [&](Id lhs, Id rhs) { + const auto lhs_parts{split_pair(lhs)}; + const auto rhs_parts{split_pair(rhs)}; + const Id borrow{OpULessThan(U1, lhs_parts[0], rhs_parts[0])}; + const Id diff_lo{OpISub(U32[1], lhs_parts[0], rhs_parts[0])}; + Id diff_hi{OpISub(U32[1], lhs_parts[1], rhs_parts[1])}; + diff_hi = OpISub(U32[1], diff_hi, bool_to_u32(borrow)); + return make_pair(diff_lo, diff_hi); + }; + const auto shift_right_pair = [&](Id value, u32 shift) { + if (shift == 0) { + return value; + } + const auto parts{split_pair(value)}; + const Id shift_id{Const(shift)}; + const Id high_shifted{OpShiftRightLogical(U32[1], parts[1], shift_id)}; + Id low_shifted{OpShiftRightLogical(U32[1], parts[0], shift_id)}; + const Id carry_bits{OpShiftLeftLogical(U32[1], parts[1], Const(32u - shift))}; + low_shifted = OpBitwiseOr(U32[1], low_shifted, carry_bits); + return make_pair(low_shifted, high_shifted); + }; + const auto greater_equal_pair = [&](Id lhs, Id rhs) { + const auto lhs_parts{split_pair(lhs)}; + const auto rhs_parts{split_pair(rhs)}; + const Id hi_gt{OpUGreaterThan(U1, lhs_parts[1], rhs_parts[1])}; + const Id hi_eq{OpIEqual(U1, lhs_parts[1], rhs_parts[1])}; + const Id lo_ge{OpUGreaterThanEqual(U1, lhs_parts[0], rhs_parts[0])}; + return OpLogicalOr(U1, hi_gt, OpLogicalAnd(U1, hi_eq, lo_ge)); + }; + const auto less_than_pair = [&](Id lhs, Id rhs) { + const auto lhs_parts{split_pair(lhs)}; + const auto rhs_parts{split_pair(rhs)}; + const Id hi_lt{OpULessThan(U1, lhs_parts[1], rhs_parts[1])}; + const Id hi_eq{OpIEqual(U1, lhs_parts[1], rhs_parts[1])}; + const Id lo_lt{OpULessThan(U1, lhs_parts[0], rhs_parts[0])}; + return OpLogicalOr(U1, hi_lt, OpLogicalAnd(U1, hi_eq, lo_lt)); + }; + + const u64 ssbo_align_mask_value{~(profile.min_ssbo_alignment - 1U)}; + const Id ssbo_align_mask{ + Const(static_cast(ssbo_align_mask_value & 0xFFFFFFFFu), + static_cast(ssbo_align_mask_value >> 32))}; + const auto define_body{[&](DefPtr ssbo_member, Id addr, Id element_pointer, u32 shift, auto&& callback) { AddLabel(); @@ -953,40 +1106,44 @@ void EmitContext::DefineGlobalMemoryFunctions(const Info& info) { const Id ssbo_size_pointer{OpAccessChain(uniform_types.U32, cbufs[ssbo.cbuf_index].U32, zero, ssbo_size_cbuf_offset)}; - const u64 ssbo_align_mask{~(profile.min_ssbo_alignment - 1U)}; - const Id unaligned_addr{OpBitcast(U64, OpLoad(U32[2], ssbo_addr_pointer))}; - const Id ssbo_addr{OpBitwiseAnd(U64, unaligned_addr, Constant(U64, ssbo_align_mask))}; - const Id ssbo_size{OpUConvert(U64, OpLoad(U32[1], ssbo_size_pointer))}; - const Id ssbo_end{OpIAdd(U64, ssbo_addr, ssbo_size)}; - const Id cond{OpLogicalAnd(U1, OpUGreaterThanEqual(U1, addr, ssbo_addr), - OpULessThan(U1, addr, ssbo_end))}; + const Id unaligned_addr_pair{OpLoad(U32[2], ssbo_addr_pointer)}; + const Id ssbo_addr_pair{and_pair(unaligned_addr_pair, ssbo_align_mask)}; + const Id ssbo_size_value{OpLoad(U32[1], ssbo_size_pointer)}; + const Id ssbo_size_pair{make_pair(ssbo_size_value, zero)}; + const Id ssbo_end_pair{add_pair(ssbo_addr_pair, ssbo_size_pair)}; + const Id cond{OpLogicalAnd(U1, greater_equal_pair(addr, ssbo_addr_pair), + less_than_pair(addr, ssbo_end_pair))}; const Id then_label{OpLabel()}; const Id else_label{OpLabel()}; OpSelectionMerge(else_label, spv::SelectionControlMask::MaskNone); OpBranchConditional(cond, then_label, else_label); AddLabel(then_label); const Id ssbo_id{ssbos[index].*ssbo_member}; - const Id ssbo_offset{OpUConvert(U32[1], OpISub(U64, addr, ssbo_addr))}; - const Id ssbo_index{OpShiftRightLogical(U32[1], ssbo_offset, Const(shift))}; + const Id ssbo_offset_pair{sub_pair(addr, ssbo_addr_pair)}; + const Id ssbo_index_pair{shift_right_pair(ssbo_offset_pair, shift)}; + const Id ssbo_index{OpCompositeExtract(U32[1], ssbo_index_pair, 0U)}; const Id ssbo_pointer{OpAccessChain(element_pointer, ssbo_id, zero, ssbo_index)}; callback(ssbo_pointer); AddLabel(else_label); } }}; + const auto define_load{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) { - const Id function_type{TypeFunction(type, U64)}; + const Id function_type{TypeFunction(type, U32[2])}; const Id func_id{OpFunction(type, spv::FunctionControlMask::MaskNone, function_type)}; - const Id addr{OpFunctionParameter(U64)}; + const Id addr{OpFunctionParameter(U32[2])}; define_body(ssbo_member, addr, element_pointer, shift, [&](Id ssbo_pointer) { OpReturnValue(OpLoad(type, ssbo_pointer)); }); OpReturnValue(ConstantNull(type)); OpFunctionEnd(); return func_id; }}; + const auto define_write{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) { - const Id function_type{TypeFunction(void_id, U64, type)}; - const Id func_id{OpFunction(void_id, spv::FunctionControlMask::MaskNone, function_type)}; - const Id addr{OpFunctionParameter(U64)}; + const Id function_type{TypeFunction(void_id, U32[2], type)}; + const Id func_id{ + OpFunction(void_id, spv::FunctionControlMask::MaskNone, function_type)}; + const Id addr{OpFunctionParameter(U32[2])}; const Id data{OpFunctionParameter(type)}; define_body(ssbo_member, addr, element_pointer, shift, [&](Id ssbo_pointer) { OpStore(ssbo_pointer, data); @@ -996,6 +1153,7 @@ void EmitContext::DefineGlobalMemoryFunctions(const Info& info) { OpFunctionEnd(); return func_id; }}; + const auto define{ [&](DefPtr ssbo_member, const StorageTypeDefinition& type_def, Id type, size_t size) { const Id element_type{type_def.element}; @@ -1004,6 +1162,7 @@ void EmitContext::DefineGlobalMemoryFunctions(const Info& info) { const Id write_func{define_write(ssbo_member, element_type, type, shift)}; return std::make_pair(load_func, write_func); }}; + std::tie(load_global_func_u32, write_global_func_u32) = define(&StorageDefinitions::U32, storage_types.U32, U32[1], sizeof(u32)); std::tie(load_global_func_u32x2, write_global_func_u32x2) = diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.h b/src/shader_recompiler/backend/spirv/spirv_emit_context.h index 66cdb1d3db..2dbeeb0911 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.h +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.h @@ -207,6 +207,15 @@ public: const Profile& profile; const RuntimeInfo& runtime_info; Stage stage{}; + const bool emulate_int64{}; + + bool SupportsNativeInt64() const { + return profile.support_int64; + } + + bool UsesInt64Emulation() const { + return emulate_int64; + } Id void_id{}; Id U1{};