/home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/tensor.h Source File

Ratpac-two: /home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/tensor.h Source File
Ratpac-two
tensor.h
Go to the documentation of this file.
1 // MIT License
2 //
3 // Copyright (c) 2020 Sergio Izquierdo
4 // Copyright (c) 2020 CarlPoirier
5 // Copyright (c) 2020 Jiannan Liu
6 // Copyright (c) 2020 liufeng27
7 // Copyright (c) 2022 Alfredo Rodriguez
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining a copy
10 // of this software and associated documentation files (the "Software"), to deal
11 // in the Software without restriction, including without limitation the rights
12 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 // copies of the Software, and to permit persons to whom the Software is
14 // furnished to do so, subject to the following conditions:
15 //
16 // The above copyright notice and this permission notice shall be included in
17 // all copies or substantial portions of the Software.
18 //
19 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 // SOFTWARE.
26 
37 #ifndef INCLUDE_CPPFLOW_TENSOR_H_
38 #define INCLUDE_CPPFLOW_TENSOR_H_
39 
40 // C headers
41 #include <tensorflow/c/eager/c_api.h>
42 #include <tensorflow/c/tf_tensor.h>
43 
44 // C++ headers
45 #include <cstring>
46 #include <memory>
47 #include <string>
48 #include <vector>
49 
50 // CppFlow headers
51 #include "cppflow/context.h"
52 #include "cppflow/datatype.h"
53 
54 namespace cppflow {
55 
61 class tensor {
62  public:
63  tensor() = default;
64 
71  template <typename T>
72  tensor(const std::vector<T>& values, const std::vector<int64_t>& shape);
73 
79  template <typename T>
80  tensor(const std::initializer_list<T>& values);
81 
87  template <typename T>
88  tensor(const T& value);
89  tensor(const tensor& tensor) = default;
90  tensor(tensor&& tensor) = default;
91  explicit tensor(TFE_TensorHandle* handle);
92  explicit tensor(TF_Tensor* t);
93 
94  ~tensor() = default;
95 
96  tensor& operator=(const tensor& other) = default;
97  tensor& operator=(tensor&& other) = default;
98 
102  tensor shape() const;
103 
109  std::string device(bool on_memory = false) const;
110 
114  datatype dtype() const;
115 
121  template <typename T>
122  std::vector<T> get_data() const;
123 
124  // NOTE:
125  // Usually, one should not call get_eager_handle() or get_tensor() below.
126  // They are designed for implementation details in cppflow.
127  // If you are calling them directly, it is likely that you are using some
128  // tenforflow APIs not supported in cppflow.
129 
130  // Additional NOTE:
131  // TF_Tensor is an immutable tensor inside tensorflow.
132  // TFE_TensorHandle is a TF_Tensor and the associated device,
133  // plus some data cache
134 
135  // @todo Need to determine if we can mark the return value or *this as const
136  std::shared_ptr<TFE_TensorHandle> get_eager_handle() const { return tfe_handle; }
137 
138  // Get the TF_Tensor data from the eager handle
139  // Call `get_data<T>()` instead if possible
140  // NOTE:
141  // Changes to the returned TF_Tensor may not be reflected in the
142  // actual device memory!
143  // Do *NOT* modify the returned TF_Tensor!
144  // See comments of `tf_tensor` for more details.
145  std::shared_ptr<TF_Tensor> get_tensor() const;
146 
147  // DO NOT directly access this member, call get_eager_handle() instead
148  // @todo This is kept as public to be compatible with existing code and
149  // should be mark as private
150  std::shared_ptr<TFE_TensorHandle> tfe_handle;
151 
152  private:
153  tensor(enum TF_DataType type, const void* data, size_t len, const std::vector<int64_t>& shape);
154 
155  // This member serves as a local cache of the data in tfe_handle.
156  // It refers to `local_mirrors_` if on device, or `data_` if on host CPU.
157  // Changes to this variable may not be reflected in the actual device memory,
158  // e.g. on GPUs or on remote nodes.
159  // Access it via get_tensor() if not in constructor
160  mutable std::shared_ptr<TF_Tensor> tf_tensor;
161 };
162 
163 } // namespace cppflow
164 
165 /******************************
166  * IMPLEMENTATION DETAILS *
167  ******************************/
168 
169 namespace cppflow {
170 
171 inline tensor::tensor(enum TF_DataType type, const void* data, size_t len, const std::vector<int64_t>& shape) {
172  this->tf_tensor = {TF_AllocateTensor(type, shape.data(), static_cast<int>(shape.size()), len), TF_DeleteTensor};
173  memcpy(TF_TensorData(this->tf_tensor.get()), data, TF_TensorByteSize(this->tf_tensor.get()));
174  this->tfe_handle = {TFE_NewTensorHandle(this->tf_tensor.get(), context::get_status()), TFE_DeleteTensorHandle};
175  status_check(context::get_status());
176 }
177 
178 template <typename T>
179 tensor::tensor(const std::vector<T>& values, const std::vector<int64_t>& shape)
180  : tensor(deduce_tf_type<T>(), values.data(), values.size() * sizeof(T), shape) {}
181 
182 template <typename T>
183 tensor::tensor(const std::initializer_list<T>& values) : tensor(std::vector<T>(values), {(int64_t)values.size()}) {}
184 
185 template <typename T>
186 tensor::tensor(const T& value) : tensor(std::vector<T>({value}), {}) {}
187 
188 #ifdef TENSORFLOW_C_TF_TSTRING_H_
189 // For future version TensorFlow 2.4
190 template <>
191 inline tensor::tensor(const std::string& value) {
192  TF_TString tstr[1];
193  TF_TString_Init(&tstr[0]);
194  TF_TString_Copy(&tstr[0], value.c_str(), value.size());
195 
196  // *this = tensor(static_cast<enum TF_DataType>(TF_STRING),
197  // reinterpret_cast<void *>(tstr), sizeof(tstr), /*shape*/ {});
198  *this = tensor(static_cast<enum TF_DataType>(TF_STRING), (void*)tstr, sizeof(tstr), /*shape*/ {});
199 }
200 #else
201 template <>
202 inline tensor::tensor(const std::string& value) {
203  size_t size = 8 + TF_StringEncodedSize(value.length());
204  char* data = new char[value.size() + 8];
205  for (int i = 0; i < 8; i++) {
206  data[i] = 0;
207  }
208  TF_StringEncode(value.c_str(), value.size(), data + 8, size - 8, context::get_status());
209  status_check(context::get_status());
210 
211  // *this = tensor(static_cast<enum TF_DataType>(TF_STRING),
212  // reinterpret_cast<void *>(data), size, /*shape*/ {});
213  *this = tensor(static_cast<enum TF_DataType>(TF_STRING), (void*)data, size, /*shape*/ {});
214  delete[] data;
215 }
216 #endif // TENSORFLOW_C_TF_TSTRING_H_
217 
218 inline tensor::tensor(TFE_TensorHandle* handle) { this->tfe_handle = {handle, TFE_DeleteTensorHandle}; }
219 
220 inline tensor::tensor(TF_Tensor* t) {
221  this->tf_tensor = {t, TF_DeleteTensor};
222  this->tfe_handle = {TFE_NewTensorHandle(this->tf_tensor.get(), context::get_status()), TFE_DeleteTensorHandle};
223  status_check(context::get_status());
224 }
225 
226 inline tensor tensor::shape() const {
227  auto op = TFE_NewOp(context::get_context(), "Shape", context::get_status());
228  status_check(context::get_status());
229 
230  TFE_OpAddInput(op, this->tfe_handle.get(), context::get_status());
231  status_check(context::get_status());
232 
233  // Output type should be int64_t
234  TFE_OpSetAttrType(op, "out_type", cppflow::datatype::TF_INT64);
235 
236  // EXECUTE
237  int n = 1;
238  TFE_TensorHandle* res[1] = {nullptr};
239  TFE_Execute(op, res, &n, context::get_status());
240  status_check(context::get_status());
241  TFE_DeleteOp(op);
242 
243  return tensor(res[0]);
244 }
245 
246 inline std::string tensor::device(bool on_memory) const {
247  std::string res;
248  if (on_memory)
249  res = TFE_TensorHandleBackingDeviceName(this->tfe_handle.get(), context::get_status());
250  else
251  res = std::string(TFE_TensorHandleDeviceName(this->tfe_handle.get(), context::get_status()));
252 
253  status_check(context::get_status());
254  return res;
255 }
256 
257 template <typename T>
258 std::vector<T> tensor::get_data() const {
259  // Check if asked datatype and tensor datatype match
260  if (this->dtype() != deduce_tf_type<T>()) {
261  auto type1 = cppflow::to_string(deduce_tf_type<T>());
262  auto type2 = cppflow::to_string(this->dtype());
263  auto error = "Datatype in function get_data (" + type1 + ") does not match tensor datatype (" + type2 + ")";
264  throw std::runtime_error(error);
265  }
266 
267  auto res_tensor = get_tensor();
268 
269  // Check tensor data is not empty
270  auto raw_data = TF_TensorData(res_tensor.get());
271  // this->error_check(raw_data != nullptr, "Tensor data is empty");
272 
273  size_t size = (TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get())));
274 
275  // Convert to correct type
276  const auto T_data = static_cast<T*>(raw_data);
277  std::vector<T> r(T_data, T_data + size);
278 
279  return r;
280 }
281 
282 inline datatype tensor::dtype() const { return TFE_TensorHandleDataType(this->tfe_handle.get()); }
283 
284 // NOTE:
285 // Changes to the returned TF_Tensor are not reflected in
286 // the actual device memory!
287 inline std::shared_ptr<TF_Tensor> tensor::get_tensor() const {
288  if (!tf_tensor) {
289  tf_tensor = {TFE_TensorHandleResolve(tfe_handle.get(), context::get_status()), TF_DeleteTensor};
290  status_check(context::get_status());
291  }
292  return tf_tensor;
293 }
294 } // namespace cppflow
295 
296 #endif // INCLUDE_CPPFLOW_TENSOR_H_
A TensorFlow eager tensor wrapper.
Definition: tensor.h:61
tensor shape() const
Definition: tensor.h:226
datatype dtype() const
Definition: tensor.h:282
std::string device(bool on_memory=false) const
Definition: tensor.h:246
std::vector< T > get_data() const
Definition: tensor.h:258
TF_DataType deduce_tf_type()
Definition: datatype.h:107
std::string to_string(datatype dt)
Definition: datatype.h:48