diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 97bc6784414..e1fc4e89486 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -884,6 +884,200 @@ parse_rt_stage(struct radv_device *device, struct radv_pipeline_layout *layout, return shader; } +static nir_function_impl * +lower_any_hit_for_intersection(nir_shader *any_hit) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(any_hit); + + /* Any-hit shaders need three parameters */ + assert(impl->function->num_params == 0); + nir_parameter params[] = { + { + /* A pointer to a boolean value for whether or not the hit was + * accepted. + */ + .num_components = 1, + .bit_size = 32, + }, + { + /* The hit T value */ + .num_components = 1, + .bit_size = 32, + }, + { + /* The hit kind */ + .num_components = 1, + .bit_size = 32, + }, + }; + impl->function->num_params = ARRAY_SIZE(params); + impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params)); + memcpy(impl->function->params, params, sizeof(params)); + + nir_builder build; + nir_builder_init(&build, impl); + nir_builder *b = &build; + + b->cursor = nir_before_cf_list(&impl->body); + + nir_ssa_def *commit_ptr = nir_load_param(b, 0); + nir_ssa_def *hit_t = nir_load_param(b, 1); + nir_ssa_def *hit_kind = nir_load_param(b, 2); + + nir_deref_instr *commit = + nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0); + + nir_foreach_block_safe (block, impl) { + nir_foreach_instr_safe (instr, block) { + switch (instr->type) { + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_ignore_ray_intersection: + b->cursor = nir_instr_remove(&intrin->instr); + /* We put the newly emitted code inside a dummy if because it's + * going to contain a jump instruction and we don't want to + * deal with that mess here. It'll get dealt with by our + * control-flow optimization passes. + */ + nir_store_deref(b, commit, nir_imm_false(b), 0x1); + nir_push_if(b, nir_imm_true(b)); + nir_jump(b, nir_jump_halt); + nir_pop_if(b, NULL); + break; + + case nir_intrinsic_terminate_ray: + /* The "normal" handling of terminateRay works fine in + * intersection shaders. + */ + break; + + case nir_intrinsic_load_ray_t_max: + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t); + nir_instr_remove(&intrin->instr); + break; + + case nir_intrinsic_load_ray_hit_kind: + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind); + nir_instr_remove(&intrin->instr); + break; + + default: + break; + } + break; + } + case nir_instr_type_jump: { + nir_jump_instr *jump = nir_instr_as_jump(instr); + if (jump->type == nir_jump_halt) { + b->cursor = nir_instr_remove(instr); + nir_jump(b, nir_jump_return); + } + break; + } + + default: + break; + } + } + } + + nir_validate_shader(any_hit, "after initial any-hit lowering"); + + nir_lower_returns_impl(impl); + + nir_validate_shader(any_hit, "after lowering returns"); + + return impl; +} + +/* Inline the any_hit shader into the intersection shader so we don't have + * to implement yet another shader call interface here. Neither do any recursion. + */ +static void +nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit) +{ + void *dead_ctx = ralloc_context(intersection); + + nir_function_impl *any_hit_impl = NULL; + struct hash_table *any_hit_var_remap = NULL; + if (any_hit) { + any_hit = nir_shader_clone(dead_ctx, any_hit); + NIR_PASS_V(any_hit, nir_opt_dce); + any_hit_impl = lower_any_hit_for_intersection(any_hit); + any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx); + } + + nir_function_impl *impl = nir_shader_get_entrypoint(intersection); + + nir_builder build; + nir_builder_init(&build, impl); + nir_builder *b = &build; + + b->cursor = nir_before_cf_list(&impl->body); + + nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit"); + nir_store_var(b, commit, nir_imm_false(b), 0x1); + + nir_foreach_block_safe (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_report_ray_intersection) + continue; + + b->cursor = nir_instr_remove(&intrin->instr); + nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1); + nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1); + nir_ssa_def *min_t = nir_load_ray_t_min(b); + nir_ssa_def *max_t = nir_load_ray_t_max(b); + + /* bool commit_tmp = false; */ + nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp"); + nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1); + + nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t))); + { + /* Any-hit defaults to commit */ + nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1); + + if (any_hit_impl != NULL) { + nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b))); + { + nir_ssa_def *params[] = { + &nir_build_deref_var(b, commit_tmp)->dest.ssa, + hit_t, + hit_kind, + }; + nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap); + } + nir_pop_if(b, NULL); + } + + nir_push_if(b, nir_load_var(b, commit_tmp)); + { + nir_report_ray_intersection(b, 1, hit_t, hit_kind); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); + + nir_ssa_def *accepted = nir_load_var(b, commit_tmp); + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted); + } + } + + /* We did some inlining; have to re-index SSA defs */ + nir_index_ssa_defs(impl); + + /* Eliminate the casts introduced for the commit return of the any-hit shader. */ + NIR_PASS_V(intersection, nir_opt_deref); + + ralloc_free(dead_ctx); +} + static nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes)