diff --git a/src/gallium/drivers/radeonsi/si_state.h b/src/gallium/drivers/radeonsi/si_state.h index 581df046d31..f750ea4bbe1 100644 --- a/src/gallium/drivers/radeonsi/si_state.h +++ b/src/gallium/drivers/radeonsi/si_state.h @@ -595,7 +595,7 @@ void si_ps_key_update_sample_shading(struct si_context *sctx); void si_ps_key_update_framebuffer_rasterizer_sample_shading(struct si_context *sctx); void si_init_tess_factor_ring(struct si_context *sctx); bool si_update_gs_ring_buffers(struct si_context *sctx); -bool si_update_spi_tmpring_size(struct si_context *sctx); +bool si_update_spi_tmpring_size(struct si_context *sctx, unsigned bytes); /* si_state_draw.c */ void si_init_draw_functions_GFX6(struct si_context *sctx); diff --git a/src/gallium/drivers/radeonsi/si_state_draw.cpp b/src/gallium/drivers/radeonsi/si_state_draw.cpp index b115dd1b619..0940e69dd38 100644 --- a/src/gallium/drivers/radeonsi/si_state_draw.cpp +++ b/src/gallium/drivers/radeonsi/si_state_draw.cpp @@ -289,7 +289,34 @@ static bool si_update_shaders(struct si_context *sctx) (si_pm4_state_enabled_and_changed(sctx, ls) || si_pm4_state_enabled_and_changed(sctx, es))) || si_pm4_state_enabled_and_changed(sctx, hs) || si_pm4_state_enabled_and_changed(sctx, gs) || si_pm4_state_enabled_and_changed(sctx, vs) || si_pm4_state_enabled_and_changed(sctx, ps)) { - if (!si_update_spi_tmpring_size(sctx)) + unsigned scratch_size = 0; + + if (HAS_TESS) { + if (GFX_VERSION <= GFX8) /* LS */ + scratch_size = MAX2(scratch_size, sctx->shader.vs.current->config.scratch_bytes_per_wave); + + scratch_size = MAX2(scratch_size, sctx->queued.named.hs->shader->config.scratch_bytes_per_wave); + + if (HAS_GS) { + if (GFX_VERSION <= GFX8) /* ES */ + scratch_size = MAX2(scratch_size, sctx->shader.tes.current->config.scratch_bytes_per_wave); + + scratch_size = MAX2(scratch_size, sctx->shader.gs.current->config.scratch_bytes_per_wave); + } else { + scratch_size = MAX2(scratch_size, sctx->shader.tes.current->config.scratch_bytes_per_wave); + } + } else if (HAS_GS) { + if (GFX_VERSION <= GFX8) /* ES */ + scratch_size = MAX2(scratch_size, sctx->shader.vs.current->config.scratch_bytes_per_wave); + + scratch_size = MAX2(scratch_size, sctx->shader.gs.current->config.scratch_bytes_per_wave); + } else { + scratch_size = MAX2(scratch_size, sctx->shader.vs.current->config.scratch_bytes_per_wave); + } + + scratch_size = MAX2(scratch_size, sctx->shader.ps.current->config.scratch_bytes_per_wave); + + if (scratch_size && !si_update_spi_tmpring_size(sctx, scratch_size)) return false; if (GFX_VERSION >= GFX7) { diff --git a/src/gallium/drivers/radeonsi/si_state_shaders.c b/src/gallium/drivers/radeonsi/si_state_shaders.c index 703987e7427..7e4543d8701 100644 --- a/src/gallium/drivers/radeonsi/si_state_shaders.c +++ b/src/gallium/drivers/radeonsi/si_state_shaders.c @@ -3832,11 +3832,6 @@ static int si_update_scratch_buffer(struct si_context *sctx, struct si_shader *s return 1; } -static unsigned si_get_scratch_buffer_bytes_per_wave(struct si_shader *shader) -{ - return shader ? shader->config.scratch_bytes_per_wave : 0; -} - static struct si_shader *si_get_tcs_current(struct si_context *sctx) { if (!sctx->shader.tes.cso) @@ -3904,7 +3899,7 @@ static bool si_update_scratch_relocs(struct si_context *sctx) return true; } -bool si_update_spi_tmpring_size(struct si_context *sctx) +bool si_update_spi_tmpring_size(struct si_context *sctx, unsigned bytes) { /* SPI_TMPRING_SIZE.WAVESIZE must be constant for each scratch buffer. * There are 2 cases to handle: @@ -3919,17 +3914,6 @@ bool si_update_spi_tmpring_size(struct si_context *sctx) * Otherwise, the number of waves that can use scratch is * SPI_TMPRING_SIZE.WAVES. */ - unsigned bytes = 0; - - bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->shader.ps.current)); - bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->shader.gs.current)); - bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->shader.vs.current)); - - if (sctx->shader.tes.cso) { - bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->shader.tes.current)); - bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(si_get_tcs_current(sctx))); - } - sctx->max_seen_scratch_bytes_per_wave = MAX2(sctx->max_seen_scratch_bytes_per_wave, bytes); unsigned scratch_needed_size = sctx->max_seen_scratch_bytes_per_wave * sctx->scratch_waves;