Kompute
Loading...
Searching...
No Matches
Memory.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include "kompute/Core.hpp"
5#include "logger/Logger.hpp"
6#include <memory>
7#include <string>
8
9namespace kp {
10
11// Forward declare the Tensor and Image classes
12class Tensor;
13class Image;
14
15class Memory
16{
17 // This is the base class for Tensors and Images.
18 // It's required so that algorithms and sequences can mix tensors and
19 // images.
20 public:
27 enum class MemoryTypes
28 {
29 eDevice = 0,
30 eHost = 1,
31 eStorage = 2,
33 3,
34 };
35
36 enum class DataTypes
37 {
38 eBool = 0,
39 eInt = 1,
40 eUnsignedInt = 2,
41 eFloat = 3,
42 eDouble = 4,
43 eCustom = 5,
44 eShort = 6,
45 eUnsignedShort = 7,
46 eChar = 8,
47 eUnsignedChar = 9
48 };
49
50 enum class Type
51 {
52 eTensor = 0,
53 eImage = 1
54 };
55
56 static std::string toString(MemoryTypes dt);
57 static std::string toString(DataTypes dt);
58
59 Memory(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
60 std::shared_ptr<vk::Device> device,
61 const DataTypes& dataType,
63 uint32_t x,
64 uint32_t y);
65
66
71 Memory(const Memory&) = delete;
72 Memory(const Memory&&) = delete;
73 Memory& operator=(const Memory&) = delete;
74 Memory& operator=(const Memory&&) = delete;
75
80 virtual ~Memory(){};
81
85 virtual void destroy();
86
93
99 template<typename T>
100 static constexpr DataTypes dataType();
101
107 DataTypes dataType();
108
115 virtual bool isInit() = 0;
116
125 const vk::CommandBuffer& commandBuffer) = 0;
126
135 const vk::CommandBuffer& commandBuffer) = 0;
148 const vk::CommandBuffer& commandBuffer,
149 vk::AccessFlagBits srcAccessMask,
150 vk::AccessFlagBits dstAccessMask,
151 vk::PipelineStageFlagBits srcStageMask,
152 vk::PipelineStageFlagBits dstStageMask) = 0;
165 const vk::CommandBuffer& commandBuffer,
166 vk::AccessFlagBits srcAccessMask,
167 vk::AccessFlagBits dstAccessMask,
168 vk::PipelineStageFlagBits srcStageMask,
169 vk::PipelineStageFlagBits dstStageMask) = 0;
170
179 void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
180 std::shared_ptr<Memory> copyFromMemory);
181
189 virtual vk::WriteDescriptorSet constructDescriptorSet(
190 vk::DescriptorSet descriptorSet,
191 uint32_t binding) = 0;
192
200
209
217
226
227 vk::DescriptorType getDescriptorType() { return mDescriptorType; }
228
236 void* rawData();
237
242 void setData(const void* data, size_t size);
243
248 template<typename T>
249 void setData(const std::vector<T>& data)
250 {
251 KP_LOG_DEBUG("Kompute Memory setting data with data size {}",
252 data.size() * sizeof(T));
253
254 this->setData(data.data(), data.size() * sizeof(T));
255 }
256
264 template<typename T>
266 {
267 if (this->mRawData == nullptr) {
268 this->mapRawData();
269 }
270
271 return (T*)this->mRawData;
272 }
273
281 template<typename T>
282 std::vector<T> vector()
283 {
284 if (this->mRawData == nullptr) {
285 this->mapRawData();
286 }
287
288 return { (T*)this->mRawData, ((T*)this->mRawData) + this->size() };
289 }
290
291 /***
292 * Retreive the size of the x-dimension of the memory
293 *
294 * @return Size of the x-dimension of the memory
295 */
296 uint32_t getX() { return this->mX; }
297
298 /***
299 * Retreive the size of the y-dimension of the memory
300 *
301 * @return Size of the y-dimension of the memory
302 */
303 uint32_t getY() { return this->mY; };
304
310 virtual Type type() = 0;
311
312 protected:
313 // -------------- ALWAYS OWNED RESOURCES
314 MemoryTypes mMemoryType;
315 DataTypes mDataType;
316 uint32_t mSize;
317 uint32_t mDataTypeMemorySize;
318 void* mRawData = nullptr;
319 vk::DescriptorType mDescriptorType;
320 bool mUnmapMemory = false;
321 uint32_t mX;
322 uint32_t mY;
323
324 // -------------- NEVER OWNED RESOURCES
325 std::shared_ptr<vk::PhysicalDevice> mPhysicalDevice;
326 std::shared_ptr<vk::Device> mDevice;
327
328 // -------------- OPTIONALLY OWNED RESOURCES
329 std::shared_ptr<vk::DeviceMemory> mPrimaryMemory;
330 bool mFreePrimaryMemory = false;
331 std::shared_ptr<vk::DeviceMemory> mStagingMemory;
332 bool mFreeStagingMemory = false;
333
334 // Private util functions
335 void mapRawData();
336 void unmapRawData();
337 void updateRawData(void* data);
338 vk::MemoryPropertyFlags getPrimaryMemoryPropertyFlags();
339 vk::MemoryPropertyFlags getStagingMemoryPropertyFlags();
340
341 virtual void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
342 std::shared_ptr<Tensor> copyFromMemory) = 0;
343 virtual void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
344 std::shared_ptr<Image> copyFromMemory) = 0;
345};
346
347template<>
348constexpr Memory::DataTypes
350{
351 return DataTypes::eBool;
352}
353
354template<>
355constexpr Memory::DataTypes
357{
358 return DataTypes::eChar;
359}
360
361template<>
362constexpr Memory::DataTypes
364{
365 return DataTypes::eUnsignedChar;
366}
367
368template<>
369constexpr Memory::DataTypes
371{
372 return DataTypes::eShort;
373}
374
375template<>
376constexpr Memory::DataTypes
378{
379 return DataTypes::eUnsignedShort;
380}
381
382template<>
383constexpr Memory::DataTypes
385{
386 return DataTypes::eInt;
387}
388
389template<>
390constexpr Memory::DataTypes
392{
393 return DataTypes::eUnsignedInt;
394}
395
396template<>
397constexpr Memory::DataTypes
399{
400 return DataTypes::eFloat;
401}
402
403template<>
404constexpr Memory::DataTypes
406{
407 return DataTypes::eDouble;
408}
409
410} // End namespace kp
Definition Memory.hpp:16
void setData(const void *data, size_t size)
uint32_t size()
uint32_t dataTypeMemorySize()
void setData(const std::vector< T > &data)
Definition Memory.hpp:249
virtual void recordStagingMemoryBarrier(const vk::CommandBuffer &commandBuffer, vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask)=0
Memory(const Memory &)=delete
Make Memory uncopyable.
virtual vk::WriteDescriptorSet constructDescriptorSet(vk::DescriptorSet descriptorSet, uint32_t binding)=0
virtual void recordCopyFromStagingToDevice(const vk::CommandBuffer &commandBuffer)=0
virtual ~Memory()
Definition Memory.hpp:80
virtual void recordPrimaryMemoryBarrier(const vk::CommandBuffer &commandBuffer, vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask)=0
static uint32_t dataTypeMemorySize(DataTypes dataType)
static constexpr DataTypes dataType()
void * rawData()
uint32_t memorySize()
std::vector< T > vector()
Definition Memory.hpp:282
virtual void recordCopyFromDeviceToStaging(const vk::CommandBuffer &commandBuffer)=0
DataTypes dataType()
T * data()
Definition Memory.hpp:265
MemoryTypes
Definition Memory.hpp:28
@ eDeviceAndHost
Type is host-visible and host-coherent device memory.
@ eDevice
Type is device memory, source and destination.
@ eHost
Type is host memory, source and destination.
@ eStorage
Type is Device memory (only)
virtual void destroy()
MemoryTypes memoryType()
virtual Type type()=0
void recordCopyFrom(const vk::CommandBuffer &commandBuffer, std::shared_ptr< Memory > copyFromMemory)
virtual bool isInit()=0