31 #ifndef INCLUDE_CPPFLOW_CONTEXT_H_
32 #define INCLUDE_CPPFLOW_CONTEXT_H_
35 #include <tensorflow/c/c_api.h>
36 #include <tensorflow/c/eager/c_api.h>
45 inline bool status_check(TF_Status* status) {
46 if (TF_GetCode(status) != TF_OK) {
47 throw std::runtime_error(TF_Message(status));
54 explicit context(TFE_ContextOptions* opts =
nullptr);
64 static TFE_Context* get_context();
67 static TF_Status* get_status();
70 TFE_Context* tfe_context{
nullptr};
77 inline context& get_global_context() {
79 return global_context;
85 inline TFE_Context* context::get_context() {
return get_global_context().tfe_context; }
87 inline TF_Status* context::get_status() {
88 thread_local std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> local_tf_status(TF_NewStatus(), &TF_DeleteStatus);
89 return local_tf_status.get();
92 inline context::context(TFE_ContextOptions* opts) {
93 auto tf_status = context::get_status();
94 if (opts ==
nullptr) {
95 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> new_opts(TFE_NewContextOptions(),
96 &TFE_DeleteContextOptions);
97 this->tfe_context = TFE_NewContext(new_opts.get(), tf_status);
99 this->tfe_context = TFE_NewContext(opts, tf_status);
101 status_check(tf_status);
104 inline context::context(context&& ctx) noexcept : tfe_context(std::exchange(ctx.tfe_context,
nullptr)) {}
106 inline context& context::operator=(context&& ctx) noexcept {
107 tfe_context = std::exchange(ctx.tfe_context, tfe_context);
111 inline context::~context() { TFE_DeleteContext(this->tfe_context); }