Kompute
Loading...
Searching...
No Matches
Tensor.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include "kompute/Core.hpp"
5#include "kompute/Memory.hpp"
6#include "logger/Logger.hpp"
7#include <memory>
8#include <string>
9
10namespace kp {
11
12// Forward-declare the Image class
13class Image;
22class Tensor : public Memory
23{
24 public:
35 Tensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
36 std::shared_ptr<vk::Device> device,
37 void* data,
40 const DataTypes& dataType,
42
53 Tensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
54 std::shared_ptr<vk::Device> device,
57 const DataTypes& dataType,
59
64 Tensor(const Tensor&) = delete;
65 Tensor(const Tensor&&) = delete;
66 Tensor& operator=(const Tensor&) = delete;
67 Tensor& operator=(const Tensor&&) = delete;
68
73 virtual ~Tensor();
74
78 void destroy() override;
79
85 bool isInit() override;
86
95 void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
96 std::shared_ptr<Tensor> copyFromTensor) override;
97
106 void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
107 std::shared_ptr<Image> copyFromImage) override;
108
117 const vk::CommandBuffer& commandBuffer) override;
118
127 const vk::CommandBuffer& commandBuffer) override;
128
141 const vk::CommandBuffer& commandBuffer,
142 vk::AccessFlagBits srcAccessMask,
143 vk::AccessFlagBits dstAccessMask,
144 vk::PipelineStageFlagBits srcStageMask,
145 vk::PipelineStageFlagBits dstStageMask) override;
158 const vk::CommandBuffer& commandBuffer,
159 vk::AccessFlagBits srcAccessMask,
160 vk::AccessFlagBits dstAccessMask,
161 vk::PipelineStageFlagBits srcStageMask,
162 vk::PipelineStageFlagBits dstStageMask) override;
163
171 vk::WriteDescriptorSet constructDescriptorSet(
172 vk::DescriptorSet descriptorSet,
173 uint32_t binding) override;
174
175 std::shared_ptr<vk::Buffer> getPrimaryBuffer();
176
177 Type type() override { return Type::eTensor; }
178
179 protected:
180 // -------------- ALWAYS OWNED RESOURCES
181 vk::DescriptorBufferInfo mDescriptorBufferInfo;
182
183 private:
184 // -------------- OPTIONALLY OWNED RESOURCES
185 std::shared_ptr<vk::Buffer> mPrimaryBuffer;
186 bool mFreePrimaryBuffer = false;
187 std::shared_ptr<vk::Buffer> mStagingBuffer;
188 bool mFreeStagingBuffer = false;
189
190 void allocateMemoryCreateGPUResources(); // Creates the vulkan buffer
191 void createBuffer(std::shared_ptr<vk::Buffer> buffer,
192 vk::BufferUsageFlags bufferUsageFlags);
193 void allocateBindMemory(std::shared_ptr<vk::Buffer> buffer,
194 std::shared_ptr<vk::DeviceMemory> memory,
195 vk::MemoryPropertyFlags memoryPropertyFlags);
196 void recordCopyBuffer(const vk::CommandBuffer& commandBuffer,
197 std::shared_ptr<vk::Buffer> bufferFrom,
198 std::shared_ptr<vk::Buffer> bufferTo,
199 vk::DeviceSize bufferSize,
200 vk::BufferCopy copyRegion);
201 void recordCopyBufferFromImage(const vk::CommandBuffer& commandBuffer,
202 std::shared_ptr<vk::Image> imageFrom,
203 std::shared_ptr<vk::Buffer> bufferTo,
204 vk::ImageLayout fromLayout,
205 vk::DeviceSize /*bufferSize*/,
206 vk::BufferImageCopy copyRegion);
207 void recordBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer,
208 const vk::Buffer& buffer,
209 vk::AccessFlagBits srcAccessMask,
210 vk::AccessFlagBits dstAccessMask,
211 vk::PipelineStageFlagBits srcStageMask,
212 vk::PipelineStageFlagBits dstStageMask);
213
214 // Private util functions
215 vk::BufferUsageFlags getPrimaryBufferUsageFlags();
216 vk::BufferUsageFlags getStagingBufferUsageFlags();
217
218 vk::DescriptorBufferInfo constructDescriptorBufferInfo();
219
224 void reserve();
225};
226
227template<typename T>
228class TensorT : public Tensor
229{
230 public:
231 TensorT(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
232 std::shared_ptr<vk::Device> device,
233 const size_t size,
234 const MemoryTypes& tensorType = MemoryTypes::eDevice)
235 : Tensor(physicalDevice,
236 device,
237 size,
238 sizeof(T),
240 tensorType)
241 {
242 KP_LOG_DEBUG("Kompute TensorT constructor with data size {}", size);
243 }
244
245 TensorT(
246 std::shared_ptr<vk::PhysicalDevice> physicalDevice,
247 std::shared_ptr<vk::Device> device,
248 const std::vector<T>& data,
250 : Tensor(physicalDevice,
251 device,
252 (void*)data.data(),
253 static_cast<uint32_t>(data.size()),
254 sizeof(T),
256 tensorType)
257 {
258 KP_LOG_DEBUG("Kompute TensorT filling constructor with data size {}",
259 data.size());
260 }
261
266 TensorT(const TensorT&) = delete;
267 TensorT(const TensorT&&) = delete;
268 TensorT& operator=(const TensorT&) = delete;
269 TensorT& operator=(const TensorT&&) = delete;
270
271
272 ~TensorT() { KP_LOG_DEBUG("Kompute TensorT destructor"); }
273
274 DataTypes dataType() { return Memory::dataType<T>(); }
275 std::vector<T> vector() { return Memory::vector<T>(); }
276 T* data() { return Memory::data<T>(); }
277};
278
279} // End namespace kp
Definition Memory.hpp:16
uint32_t size()
static constexpr DataTypes dataType()
std::vector< T > vector()
Definition Memory.hpp:282
T * data()
Definition Memory.hpp:265
MemoryTypes
Definition Memory.hpp:28
@ eDevice
Type is device memory, source and destination.
MemoryTypes memoryType()
Definition Tensor.hpp:229
TensorT(const TensorT &)=delete
Make TensorT uncopyable.
Definition Tensor.hpp:23
Type type() override
Definition Tensor.hpp:177
Tensor(const Tensor &)=delete
Make Tensor uncopyable.
void recordStagingMemoryBarrier(const vk::CommandBuffer &commandBuffer, vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask) override
void recordCopyFromStagingToDevice(const vk::CommandBuffer &commandBuffer) override
Tensor(std::shared_ptr< vk::PhysicalDevice > physicalDevice, std::shared_ptr< vk::Device > device, uint32_t elementTotalCount, uint32_t elementMemorySize, const DataTypes &dataType, const MemoryTypes &memoryType=MemoryTypes::eDevice)
void recordPrimaryMemoryBarrier(const vk::CommandBuffer &commandBuffer, vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask) override
void recordCopyFrom(const vk::CommandBuffer &commandBuffer, std::shared_ptr< Tensor > copyFromTensor) override
virtual ~Tensor()
void recordCopyFrom(const vk::CommandBuffer &commandBuffer, std::shared_ptr< Image > copyFromImage) override
bool isInit() override
Tensor(std::shared_ptr< vk::PhysicalDevice > physicalDevice, std::shared_ptr< vk::Device > device, void *data, uint32_t elementTotalCount, uint32_t elementMemorySize, const DataTypes &dataType, const MemoryTypes &tensorType=MemoryTypes::eDevice)
vk::WriteDescriptorSet constructDescriptorSet(vk::DescriptorSet descriptorSet, uint32_t binding) override
void recordCopyFromDeviceToStaging(const vk::CommandBuffer &commandBuffer) override
void destroy() override