Kompute
Loading...
Searching...
No Matches
OpMult.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include <fstream>
5
6#include "kompute/Core.hpp"
7
8#include "ShaderOpMult.hpp"
9
10#include "kompute/Algorithm.hpp"
11#include "kompute/Tensor.hpp"
12
13#include "kompute/operations/OpAlgoDispatch.hpp"
14
15namespace kp {
16
21class OpMult : public OpAlgoDispatch
22{
23 public:
33 OpMult(std::vector<std::shared_ptr<Memory>> memObjects,
34 std::shared_ptr<Algorithm> algorithm)
35 : OpAlgoDispatch(algorithm)
36 {
37 KP_LOG_DEBUG("Kompute OpMult constructor with params");
38
39 if (memObjects.size() != 3) {
40 throw std::runtime_error(
41 "Kompute OpMult expected 3 mem objects but got " +
42 std::to_string(memObjects.size()));
43 }
44
45 const std::vector<uint32_t> spirv = std::vector<uint32_t>(
46 SHADEROPMULT_COMP_SPV.begin(), SHADEROPMULT_COMP_SPV.end());
47
48 algorithm->rebuild<>(memObjects, spirv);
49 }
50
55 OpMult(const OpMult&) = delete;
56 OpMult(const OpMult&&) = delete;
57 OpMult& operator=(const OpMult&) = delete;
58 OpMult& operator=(const OpMult&&) = delete;
59
64 ~OpMult() noexcept override { KP_LOG_DEBUG("Kompute OpMult destructor started"); }
65};
66
67} // End namespace kp
Definition OpAlgoDispatch.hpp:18
Definition OpMult.hpp:22
~OpMult() noexcept override
Definition OpMult.hpp:64
OpMult(std::vector< std::shared_ptr< Memory > > memObjects, std::shared_ptr< Algorithm > algorithm)
Definition OpMult.hpp:33
OpMult(const OpMult &)=delete
Make OpMult non-copyable.