Kompute
Loading...
Searching...
No Matches
Algorithm.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include "kompute/Core.hpp"
5
6#if KOMPUTE_OPT_USE_SPDLOG
7#include <spdlog/fmt/fmt.h>
8#else
9#include <fmt/format.h>
10#endif
11
12#include "kompute/Tensor.hpp"
13#include "kompute/Shader.hpp"
14#include "logger/Logger.hpp"
15
16namespace kp {
17
23{
24 public:
43 template<typename S = float, typename P = float>
44 Algorithm(std::shared_ptr<vk::Device> device,
45 const std::vector<std::shared_ptr<Memory>>& memObjects = {},
46 const std::vector<uint32_t>& spirv = {},
47 const Workgroup& workgroup = {},
48 const std::vector<S>& specializationConstants = {},
49 const std::vector<P>& pushConstants = {}) noexcept
50 {
51 KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
52
53 this->mDevice = device;
54
55 if (memObjects.size() && spirv.size()) {
56 KP_LOG_INFO(
57 "Kompute Algorithm initialising with tensor size: {} and "
58 "spirv size: {}",
59 memObjects.size(),
60 spirv.size());
61 this->rebuild(memObjects,
62 spirv,
63 workgroup,
64 specializationConstants,
65 pushConstants);
66 } else {
67 KP_LOG_INFO(
68 "Kompute Algorithm constructor with empty mem objects and or "
69 "spirv so not rebuilding vulkan components");
70 }
71 }
72
89 template<typename S = float, typename P = float>
90 void rebuild(const std::vector<std::shared_ptr<Memory>>& memObjects,
91 const std::vector<uint32_t>& spirv,
92 const Workgroup& workgroup = {},
93 const std::vector<S>& specializationConstants = {},
94 const std::vector<P>& pushConstants = {})
95 {
96 KP_LOG_DEBUG("Kompute Algorithm rebuild started");
97
98 this->mMemObjects = memObjects;
99
100 if (specializationConstants.size()) {
101 if (this->mSpecializationConstantsData) {
102 free(this->mSpecializationConstantsData);
103 }
104 uint32_t memorySize =
105 sizeof(decltype(specializationConstants.back()));
106 uint32_t size = specializationConstants.size();
107 uint32_t totalSize = size * memorySize;
108 this->mSpecializationConstantsData = malloc(totalSize);
109 memcpy(this->mSpecializationConstantsData,
110 specializationConstants.data(),
111 totalSize);
112 this->mSpecializationConstantsDataTypeMemorySize = memorySize;
113 this->mSpecializationConstantsSize = size;
114 }
115
116 if (pushConstants.size()) {
117 if (this->mPushConstantsData) {
118 free(this->mPushConstantsData);
119 }
120 uint32_t memorySize = sizeof(decltype(pushConstants.back()));
121 uint32_t size = pushConstants.size();
122 uint32_t totalSize = size * memorySize;
123 this->mPushConstantsData = malloc(totalSize);
124 memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
125 this->mPushConstantsDataTypeMemorySize = memorySize;
126 this->mPushConstantsSize = size;
127 }
128
129 this->setWorkgroup(
130 workgroup,
131 this->mMemObjects.size() ? this->mMemObjects[0]->size() : 1);
132
133 // Descriptor pool is created first so if available then destroy all
134 // before rebuild
135 if (this->isInit()) {
136 this->destroy();
137 }
138
139 this->createParameters();
140 this->createShaderModule(spirv);
141 this->createPipeline();
142 }
143
148 Algorithm(const Algorithm&) = delete;
149 Algorithm(const Algorithm&&) = delete;
150 Algorithm& operator=(const Algorithm&) = delete;
151 Algorithm& operator=(const Algorithm&&) = delete;
152
153
158 ~Algorithm() noexcept;
159
166 void recordDispatch(const vk::CommandBuffer& commandBuffer);
167
174 void recordBindCore(const vk::CommandBuffer& commandBuffer);
175
184 void recordBindPush(const vk::CommandBuffer& commandBuffer);
185
192 bool isInit();
193
202 void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1);
211 template<typename T>
212 void setPushConstants(const std::vector<T>& pushConstants)
213 {
214 uint32_t memorySize = sizeof(decltype(pushConstants.back()));
215 uint32_t size = pushConstants.size();
216
217 this->setPushConstants(pushConstants.data(), size, memorySize);
218 }
219
229 void setPushConstants(void* data, uint32_t size, uint32_t memorySize)
230 {
231
232 uint32_t totalSize = memorySize * size;
233 uint32_t previousTotalSize =
234 this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize;
235
236 if (totalSize != previousTotalSize) {
237 throw std::runtime_error(fmt::format(
238 "Kompute Algorithm push "
239 "constant total memory size provided is {} but expected {} bytes",
240 totalSize,
241 previousTotalSize));
242 }
243 if (this->mPushConstantsData) {
244 free(this->mPushConstantsData);
245 }
246
247 this->mPushConstantsData = malloc(totalSize);
248 memcpy(this->mPushConstantsData, data, totalSize);
249 this->mPushConstantsDataTypeMemorySize = memorySize;
250 this->mPushConstantsSize = size;
251 }
252
260 const Workgroup& getWorkgroup();
267 template<typename T>
268 const std::vector<T> getSpecializationConstants()
269 {
270 return { (T*)this->mSpecializationConstantsData,
271 ((T*)this->mSpecializationConstantsData) +
272 this->mSpecializationConstantsSize };
273 }
279 template<typename T>
280 const std::vector<T> getPushConstants()
281 {
282 return { (T*)this->mPushConstantsData,
283 ((T*)this->mPushConstantsData) + this->mPushConstantsSize };
284 }
290 const std::vector<std::shared_ptr<Memory>>& getMemObjects();
291
292 void destroy();
293
294 private:
295 // -------------- NEVER OWNED RESOURCES
296 std::shared_ptr<vk::Device> mDevice;
297 std::vector<std::shared_ptr<Memory>> mMemObjects;
298
299 // -------------- OPTIONALLY OWNED RESOURCES
300 std::shared_ptr<vk::DescriptorSetLayout> mDescriptorSetLayout;
301 bool mFreeDescriptorSetLayout = false;
302 std::shared_ptr<vk::DescriptorPool> mDescriptorPool;
303 bool mFreeDescriptorPool = false;
304 std::shared_ptr<vk::DescriptorSet> mDescriptorSet;
305 bool mFreeDescriptorSet = false;
306 std::shared_ptr<vk::PipelineLayout> mPipelineLayout;
307 bool mFreePipelineLayout = false;
308 std::shared_ptr<vk::PipelineCache> mPipelineCache;
309 bool mFreePipelineCache = false;
310 std::shared_ptr<vk::Pipeline> mPipeline;
311 bool mFreePipeline = false;
312
313 // -------------- ALWAYS OWNED RESOURCES
314 void* mSpecializationConstantsData = nullptr;
315 uint32_t mSpecializationConstantsDataTypeMemorySize = 0;
316 uint32_t mSpecializationConstantsSize = 0;
317 void* mPushConstantsData = nullptr;
318 uint32_t mPushConstantsDataTypeMemorySize = 0;
319 uint32_t mPushConstantsSize = 0;
320 Workgroup mWorkgroup;
321 std::shared_ptr<Shader> mShader = nullptr;
322
323 // Create util functions
324 void createShaderModule(const std::vector<uint32_t>& spirv);
325 void createPipeline();
326
327 // Parameters
328 void createParameters();
329};
330
331} // End namespace kp
Definition Algorithm.hpp:23
void recordBindCore(const vk::CommandBuffer &commandBuffer)
const std::vector< T > getSpecializationConstants()
Definition Algorithm.hpp:268
~Algorithm() noexcept
const std::vector< std::shared_ptr< Memory > > & getMemObjects()
Algorithm(const Algorithm &)=delete
Make Algorithm uncopyable.
void setWorkgroup(const Workgroup &workgroup, uint32_t minSize=1)
void setPushConstants(const std::vector< T > &pushConstants)
Definition Algorithm.hpp:212
void recordBindPush(const vk::CommandBuffer &commandBuffer)
const Workgroup & getWorkgroup()
void rebuild(const std::vector< std::shared_ptr< Memory > > &memObjects, const std::vector< uint32_t > &spirv, const Workgroup &workgroup={}, const std::vector< S > &specializationConstants={}, const std::vector< P > &pushConstants={})
Definition Algorithm.hpp:90
void recordDispatch(const vk::CommandBuffer &commandBuffer)
void setPushConstants(void *data, uint32_t size, uint32_t memorySize)
Definition Algorithm.hpp:229
Algorithm(std::shared_ptr< vk::Device > device, const std::vector< std::shared_ptr< Memory > > &memObjects={}, const std::vector< uint32_t > &spirv={}, const Workgroup &workgroup={}, const std::vector< S > &specializationConstants={}, const std::vector< P > &pushConstants={}) noexcept
Definition Algorithm.hpp:44
const std::vector< T > getPushConstants()
Definition Algorithm.hpp:280