diff --git a/src/asahi/compiler/agx_compile.h b/src/asahi/compiler/agx_compile.h
index b874ee1b6a1..a3392faeec1 100644
--- a/src/asahi/compiler/agx_compile.h
+++ b/src/asahi/compiler/agx_compile.h
@@ -94,6 +94,9 @@ struct agx_shader_info {
/* Does the shader write point size? */
bool writes_psiz;
+
+ /* Does the shader control the sample mask? */
+ bool writes_sample_mask;
};
#define AGX_MAX_RTS (8)
diff --git a/src/asahi/lib/cmdbuf.xml b/src/asahi/lib/cmdbuf.xml
index c3d96ca2ff3..47b0767bc9d 100644
--- a/src/asahi/lib/cmdbuf.xml
+++ b/src/asahi/lib/cmdbuf.xml
@@ -313,6 +313,7 @@
+
diff --git a/src/gallium/drivers/asahi/agx_state.c b/src/gallium/drivers/asahi/agx_state.c
index d1136131c1e..f6f4428c5bd 100644
--- a/src/gallium/drivers/asahi/agx_state.c
+++ b/src/gallium/drivers/asahi/agx_state.c
@@ -1101,13 +1101,16 @@ agx_build_pipeline(struct agx_context *ctx, struct agx_compiled_shader *cs, enum
/* TODO: Can we prepack this? */
if (stage == PIPE_SHADER_FRAGMENT) {
+ bool writes_sample_mask = ctx->fs->info.writes_sample_mask;
+
agx_pack(record, SET_SHADER_EXTENDED, cfg) {
cfg.code = cs->bo->ptr.gpu;
cfg.register_quadwords = 0;
cfg.unk_3 = 0x8d;
cfg.unk_1 = 0x2010bd;
cfg.unk_2 = 0x0d;
- cfg.unk_2b = 1;
+ cfg.unk_2b = writes_sample_mask ? 5 : 1;
+ cfg.fragment_parameters.early_z_testing = !writes_sample_mask;
cfg.unk_3b = 0x1;
cfg.unk_4 = 0x800;
cfg.preshader_unk = 0xc080;
@@ -1389,13 +1392,14 @@ demo_rasterizer(struct agx_context *ctx, struct agx_pool *pool, bool is_points)
}
static uint64_t
-demo_unk11(struct agx_pool *pool, bool prim_lines, bool prim_points, bool reads_tib)
+demo_unk11(struct agx_pool *pool, bool prim_lines, bool prim_points, bool reads_tib, bool sample_mask_from_shader)
{
struct agx_ptr T = agx_pool_alloc_aligned(pool, AGX_UNKNOWN_4A_LENGTH, 64);
agx_pack(T.cpu, UNKNOWN_4A, cfg) {
cfg.lines_or_points = (prim_lines || prim_points);
cfg.reads_tilebuffer = reads_tib;
+ cfg.sample_mask_from_shader = sample_mask_from_shader;
cfg.front.lines = cfg.back.lines = prim_lines;
cfg.front.points = cfg.back.points = prim_points;
@@ -1461,12 +1465,13 @@ agx_encode_state(struct agx_context *ctx, uint8_t *out,
struct agx_pool *pool = &ctx->batch->pool;
bool reads_tib = ctx->fs->info.reads_tib;
+ bool sample_mask_from_shader = ctx->fs->info.writes_sample_mask;
agx_push_record(&out, 5, demo_interpolation(ctx->fs, pool));
agx_push_record(&out, 5, demo_launch_fragment(ctx, pool, pipeline_fragment, varyings, ctx->fs->info.varyings.nr_descs));
agx_push_record(&out, 4, demo_linkage(ctx->vs, pool));
agx_push_record(&out, 7, demo_rasterizer(ctx, pool, is_points));
- agx_push_record(&out, 5, demo_unk11(pool, is_lines, is_points, reads_tib));
+ agx_push_record(&out, 5, demo_unk11(pool, is_lines, is_points, reads_tib, sample_mask_from_shader));
if (ctx->dirty & (AGX_DIRTY_VIEWPORT | AGX_DIRTY_SCISSOR)) {
struct agx_viewport_scissor vps = agx_upload_viewport_scissor(pool,