Kompute
Loading...
Searching...
No Matches
OpAlgoDispatch.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include "kompute/Algorithm.hpp"
5#include "kompute/Core.hpp"
6#include "kompute/Tensor.hpp"
7#include "kompute/operations/OpBase.hpp"
8
9namespace kp {
10
17class OpAlgoDispatch : public OpBase
18{
19 public:
27 template<typename T = float>
28 OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm,
29 const std::vector<T>& pushConstants = {}) noexcept
30 {
31 KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor");
32
33 this->mAlgorithm = algorithm;
34
35 if (pushConstants.size()) {
36 uint32_t memorySize = sizeof(decltype(pushConstants.back()));
37 uint32_t size = pushConstants.size();
38 uint32_t totalSize = size * memorySize;
39 this->mPushConstantsData = malloc(totalSize);
40 memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
41 this->mPushConstantsDataTypeMemorySize = memorySize;
42 this->mPushConstantsSize = size;
43 }
44 }
45
51 OpAlgoDispatch(const OpAlgoDispatch&&) = delete;
52 OpAlgoDispatch& operator=(const OpAlgoDispatch&) = delete;
53 OpAlgoDispatch& operator=(const OpAlgoDispatch&&) = delete;
54
55
60 virtual ~OpAlgoDispatch() noexcept override;
61
72 virtual void record(const vk::CommandBuffer& commandBuffer) override;
73
79 virtual void preEval(const vk::CommandBuffer& commandBuffer) override;
80
86 virtual void postEval(const vk::CommandBuffer& commandBuffer) override;
87
88 private:
89 // -------------- ALWAYS OWNED RESOURCES
90 std::shared_ptr<Algorithm> mAlgorithm;
91 void* mPushConstantsData = nullptr;
92 uint32_t mPushConstantsDataTypeMemorySize = 0;
93 uint32_t mPushConstantsSize = 0;
94};
95
96} // End namespace kp
Definition Algorithm.hpp:23
Definition OpAlgoDispatch.hpp:18
virtual void preEval(const vk::CommandBuffer &commandBuffer) override
OpAlgoDispatch(const OpAlgoDispatch &)=delete
Make OpAlgoDispatch non-copyable.
OpAlgoDispatch(const std::shared_ptr< kp::Algorithm > &algorithm, const std::vector< T > &pushConstants={}) noexcept
Definition OpAlgoDispatch.hpp:28
virtual ~OpAlgoDispatch() noexcept override
virtual void record(const vk::CommandBuffer &commandBuffer) override
virtual void postEval(const vk::CommandBuffer &commandBuffer) override
Definition OpBase.hpp:20