37 #ifndef INCLUDE_CPPFLOW_TENSOR_H_
38 #define INCLUDE_CPPFLOW_TENSOR_H_
41 #include <tensorflow/c/eager/c_api.h>
42 #include <tensorflow/c/tf_tensor.h>
72 tensor(
const std::vector<T>& values,
const std::vector<int64_t>&
shape);
80 tensor(
const std::initializer_list<T>& values);
91 explicit tensor(TFE_TensorHandle* handle);
92 explicit tensor(TF_Tensor* t);
109 std::string
device(
bool on_memory =
false)
const;
114 datatype
dtype()
const;
121 template <
typename T>
136 std::shared_ptr<TFE_TensorHandle> get_eager_handle()
const {
return tfe_handle; }
145 std::shared_ptr<TF_Tensor> get_tensor()
const;
150 std::shared_ptr<TFE_TensorHandle> tfe_handle;
153 tensor(
enum TF_DataType type,
const void* data,
size_t len,
const std::vector<int64_t>&
shape);
160 mutable std::shared_ptr<TF_Tensor> tf_tensor;
171 inline tensor::tensor(
enum TF_DataType type,
const void* data,
size_t len,
const std::vector<int64_t>& shape) {
172 this->tf_tensor = {TF_AllocateTensor(type,
shape.data(),
static_cast<int>(
shape.size()), len), TF_DeleteTensor};
173 memcpy(TF_TensorData(this->tf_tensor.get()), data, TF_TensorByteSize(this->tf_tensor.get()));
174 this->tfe_handle = {TFE_NewTensorHandle(this->tf_tensor.get(), context::get_status()), TFE_DeleteTensorHandle};
175 status_check(context::get_status());
178 template <
typename T>
179 tensor::tensor(
const std::vector<T>& values,
const std::vector<int64_t>& shape)
182 template <
typename T>
183 tensor::tensor(
const std::initializer_list<T>& values) :
tensor(std::vector<T>(values), {(int64_t)values.size()}) {}
185 template <
typename T>
186 tensor::tensor(
const T& value) :
tensor(std::vector<T>({value}), {}) {}
188 #ifdef TENSORFLOW_C_TF_TSTRING_H_
191 inline tensor::tensor(
const std::string& value) {
193 TF_TString_Init(&tstr[0]);
194 TF_TString_Copy(&tstr[0], value.c_str(), value.size());
198 *
this = tensor(
static_cast<enum TF_DataType
>(TF_STRING), (
void*)tstr,
sizeof(tstr), {});
202 inline tensor::tensor(
const std::string& value) {
203 size_t size = 8 + TF_StringEncodedSize(value.length());
204 char* data =
new char[value.size() + 8];
205 for (
int i = 0; i < 8; i++) {
208 TF_StringEncode(value.c_str(), value.size(), data + 8, size - 8, context::get_status());
209 status_check(context::get_status());
213 *
this = tensor(
static_cast<enum TF_DataType
>(TF_STRING), (
void*)data, size, {});
218 inline tensor::tensor(TFE_TensorHandle* handle) { this->tfe_handle = {handle, TFE_DeleteTensorHandle}; }
220 inline tensor::tensor(TF_Tensor* t) {
221 this->tf_tensor = {t, TF_DeleteTensor};
222 this->tfe_handle = {TFE_NewTensorHandle(this->tf_tensor.get(), context::get_status()), TFE_DeleteTensorHandle};
223 status_check(context::get_status());
227 auto op = TFE_NewOp(context::get_context(),
"Shape", context::get_status());
228 status_check(context::get_status());
230 TFE_OpAddInput(op, this->tfe_handle.get(), context::get_status());
231 status_check(context::get_status());
234 TFE_OpSetAttrType(op,
"out_type", cppflow::datatype::TF_INT64);
238 TFE_TensorHandle* res[1] = {
nullptr};
239 TFE_Execute(op, res, &n, context::get_status());
240 status_check(context::get_status());
249 res = TFE_TensorHandleBackingDeviceName(this->tfe_handle.get(), context::get_status());
251 res = std::string(TFE_TensorHandleDeviceName(this->tfe_handle.get(), context::get_status()));
253 status_check(context::get_status());
257 template <
typename T>
260 if (this->
dtype() != deduce_tf_type<T>()) {
263 auto error =
"Datatype in function get_data (" + type1 +
") does not match tensor datatype (" + type2 +
")";
264 throw std::runtime_error(error);
267 auto res_tensor = get_tensor();
270 auto raw_data = TF_TensorData(res_tensor.get());
273 size_t size = (TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get())));
276 const auto T_data =
static_cast<T*
>(raw_data);
277 std::vector<T> r(T_data, T_data + size);
282 inline datatype
tensor::dtype()
const {
return TFE_TensorHandleDataType(this->tfe_handle.get()); }
287 inline std::shared_ptr<TF_Tensor> tensor::get_tensor()
const {
289 tf_tensor = {TFE_TensorHandleResolve(tfe_handle.get(), context::get_status()), TF_DeleteTensor};
290 status_check(context::get_status());
A TensorFlow eager tensor wrapper.
Definition: tensor.h:61
tensor shape() const
Definition: tensor.h:226
datatype dtype() const
Definition: tensor.h:282
std::string device(bool on_memory=false) const
Definition: tensor.h:246
std::vector< T > get_data() const
Definition: tensor.h:258
TF_DataType deduce_tf_type()
Definition: datatype.h:107
std::string to_string(datatype dt)
Definition: datatype.h:48