aco/spill: Handle calls

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38281>
This commit is contained in:
Natalie Vock
2025-02-17 18:42:49 +01:00
committed by Marge Bot
parent ecc548cd37
commit 369a3c0dca
+71 -22
View File
@@ -89,16 +89,18 @@ struct spill_ctx {
unsigned vgpr_spill_slots;
Temp scratch_rsrc;
unsigned extra_vgprs;
unsigned resume_idx;
spill_ctx(const RegisterDemand target_pressure_, Program* program_)
spill_ctx(const RegisterDemand target_pressure_, Program* program_, unsigned extra_vgprs_)
: target_pressure(target_pressure_), program(program_), memory(),
renames(program->blocks.size(), aco::map<Temp, Temp>(memory)),
spills_entry(program->blocks.size(), aco::unordered_map<Temp, uint32_t>(memory)),
spills_exit(program->blocks.size(), aco::unordered_map<Temp, uint32_t>(memory)),
processed(program->blocks.size(), false), ssa_infos(program->peekAllocationId()),
remat(memory), wave_size(program->wave_size), sgpr_spill_slots(0), vgpr_spill_slots(0),
resume_idx(0)
extra_vgprs(extra_vgprs_), resume_idx(0)
{}
void add_affinity(uint32_t first, uint32_t second)
@@ -322,9 +324,13 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
/* check how many live-through variables should be spilled */
RegisterDemand reg_pressure = block->live_in_demand;
RegisterDemand loop_demand = reg_pressure;
RegisterDemand loop_call_spills = RegisterDemand();
unsigned i = block_idx;
while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth)
loop_demand.update(ctx.program->blocks[i++].register_demand);
while (ctx.program->blocks[i].loop_nest_depth >= block->loop_nest_depth) {
loop_demand.update(ctx.program->blocks[i].register_demand);
loop_call_spills.update(ctx.program->blocks[i].call_spills);
++i;
}
for (auto spilled : ctx.spills_exit[block_idx - 1]) {
/* variable is not live at loop entry: probably a phi operand */
@@ -352,12 +358,15 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
/* select more live-through variables and constants */
RegType type = RegType::vgpr;
while (loop_demand.exceeds(ctx.target_pressure)) {
while (loop_demand.exceeds(ctx.target_pressure) ||
loop_call_spills.exceeds(spilled_registers)) {
/* if VGPR demand is low enough, select SGPRs */
if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr)
if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr &&
loop_call_spills.vgpr <= spilled_registers.vgpr)
type = RegType::sgpr;
/* if SGPR demand is low enough, break */
if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr)
if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr &&
loop_call_spills.sgpr <= spilled_registers.sgpr)
break;
float score = 0.0;
@@ -932,12 +941,28 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
}
/* check if register demand is low enough during and after the current instruction */
if (block->register_demand.exceeds(ctx.target_pressure)) {
if (block->register_demand.exceeds(ctx.target_pressure) || instr->isCall()) {
RegisterDemand new_demand = instr->register_demand;
std::optional<RegisterDemand> live_changes;
/* if reg pressure is too high, spill variable with furthest next use */
while ((new_demand - spilled_registers).exceeds(ctx.target_pressure)) {
while (true) {
bool needs_spill = (new_demand - spilled_registers).exceeds(ctx.target_pressure);
if (instr->isCall()) {
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
/* Exclude the linear VGPRs created for spilling SGPRs from the limit,
* if they are placed in clobbered register ranges (i.e. the preserved limit
* can't fit all of them). The preserved spiller will take care of those. */
call_preserved_limit.vgpr =
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
needs_spill |= (instr->call().caller_preserved_demand - spilled_registers)
.exceeds(call_preserved_limit);
}
if (!needs_spill)
break;
float score = 0.0;
Temp to_spill = Temp();
bool spill_is_operand = false;
@@ -947,7 +972,15 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
unsigned avoid_respill = 0;
RegType type = RegType::sgpr;
if (new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr)
bool spill_vgpr = new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr;
if (instr->isCall()) {
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
call_preserved_limit.vgpr =
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
spill_vgpr |= instr->call().caller_preserved_demand.vgpr - spilled_registers.vgpr >
call_preserved_limit.vgpr;
}
if (spill_vgpr)
type = RegType::vgpr;
for (unsigned t : ctx.program->live.live_in[block_idx]) {
@@ -970,19 +1003,32 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
for (auto& op : instr->operands) {
if (!op.isTemp() || op.getTemp() != var)
continue;
/* Spilling vector operands causes us to emit a split_vector, increasing live
* state temporarily.
*/
if (op.isLateKill() || op.isKill() || op.size() > 1) {
if (op.isLateKill() || op.isKill()) {
can_spill = false;
break;
}
/* Spilling vector operands causes us to emit a split_vector, increasing live
* state temporarily. This is ok when we have enough register headroom (when
* spilling calls), but causes spilling to fail otherwise.
*/
if (op.size() > 1) {
RegisterDemand before = new_demand - spilled_registers;
before -= get_temp_registers(instr.get()) - get_live_changes(instr.get());
if ((before + op.getTemp()).exceeds(ctx.target_pressure)) {
can_spill = false;
break;
}
}
if (!live_changes)
live_changes = get_temp_reg_changes(instr.get());
/* Don't spill operands if killing operands won't help with register pressure */
if (!op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) {
if (!instr->isCall() && !op.isClobbered() && RegisterDemand(op.getTemp()).exceeds(*live_changes)) {
can_spill = false;
break;
}
if (instr->isCall() && !op.isClobbered()) {
can_spill = false;
break;
}
@@ -1653,20 +1699,23 @@ spill(Program* program)
uint16_t extra_sgprs = 0;
/* calculate extra VGPRs required for spilling SGPRs */
if (demand.sgpr > limit.sgpr) {
unsigned sgpr_spills = demand.sgpr - limit.sgpr;
unsigned sgpr_spills = demand.sgpr - std::min((uint16_t)demand.sgpr, (uint16_t)limit.sgpr);
sgpr_spills += program->max_call_spills.sgpr;
if (sgpr_spills)
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
}
/* add extra SGPRs required for spilling VGPRs */
if (demand.vgpr + extra_vgprs > limit.vgpr) {
if (demand.vgpr + extra_vgprs > limit.vgpr || program->max_call_spills.vgpr) {
if (program->gfx_level >= GFX9)
extra_sgprs =
program->stack_ptr.id() ? 2 : 1; /* SADDR + scc for stack pointer additions */
else
extra_sgprs = 5; /* scratch_resource (s4) + scratch_offset (s1) */
if (demand.sgpr + extra_sgprs > limit.sgpr) {
if (demand.sgpr + extra_sgprs > limit.sgpr || program->max_call_spills.sgpr) {
/* re-calculate in case something has changed */
unsigned sgpr_spills = demand.sgpr + extra_sgprs - limit.sgpr;
sgpr_spills = program->max_call_spills.sgpr;
if (demand.sgpr + extra_sgprs > limit.sgpr)
sgpr_spills += demand.sgpr + extra_sgprs - limit.sgpr;
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
}
}
@@ -1674,7 +1723,7 @@ spill(Program* program)
const RegisterDemand target(limit.vgpr - extra_vgprs, limit.sgpr - extra_sgprs);
/* initialize ctx */
spill_ctx ctx(target, program);
spill_ctx ctx(target, program, extra_vgprs);
gather_ssa_use_info(ctx);
get_rematerialize_info(ctx);