About Social Code
aboutsummaryrefslogtreecommitdiff
path: root/src/kosmickrisp/compiler/nir_to_msl.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/kosmickrisp/compiler/nir_to_msl.c')
-rw-r--r--src/kosmickrisp/compiler/nir_to_msl.c2051
1 files changed, 2051 insertions, 0 deletions
diff --git a/src/kosmickrisp/compiler/nir_to_msl.c b/src/kosmickrisp/compiler/nir_to_msl.c
new file mode 100644
index 00000000000..51b96bb2c62
--- /dev/null
+++ b/src/kosmickrisp/compiler/nir_to_msl.c
@@ -0,0 +1,2051 @@
+/*
+ * Copyright 2025 LunarG, Inc.
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "nir_to_msl.h"
+#include "msl_private.h"
+#include "nir.h"
+
+static const char *
+get_stage_string(mesa_shader_stage stage)
+{
+ switch (stage) {
+ case MESA_SHADER_VERTEX:
+ return "vertex";
+ case MESA_SHADER_FRAGMENT:
+ return "fragment";
+ case MESA_SHADER_COMPUTE:
+ return "kernel";
+ default:
+ assert(0);
+ return "";
+ }
+}
+
+static const char *
+get_entrypoint_name(nir_shader *shader)
+{
+ return nir_shader_get_entrypoint(shader)->function->name;
+}
+
+static const char *sysval_table[SYSTEM_VALUE_MAX] = {
+ [SYSTEM_VALUE_SUBGROUP_SIZE] =
+ "uint gl_SubGroupSize [[threads_per_simdgroup]]",
+ [SYSTEM_VALUE_SUBGROUP_INVOCATION] =
+ "uint gl_SubGroupInvocation [[thread_index_in_simdgroup]]",
+ [SYSTEM_VALUE_NUM_SUBGROUPS] =
+ "uint gl_NumSubGroups [[simdgroups_per_threadgroup]]",
+ [SYSTEM_VALUE_SUBGROUP_ID] =
+ "uint gl_SubGroupID [[simdgroup_index_in_threadgroup]]",
+ [SYSTEM_VALUE_WORKGROUP_ID] =
+ "uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]",
+ [SYSTEM_VALUE_LOCAL_INVOCATION_ID] =
+ "uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]",
+ [SYSTEM_VALUE_GLOBAL_INVOCATION_ID] =
+ "uint3 gl_GlobalInvocationID [[thread_position_in_grid]]",
+ [SYSTEM_VALUE_NUM_WORKGROUPS] =
+ "uint3 gl_NumWorkGroups [[threadgroups_per_grid]]",
+ [SYSTEM_VALUE_LOCAL_INVOCATION_INDEX] =
+ "uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]]",
+ [SYSTEM_VALUE_VERTEX_ID] = "uint gl_VertexID [[vertex_id]]",
+ [SYSTEM_VALUE_INSTANCE_ID] = "uint gl_InstanceID [[instance_id]]",
+ [SYSTEM_VALUE_BASE_INSTANCE] = "uint gl_BaseInstance [[base_instance]]",
+ [SYSTEM_VALUE_FRAG_COORD] = "float4 gl_FragCoord [[position]]",
+ [SYSTEM_VALUE_POINT_COORD] = "float2 gl_PointCoord [[point_coord]]",
+ [SYSTEM_VALUE_FRONT_FACE] = "bool gl_FrontFacing [[front_facing]]",
+ [SYSTEM_VALUE_LAYER_ID] = "uint gl_Layer [[render_target_array_index]]",
+ [SYSTEM_VALUE_SAMPLE_ID] = "uint gl_SampleID [[sample_id]]",
+ [SYSTEM_VALUE_SAMPLE_MASK_IN] = "uint gl_SampleMask [[sample_mask]]",
+ [SYSTEM_VALUE_AMPLIFICATION_ID_KK] =
+ "uint mtl_AmplificationID [[amplification_id]]",
+ /* These are functions and not shader input variables */
+ [SYSTEM_VALUE_HELPER_INVOCATION] = "",
+};
+
+static void
+emit_sysvals(struct nir_to_msl_ctx *ctx, nir_shader *shader)
+{
+ unsigned i;
+ BITSET_FOREACH_SET(i, shader->info.system_values_read, SYSTEM_VALUE_MAX) {
+ assert(sysval_table[i]);
+ if (sysval_table[i] && sysval_table[i][0])
+ P_IND(ctx, "%s,\n", sysval_table[i]);
+ }
+}
+
+static void
+emit_inputs(struct nir_to_msl_ctx *ctx, nir_shader *shader)
+{
+ switch (shader->info.stage) {
+ case MESA_SHADER_FRAGMENT:
+ P_IND(ctx, "FragmentIn in [[stage_in]],\n");
+ break;
+ default:
+ break;
+ }
+ P_IND(ctx, "constant Buffer &buf0 [[buffer(0)]],\n");
+ P_IND(ctx, "constant SamplerTable &sampler_table [[buffer(1)]]\n");
+}
+
+static const char *
+output_type(nir_shader *shader)
+{
+ switch (shader->info.stage) {
+ case MESA_SHADER_VERTEX:
+ return "VertexOut";
+ case MESA_SHADER_FRAGMENT:
+ return "FragmentOut";
+ default:
+ return "void";
+ }
+}
+
+static void
+emit_local_vars(struct nir_to_msl_ctx *ctx, nir_shader *shader)
+{
+ if (shader->info.shared_size) {
+ P_IND(ctx, "threadgroup char shared_data[%d];\n",
+ shader->info.shared_size);
+ }
+ if (shader->scratch_size) {
+ P_IND(ctx, "uchar scratch[%d] = {0};\n", shader->scratch_size);
+ }
+ if (BITSET_TEST(shader->info.system_values_read,
+ SYSTEM_VALUE_HELPER_INVOCATION)) {
+ P_IND(ctx, "bool gl_HelperInvocation = simd_is_helper_thread();\n");
+ }
+}
+
+static bool
+is_register(nir_def *def)
+{
+ return ((def->parent_instr->type == nir_instr_type_intrinsic) &&
+ (nir_instr_as_intrinsic(def->parent_instr)->intrinsic ==
+ nir_intrinsic_load_reg));
+}
+
+static void
+writemask_to_msl(struct nir_to_msl_ctx *ctx, unsigned write_mask,
+ unsigned num_components)
+{
+ if (num_components != util_bitcount(write_mask)) {
+ P(ctx, ".");
+ for (unsigned i = 0; i < num_components; i++)
+ if ((write_mask >> i) & 1)
+ P(ctx, "%c", "xyzw"[i]);
+ }
+}
+
+static void
+src_to_msl(struct nir_to_msl_ctx *ctx, nir_src *src)
+{
+ /* Pointer types cannot use as_type casting */
+ const char *bitcast = msl_bitcast_for_src(ctx->types, src);
+ if (nir_src_is_const(*src)) {
+ msl_src_as_const(ctx, src);
+ return;
+ }
+ if (nir_src_is_undef(*src)) {
+ if (src->ssa->num_components == 1) {
+ P(ctx, "00");
+ } else {
+ P(ctx, "%s(", msl_type_for_src(ctx->types, src));
+ for (int i = 0; i < src->ssa->num_components; i++) {
+ if (i)
+ P(ctx, ", ");
+ P(ctx, "00");
+ }
+ P(ctx, ")");
+ }
+ return;
+ }
+
+ if (bitcast)
+ P(ctx, "as_type<%s>(", bitcast);
+ if (is_register(src->ssa)) {
+ nir_intrinsic_instr *instr =
+ nir_instr_as_intrinsic(src->ssa->parent_instr);
+ if (src->ssa->bit_size != 1u) {
+ P(ctx, "as_type<%s>(r%d)", msl_type_for_def(ctx->types, src->ssa),
+ instr->src[0].ssa->index);
+ } else {
+ P(ctx, "%s(r%d)", msl_type_for_def(ctx->types, src->ssa),
+ instr->src[0].ssa->index);
+ }
+ } else if (nir_src_is_const(*src)) {
+ msl_src_as_const(ctx, src);
+ } else {
+ P(ctx, "t%d", src->ssa->index);
+ }
+ if (bitcast)
+ P(ctx, ")");
+}
+
+static void
+alu_src_to_msl(struct nir_to_msl_ctx *ctx, nir_alu_instr *instr, int srcn)
+{
+ nir_alu_src *src = &instr->src[srcn];
+ src_to_msl(ctx, &src->src);
+ if (!nir_alu_src_is_trivial_ssa(instr, srcn) &&
+ src->src.ssa->num_components > 1) {
+ int num_components = nir_src_num_components(src->src);
+ assert(num_components <= 4);
+
+ P(ctx, ".");
+ for (int i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+ if (!nir_alu_instr_channel_used(instr, srcn, i))
+ continue;
+ P(ctx, "%c", "xyzw"[src->swizzle[i]]);
+ }
+ }
+}
+
+static void
+alu_funclike(struct nir_to_msl_ctx *ctx, nir_alu_instr *instr, const char *name)
+{
+ const nir_op_info *info = &nir_op_infos[instr->op];
+ P(ctx, "%s(", name);
+ for (int i = 0; i < info->num_inputs; i++) {
+ alu_src_to_msl(ctx, instr, i);
+ if (i < info->num_inputs - 1)
+ P(ctx, ", ");
+ }
+ P(ctx, ")");
+}
+
+static void
+alu_to_msl(struct nir_to_msl_ctx *ctx, nir_alu_instr *instr)
+{
+
+#define ALU_BINOP(op) \
+ do { \
+ alu_src_to_msl(ctx, instr, 0); \
+ P(ctx, " %s ", op); \
+ alu_src_to_msl(ctx, instr, 1); \
+ } while (0);
+
+ switch (instr->op) {
+ case nir_op_isign:
+ alu_src_to_msl(ctx, instr, 0);
+ P(ctx, " == 0 ? 0.0 : ((");
+ alu_src_to_msl(ctx, instr, 0);
+ P(ctx, " < 0) ? -1 : 1)");
+ break;
+ case nir_op_iadd:
+ case nir_op_fadd:
+ ALU_BINOP("+");
+ break;
+ case nir_op_uadd_sat:
+ case nir_op_iadd_sat:
+ alu_funclike(ctx, instr, "addsat");
+ break;
+ case nir_op_isub:
+ case nir_op_fsub:
+ ALU_BINOP("-");
+ break;
+ case nir_op_imul:
+ case nir_op_fmul:
+ ALU_BINOP("*");
+ break;
+ case nir_op_idiv:
+ case nir_op_udiv:
+ case nir_op_fdiv:
+ ALU_BINOP("/");
+ break;
+ case nir_op_irem:
+ ALU_BINOP("%");
+ break;
+ case nir_op_ishl:
+ ALU_BINOP("<<");
+ break;
+ case nir_op_ishr:
+ case nir_op_ushr:
+ ALU_BINOP(">>");
+ break;
+ case nir_op_ige:
+ case nir_op_uge:
+ case nir_op_fge:
+ ALU_BINOP(">=");
+ break;
+ case nir_op_ilt:
+ case nir_op_ult:
+ case nir_op_flt:
+ ALU_BINOP("<")
+ break;
+ case nir_op_iand:
+ ALU_BINOP("&");
+ break;
+ case nir_op_ior:
+ ALU_BINOP("|");
+ break;
+ case nir_op_ixor:
+ ALU_BINOP("^");
+ break;
+ case nir_op_bitfield_insert:
+ alu_funclike(ctx, instr, "insert_bits");
+ break;
+ case nir_op_ibitfield_extract:
+ case nir_op_ubitfield_extract:
+ alu_funclike(ctx, instr, "extract_bits");
+ break;
+ case nir_op_bitfield_reverse:
+ alu_funclike(ctx, instr, "reverse_bits");
+ break;
+ case nir_op_bit_count:
+ alu_funclike(ctx, instr, "popcount");
+ break;
+ case nir_op_uclz:
+ alu_funclike(ctx, instr, "clz");
+ break;
+ case nir_op_ieq:
+ case nir_op_feq:
+ ALU_BINOP("==");
+ break;
+ case nir_op_ine:
+ case nir_op_fneu:
+ ALU_BINOP("!=");
+ break;
+ case nir_op_umax:
+ case nir_op_imax:
+ alu_funclike(ctx, instr, "max");
+ break;
+ case nir_op_umin:
+ case nir_op_imin:
+ alu_funclike(ctx, instr, "min");
+ break;
+ case nir_op_umod:
+ case nir_op_imod:
+ ALU_BINOP("%");
+ break;
+ case nir_op_imul_high:
+ case nir_op_umul_high:
+ alu_funclike(ctx, instr, "mulhi");
+ break;
+ case nir_op_usub_sat:
+ alu_funclike(ctx, instr, "subsat");
+ break;
+ case nir_op_fsat:
+ alu_funclike(ctx, instr, "saturate");
+ break;
+ /* Functions from <metal_relational> */
+ case nir_op_fisfinite:
+ alu_funclike(ctx, instr, "isfinite");
+ break;
+ case nir_op_fisnormal:
+ alu_funclike(ctx, instr, "isnormal");
+ break;
+ /* Functions from <metal_math> */
+ case nir_op_iabs:
+ case nir_op_fabs:
+ alu_funclike(ctx, instr, "abs");
+ break;
+ case nir_op_fceil:
+ alu_funclike(ctx, instr, "ceil");
+ break;
+ case nir_op_fcos:
+ alu_funclike(ctx, instr, "cos");
+ break;
+ case nir_op_fdot2:
+ case nir_op_fdot3:
+ case nir_op_fdot4:
+ alu_funclike(ctx, instr, "dot");
+ break;
+ case nir_op_fexp2:
+ alu_funclike(ctx, instr, "exp2");
+ break;
+ case nir_op_ffloor:
+ alu_funclike(ctx, instr, "floor");
+ break;
+ case nir_op_ffma:
+ alu_funclike(ctx, instr, "fma");
+ break;
+ case nir_op_ffract:
+ alu_funclike(ctx, instr, "fract");
+ break;
+ case nir_op_flog2:
+ alu_funclike(ctx, instr, "log2");
+ break;
+ case nir_op_flrp:
+ alu_funclike(ctx, instr, "mix");
+ break;
+ case nir_op_fmax:
+ alu_funclike(ctx, instr, "fmax");
+ break;
+ case nir_op_fmin:
+ alu_funclike(ctx, instr, "fmin");
+ break;
+ case nir_op_frem:
+ alu_funclike(ctx, instr, "fmod");
+ break;
+ case nir_op_fpow:
+ alu_funclike(ctx, instr, "pow");
+ break;
+ case nir_op_fround_even:
+ alu_funclike(ctx, instr, "rint");
+ break;
+ case nir_op_frsq:
+ alu_funclike(ctx, instr, "rsqrt");
+ break;
+ case nir_op_fsign:
+ alu_funclike(ctx, instr, "sign");
+ break;
+ case nir_op_fsqrt:
+ alu_funclike(ctx, instr, "sqrt");
+ break;
+ case nir_op_fsin:
+ alu_funclike(ctx, instr, "sin");
+ break;
+ case nir_op_ldexp:
+ alu_funclike(ctx, instr, "ldexp");
+ break;
+ case nir_op_ftrunc:
+ alu_funclike(ctx, instr, "trunc");
+ break;
+ case nir_op_pack_snorm_4x8:
+ alu_funclike(ctx, instr, "pack_float_to_snorm4x8");
+ break;
+ case nir_op_pack_unorm_4x8:
+ alu_funclike(ctx, instr, "pack_float_to_unorm4x8");
+ break;
+ case nir_op_pack_snorm_2x16:
+ alu_funclike(ctx, instr, "pack_float_to_snorm2x16");
+ break;
+ case nir_op_pack_unorm_2x16:
+ alu_funclike(ctx, instr, "pack_float_to_unorm2x16");
+ break;
+ case nir_op_unpack_snorm_4x8:
+ alu_funclike(ctx, instr, "unpack_snorm4x8_to_float");
+ break;
+ case nir_op_unpack_unorm_4x8:
+ alu_funclike(ctx, instr, "unpack_unorm4x8_to_float");
+ break;
+ case nir_op_unpack_snorm_2x16:
+ alu_funclike(ctx, instr, "unpack_snorm2x16_to_float");
+ break;
+ case nir_op_unpack_unorm_2x16:
+ alu_funclike(ctx, instr, "unpack_unorm2x16_to_float");
+ break;
+ case nir_op_vec2:
+ case nir_op_vec3:
+ case nir_op_vec4:
+ case nir_op_b2b1:
+ case nir_op_b2b32:
+ case nir_op_b2i8:
+ case nir_op_b2i16:
+ case nir_op_b2i32:
+ case nir_op_b2i64:
+ case nir_op_b2f16:
+ case nir_op_i2f16:
+ case nir_op_u2f16:
+ case nir_op_i2f32:
+ case nir_op_u2f32:
+ case nir_op_i2i8:
+ case nir_op_i2i16:
+ case nir_op_i2i32:
+ case nir_op_i2i64:
+ case nir_op_f2i8:
+ case nir_op_f2i16:
+ case nir_op_f2i32:
+ case nir_op_f2i64:
+ case nir_op_f2u8:
+ case nir_op_f2u16:
+ case nir_op_f2u32:
+ case nir_op_f2u64:
+ case nir_op_u2u8:
+ case nir_op_u2u16:
+ case nir_op_u2u32:
+ case nir_op_u2u64:
+ case nir_op_f2f16:
+ case nir_op_f2f16_rtne:
+ case nir_op_f2f32:
+ alu_funclike(ctx, instr, msl_type_for_def(ctx->types, &instr->def));
+ break;
+ case nir_op_unpack_half_2x16_split_x:
+ P(ctx, "float(as_type<half>(ushort(t%d & 0x0000ffff)))",
+ instr->src[0].src.ssa->index);
+ break;
+ case nir_op_frcp:
+ P(ctx, "1/");
+ alu_src_to_msl(ctx, instr, 0);
+ break;
+ case nir_op_inot:
+ if (instr->src[0].src.ssa->bit_size == 1) {
+ P(ctx, "!");
+ } else {
+ P(ctx, "~");
+ }
+ alu_src_to_msl(ctx, instr, 0);
+ break;
+ case nir_op_ineg:
+ case nir_op_fneg:
+ P(ctx, "-");
+ alu_src_to_msl(ctx, instr, 0);
+ break;
+ case nir_op_mov:
+ alu_src_to_msl(ctx, instr, 0);
+ break;
+ case nir_op_b2f32:
+ alu_src_to_msl(ctx, instr, 0);
+ P(ctx, " ? 1.0 : 0.0");
+ break;
+ case nir_op_bcsel:
+ alu_src_to_msl(ctx, instr, 0);
+ P(ctx, " ? ");
+ alu_src_to_msl(ctx, instr, 1);
+ P(ctx, " : ");
+ alu_src_to_msl(ctx, instr, 2);
+ break;
+ default:
+ P(ctx, "ALU %s", nir_op_infos[instr->op].name);
+ }
+}
+
+static const char *
+texture_dim(enum glsl_sampler_dim dim)
+{
+ switch (dim) {
+ case GLSL_SAMPLER_DIM_1D:
+ return "1d";
+ case GLSL_SAMPLER_DIM_2D:
+ return "2d";
+ case GLSL_SAMPLER_DIM_3D:
+ return "3d";
+ case GLSL_SAMPLER_DIM_CUBE:
+ return "cube";
+ case GLSL_SAMPLER_DIM_BUF:
+ return "_buffer";
+ case GLSL_SAMPLER_DIM_MS:
+ return "2d_ms";
+ default:
+ fprintf(stderr, "Bad texture dim %d\n", dim);
+ assert(!"Bad texture dimension");
+ return "BAD";
+ }
+}
+
+static const char *
+tex_type_name(nir_alu_type ty)
+{
+ switch (ty) {
+ case nir_type_int16:
+ return "short";
+ case nir_type_int32:
+ return "int";
+ case nir_type_uint16:
+ return "ushort";
+ case nir_type_uint32:
+ return "uint";
+ case nir_type_float16:
+ return "half";
+ case nir_type_float32:
+ return "float";
+ default:
+ return "BAD";
+ }
+}
+
+static bool
+instrinsic_needs_dest_type(nir_intrinsic_instr *instr)
+{
+ const nir_intrinsic_info *info = &nir_intrinsic_infos[instr->intrinsic];
+ nir_intrinsic_op op = instr->intrinsic;
+ if (op == nir_intrinsic_decl_reg || op == nir_intrinsic_load_reg ||
+ op == nir_intrinsic_load_texture_handle_kk ||
+ op == nir_intrinsic_load_depth_texture_kk ||
+ /* Atomic swaps have a custom codegen */
+ op == nir_intrinsic_global_atomic_swap ||
+ op == nir_intrinsic_shared_atomic_swap ||
+ op == nir_intrinsic_bindless_image_atomic_swap)
+ return false;
+ return info->has_dest;
+}
+
+static const char *
+msl_pipe_format_to_msl_type(enum pipe_format format)
+{
+ switch (format) {
+ case PIPE_FORMAT_R16_FLOAT:
+ return "half";
+ case PIPE_FORMAT_R32_FLOAT:
+ return "float";
+ case PIPE_FORMAT_R8_UINT:
+ return "uchar";
+ case PIPE_FORMAT_R16_UINT:
+ return "ushort";
+ case PIPE_FORMAT_R32_UINT:
+ return "uint";
+ case PIPE_FORMAT_R64_UINT:
+ return "unsigned long";
+ case PIPE_FORMAT_R8_SINT:
+ return "char";
+ case PIPE_FORMAT_R16_SINT:
+ return "short";
+ case PIPE_FORMAT_R32_SINT:
+ return "int";
+ case PIPE_FORMAT_R64_SINT:
+ return "long";
+ default:
+ assert(0);
+ return "";
+ }
+}
+
+static const char *
+component_str(uint8_t num_components)
+{
+ switch (num_components) {
+ default:
+ case 1:
+ return "";
+ case 2:
+ return "2";
+ case 3:
+ return "3";
+ case 4:
+ return "4";
+ }
+}
+
+static void
+round_src_component_to_uint(struct nir_to_msl_ctx *ctx, nir_src *src,
+ char component)
+{
+ bool is_float = msl_src_is_float(ctx, src);
+ if (is_float) {
+ P(ctx, "uint(rint(");
+ }
+ src_to_msl(ctx, src);
+ P(ctx, ".%c", component);
+ if (is_float) {
+ P(ctx, "))");
+ }
+}
+
+static void
+texture_src_coord_swizzle(struct nir_to_msl_ctx *ctx, nir_src *coord,
+ uint32_t num_components, bool is_cube, bool is_array)
+{
+ src_to_msl(ctx, coord);
+
+ uint32_t coord_components =
+ num_components - (uint32_t)is_array - (uint32_t)is_cube;
+ if (coord_components < coord->ssa->num_components) {
+ const char *swizzle = "xyzw";
+ uint32_t i = 0;
+ P(ctx, ".");
+ for (i = 0; i < coord_components; i++)
+ P(ctx, "%c", swizzle[i]);
+
+ if (is_cube) {
+ P(ctx, ", ");
+ round_src_component_to_uint(ctx, coord, swizzle[i++]);
+ }
+ if (is_array) {
+ P(ctx, ", ");
+ round_src_component_to_uint(ctx, coord, swizzle[i++]);
+ }
+ }
+}
+
+static void
+image_coord_swizzle(struct nir_to_msl_ctx *ctx, nir_intrinsic_instr *instr)
+{
+ unsigned comps = 0;
+ bool is_array = nir_intrinsic_image_array(instr);
+ bool is_cube = false;
+ switch (nir_intrinsic_image_dim(instr)) {
+ case GLSL_SAMPLER_DIM_BUF:
+ case GLSL_SAMPLER_DIM_1D:
+ comps = 1;
+ break;
+ case GLSL_SAMPLER_DIM_2D:
+ case GLSL_SAMPLER_DIM_MS:
+ comps = 2;
+ break;
+ case GLSL_SAMPLER_DIM_3D:
+ comps = 3;
+ break;
+ case GLSL_SAMPLER_DIM_CUBE:
+ comps = 3;
+ is_cube = true;
+ break;
+ default:
+ assert(!"Bad dimension for image");
+ break;
+ }
+ if (is_array)
+ comps += 1;
+
+ texture_src_coord_swizzle(ctx, &instr->src[1], comps, is_cube, is_array);
+}
+
+/* Non-packed types have stricter alignment requirements that packed types.
+ * This helps us build a packed format for storage.
+ */
+static void
+src_to_packed(struct nir_to_msl_ctx *ctx, nir_src *src, const char *type,
+ uint32_t component_count)
+{
+ if (component_count == 1) {
+ P(ctx, "%s(", type);
+ } else {
+ P(ctx, "packed_%s(", type);
+ }
+ src_to_msl(ctx, src);
+ P(ctx, ")");
+}
+
+/* Non-packed types have stricter alignment requirements that packed types.
+ * This helps us cast the pointer to a packed type and then it builds the
+ * non-packed type for Metal usage.
+ */
+static void
+src_to_packed_load(struct nir_to_msl_ctx *ctx, nir_src *src,
+ const char *addressing, const char *type,
+ uint32_t component_count)
+{
+ if (component_count == 1) {
+ P(ctx, "*(%s %s*)(", addressing, type);
+ } else {
+ P(ctx, "%s(*(%s packed_%s*)", type, addressing, type);
+ }
+ src_to_msl(ctx, src);
+ P(ctx, ")");
+}
+
+/* Non-packed types have stricter alignment requirements that packed types.
+ * This helps us cast the pointer to a packed type and then it builds the
+ * non-packed type for Metal usage.
+ */
+static void
+src_to_packed_load_offset(struct nir_to_msl_ctx *ctx, nir_src *src,
+ nir_src *offset, const char *addressing,
+ const char *type, uint32_t component_count)
+{
+ if (component_count == 1) {
+ P(ctx, "*(%s %s*)((", addressing, type);
+ } else {
+ P(ctx, "%s(*(%s packed_%s*)(", type, addressing, type);
+ }
+ src_to_msl(ctx, src);
+ P(ctx, " + ");
+ src_to_msl(ctx, offset);
+ P(ctx, "))");
+}
+
+/* Non-packed types have stricter alignment requirements that packed types.
+ * This helps us cast the pointer to a packed type for storage.
+ */
+static void
+src_to_packed_store(struct nir_to_msl_ctx *ctx, nir_src *src,
+ const char *addressing, const char *type,
+ uint32_t num_components)
+{
+ if (num_components == 1) {
+ P_IND(ctx, "*(%s %s*)", addressing, type);
+ } else {
+ P_IND(ctx, "*(%s packed_%s*)", addressing, type);
+ }
+ src_to_msl(ctx, src);
+}
+
+static const char *
+atomic_op_to_msl(nir_atomic_op op)
+{
+ switch (op) {
+ case nir_atomic_op_iadd:
+ case nir_atomic_op_fadd:
+ return "atomic_fetch_add";
+ case nir_atomic_op_umin:
+ case nir_atomic_op_imin:
+ case nir_atomic_op_fmin:
+ return "atomic_fetch_min";
+ case nir_atomic_op_umax:
+ case nir_atomic_op_imax:
+ case nir_atomic_op_fmax:
+ return "atomic_fetch_max";
+ case nir_atomic_op_iand:
+ return "atomic_fetch_and";
+ case nir_atomic_op_ior:
+ return "atomic_fetch_or";
+ case nir_atomic_op_ixor:
+ return "atomic_fetch_xor";
+ case nir_atomic_op_xchg:
+ return "atomic_exchange";
+ case nir_atomic_op_cmpxchg:
+ case nir_atomic_op_fcmpxchg:
+ return "atomic_compare_exchange_weak";
+ default:
+ UNREACHABLE("Unhandled atomic op");
+ }
+}
+
+static void
+atomic_to_msl(struct nir_to_msl_ctx *ctx, nir_intrinsic_instr *instr,
+ const char *scope, bool shared)
+{
+ const char *atomic_op = atomic_op_to_msl(nir_intrinsic_atomic_op(instr));
+ const char *mem_order = "memory_order_relaxed";
+
+ P(ctx, "%s_explicit((%s atomic_%s*)", atomic_op, scope,
+ msl_type_for_def(ctx->types, &instr->def));
+ if (shared)
+ P(ctx, "&shared_data[");
+ src_to_msl(ctx, &instr->src[0]);
+ if (shared)
+ P(ctx, "]");
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ", %s", mem_order);
+ P(ctx, ");\n");
+}
+
+static void
+atomic_swap_to_msl(struct nir_to_msl_ctx *ctx, nir_intrinsic_instr *instr,
+ const char *scope, bool shared)
+{
+ const char *atomic_op = atomic_op_to_msl(nir_intrinsic_atomic_op(instr));
+ const char *mem_order = "memory_order_relaxed";
+ const char *type = msl_type_for_def(ctx->types, &instr->def);
+
+ P_IND(ctx, "%s ta%d = ", type, instr->def.index);
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, "; %s_explicit((%s atomic_%s*)", atomic_op, scope, type);
+ if (shared)
+ P(ctx, "&shared_data[");
+ src_to_msl(ctx, &instr->src[0]);
+ if (shared)
+ P(ctx, "]");
+ P(ctx, ", ");
+ P(ctx, "&ta%d, ", instr->def.index);
+ src_to_msl(ctx, &instr->src[2]);
+ P(ctx, ", %s, %s);", mem_order, mem_order);
+ P(ctx, "%s t%d = ta%d;\n", type, instr->def.index, instr->def.index);
+}
+
+static void
+memory_modes_to_msl(struct nir_to_msl_ctx *ctx, nir_variable_mode modes)
+{
+ bool requires_or = false;
+ u_foreach_bit(i, modes) {
+ nir_variable_mode single_mode = (1 << i);
+ if (requires_or)
+ P(ctx, " | ");
+ switch (single_mode) {
+ case nir_var_image:
+ P(ctx, "mem_flags::mem_texture");
+ break;
+ case nir_var_mem_ssbo:
+ case nir_var_mem_global:
+ P(ctx, "mem_flags::mem_device");
+ break;
+ case nir_var_function_temp:
+ P(ctx, "mem_flags::mem_none");
+ break;
+ case nir_var_mem_shared:
+ P(ctx, "mem_flags::mem_threadgroup");
+ break;
+ default:
+ UNREACHABLE("bad_memory_mode");
+ }
+ requires_or = true;
+ }
+}
+
+static uint32_t
+get_input_num_components(struct nir_to_msl_ctx *ctx, uint32_t location)
+{
+ return ctx->inputs_info[location].num_components;
+}
+
+static uint32_t
+get_output_num_components(struct nir_to_msl_ctx *ctx, uint32_t location)
+{
+ return ctx->outputs_info[location].num_components;
+}
+
+static void
+intrinsic_to_msl(struct nir_to_msl_ctx *ctx, nir_intrinsic_instr *instr)
+{
+ /* These instructions are only used to understand interpolation modes, they
+ * don't generate any code. */
+ if (instr->intrinsic == nir_intrinsic_load_barycentric_pixel ||
+ instr->intrinsic == nir_intrinsic_load_barycentric_centroid ||
+ instr->intrinsic == nir_intrinsic_load_barycentric_sample)
+ return;
+
+ const nir_intrinsic_info *info = &nir_intrinsic_infos[instr->intrinsic];
+ if (instrinsic_needs_dest_type(instr)) {
+ P_IND(ctx, "t%d = ", instr->def.index);
+ }
+ switch (instr->intrinsic) {
+ case nir_intrinsic_decl_reg: {
+ const char *reg_type = msl_uint_type(nir_intrinsic_bit_size(instr),
+ nir_intrinsic_num_components(instr));
+ P_IND(ctx, "%s r%d = %s(0);\n", reg_type, instr->def.index, reg_type);
+ } break;
+ case nir_intrinsic_load_reg:
+ // register loads get inlined into the uses
+ break;
+ case nir_intrinsic_store_reg:
+ P_IND(ctx, "r%d", instr->src[1].ssa->index);
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ /* Registers don't store the component count, so get it from the value we
+ * are assigning */
+ if (instr->src[0].ssa->bit_size == 1u) {
+ P(ctx, " = bool%s((", component_str(instr->num_components));
+ } else if (nir_src_is_const(instr->src[0])) {
+ /* Const vector types already build the type */
+ if (instr->src[0].ssa->num_components > 1) {
+ P(ctx, " = as_type<%s>((",
+ msl_uint_type(instr->src[0].ssa->bit_size,
+ instr->src[0].ssa->num_components));
+ } else {
+ P(ctx, " = as_type<%s>(%s(",
+ msl_uint_type(instr->src[0].ssa->bit_size,
+ instr->src[0].ssa->num_components),
+ msl_type_for_src(ctx->types, &instr->src[0]));
+ }
+ } else {
+ P(ctx, " = as_type<%s>((",
+ msl_uint_type(instr->src[0].ssa->bit_size,
+ instr->src[0].ssa->num_components));
+ }
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, "));\n");
+ break;
+ case nir_intrinsic_load_subgroup_size:
+ P(ctx, "gl_SubGroupSize;\n");
+ break;
+ case nir_intrinsic_load_subgroup_invocation:
+ P(ctx, "gl_SubGroupInvocation;\n");
+ break;
+ case nir_intrinsic_load_num_subgroups:
+ P(ctx, "gl_NumSubGroups;\n");
+ break;
+ case nir_intrinsic_load_subgroup_id:
+ P(ctx, "gl_SubGroupID;\n");
+ break;
+ case nir_intrinsic_load_workgroup_id:
+ P(ctx, "gl_WorkGroupID;\n");
+ break;
+ case nir_intrinsic_load_local_invocation_id:
+ P(ctx, "gl_LocalInvocationID;\n");
+ break;
+ case nir_intrinsic_load_global_invocation_id:
+ P(ctx, "gl_GlobalInvocationID;\n");
+ break;
+ case nir_intrinsic_load_num_workgroups:
+ P(ctx, "gl_NumWorkGroups;\n");
+ break;
+ case nir_intrinsic_load_local_invocation_index:
+ P(ctx, "gl_LocalInvocationIndex;\n");
+ break;
+ case nir_intrinsic_load_frag_coord:
+ P(ctx, "gl_FragCoord;\n");
+ break;
+ case nir_intrinsic_load_point_coord:
+ P(ctx, "gl_PointCoord;\n");
+ break;
+ case nir_intrinsic_load_vertex_id:
+ P(ctx, "gl_VertexID;\n");
+ break;
+ case nir_intrinsic_load_instance_id:
+ P(ctx, "gl_InstanceID;\n");
+ break;
+ case nir_intrinsic_load_base_instance:
+ P(ctx, "gl_BaseInstance;\n");
+ break;
+ case nir_intrinsic_load_helper_invocation:
+ P(ctx, "gl_HelperInvocation;\n");
+ break;
+ case nir_intrinsic_is_helper_invocation:
+ P(ctx, "simd_is_helper_thread();\n");
+ break;
+ case nir_intrinsic_ddx:
+ case nir_intrinsic_ddx_coarse:
+ case nir_intrinsic_ddx_fine:
+ P(ctx, "dfdx(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_ddy:
+ case nir_intrinsic_ddy_coarse:
+ case nir_intrinsic_ddy_fine:
+ P(ctx, "dfdy(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_load_front_face:
+ P(ctx, "gl_FrontFacing;\n");
+ break;
+ case nir_intrinsic_load_layer_id:
+ P(ctx, "gl_Layer;\n");
+ break;
+ case nir_intrinsic_load_sample_id:
+ P(ctx, "gl_SampleID;\n");
+ break;
+ case nir_intrinsic_load_sample_mask_in:
+ P(ctx, "gl_SampleMask;\n");
+ break;
+ case nir_intrinsic_load_amplification_id_kk:
+ P(ctx, "mtl_AmplificationID;\n");
+ break;
+ case nir_intrinsic_load_interpolated_input: {
+ unsigned idx = nir_src_as_uint(instr->src[1u]);
+ nir_io_semantics io = nir_intrinsic_io_semantics(instr);
+ uint32_t component = nir_intrinsic_component(instr);
+ uint32_t location = io.location + idx;
+ P(ctx, "in.%s", msl_input_name(ctx, location));
+ if (instr->num_components < get_input_num_components(ctx, location)) {
+ P(ctx, ".");
+ for (unsigned i = 0; i < instr->num_components; i++)
+ P(ctx, "%c", "xyzw"[component + i]);
+ }
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_input: {
+ unsigned idx = nir_src_as_uint(instr->src[0u]);
+ nir_io_semantics io = nir_intrinsic_io_semantics(instr);
+ uint32_t component = nir_intrinsic_component(instr);
+ uint32_t location = io.location + idx;
+ P(ctx, "in.%s", msl_input_name(ctx, location));
+ if (instr->num_components < get_input_num_components(ctx, location)) {
+ P(ctx, ".");
+ for (unsigned i = 0; i < instr->num_components; i++)
+ P(ctx, "%c", "xyzw"[component + i]);
+ }
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_output: {
+ unsigned idx = nir_src_as_uint(instr->src[0]);
+ nir_io_semantics io = nir_intrinsic_io_semantics(instr);
+ P(ctx, "out.%s;\n", msl_output_name(ctx, io.location + idx));
+ break;
+ }
+ case nir_intrinsic_store_output: {
+ uint32_t idx = nir_src_as_uint(instr->src[1]);
+ nir_io_semantics io = nir_intrinsic_io_semantics(instr);
+ uint32_t location = io.location + idx;
+ uint32_t write_mask = nir_intrinsic_write_mask(instr);
+ uint32_t component = nir_intrinsic_component(instr);
+ uint32_t dst_num_components = get_output_num_components(ctx, location);
+ uint32_t num_components = instr->num_components;
+
+ P_IND(ctx, "out.%s", msl_output_name(ctx, location));
+ if (dst_num_components > 1u) {
+ P(ctx, ".");
+ for (unsigned i = 0; i < num_components; i++)
+ if ((write_mask >> i) & 1)
+ P(ctx, "%c", "xyzw"[component + i]);
+ }
+ P(ctx, " = ");
+ src_to_msl(ctx, &instr->src[0]);
+ if (num_components > 1u) {
+ P(ctx, ".");
+ for (unsigned i = 0; i < num_components; i++)
+ if ((write_mask >> i) & 1)
+ P(ctx, "%c", "xyzw"[i]);
+ }
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_push_constant: {
+ const char *ty = msl_type_for_def(ctx->types, &instr->def);
+ assert(nir_intrinsic_base(instr) == 0);
+ P(ctx, "*((constant %s*)&buf.push_consts[", ty);
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, "]);\n");
+ break;
+ }
+ case nir_intrinsic_load_buffer_ptr_kk:
+ P(ctx, "(ulong)&buf%d.contents[0];\n", nir_intrinsic_binding(instr));
+ break;
+ case nir_intrinsic_load_global: {
+ src_to_packed_load(ctx, &instr->src[0], "device",
+ msl_type_for_def(ctx->types, &instr->def),
+ instr->def.num_components);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_global_constant: {
+ src_to_packed_load(ctx, &instr->src[0], "constant",
+ msl_type_for_def(ctx->types, &instr->def),
+ instr->def.num_components);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_global_constant_bounded: {
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, " < ");
+ src_to_msl(ctx, &instr->src[2]);
+ P(ctx, " ? ");
+ src_to_packed_load_offset(ctx, &instr->src[0], &instr->src[1], "constant",
+ msl_type_for_def(ctx->types, &instr->def),
+ instr->def.num_components);
+ P(ctx, " : 0;\n");
+ break;
+ }
+ case nir_intrinsic_load_global_constant_offset: {
+ src_to_packed_load_offset(ctx, &instr->src[0], &instr->src[1], "device",
+ msl_type_for_def(ctx->types, &instr->def),
+ instr->def.num_components);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_global_atomic:
+ atomic_to_msl(ctx, instr, "device", false);
+ break;
+ case nir_intrinsic_global_atomic_swap:
+ atomic_swap_to_msl(ctx, instr, "device", false);
+ break;
+ case nir_intrinsic_shared_atomic:
+ atomic_to_msl(ctx, instr, "threadgroup", true);
+ break;
+ case nir_intrinsic_shared_atomic_swap:
+ atomic_swap_to_msl(ctx, instr, "threadgroup", true);
+ break;
+ case nir_intrinsic_store_global: {
+ const char *type = msl_type_for_src(ctx->types, &instr->src[0]);
+ src_to_packed_store(ctx, &instr->src[1], "device", type,
+ instr->src[0].ssa->num_components);
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ P(ctx, " = ")
+ src_to_packed(ctx, &instr->src[0], type,
+ instr->src[0].ssa->num_components);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_barrier: {
+ mesa_scope execution_scope = nir_intrinsic_execution_scope(instr);
+ nir_variable_mode memory_modes = nir_intrinsic_memory_modes(instr);
+ if (execution_scope == SCOPE_SUBGROUP) {
+ P_IND(ctx, "simdgroup_barrier(");
+ memory_modes_to_msl(ctx, memory_modes);
+ } else if (execution_scope == SCOPE_WORKGROUP) {
+ P_IND(ctx, "threadgroup_barrier(");
+ memory_modes_to_msl(ctx, memory_modes);
+ } else if (execution_scope == SCOPE_NONE) {
+ /* Empty barrier */
+ if (memory_modes == 0u)
+ break;
+
+ P_IND(ctx, "atomic_thread_fence(");
+ memory_modes_to_msl(ctx, memory_modes);
+ P(ctx, ", memory_order_seq_cst, ");
+ switch (nir_intrinsic_memory_scope(instr)) {
+ case SCOPE_SUBGROUP:
+ P(ctx, "thread_scope::thread_scope_simdgroup");
+ break;
+ case SCOPE_WORKGROUP:
+ /* TODO_KOSMICKRISP This if case should not be needed but we fail
+ * the following CTS tests otherwise:
+ * dEQP-VK.memory_model.*.ext.u32.*coherent.*.atomicwrite.workgroup.payload_*local.*.guard_local.*.comp
+ * The last two wild cards being either 'workgroup' or 'physbuffer'
+ */
+ if (memory_modes &
+ (nir_var_mem_global | nir_var_mem_ssbo | nir_var_image)) {
+ P(ctx, "thread_scope::thread_scope_device");
+ } else {
+ P(ctx, "thread_scope::thread_scope_threadgroup");
+ }
+
+ break;
+ case SCOPE_QUEUE_FAMILY:
+ case SCOPE_DEVICE:
+ P(ctx, "thread_scope::thread_scope_device");
+ break;
+ default:
+ P(ctx, "bad_scope");
+ assert(!"bad scope");
+ break;
+ }
+ } else {
+ UNREACHABLE("bad_execution scope");
+ }
+ P(ctx, ");\n");
+ break;
+ }
+ case nir_intrinsic_demote:
+ P_IND(ctx, "discard_fragment();\n");
+ break;
+ case nir_intrinsic_demote_if:
+ P_IND(ctx, "if (")
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ")\n");
+ ctx->indentlevel++;
+ P_IND(ctx, "discard_fragment();\n");
+ ctx->indentlevel--;
+ break;
+ case nir_intrinsic_terminate:
+ P_IND(ctx, "discard_fragment();\n");
+ P_IND(ctx, "return {};\n");
+ break;
+ case nir_intrinsic_terminate_if:
+ P_IND(ctx, "if (")
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ") {\n");
+ ctx->indentlevel++;
+ P_IND(ctx, "discard_fragment();\n");
+ P_IND(ctx, "return {};\n");
+ ctx->indentlevel--;
+ P_IND(ctx, "}\n");
+ break;
+ case nir_intrinsic_load_shared:
+ assert(nir_intrinsic_base(instr) == 0);
+ P(ctx, "*(threadgroup %s*)&shared_data[",
+ msl_type_for_def(ctx->types, &instr->def));
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, "];\n");
+ break;
+ case nir_intrinsic_store_shared:
+ assert(nir_intrinsic_base(instr) == 0);
+ P_IND(ctx, "(*(threadgroup %s*)&shared_data[",
+ msl_type_for_src(ctx->types, &instr->src[0]));
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, "])");
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ P(ctx, " = ");
+ src_to_msl(ctx, &instr->src[0]);
+ if (instr->src[0].ssa->num_components > 1)
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ P(ctx, ";\n");
+ break;
+ case nir_intrinsic_load_scratch:
+ P(ctx, "*(thread %s*)&scratch[",
+ msl_type_for_def(ctx->types, &instr->def));
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, "];\n");
+ break;
+ case nir_intrinsic_store_scratch:
+ P_IND(ctx, "(*(thread %s*)&scratch[",
+ msl_type_for_src(ctx->types, &instr->src[0]));
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, "])");
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ P(ctx, " = ");
+ src_to_msl(ctx, &instr->src[0]);
+ if (instr->src[0].ssa->num_components > 1)
+ writemask_to_msl(ctx, nir_intrinsic_write_mask(instr),
+ instr->num_components);
+ P(ctx, ";\n");
+ break;
+ case nir_intrinsic_load_texture_handle_kk: {
+ const char *access = "";
+ switch (nir_intrinsic_flags(instr)) {
+ case MSL_ACCESS_READ:
+ access = ", access::read";
+ break;
+ case MSL_ACCESS_WRITE:
+ access = ", access::write";
+ break;
+ case MSL_ACCESS_READ_WRITE:
+ access = ", access::read_write";
+ break;
+ }
+ P_IND(ctx, "texture%s%s<%s%s> t%d = *(constant texture%s%s<%s%s>*)",
+ texture_dim(nir_intrinsic_image_dim(instr)),
+ nir_intrinsic_image_array(instr) ? "_array" : "",
+ tex_type_name(nir_intrinsic_dest_type(instr)), access,
+ instr->def.index, texture_dim(nir_intrinsic_image_dim(instr)),
+ nir_intrinsic_image_array(instr) ? "_array" : "",
+ tex_type_name(nir_intrinsic_dest_type(instr)), access);
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_intrinsic_load_depth_texture_kk:
+ P_IND(ctx, "depth%s%s<float> t%d = *(constant depth%s%s<float>*)",
+ texture_dim(nir_intrinsic_image_dim(instr)),
+ nir_intrinsic_image_array(instr) ? "_array" : "", instr->def.index,
+ texture_dim(nir_intrinsic_image_dim(instr)),
+ nir_intrinsic_image_array(instr) ? "_array" : "");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ";\n");
+ break;
+ case nir_intrinsic_load_sampler_handle_kk:
+ P(ctx, "sampler_table.handles[");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, "];\n");
+ break;
+ case nir_intrinsic_load_constant_agx: {
+ const char *type = msl_type_for_def(ctx->types, &instr->def);
+ const char *no_component_type =
+ msl_pipe_format_to_msl_type(nir_intrinsic_format(instr));
+ if (instr->def.num_components == 1) {
+ P(ctx, "(*(((constant %s*)", type);
+ } else {
+ P(ctx, "%s(*(constant packed_%s*)(((constant %s*)", type, type,
+ no_component_type);
+ }
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ") + ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, "));\n");
+ break;
+ }
+ case nir_intrinsic_bindless_image_load:
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ".read(");
+ image_coord_swizzle(ctx, instr);
+ if (nir_intrinsic_image_dim(instr) != GLSL_SAMPLER_DIM_BUF) {
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[3]);
+ }
+ /* read will always return vec4 and we may try to assign that to an uint
+ * which is illegal. */
+ P(ctx, ").");
+ for (uint32_t i = 0u; i < instr->def.num_components; ++i) {
+ P(ctx, "%c", "xyzw"[i]);
+ }
+ P(ctx, ";\n");
+ break;
+ case nir_intrinsic_bindless_image_store:
+ P_INDENT(ctx);
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ".write(");
+ src_to_msl(ctx, &instr->src[3]);
+ P(ctx, ", ");
+ image_coord_swizzle(ctx, instr);
+ if (nir_intrinsic_image_dim(instr) != GLSL_SAMPLER_DIM_BUF) {
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[4]);
+ }
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_bindless_image_atomic:
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ".%s(", atomic_op_to_msl(nir_intrinsic_atomic_op(instr)));
+ image_coord_swizzle(ctx, instr);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[3]);
+ P(ctx, ").x;\n");
+ break;
+ case nir_intrinsic_bindless_image_atomic_swap: {
+ const char *type = msl_type_for_def(ctx->types, &instr->def);
+ P_IND(ctx, "%s4 ta%d = ", type, instr->def.index);
+ src_to_msl(ctx, &instr->src[3]);
+ P(ctx, "; ");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ".%s(", atomic_op_to_msl(nir_intrinsic_atomic_op(instr)));
+ image_coord_swizzle(ctx, instr);
+ P(ctx, ", &ta%d, ", instr->def.index);
+ src_to_msl(ctx, &instr->src[4]);
+ P(ctx, "); %s t%d = ta%d.x;\n", type, instr->def.index, instr->def.index);
+ break;
+ }
+ case nir_intrinsic_ballot:
+ P(ctx, "(ulong)simd_ballot(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_elect:
+ /* If we don't add && "(ulong)simd_ballot(true)"" the following CTS tests
+ * fail:
+ * dEQP-VK.subgroups.ballot_other.graphics.subgroupballotfindlsb
+ * dEQP-VK.subgroups.ballot_other.compute.subgroupballotfindlsb
+ * Weird Metal bug:
+ * if (simd_is_first())
+ * temp = 3u;
+ * else
+ * temp = simd_ballot(true); <- This will return all active threads...
+ */
+ P(ctx, "simd_is_first() && (ulong)simd_ballot(true);\n");
+ break;
+ case nir_intrinsic_read_first_invocation:
+ P(ctx, "simd_broadcast_first(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_read_invocation:
+ P(ctx, "simd_broadcast(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");");
+ break;
+ case nir_intrinsic_shuffle:
+ P(ctx, "simd_shuffle(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_shuffle_xor:
+ P(ctx, "simd_shuffle_xor(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_shuffle_up:
+ P(ctx, "simd_shuffle_up(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_shuffle_down:
+ P(ctx, "simd_shuffle_down(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");\n");
+ break;
+
+ case nir_intrinsic_vote_all:
+ P(ctx, "simd_all(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_vote_any:
+ P(ctx, "simd_any(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_quad_broadcast:
+ P(ctx, "quad_broadcast(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", ");
+ src_to_msl(ctx, &instr->src[1]);
+ P(ctx, ");\n");
+ break;
+ case nir_intrinsic_quad_swap_horizontal:
+ P(ctx, "quad_shuffle_xor(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", 1);\n");
+ break;
+ case nir_intrinsic_quad_swap_vertical:
+ P(ctx, "quad_shuffle_xor(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", 2);\n");
+ break;
+ case nir_intrinsic_quad_swap_diagonal:
+ P(ctx, "quad_shuffle_xor(");
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ", 3);\n");
+ break;
+ case nir_intrinsic_reduce:
+ switch (nir_intrinsic_reduction_op(instr)) {
+ case nir_op_iadd:
+ case nir_op_fadd:
+ P(ctx, "simd_sum(");
+ break;
+ case nir_op_imul:
+ case nir_op_fmul:
+ P(ctx, "simd_product(");
+ break;
+ case nir_op_imin:
+ case nir_op_umin:
+ case nir_op_fmin:
+ P(ctx, "simd_min(");
+ break;
+ case nir_op_imax:
+ case nir_op_umax:
+ case nir_op_fmax:
+ P(ctx, "simd_max(");
+ break;
+ case nir_op_iand:
+ P(ctx, "simd_and(");
+ break;
+ case nir_op_ior:
+ P(ctx, "simd_or(");
+ break;
+ case nir_op_ixor:
+ P(ctx, "simd_xor(");
+ break;
+ default:
+ UNREACHABLE("Bad reduction op");
+ }
+
+ src_to_msl(ctx, &instr->src[0]);
+ P(ctx, ");\n");
+ break;
+ default:
+ P_IND(ctx, "Unknown intrinsic %s\n", info->name);
+ }
+}
+
+static nir_src *
+nir_tex_get_src(struct nir_tex_instr *tex, nir_tex_src_type type)
+{
+ int idx = nir_tex_instr_src_index(tex, type);
+ if (idx == -1)
+ return NULL;
+ return &tex->src[idx].src;
+}
+
+static void
+tex_coord_swizzle(struct nir_to_msl_ctx *ctx, nir_tex_instr *tex)
+{
+ texture_src_coord_swizzle(ctx, nir_tex_get_src(tex, nir_tex_src_coord),
+ tex->coord_components, false, tex->is_array);
+}
+
+static void
+tex_to_msl(struct nir_to_msl_ctx *ctx, nir_tex_instr *tex)
+{
+ nir_src *texhandle = nir_tex_get_src(tex, nir_tex_src_texture_handle);
+ nir_src *sampler = nir_tex_get_src(tex, nir_tex_src_sampler_handle);
+ // Projectors have to be lowered away to regular arithmetic
+ assert(!nir_tex_get_src(tex, nir_tex_src_projector));
+
+ P_IND(ctx, "t%d = ", tex->def.index);
+
+ switch (tex->op) {
+ case nir_texop_tex:
+ case nir_texop_txb:
+ case nir_texop_txl:
+ case nir_texop_txd: {
+ nir_src *bias = nir_tex_get_src(tex, nir_tex_src_bias);
+ nir_src *lod = nir_tex_get_src(tex, nir_tex_src_lod);
+ nir_src *ddx = nir_tex_get_src(tex, nir_tex_src_ddx);
+ nir_src *ddy = nir_tex_get_src(tex, nir_tex_src_ddy);
+ nir_src *min_lod_clamp = nir_tex_get_src(tex, nir_tex_src_min_lod);
+ nir_src *offset = nir_tex_get_src(tex, nir_tex_src_offset);
+ nir_src *comparator = nir_tex_get_src(tex, nir_tex_src_comparator);
+ src_to_msl(ctx, texhandle);
+ if (comparator) {
+ P(ctx, ".sample_compare(");
+ } else {
+ P(ctx, ".sample(");
+ }
+ src_to_msl(ctx, sampler);
+ P(ctx, ", ");
+ tex_coord_swizzle(ctx, tex);
+ if (comparator) {
+ P(ctx, ", ");
+ src_to_msl(ctx, comparator);
+ }
+ if (bias) {
+ P(ctx, ", bias(");
+ src_to_msl(ctx, bias);
+ P(ctx, ")");
+ }
+ if (lod) {
+ P(ctx, ", level(");
+ src_to_msl(ctx, lod);
+ P(ctx, ")");
+ }
+ if (ddx) {
+ P(ctx, ", gradient%s(", texture_dim(tex->sampler_dim));
+ src_to_msl(ctx, ddx);
+ P(ctx, ", ");
+ src_to_msl(ctx, ddy);
+ P(ctx, ")");
+ }
+ if (min_lod_clamp) {
+ P(ctx, ", min_lod_clamp(");
+ src_to_msl(ctx, min_lod_clamp);
+ P(ctx, ")");
+ }
+ if (offset) {
+ P(ctx, ", ");
+ src_to_msl(ctx, offset);
+ }
+ P(ctx, ");\n");
+ break;
+ }
+ case nir_texop_txf: {
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".read(");
+ tex_coord_swizzle(ctx, tex);
+ nir_src *lod = nir_tex_get_src(tex, nir_tex_src_lod);
+ if (lod) {
+ P(ctx, ", ");
+ src_to_msl(ctx, lod);
+ }
+ P(ctx, ");\n");
+ break;
+ }
+ case nir_texop_txf_ms:
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".read(");
+ tex_coord_swizzle(ctx, tex);
+ P(ctx, ", ");
+ src_to_msl(ctx, nir_tex_get_src(tex, nir_tex_src_ms_index));
+ P(ctx, ");\n");
+ break;
+ case nir_texop_txs: {
+ nir_src *lod = nir_tex_get_src(tex, nir_tex_src_lod);
+ if (tex->def.num_components > 1u) {
+ P(ctx, "%s%d(", tex_type_name(tex->dest_type),
+ tex->def.num_components);
+ } else {
+ P(ctx, "%s(", tex_type_name(tex->dest_type));
+ }
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_width(")
+ if (lod && tex->sampler_dim != GLSL_SAMPLER_DIM_MS &&
+ tex->sampler_dim != GLSL_SAMPLER_DIM_BUF)
+ src_to_msl(ctx, lod);
+ P(ctx, ")");
+ if (tex->sampler_dim != GLSL_SAMPLER_DIM_1D &&
+ tex->sampler_dim != GLSL_SAMPLER_DIM_BUF) {
+ P(ctx, ", ");
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_height(");
+ if (lod && tex->sampler_dim != GLSL_SAMPLER_DIM_MS &&
+ tex->sampler_dim != GLSL_SAMPLER_DIM_BUF)
+ src_to_msl(ctx, lod);
+ P(ctx, ")");
+ }
+ if (tex->sampler_dim == GLSL_SAMPLER_DIM_3D) {
+ P(ctx, ", ");
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_depth(");
+ if (lod)
+ src_to_msl(ctx, lod);
+ P(ctx, ")");
+ }
+ if (tex->is_array) {
+ P(ctx, ", ");
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_array_size()");
+ }
+ P(ctx, ");\n")
+ break;
+ }
+ case nir_texop_query_levels:
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_num_mip_levels();\n");
+ break;
+ case nir_texop_tg4: {
+ nir_src *offset = nir_tex_get_src(tex, nir_tex_src_offset);
+ nir_src *comparator = nir_tex_get_src(tex, nir_tex_src_comparator);
+ src_to_msl(ctx, texhandle);
+ if (comparator) {
+ P(ctx, ".gather_compare(");
+ } else {
+ P(ctx, ".gather(");
+ }
+ src_to_msl(ctx, sampler);
+ P(ctx, ", ");
+ tex_coord_swizzle(ctx, tex);
+ if (comparator) {
+ P(ctx, ", ");
+ src_to_msl(ctx, comparator);
+ }
+ P(ctx, ", ");
+ if (offset)
+ src_to_msl(ctx, offset);
+ else
+ P(ctx, "int2(0)");
+
+ /* Non-depth textures require component */
+ if (!comparator) {
+ P(ctx, ", component::%c", "xyzw"[tex->component]);
+ }
+
+ P(ctx, ");\n");
+ break;
+ }
+
+ case nir_texop_texture_samples:
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".get_num_samples();\n");
+ break;
+ case nir_texop_lod: {
+ nir_src *coord = nir_tex_get_src(tex, nir_tex_src_coord);
+ nir_src *bias = nir_tex_get_src(tex, nir_tex_src_bias);
+ nir_src *min = nir_tex_get_src(tex, nir_tex_src_min_lod);
+ nir_src *max = nir_tex_get_src(tex, nir_tex_src_max_lod_kk);
+ P(ctx, "float2(round(clamp(")
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".calculate_unclamped_lod(");
+ src_to_msl(ctx, sampler);
+ P(ctx, ", ");
+ src_to_msl(ctx, coord);
+ P(ctx, ") + ");
+ src_to_msl(ctx, bias);
+ P(ctx, ", ");
+ src_to_msl(ctx, min);
+ P(ctx, ", ");
+ src_to_msl(ctx, max);
+ P(ctx, ")), ");
+ src_to_msl(ctx, texhandle);
+ P(ctx, ".calculate_unclamped_lod(");
+ src_to_msl(ctx, sampler);
+ P(ctx, ", ");
+ src_to_msl(ctx, coord);
+ P(ctx, ")");
+ P(ctx, ");\n");
+ break;
+ }
+ default:
+ assert(!"Unsupported texture op");
+ }
+}
+
+static void
+jump_instr_to_msl(struct nir_to_msl_ctx *ctx, nir_jump_instr *jump)
+{
+ switch (jump->type) {
+ case nir_jump_halt:
+ P_IND(ctx, "TODO: halt\n");
+ assert(!"Unimplemented");
+ break;
+ case nir_jump_break:
+ P_IND(ctx, "break;\n");
+ break;
+ case nir_jump_continue:
+ P_IND(ctx, "continue;\n");
+ break;
+ case nir_jump_return:
+ assert(!"functions should have been inlined by now");
+ break;
+ case nir_jump_goto:
+ case nir_jump_goto_if:
+ assert(!"Unstructured control flow not supported");
+ break;
+ }
+}
+
+static void
+instr_to_msl(struct nir_to_msl_ctx *ctx, nir_instr *instr)
+{
+ switch (instr->type) {
+ case nir_instr_type_alu: {
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ P_IND(ctx, "t%d = ", alu->def.index);
+ alu_to_msl(ctx, alu);
+ P(ctx, ";\n");
+ break;
+ }
+ case nir_instr_type_deref:
+ assert(!"We should have lowered derefs by now");
+ break;
+ case nir_instr_type_call:
+ assert(!"We should have inlined all functions by now");
+ break;
+ case nir_instr_type_tex:
+ tex_to_msl(ctx, nir_instr_as_tex(instr));
+ break;
+ case nir_instr_type_intrinsic:
+ intrinsic_to_msl(ctx, nir_instr_as_intrinsic(instr));
+ break;
+ case nir_instr_type_load_const:
+ // consts get inlined into their uses
+ break;
+ case nir_instr_type_jump:
+ jump_instr_to_msl(ctx, nir_instr_as_jump(instr));
+ break;
+ case nir_instr_type_undef:
+ // undefs get inlined into their uses (and we shouldn't see them hopefully)
+ break;
+ case nir_instr_type_phi:
+ case nir_instr_type_parallel_copy:
+ assert(!"NIR should be taken out of SSA");
+ break;
+ }
+}
+
+static void
+cf_node_to_metal(struct nir_to_msl_ctx *ctx, nir_cf_node *node)
+{
+ switch (node->type) {
+ case nir_cf_node_block: {
+ nir_block *block = nir_cf_node_as_block(node);
+ nir_foreach_instr(instr, block) {
+ instr_to_msl(ctx, instr);
+ }
+ break;
+ }
+ case nir_cf_node_if: {
+ nir_if *ifnode = nir_cf_node_as_if(node);
+ P_IND(ctx, "if (");
+ src_to_msl(ctx, &ifnode->condition);
+ P(ctx, ") {\n");
+ ctx->indentlevel++;
+ foreach_list_typed(nir_cf_node, node, node, &ifnode->then_list) {
+ cf_node_to_metal(ctx, node);
+ }
+ ctx->indentlevel--;
+ if (!nir_cf_list_is_empty_block(&ifnode->else_list)) {
+ P_IND(ctx, "} else {\n");
+ ctx->indentlevel++;
+ foreach_list_typed(nir_cf_node, node, node, &ifnode->else_list) {
+ cf_node_to_metal(ctx, node);
+ }
+ ctx->indentlevel--;
+ }
+ P_IND(ctx, "}\n");
+ break;
+ }
+ case nir_cf_node_loop: {
+ nir_loop *loop = nir_cf_node_as_loop(node);
+ assert(!nir_loop_has_continue_construct(loop));
+ /* We need to loop to infinite since MSL compiler crashes if we have
+ something like (simplified version):
+ * // clang-format off
+ * while (true) {
+ * if (some_conditional) {
+ * break_loop = true;
+ * } else {
+ * break_loop = false;
+ * }
+ * if (break_loop) {
+ * break;
+ * }
+ * }
+ * // clang-format on
+ * The issue I believe is that some_conditional wouldn't change the value
+ * no matter in which iteration we are (something like fetching the same
+ * value from a buffer) and the MSL compiler doesn't seem to like that
+ * much to the point it crashes.
+ * With this for loop now, we trick the MSL compiler into believing we are
+ * not doing an infinite loop (wink wink)
+ */
+ P_IND(ctx,
+ "for (uint64_t no_crash = 0u; no_crash < %" PRIu64
+ "; ++no_crash) {\n",
+ UINT64_MAX);
+ ctx->indentlevel++;
+ foreach_list_typed(nir_cf_node, node, node, &loop->body) {
+ cf_node_to_metal(ctx, node);
+ }
+ ctx->indentlevel--;
+ P_IND(ctx, "}\n");
+ break;
+ }
+ case nir_cf_node_function:
+ assert(!"All functions are supposed to be inlined");
+ }
+}
+
+static void
+emit_output_return(struct nir_to_msl_ctx *ctx, nir_shader *shader)
+{
+ if (shader->info.stage == MESA_SHADER_VERTEX ||
+ shader->info.stage == MESA_SHADER_FRAGMENT)
+ P_IND(ctx, "return out;\n");
+}
+
+static void
+rename_main_entrypoint(struct nir_shader *nir)
+{
+ /* Rename entrypoint to avoid MSL limitations after we've removed all others.
+ * We don't really care what it's named as long as it's not "main"
+ */
+ const char *entrypoint_name = "main_entrypoint";
+ nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
+ struct nir_function *function = entrypoint->function;
+ ralloc_free((void *)function->name);
+ function->name = ralloc_strdup(function, entrypoint_name);
+}
+
+static bool
+kk_scalarize_filter(const nir_instr *instr, const void *data)
+{
+ if (instr->type != nir_instr_type_alu)
+ return false;
+ return true;
+}
+
+void
+msl_preprocess_nir(struct nir_shader *nir)
+{
+ /* First, inline away all the functions */
+ NIR_PASS(_, nir, nir_lower_variable_initializers, nir_var_function_temp);
+ NIR_PASS(_, nir, nir_lower_returns);
+ NIR_PASS(_, nir, nir_inline_functions);
+ NIR_PASS(_, nir, nir_opt_deref);
+ nir_remove_non_entrypoints(nir);
+
+ NIR_PASS(_, nir, nir_lower_global_vars_to_local);
+ NIR_PASS(_, nir, nir_split_var_copies);
+ NIR_PASS(_, nir, nir_split_struct_vars, nir_var_function_temp);
+ NIR_PASS(_, nir, nir_split_array_vars, nir_var_function_temp);
+ NIR_PASS(_, nir, nir_split_per_member_structs);
+ NIR_PASS(_, nir, nir_lower_continue_constructs);
+
+ NIR_PASS(_, nir, nir_lower_frexp);
+
+ NIR_PASS(_, nir, nir_lower_vars_to_ssa);
+ NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp, NULL);
+ if (nir->info.stage == MESA_SHADER_FRAGMENT) {
+ nir_input_attachment_options input_attachment_options = {
+ .use_fragcoord_sysval = true,
+ .use_layer_id_sysval = true,
+ };
+ NIR_PASS(_, nir, nir_lower_input_attachments, &input_attachment_options);
+ }
+ NIR_PASS(_, nir, nir_opt_combine_barriers, NULL, NULL);
+ NIR_PASS(_, nir, nir_lower_var_copies);
+ NIR_PASS(_, nir, nir_split_var_copies);
+
+ NIR_PASS(_, nir, nir_split_array_vars,
+ nir_var_function_temp | nir_var_shader_in | nir_var_shader_out);
+ NIR_PASS(_, nir, nir_lower_alu_to_scalar, kk_scalarize_filter, NULL);
+
+ NIR_PASS(_, nir, nir_lower_indirect_derefs,
+ nir_var_shader_in | nir_var_shader_out, UINT32_MAX);
+ NIR_PASS(_, nir, nir_lower_vars_to_scratch, nir_var_function_temp, 0,
+ glsl_get_natural_size_align_bytes,
+ glsl_get_natural_size_align_bytes);
+
+ NIR_PASS(_, nir, nir_lower_system_values);
+
+ nir_lower_compute_system_values_options csv_options = {
+ .has_base_global_invocation_id = 0,
+ .has_base_workgroup_id = true,
+ };
+ NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options);
+
+ msl_nir_lower_subgroups(nir);
+}
+
+bool
+msl_optimize_nir(struct nir_shader *nir)
+{
+ bool progress;
+ NIR_PASS(_, nir, nir_lower_int64);
+ do {
+ progress = false;
+
+ NIR_PASS(progress, nir, nir_split_var_copies);
+ NIR_PASS(progress, nir, nir_split_struct_vars, nir_var_function_temp);
+ NIR_PASS(progress, nir, nir_lower_var_copies);
+ NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
+ NIR_PASS(progress, nir, nir_opt_undef);
+ NIR_PASS(progress, nir, nir_opt_dce);
+ NIR_PASS(progress, nir, nir_opt_cse);
+ NIR_PASS(progress, nir, nir_opt_dead_cf);
+ NIR_PASS(progress, nir, nir_copy_prop);
+ NIR_PASS(progress, nir, nir_opt_deref);
+ NIR_PASS(progress, nir, nir_opt_constant_folding);
+ NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
+ NIR_PASS(progress, nir, nir_opt_dead_write_vars);
+ NIR_PASS(progress, nir, nir_opt_combine_stores, nir_var_all);
+ NIR_PASS(progress, nir, nir_remove_dead_variables, nir_var_function_temp,
+ NULL);
+ NIR_PASS(progress, nir, nir_opt_algebraic);
+ NIR_PASS(progress, nir, nir_opt_if, 0);
+ NIR_PASS(progress, nir, nir_opt_remove_phis);
+ NIR_PASS(progress, nir, nir_opt_loop);
+ NIR_PASS(progress, nir, nir_lower_pack);
+ NIR_PASS(progress, nir, nir_lower_alu_to_scalar, kk_scalarize_filter,
+ NULL);
+ } while (progress);
+ NIR_PASS(_, nir, nir_lower_load_const_to_scalar);
+ NIR_PASS(_, nir, msl_nir_lower_algebraic_late);
+ NIR_PASS(_, nir, nir_convert_from_ssa, true, false);
+ nir_trivialize_registers(nir);
+ NIR_PASS(_, nir, nir_copy_prop);
+
+ return progress;
+}
+
+static void
+msl_gather_info(struct nir_to_msl_ctx *ctx)
+{
+ nir_function_impl *impl = nir_shader_get_entrypoint(ctx->shader);
+ ctx->types = msl_infer_types(ctx->shader);
+
+ /* TODO_KOSMICKRISP
+ * Reindex blocks and ssa. This allows us to optimize things we don't at the
+ * moment. */
+ nir_index_blocks(impl);
+ nir_index_ssa_defs(impl);
+
+ if (ctx->shader->info.stage == MESA_SHADER_VERTEX ||
+ ctx->shader->info.stage == MESA_SHADER_FRAGMENT) {
+ msl_gather_io_info(ctx, ctx->inputs_info, ctx->outputs_info);
+ }
+}
+
+static void
+predeclare_ssa_values(struct nir_to_msl_ctx *ctx, nir_function_impl *impl)
+{
+ nir_foreach_block(block, impl) {
+ nir_foreach_instr(instr, block) {
+ nir_def *def;
+ switch (instr->type) {
+ case nir_instr_type_alu: {
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ def = &alu->def;
+ break;
+ }
+ case nir_instr_type_intrinsic: {
+ nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+ if (!instrinsic_needs_dest_type(intr))
+ continue;
+ def = &intr->def;
+ break;
+ }
+ case nir_instr_type_tex: {
+ nir_tex_instr *tex = nir_instr_as_tex(instr);
+ def = &tex->def;
+ break;
+ }
+ default:
+ continue;
+ }
+ const char *type = msl_type_for_def(ctx->types, def);
+ if (!type)
+ continue;
+ if (msl_def_is_sampler(ctx, def)) {
+ P_IND(ctx, "%s t%u;\n", type, def->index);
+ } else
+ P_IND(ctx, "%s t%u = %s(0);\n", type, def->index, type);
+ }
+ }
+}
+
+char *
+nir_to_msl(nir_shader *shader, void *mem_ctx)
+{
+ /* Need to rename the entrypoint here since hardcoded shaders used by vk_meta
+ * don't go through the preprocess step since we are the ones creating them.
+ */
+ rename_main_entrypoint(shader);
+
+ struct nir_to_msl_ctx ctx = {
+ .shader = shader,
+ .text = _mesa_string_buffer_create(mem_ctx, 1024),
+ };
+ nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+ msl_gather_info(&ctx);
+
+ P(&ctx, "// Generated by Mesa compiler\n");
+ if (shader->info.stage == MESA_SHADER_COMPUTE)
+ P(&ctx, "#include <metal_compute>\n");
+ P(&ctx, "#include <metal_stdlib>\n");
+ P(&ctx, "using namespace metal;\n");
+
+ msl_emit_io_blocks(&ctx, shader);
+ if (shader->info.stage == MESA_SHADER_FRAGMENT &&
+ shader->info.fs.early_fragment_tests)
+ P(&ctx, "[[early_fragment_tests]]\n");
+ P(&ctx, "%s %s %s(\n", get_stage_string(shader->info.stage),
+ output_type(shader), get_entrypoint_name(shader));
+ ctx.indentlevel++;
+ emit_sysvals(&ctx, shader);
+ emit_inputs(&ctx, shader);
+ ctx.indentlevel--;
+ P(&ctx, ")\n");
+ P(&ctx, "{\n");
+ ctx.indentlevel++;
+ msl_emit_output_var(&ctx, shader);
+ emit_local_vars(&ctx, shader);
+ predeclare_ssa_values(&ctx, impl);
+ foreach_list_typed(nir_cf_node, node, node, &impl->body) {
+ cf_node_to_metal(&ctx, node);
+ }
+ emit_output_return(&ctx, shader);
+ ctx.indentlevel--;
+ P(&ctx, "}\n");
+ char *ret = ctx.text->buf;
+ ralloc_steal(mem_ctx, ctx.text->buf);
+ ralloc_free(ctx.text);
+ return ret;
+}