About Social Code
aboutsummaryrefslogtreecommitdiff
path: root/src/compiler/nir/nir_opt_mqsad.c
blob: f1f24d5358f0f0cfd1ee1ddaf692502bca793164 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
 * Copyright 2023 Valve Corporation
 * SPDX-License-Identifier: MIT
 */
#include "nir.h"
#include "nir_builder.h"
#include "nir_worklist.h"

/*
 * This pass recognizes certain patterns of nir_op_shfr and nir_op_msad_4x8 and replaces it
 * with a single nir_op_mqsad_4x8 instruction.
 */

struct mqsad {
   nir_scalar ref;
   nir_scalar src[2];

   nir_scalar accum[4];
   nir_alu_instr *msad[4];
   unsigned first_msad_index;
   uint8_t mask;
};

static bool
is_mqsad_compatible(struct mqsad *mqsad, nir_scalar ref, nir_scalar src0, nir_scalar src1,
                    unsigned idx, nir_alu_instr *msad)
{
   if (!nir_scalar_equal(ref, mqsad->ref) || !nir_scalar_equal(src0, mqsad->src[0]))
      return false;
   if ((mqsad->mask & 0b1110) && idx && !nir_scalar_equal(src1, mqsad->src[1]))
      return false;

   /* Ensure that this MSAD doesn't depend on any previous MSAD. */
   nir_instr_worklist wl;
   nir_instr_worklist_init(&wl);
   nir_instr_worklist_add_ssa_srcs(&wl, &msad->instr);
   nir_foreach_instr_in_worklist(instr, &wl) {
      if (instr->block != msad->instr.block || instr->index < mqsad->first_msad_index)
         continue;

      u_foreach_bit(i, mqsad->mask) {
         if (instr == &mqsad->msad[i]->instr) {
            nir_instr_worklist_fini(&wl);
            return false;
         }
      }

      nir_instr_worklist_add_ssa_srcs(&wl, instr);
   }
   nir_instr_worklist_fini(&wl);

   return true;
}

static void
parse_msad(nir_alu_instr *msad, struct mqsad *mqsad)
{
   if (msad->def.num_components != 1)
      return;

   nir_scalar msad_s = nir_get_scalar(&msad->def, 0);
   nir_scalar ref = nir_scalar_chase_alu_src(msad_s, 0);
   nir_scalar accum = nir_scalar_chase_alu_src(msad_s, 2);

   unsigned idx = 0;
   nir_scalar src0 = nir_scalar_chase_alu_src(msad_s, 1);
   nir_scalar src1;
   if (nir_scalar_is_alu(src0) && nir_scalar_alu_op(src0) == nir_op_shfr) {
      nir_scalar amount_s = nir_scalar_chase_alu_src(src0, 2);
      uint32_t amount = nir_scalar_is_const(amount_s) ? nir_scalar_as_uint(amount_s) : 0;
      if (amount == 8 || amount == 16 || amount == 24) {
         idx = amount / 8;
         src1 = nir_scalar_chase_alu_src(src0, 0);
         src0 = nir_scalar_chase_alu_src(src0, 1);
      }
   }

   if (mqsad->mask && !is_mqsad_compatible(mqsad, ref, src0, src1, idx, msad))
      memset(mqsad, 0, sizeof(*mqsad));

   /* Add this instruction to the in-progress MQSAD. */
   mqsad->ref = ref;
   mqsad->src[0] = src0;
   if (idx)
      mqsad->src[1] = src1;

   mqsad->accum[idx] = accum;
   mqsad->msad[idx] = msad;
   if (!mqsad->mask)
      mqsad->first_msad_index = msad->instr.index;
   mqsad->mask |= 1 << idx;
}

static void
create_msad(nir_builder *b, struct mqsad *mqsad)
{
   nir_def *mqsad_def = nir_mqsad_4x8(b, nir_mov_scalar(b, mqsad->ref),
                                      nir_vec_scalars(b, mqsad->src, 2),
                                      nir_vec_scalars(b, mqsad->accum, 4));

   for (unsigned i = 0; i < 4; i++)
      nir_def_rewrite_uses(&mqsad->msad[i]->def, nir_channel(b, mqsad_def, i));

   memset(mqsad, 0, sizeof(*mqsad));
}

bool
nir_opt_mqsad(nir_shader *shader)
{
   bool progress = false;
   nir_foreach_function_impl(impl, shader) {
      bool progress_impl = false;

      nir_metadata_require(impl, nir_metadata_instr_index);

      nir_foreach_block(block, impl) {
         struct mqsad mqsad;
         memset(&mqsad, 0, sizeof(mqsad));

         nir_foreach_instr(instr, block) {
            if (instr->type != nir_instr_type_alu)
               continue;

            nir_alu_instr *alu = nir_instr_as_alu(instr);
            if (alu->op != nir_op_msad_4x8)
               continue;

            parse_msad(alu, &mqsad);

            if (mqsad.mask == 0xf) {
               nir_builder b = nir_builder_at(nir_before_instr(instr));
               create_msad(&b, &mqsad);
               progress_impl = true;
            }
         }
      }

      if (progress_impl) {
         progress = nir_progress(true, impl, nir_metadata_control_flow);
      } else {
         nir_progress(true, impl, nir_metadata_block_index);
      }
   }

   return progress;
}