zink: pass KERNEL shaders through successfully

basically just merging with COMPUTE cases

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19327>
This commit is contained in:
Mike Blumenkrantz
2022-10-17 10:11:08 -04:00
committed by Marge Bot
parent 2a08b97330
commit 037bbabcb9
4 changed files with 21 additions and 12 deletions
@@ -3335,7 +3335,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
break;
case nir_intrinsic_control_barrier:
if (ctx->stage == MESA_SHADER_COMPUTE)
if (gl_shader_stage_is_compute(ctx->stage))
spirv_builder_emit_control_barrier(&ctx->builder, SpvScopeWorkgroup,
SpvScopeWorkgroup,
SpvMemorySemanticsWorkgroupMemoryMask | SpvMemorySemanticsAcquireReleaseMask);
@@ -4428,7 +4428,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
ctx.explicit_lod = true;
spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageUnknown, 0);
if (s->info.stage == MESA_SHADER_COMPUTE) {
if (gl_shader_stage_is_compute(s->info.stage)) {
SpvAddressingModel model;
if (s->info.cs.ptr_size == 32)
model = SpvAddressingModelPhysical32;
@@ -4474,6 +4474,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
exec_model = SpvExecutionModelFragment;
break;
case MESA_SHADER_COMPUTE:
case MESA_SHADER_KERNEL:
exec_model = SpvExecutionModelGLCompute;
break;
default:
@@ -4597,6 +4598,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
SpvExecutionModeOutputVertices,
MAX2(s->info.gs.vertices_out, 1));
break;
case MESA_SHADER_KERNEL:
case MESA_SHADER_COMPUTE:
if (s->info.workgroup_size[0] || s->info.workgroup_size[1] || s->info.workgroup_size[2])
spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, SpvExecutionModeLocalSize,
+9 -8
View File
@@ -2245,7 +2245,7 @@ zink_shader_spirv_compile(struct zink_screen *screen, struct zink_shader *zs, st
}
nir_shader *nir = spirv_to_nir(spirv->words, spirv->num_words,
spec_entries, num_spec_entries,
zs->nir->info.stage, "main", &spirv_options, &screen->nir_options);
clamp_stage(zs->nir), "main", &spirv_options, &screen->nir_options);
assert(nir);
ralloc_free(nir);
free(spec_entries);
@@ -2791,7 +2791,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index, bool compa
} else {
unsigned base = stage;
/* clamp compute bindings for better driver efficiency */
if (stage == MESA_SHADER_COMPUTE)
if (gl_shader_stage_is_compute(stage))
base = 0;
switch (type) {
case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
@@ -3263,7 +3263,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
subgroup_options.ballot_bit_size = 32;
subgroup_options.ballot_components = 4;
subgroup_options.lower_subgroup_masks = true;
if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(nir->info.stage))) {
if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(clamp_stage(nir)))) {
subgroup_options.subgroup_size = 1;
subgroup_options.lower_vote_trivial = true;
}
@@ -3325,8 +3325,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
ztype = ZINK_DESCRIPTOR_TYPE_UBO;
/* buffer 0 is a push descriptor */
var->data.descriptor_set = !!var->data.driver_location;
var->data.binding = !var->data.driver_location ? nir->info.stage :
zink_binding(nir->info.stage,
var->data.binding = !var->data.driver_location ? clamp_stage(nir) :
zink_binding(clamp_stage(nir),
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
var->data.driver_location,
screen->compact_descriptors);
@@ -3347,7 +3347,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
} else if (var->data.mode == nir_var_mem_ssbo) {
ztype = ZINK_DESCRIPTOR_TYPE_SSBO;
var->data.descriptor_set = screen->desc_set_id[ztype];
var->data.binding = zink_binding(nir->info.stage,
var->data.binding = zink_binding(clamp_stage(nir),
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
var->data.driver_location,
screen->compact_descriptors);
@@ -3370,7 +3370,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
ret->num_texel_buffers++;
var->data.driver_location = var->data.binding;
var->data.descriptor_set = screen->desc_set_id[ztype];
var->data.binding = zink_binding(nir->info.stage, vktype, var->data.driver_location, screen->compact_descriptors);
var->data.binding = zink_binding(clamp_stage(nir), vktype, var->data.driver_location, screen->compact_descriptors);
ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location;
ret->bindings[ztype][ret->num_bindings[ztype]].binding = var->data.binding;
ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;
@@ -3389,7 +3389,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
if (!screen->info.feats.features.shaderInt64 || !screen->info.feats.features.shaderFloat64)
NIR_PASS_V(nir, lower_64bit_vars, screen->info.feats.features.shaderInt64);
NIR_PASS_V(nir, match_tex_dests);
if (nir->info.stage != MESA_SHADER_KERNEL)
NIR_PASS_V(nir, match_tex_dests);
ret->nir = nir;
nir_foreach_shader_out_variable(var, nir)
+5
View File
@@ -40,6 +40,11 @@ struct spirv_shader;
struct tgsi_token;
static inline gl_shader_stage
clamp_stage(nir_shader *nir)
{
return nir->info.stage == MESA_SHADER_KERNEL ? MESA_SHADER_COMPUTE : nir->info.stage;
}
const void *
zink_get_compiler_options(struct pipe_screen *screen,
+3 -2
View File
@@ -26,6 +26,7 @@
*/
#include "zink_context.h"
#include "zink_compiler.h"
#include "zink_descriptors.h"
#include "zink_program.h"
#include "zink_render_pass.h"
@@ -308,7 +309,7 @@ init_template_entry(struct zink_shader *shader, enum zink_descriptor_type type,
unsigned idx, VkDescriptorUpdateTemplateEntry *entry, unsigned *entry_idx)
{
int index = shader->bindings[type][idx].index;
gl_shader_stage stage = shader->nir->info.stage;
gl_shader_stage stage = clamp_stage(shader->nir);
entry->dstArrayElement = 0;
entry->dstBinding = shader->bindings[type][idx].binding;
entry->descriptorCount = shader->bindings[type][idx].size;
@@ -423,7 +424,7 @@ zink_descriptor_program_init(struct zink_context *ctx, struct zink_program *pg)
if (!shader)
continue;
gl_shader_stage stage = shader->nir->info.stage;
gl_shader_stage stage = clamp_stage(shader->nir);
VkShaderStageFlagBits stage_flags = mesa_to_vk_shader_stage(stage);
/* uniform ubos handled in push */
if (shader->has_uniforms) {