aco/optimizer: unify constant labels

Foz-DB Navi21:
Totals from 14 (0.02% of 79789) affected shaders:
Instrs: 44868 -> 44867 (-0.00%)
CodeSize: 279132 -> 279124 (-0.00%)
Copies: 11692 -> 11691 (-0.01%)
VALU: 30353 -> 30352 (-0.00%)

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35272>
This commit is contained in:
Georg Lehmann
2024-12-16 17:54:22 +01:00
committed by Marge Bot
parent 2d410cf18e
commit f1cbac7a8e

View File

@@ -40,14 +40,13 @@ struct mad_info {
};
enum Label {
label_constant_32bit = 1 << 1,
label_constant = 1 << 1,
/* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
* 32-bit operations but this shouldn't cause any issues because we don't
* look through any conversions */
label_abs = 1 << 2,
label_neg = 1 << 3,
label_temp = 1 << 5,
label_literal = 1 << 6,
label_mad = 1 << 7,
label_omod2 = 1 << 8,
label_omod4 = 1 << 9,
@@ -56,7 +55,6 @@ enum Label {
label_b2f = 1 << 16,
/* This label means that it's either 0 or -1, and the ssa_info::temp is an s1 which is 0 or 1. */
label_uniform_bool = 1 << 21,
label_constant_64bit = 1 << 22,
/* This label is added to the first definition of s_not/s_or/s_xor/s_and when all operands are
* uniform_bool or uniform_bitwise. The first definition of ssa_info::instr would be 0 or -1 and
* the second is SCC.
@@ -67,7 +65,6 @@ enum Label {
label_scc_needed = 1 << 26,
label_b2i = 1 << 27,
label_fcanonicalize = 1 << 28,
label_constant_16bit = 1 << 29,
label_canonicalized = 1ull << 32, /* 1ull to prevent sign extension */
label_extract = 1ull << 33,
label_insert = 1ull << 34,
@@ -80,8 +77,8 @@ static constexpr uint64_t instr_mod_labels =
static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f |
label_uniform_bool | label_scc_invert | label_b2i |
label_fcanonicalize;
static constexpr uint32_t val_labels =
label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
static constexpr uint32_t val_labels = label_constant | label_mad;
static_assert((instr_mod_labels & temp_labels) == 0, "labels cannot intersect");
static_assert((instr_mod_labels & val_labels) == 0, "labels cannot intersect");
@@ -90,7 +87,7 @@ static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
struct ssa_info {
uint64_t label;
union {
uint32_t val;
uint64_t val;
Temp temp;
Instruction* mod_instr;
};
@@ -110,12 +107,7 @@ struct ssa_info {
label &= ~(instr_mod_labels | val_labels); /* instr, temp and val alias */
}
uint32_t const_labels =
label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
if (new_label & const_labels) {
label &= ~val_labels | const_labels;
label &= ~(instr_mod_labels | temp_labels); /* instr, temp and val alias */
} else if (new_label & val_labels) {
if (new_label & val_labels) {
label &= ~val_labels;
label &= ~(instr_mod_labels | temp_labels); /* instr, temp and val alias */
}
@@ -123,61 +115,13 @@ struct ssa_info {
label |= new_label;
}
void set_constant(amd_gfx_level gfx_level, uint64_t constant)
void set_constant(uint64_t constant)
{
Operand op16 = Operand::c16(constant);
Operand op32 = Operand::get_const(gfx_level, constant, 4);
add_label(label_literal);
add_label(label_constant);
val = constant;
/* check that no upper bits are lost in case of packed 16bit constants */
if (gfx_level >= GFX8 && !op16.isLiteral() &&
op16.constantValue16(true) == ((constant >> 16) & 0xffff))
add_label(label_constant_16bit);
if (!op32.isLiteral())
add_label(label_constant_32bit);
if (Operand::is_constant_representable(constant, 8))
add_label(label_constant_64bit);
if (label & label_constant_64bit) {
val = Operand::c64(constant).constantValue();
if (val != constant)
label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
}
}
bool is_constant(unsigned bits)
{
switch (bits) {
case 8: return label & label_literal;
case 16: return label & label_constant_16bit;
case 32: return label & label_constant_32bit;
case 64: return label & label_constant_64bit;
}
return false;
}
bool is_literal(unsigned bits)
{
bool is_lit = label & label_literal;
switch (bits) {
case 8: return false;
case 16: return is_lit && ~(label & label_constant_16bit);
case 32: return is_lit && ~(label & label_constant_32bit);
case 64: return false;
}
return false;
}
bool is_constant_or_literal(unsigned bits)
{
if (bits == 64)
return label & label_constant_64bit;
else
return label & label_literal;
}
bool is_constant() { return label & label_constant; }
void set_abs(Temp abs_temp)
{
@@ -510,8 +454,25 @@ optimize_constants(opt_ctx& ctx, alu_opt_info& info)
if (!type.constant_bits())
return false;
if (type.bytes() > 4)
if (type.bytes() > 4) {
if (!op_info.op.isLiteral())
continue;
int64_t constant = op_info.op.constantValue64();
if (type.base_type == aco_base_type_float)
return false; /* Operand doesn't support double literal yet. */
else if (type.base_type == aco_base_type_int && constant >= 0x7fff'ffff)
return false;
else if (type.base_type != aco_base_type_int && constant < 0)
return false;
uint32_t constant32 = op_info.op.constantValue();
if (literal != (constant32 & BITFIELD_MASK(litbits_used)))
return false;
literal = constant32;
litbits_used = 32;
continue;
}
/* remove modifiers on constants: apply extract, f2f32, abs, neg */
assert(op_info.op.size() == 1);
@@ -756,8 +717,6 @@ alu_opt_info_is_valid(opt_ctx& ctx, alu_opt_info& info)
if (op == other) {
constant_limit++;
break;
} else if (op.isLiteral() && other.isLiteral()) {
return false;
}
}
}
@@ -1708,7 +1667,7 @@ parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* bas
if (add_instr->operands[i].isConstant()) {
*offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
} else if (add_instr->operands[i].isTemp() &&
ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
ctx.info[add_instr->operands[i].tempId()].is_constant()) {
*offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
} else {
continue;
@@ -1783,14 +1742,14 @@ smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
Temp base;
uint32_t offset;
if (info.is_constant_or_literal(32) && info.val <= ctx.program->dev.smem_offset_max) {
if (info.is_constant() && info.val <= ctx.program->dev.smem_offset_max) {
instr->operands[1] = Operand::c32(info.val);
} else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
base.regClass() == s1 && offset <= ctx.program->dev.smem_offset_max &&
ctx.program->gfx_level >= GFX9 && offset % align == 0) {
bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
if (soe) {
if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
if (ctx.info[smem.operands.back().tempId()].is_constant() &&
ctx.info[smem.operands.back().tempId()].val == 0) {
smem.operands[1] = Operand::c32(offset);
smem.operands.back() = Operand(base);
@@ -1947,8 +1906,8 @@ is_op_canonicalized(opt_ctx& ctx, Operand op)
(op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
return true;
if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant())) {
uint64_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
if (op.bytes() == 2)
return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
else if (op.bytes() == 4)
@@ -2059,7 +2018,7 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type)
return true;
}
if (info.is_constant_or_literal(type.bit_size)) {
if (info.is_constant()) {
op_info.op = get_constant_op(ctx, info, type.bit_size);
return true;
}
@@ -2388,7 +2347,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
info = ctx.info[info.temp.id()];
}
unsigned bits = instr->operands[i].bytes() * 8u;
if (info.is_constant_or_literal(bits) && pseudo_can_accept_constant(instr, i)) {
if (info.is_constant() && pseudo_can_accept_constant(instr, i)) {
instr->operands[i] = get_constant_op(ctx, info, bits);
continue;
}
@@ -2423,15 +2382,14 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
mubuf.offset += info.parent_instr->operands[1].constantValue();
mubuf.offen = false;
continue;
} else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
} else if (mubuf.offen && i == 1 && info.is_constant() &&
mubuf.offset + info.val <= const_max) {
assert(!mubuf.idxen);
instr->operands[1] = Operand(v1);
mubuf.offset += info.val;
mubuf.offen = false;
continue;
} else if (i == 2 && info.is_constant_or_literal(32) &&
mubuf.offset + info.val <= const_max) {
} else if (i == 2 && info.is_constant() && mubuf.offset + info.val <= const_max) {
instr->operands[2] = Operand::c32(0);
mubuf.offset += info.val;
continue;
@@ -2497,8 +2455,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->operands[i].setTemp(base);
scratch.offset += (int32_t)offset;
continue;
} else if (i <= 1 && info.is_constant_or_literal(32) &&
ctx.program->gfx_level >= GFX10_3 &&
} else if (i <= 1 && info.is_constant() && ctx.program->gfx_level >= GFX10_3 &&
is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
/* GFX10.3+ can disable both SADDR and ADDR. */
instr->operands[i] = Operand(instr->operands[i].regClass());
@@ -2626,11 +2583,11 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
case aco_opcode::p_split_vector: {
ssa_info& info = ctx.info[instr->operands[0].tempId()];
if (info.is_constant_or_literal(32)) {
if (info.is_constant()) {
uint64_t val = info.val;
for (Definition def : instr->definitions) {
uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
uint64_t mask = u_bit_consecutive64(0, def.bytes() * 8u);
ctx.info[def.tempId()].set_constant(val & mask);
val >>= def.bytes() * 8u;
}
break;
@@ -2661,8 +2618,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
Operand vec_op = vec->operands[vec_index];
if (vec_op.isConstant()) {
ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
vec_op.constantValue64());
ctx.info[instr->definitions[i].tempId()].set_constant(vec_op.constantValue64());
} else if (vec_op.isTemp()) {
ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
}
@@ -2691,13 +2647,12 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->operands[0] = op;
break;
}
} else if (info.is_constant_or_literal(32)) {
} else if (info.is_constant()) {
/* propagate constants */
uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
uint64_t mask = u_bit_consecutive64(0, instr->definitions[0].bytes() * 8u);
uint64_t val = (info.val >> (dst_offset * 8u)) & mask;
instr->operands[0] =
Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
;
}
}
@@ -2746,7 +2701,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* don't copy-propagate copies into fixed registers */
} else if (instr->operands[0].isConstant()) {
ctx.info[instr->definitions[0].tempId()].set_constant(
ctx.program->gfx_level, instr->operands[0].constantValue64());
instr->operands[0].constantValue64());
} else if (instr->operands[0].isTemp()) {
ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
@@ -2757,7 +2712,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
break;
case aco_opcode::p_is_helper:
if (!ctx.program->needs_wqm)
ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
ctx.info[instr->definitions[0].tempId()].set_constant(0u);
break;
case aco_opcode::v_mul_f16:
case aco_opcode::v_mul_f32:
@@ -2807,7 +2762,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
((!instr->definitions[0].isNaNPreserve() &&
!instr->definitions[0].isInfPreserve()) ||
instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
ctx.info[instr->definitions[0].tempId()].set_constant(0u);
} else if (denorm_mode != fp_denorm_flush) {
/* omod has no effect if denormals are enabled. */
continue;
@@ -2997,7 +2952,7 @@ is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value
return true;
} else if (op.isTemp()) {
unsigned id = original_temp_id(ctx, op.getTemp());
if (!ctx.info[id].is_constant_or_literal(bit_size))
if (!ctx.info[id].is_constant())
return false;
*value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
return true;
@@ -3677,8 +3632,7 @@ combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opc
bool hi16 = opsel & (1 << i);
if (operands[i].isConstant()) {
val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
} else if (operands[i].isTemp() &&
ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
} else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant()) {
val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
} else {
continue;
@@ -4400,9 +4354,12 @@ combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
bool
is_pow_of_two(opt_ctx& ctx, Operand op)
{
if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
else if (!op.isConstant())
if (op.isTemp()) {
unsigned id = original_temp_id(ctx, op.getTemp());
if (ctx.info[id].is_constant())
op = get_constant_op(ctx, ctx.info[id], op.bytes() * 8);
}
if (!op.isConstant())
return false;
uint64_t val = op.constantValue64();
@@ -5279,7 +5236,7 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
if (!op.isTemp() || op.isFixed())
return false;
auto& temp_info = ctx.info[op.tempId()];
return temp_info.is_constant_or_literal(op.size() * 32);
return temp_info.is_constant();
}))
return;
@@ -5293,7 +5250,7 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
if (!op.isTemp() || op.isFixed())
continue;
auto& temp_info = ctx.info[op.tempId()];
if (temp_info.is_constant_or_literal(op.size() * 32))
if (temp_info.is_constant())
literal_mask |= BITFIELD_BIT(i);
}
@@ -5519,8 +5476,7 @@ apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* apply literals on SALU/VALU */
if (instr->isSALU() || instr->isVALU()) {
for (const Operand& op : instr->operands) {
if (op.isTemp() && ctx.info[op.tempId()].is_literal(op.size() * 32) &&
ctx.uses[op.tempId()] == 0) {
if (op.isTemp() && ctx.info[op.tempId()].is_constant() && ctx.uses[op.tempId()] == 0) {
alu_opt_info info;
if (!alu_opt_gather_info(ctx, instr.get(), info))
UNREACHABLE("We already check that we can apply lit");