/home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/ops.h Source File

Ratpac-two: /home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/ops.h Source File
Ratpac-two
ops.h
Go to the documentation of this file.
1 // MIT License
2 //
3 // Copyright (c) 2020 Sergio Izquierdo
4 // Copyright (c) 2020 Jiannan Liu
5 //
6 // Permission is hereby granted, free of charge, to any person obtaining a copy
7 // of this software and associated documentation files (the "Software"), to deal
8 // in the Software without restriction, including without limitation the rights
9 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 // copies of the Software, and to permit persons to whom the Software is
11 // furnished to do so, subject to the following conditions:
12 //
13 // The above copyright notice and this permission notice shall be included in
14 // all copies or substantial portions of the Software.
15 //
16 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 // SOFTWARE.
23 
31 #ifndef INCLUDE_CPPFLOW_OPS_H_
32 #define INCLUDE_CPPFLOW_OPS_H_
33 
34 // C++ headers
35 #include <string>
36 
37 // CppFlow headers
38 #include "cppflow/raw_ops.h"
39 #include "cppflow/tensor.h"
40 
41 namespace cppflow {
42 
47 
51 tensor operator+(const tensor& x, const tensor& y);
52 
56 tensor operator-(const tensor& x, const tensor& y);
57 
61 tensor operator*(const tensor& x, const tensor& y);
62 
66 tensor operator/(const tensor& x, const tensor& y);
67 
68 std::ostream& operator<<(std::ostream& os, const cppflow::tensor& t);
69 
71 
77 std::string to_string(const tensor& t);
78 } // namespace cppflow
79 
80 /******************************
81  * IMPLEMENTATION DETAILS *
82  ******************************/
83 
84 namespace cppflow {
85 
86 // Operators
87 
88 inline tensor operator+(const tensor& x, const tensor& y) { return add(x, y); }
89 
90 inline tensor operator-(const tensor& x, const tensor& y) { return sub(x, y); }
91 
92 inline tensor operator*(const tensor& x, const tensor& y) { return mul(x, y); }
93 
94 inline tensor operator/(const tensor& x, const tensor& y) { return div(x, y); }
95 
96 inline std::ostream& operator<<(std::ostream& os, const cppflow::tensor& t) {
97  std::string res = to_string(t);
98  return os << res;
99 }
100 
101 inline std::string to_string(const tensor& t) {
102  auto res_tensor = string_format({t.shape(), t}, "(tensor: shape=%s, dtype=" + to_string(t.dtype()) + ", data=\n%s)");
103  auto res_tensor_h = res_tensor.get_tensor();
104 
105 #ifdef TENSORFLOW_C_TF_TSTRING_H_
106  // For future version TensorFlow 2.4
107  // auto *t_str = reinterpret_cast<TF_TString *>(
108  // TF_TensorData(res_tensor_h.get()));
109  auto* t_str = (TF_TString*)(TF_TensorData(res_tensor_h.get()));
110  auto result = std::string(TF_TString_GetDataPointer(t_str), TF_TString_GetSize(t_str));
111 #else
112  const char* dst[1] = {nullptr};
113  size_t dst_len[1] = {3};
114  TF_StringDecode(static_cast<char*>(TF_TensorData(res_tensor_h.get())) + 8, TF_TensorByteSize(res_tensor_h.get()), dst,
115  dst_len, context::get_status());
116  status_check(context::get_status());
117  auto result = std::string(dst[0], *dst_len);
118 #endif // TENSORFLOW_C_TF_TSTRING_H_
119 
120  return result;
121 }
122 
123 } // namespace cppflow
124 
125 #endif // INCLUDE_CPPFLOW_OPS_H_
A TensorFlow eager tensor wrapper.
Definition: tensor.h:61
tensor shape() const
Definition: tensor.h:226
datatype dtype() const
Definition: tensor.h:282
std::ostream & operator<<(std::ostream &os, datatype dt)
Definition: datatype.h:127
std::string to_string(datatype dt)
Definition: datatype.h:48
tensor operator*(const tensor &x, const tensor &y)
Definition: ops.h:92
tensor operator+(const tensor &x, const tensor &y)
Definition: ops.h:88
tensor operator-(const tensor &x, const tensor &y)
Definition: ops.h:90
tensor operator/(const tensor &x, const tensor &y)
Definition: ops.h:94