aco: Clean up and fix quad group instructions with WQM.

According to the Vulkan spec chapter 9.25 Helper Invocations,
quad group operations have to be executed by helper invocations.

This commit cleans up the code for quad group instructions by
unifying the code path of quad broadcast with the others, and then
calling emit_wqm just once at the end.

Fixes: 93c8ebfa78
Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/5570
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13929>
This commit is contained in:
Timur Kristóf
2021-11-23 16:50:20 +01:00
committed by Marge Bot
parent af163d7220
commit 77db4e27b1
+68 -108
View File
@@ -8465,146 +8465,106 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
}
break;
}
case nir_intrinsic_quad_broadcast: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
if (!nir_dest_is_divergent(instr->dest)) {
emit_uniform_subgroup(ctx, instr, src);
} else {
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
unsigned lane = nir_src_as_const_value(instr->src[1])->u32;
uint32_t dpp_ctrl = dpp_quad_perm(lane, lane, lane, lane);
if (instr->dest.ssa.bit_size != 1)
src = as_vgpr(ctx, src);
if (instr->dest.ssa.bit_size == 1) {
assert(src.regClass() == bld.lm);
assert(dst.regClass() == bld.lm);
uint32_t half_mask = 0x11111111u << lane;
Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2),
Operand::c32(half_mask), Operand::c32(half_mask));
Temp tmp = bld.tmp(bld.lm);
bld.sop1(Builder::s_wqm, Definition(tmp),
bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), mask_tmp,
bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src,
Operand(exec, bld.lm))));
emit_wqm(bld, tmp, dst);
} else if (instr->dest.ssa.bit_size == 8) {
Temp tmp = bld.tmp(v1);
if (ctx->program->chip_class >= GFX8)
emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
else
emit_wqm(bld,
bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
tmp);
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v3b), tmp);
} else if (instr->dest.ssa.bit_size == 16) {
Temp tmp = bld.tmp(v1);
if (ctx->program->chip_class >= GFX8)
emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
else
emit_wqm(bld,
bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
tmp);
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
} else if (instr->dest.ssa.bit_size == 32) {
if (ctx->program->chip_class >= GFX8)
emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), dst);
else
emit_wqm(bld,
bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
dst);
} else if (instr->dest.ssa.bit_size == 64) {
Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
if (ctx->program->chip_class >= GFX8) {
lo = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl));
hi = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl));
} else {
lo = emit_wqm(
bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, (1 << 15) | dpp_ctrl));
hi = emit_wqm(
bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, (1 << 15) | dpp_ctrl));
}
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
emit_split_vector(ctx, dst, 2);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
}
break;
}
case nir_intrinsic_quad_broadcast:
case nir_intrinsic_quad_swap_horizontal:
case nir_intrinsic_quad_swap_vertical:
case nir_intrinsic_quad_swap_diagonal:
case nir_intrinsic_quad_swizzle_amd: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
if (!nir_dest_is_divergent(instr->dest)) {
emit_uniform_subgroup(ctx, instr, src);
break;
}
/* Quad broadcast lane. */
unsigned lane = 0;
/* Use VALU for the bool instructions that don't have a SALU-only special case. */
bool bool_use_valu = instr->dest.ssa.bit_size == 1;
uint16_t dpp_ctrl = 0;
switch (instr->intrinsic) {
case nir_intrinsic_quad_swap_horizontal: dpp_ctrl = dpp_quad_perm(1, 0, 3, 2); break;
case nir_intrinsic_quad_swap_vertical: dpp_ctrl = dpp_quad_perm(2, 3, 0, 1); break;
case nir_intrinsic_quad_swap_diagonal: dpp_ctrl = dpp_quad_perm(3, 2, 1, 0); break;
case nir_intrinsic_quad_swizzle_amd: dpp_ctrl = nir_intrinsic_swizzle_mask(instr); break;
case nir_intrinsic_quad_broadcast:
lane = nir_src_as_const_value(instr->src[1])->u32;
dpp_ctrl = dpp_quad_perm(lane, lane, lane, lane);
bool_use_valu = false;
break;
default: break;
}
if (ctx->program->chip_class < GFX8)
dpp_ctrl |= (1 << 15);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
Temp tmp(dst);
if (instr->dest.ssa.bit_size != 1)
src = as_vgpr(ctx, src);
if (instr->dest.ssa.bit_size == 1) {
assert(src.regClass() == bld.lm);
/* Setup source. */
if (bool_use_valu)
src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::zero(),
Operand::c32(-1), src);
else if (instr->dest.ssa.bit_size != 1)
src = as_vgpr(ctx, src);
/* Setup temporary destination. */
if (bool_use_valu)
tmp = bld.tmp(v1);
else if (ctx->program->stage == fragment_fs)
tmp = bld.tmp(dst.regClass());
if (instr->dest.ssa.bit_size == 1 && instr->intrinsic == nir_intrinsic_quad_broadcast) {
/* Special case for quad broadcast using SALU only. */
assert(src.regClass() == bld.lm && tmp.regClass() == bld.lm);
uint32_t half_mask = 0x11111111u << lane;
Operand mask_tmp = bld.lm.bytes() == 4
? Operand::c32(half_mask)
: bld.pseudo(aco_opcode::p_create_vector, bld.def(bld.lm),
Operand::c32(half_mask), Operand::c32(half_mask));
src =
bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm));
src = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), mask_tmp, src);
bld.sop1(Builder::s_wqm, Definition(tmp), src);
} else if (instr->dest.ssa.bit_size <= 32 || bool_use_valu) {
unsigned excess_bytes = bool_use_valu ? 0 : 4 - instr->dest.ssa.bit_size / 8;
Definition def = excess_bytes ? bld.def(v1) : Definition(tmp);
if (ctx->program->chip_class >= GFX8)
src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
bld.vop1_dpp(aco_opcode::v_mov_b32, def, src, dpp_ctrl);
else
src = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl);
Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(bld.lm), Operand::zero(), src);
emit_wqm(bld, tmp, dst);
} else if (instr->dest.ssa.bit_size == 8) {
Temp tmp = bld.tmp(v1);
if (ctx->program->chip_class >= GFX8)
emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
else
emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl), tmp);
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v3b), tmp);
} else if (instr->dest.ssa.bit_size == 16) {
Temp tmp = bld.tmp(v1);
if (ctx->program->chip_class >= GFX8)
emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
else
emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl), tmp);
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
} else if (instr->dest.ssa.bit_size == 32) {
Temp tmp;
if (ctx->program->chip_class >= GFX8)
tmp = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
else
tmp = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl);
emit_wqm(bld, tmp, dst);
bld.ds(aco_opcode::ds_swizzle_b32, def, src, (1 << 15) | dpp_ctrl);
if (excess_bytes)
bld.pseudo(aco_opcode::p_split_vector, Definition(tmp),
bld.def(RegClass::get(tmp.type(), excess_bytes)), def.getTemp());
} else if (instr->dest.ssa.bit_size == 64) {
Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
if (ctx->program->chip_class >= GFX8) {
lo = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl));
hi = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl));
lo = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl);
hi = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl);
} else {
lo = emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, dpp_ctrl));
hi = emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, dpp_ctrl));
lo = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, (1 << 15) | dpp_ctrl);
hi = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, (1 << 15) | dpp_ctrl);
}
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
emit_split_vector(ctx, dst, 2);
bld.pseudo(aco_opcode::p_create_vector, Definition(tmp), lo, hi);
emit_split_vector(ctx, tmp, 2);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
isel_err(&instr->instr, "Unimplemented NIR quad group instruction bit size.");
}
if (tmp.id() != dst.id()) {
if (bool_use_valu)
tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(bld.lm), Operand::zero(), tmp);
/* Vulkan spec 9.25: Helper invocations must be active for quad group instructions. */
emit_wqm(bld, tmp, dst, true);
}
break;
}
case nir_intrinsic_masked_swizzle_amd: {