39 #ifndef INCLUDE_CPPFLOW_MODEL_H_
40 #define INCLUDE_CPPFLOW_MODEL_H_
43 #include <tensorflow/c/c_api.h>
67 explicit model(
const std::string& filename,
const TYPE type = TYPE::SAVED_MODEL);
73 model& operator=(
const model& other) =
default;
75 std::vector<tensor> operator()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<std::string> outputs);
78 std::vector<std::string> get_operations()
const;
79 std::vector<int64_t> get_operation_shape(
const std::string& operation)
const;
82 TF_Buffer* readGraph(
const std::string& filename);
84 std::shared_ptr<TF_Status> status;
85 std::shared_ptr<TF_Graph> graph;
86 std::shared_ptr<TF_Session> session;
93 inline model::model(
const std::string& filename,
const TYPE type) {
94 this->status = {TF_NewStatus(), &TF_DeleteStatus};
95 this->graph = {TF_NewGraph(), TF_DeleteGraph};
98 std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(),
99 TF_DeleteSessionOptions};
101 auto session_deleter = [
this](TF_Session* sess) {
102 TF_DeleteSession(sess, this->status.get());
103 status_check(this->status.get());
106 if (type == TYPE::SAVED_MODEL) {
107 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString(
"", 0),
109 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};
112 const char* tag =
"serve";
113 this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(), &tag,
114 tag_len, this->graph.get(), meta_graph.get(), this->status.get()),
116 }
else if (type == TYPE::FROZEN_GRAPH) {
117 this->session = {TF_NewSession(this->graph.get(), session_options.get(), this->status.get()), session_deleter};
118 status_check(this->status.get());
121 TF_Buffer* def = readGraph(filename);
122 if (def ==
nullptr) {
123 throw std::runtime_error(
"Failed to import graph def from file");
126 std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
127 TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
128 TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), this->status.get());
129 TF_DeleteBuffer(def);
131 throw std::runtime_error(
"Model type unknown");
134 status_check(this->status.get());
137 inline std::vector<std::string> model::get_operations()
const {
138 std::vector<std::string> result;
143 while ((oper = TF_GraphNextOperation(this->graph.get(), &pos)) !=
nullptr) {
144 result.emplace_back(TF_OperationName(oper));
149 inline std::vector<int64_t> model::get_operation_shape(
const std::string& operation)
const {
152 out_op.oper = TF_GraphOperationByName(this->graph.get(), operation.c_str());
155 std::vector<int64_t> shape;
158 if (!out_op.oper)
throw std::runtime_error(
"No operation named \"" + operation +
"\" exists");
160 if (operation ==
"NoOp")
throw std::runtime_error(
"NoOp doesn't have a shape");
165 int n_dims = TF_GraphGetTensorNumDims(this->graph.get(), out_op, this->status.get());
170 auto* dims =
new int64_t[n_dims];
171 TF_GraphGetTensorShape(this->graph.get(), out_op, dims, n_dims, this->status.get());
174 status_check(this->status.get());
176 shape = std::vector<int64_t>(dims, dims + n_dims);
184 inline std::tuple<std::string, int> parse_name(
const std::string& name) {
185 auto idx = name.find(
':');
186 return (idx == std::string::npos ? std::make_tuple(name, 0)
187 : std::make_tuple(name.substr(0, idx), std::stoi(name.substr(idx + 1))));
190 inline std::vector<tensor> model::operator()(std::vector<std::tuple<std::string, tensor>> inputs,
191 std::vector<std::string> outputs) {
192 std::vector<TF_Output> inp_ops(inputs.size());
193 std::vector<TF_Tensor*> inp_val(inputs.size(),
nullptr);
195 for (decltype(inputs.size()) i = 0; i < inputs.size(); i++) {
197 const auto [op_name, op_idx] = parse_name(std::get<0>(inputs[i]));
198 inp_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
199 inp_ops[i].index = op_idx;
201 if (!inp_ops[i].oper)
throw std::runtime_error(
"No operation named \"" + op_name +
"\" exists");
204 inp_val[i] = std::get<1>(inputs[i]).get_tensor().get();
207 std::vector<TF_Output> out_ops(outputs.size());
208 auto out_val = std::make_unique<TF_Tensor*[]>(outputs.size());
209 for (decltype(outputs.size()) i = 0; i < outputs.size(); i++) {
210 const auto [op_name, op_idx] = parse_name(outputs[i]);
211 out_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
212 out_ops[i].index = op_idx;
214 if (!out_ops[i].oper)
throw std::runtime_error(
"No operation named \"" + op_name +
"\" exists");
217 TF_SessionRun(this->session.get(), NULL, inp_ops.data(), inp_val.data(),
218 static_cast<int>(inputs.size()), out_ops.data(), out_val.get(),
static_cast<int>(outputs.size()),
219 NULL, 0, NULL, this->status.get());
220 status_check(this->status.get());
222 std::vector<tensor> result;
223 result.reserve(outputs.size());
224 for (decltype(outputs.size()) i = 0; i < outputs.size(); i++) {
225 result.emplace_back(tensor(out_val[i]));
231 inline tensor model::operator()(
const tensor& input) {
232 return (*
this)({{
"serving_default_input_1", input}}, {
"StatefulPartitionedCall"})[0];
235 inline TF_Buffer* model::readGraph(
const std::string& filename) {
236 std::ifstream file(filename, std::ios::binary | std::ios::ate);
239 if (!file.is_open()) {
240 std::cerr <<
"Unable to open file: " << filename << std::endl;
245 auto size = file.tellg();
247 file.seekg(0, std::ios::beg);
250 auto data = std::make_unique<char[]>(size);
251 file.seekg(0, std::ios::beg);
252 file.read(data.get(), size);
256 std::cerr <<
"Unable to read the full file: " << filename << std::endl;
261 TF_Buffer* buffer = TF_NewBufferFromString(data.get(), size);
A TensorFlow eager tensor wrapper.
Definition: tensor.h:61