/home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/context.h Source File

Ratpac-two: /home/docs/checkouts/readthedocs.org/user_builds/ratpac/checkouts/latest/src/external/cppflow/include/cppflow/context.h Source File
Ratpac-two
context.h
Go to the documentation of this file.
1 // MIT License
2 //
3 // Copyright (c) 2020 Sergio Izquierdo
4 // Copyright (c) 2020 Jiannan Liu
5 //
6 // Permission is hereby granted, free of charge, to any person obtaining a copy
7 // of this software and associated documentation files (the "Software"), to deal
8 // in the Software without restriction, including without limitation the rights
9 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 // copies of the Software, and to permit persons to whom the Software is
11 // furnished to do so, subject to the following conditions:
12 //
13 // The above copyright notice and this permission notice shall be included in
14 // all copies or substantial portions of the Software.
15 //
16 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 // SOFTWARE.
23 
31 #ifndef INCLUDE_CPPFLOW_CONTEXT_H_
32 #define INCLUDE_CPPFLOW_CONTEXT_H_
33 
34 // C headers
35 #include <tensorflow/c/c_api.h>
36 #include <tensorflow/c/eager/c_api.h>
37 
38 // C++ headers
39 #include <memory>
40 #include <stdexcept>
41 #include <utility>
42 
43 namespace cppflow {
44 
45 inline bool status_check(TF_Status* status) {
46  if (TF_GetCode(status) != TF_OK) {
47  throw std::runtime_error(TF_Message(status));
48  }
49  return true;
50 }
51 
52 class context {
53  public:
54  explicit context(TFE_ContextOptions* opts = nullptr);
55 
56  context(context const&) = delete;
57  context(context&&) noexcept;
58 
59  ~context();
60 
61  context& operator=(context const&) = delete;
62  context& operator=(context&&) noexcept;
63 
64  static TFE_Context* get_context();
65 
66  // only use get_status() for eager ops
67  static TF_Status* get_status();
68 
69  private:
70  TFE_Context* tfe_context{nullptr};
71 }; // Class context
72 
73 // @todo create ContextManager class if needed
74 // Set new context, thread unsafe, must be called at the beginning.
75 // TFE_ContextOptions* tfe_opts = ...
76 // cppflow::get_global_context() = cppflow::context(tfe_opts);
77 inline context& get_global_context() {
78  static context global_context;
79  return global_context;
80 }
81 } // namespace cppflow
82 
83 namespace cppflow {
84 
85 inline TFE_Context* context::get_context() { return get_global_context().tfe_context; }
86 
87 inline TF_Status* context::get_status() {
88  thread_local std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> local_tf_status(TF_NewStatus(), &TF_DeleteStatus);
89  return local_tf_status.get();
90 }
91 
92 inline context::context(TFE_ContextOptions* opts) {
93  auto tf_status = context::get_status();
94  if (opts == nullptr) {
95  std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> new_opts(TFE_NewContextOptions(),
96  &TFE_DeleteContextOptions);
97  this->tfe_context = TFE_NewContext(new_opts.get(), tf_status);
98  } else {
99  this->tfe_context = TFE_NewContext(opts, tf_status);
100  }
101  status_check(tf_status);
102 }
103 
104 inline context::context(context&& ctx) noexcept : tfe_context(std::exchange(ctx.tfe_context, nullptr)) {}
105 
106 inline context& context::operator=(context&& ctx) noexcept {
107  tfe_context = std::exchange(ctx.tfe_context, tfe_context);
108  return *this;
109 }
110 
111 inline context::~context() { TFE_DeleteContext(this->tfe_context); }
112 
113 } // namespace cppflow
114 
115 #endif // INCLUDE_CPPFLOW_CONTEXT_H_
Definition: context.h:52