31 #ifndef INCLUDE_CPPFLOW_DATATYPE_H_
32 #define INCLUDE_CPPFLOW_DATATYPE_H_
38 #include <type_traits>
43 using datatype = TF_DataType;
65 return "TF_COMPLEX64";
85 return "TF_COMPLEX128";
97 return "DATATYPE_NOT_KNOWN";
106 template <
typename T>
108 if (std::is_same<T, float>::value)
return TF_FLOAT;
109 if (std::is_same<T, double>::value)
return TF_DOUBLE;
110 if (std::is_same<T, int32_t>::value)
return TF_INT32;
111 if (std::is_same<T, uint8_t>::value)
return TF_UINT8;
112 if (std::is_same<T, int16_t>::value)
return TF_INT16;
113 if (std::is_same<T, int8_t>::value)
return TF_INT8;
114 if (std::is_same<T, int64_t>::value)
return TF_INT64;
115 if (std::is_same<T, unsigned char>::value)
return TF_BOOL;
116 if (std::is_same<T, uint16_t>::value)
return TF_UINT16;
117 if (std::is_same<T, uint32_t>::value)
return TF_UINT32;
118 if (std::is_same<T, uint64_t>::value)
return TF_UINT64;
121 throw std::runtime_error{
"Could not deduce type! type_name: " + std::string(
typeid(T).name())};
127 inline std::ostream&
operator<<(std::ostream& os, datatype dt) {
std::ostream & operator<<(std::ostream &os, datatype dt)
Definition: datatype.h:127
TF_DataType deduce_tf_type()
Definition: datatype.h:107
std::string to_string(datatype dt)
Definition: datatype.h:48