|
|
@ -325,6 +325,7 @@ public: |
|
|
DeclareRegisters(); |
|
|
DeclareRegisters(); |
|
|
DeclarePredicates(); |
|
|
DeclarePredicates(); |
|
|
DeclareLocalMemory(); |
|
|
DeclareLocalMemory(); |
|
|
|
|
|
DeclareSharedMemory(); |
|
|
DeclareInternalFlags(); |
|
|
DeclareInternalFlags(); |
|
|
DeclareInputAttributes(); |
|
|
DeclareInputAttributes(); |
|
|
DeclareOutputAttributes(); |
|
|
DeclareOutputAttributes(); |
|
|
@ -500,6 +501,13 @@ private: |
|
|
code.AddNewLine(); |
|
|
code.AddNewLine(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void DeclareSharedMemory() { |
|
|
|
|
|
if (stage != ProgramType::Compute) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
code.AddLine("shared uint {}[];", GetSharedMemory()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void DeclareInternalFlags() { |
|
|
void DeclareInternalFlags() { |
|
|
for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { |
|
|
for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { |
|
|
const auto flag_code = static_cast<InternalFlag>(flag); |
|
|
const auto flag_code = static_cast<InternalFlag>(flag); |
|
|
@ -858,6 +866,12 @@ private: |
|
|
Type::Uint}; |
|
|
Type::Uint}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (const auto smem = std::get_if<SmemNode>(&*node)) { |
|
|
|
|
|
return { |
|
|
|
|
|
fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), |
|
|
|
|
|
Type::Uint}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { |
|
|
if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { |
|
|
return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool}; |
|
|
return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool}; |
|
|
} |
|
|
} |
|
|
@ -1195,6 +1209,11 @@ private: |
|
|
target = { |
|
|
target = { |
|
|
fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), |
|
|
fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), |
|
|
Type::Uint}; |
|
|
Type::Uint}; |
|
|
|
|
|
} else if (const auto smem = std::get_if<SmemNode>(&*dest)) { |
|
|
|
|
|
ASSERT(stage == ProgramType::Compute); |
|
|
|
|
|
target = { |
|
|
|
|
|
fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), |
|
|
|
|
|
Type::Uint}; |
|
|
} else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { |
|
|
} else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { |
|
|
const std::string real = Visit(gmem->GetRealAddress()).AsUint(); |
|
|
const std::string real = Visit(gmem->GetRealAddress()).AsUint(); |
|
|
const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); |
|
|
const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); |
|
|
@ -2076,6 +2095,10 @@ private: |
|
|
return "lmem_" + suffix; |
|
|
return "lmem_" + suffix; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string GetSharedMemory() const { |
|
|
|
|
|
return fmt::format("smem_{}", suffix); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::string GetInternalFlag(InternalFlag flag) const { |
|
|
std::string GetInternalFlag(InternalFlag flag) const { |
|
|
constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", |
|
|
constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", |
|
|
"overflow_flag"}; |
|
|
"overflow_flag"}; |
|
|
|