diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index e1fc4e89486..f65c83a7f0e 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1078,6 +1078,654 @@ nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit) ralloc_free(dead_ctx); } +/* Variables only used internally to ray traversal. This is data that describes + * the current state of the traversal vs. what we'd give to a shader. e.g. what + * is the instance we're currently visiting vs. what is the instance of the + * closest hit. */ +struct rt_traversal_vars { + nir_variable *origin; + nir_variable *dir; + nir_variable *inv_dir; + nir_variable *sbt_offset_and_flags; + nir_variable *instance_id; + nir_variable *custom_instance_and_mask; + nir_variable *instance_addr; + nir_variable *should_return; + nir_variable *bvh_base; + nir_variable *stack; + nir_variable *top_stack; +}; + +static struct rt_traversal_vars +init_traversal_vars(nir_builder *b) +{ + const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); + struct rt_traversal_vars ret; + + ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin"); + ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir"); + ret.inv_dir = + nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir"); + ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), + "traversal_sbt_offset_and_flags"); + ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), + "traversal_instance_id"); + ret.custom_instance_and_mask = nir_variable_create( + b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask"); + ret.instance_addr = + nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); + ret.should_return = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), + "traversal_should_return"); + ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), + "traversal_bvh_base"); + ret.stack = + nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr"); + ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), + "traversal_top_stack_ptr"); + return ret; +} + +static nir_ssa_def * +nir_build_addr_to_node(nir_builder *b, nir_ssa_def *addr) +{ + const uint64_t bvh_size = 1ull << 42; + nir_ssa_def *node = nir_ushr(b, addr, nir_imm_int(b, 3)); + return nir_iand(b, node, nir_imm_int64(b, (bvh_size - 1) << 3)); +} + +static nir_ssa_def * +nir_build_node_to_addr(nir_builder *b, nir_ssa_def *node) +{ + nir_ssa_def *addr = nir_iand(b, node, nir_imm_int64(b, ~7ull)); + addr = nir_ishl(b, addr, nir_imm_int(b, 3)); + /* Assumes everything is in the top half of address space, which is true in + * GFX9+ for now. */ + return nir_ior(b, addr, nir_imm_int64(b, 0xffffull << 48)); +} + +/* When a hit is opaque the any_hit shader is skipped for this hit and the hit + * is assumed to be an actual hit. */ +static nir_ssa_def * +hit_is_opaque(nir_builder *b, const struct rt_variables *vars, + const struct rt_traversal_vars *trav_vars, nir_ssa_def *geometry_id_and_flags) +{ + nir_ssa_def *geom_force_opaque = nir_ine( + b, nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 1u << 28 /* VK_GEOMETRY_OPAQUE_BIT */)), + nir_imm_int(b, 0)); + nir_ssa_def *instance_force_opaque = + nir_ine(b, + nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 4 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT */)), + nir_imm_int(b, 0)); + nir_ssa_def *instance_force_non_opaque = + nir_ine(b, + nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 8 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT */)), + nir_imm_int(b, 0)); + + nir_ssa_def *opaque = geom_force_opaque; + opaque = nir_bcsel(b, instance_force_opaque, nir_imm_bool(b, true), opaque); + opaque = nir_bcsel(b, instance_force_non_opaque, nir_imm_bool(b, false), opaque); + + nir_ssa_def *ray_force_opaque = + nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 1 /* RayFlagsOpaque */)), + nir_imm_int(b, 0)); + nir_ssa_def *ray_force_non_opaque = nir_ine( + b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 2 /* RayFlagsNoOpaque */)), + nir_imm_int(b, 0)); + + opaque = nir_bcsel(b, ray_force_opaque, nir_imm_bool(b, true), opaque); + opaque = nir_bcsel(b, ray_force_non_opaque, nir_imm_bool(b, false), opaque); + return opaque; +} + +static void +visit_any_hit_shaders(struct radv_device *device, + const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, + struct rt_variables *vars) +{ + RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout); + nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx); + + nir_push_if(b, nir_ine(b, sbt_idx, nir_imm_int(b, 0))); + for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { + const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; + uint32_t shader_id = VK_SHADER_UNUSED_KHR; + + switch (group_info->type) { + case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: + shader_id = group_info->anyHitShader; + break; + default: + break; + } + if (shader_id == VK_SHADER_UNUSED_KHR) + continue; + + const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; + nir_shader *nir_stage = parse_rt_stage(device, layout, stage); + + vars->group_idx = i; + insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); + } + nir_pop_if(b, NULL); +} + +static void +insert_traversal_triangle_case(struct radv_device *device, + const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, + nir_ssa_def *result, const struct rt_variables *vars, + const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) +{ + nir_ssa_def *dist = nir_vector_extract(b, result, nir_imm_int(b, 0)); + nir_ssa_def *div = nir_vector_extract(b, result, nir_imm_int(b, 1)); + dist = nir_fdiv(b, dist, div); + nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div); + nir_ssa_def *switch_ccw = nir_ine( + b, + nir_iand( + b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 2 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FRONT_COUNTERCLOCKWISE_BIT */)), + nir_imm_int(b, 0)); + frontface = nir_ixor(b, frontface, switch_ccw); + + nir_ssa_def *not_cull = nir_ieq( + b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 256 /* RayFlagsSkipTriangles */)), + nir_imm_int(b, 0)); + nir_ssa_def *not_facing_cull = nir_ieq( + b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_bcsel(b, frontface, nir_imm_int(b, 32 /* RayFlagsCullFrontFacingTriangles */), + nir_imm_int(b, 16 /* RayFlagsCullBackFacingTriangles */))), + nir_imm_int(b, 0)); + + not_cull = nir_iand( + b, not_cull, + nir_ior( + b, not_facing_cull, + nir_ine( + b, + nir_iand( + b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 1 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT */)), + nir_imm_int(b, 0)))); + + nir_push_if(b, nir_iand(b, + nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)), + nir_fge(b, dist, nir_load_var(b, vars->tmin))), + not_cull)); + { + + nir_ssa_def *triangle_info = nir_build_load_global( + b, 2, 32, + nir_iadd(b, nir_build_node_to_addr(b, bvh_node), + nir_imm_int64(b, offsetof(struct radv_bvh_triangle_node, triangle_id))), + .align_mul = 4, .align_offset = 0); + nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); + nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); + nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff)); + nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags); + + not_cull = + nir_ieq(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))), + nir_imm_int(b, 0)); + nir_push_if(b, not_cull); + { + nir_ssa_def *sbt_idx = + nir_iadd(b, + nir_iadd(b, nir_load_var(b, vars->sbt_offset), + nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 0xffffff))), + nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); + nir_ssa_def *divs[2] = {div, div}; + nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2)); + nir_ssa_def *hit_kind = + nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)); + + nir_store_scratch( + b, ij, + nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, RADV_HIT_ATTRIB_OFFSET)), + .align_mul = 16, .write_mask = 3); + + nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1); + + nir_push_if(b, nir_ine(b, is_opaque, nir_imm_bool(b, true))); + { + struct rt_variables inner_vars = create_inner_vars(b, vars); + + nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); + nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); + nir_store_var(b, inner_vars.tmax, dist, 0x1); + nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); + nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), + 0x1); + nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1); + nir_store_var(b, inner_vars.custom_instance_and_mask, + nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); + + load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); + + visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars); + + nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1))); + { + nir_jump(b, nir_jump_continue); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); + + nir_store_var(b, vars->primitive_id, primitive_id, 1); + nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); + nir_store_var(b, vars->tmax, dist, 0x1); + nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); + nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); + nir_store_var(b, vars->hit_kind, hit_kind, 0x1); + nir_store_var(b, vars->custom_instance_and_mask, + nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); + + load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); + + nir_store_var(b, trav_vars->should_return, + nir_ior(b, + nir_ine(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_imm_int(b, 8 /* SkipClosestHitShader */)), + nir_imm_int(b, 0)), + nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))), + 1); + + nir_ssa_def *terminate_on_first_hit = + nir_ine(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)), + nir_imm_int(b, 0)); + nir_ssa_def *ray_terminated = + nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2)); + nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); +} + +static void +insert_traversal_aabb_case(struct radv_device *device, + const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, + nir_ssa_def *result, const struct rt_variables *vars, + const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) +{ + RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout); + + nir_ssa_def *node_addr = nir_build_node_to_addr(b, bvh_node); + nir_ssa_def *triangle_info = nir_build_load_global( + b, 2, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 24)), .align_mul = 4, .align_offset = 0); + nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); + nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); + nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff)); + nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags); + + nir_ssa_def *not_cull = + nir_ieq(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))), + nir_imm_int(b, 0)); + nir_push_if(b, not_cull); + { + nir_ssa_def *sbt_idx = + nir_iadd(b, + nir_iadd(b, nir_load_var(b, vars->sbt_offset), + nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), + nir_imm_int(b, 0xffffff))), + nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); + + struct rt_variables inner_vars = create_inner_vars(b, vars); + + /* For AABBs the intersection shader writes the hit kind, and only does it if it is the + * next closest hit candidate. */ + inner_vars.hit_kind = vars->hit_kind; + + nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); + nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); + nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1); + nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); + nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); + nir_store_var(b, inner_vars.custom_instance_and_mask, + nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); + nir_store_var(b, inner_vars.opaque, is_opaque, 1); + + load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); + + nir_store_var(b, vars->ahit_status, nir_imm_int(b, 1), 1); + + nir_push_if(b, nir_ine(b, nir_load_var(b, inner_vars.idx), nir_imm_int(b, 0))); + for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { + const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; + uint32_t shader_id = VK_SHADER_UNUSED_KHR; + uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR; + + switch (group_info->type) { + case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: + shader_id = group_info->intersectionShader; + any_hit_shader_id = group_info->anyHitShader; + break; + default: + break; + } + if (shader_id == VK_SHADER_UNUSED_KHR) + continue; + + const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; + nir_shader *nir_stage = parse_rt_stage(device, layout, stage); + + nir_shader *any_hit_stage = NULL; + if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { + stage = &pCreateInfo->pStages[any_hit_shader_id]; + any_hit_stage = parse_rt_stage(device, layout, stage); + + nir_lower_intersection_shader(nir_stage, any_hit_stage); + ralloc_free(any_hit_stage); + } + + inner_vars.group_idx = i; + insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2); + } + nir_push_else(b, NULL); + { + nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7); + nir_ssa_def *vec3_inf = + nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7); + + nir_ssa_def *bvh_lo = + nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 0)), + .align_mul = 4, .align_offset = 0); + nir_ssa_def *bvh_hi = + nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 12)), + .align_mul = 4, .align_offset = 0); + + bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin)); + bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin)); + nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), + nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); + nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), + nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); + /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */ + t2_vec = + nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec); + + nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1)); + t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2)); + + nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1)); + t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2)); + + nir_push_if(b, nir_iand(b, nir_flt(b, t_min, nir_load_var(b, vars->tmax)), + nir_fge(b, t_max, nir_load_var(b, vars->tmin)))); + { + nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1); + nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); + + nir_push_if(b, nir_ine(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1))); + { + nir_store_var(b, vars->primitive_id, primitive_id, 1); + nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); + nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1); + nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); + nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); + nir_store_var(b, vars->custom_instance_and_mask, + nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); + + load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); + + nir_store_var(b, trav_vars->should_return, + nir_ior(b, + nir_ine(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_imm_int(b, 8 /* SkipClosestHitShader */)), + nir_imm_int(b, 0)), + nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))), + 1); + + nir_ssa_def *terminate_on_first_hit = + nir_ine(b, + nir_iand(b, nir_load_var(b, vars->flags), + nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)), + nir_imm_int(b, 0)); + nir_ssa_def *ray_terminated = + nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2)); + nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); +} + +static void +insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + nir_builder *b, const struct rt_variables *vars) +{ + unsigned stack_entry_size = 4; + unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * + b->shader->info.workgroup_size[2]; + unsigned stack_entry_stride = stack_entry_size * lanes; + nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride); + nir_ssa_def *stack_base = + nir_iadd(b, nir_imm_int(b, b->shader->info.shared_size), + nir_imul(b, nir_load_subgroup_invocation(b), nir_imm_int(b, stack_entry_size))); + + /* + * A top-level AS can contain 2^24 children and a bottom-level AS can contain 2^24 triangles. At + * a branching factor of 4, that means we may need up to 24 levels of box nodes + 1 triangle node + * + 1 instance node. Furthermore, when processing a box node, worst case we actually push all 4 + * children and remove one, so the DFS stack depth is box nodes * 3 + 2. + */ + b->shader->info.shared_size += stack_entry_stride * 76; + assert(b->shader->info.shared_size <= 32768); + + nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct); + + struct rt_traversal_vars trav_vars = init_traversal_vars(b); + + /* Initialize the follow-up shader idx to 0, to be replaced by the miss shader + * if we actually miss. */ + nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1); + + nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1); + + nir_push_if(b, nir_ine(b, accel_struct, nir_imm_int64(b, 0))); + { + nir_store_var(b, trav_vars.bvh_base, nir_build_addr_to_node(b, accel_struct), 1); + + nir_ssa_def *bvh_root = + nir_build_load_global(b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, + .align_mul = 64, .align_offset = 0); + + /* We create a BVH descriptor that covers the entire memory range. That way we can always + * use the same descriptor, which avoids divergence when different rays hit different + * instances at the cost of having to use 64-bit node ids. */ + const uint64_t bvh_size = 1ull << 42; + nir_ssa_def *desc = nir_imm_ivec4( + b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu, + ((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31)); + + nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7); + nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); + nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); + nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); + nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1); + nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); + + nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1); + nir_store_shared(b, bvh_root, stack_base, .base = 0, .write_mask = 0x1, + .align_mul = stack_entry_size, .align_offset = 0); + + nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); + + nir_push_loop(b); + + nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base)); + nir_jump(b, nir_jump_break); + nir_pop_if(b, NULL); + + nir_push_if( + b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack))); + nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); + nir_store_var(b, trav_vars.bvh_base, + nir_build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1); + nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); + nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); + nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); + nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); + + nir_pop_if(b, NULL); + + nir_store_var(b, trav_vars.stack, + nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); + + nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0, + .align_mul = stack_entry_size, .align_offset = 0); + nir_ssa_def *bvh_node_type = nir_iand(b, bvh_node, nir_imm_int(b, 7)); + + bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64)); + nir_ssa_def *result = nir_bvh64_intersect_ray_amd( + b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax), + nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir), + nir_load_var(b, trav_vars.inv_dir)); + + nir_push_if(b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 4)), nir_imm_int(b, 0))); + { + nir_push_if(b, + nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 2)), nir_imm_int(b, 0))); + { + /* custom */ + nir_push_if( + b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 1)), nir_imm_int(b, 0))); + { + insert_traversal_aabb_case(device, pCreateInfo, b, result, vars, &trav_vars, + bvh_node); + } + nir_push_else(b, NULL); + { + /* instance */ + nir_ssa_def *instance_node_addr = nir_build_node_to_addr(b, bvh_node); + nir_ssa_def *instance_data = nir_build_load_global( + b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0); + nir_ssa_def *wto_matrix[] = { + nir_build_load_global(b, 4, 32, + nir_iadd(b, instance_node_addr, nir_imm_int64(b, 16)), + .align_mul = 64, .align_offset = 16), + nir_build_load_global(b, 4, 32, + nir_iadd(b, instance_node_addr, nir_imm_int64(b, 32)), + .align_mul = 64, .align_offset = 32), + nir_build_load_global(b, 4, 32, + nir_iadd(b, instance_node_addr, nir_imm_int64(b, 48)), + .align_mul = 64, .align_offset = 48)}; + nir_ssa_def *instance_id = nir_build_load_global( + b, 1, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 88)), .align_mul = 4, + .align_offset = 0); + nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2); + nir_ssa_def *instance_mask = nir_ushr(b, instance_and_mask, nir_imm_int(b, 24)); + + nir_push_if(b, + nir_ieq(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)), + nir_imm_int(b, 0))); + nir_jump(b, nir_jump_continue); + nir_pop_if(b, NULL); + + nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1); + nir_store_var(b, trav_vars.bvh_base, + nir_build_addr_to_node( + b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))), + 1); + nir_store_shared(b, + nir_iand(b, nir_channel(b, instance_data, 0), nir_imm_int(b, 63)), + nir_load_var(b, trav_vars.stack), .base = 0, .write_mask = 0x1, + .align_mul = stack_entry_size, .align_offset = 0); + nir_store_var(b, trav_vars.stack, + nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), + 1); + + nir_store_var( + b, trav_vars.origin, + nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7); + nir_store_var( + b, trav_vars.dir, + nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false), + 7); + nir_store_var(b, trav_vars.inv_dir, + nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); + nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1); + nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3), + 1); + nir_store_var(b, trav_vars.instance_id, instance_id, 1); + nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1); + } + nir_pop_if(b, NULL); + } + nir_push_else(b, NULL); + { + /* box */ + + for (unsigned i = 0; i < 4; ++i) { + nir_ssa_def *new_node = nir_vector_extract(b, result, nir_imm_int(b, i)); + nir_push_if(b, nir_ine(b, new_node, nir_imm_int(b, 0xffffffff))); + { + nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0, + .write_mask = 0x1, .align_mul = stack_entry_size, + .align_offset = 0); + nir_store_var( + b, trav_vars.stack, + nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); + } + nir_pop_if(b, NULL); + } + } + nir_pop_if(b, NULL); + } + nir_push_else(b, NULL); + { + insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node); + } + nir_pop_if(b, NULL); + + nir_pop_loop(b, NULL); + } + nir_pop_if(b, NULL); + + /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence + * need to return immediately to the calling shader. */ + nir_push_if(b, nir_load_var(b, trav_vars.should_return)); + { + insert_rt_return(b, vars); + } + nir_push_else(b, NULL); + { + /* Only load the miss shader if we actually miss, which we determining by not having set + * a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none + * of the rays miss. */ + nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))); + { + load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0); + } + nir_pop_if(b, NULL); + } + nir_pop_if(b, NULL); +} + static nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes)