About Social Code
aboutsummaryrefslogtreecommitdiff
path: root/src/kosmickrisp/vulkan/kk_cmd_dispatch.c
blob: 338766ef12d62cec744442f1d2ef08a6799e1545 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*
 * Copyright © 2025 LunarG, Inc
 * Copyright 2025 Google LLC
 * SPDX-License-Identifier: MIT
 */

#include "vulkan/vulkan_core.h"

#include "kk_buffer.h"
#include "kk_cmd_buffer.h"
#include "kk_descriptor_set_layout.h"
#include "kk_device.h"
#include "kk_encoder.h"
#include "kk_entrypoints.h"
#include "kk_shader.h"

#include "kosmickrisp/bridge/mtl_bridge.h"

#include "vk_common_entrypoints.h"

void
kk_cmd_dispatch_pipeline(struct kk_cmd_buffer *cmd,
                         mtl_compute_encoder *encoder,
                         mtl_compute_pipeline_state *pipeline,
                         const void *push_data, size_t push_size,
                         uint32_t groupCountX, uint32_t groupCountY,
                         uint32_t groupCountZ)
{
   struct kk_root_descriptor_table *root = NULL;
   struct kk_bo *bo = kk_cmd_allocate_buffer(cmd, sizeof(*root), 8u);
   /* kk_cmd_allocate_buffer already sets the error, we can just exit */
   if (!bo)
      return;

   root = bo->cpu;
   assert(push_size <= sizeof(root->push));
   memcpy(root->push, push_data, push_size);
   root->cs.base_group[0] = 1; /* TODO_KOSMICKRISP This is hard-coded because we
                                  know this is the size we create them with */
   root->cs.base_group[1] = 1;
   root->cs.base_group[2] = 1;

   mtl_compute_set_buffer(encoder, bo->map, 0, 0);
   mtl_compute_set_pipeline_state(encoder, pipeline);

   struct mtl_size grid_size = {
      .x = groupCountX,
      .y = groupCountY,
      .z = groupCountZ,
   };
   struct mtl_size local_size = {
      .x = 1,
      .y = 1,
      .z = 1,
   };
   mtl_dispatch_threads(encoder, grid_size, local_size);
}

VKAPI_ATTR void VKAPI_CALL
kk_CmdDispatch(VkCommandBuffer commandBuffer, uint32_t groupCountX,
               uint32_t groupCountY, uint32_t groupCountZ)
{
   kk_CmdDispatchBase(commandBuffer, 0, 0, 0, groupCountX, groupCountY,
                      groupCountZ);
}

static void
kk_flush_compute_state(struct kk_cmd_buffer *cmd)
{
   mtl_compute_encoder *enc = kk_compute_encoder(cmd);
   struct kk_device *dev = kk_cmd_buffer_device(cmd);

   // Fill Metal argument buffer with descriptor set addresses
   struct kk_descriptor_state *desc = &cmd->state.cs.descriptors;

   if (desc->push_dirty)
      kk_cmd_buffer_flush_push_descriptors(cmd, desc);
   /* After push descriptors' buffers are created. Otherwise, the buffer where
    * they live will not be created and cannot make it resident */
   if (desc->sets_not_resident)
      kk_make_descriptor_resources_resident(cmd,
                                            VK_PIPELINE_BIND_POINT_COMPUTE);
   if (desc->root_dirty)
      kk_upload_descriptor_root(cmd, VK_PIPELINE_BIND_POINT_COMPUTE);

   /* Make user allocated heaps resident */
   simple_mtx_lock(&dev->user_heap_cache.mutex);
   if (cmd->encoder->main.user_heap_hash != dev->user_heap_cache.hash) {
      cmd->encoder->main.user_heap_hash = dev->user_heap_cache.hash;
      mtl_heap **heaps = util_dynarray_begin(&dev->user_heap_cache.handles);
      uint32_t count =
         util_dynarray_num_elements(&dev->user_heap_cache.handles, mtl_heap *);
      mtl_compute_use_heaps(enc, heaps, count);
   }
   simple_mtx_unlock(&dev->user_heap_cache.mutex);

   struct kk_bo *root_buffer = desc->root.root_buffer;
   if (root_buffer)
      mtl_compute_set_buffer(enc, root_buffer->map, 0, 0);

   mtl_compute_set_pipeline_state(enc, cmd->state.cs.pipeline_state);
   cmd->state.cs.dirty = 0u;
}

VKAPI_ATTR void VKAPI_CALL
kk_CmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t baseGroupX,
                   uint32_t baseGroupY, uint32_t baseGroupZ,
                   uint32_t groupCountX, uint32_t groupCountY,
                   uint32_t groupCountZ)
{
   VK_FROM_HANDLE(kk_cmd_buffer, cmd, commandBuffer);

   struct kk_descriptor_state *desc = &cmd->state.cs.descriptors;
   desc->root_dirty |= desc->root.cs.base_group[0] != baseGroupX;
   desc->root_dirty |= desc->root.cs.base_group[1] != baseGroupY;
   desc->root_dirty |= desc->root.cs.base_group[2] != baseGroupZ;
   desc->root.cs.base_group[0] = baseGroupX;
   desc->root.cs.base_group[1] = baseGroupY;
   desc->root.cs.base_group[2] = baseGroupZ;

   kk_flush_compute_state(cmd);

   struct mtl_size grid_size = {
      .x = groupCountX,
      .y = groupCountY,
      .z = groupCountZ,
   };
   mtl_compute_encoder *enc = kk_compute_encoder(cmd);
   mtl_dispatch_threads(enc, grid_size, cmd->state.cs.local_size);
}

VKAPI_ATTR void VKAPI_CALL
kk_CmdDispatchIndirect(VkCommandBuffer commandBuffer, VkBuffer _buffer,
                       VkDeviceSize offset)
{
   VK_FROM_HANDLE(kk_cmd_buffer, cmd, commandBuffer);
   VK_FROM_HANDLE(kk_buffer, buffer, _buffer);

   struct kk_descriptor_state *desc = &cmd->state.cs.descriptors;
   desc->root_dirty |= desc->root.cs.base_group[0] != 0;
   desc->root_dirty |= desc->root.cs.base_group[1] != 0;
   desc->root_dirty |= desc->root.cs.base_group[2] != 0;
   desc->root.cs.base_group[0] = 0;
   desc->root.cs.base_group[1] = 0;
   desc->root.cs.base_group[2] = 0;

   kk_flush_compute_state(cmd);

   mtl_compute_encoder *enc = kk_compute_encoder(cmd);
   mtl_dispatch_threadgroups_with_indirect_buffer(
      enc, buffer->mtl_handle, offset, cmd->state.cs.local_size);
}