45 const std::vector<std::shared_ptr<Memory>>& memObjects = {},
46 const std::vector<uint32_t>& spirv = {},
47 const Workgroup& workgroup = {},
48 const std::vector<S>& specializationConstants = {},
49 const std::vector<P>& pushConstants = {})
noexcept
51 KP_LOG_DEBUG(
"Kompute Algorithm Constructor with device");
53 this->mDevice = device;
55 if (memObjects.size() && spirv.size()) {
57 "Kompute Algorithm initialising with tensor size: {} and "
64 specializationConstants,
68 "Kompute Algorithm constructor with empty mem objects and or "
69 "spirv so not rebuilding vulkan components");
90 void rebuild(
const std::vector<std::shared_ptr<Memory>>& memObjects,
91 const std::vector<uint32_t>& spirv,
92 const Workgroup& workgroup = {},
93 const std::vector<S>& specializationConstants = {},
94 const std::vector<P>& pushConstants = {})
96 KP_LOG_DEBUG(
"Kompute Algorithm rebuild started");
98 this->mMemObjects = memObjects;
100 if (specializationConstants.size()) {
101 if (this->mSpecializationConstantsData) {
102 free(this->mSpecializationConstantsData);
104 uint32_t memorySize =
105 sizeof(
decltype(specializationConstants.back()));
106 uint32_t size = specializationConstants.size();
107 uint32_t totalSize = size * memorySize;
108 this->mSpecializationConstantsData = malloc(totalSize);
109 memcpy(this->mSpecializationConstantsData,
110 specializationConstants.data(),
112 this->mSpecializationConstantsDataTypeMemorySize = memorySize;
113 this->mSpecializationConstantsSize = size;
116 if (pushConstants.size()) {
117 if (this->mPushConstantsData) {
118 free(this->mPushConstantsData);
120 uint32_t memorySize =
sizeof(
decltype(pushConstants.back()));
121 uint32_t size = pushConstants.size();
122 uint32_t totalSize = size * memorySize;
123 this->mPushConstantsData = malloc(totalSize);
124 memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
125 this->mPushConstantsDataTypeMemorySize = memorySize;
126 this->mPushConstantsSize = size;
131 this->mMemObjects.size() ? this->mMemObjects[0]->size() : 1);
139 this->createParameters();
140 this->createShaderModule(spirv);
141 this->createPipeline();
232 uint32_t totalSize = memorySize * size;
233 uint32_t previousTotalSize =
234 this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize;
236 if (totalSize != previousTotalSize) {
237 throw std::runtime_error(fmt::format(
238 "Kompute Algorithm push "
239 "constant total memory size provided is {} but expected {} bytes",
243 if (this->mPushConstantsData) {
244 free(this->mPushConstantsData);
247 this->mPushConstantsData = malloc(totalSize);
248 memcpy(this->mPushConstantsData, data, totalSize);
249 this->mPushConstantsDataTypeMemorySize = memorySize;
250 this->mPushConstantsSize = size;