From e95d728a9803529290f237ad8ea87b6020b05481 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Sch=C3=BCrmann?= Date: Fri, 1 Aug 2025 11:11:41 +0200 Subject: [PATCH] aco/scheduler: split downwards_move_clause() from downwards_move() We will do batched moves for clauses with the next commit. Part-of: --- src/amd/compiler/aco_scheduler.cpp | 168 +++++++++++++++++------------ 1 file changed, 99 insertions(+), 69 deletions(-) diff --git a/src/amd/compiler/aco_scheduler.cpp b/src/amd/compiler/aco_scheduler.cpp index 0195e168a4f..8e7911986b1 100644 --- a/src/amd/compiler/aco_scheduler.cpp +++ b/src/amd/compiler/aco_scheduler.cpp @@ -102,7 +102,8 @@ struct MoveState { /* for moving instructions before the current instruction to after it */ DownwardsCursor downwards_init(int current_idx, bool improved_rar, bool may_form_clauses); - MoveResult downwards_move(DownwardsCursor&, bool clause); + MoveResult downwards_move(DownwardsCursor&); + MoveResult downwards_move_clause(DownwardsCursor&); void downwards_skip(DownwardsCursor&); /* for moving instructions after the first use of the current instruction upwards */ @@ -216,80 +217,109 @@ check_dependencies(Instruction* instr, std::vector& def_dep, std::vector& candidate = block->instructions[cursor.source_idx]; /* check if one of candidate's operands is killed by depending instruction */ - if (add_to_clause) { - assert(improved_rar); - aco_ptr& instr = block->instructions[cursor.insert_idx_clause]; - int i = cursor.source_idx; - while (should_form_clause(block->instructions[i].get(), instr.get())) { - if (check_dependencies(block->instructions[i].get(), depends_on, RAR_dependencies_clause)) - return move_fail_ssa; - i--; - } - } else { - std::vector& RAR_deps = improved_rar ? RAR_dependencies : depends_on; - if (check_dependencies(candidate.get(), depends_on, RAR_deps)) - return move_fail_ssa; - } - - if (add_to_clause) { - for (const Operand& op : candidate->operands) { - if (op.isTemp()) { - depends_on[op.tempId()] = true; - if (op.isFirstKill()) - RAR_dependencies[op.tempId()] = true; - } - } - } - - const int dest_insert_idx = add_to_clause ? cursor.insert_idx_clause : cursor.insert_idx; - RegisterDemand register_pressure = cursor.total_demand; - if (!add_to_clause) { - register_pressure.update(cursor.clause_demand); - } + std::vector& RAR_deps = improved_rar ? RAR_dependencies : depends_on; + if (check_dependencies(candidate.get(), depends_on, RAR_deps)) + return move_fail_ssa; /* Check the new demand of the instructions being moved over */ + RegisterDemand register_pressure = cursor.total_demand; + register_pressure.update(cursor.clause_demand); const RegisterDemand candidate_diff = get_live_changes(candidate.get()); if (RegisterDemand(register_pressure - candidate_diff).exceeds(max_registers)) return move_fail_pressure; /* New demand for the moved instruction */ const RegisterDemand temp = get_temp_registers(candidate.get()); - const RegisterDemand insert_demand = - add_to_clause ? cursor.insert_demand_clause : cursor.insert_demand; + const RegisterDemand insert_demand = cursor.insert_demand; const RegisterDemand new_demand = insert_demand + temp; if (new_demand.exceeds(max_registers)) return move_fail_pressure; /* move the candidate below the memory load */ - move_element(block->instructions.begin(), cursor.source_idx, dest_insert_idx); + move_element(block->instructions.begin(), cursor.source_idx, cursor.insert_idx); + cursor.insert_idx--; + cursor.insert_idx_clause--; /* update register pressure */ - for (int i = cursor.source_idx; i < dest_insert_idx - 1; i++) + for (int i = cursor.source_idx; i < cursor.insert_idx; i++) block->instructions[i]->register_demand -= candidate_diff; - block->instructions[dest_insert_idx - 1]->register_demand = new_demand; - cursor.insert_idx_clause--; + block->instructions[cursor.insert_idx]->register_demand = new_demand; if (cursor.source_idx != cursor.insert_idx_clause) { /* Update demand if we moved over any instructions before the clause */ cursor.total_demand -= candidate_diff; } else { assert(cursor.total_demand == RegisterDemand{}); } - if (add_to_clause) { - cursor.clause_demand.update(new_demand); - } else { - cursor.clause_demand -= candidate_diff; - cursor.insert_demand -= candidate_diff; - cursor.insert_idx--; + + cursor.clause_demand -= candidate_diff; + cursor.insert_demand -= candidate_diff; + cursor.insert_demand_clause -= candidate_diff; + + cursor.source_idx--; + cursor.verify_invariants(block); + return move_success; +} + +/* The current clause is extended by moving the instruction at source_idx + * in front of the clause. + */ +MoveResult +MoveState::downwards_move_clause(DownwardsCursor& cursor) +{ + assert(improved_rar); + aco_ptr& candidate = block->instructions[cursor.source_idx]; + + /* check if one of candidate's operands is killed by depending instruction */ + aco_ptr& instr = block->instructions[cursor.insert_idx_clause]; + int idx = cursor.source_idx; + while (should_form_clause(block->instructions[idx].get(), instr.get())) { + if (check_dependencies(block->instructions[idx].get(), depends_on, RAR_dependencies_clause)) + return move_fail_ssa; + idx--; } + + for (const Operand& op : candidate->operands) { + if (op.isTemp()) { + depends_on[op.tempId()] = true; + if (op.isFirstKill()) + RAR_dependencies[op.tempId()] = true; + } + } + + /* Check the new demand of the instructions being moved over */ + RegisterDemand register_pressure = cursor.total_demand; + const RegisterDemand candidate_diff = get_live_changes(candidate.get()); + if (RegisterDemand(register_pressure - candidate_diff).exceeds(max_registers)) + return move_fail_pressure; + + /* New demand for the moved instruction */ + const RegisterDemand temp = get_temp_registers(candidate.get()); + const RegisterDemand new_demand = cursor.insert_demand_clause + temp; + if (new_demand.exceeds(max_registers)) + return move_fail_pressure; + + /* move the candidate below the memory load */ + move_element(block->instructions.begin(), cursor.source_idx, cursor.insert_idx_clause); + cursor.insert_idx_clause--; + + /* update register pressure */ + for (int i = cursor.source_idx; i < cursor.insert_idx_clause; i++) + block->instructions[i]->register_demand -= candidate_diff; + block->instructions[cursor.insert_idx_clause]->register_demand = new_demand; + if (cursor.source_idx != cursor.insert_idx_clause) { + /* Update demand if we moved over any instructions before the clause */ + cursor.total_demand -= candidate_diff; + } else { + assert(cursor.total_demand == RegisterDemand{}); + } + cursor.clause_demand.update(new_demand); cursor.insert_demand_clause -= candidate_diff; cursor.source_idx--; @@ -789,7 +819,7 @@ schedule_SMEM(sched_ctx& ctx, Block* block, Instruction* current, int idx) continue; } - MoveResult res = ctx.mv.downwards_move(cursor, false); + MoveResult res = ctx.mv.downwards_move(cursor); if (res == move_fail_ssa || res == move_fail_rar) { add_to_hazard_query(&hq, candidate.get()); ctx.mv.downwards_skip(cursor); @@ -944,29 +974,28 @@ schedule_VMEM(sched_ctx& ctx, Block* block, Instruction* current, int idx) } Instruction* candidate_ptr = candidate.get(); - MoveResult res = ctx.mv.downwards_move(cursor, part_of_clause); - if (res == move_fail_ssa || res == move_fail_rar) { - if (part_of_clause) - break; - add_to_hazard_query(&indep_hq, candidate.get()); - add_to_hazard_query(&clause_hq, candidate.get()); - ctx.mv.downwards_skip(cursor); - continue; - } else if (res == move_fail_pressure) { - only_clauses = true; - if (part_of_clause) - break; - add_to_hazard_query(&indep_hq, candidate.get()); - add_to_hazard_query(&clause_hq, candidate.get()); - ctx.mv.downwards_skip(cursor); - continue; - } if (part_of_clause) { + if (ctx.mv.downwards_move_clause(cursor) != move_success) + break; add_to_hazard_query(&indep_hq, candidate_ptr); only_clauses = true; } else { + MoveResult res = ctx.mv.downwards_move(cursor); + if (res == move_fail_ssa || res == move_fail_rar) { + add_to_hazard_query(&indep_hq, candidate.get()); + add_to_hazard_query(&clause_hq, candidate.get()); + ctx.mv.downwards_skip(cursor); + continue; + } else if (res == move_fail_pressure) { + only_clauses = true; + add_to_hazard_query(&indep_hq, candidate.get()); + add_to_hazard_query(&clause_hq, candidate.get()); + ctx.mv.downwards_skip(cursor); + continue; + } k++; } + if (candidate_idx < ctx.last_SMEM_dep_idx) ctx.last_SMEM_stall++; } @@ -1062,7 +1091,7 @@ schedule_LDS(sched_ctx& ctx, Block* block, Instruction* current, int idx) } if (perform_hazard_query(&hq, candidate.get(), false) != hazard_success || - ctx.mv.downwards_move(cursor, false) != move_success) + ctx.mv.downwards_move(cursor) != move_success) break; k++; @@ -1146,7 +1175,7 @@ schedule_position_export(sched_ctx& ctx, Block* block, Instruction* current, int continue; } - MoveResult res = ctx.mv.downwards_move(cursor, false); + MoveResult res = ctx.mv.downwards_move(cursor); if (res == move_fail_ssa || res == move_fail_rar) { add_to_hazard_query(&hq, candidate.get()); ctx.mv.downwards_skip(cursor); @@ -1179,8 +1208,9 @@ schedule_VMEM_store(sched_ctx& ctx, Block* block, Instruction* current, int idx) continue; } - if (perform_hazard_query(&hq, candidate.get(), false) != hazard_success || - ctx.mv.downwards_move(cursor, true) != move_success) + if (perform_hazard_query(&hq, candidate.get(), false) != hazard_success) + break; + if (ctx.mv.downwards_move_clause(cursor) != move_success) break; skip++;