From f7e6b61379ca2a55c85427818565c3d6ff722dce Mon Sep 17 00:00:00 2001 From: Tony Wasserka Date: Thu, 29 Oct 2020 11:41:11 +0100 Subject: [PATCH] aco/ra: Add helpers to test for intersection/containment of reg intervals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_register_allocation.cpp | 25 ++++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 2b0533a2dcf..ff74900b5ea 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -158,6 +158,14 @@ struct PhysRegInterval { return { first, end - first }; } + bool contains(PhysReg reg) const { + return lo() <= reg && reg < hi(); + } + + bool contains(const PhysRegInterval& needle) const { + return needle.lo() >= lo() && needle.hi() <= hi(); + } + PhysRegIterator begin() const { return { lo_ }; } @@ -167,6 +175,11 @@ struct PhysRegInterval { } }; +bool intersects(const PhysRegInterval& a, const PhysRegInterval& b) { + return ((a.lo() >= b.lo() && a.lo() < b.hi()) || + (a.hi() > b.lo() && a.hi() <= b.hi())); +} + /* Gets the stride for full (non-subdword) registers */ uint32_t get_stride(RegClass rc) { if (rc.type() == RegType::vgpr) { @@ -811,7 +824,7 @@ std::pair get_reg_simple(ra_ctx& ctx, if (rc.is_subdword()) { for (std::pair> entry : reg_file.subdword_regs) { assert(reg_file[entry.first] == 0xF0000000); - if (bounds.lo() > entry.first || entry.first >= bounds.hi()) + if (!bounds.contains(PhysReg{entry.first})) continue; for (unsigned i = 0; i < 4; i+= info.stride) { @@ -943,8 +956,7 @@ bool get_regs_for_copies(ra_ctx& ctx, unsigned stride = var.rc.is_subdword() ? 1 : info.stride; for (PhysRegInterval reg_win { bounds.lo(), size }; reg_win.hi() <= bounds.hi(); reg_win += stride) { - if (!is_dead_operand && ((reg_win.lo() >= def_reg.lo() && reg_win.lo() < def_reg.hi()) || - (reg_win.hi() > def_reg.lo() && reg_win.hi() <= def_reg.hi()))) + if (!is_dead_operand && intersects(reg_win, def_reg)) continue; /* second, check that we have at most k=num_moves elements in the window @@ -1048,8 +1060,7 @@ std::pair get_reg_impl(ra_ctx& ctx, Operand& op = instr->operands[j]; if (op.isTemp() && op.isFirstKillBeforeDef() && - op.physReg() >= bounds.lo() && - op.physReg() < bounds.hi() && + bounds.contains(op.physReg()) && !reg_file.test(PhysReg{op.physReg().reg()}, align(op.bytes() + op.physReg().byte(), 4))) { assert(op.isFixed()); @@ -1214,7 +1225,7 @@ bool get_reg_specified(ra_ctx& ctx, } PhysRegInterval reg_win = { reg.reg(), rc.size() }; - if (reg_win.lo() < bounds.lo() || reg_win.hi() > bounds.hi()) + if (!bounds.contains(reg_win)) return false; if (rc.is_subdword()) { @@ -1383,7 +1394,7 @@ PhysReg get_reg_create_vector(ra_ctx& ctx, /* check borders */ // TODO: this can be improved */ - if (reg_win.lo() < bounds.lo() || reg_win.hi() > bounds.hi() || reg_win.lo() % stride != 0) + if (!bounds.contains(reg_win) || reg_win.lo() % stride != 0) continue; if (reg_win.lo() > bounds.lo() && reg_file[reg_win.lo()] != 0 && reg_file.get_id(PhysReg(reg_win.lo())) == reg_file.get_id(PhysReg(reg_win.lo()).advance(-1))) continue;