7 #ifndef CPPFLOW2_RAW_OPS_H
8 #define CPPFLOW2_RAW_OPS_H
10 #include <tensorflow/c/eager/c_api.h>
11 #include <tensorflow/c/tf_datatype.h>
12 #include <tensorflow/c/tf_tensor.h>
24 inline tensor abs(
const tensor& x) {
26 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Abs", context::get_status()),
28 status_check(context::get_status());
32 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
33 status_check(context::get_status());
38 int num_outputs_op = 1;
39 TFE_TensorHandle* res[1] = {
nullptr};
40 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
41 status_check(context::get_status());
42 return tensor(res[0]);
45 inline tensor accumulate_n_v2(
const std::vector<tensor>& inputs,
const std::vector<int64_t>& shape) {
47 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
48 TFE_NewOp(context::get_context(),
"AccumulateNV2", context::get_status()), &TFE_DeleteOp);
49 status_check(context::get_status());
53 std::vector<TFE_TensorHandle*> inputs_handles;
54 inputs_handles.reserve(inputs.size());
55 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
56 [](
const auto& t) { return t.tfe_handle.get(); });
57 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
58 status_check(context::get_status());
61 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
63 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
64 status_check(context::get_status());
67 int num_outputs_op = 1;
68 TFE_TensorHandle* res[1] = {
nullptr};
69 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
70 status_check(context::get_status());
71 return tensor(res[0]);
74 inline tensor accumulator_num_accumulated(
const tensor& handle) {
76 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
77 TFE_NewOp(context::get_context(),
"AccumulatorNumAccumulated", context::get_status()), &TFE_DeleteOp);
78 status_check(context::get_status());
82 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
83 status_check(context::get_status());
88 int num_outputs_op = 1;
89 TFE_TensorHandle* res[1] = {
nullptr};
90 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
91 status_check(context::get_status());
92 return tensor(res[0]);
95 inline tensor accumulator_take_gradient(
const tensor& handle,
const tensor& num_required, datatype dtype) {
97 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
98 TFE_NewOp(context::get_context(),
"AccumulatorTakeGradient", context::get_status()), &TFE_DeleteOp);
99 status_check(context::get_status());
103 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
104 status_check(context::get_status());
106 TFE_OpAddInput(op.get(), num_required.tfe_handle.get(), context::get_status());
107 status_check(context::get_status());
110 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
113 int num_outputs_op = 1;
114 TFE_TensorHandle* res[1] = {
nullptr};
115 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
116 status_check(context::get_status());
117 return tensor(res[0]);
120 inline tensor acos(
const tensor& x) {
122 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Acos", context::get_status()),
124 status_check(context::get_status());
128 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
129 status_check(context::get_status());
134 int num_outputs_op = 1;
135 TFE_TensorHandle* res[1] = {
nullptr};
136 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
137 status_check(context::get_status());
138 return tensor(res[0]);
141 inline tensor acosh(
const tensor& x) {
143 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Acosh", context::get_status()),
145 status_check(context::get_status());
149 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
150 status_check(context::get_status());
155 int num_outputs_op = 1;
156 TFE_TensorHandle* res[1] = {
nullptr};
157 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
158 status_check(context::get_status());
159 return tensor(res[0]);
162 inline tensor add(
const tensor& x,
const tensor& y) {
164 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Add", context::get_status()),
166 status_check(context::get_status());
170 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
171 status_check(context::get_status());
173 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
174 status_check(context::get_status());
179 int num_outputs_op = 1;
180 TFE_TensorHandle* res[1] = {
nullptr};
181 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
182 status_check(context::get_status());
183 return tensor(res[0]);
186 inline tensor add_many_sparse_to_tensors_map(
const tensor& sparse_indices,
const tensor& sparse_values,
187 const tensor& sparse_shape,
const std::string& container =
"",
188 const std::string& shared_name =
"") {
190 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
191 TFE_NewOp(context::get_context(),
"AddManySparseToTensorsMap", context::get_status()), &TFE_DeleteOp);
192 status_check(context::get_status());
196 TFE_OpAddInput(op.get(), sparse_indices.tfe_handle.get(), context::get_status());
197 status_check(context::get_status());
199 TFE_OpAddInput(op.get(), sparse_values.tfe_handle.get(), context::get_status());
200 status_check(context::get_status());
202 TFE_OpAddInput(op.get(), sparse_shape.tfe_handle.get(), context::get_status());
203 status_check(context::get_status());
206 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
207 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
210 int num_outputs_op = 1;
211 TFE_TensorHandle* res[1] = {
nullptr};
212 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
213 status_check(context::get_status());
214 return tensor(res[0]);
217 inline tensor add_n(
const std::vector<tensor>& inputs) {
219 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"AddN", context::get_status()),
221 status_check(context::get_status());
225 std::vector<TFE_TensorHandle*> inputs_handles;
226 inputs_handles.reserve(inputs.size());
227 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
228 [](
const auto& t) { return t.tfe_handle.get(); });
229 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
230 status_check(context::get_status());
233 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
236 int num_outputs_op = 1;
237 TFE_TensorHandle* res[1] = {
nullptr};
238 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
239 status_check(context::get_status());
240 return tensor(res[0]);
243 inline tensor add_sparse_to_tensors_map(
const tensor& sparse_indices,
const tensor& sparse_values,
244 const tensor& sparse_shape,
const std::string& container =
"",
245 const std::string& shared_name =
"") {
247 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
248 TFE_NewOp(context::get_context(),
"AddSparseToTensorsMap", context::get_status()), &TFE_DeleteOp);
249 status_check(context::get_status());
253 TFE_OpAddInput(op.get(), sparse_indices.tfe_handle.get(), context::get_status());
254 status_check(context::get_status());
256 TFE_OpAddInput(op.get(), sparse_values.tfe_handle.get(), context::get_status());
257 status_check(context::get_status());
259 TFE_OpAddInput(op.get(), sparse_shape.tfe_handle.get(), context::get_status());
260 status_check(context::get_status());
263 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
264 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
267 int num_outputs_op = 1;
268 TFE_TensorHandle* res[1] = {
nullptr};
269 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
270 status_check(context::get_status());
271 return tensor(res[0]);
274 inline tensor add_v2(
const tensor& x,
const tensor& y) {
276 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"AddV2", context::get_status()),
278 status_check(context::get_status());
282 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
283 status_check(context::get_status());
285 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
286 status_check(context::get_status());
291 int num_outputs_op = 1;
292 TFE_TensorHandle* res[1] = {
nullptr};
293 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
294 status_check(context::get_status());
295 return tensor(res[0]);
298 inline tensor adjust_contrast(
const tensor& images,
const tensor& contrast_factor,
const tensor& min_value,
299 const tensor& max_value) {
301 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
302 TFE_NewOp(context::get_context(),
"AdjustContrast", context::get_status()), &TFE_DeleteOp);
303 status_check(context::get_status());
307 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
308 status_check(context::get_status());
310 TFE_OpAddInput(op.get(), contrast_factor.tfe_handle.get(), context::get_status());
311 status_check(context::get_status());
313 TFE_OpAddInput(op.get(), min_value.tfe_handle.get(), context::get_status());
314 status_check(context::get_status());
316 TFE_OpAddInput(op.get(), max_value.tfe_handle.get(), context::get_status());
317 status_check(context::get_status());
322 int num_outputs_op = 1;
323 TFE_TensorHandle* res[1] = {
nullptr};
324 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
325 status_check(context::get_status());
326 return tensor(res[0]);
329 inline tensor adjust_contrastv2(
const tensor& images,
const tensor& contrast_factor) {
331 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
332 TFE_NewOp(context::get_context(),
"AdjustContrastv2", context::get_status()), &TFE_DeleteOp);
333 status_check(context::get_status());
337 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
338 status_check(context::get_status());
340 TFE_OpAddInput(op.get(), contrast_factor.tfe_handle.get(), context::get_status());
341 status_check(context::get_status());
346 int num_outputs_op = 1;
347 TFE_TensorHandle* res[1] = {
nullptr};
348 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
349 status_check(context::get_status());
350 return tensor(res[0]);
353 inline tensor adjust_hue(
const tensor& images,
const tensor& delta) {
355 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
356 TFE_NewOp(context::get_context(),
"AdjustHue", context::get_status()), &TFE_DeleteOp);
357 status_check(context::get_status());
361 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
362 status_check(context::get_status());
364 TFE_OpAddInput(op.get(), delta.tfe_handle.get(), context::get_status());
365 status_check(context::get_status());
370 int num_outputs_op = 1;
371 TFE_TensorHandle* res[1] = {
nullptr};
372 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
373 status_check(context::get_status());
374 return tensor(res[0]);
377 inline tensor adjust_saturation(
const tensor& images,
const tensor& scale) {
379 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
380 TFE_NewOp(context::get_context(),
"AdjustSaturation", context::get_status()), &TFE_DeleteOp);
381 status_check(context::get_status());
385 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
386 status_check(context::get_status());
388 TFE_OpAddInput(op.get(), scale.tfe_handle.get(), context::get_status());
389 status_check(context::get_status());
394 int num_outputs_op = 1;
395 TFE_TensorHandle* res[1] = {
nullptr};
396 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
397 status_check(context::get_status());
398 return tensor(res[0]);
401 inline tensor all(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
402 datatype Tidx =
static_cast<datatype
>(3)) {
404 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"All", context::get_status()),
406 status_check(context::get_status());
410 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
411 status_check(context::get_status());
413 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
414 status_check(context::get_status());
417 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
418 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
421 int num_outputs_op = 1;
422 TFE_TensorHandle* res[1] = {
nullptr};
423 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
424 status_check(context::get_status());
425 return tensor(res[0]);
428 inline tensor all_to_all(
const tensor& input,
const tensor& group_assignment, int64_t concat_dimension,
429 int64_t split_dimension, int64_t split_count) {
431 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
432 TFE_NewOp(context::get_context(),
"AllToAll", context::get_status()), &TFE_DeleteOp);
433 status_check(context::get_status());
437 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
438 status_check(context::get_status());
440 TFE_OpAddInput(op.get(), group_assignment.tfe_handle.get(), context::get_status());
441 status_check(context::get_status());
444 TFE_OpSetAttrInt(op.get(),
"concat_dimension", concat_dimension);
445 TFE_OpSetAttrInt(op.get(),
"split_dimension", split_dimension);
446 TFE_OpSetAttrInt(op.get(),
"split_count", split_count);
449 int num_outputs_op = 1;
450 TFE_TensorHandle* res[1] = {
nullptr};
451 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
452 status_check(context::get_status());
453 return tensor(res[0]);
456 inline tensor angle(
const tensor& input, datatype Tout =
static_cast<datatype
>(1)) {
458 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Angle", context::get_status()),
460 status_check(context::get_status());
464 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
465 status_check(context::get_status());
468 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
471 int num_outputs_op = 1;
472 TFE_TensorHandle* res[1] = {
nullptr};
473 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
474 status_check(context::get_status());
475 return tensor(res[0]);
478 inline tensor anonymous_iterator(
const std::vector<datatype>& output_types,
479 const std::vector<std::vector<int64_t>>& output_shapes) {
481 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
482 TFE_NewOp(context::get_context(),
"AnonymousIterator", context::get_status()), &TFE_DeleteOp);
483 status_check(context::get_status());
488 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
489 static_cast<int>(output_types.size()));
491 std::vector<const int64_t*> output_shapes_values;
492 output_shapes_values.reserve(output_shapes.size());
493 std::vector<int> output_shapes_ndims;
494 output_shapes_ndims.reserve(output_shapes.size());
495 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
496 [](
const auto& v) { return v.data(); });
497 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
498 [](
const auto& v) { return static_cast<int>(v.size()); });
499 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
500 static_cast<int>(output_shapes.size()), context::get_status());
501 status_check(context::get_status());
504 int num_outputs_op = 1;
505 TFE_TensorHandle* res[1] = {
nullptr};
506 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
507 status_check(context::get_status());
508 return tensor(res[0]);
511 inline tensor any(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
512 datatype Tidx =
static_cast<datatype
>(3)) {
514 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Any", context::get_status()),
516 status_check(context::get_status());
520 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
521 status_check(context::get_status());
523 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
524 status_check(context::get_status());
527 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
528 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
531 int num_outputs_op = 1;
532 TFE_TensorHandle* res[1] = {
nullptr};
533 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
534 status_check(context::get_status());
535 return tensor(res[0]);
538 inline tensor apply_ada_max(
const tensor& var,
const tensor& m,
const tensor& v,
const tensor& beta1_power,
539 const tensor& lr,
const tensor& beta1,
const tensor& beta2,
const tensor& epsilon,
540 const tensor& grad,
bool use_locking =
false) {
542 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
543 TFE_NewOp(context::get_context(),
"ApplyAdaMax", context::get_status()), &TFE_DeleteOp);
544 status_check(context::get_status());
548 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
549 status_check(context::get_status());
551 TFE_OpAddInput(op.get(), m.tfe_handle.get(), context::get_status());
552 status_check(context::get_status());
554 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
555 status_check(context::get_status());
557 TFE_OpAddInput(op.get(), beta1_power.tfe_handle.get(), context::get_status());
558 status_check(context::get_status());
560 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
561 status_check(context::get_status());
563 TFE_OpAddInput(op.get(), beta1.tfe_handle.get(), context::get_status());
564 status_check(context::get_status());
566 TFE_OpAddInput(op.get(), beta2.tfe_handle.get(), context::get_status());
567 status_check(context::get_status());
569 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
570 status_check(context::get_status());
572 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
573 status_check(context::get_status());
576 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
579 int num_outputs_op = 1;
580 TFE_TensorHandle* res[1] = {
nullptr};
581 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
582 status_check(context::get_status());
583 return tensor(res[0]);
586 inline tensor apply_adadelta(
const tensor& var,
const tensor& accum,
const tensor& accum_update,
const tensor& lr,
587 const tensor& rho,
const tensor& epsilon,
const tensor& grad,
bool use_locking =
false) {
589 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
590 TFE_NewOp(context::get_context(),
"ApplyAdadelta", context::get_status()), &TFE_DeleteOp);
591 status_check(context::get_status());
595 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
596 status_check(context::get_status());
598 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
599 status_check(context::get_status());
601 TFE_OpAddInput(op.get(), accum_update.tfe_handle.get(), context::get_status());
602 status_check(context::get_status());
604 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
605 status_check(context::get_status());
607 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
608 status_check(context::get_status());
610 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
611 status_check(context::get_status());
613 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
614 status_check(context::get_status());
617 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
620 int num_outputs_op = 1;
621 TFE_TensorHandle* res[1] = {
nullptr};
622 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
623 status_check(context::get_status());
624 return tensor(res[0]);
627 inline tensor apply_adagrad(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& grad,
628 bool use_locking =
false,
bool update_slots =
true) {
630 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
631 TFE_NewOp(context::get_context(),
"ApplyAdagrad", context::get_status()), &TFE_DeleteOp);
632 status_check(context::get_status());
636 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
637 status_check(context::get_status());
639 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
640 status_check(context::get_status());
642 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
643 status_check(context::get_status());
645 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
646 status_check(context::get_status());
649 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
650 TFE_OpSetAttrBool(op.get(),
"update_slots", (
unsigned char)update_slots);
653 int num_outputs_op = 1;
654 TFE_TensorHandle* res[1] = {
nullptr};
655 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
656 status_check(context::get_status());
657 return tensor(res[0]);
660 inline tensor apply_adagrad_d_a(
const tensor& var,
const tensor& gradient_accumulator,
661 const tensor& gradient_squared_accumulator,
const tensor& grad,
const tensor& lr,
662 const tensor& l1,
const tensor& l2,
const tensor& global_step,
663 bool use_locking =
false) {
665 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
666 TFE_NewOp(context::get_context(),
"ApplyAdagradDA", context::get_status()), &TFE_DeleteOp);
667 status_check(context::get_status());
671 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
672 status_check(context::get_status());
674 TFE_OpAddInput(op.get(), gradient_accumulator.tfe_handle.get(), context::get_status());
675 status_check(context::get_status());
677 TFE_OpAddInput(op.get(), gradient_squared_accumulator.tfe_handle.get(), context::get_status());
678 status_check(context::get_status());
680 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
681 status_check(context::get_status());
683 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
684 status_check(context::get_status());
686 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
687 status_check(context::get_status());
689 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
690 status_check(context::get_status());
692 TFE_OpAddInput(op.get(), global_step.tfe_handle.get(), context::get_status());
693 status_check(context::get_status());
696 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
699 int num_outputs_op = 1;
700 TFE_TensorHandle* res[1] = {
nullptr};
701 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
702 status_check(context::get_status());
703 return tensor(res[0]);
706 inline tensor apply_adagrad_v2(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& epsilon,
707 const tensor& grad,
bool use_locking =
false,
bool update_slots =
true) {
709 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
710 TFE_NewOp(context::get_context(),
"ApplyAdagradV2", context::get_status()), &TFE_DeleteOp);
711 status_check(context::get_status());
715 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
716 status_check(context::get_status());
718 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
719 status_check(context::get_status());
721 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
722 status_check(context::get_status());
724 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
725 status_check(context::get_status());
727 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
728 status_check(context::get_status());
731 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
732 TFE_OpSetAttrBool(op.get(),
"update_slots", (
unsigned char)update_slots);
735 int num_outputs_op = 1;
736 TFE_TensorHandle* res[1] = {
nullptr};
737 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
738 status_check(context::get_status());
739 return tensor(res[0]);
742 inline tensor apply_adam(
const tensor& var,
const tensor& m,
const tensor& v,
const tensor& beta1_power,
743 const tensor& beta2_power,
const tensor& lr,
const tensor& beta1,
const tensor& beta2,
744 const tensor& epsilon,
const tensor& grad,
bool use_locking =
false,
745 bool use_nesterov =
false) {
747 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
748 TFE_NewOp(context::get_context(),
"ApplyAdam", context::get_status()), &TFE_DeleteOp);
749 status_check(context::get_status());
753 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
754 status_check(context::get_status());
756 TFE_OpAddInput(op.get(), m.tfe_handle.get(), context::get_status());
757 status_check(context::get_status());
759 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
760 status_check(context::get_status());
762 TFE_OpAddInput(op.get(), beta1_power.tfe_handle.get(), context::get_status());
763 status_check(context::get_status());
765 TFE_OpAddInput(op.get(), beta2_power.tfe_handle.get(), context::get_status());
766 status_check(context::get_status());
768 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
769 status_check(context::get_status());
771 TFE_OpAddInput(op.get(), beta1.tfe_handle.get(), context::get_status());
772 status_check(context::get_status());
774 TFE_OpAddInput(op.get(), beta2.tfe_handle.get(), context::get_status());
775 status_check(context::get_status());
777 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
778 status_check(context::get_status());
780 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
781 status_check(context::get_status());
784 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
785 TFE_OpSetAttrBool(op.get(),
"use_nesterov", (
unsigned char)use_nesterov);
788 int num_outputs_op = 1;
789 TFE_TensorHandle* res[1] = {
nullptr};
790 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
791 status_check(context::get_status());
792 return tensor(res[0]);
795 inline tensor apply_add_sign(
const tensor& var,
const tensor& m,
const tensor& lr,
const tensor& alpha,
796 const tensor& sign_decay,
const tensor& beta,
const tensor& grad,
797 bool use_locking =
false) {
799 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
800 TFE_NewOp(context::get_context(),
"ApplyAddSign", context::get_status()), &TFE_DeleteOp);
801 status_check(context::get_status());
805 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
806 status_check(context::get_status());
808 TFE_OpAddInput(op.get(), m.tfe_handle.get(), context::get_status());
809 status_check(context::get_status());
811 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
812 status_check(context::get_status());
814 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
815 status_check(context::get_status());
817 TFE_OpAddInput(op.get(), sign_decay.tfe_handle.get(), context::get_status());
818 status_check(context::get_status());
820 TFE_OpAddInput(op.get(), beta.tfe_handle.get(), context::get_status());
821 status_check(context::get_status());
823 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
824 status_check(context::get_status());
827 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
830 int num_outputs_op = 1;
831 TFE_TensorHandle* res[1] = {
nullptr};
832 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
833 status_check(context::get_status());
834 return tensor(res[0]);
837 inline tensor apply_centered_r_m_s_prop(
const tensor& var,
const tensor& mg,
const tensor& ms,
const tensor& mom,
838 const tensor& lr,
const tensor& rho,
const tensor& momentum,
839 const tensor& epsilon,
const tensor& grad,
bool use_locking =
false) {
841 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
842 TFE_NewOp(context::get_context(),
"ApplyCenteredRMSProp", context::get_status()), &TFE_DeleteOp);
843 status_check(context::get_status());
847 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
848 status_check(context::get_status());
850 TFE_OpAddInput(op.get(), mg.tfe_handle.get(), context::get_status());
851 status_check(context::get_status());
853 TFE_OpAddInput(op.get(), ms.tfe_handle.get(), context::get_status());
854 status_check(context::get_status());
856 TFE_OpAddInput(op.get(), mom.tfe_handle.get(), context::get_status());
857 status_check(context::get_status());
859 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
860 status_check(context::get_status());
862 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
863 status_check(context::get_status());
865 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
866 status_check(context::get_status());
868 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
869 status_check(context::get_status());
871 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
872 status_check(context::get_status());
875 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
878 int num_outputs_op = 1;
879 TFE_TensorHandle* res[1] = {
nullptr};
880 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
881 status_check(context::get_status());
882 return tensor(res[0]);
885 inline tensor apply_ftrl(
const tensor& var,
const tensor& accum,
const tensor& linear,
const tensor& grad,
886 const tensor& lr,
const tensor& l1,
const tensor& l2,
const tensor& lr_power,
887 bool use_locking =
false,
bool multiply_linear_by_lr =
false) {
889 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
890 TFE_NewOp(context::get_context(),
"ApplyFtrl", context::get_status()), &TFE_DeleteOp);
891 status_check(context::get_status());
895 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
896 status_check(context::get_status());
898 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
899 status_check(context::get_status());
901 TFE_OpAddInput(op.get(), linear.tfe_handle.get(), context::get_status());
902 status_check(context::get_status());
904 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
905 status_check(context::get_status());
907 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
908 status_check(context::get_status());
910 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
911 status_check(context::get_status());
913 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
914 status_check(context::get_status());
916 TFE_OpAddInput(op.get(), lr_power.tfe_handle.get(), context::get_status());
917 status_check(context::get_status());
920 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
921 TFE_OpSetAttrBool(op.get(),
"multiply_linear_by_lr", (
unsigned char)multiply_linear_by_lr);
924 int num_outputs_op = 1;
925 TFE_TensorHandle* res[1] = {
nullptr};
926 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
927 status_check(context::get_status());
928 return tensor(res[0]);
931 inline tensor apply_ftrl_v2(
const tensor& var,
const tensor& accum,
const tensor& linear,
const tensor& grad,
932 const tensor& lr,
const tensor& l1,
const tensor& l2,
const tensor& l2_shrinkage,
933 const tensor& lr_power,
bool use_locking =
false,
bool multiply_linear_by_lr =
false) {
935 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
936 TFE_NewOp(context::get_context(),
"ApplyFtrlV2", context::get_status()), &TFE_DeleteOp);
937 status_check(context::get_status());
941 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
942 status_check(context::get_status());
944 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
945 status_check(context::get_status());
947 TFE_OpAddInput(op.get(), linear.tfe_handle.get(), context::get_status());
948 status_check(context::get_status());
950 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
951 status_check(context::get_status());
953 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
954 status_check(context::get_status());
956 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
957 status_check(context::get_status());
959 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
960 status_check(context::get_status());
962 TFE_OpAddInput(op.get(), l2_shrinkage.tfe_handle.get(), context::get_status());
963 status_check(context::get_status());
965 TFE_OpAddInput(op.get(), lr_power.tfe_handle.get(), context::get_status());
966 status_check(context::get_status());
969 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
970 TFE_OpSetAttrBool(op.get(),
"multiply_linear_by_lr", (
unsigned char)multiply_linear_by_lr);
973 int num_outputs_op = 1;
974 TFE_TensorHandle* res[1] = {
nullptr};
975 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
976 status_check(context::get_status());
977 return tensor(res[0]);
980 inline tensor apply_gradient_descent(
const tensor& var,
const tensor& alpha,
const tensor& delta,
981 bool use_locking =
false) {
983 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
984 TFE_NewOp(context::get_context(),
"ApplyGradientDescent", context::get_status()), &TFE_DeleteOp);
985 status_check(context::get_status());
989 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
990 status_check(context::get_status());
992 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
993 status_check(context::get_status());
995 TFE_OpAddInput(op.get(), delta.tfe_handle.get(), context::get_status());
996 status_check(context::get_status());
999 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1002 int num_outputs_op = 1;
1003 TFE_TensorHandle* res[1] = {
nullptr};
1004 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1005 status_check(context::get_status());
1006 return tensor(res[0]);
1009 inline tensor apply_momentum(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& grad,
1010 const tensor& momentum,
bool use_locking =
false,
bool use_nesterov =
false) {
1012 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1013 TFE_NewOp(context::get_context(),
"ApplyMomentum", context::get_status()), &TFE_DeleteOp);
1014 status_check(context::get_status());
1018 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
1019 status_check(context::get_status());
1021 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
1022 status_check(context::get_status());
1024 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
1025 status_check(context::get_status());
1027 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1028 status_check(context::get_status());
1030 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
1031 status_check(context::get_status());
1034 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1035 TFE_OpSetAttrBool(op.get(),
"use_nesterov", (
unsigned char)use_nesterov);
1038 int num_outputs_op = 1;
1039 TFE_TensorHandle* res[1] = {
nullptr};
1040 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1041 status_check(context::get_status());
1042 return tensor(res[0]);
1045 inline tensor apply_power_sign(
const tensor& var,
const tensor& m,
const tensor& lr,
const tensor& logbase,
1046 const tensor& sign_decay,
const tensor& beta,
const tensor& grad,
1047 bool use_locking =
false) {
1049 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1050 TFE_NewOp(context::get_context(),
"ApplyPowerSign", context::get_status()), &TFE_DeleteOp);
1051 status_check(context::get_status());
1055 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
1056 status_check(context::get_status());
1058 TFE_OpAddInput(op.get(), m.tfe_handle.get(), context::get_status());
1059 status_check(context::get_status());
1061 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
1062 status_check(context::get_status());
1064 TFE_OpAddInput(op.get(), logbase.tfe_handle.get(), context::get_status());
1065 status_check(context::get_status());
1067 TFE_OpAddInput(op.get(), sign_decay.tfe_handle.get(), context::get_status());
1068 status_check(context::get_status());
1070 TFE_OpAddInput(op.get(), beta.tfe_handle.get(), context::get_status());
1071 status_check(context::get_status());
1073 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1074 status_check(context::get_status());
1077 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1080 int num_outputs_op = 1;
1081 TFE_TensorHandle* res[1] = {
nullptr};
1082 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1083 status_check(context::get_status());
1084 return tensor(res[0]);
1087 inline tensor apply_proximal_adagrad(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& l1,
1088 const tensor& l2,
const tensor& grad,
bool use_locking =
false) {
1090 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1091 TFE_NewOp(context::get_context(),
"ApplyProximalAdagrad", context::get_status()), &TFE_DeleteOp);
1092 status_check(context::get_status());
1096 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
1097 status_check(context::get_status());
1099 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
1100 status_check(context::get_status());
1102 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
1103 status_check(context::get_status());
1105 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
1106 status_check(context::get_status());
1108 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
1109 status_check(context::get_status());
1111 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1112 status_check(context::get_status());
1115 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1118 int num_outputs_op = 1;
1119 TFE_TensorHandle* res[1] = {
nullptr};
1120 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1121 status_check(context::get_status());
1122 return tensor(res[0]);
1125 inline tensor apply_proximal_gradient_descent(
const tensor& var,
const tensor& alpha,
const tensor& l1,
1126 const tensor& l2,
const tensor& delta,
bool use_locking =
false) {
1128 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1129 TFE_NewOp(context::get_context(),
"ApplyProximalGradientDescent", context::get_status()), &TFE_DeleteOp);
1130 status_check(context::get_status());
1134 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
1135 status_check(context::get_status());
1137 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
1138 status_check(context::get_status());
1140 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
1141 status_check(context::get_status());
1143 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
1144 status_check(context::get_status());
1146 TFE_OpAddInput(op.get(), delta.tfe_handle.get(), context::get_status());
1147 status_check(context::get_status());
1150 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1153 int num_outputs_op = 1;
1154 TFE_TensorHandle* res[1] = {
nullptr};
1155 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1156 status_check(context::get_status());
1157 return tensor(res[0]);
1160 inline tensor apply_r_m_s_prop(
const tensor& var,
const tensor& ms,
const tensor& mom,
const tensor& lr,
1161 const tensor& rho,
const tensor& momentum,
const tensor& epsilon,
const tensor& grad,
1162 bool use_locking =
false) {
1164 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1165 TFE_NewOp(context::get_context(),
"ApplyRMSProp", context::get_status()), &TFE_DeleteOp);
1166 status_check(context::get_status());
1170 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
1171 status_check(context::get_status());
1173 TFE_OpAddInput(op.get(), ms.tfe_handle.get(), context::get_status());
1174 status_check(context::get_status());
1176 TFE_OpAddInput(op.get(), mom.tfe_handle.get(), context::get_status());
1177 status_check(context::get_status());
1179 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
1180 status_check(context::get_status());
1182 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
1183 status_check(context::get_status());
1185 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
1186 status_check(context::get_status());
1188 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
1189 status_check(context::get_status());
1191 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1192 status_check(context::get_status());
1195 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1198 int num_outputs_op = 1;
1199 TFE_TensorHandle* res[1] = {
nullptr};
1200 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1201 status_check(context::get_status());
1202 return tensor(res[0]);
1205 inline tensor approximate_equal(
const tensor& x,
const tensor& y,
float tolerance = 1.0000e-05) {
1207 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1208 TFE_NewOp(context::get_context(),
"ApproximateEqual", context::get_status()), &TFE_DeleteOp);
1209 status_check(context::get_status());
1213 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1214 status_check(context::get_status());
1216 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
1217 status_check(context::get_status());
1220 TFE_OpSetAttrFloat(op.get(),
"tolerance", tolerance);
1223 int num_outputs_op = 1;
1224 TFE_TensorHandle* res[1] = {
nullptr};
1225 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1226 status_check(context::get_status());
1227 return tensor(res[0]);
1230 inline tensor arg_max(
const tensor& input,
const tensor& dimension, datatype Tidx =
static_cast<datatype
>(3),
1231 datatype output_type =
static_cast<datatype
>(9)) {
1233 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1234 TFE_NewOp(context::get_context(),
"ArgMax", context::get_status()), &TFE_DeleteOp);
1235 status_check(context::get_status());
1239 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1240 status_check(context::get_status());
1242 TFE_OpAddInput(op.get(), dimension.tfe_handle.get(), context::get_status());
1243 status_check(context::get_status());
1246 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
1247 TFE_OpSetAttrType(op.get(),
"output_type", output_type);
1250 int num_outputs_op = 1;
1251 TFE_TensorHandle* res[1] = {
nullptr};
1252 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1253 status_check(context::get_status());
1254 return tensor(res[0]);
1257 inline tensor arg_min(
const tensor& input,
const tensor& dimension, datatype Tidx =
static_cast<datatype
>(3),
1258 datatype output_type =
static_cast<datatype
>(9)) {
1260 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1261 TFE_NewOp(context::get_context(),
"ArgMin", context::get_status()), &TFE_DeleteOp);
1262 status_check(context::get_status());
1266 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1267 status_check(context::get_status());
1269 TFE_OpAddInput(op.get(), dimension.tfe_handle.get(), context::get_status());
1270 status_check(context::get_status());
1273 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
1274 TFE_OpSetAttrType(op.get(),
"output_type", output_type);
1277 int num_outputs_op = 1;
1278 TFE_TensorHandle* res[1] = {
nullptr};
1279 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1280 status_check(context::get_status());
1281 return tensor(res[0]);
1284 inline tensor as_string(
const tensor& input, int64_t precision = -1,
bool scientific =
false,
bool shortest =
false,
1285 int64_t width = -1,
const std::string& fill =
"") {
1287 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1288 TFE_NewOp(context::get_context(),
"AsString", context::get_status()), &TFE_DeleteOp);
1289 status_check(context::get_status());
1293 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1294 status_check(context::get_status());
1297 TFE_OpSetAttrInt(op.get(),
"precision", precision);
1298 TFE_OpSetAttrBool(op.get(),
"scientific", (
unsigned char)scientific);
1299 TFE_OpSetAttrBool(op.get(),
"shortest", (
unsigned char)shortest);
1300 TFE_OpSetAttrInt(op.get(),
"width", width);
1301 TFE_OpSetAttrString(op.get(),
"fill", (
void*)fill.c_str(), fill.size());
1304 int num_outputs_op = 1;
1305 TFE_TensorHandle* res[1] = {
nullptr};
1306 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1307 status_check(context::get_status());
1308 return tensor(res[0]);
1311 inline tensor asin(
const tensor& x) {
1313 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Asin", context::get_status()),
1315 status_check(context::get_status());
1319 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1320 status_check(context::get_status());
1325 int num_outputs_op = 1;
1326 TFE_TensorHandle* res[1] = {
nullptr};
1327 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1328 status_check(context::get_status());
1329 return tensor(res[0]);
1332 inline tensor asinh(
const tensor& x) {
1334 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Asinh", context::get_status()),
1336 status_check(context::get_status());
1340 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1341 status_check(context::get_status());
1346 int num_outputs_op = 1;
1347 TFE_TensorHandle* res[1] = {
nullptr};
1348 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1349 status_check(context::get_status());
1350 return tensor(res[0]);
1353 inline tensor assert_cardinality_dataset(
const tensor& input_dataset,
const tensor& cardinality,
1354 const std::vector<datatype>& output_types,
1355 const std::vector<std::vector<int64_t>>& output_shapes) {
1357 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1358 TFE_NewOp(context::get_context(),
"AssertCardinalityDataset", context::get_status()), &TFE_DeleteOp);
1359 status_check(context::get_status());
1363 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
1364 status_check(context::get_status());
1366 TFE_OpAddInput(op.get(), cardinality.tfe_handle.get(), context::get_status());
1367 status_check(context::get_status());
1370 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
1371 static_cast<int>(output_types.size()));
1373 std::vector<const int64_t*> output_shapes_values;
1374 output_shapes_values.reserve(output_shapes.size());
1375 std::vector<int> output_shapes_ndims;
1376 output_shapes_ndims.reserve(output_shapes.size());
1377 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
1378 [](
const auto& v) { return v.data(); });
1379 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
1380 [](
const auto& v) { return static_cast<int>(v.size()); });
1381 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
1382 static_cast<int>(output_shapes.size()), context::get_status());
1383 status_check(context::get_status());
1386 int num_outputs_op = 1;
1387 TFE_TensorHandle* res[1] = {
nullptr};
1388 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1389 status_check(context::get_status());
1390 return tensor(res[0]);
1393 inline tensor assert_next_dataset(
const tensor& input_dataset,
const tensor& transformations,
1394 const std::vector<datatype>& output_types,
1395 const std::vector<std::vector<int64_t>>& output_shapes) {
1397 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1398 TFE_NewOp(context::get_context(),
"AssertNextDataset", context::get_status()), &TFE_DeleteOp);
1399 status_check(context::get_status());
1403 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
1404 status_check(context::get_status());
1406 TFE_OpAddInput(op.get(), transformations.tfe_handle.get(), context::get_status());
1407 status_check(context::get_status());
1410 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
1411 static_cast<int>(output_types.size()));
1413 std::vector<const int64_t*> output_shapes_values;
1414 output_shapes_values.reserve(output_shapes.size());
1415 std::vector<int> output_shapes_ndims;
1416 output_shapes_ndims.reserve(output_shapes.size());
1417 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
1418 [](
const auto& v) { return v.data(); });
1419 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
1420 [](
const auto& v) { return static_cast<int>(v.size()); });
1421 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
1422 static_cast<int>(output_shapes.size()), context::get_status());
1423 status_check(context::get_status());
1426 int num_outputs_op = 1;
1427 TFE_TensorHandle* res[1] = {
nullptr};
1428 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1429 status_check(context::get_status());
1430 return tensor(res[0]);
1433 inline tensor assign(
const tensor& ref,
const tensor& value,
bool validate_shape =
true,
bool use_locking =
true) {
1435 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1436 TFE_NewOp(context::get_context(),
"Assign", context::get_status()), &TFE_DeleteOp);
1437 status_check(context::get_status());
1441 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
1442 status_check(context::get_status());
1444 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
1445 status_check(context::get_status());
1448 TFE_OpSetAttrBool(op.get(),
"validate_shape", (
unsigned char)validate_shape);
1449 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1452 int num_outputs_op = 1;
1453 TFE_TensorHandle* res[1] = {
nullptr};
1454 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1455 status_check(context::get_status());
1456 return tensor(res[0]);
1459 inline tensor assign_add(
const tensor& ref,
const tensor& value,
bool use_locking =
false) {
1461 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1462 TFE_NewOp(context::get_context(),
"AssignAdd", context::get_status()), &TFE_DeleteOp);
1463 status_check(context::get_status());
1467 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
1468 status_check(context::get_status());
1470 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
1471 status_check(context::get_status());
1474 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1477 int num_outputs_op = 1;
1478 TFE_TensorHandle* res[1] = {
nullptr};
1479 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1480 status_check(context::get_status());
1481 return tensor(res[0]);
1484 inline tensor assign_sub(
const tensor& ref,
const tensor& value,
bool use_locking =
false) {
1486 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1487 TFE_NewOp(context::get_context(),
"AssignSub", context::get_status()), &TFE_DeleteOp);
1488 status_check(context::get_status());
1492 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
1493 status_check(context::get_status());
1495 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
1496 status_check(context::get_status());
1499 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
1502 int num_outputs_op = 1;
1503 TFE_TensorHandle* res[1] = {
nullptr};
1504 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1505 status_check(context::get_status());
1506 return tensor(res[0]);
1509 inline tensor atan(
const tensor& x) {
1511 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Atan", context::get_status()),
1513 status_check(context::get_status());
1517 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1518 status_check(context::get_status());
1523 int num_outputs_op = 1;
1524 TFE_TensorHandle* res[1] = {
nullptr};
1525 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1526 status_check(context::get_status());
1527 return tensor(res[0]);
1530 inline tensor atan2(
const tensor& y,
const tensor& x) {
1532 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Atan2", context::get_status()),
1534 status_check(context::get_status());
1538 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
1539 status_check(context::get_status());
1541 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1542 status_check(context::get_status());
1547 int num_outputs_op = 1;
1548 TFE_TensorHandle* res[1] = {
nullptr};
1549 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1550 status_check(context::get_status());
1551 return tensor(res[0]);
1554 inline tensor atanh(
const tensor& x) {
1556 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Atanh", context::get_status()),
1558 status_check(context::get_status());
1562 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
1563 status_check(context::get_status());
1568 int num_outputs_op = 1;
1569 TFE_TensorHandle* res[1] = {
nullptr};
1570 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1571 status_check(context::get_status());
1572 return tensor(res[0]);
1575 inline tensor audio_spectrogram(
const tensor& input, int64_t window_size, int64_t stride,
1576 bool magnitude_squared =
false) {
1578 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1579 TFE_NewOp(context::get_context(),
"AudioSpectrogram", context::get_status()), &TFE_DeleteOp);
1580 status_check(context::get_status());
1584 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1585 status_check(context::get_status());
1588 TFE_OpSetAttrInt(op.get(),
"window_size", window_size);
1589 TFE_OpSetAttrInt(op.get(),
"stride", stride);
1590 TFE_OpSetAttrBool(op.get(),
"magnitude_squared", (
unsigned char)magnitude_squared);
1593 int num_outputs_op = 1;
1594 TFE_TensorHandle* res[1] = {
nullptr};
1595 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1596 status_check(context::get_status());
1597 return tensor(res[0]);
1600 inline tensor audio_summary(
const tensor& tag,
const tensor& input_tensor,
float sample_rate, int64_t max_outputs = 3) {
1602 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1603 TFE_NewOp(context::get_context(),
"AudioSummary", context::get_status()), &TFE_DeleteOp);
1604 status_check(context::get_status());
1608 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
1609 status_check(context::get_status());
1611 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
1612 status_check(context::get_status());
1615 TFE_OpSetAttrFloat(op.get(),
"sample_rate", sample_rate);
1616 TFE_OpSetAttrInt(op.get(),
"max_outputs", max_outputs);
1619 int num_outputs_op = 1;
1620 TFE_TensorHandle* res[1] = {
nullptr};
1621 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1622 status_check(context::get_status());
1623 return tensor(res[0]);
1626 inline tensor audio_summary_v2(
const tensor& tag,
const tensor& input_tensor,
const tensor& sample_rate,
1627 int64_t max_outputs = 3) {
1629 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1630 TFE_NewOp(context::get_context(),
"AudioSummaryV2", context::get_status()), &TFE_DeleteOp);
1631 status_check(context::get_status());
1635 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
1636 status_check(context::get_status());
1638 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
1639 status_check(context::get_status());
1641 TFE_OpAddInput(op.get(), sample_rate.tfe_handle.get(), context::get_status());
1642 status_check(context::get_status());
1645 TFE_OpSetAttrInt(op.get(),
"max_outputs", max_outputs);
1648 int num_outputs_op = 1;
1649 TFE_TensorHandle* res[1] = {
nullptr};
1650 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1651 status_check(context::get_status());
1652 return tensor(res[0]);
1655 inline tensor auto_shard_dataset(
const tensor& input_dataset,
const tensor& num_workers,
const tensor& index,
1656 const std::vector<datatype>& output_types,
1657 const std::vector<std::vector<int64_t>>& output_shapes,
1658 int64_t auto_shard_policy = 0) {
1660 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1661 TFE_NewOp(context::get_context(),
"AutoShardDataset", context::get_status()), &TFE_DeleteOp);
1662 status_check(context::get_status());
1666 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
1667 status_check(context::get_status());
1669 TFE_OpAddInput(op.get(), num_workers.tfe_handle.get(), context::get_status());
1670 status_check(context::get_status());
1672 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
1673 status_check(context::get_status());
1676 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
1677 static_cast<int>(output_types.size()));
1679 std::vector<const int64_t*> output_shapes_values;
1680 output_shapes_values.reserve(output_shapes.size());
1681 std::vector<int> output_shapes_ndims;
1682 output_shapes_ndims.reserve(output_shapes.size());
1683 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
1684 [](
const auto& v) { return v.data(); });
1685 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
1686 [](
const auto& v) { return static_cast<int>(v.size()); });
1687 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
1688 static_cast<int>(output_shapes.size()), context::get_status());
1689 status_check(context::get_status());
1691 TFE_OpSetAttrInt(op.get(),
"auto_shard_policy", auto_shard_policy);
1694 int num_outputs_op = 1;
1695 TFE_TensorHandle* res[1] = {
nullptr};
1696 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1697 status_check(context::get_status());
1698 return tensor(res[0]);
1701 inline tensor avg_pool(
const tensor& value,
const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
1702 const std::string& padding,
const std::string& data_format =
"NHWC") {
1704 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1705 TFE_NewOp(context::get_context(),
"AvgPool", context::get_status()), &TFE_DeleteOp);
1706 status_check(context::get_status());
1710 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
1711 status_check(context::get_status());
1714 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
1715 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
1716 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
1717 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
1720 int num_outputs_op = 1;
1721 TFE_TensorHandle* res[1] = {
nullptr};
1722 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1723 status_check(context::get_status());
1724 return tensor(res[0]);
1727 inline tensor avg_pool3_d(
const tensor& input,
const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
1728 const std::string& padding,
const std::string& data_format =
"NDHWC") {
1730 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1731 TFE_NewOp(context::get_context(),
"AvgPool3D", context::get_status()), &TFE_DeleteOp);
1732 status_check(context::get_status());
1736 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1737 status_check(context::get_status());
1740 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
1741 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
1742 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
1743 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
1746 int num_outputs_op = 1;
1747 TFE_TensorHandle* res[1] = {
nullptr};
1748 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1749 status_check(context::get_status());
1750 return tensor(res[0]);
1753 inline tensor avg_pool3_d_grad(
const tensor& orig_input_shape,
const tensor& grad,
const std::vector<int64_t>& ksize,
1754 const std::vector<int64_t>& strides,
const std::string& padding,
1755 const std::string& data_format =
"NDHWC") {
1757 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1758 TFE_NewOp(context::get_context(),
"AvgPool3DGrad", context::get_status()), &TFE_DeleteOp);
1759 status_check(context::get_status());
1763 TFE_OpAddInput(op.get(), orig_input_shape.tfe_handle.get(), context::get_status());
1764 status_check(context::get_status());
1766 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1767 status_check(context::get_status());
1770 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
1771 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
1772 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
1773 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
1776 int num_outputs_op = 1;
1777 TFE_TensorHandle* res[1] = {
nullptr};
1778 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1779 status_check(context::get_status());
1780 return tensor(res[0]);
1783 inline tensor avg_pool_grad(
const tensor& orig_input_shape,
const tensor& grad,
const std::vector<int64_t>& ksize,
1784 const std::vector<int64_t>& strides,
const std::string& padding,
1785 const std::string& data_format =
"NHWC") {
1787 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1788 TFE_NewOp(context::get_context(),
"AvgPoolGrad", context::get_status()), &TFE_DeleteOp);
1789 status_check(context::get_status());
1793 TFE_OpAddInput(op.get(), orig_input_shape.tfe_handle.get(), context::get_status());
1794 status_check(context::get_status());
1796 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1797 status_check(context::get_status());
1800 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
1801 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
1802 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
1803 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
1806 int num_outputs_op = 1;
1807 TFE_TensorHandle* res[1] = {
nullptr};
1808 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1809 status_check(context::get_status());
1810 return tensor(res[0]);
1813 inline tensor banded_triangular_solve(
const tensor& matrix,
const tensor& rhs,
bool lower =
true,
1814 bool adjoint =
false) {
1816 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1817 TFE_NewOp(context::get_context(),
"BandedTriangularSolve", context::get_status()), &TFE_DeleteOp);
1818 status_check(context::get_status());
1822 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
1823 status_check(context::get_status());
1825 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
1826 status_check(context::get_status());
1829 TFE_OpSetAttrBool(op.get(),
"lower", (
unsigned char)lower);
1830 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
1833 int num_outputs_op = 1;
1834 TFE_TensorHandle* res[1] = {
nullptr};
1835 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1836 status_check(context::get_status());
1837 return tensor(res[0]);
1840 inline tensor barrier(
const std::vector<datatype>& component_types,
const std::vector<std::vector<int64_t>>& shapes,
1841 int64_t capacity = -1,
const std::string& container =
"",
const std::string& shared_name =
"") {
1843 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1844 TFE_NewOp(context::get_context(),
"Barrier", context::get_status()), &TFE_DeleteOp);
1845 status_check(context::get_status());
1850 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
1851 static_cast<int>(component_types.size()));
1853 std::vector<const int64_t*> shapes_values;
1854 shapes_values.reserve(shapes.size());
1855 std::vector<int> shapes_ndims;
1856 shapes_ndims.reserve(shapes.size());
1857 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
1858 [](
const auto& v) { return v.data(); });
1859 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
1860 [](
const auto& v) { return static_cast<int>(v.size()); });
1861 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
1862 context::get_status());
1863 status_check(context::get_status());
1865 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
1866 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
1867 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
1870 int num_outputs_op = 1;
1871 TFE_TensorHandle* res[1] = {
nullptr};
1872 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1873 status_check(context::get_status());
1874 return tensor(res[0]);
1877 inline tensor barrier_incomplete_size(
const tensor& handle) {
1879 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1880 TFE_NewOp(context::get_context(),
"BarrierIncompleteSize", context::get_status()), &TFE_DeleteOp);
1881 status_check(context::get_status());
1885 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
1886 status_check(context::get_status());
1891 int num_outputs_op = 1;
1892 TFE_TensorHandle* res[1] = {
nullptr};
1893 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1894 status_check(context::get_status());
1895 return tensor(res[0]);
1898 inline tensor barrier_ready_size(
const tensor& handle) {
1900 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1901 TFE_NewOp(context::get_context(),
"BarrierReadySize", context::get_status()), &TFE_DeleteOp);
1902 status_check(context::get_status());
1906 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
1907 status_check(context::get_status());
1912 int num_outputs_op = 1;
1913 TFE_TensorHandle* res[1] = {
nullptr};
1914 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1915 status_check(context::get_status());
1916 return tensor(res[0]);
1919 inline tensor batch_cholesky(
const tensor& input) {
1921 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1922 TFE_NewOp(context::get_context(),
"BatchCholesky", context::get_status()), &TFE_DeleteOp);
1923 status_check(context::get_status());
1927 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
1928 status_check(context::get_status());
1933 int num_outputs_op = 1;
1934 TFE_TensorHandle* res[1] = {
nullptr};
1935 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1936 status_check(context::get_status());
1937 return tensor(res[0]);
1940 inline tensor batch_cholesky_grad(
const tensor& l,
const tensor& grad) {
1942 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1943 TFE_NewOp(context::get_context(),
"BatchCholeskyGrad", context::get_status()), &TFE_DeleteOp);
1944 status_check(context::get_status());
1948 TFE_OpAddInput(op.get(), l.tfe_handle.get(), context::get_status());
1949 status_check(context::get_status());
1951 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
1952 status_check(context::get_status());
1957 int num_outputs_op = 1;
1958 TFE_TensorHandle* res[1] = {
nullptr};
1959 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
1960 status_check(context::get_status());
1961 return tensor(res[0]);
1964 inline tensor batch_dataset(
const tensor& input_dataset,
const tensor& batch_size,
1965 const std::vector<datatype>& output_types,
1966 const std::vector<std::vector<int64_t>>& output_shapes) {
1968 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
1969 TFE_NewOp(context::get_context(),
"BatchDataset", context::get_status()), &TFE_DeleteOp);
1970 status_check(context::get_status());
1974 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
1975 status_check(context::get_status());
1977 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
1978 status_check(context::get_status());
1981 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
1982 static_cast<int>(output_types.size()));
1984 std::vector<const int64_t*> output_shapes_values;
1985 output_shapes_values.reserve(output_shapes.size());
1986 std::vector<int> output_shapes_ndims;
1987 output_shapes_ndims.reserve(output_shapes.size());
1988 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
1989 [](
const auto& v) { return v.data(); });
1990 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
1991 [](
const auto& v) { return static_cast<int>(v.size()); });
1992 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
1993 static_cast<int>(output_shapes.size()), context::get_status());
1994 status_check(context::get_status());
1997 int num_outputs_op = 1;
1998 TFE_TensorHandle* res[1] = {
nullptr};
1999 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2000 status_check(context::get_status());
2001 return tensor(res[0]);
2004 inline tensor batch_dataset_v2(
const tensor& input_dataset,
const tensor& batch_size,
const tensor& drop_remainder,
2005 const std::vector<datatype>& output_types,
2006 const std::vector<std::vector<int64_t>>& output_shapes,
bool parallel_copy =
false) {
2008 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2009 TFE_NewOp(context::get_context(),
"BatchDatasetV2", context::get_status()), &TFE_DeleteOp);
2010 status_check(context::get_status());
2014 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
2015 status_check(context::get_status());
2017 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
2018 status_check(context::get_status());
2020 TFE_OpAddInput(op.get(), drop_remainder.tfe_handle.get(), context::get_status());
2021 status_check(context::get_status());
2024 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
2025 static_cast<int>(output_types.size()));
2027 std::vector<const int64_t*> output_shapes_values;
2028 output_shapes_values.reserve(output_shapes.size());
2029 std::vector<int> output_shapes_ndims;
2030 output_shapes_ndims.reserve(output_shapes.size());
2031 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
2032 [](
const auto& v) { return v.data(); });
2033 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
2034 [](
const auto& v) { return static_cast<int>(v.size()); });
2035 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
2036 static_cast<int>(output_shapes.size()), context::get_status());
2037 status_check(context::get_status());
2039 TFE_OpSetAttrBool(op.get(),
"parallel_copy", (
unsigned char)parallel_copy);
2042 int num_outputs_op = 1;
2043 TFE_TensorHandle* res[1] = {
nullptr};
2044 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2045 status_check(context::get_status());
2046 return tensor(res[0]);
2049 inline tensor batch_f_f_t(
const tensor& input) {
2051 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2052 TFE_NewOp(context::get_context(),
"BatchFFT", context::get_status()), &TFE_DeleteOp);
2053 status_check(context::get_status());
2057 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2058 status_check(context::get_status());
2063 int num_outputs_op = 1;
2064 TFE_TensorHandle* res[1] = {
nullptr};
2065 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2066 status_check(context::get_status());
2067 return tensor(res[0]);
2070 inline tensor batch_f_f_t2_d(
const tensor& input) {
2072 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2073 TFE_NewOp(context::get_context(),
"BatchFFT2D", context::get_status()), &TFE_DeleteOp);
2074 status_check(context::get_status());
2078 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2079 status_check(context::get_status());
2084 int num_outputs_op = 1;
2085 TFE_TensorHandle* res[1] = {
nullptr};
2086 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2087 status_check(context::get_status());
2088 return tensor(res[0]);
2091 inline tensor batch_f_f_t3_d(
const tensor& input) {
2093 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2094 TFE_NewOp(context::get_context(),
"BatchFFT3D", context::get_status()), &TFE_DeleteOp);
2095 status_check(context::get_status());
2099 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2100 status_check(context::get_status());
2105 int num_outputs_op = 1;
2106 TFE_TensorHandle* res[1] = {
nullptr};
2107 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2108 status_check(context::get_status());
2109 return tensor(res[0]);
2112 inline tensor batch_i_f_f_t(
const tensor& input) {
2114 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2115 TFE_NewOp(context::get_context(),
"BatchIFFT", context::get_status()), &TFE_DeleteOp);
2116 status_check(context::get_status());
2120 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2121 status_check(context::get_status());
2126 int num_outputs_op = 1;
2127 TFE_TensorHandle* res[1] = {
nullptr};
2128 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2129 status_check(context::get_status());
2130 return tensor(res[0]);
2133 inline tensor batch_i_f_f_t2_d(
const tensor& input) {
2135 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2136 TFE_NewOp(context::get_context(),
"BatchIFFT2D", context::get_status()), &TFE_DeleteOp);
2137 status_check(context::get_status());
2141 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2142 status_check(context::get_status());
2147 int num_outputs_op = 1;
2148 TFE_TensorHandle* res[1] = {
nullptr};
2149 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2150 status_check(context::get_status());
2151 return tensor(res[0]);
2154 inline tensor batch_i_f_f_t3_d(
const tensor& input) {
2156 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2157 TFE_NewOp(context::get_context(),
"BatchIFFT3D", context::get_status()), &TFE_DeleteOp);
2158 status_check(context::get_status());
2162 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2163 status_check(context::get_status());
2168 int num_outputs_op = 1;
2169 TFE_TensorHandle* res[1] = {
nullptr};
2170 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2171 status_check(context::get_status());
2172 return tensor(res[0]);
2175 inline tensor batch_mat_mul(
const tensor& x,
const tensor& y,
bool adj_x =
false,
bool adj_y =
false) {
2177 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2178 TFE_NewOp(context::get_context(),
"BatchMatMul", context::get_status()), &TFE_DeleteOp);
2179 status_check(context::get_status());
2183 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2184 status_check(context::get_status());
2186 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
2187 status_check(context::get_status());
2190 TFE_OpSetAttrBool(op.get(),
"adj_x", (
unsigned char)adj_x);
2191 TFE_OpSetAttrBool(op.get(),
"adj_y", (
unsigned char)adj_y);
2194 int num_outputs_op = 1;
2195 TFE_TensorHandle* res[1] = {
nullptr};
2196 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2197 status_check(context::get_status());
2198 return tensor(res[0]);
2201 inline tensor batch_mat_mul_v2(
const tensor& x,
const tensor& y,
bool adj_x =
false,
bool adj_y =
false) {
2203 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2204 TFE_NewOp(context::get_context(),
"BatchMatMulV2", context::get_status()), &TFE_DeleteOp);
2205 status_check(context::get_status());
2209 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2210 status_check(context::get_status());
2212 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
2213 status_check(context::get_status());
2216 TFE_OpSetAttrBool(op.get(),
"adj_x", (
unsigned char)adj_x);
2217 TFE_OpSetAttrBool(op.get(),
"adj_y", (
unsigned char)adj_y);
2220 int num_outputs_op = 1;
2221 TFE_TensorHandle* res[1] = {
nullptr};
2222 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2223 status_check(context::get_status());
2224 return tensor(res[0]);
2227 inline tensor batch_matrix_band_part(
const tensor& input,
const tensor& num_lower,
const tensor& num_upper) {
2229 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2230 TFE_NewOp(context::get_context(),
"BatchMatrixBandPart", context::get_status()), &TFE_DeleteOp);
2231 status_check(context::get_status());
2235 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2236 status_check(context::get_status());
2238 TFE_OpAddInput(op.get(), num_lower.tfe_handle.get(), context::get_status());
2239 status_check(context::get_status());
2241 TFE_OpAddInput(op.get(), num_upper.tfe_handle.get(), context::get_status());
2242 status_check(context::get_status());
2247 int num_outputs_op = 1;
2248 TFE_TensorHandle* res[1] = {
nullptr};
2249 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2250 status_check(context::get_status());
2251 return tensor(res[0]);
2254 inline tensor batch_matrix_determinant(
const tensor& input) {
2256 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2257 TFE_NewOp(context::get_context(),
"BatchMatrixDeterminant", context::get_status()), &TFE_DeleteOp);
2258 status_check(context::get_status());
2262 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2263 status_check(context::get_status());
2268 int num_outputs_op = 1;
2269 TFE_TensorHandle* res[1] = {
nullptr};
2270 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2271 status_check(context::get_status());
2272 return tensor(res[0]);
2275 inline tensor batch_matrix_diag(
const tensor& diagonal) {
2277 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2278 TFE_NewOp(context::get_context(),
"BatchMatrixDiag", context::get_status()), &TFE_DeleteOp);
2279 status_check(context::get_status());
2283 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
2284 status_check(context::get_status());
2289 int num_outputs_op = 1;
2290 TFE_TensorHandle* res[1] = {
nullptr};
2291 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2292 status_check(context::get_status());
2293 return tensor(res[0]);
2296 inline tensor batch_matrix_diag_part(
const tensor& input) {
2298 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2299 TFE_NewOp(context::get_context(),
"BatchMatrixDiagPart", context::get_status()), &TFE_DeleteOp);
2300 status_check(context::get_status());
2304 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2305 status_check(context::get_status());
2310 int num_outputs_op = 1;
2311 TFE_TensorHandle* res[1] = {
nullptr};
2312 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2313 status_check(context::get_status());
2314 return tensor(res[0]);
2317 inline tensor batch_matrix_inverse(
const tensor& input,
bool adjoint =
false) {
2319 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2320 TFE_NewOp(context::get_context(),
"BatchMatrixInverse", context::get_status()), &TFE_DeleteOp);
2321 status_check(context::get_status());
2325 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2326 status_check(context::get_status());
2329 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
2332 int num_outputs_op = 1;
2333 TFE_TensorHandle* res[1] = {
nullptr};
2334 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2335 status_check(context::get_status());
2336 return tensor(res[0]);
2339 inline tensor batch_matrix_set_diag(
const tensor& input,
const tensor& diagonal) {
2341 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2342 TFE_NewOp(context::get_context(),
"BatchMatrixSetDiag", context::get_status()), &TFE_DeleteOp);
2343 status_check(context::get_status());
2347 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2348 status_check(context::get_status());
2350 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
2351 status_check(context::get_status());
2356 int num_outputs_op = 1;
2357 TFE_TensorHandle* res[1] = {
nullptr};
2358 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2359 status_check(context::get_status());
2360 return tensor(res[0]);
2363 inline tensor batch_matrix_solve(
const tensor& matrix,
const tensor& rhs,
bool adjoint =
false) {
2365 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2366 TFE_NewOp(context::get_context(),
"BatchMatrixSolve", context::get_status()), &TFE_DeleteOp);
2367 status_check(context::get_status());
2371 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
2372 status_check(context::get_status());
2374 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
2375 status_check(context::get_status());
2378 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
2381 int num_outputs_op = 1;
2382 TFE_TensorHandle* res[1] = {
nullptr};
2383 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2384 status_check(context::get_status());
2385 return tensor(res[0]);
2388 inline tensor batch_matrix_solve_ls(
const tensor& matrix,
const tensor& rhs,
const tensor& l2_regularizer,
2391 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2392 TFE_NewOp(context::get_context(),
"BatchMatrixSolveLs", context::get_status()), &TFE_DeleteOp);
2393 status_check(context::get_status());
2397 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
2398 status_check(context::get_status());
2400 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
2401 status_check(context::get_status());
2403 TFE_OpAddInput(op.get(), l2_regularizer.tfe_handle.get(), context::get_status());
2404 status_check(context::get_status());
2407 TFE_OpSetAttrBool(op.get(),
"fast", (
unsigned char)fast);
2410 int num_outputs_op = 1;
2411 TFE_TensorHandle* res[1] = {
nullptr};
2412 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2413 status_check(context::get_status());
2414 return tensor(res[0]);
2417 inline tensor batch_matrix_triangular_solve(
const tensor& matrix,
const tensor& rhs,
bool lower =
true,
2418 bool adjoint =
false) {
2420 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2421 TFE_NewOp(context::get_context(),
"BatchMatrixTriangularSolve", context::get_status()), &TFE_DeleteOp);
2422 status_check(context::get_status());
2426 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
2427 status_check(context::get_status());
2429 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
2430 status_check(context::get_status());
2433 TFE_OpSetAttrBool(op.get(),
"lower", (
unsigned char)lower);
2434 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
2437 int num_outputs_op = 1;
2438 TFE_TensorHandle* res[1] = {
nullptr};
2439 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2440 status_check(context::get_status());
2441 return tensor(res[0]);
2444 inline tensor batch_norm_with_global_normalization(
const tensor& t,
const tensor& m,
const tensor& v,
2445 const tensor& beta,
const tensor& gamma,
float variance_epsilon,
2446 bool scale_after_normalization) {
2448 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2449 TFE_NewOp(context::get_context(),
"BatchNormWithGlobalNormalization", context::get_status()), &TFE_DeleteOp);
2450 status_check(context::get_status());
2454 TFE_OpAddInput(op.get(), t.tfe_handle.get(), context::get_status());
2455 status_check(context::get_status());
2457 TFE_OpAddInput(op.get(), m.tfe_handle.get(), context::get_status());
2458 status_check(context::get_status());
2460 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
2461 status_check(context::get_status());
2463 TFE_OpAddInput(op.get(), beta.tfe_handle.get(), context::get_status());
2464 status_check(context::get_status());
2466 TFE_OpAddInput(op.get(), gamma.tfe_handle.get(), context::get_status());
2467 status_check(context::get_status());
2470 TFE_OpSetAttrFloat(op.get(),
"variance_epsilon", variance_epsilon);
2471 TFE_OpSetAttrBool(op.get(),
"scale_after_normalization", (
unsigned char)scale_after_normalization);
2474 int num_outputs_op = 1;
2475 TFE_TensorHandle* res[1] = {
nullptr};
2476 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2477 status_check(context::get_status());
2478 return tensor(res[0]);
2481 inline tensor batch_self_adjoint_eig(
const tensor& input) {
2483 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2484 TFE_NewOp(context::get_context(),
"BatchSelfAdjointEig", context::get_status()), &TFE_DeleteOp);
2485 status_check(context::get_status());
2489 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2490 status_check(context::get_status());
2495 int num_outputs_op = 1;
2496 TFE_TensorHandle* res[1] = {
nullptr};
2497 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2498 status_check(context::get_status());
2499 return tensor(res[0]);
2502 inline tensor batch_to_space(
const tensor& input,
const tensor& crops, int64_t block_size,
2503 datatype Tidx =
static_cast<datatype
>(3)) {
2505 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2506 TFE_NewOp(context::get_context(),
"BatchToSpace", context::get_status()), &TFE_DeleteOp);
2507 status_check(context::get_status());
2511 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2512 status_check(context::get_status());
2514 TFE_OpAddInput(op.get(), crops.tfe_handle.get(), context::get_status());
2515 status_check(context::get_status());
2518 TFE_OpSetAttrInt(op.get(),
"block_size", block_size);
2519 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
2522 int num_outputs_op = 1;
2523 TFE_TensorHandle* res[1] = {
nullptr};
2524 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2525 status_check(context::get_status());
2526 return tensor(res[0]);
2529 inline tensor batch_to_space_n_d(
const tensor& input,
const tensor& block_shape,
const tensor& crops,
2530 datatype Tblock_shape =
static_cast<datatype
>(3),
2531 datatype Tcrops =
static_cast<datatype
>(3)) {
2533 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2534 TFE_NewOp(context::get_context(),
"BatchToSpaceND", context::get_status()), &TFE_DeleteOp);
2535 status_check(context::get_status());
2539 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2540 status_check(context::get_status());
2542 TFE_OpAddInput(op.get(), block_shape.tfe_handle.get(), context::get_status());
2543 status_check(context::get_status());
2545 TFE_OpAddInput(op.get(), crops.tfe_handle.get(), context::get_status());
2546 status_check(context::get_status());
2549 TFE_OpSetAttrType(op.get(),
"Tblock_shape", Tblock_shape);
2550 TFE_OpSetAttrType(op.get(),
"Tcrops", Tcrops);
2553 int num_outputs_op = 1;
2554 TFE_TensorHandle* res[1] = {
nullptr};
2555 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2556 status_check(context::get_status());
2557 return tensor(res[0]);
2560 inline tensor bessel_i0(
const tensor& x) {
2562 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2563 TFE_NewOp(context::get_context(),
"BesselI0", context::get_status()), &TFE_DeleteOp);
2564 status_check(context::get_status());
2568 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2569 status_check(context::get_status());
2574 int num_outputs_op = 1;
2575 TFE_TensorHandle* res[1] = {
nullptr};
2576 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2577 status_check(context::get_status());
2578 return tensor(res[0]);
2581 inline tensor bessel_i0e(
const tensor& x) {
2583 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2584 TFE_NewOp(context::get_context(),
"BesselI0e", context::get_status()), &TFE_DeleteOp);
2585 status_check(context::get_status());
2589 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2590 status_check(context::get_status());
2595 int num_outputs_op = 1;
2596 TFE_TensorHandle* res[1] = {
nullptr};
2597 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2598 status_check(context::get_status());
2599 return tensor(res[0]);
2602 inline tensor bessel_i1(
const tensor& x) {
2604 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2605 TFE_NewOp(context::get_context(),
"BesselI1", context::get_status()), &TFE_DeleteOp);
2606 status_check(context::get_status());
2610 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2611 status_check(context::get_status());
2616 int num_outputs_op = 1;
2617 TFE_TensorHandle* res[1] = {
nullptr};
2618 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2619 status_check(context::get_status());
2620 return tensor(res[0]);
2623 inline tensor bessel_i1e(
const tensor& x) {
2625 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2626 TFE_NewOp(context::get_context(),
"BesselI1e", context::get_status()), &TFE_DeleteOp);
2627 status_check(context::get_status());
2631 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2632 status_check(context::get_status());
2637 int num_outputs_op = 1;
2638 TFE_TensorHandle* res[1] = {
nullptr};
2639 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2640 status_check(context::get_status());
2641 return tensor(res[0]);
2644 inline tensor bessel_j0(
const tensor& x) {
2646 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2647 TFE_NewOp(context::get_context(),
"BesselJ0", context::get_status()), &TFE_DeleteOp);
2648 status_check(context::get_status());
2652 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2653 status_check(context::get_status());
2658 int num_outputs_op = 1;
2659 TFE_TensorHandle* res[1] = {
nullptr};
2660 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2661 status_check(context::get_status());
2662 return tensor(res[0]);
2665 inline tensor bessel_j1(
const tensor& x) {
2667 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2668 TFE_NewOp(context::get_context(),
"BesselJ1", context::get_status()), &TFE_DeleteOp);
2669 status_check(context::get_status());
2673 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2674 status_check(context::get_status());
2679 int num_outputs_op = 1;
2680 TFE_TensorHandle* res[1] = {
nullptr};
2681 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2682 status_check(context::get_status());
2683 return tensor(res[0]);
2686 inline tensor bessel_k0(
const tensor& x) {
2688 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2689 TFE_NewOp(context::get_context(),
"BesselK0", context::get_status()), &TFE_DeleteOp);
2690 status_check(context::get_status());
2694 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2695 status_check(context::get_status());
2700 int num_outputs_op = 1;
2701 TFE_TensorHandle* res[1] = {
nullptr};
2702 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2703 status_check(context::get_status());
2704 return tensor(res[0]);
2707 inline tensor bessel_k0e(
const tensor& x) {
2709 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2710 TFE_NewOp(context::get_context(),
"BesselK0e", context::get_status()), &TFE_DeleteOp);
2711 status_check(context::get_status());
2715 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2716 status_check(context::get_status());
2721 int num_outputs_op = 1;
2722 TFE_TensorHandle* res[1] = {
nullptr};
2723 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2724 status_check(context::get_status());
2725 return tensor(res[0]);
2728 inline tensor bessel_k1(
const tensor& x) {
2730 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2731 TFE_NewOp(context::get_context(),
"BesselK1", context::get_status()), &TFE_DeleteOp);
2732 status_check(context::get_status());
2736 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2737 status_check(context::get_status());
2742 int num_outputs_op = 1;
2743 TFE_TensorHandle* res[1] = {
nullptr};
2744 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2745 status_check(context::get_status());
2746 return tensor(res[0]);
2749 inline tensor bessel_k1e(
const tensor& x) {
2751 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2752 TFE_NewOp(context::get_context(),
"BesselK1e", context::get_status()), &TFE_DeleteOp);
2753 status_check(context::get_status());
2757 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2758 status_check(context::get_status());
2763 int num_outputs_op = 1;
2764 TFE_TensorHandle* res[1] = {
nullptr};
2765 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2766 status_check(context::get_status());
2767 return tensor(res[0]);
2770 inline tensor bessel_y0(
const tensor& x) {
2772 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2773 TFE_NewOp(context::get_context(),
"BesselY0", context::get_status()), &TFE_DeleteOp);
2774 status_check(context::get_status());
2778 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2779 status_check(context::get_status());
2784 int num_outputs_op = 1;
2785 TFE_TensorHandle* res[1] = {
nullptr};
2786 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2787 status_check(context::get_status());
2788 return tensor(res[0]);
2791 inline tensor bessel_y1(
const tensor& x) {
2793 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2794 TFE_NewOp(context::get_context(),
"BesselY1", context::get_status()), &TFE_DeleteOp);
2795 status_check(context::get_status());
2799 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2800 status_check(context::get_status());
2805 int num_outputs_op = 1;
2806 TFE_TensorHandle* res[1] = {
nullptr};
2807 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2808 status_check(context::get_status());
2809 return tensor(res[0]);
2812 inline tensor betainc(
const tensor& a,
const tensor& b,
const tensor& x) {
2814 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2815 TFE_NewOp(context::get_context(),
"Betainc", context::get_status()), &TFE_DeleteOp);
2816 status_check(context::get_status());
2820 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
2821 status_check(context::get_status());
2823 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
2824 status_check(context::get_status());
2826 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2827 status_check(context::get_status());
2832 int num_outputs_op = 1;
2833 TFE_TensorHandle* res[1] = {
nullptr};
2834 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2835 status_check(context::get_status());
2836 return tensor(res[0]);
2839 inline tensor bias_add(
const tensor& value,
const tensor& bias,
const std::string& data_format =
"NHWC") {
2841 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2842 TFE_NewOp(context::get_context(),
"BiasAdd", context::get_status()), &TFE_DeleteOp);
2843 status_check(context::get_status());
2847 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
2848 status_check(context::get_status());
2850 TFE_OpAddInput(op.get(), bias.tfe_handle.get(), context::get_status());
2851 status_check(context::get_status());
2854 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
2857 int num_outputs_op = 1;
2858 TFE_TensorHandle* res[1] = {
nullptr};
2859 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2860 status_check(context::get_status());
2861 return tensor(res[0]);
2864 inline tensor bias_add_grad(
const tensor& out_backprop,
const std::string& data_format =
"NHWC") {
2866 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2867 TFE_NewOp(context::get_context(),
"BiasAddGrad", context::get_status()), &TFE_DeleteOp);
2868 status_check(context::get_status());
2872 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
2873 status_check(context::get_status());
2876 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
2879 int num_outputs_op = 1;
2880 TFE_TensorHandle* res[1] = {
nullptr};
2881 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2882 status_check(context::get_status());
2883 return tensor(res[0]);
2886 inline tensor bias_add_v1(
const tensor& value,
const tensor& bias) {
2888 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2889 TFE_NewOp(context::get_context(),
"BiasAddV1", context::get_status()), &TFE_DeleteOp);
2890 status_check(context::get_status());
2894 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
2895 status_check(context::get_status());
2897 TFE_OpAddInput(op.get(), bias.tfe_handle.get(), context::get_status());
2898 status_check(context::get_status());
2903 int num_outputs_op = 1;
2904 TFE_TensorHandle* res[1] = {
nullptr};
2905 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2906 status_check(context::get_status());
2907 return tensor(res[0]);
2910 inline tensor bincount(
const tensor& arr,
const tensor& size,
const tensor& weights) {
2912 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2913 TFE_NewOp(context::get_context(),
"Bincount", context::get_status()), &TFE_DeleteOp);
2914 status_check(context::get_status());
2918 TFE_OpAddInput(op.get(), arr.tfe_handle.get(), context::get_status());
2919 status_check(context::get_status());
2921 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
2922 status_check(context::get_status());
2924 TFE_OpAddInput(op.get(), weights.tfe_handle.get(), context::get_status());
2925 status_check(context::get_status());
2930 int num_outputs_op = 1;
2931 TFE_TensorHandle* res[1] = {
nullptr};
2932 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2933 status_check(context::get_status());
2934 return tensor(res[0]);
2937 inline tensor bitcast(
const tensor& input, datatype type) {
2939 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2940 TFE_NewOp(context::get_context(),
"Bitcast", context::get_status()), &TFE_DeleteOp);
2941 status_check(context::get_status());
2945 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
2946 status_check(context::get_status());
2949 TFE_OpSetAttrType(op.get(),
"type", type);
2952 int num_outputs_op = 1;
2953 TFE_TensorHandle* res[1] = {
nullptr};
2954 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2955 status_check(context::get_status());
2956 return tensor(res[0]);
2959 inline tensor bitwise_and(
const tensor& x,
const tensor& y) {
2961 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2962 TFE_NewOp(context::get_context(),
"BitwiseAnd", context::get_status()), &TFE_DeleteOp);
2963 status_check(context::get_status());
2967 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2968 status_check(context::get_status());
2970 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
2971 status_check(context::get_status());
2976 int num_outputs_op = 1;
2977 TFE_TensorHandle* res[1] = {
nullptr};
2978 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
2979 status_check(context::get_status());
2980 return tensor(res[0]);
2983 inline tensor bitwise_or(
const tensor& x,
const tensor& y) {
2985 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
2986 TFE_NewOp(context::get_context(),
"BitwiseOr", context::get_status()), &TFE_DeleteOp);
2987 status_check(context::get_status());
2991 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
2992 status_check(context::get_status());
2994 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
2995 status_check(context::get_status());
3000 int num_outputs_op = 1;
3001 TFE_TensorHandle* res[1] = {
nullptr};
3002 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3003 status_check(context::get_status());
3004 return tensor(res[0]);
3007 inline tensor bitwise_xor(
const tensor& x,
const tensor& y) {
3009 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3010 TFE_NewOp(context::get_context(),
"BitwiseXor", context::get_status()), &TFE_DeleteOp);
3011 status_check(context::get_status());
3015 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
3016 status_check(context::get_status());
3018 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
3019 status_check(context::get_status());
3024 int num_outputs_op = 1;
3025 TFE_TensorHandle* res[1] = {
nullptr};
3026 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3027 status_check(context::get_status());
3028 return tensor(res[0]);
3031 inline tensor boosted_trees_aggregate_stats(
const tensor& node_ids,
const tensor& gradients,
const tensor& hessians,
3032 const tensor& feature, int64_t max_splits, int64_t num_buckets) {
3034 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3035 TFE_NewOp(context::get_context(),
"BoostedTreesAggregateStats", context::get_status()), &TFE_DeleteOp);
3036 status_check(context::get_status());
3040 TFE_OpAddInput(op.get(), node_ids.tfe_handle.get(), context::get_status());
3041 status_check(context::get_status());
3043 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
3044 status_check(context::get_status());
3046 TFE_OpAddInput(op.get(), hessians.tfe_handle.get(), context::get_status());
3047 status_check(context::get_status());
3049 TFE_OpAddInput(op.get(), feature.tfe_handle.get(), context::get_status());
3050 status_check(context::get_status());
3053 TFE_OpSetAttrInt(op.get(),
"max_splits", max_splits);
3054 TFE_OpSetAttrInt(op.get(),
"num_buckets", num_buckets);
3057 int num_outputs_op = 1;
3058 TFE_TensorHandle* res[1] = {
nullptr};
3059 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3060 status_check(context::get_status());
3061 return tensor(res[0]);
3064 inline tensor boosted_trees_bucketize(
const std::vector<tensor>& float_values,
3065 const std::vector<tensor>& bucket_boundaries) {
3067 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3068 TFE_NewOp(context::get_context(),
"BoostedTreesBucketize", context::get_status()), &TFE_DeleteOp);
3069 status_check(context::get_status());
3073 std::vector<TFE_TensorHandle*> float_values_handles;
3074 float_values_handles.reserve(float_values.size());
3075 std::transform(float_values.begin(), float_values.end(), std::back_inserter(float_values_handles),
3076 [](
const auto& t) { return t.tfe_handle.get(); });
3077 TFE_OpAddInputList(op.get(), float_values_handles.data(),
static_cast<int>(float_values.size()),
3078 context::get_status());
3079 status_check(context::get_status());
3081 std::vector<TFE_TensorHandle*> bucket_boundaries_handles;
3082 bucket_boundaries_handles.reserve(bucket_boundaries.size());
3083 std::transform(bucket_boundaries.begin(), bucket_boundaries.end(), std::back_inserter(bucket_boundaries_handles),
3084 [](
const auto& t) { return t.tfe_handle.get(); });
3085 TFE_OpAddInputList(op.get(), bucket_boundaries_handles.data(),
static_cast<int>(bucket_boundaries.size()),
3086 context::get_status());
3087 status_check(context::get_status());
3090 TFE_OpSetAttrInt(op.get(),
"num_features", float_values.size());
3093 int num_outputs_op = 1;
3094 TFE_TensorHandle* res[1] = {
nullptr};
3095 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3096 status_check(context::get_status());
3097 return tensor(res[0]);
3100 inline tensor boosted_trees_center_bias(
const tensor& tree_ensemble_handle,
const tensor& mean_gradients,
3101 const tensor& mean_hessians,
const tensor& l1,
const tensor& l2) {
3103 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3104 TFE_NewOp(context::get_context(),
"BoostedTreesCenterBias", context::get_status()), &TFE_DeleteOp);
3105 status_check(context::get_status());
3109 TFE_OpAddInput(op.get(), tree_ensemble_handle.tfe_handle.get(), context::get_status());
3110 status_check(context::get_status());
3112 TFE_OpAddInput(op.get(), mean_gradients.tfe_handle.get(), context::get_status());
3113 status_check(context::get_status());
3115 TFE_OpAddInput(op.get(), mean_hessians.tfe_handle.get(), context::get_status());
3116 status_check(context::get_status());
3118 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
3119 status_check(context::get_status());
3121 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
3122 status_check(context::get_status());
3127 int num_outputs_op = 1;
3128 TFE_TensorHandle* res[1] = {
nullptr};
3129 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3130 status_check(context::get_status());
3131 return tensor(res[0]);
3134 inline tensor boosted_trees_ensemble_resource_handle_op(
const std::string& container =
"",
3135 const std::string& shared_name =
"") {
3137 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3138 TFE_NewOp(context::get_context(),
"BoostedTreesEnsembleResourceHandleOp", context::get_status()), &TFE_DeleteOp);
3139 status_check(context::get_status());
3144 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
3145 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
3148 int num_outputs_op = 1;
3149 TFE_TensorHandle* res[1] = {
nullptr};
3150 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3151 status_check(context::get_status());
3152 return tensor(res[0]);
3155 inline tensor boosted_trees_example_debug_outputs(
const tensor& tree_ensemble_handle,
3156 const std::vector<tensor>& bucketized_features,
3157 int64_t logits_dimension) {
3159 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3160 TFE_NewOp(context::get_context(),
"BoostedTreesExampleDebugOutputs", context::get_status()), &TFE_DeleteOp);
3161 status_check(context::get_status());
3165 TFE_OpAddInput(op.get(), tree_ensemble_handle.tfe_handle.get(), context::get_status());
3166 status_check(context::get_status());
3168 std::vector<TFE_TensorHandle*> bucketized_features_handles;
3169 bucketized_features_handles.reserve(bucketized_features.size());
3170 std::transform(bucketized_features.begin(), bucketized_features.end(),
3171 std::back_inserter(bucketized_features_handles), [](
const auto& t) { return t.tfe_handle.get(); });
3172 TFE_OpAddInputList(op.get(), bucketized_features_handles.data(),
static_cast<int>(bucketized_features.size()),
3173 context::get_status());
3174 status_check(context::get_status());
3177 TFE_OpSetAttrInt(op.get(),
"num_bucketized_features", bucketized_features.size());
3178 TFE_OpSetAttrInt(op.get(),
"logits_dimension", logits_dimension);
3181 int num_outputs_op = 1;
3182 TFE_TensorHandle* res[1] = {
nullptr};
3183 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3184 status_check(context::get_status());
3185 return tensor(res[0]);
3188 inline tensor boosted_trees_flush_quantile_summaries(
const tensor& quantile_stream_resource_handle,
3189 int64_t num_features) {
3191 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3192 TFE_NewOp(context::get_context(),
"BoostedTreesFlushQuantileSummaries", context::get_status()), &TFE_DeleteOp);
3193 status_check(context::get_status());
3197 TFE_OpAddInput(op.get(), quantile_stream_resource_handle.tfe_handle.get(), context::get_status());
3198 status_check(context::get_status());
3201 TFE_OpSetAttrInt(op.get(),
"num_features", num_features);
3204 int num_outputs_op = 1;
3205 TFE_TensorHandle* res[1] = {
nullptr};
3206 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3207 status_check(context::get_status());
3208 return tensor(res[0]);
3211 inline tensor boosted_trees_make_quantile_summaries(
const std::vector<tensor>& float_values,
3212 const tensor& example_weights,
const tensor& epsilon) {
3214 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3215 TFE_NewOp(context::get_context(),
"BoostedTreesMakeQuantileSummaries", context::get_status()), &TFE_DeleteOp);
3216 status_check(context::get_status());
3220 std::vector<TFE_TensorHandle*> float_values_handles;
3221 float_values_handles.reserve(float_values.size());
3222 std::transform(float_values.begin(), float_values.end(), std::back_inserter(float_values_handles),
3223 [](
const auto& t) { return t.tfe_handle.get(); });
3224 TFE_OpAddInputList(op.get(), float_values_handles.data(),
static_cast<int>(float_values.size()),
3225 context::get_status());
3226 status_check(context::get_status());
3228 TFE_OpAddInput(op.get(), example_weights.tfe_handle.get(), context::get_status());
3229 status_check(context::get_status());
3231 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
3232 status_check(context::get_status());
3235 TFE_OpSetAttrInt(op.get(),
"num_features", float_values.size());
3238 int num_outputs_op = 1;
3239 TFE_TensorHandle* res[1] = {
nullptr};
3240 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3241 status_check(context::get_status());
3242 return tensor(res[0]);
3245 inline tensor boosted_trees_make_stats_summary(
const tensor& node_ids,
const tensor& gradients,
const tensor& hessians,
3246 const std::vector<tensor>& bucketized_features_list, int64_t max_splits,
3247 int64_t num_buckets) {
3249 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3250 TFE_NewOp(context::get_context(),
"BoostedTreesMakeStatsSummary", context::get_status()), &TFE_DeleteOp);
3251 status_check(context::get_status());
3255 TFE_OpAddInput(op.get(), node_ids.tfe_handle.get(), context::get_status());
3256 status_check(context::get_status());
3258 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
3259 status_check(context::get_status());
3261 TFE_OpAddInput(op.get(), hessians.tfe_handle.get(), context::get_status());
3262 status_check(context::get_status());
3264 std::vector<TFE_TensorHandle*> bucketized_features_list_handles;
3265 bucketized_features_list_handles.reserve(bucketized_features_list.size());
3266 std::transform(bucketized_features_list.begin(), bucketized_features_list.end(),
3267 std::back_inserter(bucketized_features_list_handles),
3268 [](
const auto& t) { return t.tfe_handle.get(); });
3269 TFE_OpAddInputList(op.get(), bucketized_features_list_handles.data(),
3270 static_cast<int>(bucketized_features_list.size()), context::get_status());
3271 status_check(context::get_status());
3274 TFE_OpSetAttrInt(op.get(),
"max_splits", max_splits);
3275 TFE_OpSetAttrInt(op.get(),
"num_buckets", num_buckets);
3276 TFE_OpSetAttrInt(op.get(),
"num_features", bucketized_features_list.size());
3279 int num_outputs_op = 1;
3280 TFE_TensorHandle* res[1] = {
nullptr};
3281 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3282 status_check(context::get_status());
3283 return tensor(res[0]);
3286 inline tensor boosted_trees_predict(
const tensor& tree_ensemble_handle,
const std::vector<tensor>& bucketized_features,
3287 int64_t logits_dimension) {
3289 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3290 TFE_NewOp(context::get_context(),
"BoostedTreesPredict", context::get_status()), &TFE_DeleteOp);
3291 status_check(context::get_status());
3295 TFE_OpAddInput(op.get(), tree_ensemble_handle.tfe_handle.get(), context::get_status());
3296 status_check(context::get_status());
3298 std::vector<TFE_TensorHandle*> bucketized_features_handles;
3299 bucketized_features_handles.reserve(bucketized_features.size());
3300 std::transform(bucketized_features.begin(), bucketized_features.end(),
3301 std::back_inserter(bucketized_features_handles), [](
const auto& t) { return t.tfe_handle.get(); });
3302 TFE_OpAddInputList(op.get(), bucketized_features_handles.data(),
static_cast<int>(bucketized_features.size()),
3303 context::get_status());
3304 status_check(context::get_status());
3307 TFE_OpSetAttrInt(op.get(),
"num_bucketized_features", bucketized_features.size());
3308 TFE_OpSetAttrInt(op.get(),
"logits_dimension", logits_dimension);
3311 int num_outputs_op = 1;
3312 TFE_TensorHandle* res[1] = {
nullptr};
3313 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3314 status_check(context::get_status());
3315 return tensor(res[0]);
3318 inline tensor boosted_trees_quantile_stream_resource_get_bucket_boundaries(
3319 const tensor& quantile_stream_resource_handle, int64_t num_features) {
3321 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3322 TFE_NewOp(context::get_context(),
"BoostedTreesQuantileStreamResourceGetBucketBoundaries", context::get_status()),
3324 status_check(context::get_status());
3328 TFE_OpAddInput(op.get(), quantile_stream_resource_handle.tfe_handle.get(), context::get_status());
3329 status_check(context::get_status());
3332 TFE_OpSetAttrInt(op.get(),
"num_features", num_features);
3335 int num_outputs_op = 1;
3336 TFE_TensorHandle* res[1] = {
nullptr};
3337 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3338 status_check(context::get_status());
3339 return tensor(res[0]);
3342 inline tensor boosted_trees_quantile_stream_resource_handle_op(
const std::string& container =
"",
3343 const std::string& shared_name =
"") {
3345 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3346 TFE_NewOp(context::get_context(),
"BoostedTreesQuantileStreamResourceHandleOp", context::get_status()),
3348 status_check(context::get_status());
3353 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
3354 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
3357 int num_outputs_op = 1;
3358 TFE_TensorHandle* res[1] = {
nullptr};
3359 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3360 status_check(context::get_status());
3361 return tensor(res[0]);
3364 inline tensor broadcast_args(
const tensor& s0,
const tensor& s1) {
3366 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3367 TFE_NewOp(context::get_context(),
"BroadcastArgs", context::get_status()), &TFE_DeleteOp);
3368 status_check(context::get_status());
3372 TFE_OpAddInput(op.get(), s0.tfe_handle.get(), context::get_status());
3373 status_check(context::get_status());
3375 TFE_OpAddInput(op.get(), s1.tfe_handle.get(), context::get_status());
3376 status_check(context::get_status());
3381 int num_outputs_op = 1;
3382 TFE_TensorHandle* res[1] = {
nullptr};
3383 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3384 status_check(context::get_status());
3385 return tensor(res[0]);
3388 inline tensor broadcast_to(
const tensor& input,
const tensor& shape, datatype Tidx =
static_cast<datatype
>(3)) {
3390 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3391 TFE_NewOp(context::get_context(),
"BroadcastTo", context::get_status()), &TFE_DeleteOp);
3392 status_check(context::get_status());
3396 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3397 status_check(context::get_status());
3399 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
3400 status_check(context::get_status());
3403 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
3406 int num_outputs_op = 1;
3407 TFE_TensorHandle* res[1] = {
nullptr};
3408 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3409 status_check(context::get_status());
3410 return tensor(res[0]);
3413 inline tensor bucketize(
const tensor& input,
const std::vector<float>& boundaries) {
3415 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3416 TFE_NewOp(context::get_context(),
"Bucketize", context::get_status()), &TFE_DeleteOp);
3417 status_check(context::get_status());
3421 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3422 status_check(context::get_status());
3425 TFE_OpSetAttrFloatList(op.get(),
"boundaries", boundaries.data(),
static_cast<int>(boundaries.size()));
3428 int num_outputs_op = 1;
3429 TFE_TensorHandle* res[1] = {
nullptr};
3430 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3431 status_check(context::get_status());
3432 return tensor(res[0]);
3435 inline tensor bytes_produced_stats_dataset(
const tensor& input_dataset,
const tensor& tag,
3436 const std::vector<datatype>& output_types,
3437 const std::vector<std::vector<int64_t>>& output_shapes) {
3439 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3440 TFE_NewOp(context::get_context(),
"BytesProducedStatsDataset", context::get_status()), &TFE_DeleteOp);
3441 status_check(context::get_status());
3445 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
3446 status_check(context::get_status());
3448 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
3449 status_check(context::get_status());
3452 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
3453 static_cast<int>(output_types.size()));
3455 std::vector<const int64_t*> output_shapes_values;
3456 output_shapes_values.reserve(output_shapes.size());
3457 std::vector<int> output_shapes_ndims;
3458 output_shapes_ndims.reserve(output_shapes.size());
3459 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
3460 [](
const auto& v) { return v.data(); });
3461 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
3462 [](
const auto& v) { return static_cast<int>(v.size()); });
3463 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
3464 static_cast<int>(output_shapes.size()), context::get_status());
3465 status_check(context::get_status());
3468 int num_outputs_op = 1;
3469 TFE_TensorHandle* res[1] = {
nullptr};
3470 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3471 status_check(context::get_status());
3472 return tensor(res[0]);
3475 inline tensor c_s_r_sparse_matrix_to_dense(
const tensor& sparse_input, datatype type) {
3477 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3478 TFE_NewOp(context::get_context(),
"CSRSparseMatrixToDense", context::get_status()), &TFE_DeleteOp);
3479 status_check(context::get_status());
3483 TFE_OpAddInput(op.get(), sparse_input.tfe_handle.get(), context::get_status());
3484 status_check(context::get_status());
3487 TFE_OpSetAttrType(op.get(),
"type", type);
3490 int num_outputs_op = 1;
3491 TFE_TensorHandle* res[1] = {
nullptr};
3492 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3493 status_check(context::get_status());
3494 return tensor(res[0]);
3497 inline tensor c_s_v_dataset(
const tensor& filenames,
const tensor& compression_type,
const tensor& buffer_size,
3498 const tensor& header,
const tensor& field_delim,
const tensor& use_quote_delim,
3499 const tensor& na_value,
const tensor& select_cols,
3500 const std::vector<tensor>& record_defaults,
const std::vector<datatype>& output_types,
3501 const std::vector<std::vector<int64_t>>& output_shapes) {
3503 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3504 TFE_NewOp(context::get_context(),
"CSVDataset", context::get_status()), &TFE_DeleteOp);
3505 status_check(context::get_status());
3509 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
3510 status_check(context::get_status());
3512 TFE_OpAddInput(op.get(), compression_type.tfe_handle.get(), context::get_status());
3513 status_check(context::get_status());
3515 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
3516 status_check(context::get_status());
3518 TFE_OpAddInput(op.get(), header.tfe_handle.get(), context::get_status());
3519 status_check(context::get_status());
3521 TFE_OpAddInput(op.get(), field_delim.tfe_handle.get(), context::get_status());
3522 status_check(context::get_status());
3524 TFE_OpAddInput(op.get(), use_quote_delim.tfe_handle.get(), context::get_status());
3525 status_check(context::get_status());
3527 TFE_OpAddInput(op.get(), na_value.tfe_handle.get(), context::get_status());
3528 status_check(context::get_status());
3530 TFE_OpAddInput(op.get(), select_cols.tfe_handle.get(), context::get_status());
3531 status_check(context::get_status());
3533 std::vector<TFE_TensorHandle*> record_defaults_handles;
3534 record_defaults_handles.reserve(record_defaults.size());
3535 std::transform(record_defaults.begin(), record_defaults.end(), std::back_inserter(record_defaults_handles),
3536 [](
const auto& t) { return t.tfe_handle.get(); });
3537 TFE_OpAddInputList(op.get(), record_defaults_handles.data(),
static_cast<int>(record_defaults.size()),
3538 context::get_status());
3539 status_check(context::get_status());
3542 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
3543 static_cast<int>(output_types.size()));
3545 std::vector<const int64_t*> output_shapes_values;
3546 output_shapes_values.reserve(output_shapes.size());
3547 std::vector<int> output_shapes_ndims;
3548 output_shapes_ndims.reserve(output_shapes.size());
3549 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
3550 [](
const auto& v) { return v.data(); });
3551 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
3552 [](
const auto& v) { return static_cast<int>(v.size()); });
3553 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
3554 static_cast<int>(output_shapes.size()), context::get_status());
3555 status_check(context::get_status());
3558 int num_outputs_op = 1;
3559 TFE_TensorHandle* res[1] = {
nullptr};
3560 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3561 status_check(context::get_status());
3562 return tensor(res[0]);
3565 inline tensor cache_dataset(
const tensor& input_dataset,
const tensor& filename,
3566 const std::vector<datatype>& output_types,
3567 const std::vector<std::vector<int64_t>>& output_shapes) {
3569 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3570 TFE_NewOp(context::get_context(),
"CacheDataset", context::get_status()), &TFE_DeleteOp);
3571 status_check(context::get_status());
3575 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
3576 status_check(context::get_status());
3578 TFE_OpAddInput(op.get(), filename.tfe_handle.get(), context::get_status());
3579 status_check(context::get_status());
3582 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
3583 static_cast<int>(output_types.size()));
3585 std::vector<const int64_t*> output_shapes_values;
3586 output_shapes_values.reserve(output_shapes.size());
3587 std::vector<int> output_shapes_ndims;
3588 output_shapes_ndims.reserve(output_shapes.size());
3589 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
3590 [](
const auto& v) { return v.data(); });
3591 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
3592 [](
const auto& v) { return static_cast<int>(v.size()); });
3593 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
3594 static_cast<int>(output_shapes.size()), context::get_status());
3595 status_check(context::get_status());
3598 int num_outputs_op = 1;
3599 TFE_TensorHandle* res[1] = {
nullptr};
3600 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3601 status_check(context::get_status());
3602 return tensor(res[0]);
3605 inline tensor cache_dataset_v2(
const tensor& input_dataset,
const tensor& filename,
const tensor& cache,
3606 const std::vector<datatype>& output_types,
3607 const std::vector<std::vector<int64_t>>& output_shapes) {
3609 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3610 TFE_NewOp(context::get_context(),
"CacheDatasetV2", context::get_status()), &TFE_DeleteOp);
3611 status_check(context::get_status());
3615 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
3616 status_check(context::get_status());
3618 TFE_OpAddInput(op.get(), filename.tfe_handle.get(), context::get_status());
3619 status_check(context::get_status());
3621 TFE_OpAddInput(op.get(), cache.tfe_handle.get(), context::get_status());
3622 status_check(context::get_status());
3625 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
3626 static_cast<int>(output_types.size()));
3628 std::vector<const int64_t*> output_shapes_values;
3629 output_shapes_values.reserve(output_shapes.size());
3630 std::vector<int> output_shapes_ndims;
3631 output_shapes_ndims.reserve(output_shapes.size());
3632 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
3633 [](
const auto& v) { return v.data(); });
3634 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
3635 [](
const auto& v) { return static_cast<int>(v.size()); });
3636 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
3637 static_cast<int>(output_shapes.size()), context::get_status());
3638 status_check(context::get_status());
3641 int num_outputs_op = 1;
3642 TFE_TensorHandle* res[1] = {
nullptr};
3643 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3644 status_check(context::get_status());
3645 return tensor(res[0]);
3648 inline tensor cast(
const tensor& x, datatype SrcT, datatype DstT,
bool Truncate =
false) {
3650 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Cast", context::get_status()),
3652 status_check(context::get_status());
3656 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
3657 status_check(context::get_status());
3660 TFE_OpSetAttrType(op.get(),
"SrcT", SrcT);
3661 TFE_OpSetAttrType(op.get(),
"DstT", DstT);
3662 TFE_OpSetAttrBool(op.get(),
"Truncate", (
unsigned char)Truncate);
3665 int num_outputs_op = 1;
3666 TFE_TensorHandle* res[1] = {
nullptr};
3667 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3668 status_check(context::get_status());
3669 return tensor(res[0]);
3672 inline tensor ceil(
const tensor& x) {
3674 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Ceil", context::get_status()),
3676 status_check(context::get_status());
3680 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
3681 status_check(context::get_status());
3686 int num_outputs_op = 1;
3687 TFE_TensorHandle* res[1] = {
nullptr};
3688 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3689 status_check(context::get_status());
3690 return tensor(res[0]);
3693 inline tensor check_numerics(
const tensor& input_tensor,
const std::string& message) {
3695 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3696 TFE_NewOp(context::get_context(),
"CheckNumerics", context::get_status()), &TFE_DeleteOp);
3697 status_check(context::get_status());
3701 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
3702 status_check(context::get_status());
3705 TFE_OpSetAttrString(op.get(),
"message", (
void*)message.c_str(), message.size());
3708 int num_outputs_op = 1;
3709 TFE_TensorHandle* res[1] = {
nullptr};
3710 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3711 status_check(context::get_status());
3712 return tensor(res[0]);
3715 inline tensor check_numerics_v2(
const tensor& input_tensor,
const std::string& message) {
3717 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3718 TFE_NewOp(context::get_context(),
"CheckNumericsV2", context::get_status()), &TFE_DeleteOp);
3719 status_check(context::get_status());
3723 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
3724 status_check(context::get_status());
3727 TFE_OpSetAttrString(op.get(),
"message", (
void*)message.c_str(), message.size());
3730 int num_outputs_op = 1;
3731 TFE_TensorHandle* res[1] = {
nullptr};
3732 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3733 status_check(context::get_status());
3734 return tensor(res[0]);
3737 inline tensor cholesky(
const tensor& input) {
3739 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3740 TFE_NewOp(context::get_context(),
"Cholesky", context::get_status()), &TFE_DeleteOp);
3741 status_check(context::get_status());
3745 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3746 status_check(context::get_status());
3751 int num_outputs_op = 1;
3752 TFE_TensorHandle* res[1] = {
nullptr};
3753 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3754 status_check(context::get_status());
3755 return tensor(res[0]);
3758 inline tensor cholesky_grad(
const tensor& l,
const tensor& grad) {
3760 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3761 TFE_NewOp(context::get_context(),
"CholeskyGrad", context::get_status()), &TFE_DeleteOp);
3762 status_check(context::get_status());
3766 TFE_OpAddInput(op.get(), l.tfe_handle.get(), context::get_status());
3767 status_check(context::get_status());
3769 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
3770 status_check(context::get_status());
3775 int num_outputs_op = 1;
3776 TFE_TensorHandle* res[1] = {
nullptr};
3777 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3778 status_check(context::get_status());
3779 return tensor(res[0]);
3782 inline tensor choose_fastest_dataset(
const std::vector<tensor>& input_datasets, int64_t num_experiments,
3783 const std::vector<datatype>& output_types,
3784 const std::vector<std::vector<int64_t>>& output_shapes) {
3786 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3787 TFE_NewOp(context::get_context(),
"ChooseFastestDataset", context::get_status()), &TFE_DeleteOp);
3788 status_check(context::get_status());
3792 std::vector<TFE_TensorHandle*> input_datasets_handles;
3793 input_datasets_handles.reserve(input_datasets.size());
3794 std::transform(input_datasets.begin(), input_datasets.end(), std::back_inserter(input_datasets_handles),
3795 [](
const auto& t) { return t.tfe_handle.get(); });
3796 TFE_OpAddInputList(op.get(), input_datasets_handles.data(),
static_cast<int>(input_datasets.size()),
3797 context::get_status());
3798 status_check(context::get_status());
3801 TFE_OpSetAttrInt(op.get(),
"N", input_datasets.size());
3802 TFE_OpSetAttrInt(op.get(),
"num_experiments", num_experiments);
3803 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
3804 static_cast<int>(output_types.size()));
3806 std::vector<const int64_t*> output_shapes_values;
3807 output_shapes_values.reserve(output_shapes.size());
3808 std::vector<int> output_shapes_ndims;
3809 output_shapes_ndims.reserve(output_shapes.size());
3810 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
3811 [](
const auto& v) { return v.data(); });
3812 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
3813 [](
const auto& v) { return static_cast<int>(v.size()); });
3814 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
3815 static_cast<int>(output_shapes.size()), context::get_status());
3816 status_check(context::get_status());
3819 int num_outputs_op = 1;
3820 TFE_TensorHandle* res[1] = {
nullptr};
3821 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3822 status_check(context::get_status());
3823 return tensor(res[0]);
3826 inline tensor clip_by_value(
const tensor& t,
const tensor& clip_value_min,
const tensor& clip_value_max) {
3828 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3829 TFE_NewOp(context::get_context(),
"ClipByValue", context::get_status()), &TFE_DeleteOp);
3830 status_check(context::get_status());
3834 TFE_OpAddInput(op.get(), t.tfe_handle.get(), context::get_status());
3835 status_check(context::get_status());
3837 TFE_OpAddInput(op.get(), clip_value_min.tfe_handle.get(), context::get_status());
3838 status_check(context::get_status());
3840 TFE_OpAddInput(op.get(), clip_value_max.tfe_handle.get(), context::get_status());
3841 status_check(context::get_status());
3846 int num_outputs_op = 1;
3847 TFE_TensorHandle* res[1] = {
nullptr};
3848 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3849 status_check(context::get_status());
3850 return tensor(res[0]);
3853 inline tensor collective_bcast_recv(int64_t group_size, int64_t group_key, int64_t instance_key,
3854 const std::vector<int64_t>& shape,
const std::string& communication_hint =
"auto",
3855 float timeout_seconds = 0.0000e+00) {
3857 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3858 TFE_NewOp(context::get_context(),
"CollectiveBcastRecv", context::get_status()), &TFE_DeleteOp);
3859 status_check(context::get_status());
3864 TFE_OpSetAttrInt(op.get(),
"group_size", group_size);
3865 TFE_OpSetAttrInt(op.get(),
"group_key", group_key);
3866 TFE_OpSetAttrInt(op.get(),
"instance_key", instance_key);
3868 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
3869 status_check(context::get_status());
3871 TFE_OpSetAttrString(op.get(),
"communication_hint", (
void*)communication_hint.c_str(), communication_hint.size());
3872 TFE_OpSetAttrFloat(op.get(),
"timeout_seconds", timeout_seconds);
3875 int num_outputs_op = 1;
3876 TFE_TensorHandle* res[1] = {
nullptr};
3877 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3878 status_check(context::get_status());
3879 return tensor(res[0]);
3882 inline tensor collective_bcast_send(
const tensor& input, int64_t group_size, int64_t group_key, int64_t instance_key,
3883 const std::vector<int64_t>& shape,
const std::string& communication_hint =
"auto",
3884 float timeout_seconds = 0.0000e+00) {
3886 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3887 TFE_NewOp(context::get_context(),
"CollectiveBcastSend", context::get_status()), &TFE_DeleteOp);
3888 status_check(context::get_status());
3892 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3893 status_check(context::get_status());
3896 TFE_OpSetAttrInt(op.get(),
"group_size", group_size);
3897 TFE_OpSetAttrInt(op.get(),
"group_key", group_key);
3898 TFE_OpSetAttrInt(op.get(),
"instance_key", instance_key);
3900 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
3901 status_check(context::get_status());
3903 TFE_OpSetAttrString(op.get(),
"communication_hint", (
void*)communication_hint.c_str(), communication_hint.size());
3904 TFE_OpSetAttrFloat(op.get(),
"timeout_seconds", timeout_seconds);
3907 int num_outputs_op = 1;
3908 TFE_TensorHandle* res[1] = {
nullptr};
3909 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3910 status_check(context::get_status());
3911 return tensor(res[0]);
3914 inline tensor collective_gather(
const tensor& input, int64_t group_size, int64_t group_key, int64_t instance_key,
3915 const std::vector<int64_t>& shape,
const std::string& communication_hint =
"auto",
3916 float timeout_seconds = 0.0000e+00) {
3918 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3919 TFE_NewOp(context::get_context(),
"CollectiveGather", context::get_status()), &TFE_DeleteOp);
3920 status_check(context::get_status());
3924 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3925 status_check(context::get_status());
3928 TFE_OpSetAttrInt(op.get(),
"group_size", group_size);
3929 TFE_OpSetAttrInt(op.get(),
"group_key", group_key);
3930 TFE_OpSetAttrInt(op.get(),
"instance_key", instance_key);
3932 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
3933 status_check(context::get_status());
3935 TFE_OpSetAttrString(op.get(),
"communication_hint", (
void*)communication_hint.c_str(), communication_hint.size());
3936 TFE_OpSetAttrFloat(op.get(),
"timeout_seconds", timeout_seconds);
3939 int num_outputs_op = 1;
3940 TFE_TensorHandle* res[1] = {
nullptr};
3941 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3942 status_check(context::get_status());
3943 return tensor(res[0]);
3946 inline tensor collective_permute(
const tensor& input,
const tensor& source_target_pairs) {
3948 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3949 TFE_NewOp(context::get_context(),
"CollectivePermute", context::get_status()), &TFE_DeleteOp);
3950 status_check(context::get_status());
3954 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3955 status_check(context::get_status());
3957 TFE_OpAddInput(op.get(), source_target_pairs.tfe_handle.get(), context::get_status());
3958 status_check(context::get_status());
3963 int num_outputs_op = 1;
3964 TFE_TensorHandle* res[1] = {
nullptr};
3965 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3966 status_check(context::get_status());
3967 return tensor(res[0]);
3970 inline tensor collective_reduce(
const tensor& input, int64_t group_size, int64_t group_key, int64_t instance_key,
3971 const std::string& merge_op,
const std::string& final_op,
3972 const std::vector<int64_t>& subdiv_offsets,
const std::vector<int64_t>& wait_for,
3973 const std::string& communication_hint =
"auto",
float timeout_seconds = 0.0000e+00) {
3975 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
3976 TFE_NewOp(context::get_context(),
"CollectiveReduce", context::get_status()), &TFE_DeleteOp);
3977 status_check(context::get_status());
3981 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
3982 status_check(context::get_status());
3985 TFE_OpSetAttrInt(op.get(),
"group_size", group_size);
3986 TFE_OpSetAttrInt(op.get(),
"group_key", group_key);
3987 TFE_OpSetAttrInt(op.get(),
"instance_key", instance_key);
3988 TFE_OpSetAttrString(op.get(),
"merge_op", (
void*)merge_op.c_str(), merge_op.size());
3989 TFE_OpSetAttrString(op.get(),
"final_op", (
void*)final_op.c_str(), final_op.size());
3990 TFE_OpSetAttrIntList(op.get(),
"subdiv_offsets", subdiv_offsets.data(),
static_cast<int>(subdiv_offsets.size()));
3991 TFE_OpSetAttrIntList(op.get(),
"wait_for", wait_for.data(),
static_cast<int>(wait_for.size()));
3992 TFE_OpSetAttrString(op.get(),
"communication_hint", (
void*)communication_hint.c_str(), communication_hint.size());
3993 TFE_OpSetAttrFloat(op.get(),
"timeout_seconds", timeout_seconds);
3996 int num_outputs_op = 1;
3997 TFE_TensorHandle* res[1] = {
nullptr};
3998 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
3999 status_check(context::get_status());
4000 return tensor(res[0]);
4003 inline tensor compare_and_bitpack(
const tensor& input,
const tensor& threshold) {
4005 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4006 TFE_NewOp(context::get_context(),
"CompareAndBitpack", context::get_status()), &TFE_DeleteOp);
4007 status_check(context::get_status());
4011 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4012 status_check(context::get_status());
4014 TFE_OpAddInput(op.get(), threshold.tfe_handle.get(), context::get_status());
4015 status_check(context::get_status());
4020 int num_outputs_op = 1;
4021 TFE_TensorHandle* res[1] = {
nullptr};
4022 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4023 status_check(context::get_status());
4024 return tensor(res[0]);
4027 inline tensor complex(
const tensor& real,
const tensor& imag, datatype Tout =
static_cast<datatype
>(8)) {
4029 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4030 TFE_NewOp(context::get_context(),
"Complex", context::get_status()), &TFE_DeleteOp);
4031 status_check(context::get_status());
4035 TFE_OpAddInput(op.get(), real.tfe_handle.get(), context::get_status());
4036 status_check(context::get_status());
4038 TFE_OpAddInput(op.get(), imag.tfe_handle.get(), context::get_status());
4039 status_check(context::get_status());
4042 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
4045 int num_outputs_op = 1;
4046 TFE_TensorHandle* res[1] = {
nullptr};
4047 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4048 status_check(context::get_status());
4049 return tensor(res[0]);
4052 inline tensor complex_abs(
const tensor& x, datatype Tout =
static_cast<datatype
>(1)) {
4054 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4055 TFE_NewOp(context::get_context(),
"ComplexAbs", context::get_status()), &TFE_DeleteOp);
4056 status_check(context::get_status());
4060 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
4061 status_check(context::get_status());
4064 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
4067 int num_outputs_op = 1;
4068 TFE_TensorHandle* res[1] = {
nullptr};
4069 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4070 status_check(context::get_status());
4071 return tensor(res[0]);
4074 inline tensor compress_element(
const std::vector<tensor>& components,
const std::vector<datatype>& input_types) {
4076 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4077 TFE_NewOp(context::get_context(),
"CompressElement", context::get_status()), &TFE_DeleteOp);
4078 status_check(context::get_status());
4082 std::vector<TFE_TensorHandle*> components_handles;
4083 components_handles.reserve(components.size());
4084 std::transform(components.begin(), components.end(), std::back_inserter(components_handles),
4085 [](
const auto& t) { return t.tfe_handle.get(); });
4086 TFE_OpAddInputList(op.get(), components_handles.data(),
static_cast<int>(components.size()), context::get_status());
4087 status_check(context::get_status());
4090 TFE_OpSetAttrTypeList(op.get(),
"input_types",
reinterpret_cast<const enum TF_DataType*
>(input_types.data()),
4091 static_cast<int>(input_types.size()));
4094 int num_outputs_op = 1;
4095 TFE_TensorHandle* res[1] = {
nullptr};
4096 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4097 status_check(context::get_status());
4098 return tensor(res[0]);
4101 inline tensor concat(
const tensor& concat_dim,
const std::vector<tensor>& values) {
4103 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4104 TFE_NewOp(context::get_context(),
"Concat", context::get_status()), &TFE_DeleteOp);
4105 status_check(context::get_status());
4109 TFE_OpAddInput(op.get(), concat_dim.tfe_handle.get(), context::get_status());
4110 status_check(context::get_status());
4112 std::vector<TFE_TensorHandle*> values_handles;
4113 values_handles.reserve(values.size());
4114 std::transform(values.begin(), values.end(), std::back_inserter(values_handles),
4115 [](
const auto& t) { return t.tfe_handle.get(); });
4116 TFE_OpAddInputList(op.get(), values_handles.data(),
static_cast<int>(values.size()), context::get_status());
4117 status_check(context::get_status());
4120 TFE_OpSetAttrInt(op.get(),
"N", values.size());
4123 int num_outputs_op = 1;
4124 TFE_TensorHandle* res[1] = {
nullptr};
4125 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4126 status_check(context::get_status());
4127 return tensor(res[0]);
4130 inline tensor concat_offset(
const tensor& concat_dim,
const std::vector<tensor>& shape) {
4132 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4133 TFE_NewOp(context::get_context(),
"ConcatOffset", context::get_status()), &TFE_DeleteOp);
4134 status_check(context::get_status());
4138 TFE_OpAddInput(op.get(), concat_dim.tfe_handle.get(), context::get_status());
4139 status_check(context::get_status());
4141 std::vector<TFE_TensorHandle*> shape_handles;
4142 shape_handles.reserve(shape.size());
4143 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_handles),
4144 [](
const auto& t) { return t.tfe_handle.get(); });
4145 TFE_OpAddInputList(op.get(), shape_handles.data(),
static_cast<int>(shape.size()), context::get_status());
4146 status_check(context::get_status());
4149 TFE_OpSetAttrInt(op.get(),
"N", shape.size());
4152 int num_outputs_op = 1;
4153 TFE_TensorHandle* res[1] = {
nullptr};
4154 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4155 status_check(context::get_status());
4156 return tensor(res[0]);
4159 inline tensor concat_v2(
const std::vector<tensor>& values,
const tensor& axis,
4160 datatype Tidx =
static_cast<datatype
>(3)) {
4162 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4163 TFE_NewOp(context::get_context(),
"ConcatV2", context::get_status()), &TFE_DeleteOp);
4164 status_check(context::get_status());
4168 std::vector<TFE_TensorHandle*> values_handles;
4169 values_handles.reserve(values.size());
4170 std::transform(values.begin(), values.end(), std::back_inserter(values_handles),
4171 [](
const auto& t) { return t.tfe_handle.get(); });
4172 TFE_OpAddInputList(op.get(), values_handles.data(),
static_cast<int>(values.size()), context::get_status());
4173 status_check(context::get_status());
4175 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
4176 status_check(context::get_status());
4179 TFE_OpSetAttrInt(op.get(),
"N", values.size());
4180 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
4183 int num_outputs_op = 1;
4184 TFE_TensorHandle* res[1] = {
nullptr};
4185 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4186 status_check(context::get_status());
4187 return tensor(res[0]);
4190 inline tensor concatenate_dataset(
const tensor& input_dataset,
const tensor& another_dataset,
4191 const std::vector<datatype>& output_types,
4192 const std::vector<std::vector<int64_t>>& output_shapes) {
4194 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4195 TFE_NewOp(context::get_context(),
"ConcatenateDataset", context::get_status()), &TFE_DeleteOp);
4196 status_check(context::get_status());
4200 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
4201 status_check(context::get_status());
4203 TFE_OpAddInput(op.get(), another_dataset.tfe_handle.get(), context::get_status());
4204 status_check(context::get_status());
4207 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
4208 static_cast<int>(output_types.size()));
4210 std::vector<const int64_t*> output_shapes_values;
4211 output_shapes_values.reserve(output_shapes.size());
4212 std::vector<int> output_shapes_ndims;
4213 output_shapes_ndims.reserve(output_shapes.size());
4214 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
4215 [](
const auto& v) { return v.data(); });
4216 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
4217 [](
const auto& v) { return static_cast<int>(v.size()); });
4218 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
4219 static_cast<int>(output_shapes.size()), context::get_status());
4220 status_check(context::get_status());
4223 int num_outputs_op = 1;
4224 TFE_TensorHandle* res[1] = {
nullptr};
4225 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4226 status_check(context::get_status());
4227 return tensor(res[0]);
4230 inline tensor conditional_accumulator(datatype dtype,
const std::vector<int64_t>& shape,
4231 const std::string& container =
"",
const std::string& shared_name =
"",
4232 const std::string& reduction_type =
"MEAN") {
4234 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4235 TFE_NewOp(context::get_context(),
"ConditionalAccumulator", context::get_status()), &TFE_DeleteOp);
4236 status_check(context::get_status());
4241 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
4243 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
4244 status_check(context::get_status());
4246 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
4247 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
4248 TFE_OpSetAttrString(op.get(),
"reduction_type", (
void*)reduction_type.c_str(), reduction_type.size());
4251 int num_outputs_op = 1;
4252 TFE_TensorHandle* res[1] = {
nullptr};
4253 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4254 status_check(context::get_status());
4255 return tensor(res[0]);
4258 inline tensor configure_distributed_t_p_u(
const std::string& embedding_config =
"",
4259 const std::string& tpu_embedding_config =
"",
bool is_global_init =
false,
4260 bool enable_whole_mesh_compilations =
false,
4261 bool compilation_failure_closes_chips =
true) {
4263 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4264 TFE_NewOp(context::get_context(),
"ConfigureDistributedTPU", context::get_status()), &TFE_DeleteOp);
4265 status_check(context::get_status());
4270 TFE_OpSetAttrString(op.get(),
"embedding_config", (
void*)embedding_config.c_str(), embedding_config.size());
4271 TFE_OpSetAttrString(op.get(),
"tpu_embedding_config", (
void*)tpu_embedding_config.c_str(),
4272 tpu_embedding_config.size());
4273 TFE_OpSetAttrBool(op.get(),
"is_global_init", (
unsigned char)is_global_init);
4274 TFE_OpSetAttrBool(op.get(),
"enable_whole_mesh_compilations", (
unsigned char)enable_whole_mesh_compilations);
4275 TFE_OpSetAttrBool(op.get(),
"compilation_failure_closes_chips", (
unsigned char)compilation_failure_closes_chips);
4278 int num_outputs_op = 1;
4279 TFE_TensorHandle* res[1] = {
nullptr};
4280 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4281 status_check(context::get_status());
4282 return tensor(res[0]);
4285 inline tensor conj(
const tensor& input) {
4287 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Conj", context::get_status()),
4289 status_check(context::get_status());
4293 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4294 status_check(context::get_status());
4299 int num_outputs_op = 1;
4300 TFE_TensorHandle* res[1] = {
nullptr};
4301 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4302 status_check(context::get_status());
4303 return tensor(res[0]);
4306 inline tensor conjugate_transpose(
const tensor& x,
const tensor& perm, datatype Tperm =
static_cast<datatype
>(3)) {
4308 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4309 TFE_NewOp(context::get_context(),
"ConjugateTranspose", context::get_status()), &TFE_DeleteOp);
4310 status_check(context::get_status());
4314 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
4315 status_check(context::get_status());
4317 TFE_OpAddInput(op.get(), perm.tfe_handle.get(), context::get_status());
4318 status_check(context::get_status());
4321 TFE_OpSetAttrType(op.get(),
"Tperm", Tperm);
4324 int num_outputs_op = 1;
4325 TFE_TensorHandle* res[1] = {
nullptr};
4326 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4327 status_check(context::get_status());
4328 return tensor(res[0]);
4331 inline tensor const_tensor(
const tensor& value, datatype dtype) {
4333 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Const", context::get_status()),
4335 status_check(context::get_status());
4341 TFE_OpSetAttrTensor(op.get(),
"value", value.get_tensor().get(), context::get_status());
4342 status_check(context::get_status());
4344 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
4347 int num_outputs_op = 1;
4348 TFE_TensorHandle* res[1] = {
nullptr};
4349 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4350 status_check(context::get_status());
4351 return tensor(res[0]);
4354 inline tensor conv2_d(
const tensor& input,
const tensor& filter,
const std::vector<int64_t>& strides,
4355 const std::string& padding,
const std::vector<int64_t>& explicit_paddings,
4356 const std::vector<int64_t>& dilations,
bool use_cudnn_on_gpu =
true,
4357 const std::string& data_format =
"NHWC") {
4359 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4360 TFE_NewOp(context::get_context(),
"Conv2D", context::get_status()), &TFE_DeleteOp);
4361 status_check(context::get_status());
4365 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4366 status_check(context::get_status());
4368 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4369 status_check(context::get_status());
4372 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4373 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4374 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
4375 static_cast<int>(explicit_paddings.size()));
4376 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4377 TFE_OpSetAttrBool(op.get(),
"use_cudnn_on_gpu", (
unsigned char)use_cudnn_on_gpu);
4378 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4381 int num_outputs_op = 1;
4382 TFE_TensorHandle* res[1] = {
nullptr};
4383 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4384 status_check(context::get_status());
4385 return tensor(res[0]);
4388 inline tensor conv2_d_backprop_filter(
const tensor& input,
const tensor& filter_sizes,
const tensor& out_backprop,
4389 const std::vector<int64_t>& strides,
const std::string& padding,
4390 const std::vector<int64_t>& explicit_paddings,
4391 const std::vector<int64_t>& dilations,
bool use_cudnn_on_gpu =
true,
4392 const std::string& data_format =
"NHWC") {
4394 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4395 TFE_NewOp(context::get_context(),
"Conv2DBackpropFilter", context::get_status()), &TFE_DeleteOp);
4396 status_check(context::get_status());
4400 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4401 status_check(context::get_status());
4403 TFE_OpAddInput(op.get(), filter_sizes.tfe_handle.get(), context::get_status());
4404 status_check(context::get_status());
4406 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4407 status_check(context::get_status());
4410 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4411 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4412 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
4413 static_cast<int>(explicit_paddings.size()));
4414 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4415 TFE_OpSetAttrBool(op.get(),
"use_cudnn_on_gpu", (
unsigned char)use_cudnn_on_gpu);
4416 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4419 int num_outputs_op = 1;
4420 TFE_TensorHandle* res[1] = {
nullptr};
4421 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4422 status_check(context::get_status());
4423 return tensor(res[0]);
4426 inline tensor conv2_d_backprop_input(
const tensor& input_sizes,
const tensor& filter,
const tensor& out_backprop,
4427 const std::vector<int64_t>& strides,
const std::string& padding,
4428 const std::vector<int64_t>& explicit_paddings,
4429 const std::vector<int64_t>& dilations,
bool use_cudnn_on_gpu =
true,
4430 const std::string& data_format =
"NHWC") {
4432 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4433 TFE_NewOp(context::get_context(),
"Conv2DBackpropInput", context::get_status()), &TFE_DeleteOp);
4434 status_check(context::get_status());
4438 TFE_OpAddInput(op.get(), input_sizes.tfe_handle.get(), context::get_status());
4439 status_check(context::get_status());
4441 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4442 status_check(context::get_status());
4444 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4445 status_check(context::get_status());
4448 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4449 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4450 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
4451 static_cast<int>(explicit_paddings.size()));
4452 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4453 TFE_OpSetAttrBool(op.get(),
"use_cudnn_on_gpu", (
unsigned char)use_cudnn_on_gpu);
4454 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4457 int num_outputs_op = 1;
4458 TFE_TensorHandle* res[1] = {
nullptr};
4459 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4460 status_check(context::get_status());
4461 return tensor(res[0]);
4464 inline tensor conv3_d(
const tensor& input,
const tensor& filter,
const std::vector<int64_t>& strides,
4465 const std::string& padding,
const std::vector<int64_t>& dilations,
4466 const std::string& data_format =
"NDHWC") {
4468 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4469 TFE_NewOp(context::get_context(),
"Conv3D", context::get_status()), &TFE_DeleteOp);
4470 status_check(context::get_status());
4474 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4475 status_check(context::get_status());
4477 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4478 status_check(context::get_status());
4481 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4482 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4483 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4484 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4487 int num_outputs_op = 1;
4488 TFE_TensorHandle* res[1] = {
nullptr};
4489 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4490 status_check(context::get_status());
4491 return tensor(res[0]);
4494 inline tensor conv3_d_backprop_filter(
const tensor& input,
const tensor& filter,
const tensor& out_backprop,
4495 const std::vector<int64_t>& strides,
const std::string& padding,
4496 const std::vector<int64_t>& dilations) {
4498 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4499 TFE_NewOp(context::get_context(),
"Conv3DBackpropFilter", context::get_status()), &TFE_DeleteOp);
4500 status_check(context::get_status());
4504 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4505 status_check(context::get_status());
4507 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4508 status_check(context::get_status());
4510 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4511 status_check(context::get_status());
4514 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4515 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4516 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4519 int num_outputs_op = 1;
4520 TFE_TensorHandle* res[1] = {
nullptr};
4521 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4522 status_check(context::get_status());
4523 return tensor(res[0]);
4526 inline tensor conv3_d_backprop_filter_v2(
const tensor& input,
const tensor& filter_sizes,
const tensor& out_backprop,
4527 const std::vector<int64_t>& strides,
const std::string& padding,
4528 const std::vector<int64_t>& dilations,
4529 const std::string& data_format =
"NDHWC") {
4531 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4532 TFE_NewOp(context::get_context(),
"Conv3DBackpropFilterV2", context::get_status()), &TFE_DeleteOp);
4533 status_check(context::get_status());
4537 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4538 status_check(context::get_status());
4540 TFE_OpAddInput(op.get(), filter_sizes.tfe_handle.get(), context::get_status());
4541 status_check(context::get_status());
4543 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4544 status_check(context::get_status());
4547 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4548 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4549 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4550 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4553 int num_outputs_op = 1;
4554 TFE_TensorHandle* res[1] = {
nullptr};
4555 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4556 status_check(context::get_status());
4557 return tensor(res[0]);
4560 inline tensor conv3_d_backprop_input(
const tensor& input,
const tensor& filter,
const tensor& out_backprop,
4561 const std::vector<int64_t>& strides,
const std::string& padding,
4562 const std::vector<int64_t>& dilations) {
4564 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4565 TFE_NewOp(context::get_context(),
"Conv3DBackpropInput", context::get_status()), &TFE_DeleteOp);
4566 status_check(context::get_status());
4570 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4571 status_check(context::get_status());
4573 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4574 status_check(context::get_status());
4576 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4577 status_check(context::get_status());
4580 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4581 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4582 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4585 int num_outputs_op = 1;
4586 TFE_TensorHandle* res[1] = {
nullptr};
4587 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4588 status_check(context::get_status());
4589 return tensor(res[0]);
4592 inline tensor conv3_d_backprop_input_v2(
const tensor& input_sizes,
const tensor& filter,
const tensor& out_backprop,
4593 const std::vector<int64_t>& strides,
const std::string& padding,
4594 const std::vector<int64_t>& dilations,
const std::string& data_format =
"NDHWC",
4595 datatype Tshape =
static_cast<datatype
>(3)) {
4597 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4598 TFE_NewOp(context::get_context(),
"Conv3DBackpropInputV2", context::get_status()), &TFE_DeleteOp);
4599 status_check(context::get_status());
4603 TFE_OpAddInput(op.get(), input_sizes.tfe_handle.get(), context::get_status());
4604 status_check(context::get_status());
4606 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
4607 status_check(context::get_status());
4609 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
4610 status_check(context::get_status());
4613 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
4614 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
4615 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
4616 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
4617 TFE_OpSetAttrType(op.get(),
"Tshape", Tshape);
4620 int num_outputs_op = 1;
4621 TFE_TensorHandle* res[1] = {
nullptr};
4622 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4623 status_check(context::get_status());
4624 return tensor(res[0]);
4627 inline tensor copy(
const tensor& input,
const std::vector<std::string>& debug_ops_spec,
4628 const std::string& tensor_name =
"") {
4630 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Copy", context::get_status()),
4632 status_check(context::get_status());
4636 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4637 status_check(context::get_status());
4641 std::vector<std::size_t> debug_ops_spec_sizes;
4642 debug_ops_spec_sizes.reserve(debug_ops_spec.size());
4643 std::transform(debug_ops_spec.begin(), debug_ops_spec.end(), std::back_inserter(debug_ops_spec_sizes),
4644 [](
const auto& s) { return s.size(); });
4645 TFE_OpSetAttrStringList(op.get(),
"debug_ops_spec",
reinterpret_cast<const void* const*
>(debug_ops_spec.data()),
4646 debug_ops_spec_sizes.data(),
static_cast<int>(debug_ops_spec.size()));
4648 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
4651 int num_outputs_op = 1;
4652 TFE_TensorHandle* res[1] = {
nullptr};
4653 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4654 status_check(context::get_status());
4655 return tensor(res[0]);
4658 inline tensor copy_host(
const tensor& input,
const std::vector<std::string>& debug_ops_spec,
4659 const std::string& tensor_name =
"") {
4661 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4662 TFE_NewOp(context::get_context(),
"CopyHost", context::get_status()), &TFE_DeleteOp);
4663 status_check(context::get_status());
4667 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4668 status_check(context::get_status());
4672 std::vector<std::size_t> debug_ops_spec_sizes;
4673 debug_ops_spec_sizes.reserve(debug_ops_spec.size());
4674 std::transform(debug_ops_spec.begin(), debug_ops_spec.end(), std::back_inserter(debug_ops_spec_sizes),
4675 [](
const auto& s) { return s.size(); });
4676 TFE_OpSetAttrStringList(op.get(),
"debug_ops_spec",
reinterpret_cast<const void* const*
>(debug_ops_spec.data()),
4677 debug_ops_spec_sizes.data(),
static_cast<int>(debug_ops_spec.size()));
4679 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
4682 int num_outputs_op = 1;
4683 TFE_TensorHandle* res[1] = {
nullptr};
4684 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4685 status_check(context::get_status());
4686 return tensor(res[0]);
4689 inline tensor cos(
const tensor& x) {
4691 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Cos", context::get_status()),
4693 status_check(context::get_status());
4697 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
4698 status_check(context::get_status());
4703 int num_outputs_op = 1;
4704 TFE_TensorHandle* res[1] = {
nullptr};
4705 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4706 status_check(context::get_status());
4707 return tensor(res[0]);
4710 inline tensor cosh(
const tensor& x) {
4712 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Cosh", context::get_status()),
4714 status_check(context::get_status());
4718 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
4719 status_check(context::get_status());
4724 int num_outputs_op = 1;
4725 TFE_TensorHandle* res[1] = {
nullptr};
4726 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4727 status_check(context::get_status());
4728 return tensor(res[0]);
4731 inline tensor count_up_to(
const tensor& ref, int64_t limit) {
4733 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4734 TFE_NewOp(context::get_context(),
"CountUpTo", context::get_status()), &TFE_DeleteOp);
4735 status_check(context::get_status());
4739 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
4740 status_check(context::get_status());
4743 TFE_OpSetAttrInt(op.get(),
"limit", limit);
4746 int num_outputs_op = 1;
4747 TFE_TensorHandle* res[1] = {
nullptr};
4748 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4749 status_check(context::get_status());
4750 return tensor(res[0]);
4753 inline tensor crop_and_resize(
const tensor& image,
const tensor& boxes,
const tensor& box_ind,
const tensor& crop_size,
4754 const std::string& method =
"bilinear",
float extrapolation_value = 0.0000e+00) {
4756 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4757 TFE_NewOp(context::get_context(),
"CropAndResize", context::get_status()), &TFE_DeleteOp);
4758 status_check(context::get_status());
4762 TFE_OpAddInput(op.get(), image.tfe_handle.get(), context::get_status());
4763 status_check(context::get_status());
4765 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
4766 status_check(context::get_status());
4768 TFE_OpAddInput(op.get(), box_ind.tfe_handle.get(), context::get_status());
4769 status_check(context::get_status());
4771 TFE_OpAddInput(op.get(), crop_size.tfe_handle.get(), context::get_status());
4772 status_check(context::get_status());
4775 TFE_OpSetAttrString(op.get(),
"method", (
void*)method.c_str(), method.size());
4776 TFE_OpSetAttrFloat(op.get(),
"extrapolation_value", extrapolation_value);
4779 int num_outputs_op = 1;
4780 TFE_TensorHandle* res[1] = {
nullptr};
4781 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4782 status_check(context::get_status());
4783 return tensor(res[0]);
4786 inline tensor crop_and_resize_grad_boxes(
const tensor& grads,
const tensor& image,
const tensor& boxes,
4787 const tensor& box_ind,
const std::string& method =
"bilinear") {
4789 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4790 TFE_NewOp(context::get_context(),
"CropAndResizeGradBoxes", context::get_status()), &TFE_DeleteOp);
4791 status_check(context::get_status());
4795 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
4796 status_check(context::get_status());
4798 TFE_OpAddInput(op.get(), image.tfe_handle.get(), context::get_status());
4799 status_check(context::get_status());
4801 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
4802 status_check(context::get_status());
4804 TFE_OpAddInput(op.get(), box_ind.tfe_handle.get(), context::get_status());
4805 status_check(context::get_status());
4808 TFE_OpSetAttrString(op.get(),
"method", (
void*)method.c_str(), method.size());
4811 int num_outputs_op = 1;
4812 TFE_TensorHandle* res[1] = {
nullptr};
4813 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4814 status_check(context::get_status());
4815 return tensor(res[0]);
4818 inline tensor crop_and_resize_grad_image(
const tensor& grads,
const tensor& boxes,
const tensor& box_ind,
4819 const tensor& image_size,
const std::string& method =
"bilinear") {
4821 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4822 TFE_NewOp(context::get_context(),
"CropAndResizeGradImage", context::get_status()), &TFE_DeleteOp);
4823 status_check(context::get_status());
4827 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
4828 status_check(context::get_status());
4830 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
4831 status_check(context::get_status());
4833 TFE_OpAddInput(op.get(), box_ind.tfe_handle.get(), context::get_status());
4834 status_check(context::get_status());
4836 TFE_OpAddInput(op.get(), image_size.tfe_handle.get(), context::get_status());
4837 status_check(context::get_status());
4840 TFE_OpSetAttrString(op.get(),
"method", (
void*)method.c_str(), method.size());
4843 int num_outputs_op = 1;
4844 TFE_TensorHandle* res[1] = {
nullptr};
4845 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4846 status_check(context::get_status());
4847 return tensor(res[0]);
4850 inline tensor cross(
const tensor& a,
const tensor& b) {
4852 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Cross", context::get_status()),
4854 status_check(context::get_status());
4858 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
4859 status_check(context::get_status());
4861 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
4862 status_check(context::get_status());
4867 int num_outputs_op = 1;
4868 TFE_TensorHandle* res[1] = {
nullptr};
4869 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4870 status_check(context::get_status());
4871 return tensor(res[0]);
4874 inline tensor cross_replica_sum(
const tensor& input,
const tensor& group_assignment) {
4876 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4877 TFE_NewOp(context::get_context(),
"CrossReplicaSum", context::get_status()), &TFE_DeleteOp);
4878 status_check(context::get_status());
4882 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
4883 status_check(context::get_status());
4885 TFE_OpAddInput(op.get(), group_assignment.tfe_handle.get(), context::get_status());
4886 status_check(context::get_status());
4891 int num_outputs_op = 1;
4892 TFE_TensorHandle* res[1] = {
nullptr};
4893 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4894 status_check(context::get_status());
4895 return tensor(res[0]);
4898 inline tensor cudnn_r_n_n_canonical_to_params(
const tensor& num_layers,
const tensor& num_units,
4899 const tensor& input_size,
const std::vector<tensor>& weights,
4900 const std::vector<tensor>& biases,
const std::string& rnn_mode =
"lstm",
4901 const std::string& input_mode =
"linear_input",
4902 const std::string& direction =
"unidirectional",
4903 float dropout = 0.0000e+00, int64_t seed = 0, int64_t seed2 = 0) {
4905 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4906 TFE_NewOp(context::get_context(),
"CudnnRNNCanonicalToParams", context::get_status()), &TFE_DeleteOp);
4907 status_check(context::get_status());
4911 TFE_OpAddInput(op.get(), num_layers.tfe_handle.get(), context::get_status());
4912 status_check(context::get_status());
4914 TFE_OpAddInput(op.get(), num_units.tfe_handle.get(), context::get_status());
4915 status_check(context::get_status());
4917 TFE_OpAddInput(op.get(), input_size.tfe_handle.get(), context::get_status());
4918 status_check(context::get_status());
4920 std::vector<TFE_TensorHandle*> weights_handles;
4921 weights_handles.reserve(weights.size());
4922 std::transform(weights.begin(), weights.end(), std::back_inserter(weights_handles),
4923 [](
const auto& t) { return t.tfe_handle.get(); });
4924 TFE_OpAddInputList(op.get(), weights_handles.data(),
static_cast<int>(weights.size()), context::get_status());
4925 status_check(context::get_status());
4927 std::vector<TFE_TensorHandle*> biases_handles;
4928 biases_handles.reserve(biases.size());
4929 std::transform(biases.begin(), biases.end(), std::back_inserter(biases_handles),
4930 [](
const auto& t) { return t.tfe_handle.get(); });
4931 TFE_OpAddInputList(op.get(), biases_handles.data(),
static_cast<int>(biases.size()), context::get_status());
4932 status_check(context::get_status());
4935 TFE_OpSetAttrInt(op.get(),
"num_params", weights.size());
4936 TFE_OpSetAttrString(op.get(),
"rnn_mode", (
void*)rnn_mode.c_str(), rnn_mode.size());
4937 TFE_OpSetAttrString(op.get(),
"input_mode", (
void*)input_mode.c_str(), input_mode.size());
4938 TFE_OpSetAttrString(op.get(),
"direction", (
void*)direction.c_str(), direction.size());
4939 TFE_OpSetAttrFloat(op.get(),
"dropout", dropout);
4940 TFE_OpSetAttrInt(op.get(),
"seed", seed);
4941 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
4944 int num_outputs_op = 1;
4945 TFE_TensorHandle* res[1] = {
nullptr};
4946 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
4947 status_check(context::get_status());
4948 return tensor(res[0]);
4951 inline tensor cudnn_r_n_n_canonical_to_params_v2(
4952 const tensor& num_layers,
const tensor& num_units,
const tensor& input_size,
const std::vector<tensor>& weights,
4953 const std::vector<tensor>& biases,
const std::string& rnn_mode =
"lstm",
4954 const std::string& input_mode =
"linear_input",
const std::string& direction =
"unidirectional",
4955 float dropout = 0.0000e+00, int64_t seed = 0, int64_t seed2 = 0, int64_t num_proj = 0) {
4957 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
4958 TFE_NewOp(context::get_context(),
"CudnnRNNCanonicalToParamsV2", context::get_status()), &TFE_DeleteOp);
4959 status_check(context::get_status());
4963 TFE_OpAddInput(op.get(), num_layers.tfe_handle.get(), context::get_status());
4964 status_check(context::get_status());
4966 TFE_OpAddInput(op.get(), num_units.tfe_handle.get(), context::get_status());
4967 status_check(context::get_status());
4969 TFE_OpAddInput(op.get(), input_size.tfe_handle.get(), context::get_status());
4970 status_check(context::get_status());
4972 std::vector<TFE_TensorHandle*> weights_handles;
4973 weights_handles.reserve(weights.size());
4974 std::transform(weights.begin(), weights.end(), std::back_inserter(weights_handles),
4975 [](
const auto& t) { return t.tfe_handle.get(); });
4976 TFE_OpAddInputList(op.get(), weights_handles.data(),
static_cast<int>(weights.size()), context::get_status());
4977 status_check(context::get_status());
4979 std::vector<TFE_TensorHandle*> biases_handles;
4980 biases_handles.reserve(biases.size());
4981 std::transform(biases.begin(), biases.end(), std::back_inserter(biases_handles),
4982 [](
const auto& t) { return t.tfe_handle.get(); });
4983 TFE_OpAddInputList(op.get(), biases_handles.data(),
static_cast<int>(biases.size()), context::get_status());
4984 status_check(context::get_status());
4987 TFE_OpSetAttrInt(op.get(),
"num_params_weights", weights.size());
4988 TFE_OpSetAttrInt(op.get(),
"num_params_biases", biases.size());
4989 TFE_OpSetAttrString(op.get(),
"rnn_mode", (
void*)rnn_mode.c_str(), rnn_mode.size());
4990 TFE_OpSetAttrString(op.get(),
"input_mode", (
void*)input_mode.c_str(), input_mode.size());
4991 TFE_OpSetAttrString(op.get(),
"direction", (
void*)direction.c_str(), direction.size());
4992 TFE_OpSetAttrFloat(op.get(),
"dropout", dropout);
4993 TFE_OpSetAttrInt(op.get(),
"seed", seed);
4994 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
4995 TFE_OpSetAttrInt(op.get(),
"num_proj", num_proj);
4998 int num_outputs_op = 1;
4999 TFE_TensorHandle* res[1] = {
nullptr};
5000 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5001 status_check(context::get_status());
5002 return tensor(res[0]);
5005 inline tensor cudnn_r_n_n_params_size(
const tensor& num_layers,
const tensor& num_units,
const tensor& input_size,
5006 datatype S,
const std::string& rnn_mode =
"lstm",
5007 const std::string& input_mode =
"linear_input",
5008 const std::string& direction =
"unidirectional",
float dropout = 0.0000e+00,
5009 int64_t seed = 0, int64_t seed2 = 0, int64_t num_proj = 0) {
5011 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5012 TFE_NewOp(context::get_context(),
"CudnnRNNParamsSize", context::get_status()), &TFE_DeleteOp);
5013 status_check(context::get_status());
5017 TFE_OpAddInput(op.get(), num_layers.tfe_handle.get(), context::get_status());
5018 status_check(context::get_status());
5020 TFE_OpAddInput(op.get(), num_units.tfe_handle.get(), context::get_status());
5021 status_check(context::get_status());
5023 TFE_OpAddInput(op.get(), input_size.tfe_handle.get(), context::get_status());
5024 status_check(context::get_status());
5027 TFE_OpSetAttrType(op.get(),
"S", S);
5028 TFE_OpSetAttrString(op.get(),
"rnn_mode", (
void*)rnn_mode.c_str(), rnn_mode.size());
5029 TFE_OpSetAttrString(op.get(),
"input_mode", (
void*)input_mode.c_str(), input_mode.size());
5030 TFE_OpSetAttrString(op.get(),
"direction", (
void*)direction.c_str(), direction.size());
5031 TFE_OpSetAttrFloat(op.get(),
"dropout", dropout);
5032 TFE_OpSetAttrInt(op.get(),
"seed", seed);
5033 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
5034 TFE_OpSetAttrInt(op.get(),
"num_proj", num_proj);
5037 int num_outputs_op = 1;
5038 TFE_TensorHandle* res[1] = {
nullptr};
5039 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5040 status_check(context::get_status());
5041 return tensor(res[0]);
5044 inline tensor cumprod(
const tensor& x,
const tensor& axis,
bool exclusive =
false,
bool reverse =
false,
5045 datatype Tidx =
static_cast<datatype
>(3)) {
5047 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5048 TFE_NewOp(context::get_context(),
"Cumprod", context::get_status()), &TFE_DeleteOp);
5049 status_check(context::get_status());
5053 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5054 status_check(context::get_status());
5056 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
5057 status_check(context::get_status());
5060 TFE_OpSetAttrBool(op.get(),
"exclusive", (
unsigned char)exclusive);
5061 TFE_OpSetAttrBool(op.get(),
"reverse", (
unsigned char)reverse);
5062 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
5065 int num_outputs_op = 1;
5066 TFE_TensorHandle* res[1] = {
nullptr};
5067 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5068 status_check(context::get_status());
5069 return tensor(res[0]);
5072 inline tensor cumsum(
const tensor& x,
const tensor& axis,
bool exclusive =
false,
bool reverse =
false,
5073 datatype Tidx =
static_cast<datatype
>(3)) {
5075 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5076 TFE_NewOp(context::get_context(),
"Cumsum", context::get_status()), &TFE_DeleteOp);
5077 status_check(context::get_status());
5081 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5082 status_check(context::get_status());
5084 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
5085 status_check(context::get_status());
5088 TFE_OpSetAttrBool(op.get(),
"exclusive", (
unsigned char)exclusive);
5089 TFE_OpSetAttrBool(op.get(),
"reverse", (
unsigned char)reverse);
5090 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
5093 int num_outputs_op = 1;
5094 TFE_TensorHandle* res[1] = {
nullptr};
5095 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5096 status_check(context::get_status());
5097 return tensor(res[0]);
5100 inline tensor cumulative_logsumexp(
const tensor& x,
const tensor& axis,
bool exclusive =
false,
bool reverse =
false,
5101 datatype Tidx =
static_cast<datatype
>(3)) {
5103 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5104 TFE_NewOp(context::get_context(),
"CumulativeLogsumexp", context::get_status()), &TFE_DeleteOp);
5105 status_check(context::get_status());
5109 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5110 status_check(context::get_status());
5112 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
5113 status_check(context::get_status());
5116 TFE_OpSetAttrBool(op.get(),
"exclusive", (
unsigned char)exclusive);
5117 TFE_OpSetAttrBool(op.get(),
"reverse", (
unsigned char)reverse);
5118 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
5121 int num_outputs_op = 1;
5122 TFE_TensorHandle* res[1] = {
nullptr};
5123 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5124 status_check(context::get_status());
5125 return tensor(res[0]);
5128 inline tensor data_format_dim_map(
const tensor& x,
const std::string& src_format =
"NHWC",
5129 const std::string& dst_format =
"NCHW") {
5131 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5132 TFE_NewOp(context::get_context(),
"DataFormatDimMap", context::get_status()), &TFE_DeleteOp);
5133 status_check(context::get_status());
5137 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5138 status_check(context::get_status());
5141 TFE_OpSetAttrString(op.get(),
"src_format", (
void*)src_format.c_str(), src_format.size());
5142 TFE_OpSetAttrString(op.get(),
"dst_format", (
void*)dst_format.c_str(), dst_format.size());
5145 int num_outputs_op = 1;
5146 TFE_TensorHandle* res[1] = {
nullptr};
5147 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5148 status_check(context::get_status());
5149 return tensor(res[0]);
5152 inline tensor data_format_vec_permute(
const tensor& x,
const std::string& src_format =
"NHWC",
5153 const std::string& dst_format =
"NCHW") {
5155 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5156 TFE_NewOp(context::get_context(),
"DataFormatVecPermute", context::get_status()), &TFE_DeleteOp);
5157 status_check(context::get_status());
5161 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5162 status_check(context::get_status());
5165 TFE_OpSetAttrString(op.get(),
"src_format", (
void*)src_format.c_str(), src_format.size());
5166 TFE_OpSetAttrString(op.get(),
"dst_format", (
void*)dst_format.c_str(), dst_format.size());
5169 int num_outputs_op = 1;
5170 TFE_TensorHandle* res[1] = {
nullptr};
5171 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5172 status_check(context::get_status());
5173 return tensor(res[0]);
5176 inline tensor data_service_dataset(
const tensor& dataset_id,
const tensor& processing_mode,
const tensor& address,
5177 const tensor& protocol,
const tensor& job_name,
5178 const tensor& max_outstanding_requests,
const tensor& iteration_counter,
5179 const std::vector<datatype>& output_types,
5180 const std::vector<std::vector<int64_t>>& output_shapes,
5181 int64_t task_refresh_interval_hint_ms = -1) {
5183 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5184 TFE_NewOp(context::get_context(),
"DataServiceDataset", context::get_status()), &TFE_DeleteOp);
5185 status_check(context::get_status());
5189 TFE_OpAddInput(op.get(), dataset_id.tfe_handle.get(), context::get_status());
5190 status_check(context::get_status());
5192 TFE_OpAddInput(op.get(), processing_mode.tfe_handle.get(), context::get_status());
5193 status_check(context::get_status());
5195 TFE_OpAddInput(op.get(), address.tfe_handle.get(), context::get_status());
5196 status_check(context::get_status());
5198 TFE_OpAddInput(op.get(), protocol.tfe_handle.get(), context::get_status());
5199 status_check(context::get_status());
5201 TFE_OpAddInput(op.get(), job_name.tfe_handle.get(), context::get_status());
5202 status_check(context::get_status());
5204 TFE_OpAddInput(op.get(), max_outstanding_requests.tfe_handle.get(), context::get_status());
5205 status_check(context::get_status());
5207 TFE_OpAddInput(op.get(), iteration_counter.tfe_handle.get(), context::get_status());
5208 status_check(context::get_status());
5211 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
5212 static_cast<int>(output_types.size()));
5214 std::vector<const int64_t*> output_shapes_values;
5215 output_shapes_values.reserve(output_shapes.size());
5216 std::vector<int> output_shapes_ndims;
5217 output_shapes_ndims.reserve(output_shapes.size());
5218 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
5219 [](
const auto& v) { return v.data(); });
5220 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
5221 [](
const auto& v) { return static_cast<int>(v.size()); });
5222 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
5223 static_cast<int>(output_shapes.size()), context::get_status());
5224 status_check(context::get_status());
5226 TFE_OpSetAttrInt(op.get(),
"task_refresh_interval_hint_ms", task_refresh_interval_hint_ms);
5229 int num_outputs_op = 1;
5230 TFE_TensorHandle* res[1] = {
nullptr};
5231 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5232 status_check(context::get_status());
5233 return tensor(res[0]);
5236 inline tensor dataset_cardinality(
const tensor& input_dataset) {
5238 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5239 TFE_NewOp(context::get_context(),
"DatasetCardinality", context::get_status()), &TFE_DeleteOp);
5240 status_check(context::get_status());
5244 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
5245 status_check(context::get_status());
5250 int num_outputs_op = 1;
5251 TFE_TensorHandle* res[1] = {
nullptr};
5252 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5253 status_check(context::get_status());
5254 return tensor(res[0]);
5257 inline tensor dataset_from_graph(
const tensor& graph_def) {
5259 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5260 TFE_NewOp(context::get_context(),
"DatasetFromGraph", context::get_status()), &TFE_DeleteOp);
5261 status_check(context::get_status());
5265 TFE_OpAddInput(op.get(), graph_def.tfe_handle.get(), context::get_status());
5266 status_check(context::get_status());
5271 int num_outputs_op = 1;
5272 TFE_TensorHandle* res[1] = {
nullptr};
5273 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5274 status_check(context::get_status());
5275 return tensor(res[0]);
5278 inline tensor dataset_to_graph(
const tensor& input_dataset,
const std::vector<std::string>& stateful_whitelist,
5279 bool allow_stateful =
false,
bool strip_device_assignment =
false) {
5281 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5282 TFE_NewOp(context::get_context(),
"DatasetToGraph", context::get_status()), &TFE_DeleteOp);
5283 status_check(context::get_status());
5287 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
5288 status_check(context::get_status());
5292 std::vector<std::size_t> stateful_whitelist_sizes;
5293 stateful_whitelist_sizes.reserve(stateful_whitelist.size());
5294 std::transform(stateful_whitelist.begin(), stateful_whitelist.end(), std::back_inserter(stateful_whitelist_sizes),
5295 [](
const auto& s) { return s.size(); });
5296 TFE_OpSetAttrStringList(op.get(),
"stateful_whitelist",
5297 reinterpret_cast<const void* const*
>(stateful_whitelist.data()),
5298 stateful_whitelist_sizes.data(),
static_cast<int>(stateful_whitelist.size()));
5300 TFE_OpSetAttrBool(op.get(),
"allow_stateful", (
unsigned char)allow_stateful);
5301 TFE_OpSetAttrBool(op.get(),
"strip_device_assignment", (
unsigned char)strip_device_assignment);
5304 int num_outputs_op = 1;
5305 TFE_TensorHandle* res[1] = {
nullptr};
5306 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5307 status_check(context::get_status());
5308 return tensor(res[0]);
5311 inline tensor dataset_to_graph_v2(
const tensor& input_dataset, int64_t external_state_policy = 0,
5312 bool strip_device_assignment =
false) {
5314 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5315 TFE_NewOp(context::get_context(),
"DatasetToGraphV2", context::get_status()), &TFE_DeleteOp);
5316 status_check(context::get_status());
5320 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
5321 status_check(context::get_status());
5324 TFE_OpSetAttrInt(op.get(),
"external_state_policy", external_state_policy);
5325 TFE_OpSetAttrBool(op.get(),
"strip_device_assignment", (
unsigned char)strip_device_assignment);
5328 int num_outputs_op = 1;
5329 TFE_TensorHandle* res[1] = {
nullptr};
5330 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5331 status_check(context::get_status());
5332 return tensor(res[0]);
5335 inline tensor dataset_to_single_element(
const tensor& dataset,
const std::vector<datatype>& output_types,
5336 const std::vector<std::vector<int64_t>>& output_shapes) {
5338 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5339 TFE_NewOp(context::get_context(),
"DatasetToSingleElement", context::get_status()), &TFE_DeleteOp);
5340 status_check(context::get_status());
5344 TFE_OpAddInput(op.get(), dataset.tfe_handle.get(), context::get_status());
5345 status_check(context::get_status());
5348 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
5349 static_cast<int>(output_types.size()));
5351 std::vector<const int64_t*> output_shapes_values;
5352 output_shapes_values.reserve(output_shapes.size());
5353 std::vector<int> output_shapes_ndims;
5354 output_shapes_ndims.reserve(output_shapes.size());
5355 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
5356 [](
const auto& v) { return v.data(); });
5357 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
5358 [](
const auto& v) { return static_cast<int>(v.size()); });
5359 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
5360 static_cast<int>(output_shapes.size()), context::get_status());
5361 status_check(context::get_status());
5364 int num_outputs_op = 1;
5365 TFE_TensorHandle* res[1] = {
nullptr};
5366 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5367 status_check(context::get_status());
5368 return tensor(res[0]);
5371 inline tensor dawsn(
const tensor& x) {
5373 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Dawsn", context::get_status()),
5375 status_check(context::get_status());
5379 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5380 status_check(context::get_status());
5385 int num_outputs_op = 1;
5386 TFE_TensorHandle* res[1] = {
nullptr};
5387 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5388 status_check(context::get_status());
5389 return tensor(res[0]);
5392 inline tensor debug_gradient_identity(
const tensor& input) {
5394 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5395 TFE_NewOp(context::get_context(),
"DebugGradientIdentity", context::get_status()), &TFE_DeleteOp);
5396 status_check(context::get_status());
5400 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5401 status_check(context::get_status());
5406 int num_outputs_op = 1;
5407 TFE_TensorHandle* res[1] = {
nullptr};
5408 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5409 status_check(context::get_status());
5410 return tensor(res[0]);
5413 inline tensor debug_gradient_ref_identity(
const tensor& input) {
5415 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5416 TFE_NewOp(context::get_context(),
"DebugGradientRefIdentity", context::get_status()), &TFE_DeleteOp);
5417 status_check(context::get_status());
5421 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5422 status_check(context::get_status());
5427 int num_outputs_op = 1;
5428 TFE_TensorHandle* res[1] = {
nullptr};
5429 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5430 status_check(context::get_status());
5431 return tensor(res[0]);
5434 inline tensor debug_identity(
const tensor& input,
const std::vector<std::string>& debug_urls,
5435 const std::string& device_name =
"",
const std::string& tensor_name =
"",
5436 bool gated_grpc =
false) {
5438 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5439 TFE_NewOp(context::get_context(),
"DebugIdentity", context::get_status()), &TFE_DeleteOp);
5440 status_check(context::get_status());
5444 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5445 status_check(context::get_status());
5449 std::vector<std::size_t> debug_urls_sizes;
5450 debug_urls_sizes.reserve(debug_urls.size());
5451 std::transform(debug_urls.begin(), debug_urls.end(), std::back_inserter(debug_urls_sizes),
5452 [](
const auto& s) { return s.size(); });
5453 TFE_OpSetAttrStringList(op.get(),
"debug_urls",
reinterpret_cast<const void* const*
>(debug_urls.data()),
5454 debug_urls_sizes.data(),
static_cast<int>(debug_urls.size()));
5456 TFE_OpSetAttrString(op.get(),
"device_name", (
void*)device_name.c_str(), device_name.size());
5457 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
5458 TFE_OpSetAttrBool(op.get(),
"gated_grpc", (
unsigned char)gated_grpc);
5461 int num_outputs_op = 1;
5462 TFE_TensorHandle* res[1] = {
nullptr};
5463 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5464 status_check(context::get_status());
5465 return tensor(res[0]);
5468 inline tensor debug_identity_v2(
const tensor& input,
const std::vector<std::string>& debug_urls,
5469 const std::string& tfdbg_context_id =
"",
const std::string& op_name =
"",
5470 int64_t output_slot = -1, int64_t tensor_debug_mode = -1,
5471 int64_t circular_buffer_size = 1000,
const std::string& tfdbg_run_id =
"") {
5473 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5474 TFE_NewOp(context::get_context(),
"DebugIdentityV2", context::get_status()), &TFE_DeleteOp);
5475 status_check(context::get_status());
5479 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5480 status_check(context::get_status());
5484 std::vector<std::size_t> debug_urls_sizes;
5485 debug_urls_sizes.reserve(debug_urls.size());
5486 std::transform(debug_urls.begin(), debug_urls.end(), std::back_inserter(debug_urls_sizes),
5487 [](
const auto& s) { return s.size(); });
5488 TFE_OpSetAttrStringList(op.get(),
"debug_urls",
reinterpret_cast<const void* const*
>(debug_urls.data()),
5489 debug_urls_sizes.data(),
static_cast<int>(debug_urls.size()));
5491 TFE_OpSetAttrString(op.get(),
"tfdbg_context_id", (
void*)tfdbg_context_id.c_str(), tfdbg_context_id.size());
5492 TFE_OpSetAttrString(op.get(),
"op_name", (
void*)op_name.c_str(), op_name.size());
5493 TFE_OpSetAttrInt(op.get(),
"output_slot", output_slot);
5494 TFE_OpSetAttrInt(op.get(),
"tensor_debug_mode", tensor_debug_mode);
5495 TFE_OpSetAttrInt(op.get(),
"circular_buffer_size", circular_buffer_size);
5496 TFE_OpSetAttrString(op.get(),
"tfdbg_run_id", (
void*)tfdbg_run_id.c_str(), tfdbg_run_id.size());
5499 int num_outputs_op = 1;
5500 TFE_TensorHandle* res[1] = {
nullptr};
5501 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5502 status_check(context::get_status());
5503 return tensor(res[0]);
5506 inline tensor debug_nan_count(
const tensor& input,
const std::vector<std::string>& debug_urls,
5507 const std::string& device_name =
"",
const std::string& tensor_name =
"",
5508 bool gated_grpc =
false) {
5510 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5511 TFE_NewOp(context::get_context(),
"DebugNanCount", context::get_status()), &TFE_DeleteOp);
5512 status_check(context::get_status());
5516 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5517 status_check(context::get_status());
5521 std::vector<std::size_t> debug_urls_sizes;
5522 debug_urls_sizes.reserve(debug_urls.size());
5523 std::transform(debug_urls.begin(), debug_urls.end(), std::back_inserter(debug_urls_sizes),
5524 [](
const auto& s) { return s.size(); });
5525 TFE_OpSetAttrStringList(op.get(),
"debug_urls",
reinterpret_cast<const void* const*
>(debug_urls.data()),
5526 debug_urls_sizes.data(),
static_cast<int>(debug_urls.size()));
5528 TFE_OpSetAttrString(op.get(),
"device_name", (
void*)device_name.c_str(), device_name.size());
5529 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
5530 TFE_OpSetAttrBool(op.get(),
"gated_grpc", (
unsigned char)gated_grpc);
5533 int num_outputs_op = 1;
5534 TFE_TensorHandle* res[1] = {
nullptr};
5535 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5536 status_check(context::get_status());
5537 return tensor(res[0]);
5540 inline tensor debug_numeric_summary(
const tensor& input,
const std::vector<std::string>& debug_urls,
5541 const std::string& device_name =
"",
const std::string& tensor_name =
"",
5542 float lower_bound = -std::numeric_limits<float>::infinity(),
5543 float upper_bound = std::numeric_limits<float>::infinity(),
5544 bool mute_if_healthy =
false,
bool gated_grpc =
false) {
5546 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5547 TFE_NewOp(context::get_context(),
"DebugNumericSummary", context::get_status()), &TFE_DeleteOp);
5548 status_check(context::get_status());
5552 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5553 status_check(context::get_status());
5557 std::vector<std::size_t> debug_urls_sizes;
5558 debug_urls_sizes.reserve(debug_urls.size());
5559 std::transform(debug_urls.begin(), debug_urls.end(), std::back_inserter(debug_urls_sizes),
5560 [](
const auto& s) { return s.size(); });
5561 TFE_OpSetAttrStringList(op.get(),
"debug_urls",
reinterpret_cast<const void* const*
>(debug_urls.data()),
5562 debug_urls_sizes.data(),
static_cast<int>(debug_urls.size()));
5564 TFE_OpSetAttrString(op.get(),
"device_name", (
void*)device_name.c_str(), device_name.size());
5565 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
5566 TFE_OpSetAttrFloat(op.get(),
"lower_bound", lower_bound);
5567 TFE_OpSetAttrFloat(op.get(),
"upper_bound", upper_bound);
5568 TFE_OpSetAttrBool(op.get(),
"mute_if_healthy", (
unsigned char)mute_if_healthy);
5569 TFE_OpSetAttrBool(op.get(),
"gated_grpc", (
unsigned char)gated_grpc);
5572 int num_outputs_op = 1;
5573 TFE_TensorHandle* res[1] = {
nullptr};
5574 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5575 status_check(context::get_status());
5576 return tensor(res[0]);
5579 inline tensor debug_numeric_summary_v2(
const tensor& input, datatype output_dtype =
static_cast<datatype
>(1),
5580 int64_t tensor_debug_mode = -1, int64_t tensor_id = -1) {
5582 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5583 TFE_NewOp(context::get_context(),
"DebugNumericSummaryV2", context::get_status()), &TFE_DeleteOp);
5584 status_check(context::get_status());
5588 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5589 status_check(context::get_status());
5592 TFE_OpSetAttrType(op.get(),
"output_dtype", output_dtype);
5593 TFE_OpSetAttrInt(op.get(),
"tensor_debug_mode", tensor_debug_mode);
5594 TFE_OpSetAttrInt(op.get(),
"tensor_id", tensor_id);
5597 int num_outputs_op = 1;
5598 TFE_TensorHandle* res[1] = {
nullptr};
5599 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5600 status_check(context::get_status());
5601 return tensor(res[0]);
5604 inline tensor decode_and_crop_jpeg(
const tensor& contents,
const tensor& crop_window, int64_t channels = 0,
5605 int64_t ratio = 1,
bool fancy_upscaling =
true,
bool try_recover_truncated =
false,
5606 float acceptable_fraction = 1.0000e+00,
const std::string& dct_method =
"") {
5608 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5609 TFE_NewOp(context::get_context(),
"DecodeAndCropJpeg", context::get_status()), &TFE_DeleteOp);
5610 status_check(context::get_status());
5614 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5615 status_check(context::get_status());
5617 TFE_OpAddInput(op.get(), crop_window.tfe_handle.get(), context::get_status());
5618 status_check(context::get_status());
5621 TFE_OpSetAttrInt(op.get(),
"channels", channels);
5622 TFE_OpSetAttrInt(op.get(),
"ratio", ratio);
5623 TFE_OpSetAttrBool(op.get(),
"fancy_upscaling", (
unsigned char)fancy_upscaling);
5624 TFE_OpSetAttrBool(op.get(),
"try_recover_truncated", (
unsigned char)try_recover_truncated);
5625 TFE_OpSetAttrFloat(op.get(),
"acceptable_fraction", acceptable_fraction);
5626 TFE_OpSetAttrString(op.get(),
"dct_method", (
void*)dct_method.c_str(), dct_method.size());
5629 int num_outputs_op = 1;
5630 TFE_TensorHandle* res[1] = {
nullptr};
5631 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5632 status_check(context::get_status());
5633 return tensor(res[0]);
5636 inline tensor decode_base64(
const tensor& input) {
5638 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5639 TFE_NewOp(context::get_context(),
"DecodeBase64", context::get_status()), &TFE_DeleteOp);
5640 status_check(context::get_status());
5644 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5645 status_check(context::get_status());
5650 int num_outputs_op = 1;
5651 TFE_TensorHandle* res[1] = {
nullptr};
5652 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5653 status_check(context::get_status());
5654 return tensor(res[0]);
5657 inline tensor decode_bmp(
const tensor& contents, int64_t channels = 0) {
5659 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5660 TFE_NewOp(context::get_context(),
"DecodeBmp", context::get_status()), &TFE_DeleteOp);
5661 status_check(context::get_status());
5665 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5666 status_check(context::get_status());
5669 TFE_OpSetAttrInt(op.get(),
"channels", channels);
5672 int num_outputs_op = 1;
5673 TFE_TensorHandle* res[1] = {
nullptr};
5674 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5675 status_check(context::get_status());
5676 return tensor(res[0]);
5679 inline tensor decode_c_s_v(
const tensor& records,
const std::vector<tensor>& record_defaults,
5680 const std::vector<datatype>& OUT_TYPE,
const std::vector<int64_t>& select_cols,
5681 const std::string& field_delim =
",",
bool use_quote_delim =
true,
5682 const std::string& na_value =
"") {
5684 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5685 TFE_NewOp(context::get_context(),
"DecodeCSV", context::get_status()), &TFE_DeleteOp);
5686 status_check(context::get_status());
5690 TFE_OpAddInput(op.get(), records.tfe_handle.get(), context::get_status());
5691 status_check(context::get_status());
5693 std::vector<TFE_TensorHandle*> record_defaults_handles;
5694 record_defaults_handles.reserve(record_defaults.size());
5695 std::transform(record_defaults.begin(), record_defaults.end(), std::back_inserter(record_defaults_handles),
5696 [](
const auto& t) { return t.tfe_handle.get(); });
5697 TFE_OpAddInputList(op.get(), record_defaults_handles.data(),
static_cast<int>(record_defaults.size()),
5698 context::get_status());
5699 status_check(context::get_status());
5702 TFE_OpSetAttrTypeList(op.get(),
"OUT_TYPE",
reinterpret_cast<const enum TF_DataType*
>(OUT_TYPE.data()),
5703 static_cast<int>(OUT_TYPE.size()));
5704 TFE_OpSetAttrIntList(op.get(),
"select_cols", select_cols.data(),
static_cast<int>(select_cols.size()));
5705 TFE_OpSetAttrString(op.get(),
"field_delim", (
void*)field_delim.c_str(), field_delim.size());
5706 TFE_OpSetAttrBool(op.get(),
"use_quote_delim", (
unsigned char)use_quote_delim);
5707 TFE_OpSetAttrString(op.get(),
"na_value", (
void*)na_value.c_str(), na_value.size());
5710 int num_outputs_op = 1;
5711 TFE_TensorHandle* res[1] = {
nullptr};
5712 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5713 status_check(context::get_status());
5714 return tensor(res[0]);
5717 inline tensor decode_compressed(
const tensor& bytes,
const std::string& compression_type =
"") {
5719 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5720 TFE_NewOp(context::get_context(),
"DecodeCompressed", context::get_status()), &TFE_DeleteOp);
5721 status_check(context::get_status());
5725 TFE_OpAddInput(op.get(), bytes.tfe_handle.get(), context::get_status());
5726 status_check(context::get_status());
5729 TFE_OpSetAttrString(op.get(),
"compression_type", (
void*)compression_type.c_str(), compression_type.size());
5732 int num_outputs_op = 1;
5733 TFE_TensorHandle* res[1] = {
nullptr};
5734 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5735 status_check(context::get_status());
5736 return tensor(res[0]);
5739 inline tensor decode_gif(
const tensor& contents) {
5741 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5742 TFE_NewOp(context::get_context(),
"DecodeGif", context::get_status()), &TFE_DeleteOp);
5743 status_check(context::get_status());
5747 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5748 status_check(context::get_status());
5753 int num_outputs_op = 1;
5754 TFE_TensorHandle* res[1] = {
nullptr};
5755 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5756 status_check(context::get_status());
5757 return tensor(res[0]);
5760 inline tensor decode_image(
const tensor& contents, int64_t channels = 0, datatype dtype =
static_cast<datatype
>(4),
5761 bool expand_animations =
true) {
5763 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5764 TFE_NewOp(context::get_context(),
"DecodeImage", context::get_status()), &TFE_DeleteOp);
5765 status_check(context::get_status());
5769 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5770 status_check(context::get_status());
5773 TFE_OpSetAttrInt(op.get(),
"channels", channels);
5774 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
5775 TFE_OpSetAttrBool(op.get(),
"expand_animations", (
unsigned char)expand_animations);
5778 int num_outputs_op = 1;
5779 TFE_TensorHandle* res[1] = {
nullptr};
5780 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5781 status_check(context::get_status());
5782 return tensor(res[0]);
5785 inline tensor decode_j_s_o_n_example(
const tensor& json_examples) {
5787 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5788 TFE_NewOp(context::get_context(),
"DecodeJSONExample", context::get_status()), &TFE_DeleteOp);
5789 status_check(context::get_status());
5793 TFE_OpAddInput(op.get(), json_examples.tfe_handle.get(), context::get_status());
5794 status_check(context::get_status());
5799 int num_outputs_op = 1;
5800 TFE_TensorHandle* res[1] = {
nullptr};
5801 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5802 status_check(context::get_status());
5803 return tensor(res[0]);
5806 inline tensor decode_jpeg(
const tensor& contents, int64_t channels = 0, int64_t ratio = 1,
bool fancy_upscaling =
true,
5807 bool try_recover_truncated =
false,
float acceptable_fraction = 1.0000e+00,
5808 const std::string& dct_method =
"") {
5810 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5811 TFE_NewOp(context::get_context(),
"DecodeJpeg", context::get_status()), &TFE_DeleteOp);
5812 status_check(context::get_status());
5816 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5817 status_check(context::get_status());
5820 TFE_OpSetAttrInt(op.get(),
"channels", channels);
5821 TFE_OpSetAttrInt(op.get(),
"ratio", ratio);
5822 TFE_OpSetAttrBool(op.get(),
"fancy_upscaling", (
unsigned char)fancy_upscaling);
5823 TFE_OpSetAttrBool(op.get(),
"try_recover_truncated", (
unsigned char)try_recover_truncated);
5824 TFE_OpSetAttrFloat(op.get(),
"acceptable_fraction", acceptable_fraction);
5825 TFE_OpSetAttrString(op.get(),
"dct_method", (
void*)dct_method.c_str(), dct_method.size());
5828 int num_outputs_op = 1;
5829 TFE_TensorHandle* res[1] = {
nullptr};
5830 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5831 status_check(context::get_status());
5832 return tensor(res[0]);
5835 inline tensor decode_padded_raw(
const tensor& input_bytes,
const tensor& fixed_length, datatype out_type,
5836 bool little_endian =
true) {
5838 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5839 TFE_NewOp(context::get_context(),
"DecodePaddedRaw", context::get_status()), &TFE_DeleteOp);
5840 status_check(context::get_status());
5844 TFE_OpAddInput(op.get(), input_bytes.tfe_handle.get(), context::get_status());
5845 status_check(context::get_status());
5847 TFE_OpAddInput(op.get(), fixed_length.tfe_handle.get(), context::get_status());
5848 status_check(context::get_status());
5851 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
5852 TFE_OpSetAttrBool(op.get(),
"little_endian", (
unsigned char)little_endian);
5855 int num_outputs_op = 1;
5856 TFE_TensorHandle* res[1] = {
nullptr};
5857 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5858 status_check(context::get_status());
5859 return tensor(res[0]);
5862 inline tensor decode_png(
const tensor& contents, int64_t channels = 0, datatype dtype =
static_cast<datatype
>(4)) {
5864 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5865 TFE_NewOp(context::get_context(),
"DecodePng", context::get_status()), &TFE_DeleteOp);
5866 status_check(context::get_status());
5870 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
5871 status_check(context::get_status());
5874 TFE_OpSetAttrInt(op.get(),
"channels", channels);
5875 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
5878 int num_outputs_op = 1;
5879 TFE_TensorHandle* res[1] = {
nullptr};
5880 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5881 status_check(context::get_status());
5882 return tensor(res[0]);
5885 inline tensor decode_raw(
const tensor& bytes, datatype out_type,
bool little_endian =
true) {
5887 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5888 TFE_NewOp(context::get_context(),
"DecodeRaw", context::get_status()), &TFE_DeleteOp);
5889 status_check(context::get_status());
5893 TFE_OpAddInput(op.get(), bytes.tfe_handle.get(), context::get_status());
5894 status_check(context::get_status());
5897 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
5898 TFE_OpSetAttrBool(op.get(),
"little_endian", (
unsigned char)little_endian);
5901 int num_outputs_op = 1;
5902 TFE_TensorHandle* res[1] = {
nullptr};
5903 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5904 status_check(context::get_status());
5905 return tensor(res[0]);
5908 inline tensor deep_copy(
const tensor& x) {
5910 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5911 TFE_NewOp(context::get_context(),
"DeepCopy", context::get_status()), &TFE_DeleteOp);
5912 status_check(context::get_status());
5916 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
5917 status_check(context::get_status());
5922 int num_outputs_op = 1;
5923 TFE_TensorHandle* res[1] = {
nullptr};
5924 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5925 status_check(context::get_status());
5926 return tensor(res[0]);
5929 inline tensor dense_bincount(
const tensor& input,
const tensor& size,
const tensor& weights, datatype Tidx,
5930 bool binary_output =
false) {
5932 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5933 TFE_NewOp(context::get_context(),
"DenseBincount", context::get_status()), &TFE_DeleteOp);
5934 status_check(context::get_status());
5938 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
5939 status_check(context::get_status());
5941 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
5942 status_check(context::get_status());
5944 TFE_OpAddInput(op.get(), weights.tfe_handle.get(), context::get_status());
5945 status_check(context::get_status());
5948 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
5949 TFE_OpSetAttrBool(op.get(),
"binary_output", (
unsigned char)binary_output);
5952 int num_outputs_op = 1;
5953 TFE_TensorHandle* res[1] = {
nullptr};
5954 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5955 status_check(context::get_status());
5956 return tensor(res[0]);
5959 inline tensor dense_to_c_s_r_sparse_matrix(
const tensor& dense_input,
const tensor& indices) {
5961 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5962 TFE_NewOp(context::get_context(),
"DenseToCSRSparseMatrix", context::get_status()), &TFE_DeleteOp);
5963 status_check(context::get_status());
5967 TFE_OpAddInput(op.get(), dense_input.tfe_handle.get(), context::get_status());
5968 status_check(context::get_status());
5970 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
5971 status_check(context::get_status());
5976 int num_outputs_op = 1;
5977 TFE_TensorHandle* res[1] = {
nullptr};
5978 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
5979 status_check(context::get_status());
5980 return tensor(res[0]);
5983 inline tensor dense_to_sparse_batch_dataset(
const tensor& input_dataset,
const tensor& batch_size,
5984 const tensor& row_shape,
const std::vector<datatype>& output_types,
5985 const std::vector<std::vector<int64_t>>& output_shapes) {
5987 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
5988 TFE_NewOp(context::get_context(),
"DenseToSparseBatchDataset", context::get_status()), &TFE_DeleteOp);
5989 status_check(context::get_status());
5993 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
5994 status_check(context::get_status());
5996 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
5997 status_check(context::get_status());
5999 TFE_OpAddInput(op.get(), row_shape.tfe_handle.get(), context::get_status());
6000 status_check(context::get_status());
6003 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
6004 static_cast<int>(output_types.size()));
6006 std::vector<const int64_t*> output_shapes_values;
6007 output_shapes_values.reserve(output_shapes.size());
6008 std::vector<int> output_shapes_ndims;
6009 output_shapes_ndims.reserve(output_shapes.size());
6010 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
6011 [](
const auto& v) { return v.data(); });
6012 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
6013 [](
const auto& v) { return static_cast<int>(v.size()); });
6014 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
6015 static_cast<int>(output_shapes.size()), context::get_status());
6016 status_check(context::get_status());
6019 int num_outputs_op = 1;
6020 TFE_TensorHandle* res[1] = {
nullptr};
6021 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6022 status_check(context::get_status());
6023 return tensor(res[0]);
6026 inline tensor depth_to_space(
const tensor& input, int64_t block_size,
const std::string& data_format =
"NHWC") {
6028 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6029 TFE_NewOp(context::get_context(),
"DepthToSpace", context::get_status()), &TFE_DeleteOp);
6030 status_check(context::get_status());
6034 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6035 status_check(context::get_status());
6038 TFE_OpSetAttrInt(op.get(),
"block_size", block_size);
6039 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
6042 int num_outputs_op = 1;
6043 TFE_TensorHandle* res[1] = {
nullptr};
6044 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6045 status_check(context::get_status());
6046 return tensor(res[0]);
6049 inline tensor depthwise_conv2d_native(
const tensor& input,
const tensor& filter,
const std::vector<int64_t>& strides,
6050 const std::string& padding,
const std::vector<int64_t>& explicit_paddings,
6051 const std::vector<int64_t>& dilations,
const std::string& data_format =
"NHWC") {
6053 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6054 TFE_NewOp(context::get_context(),
"DepthwiseConv2dNative", context::get_status()), &TFE_DeleteOp);
6055 status_check(context::get_status());
6059 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6060 status_check(context::get_status());
6062 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
6063 status_check(context::get_status());
6066 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6067 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6068 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
6069 static_cast<int>(explicit_paddings.size()));
6070 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
6071 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
6074 int num_outputs_op = 1;
6075 TFE_TensorHandle* res[1] = {
nullptr};
6076 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6077 status_check(context::get_status());
6078 return tensor(res[0]);
6081 inline tensor depthwise_conv2d_native_backprop_filter(
const tensor& input,
const tensor& filter_sizes,
6082 const tensor& out_backprop,
const std::vector<int64_t>& strides,
6083 const std::string& padding,
6084 const std::vector<int64_t>& explicit_paddings,
6085 const std::vector<int64_t>& dilations,
6086 const std::string& data_format =
"NHWC") {
6088 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6089 TFE_NewOp(context::get_context(),
"DepthwiseConv2dNativeBackpropFilter", context::get_status()), &TFE_DeleteOp);
6090 status_check(context::get_status());
6094 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6095 status_check(context::get_status());
6097 TFE_OpAddInput(op.get(), filter_sizes.tfe_handle.get(), context::get_status());
6098 status_check(context::get_status());
6100 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
6101 status_check(context::get_status());
6104 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6105 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6106 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
6107 static_cast<int>(explicit_paddings.size()));
6108 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
6109 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
6112 int num_outputs_op = 1;
6113 TFE_TensorHandle* res[1] = {
nullptr};
6114 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6115 status_check(context::get_status());
6116 return tensor(res[0]);
6119 inline tensor depthwise_conv2d_native_backprop_input(
const tensor& input_sizes,
const tensor& filter,
6120 const tensor& out_backprop,
const std::vector<int64_t>& strides,
6121 const std::string& padding,
6122 const std::vector<int64_t>& explicit_paddings,
6123 const std::vector<int64_t>& dilations,
6124 const std::string& data_format =
"NHWC") {
6126 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6127 TFE_NewOp(context::get_context(),
"DepthwiseConv2dNativeBackpropInput", context::get_status()), &TFE_DeleteOp);
6128 status_check(context::get_status());
6132 TFE_OpAddInput(op.get(), input_sizes.tfe_handle.get(), context::get_status());
6133 status_check(context::get_status());
6135 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
6136 status_check(context::get_status());
6138 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
6139 status_check(context::get_status());
6142 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6143 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6144 TFE_OpSetAttrIntList(op.get(),
"explicit_paddings", explicit_paddings.data(),
6145 static_cast<int>(explicit_paddings.size()));
6146 TFE_OpSetAttrIntList(op.get(),
"dilations", dilations.data(),
static_cast<int>(dilations.size()));
6147 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
6150 int num_outputs_op = 1;
6151 TFE_TensorHandle* res[1] = {
nullptr};
6152 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6153 status_check(context::get_status());
6154 return tensor(res[0]);
6157 inline tensor dequantize(
const tensor& input,
const tensor& min_range,
const tensor& max_range,
6158 const std::string& mode =
"MIN_COMBINED",
bool narrow_range =
false, int64_t axis = -1,
6159 datatype dtype =
static_cast<datatype
>(1)) {
6161 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6162 TFE_NewOp(context::get_context(),
"Dequantize", context::get_status()), &TFE_DeleteOp);
6163 status_check(context::get_status());
6167 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6168 status_check(context::get_status());
6170 TFE_OpAddInput(op.get(), min_range.tfe_handle.get(), context::get_status());
6171 status_check(context::get_status());
6173 TFE_OpAddInput(op.get(), max_range.tfe_handle.get(), context::get_status());
6174 status_check(context::get_status());
6177 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
6178 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
6179 TFE_OpSetAttrInt(op.get(),
"axis", axis);
6180 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
6183 int num_outputs_op = 1;
6184 TFE_TensorHandle* res[1] = {
nullptr};
6185 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6186 status_check(context::get_status());
6187 return tensor(res[0]);
6190 inline tensor destroy_temporary_variable(
const tensor& ref,
const std::string& var_name) {
6192 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6193 TFE_NewOp(context::get_context(),
"DestroyTemporaryVariable", context::get_status()), &TFE_DeleteOp);
6194 status_check(context::get_status());
6198 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
6199 status_check(context::get_status());
6202 TFE_OpSetAttrString(op.get(),
"var_name", (
void*)var_name.c_str(), var_name.size());
6205 int num_outputs_op = 1;
6206 TFE_TensorHandle* res[1] = {
nullptr};
6207 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6208 status_check(context::get_status());
6209 return tensor(res[0]);
6212 inline tensor device_index(
const std::vector<std::string>& device_names) {
6214 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6215 TFE_NewOp(context::get_context(),
"DeviceIndex", context::get_status()), &TFE_DeleteOp);
6216 status_check(context::get_status());
6222 std::vector<std::size_t> device_names_sizes;
6223 device_names_sizes.reserve(device_names.size());
6224 std::transform(device_names.begin(), device_names.end(), std::back_inserter(device_names_sizes),
6225 [](
const auto& s) { return s.size(); });
6226 TFE_OpSetAttrStringList(op.get(),
"device_names",
reinterpret_cast<const void* const*
>(device_names.data()),
6227 device_names_sizes.data(),
static_cast<int>(device_names.size()));
6230 int num_outputs_op = 1;
6231 TFE_TensorHandle* res[1] = {
nullptr};
6232 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6233 status_check(context::get_status());
6234 return tensor(res[0]);
6237 inline tensor diag(
const tensor& diagonal) {
6239 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Diag", context::get_status()),
6241 status_check(context::get_status());
6245 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
6246 status_check(context::get_status());
6251 int num_outputs_op = 1;
6252 TFE_TensorHandle* res[1] = {
nullptr};
6253 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6254 status_check(context::get_status());
6255 return tensor(res[0]);
6258 inline tensor diag_part(
const tensor& input) {
6260 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6261 TFE_NewOp(context::get_context(),
"DiagPart", context::get_status()), &TFE_DeleteOp);
6262 status_check(context::get_status());
6266 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6267 status_check(context::get_status());
6272 int num_outputs_op = 1;
6273 TFE_TensorHandle* res[1] = {
nullptr};
6274 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6275 status_check(context::get_status());
6276 return tensor(res[0]);
6279 inline tensor digamma(
const tensor& x) {
6281 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6282 TFE_NewOp(context::get_context(),
"Digamma", context::get_status()), &TFE_DeleteOp);
6283 status_check(context::get_status());
6287 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
6288 status_check(context::get_status());
6293 int num_outputs_op = 1;
6294 TFE_TensorHandle* res[1] = {
nullptr};
6295 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6296 status_check(context::get_status());
6297 return tensor(res[0]);
6300 inline tensor dilation2_d(
const tensor& input,
const tensor& filter,
const std::vector<int64_t>& strides,
6301 const std::vector<int64_t>& rates,
const std::string& padding) {
6303 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6304 TFE_NewOp(context::get_context(),
"Dilation2D", context::get_status()), &TFE_DeleteOp);
6305 status_check(context::get_status());
6309 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6310 status_check(context::get_status());
6312 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
6313 status_check(context::get_status());
6316 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6317 TFE_OpSetAttrIntList(op.get(),
"rates", rates.data(),
static_cast<int>(rates.size()));
6318 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6321 int num_outputs_op = 1;
6322 TFE_TensorHandle* res[1] = {
nullptr};
6323 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6324 status_check(context::get_status());
6325 return tensor(res[0]);
6328 inline tensor dilation2_d_backprop_filter(
const tensor& input,
const tensor& filter,
const tensor& out_backprop,
6329 const std::vector<int64_t>& strides,
const std::vector<int64_t>& rates,
6330 const std::string& padding) {
6332 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6333 TFE_NewOp(context::get_context(),
"Dilation2DBackpropFilter", context::get_status()), &TFE_DeleteOp);
6334 status_check(context::get_status());
6338 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6339 status_check(context::get_status());
6341 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
6342 status_check(context::get_status());
6344 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
6345 status_check(context::get_status());
6348 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6349 TFE_OpSetAttrIntList(op.get(),
"rates", rates.data(),
static_cast<int>(rates.size()));
6350 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6353 int num_outputs_op = 1;
6354 TFE_TensorHandle* res[1] = {
nullptr};
6355 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6356 status_check(context::get_status());
6357 return tensor(res[0]);
6360 inline tensor dilation2_d_backprop_input(
const tensor& input,
const tensor& filter,
const tensor& out_backprop,
6361 const std::vector<int64_t>& strides,
const std::vector<int64_t>& rates,
6362 const std::string& padding) {
6364 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6365 TFE_NewOp(context::get_context(),
"Dilation2DBackpropInput", context::get_status()), &TFE_DeleteOp);
6366 status_check(context::get_status());
6370 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6371 status_check(context::get_status());
6373 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
6374 status_check(context::get_status());
6376 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
6377 status_check(context::get_status());
6380 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
6381 TFE_OpSetAttrIntList(op.get(),
"rates", rates.data(),
static_cast<int>(rates.size()));
6382 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
6385 int num_outputs_op = 1;
6386 TFE_TensorHandle* res[1] = {
nullptr};
6387 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6388 status_check(context::get_status());
6389 return tensor(res[0]);
6392 inline tensor directed_interleave_dataset(
const tensor& selector_input_dataset,
6393 const std::vector<tensor>& data_input_datasets,
6394 const std::vector<datatype>& output_types,
6395 const std::vector<std::vector<int64_t>>& output_shapes) {
6397 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6398 TFE_NewOp(context::get_context(),
"DirectedInterleaveDataset", context::get_status()), &TFE_DeleteOp);
6399 status_check(context::get_status());
6403 TFE_OpAddInput(op.get(), selector_input_dataset.tfe_handle.get(), context::get_status());
6404 status_check(context::get_status());
6406 std::vector<TFE_TensorHandle*> data_input_datasets_handles;
6407 data_input_datasets_handles.reserve(data_input_datasets.size());
6408 std::transform(data_input_datasets.begin(), data_input_datasets.end(),
6409 std::back_inserter(data_input_datasets_handles), [](
const auto& t) { return t.tfe_handle.get(); });
6410 TFE_OpAddInputList(op.get(), data_input_datasets_handles.data(),
static_cast<int>(data_input_datasets.size()),
6411 context::get_status());
6412 status_check(context::get_status());
6415 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
6416 static_cast<int>(output_types.size()));
6418 std::vector<const int64_t*> output_shapes_values;
6419 output_shapes_values.reserve(output_shapes.size());
6420 std::vector<int> output_shapes_ndims;
6421 output_shapes_ndims.reserve(output_shapes.size());
6422 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
6423 [](
const auto& v) { return v.data(); });
6424 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
6425 [](
const auto& v) { return static_cast<int>(v.size()); });
6426 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
6427 static_cast<int>(output_shapes.size()), context::get_status());
6428 status_check(context::get_status());
6430 TFE_OpSetAttrInt(op.get(),
"N", data_input_datasets.size());
6433 int num_outputs_op = 1;
6434 TFE_TensorHandle* res[1] = {
nullptr};
6435 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6436 status_check(context::get_status());
6437 return tensor(res[0]);
6440 inline tensor div(
const tensor& x,
const tensor& y) {
6442 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Div", context::get_status()),
6444 status_check(context::get_status());
6448 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
6449 status_check(context::get_status());
6451 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
6452 status_check(context::get_status());
6457 int num_outputs_op = 1;
6458 TFE_TensorHandle* res[1] = {
nullptr};
6459 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6460 status_check(context::get_status());
6461 return tensor(res[0]);
6464 inline tensor div_no_nan(
const tensor& x,
const tensor& y) {
6466 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6467 TFE_NewOp(context::get_context(),
"DivNoNan", context::get_status()), &TFE_DeleteOp);
6468 status_check(context::get_status());
6472 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
6473 status_check(context::get_status());
6475 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
6476 status_check(context::get_status());
6481 int num_outputs_op = 1;
6482 TFE_TensorHandle* res[1] = {
nullptr};
6483 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6484 status_check(context::get_status());
6485 return tensor(res[0]);
6488 inline tensor draw_bounding_boxes(
const tensor& images,
const tensor& boxes) {
6490 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6491 TFE_NewOp(context::get_context(),
"DrawBoundingBoxes", context::get_status()), &TFE_DeleteOp);
6492 status_check(context::get_status());
6496 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
6497 status_check(context::get_status());
6499 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
6500 status_check(context::get_status());
6505 int num_outputs_op = 1;
6506 TFE_TensorHandle* res[1] = {
nullptr};
6507 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6508 status_check(context::get_status());
6509 return tensor(res[0]);
6512 inline tensor draw_bounding_boxes_v2(
const tensor& images,
const tensor& boxes,
const tensor& colors) {
6514 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6515 TFE_NewOp(context::get_context(),
"DrawBoundingBoxesV2", context::get_status()), &TFE_DeleteOp);
6516 status_check(context::get_status());
6520 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
6521 status_check(context::get_status());
6523 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
6524 status_check(context::get_status());
6526 TFE_OpAddInput(op.get(), colors.tfe_handle.get(), context::get_status());
6527 status_check(context::get_status());
6532 int num_outputs_op = 1;
6533 TFE_TensorHandle* res[1] = {
nullptr};
6534 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6535 status_check(context::get_status());
6536 return tensor(res[0]);
6539 inline tensor dummy_iteration_counter() {
6541 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6542 TFE_NewOp(context::get_context(),
"DummyIterationCounter", context::get_status()), &TFE_DeleteOp);
6543 status_check(context::get_status());
6550 int num_outputs_op = 1;
6551 TFE_TensorHandle* res[1] = {
nullptr};
6552 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6553 status_check(context::get_status());
6554 return tensor(res[0]);
6557 inline tensor dummy_memory_cache() {
6559 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6560 TFE_NewOp(context::get_context(),
"DummyMemoryCache", context::get_status()), &TFE_DeleteOp);
6561 status_check(context::get_status());
6568 int num_outputs_op = 1;
6569 TFE_TensorHandle* res[1] = {
nullptr};
6570 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6571 status_check(context::get_status());
6572 return tensor(res[0]);
6575 inline tensor dummy_seed_generator() {
6577 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6578 TFE_NewOp(context::get_context(),
"DummySeedGenerator", context::get_status()), &TFE_DeleteOp);
6579 status_check(context::get_status());
6586 int num_outputs_op = 1;
6587 TFE_TensorHandle* res[1] = {
nullptr};
6588 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6589 status_check(context::get_status());
6590 return tensor(res[0]);
6593 inline tensor dynamic_partition(
const tensor& data,
const tensor& partitions, int64_t num_partitions) {
6595 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6596 TFE_NewOp(context::get_context(),
"DynamicPartition", context::get_status()), &TFE_DeleteOp);
6597 status_check(context::get_status());
6601 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
6602 status_check(context::get_status());
6604 TFE_OpAddInput(op.get(), partitions.tfe_handle.get(), context::get_status());
6605 status_check(context::get_status());
6608 TFE_OpSetAttrInt(op.get(),
"num_partitions", num_partitions);
6611 int num_outputs_op = 1;
6612 TFE_TensorHandle* res[1] = {
nullptr};
6613 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6614 status_check(context::get_status());
6615 return tensor(res[0]);
6618 inline tensor dynamic_stitch(
const std::vector<tensor>& indices,
const std::vector<tensor>& data) {
6620 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6621 TFE_NewOp(context::get_context(),
"DynamicStitch", context::get_status()), &TFE_DeleteOp);
6622 status_check(context::get_status());
6626 std::vector<TFE_TensorHandle*> indices_handles;
6627 indices_handles.reserve(indices.size());
6628 std::transform(indices.begin(), indices.end(), std::back_inserter(indices_handles),
6629 [](
const auto& t) { return t.tfe_handle.get(); });
6630 TFE_OpAddInputList(op.get(), indices_handles.data(),
static_cast<int>(indices.size()), context::get_status());
6631 status_check(context::get_status());
6633 std::vector<TFE_TensorHandle*> data_handles;
6634 data_handles.reserve(data.size());
6635 std::transform(data.begin(), data.end(), std::back_inserter(data_handles),
6636 [](
const auto& t) { return t.tfe_handle.get(); });
6637 TFE_OpAddInputList(op.get(), data_handles.data(),
static_cast<int>(data.size()), context::get_status());
6638 status_check(context::get_status());
6641 TFE_OpSetAttrInt(op.get(),
"N", indices.size());
6644 int num_outputs_op = 1;
6645 TFE_TensorHandle* res[1] = {
nullptr};
6646 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6647 status_check(context::get_status());
6648 return tensor(res[0]);
6651 inline tensor eager_py_func(
const std::vector<tensor>& input,
const std::string& token,
6652 const std::vector<datatype>& Tin,
const std::vector<datatype>& Tout,
6653 bool is_async =
false) {
6655 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6656 TFE_NewOp(context::get_context(),
"EagerPyFunc", context::get_status()), &TFE_DeleteOp);
6657 status_check(context::get_status());
6661 std::vector<TFE_TensorHandle*> input_handles;
6662 input_handles.reserve(input.size());
6663 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
6664 [](
const auto& t) { return t.tfe_handle.get(); });
6665 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
6666 status_check(context::get_status());
6669 TFE_OpSetAttrString(op.get(),
"token", (
void*)token.c_str(), token.size());
6670 TFE_OpSetAttrTypeList(op.get(),
"Tin",
reinterpret_cast<const enum TF_DataType*
>(Tin.data()),
6671 static_cast<int>(Tin.size()));
6672 TFE_OpSetAttrTypeList(op.get(),
"Tout",
reinterpret_cast<const enum TF_DataType*
>(Tout.data()),
6673 static_cast<int>(Tout.size()));
6674 TFE_OpSetAttrBool(op.get(),
"is_async", (
unsigned char)is_async);
6677 int num_outputs_op = 1;
6678 TFE_TensorHandle* res[1] = {
nullptr};
6679 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6680 status_check(context::get_status());
6681 return tensor(res[0]);
6684 inline tensor edit_distance(
const tensor& hypothesis_indices,
const tensor& hypothesis_values,
6685 const tensor& hypothesis_shape,
const tensor& truth_indices,
const tensor& truth_values,
6686 const tensor& truth_shape,
bool normalize =
true) {
6688 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6689 TFE_NewOp(context::get_context(),
"EditDistance", context::get_status()), &TFE_DeleteOp);
6690 status_check(context::get_status());
6694 TFE_OpAddInput(op.get(), hypothesis_indices.tfe_handle.get(), context::get_status());
6695 status_check(context::get_status());
6697 TFE_OpAddInput(op.get(), hypothesis_values.tfe_handle.get(), context::get_status());
6698 status_check(context::get_status());
6700 TFE_OpAddInput(op.get(), hypothesis_shape.tfe_handle.get(), context::get_status());
6701 status_check(context::get_status());
6703 TFE_OpAddInput(op.get(), truth_indices.tfe_handle.get(), context::get_status());
6704 status_check(context::get_status());
6706 TFE_OpAddInput(op.get(), truth_values.tfe_handle.get(), context::get_status());
6707 status_check(context::get_status());
6709 TFE_OpAddInput(op.get(), truth_shape.tfe_handle.get(), context::get_status());
6710 status_check(context::get_status());
6713 TFE_OpSetAttrBool(op.get(),
"normalize", (
unsigned char)normalize);
6716 int num_outputs_op = 1;
6717 TFE_TensorHandle* res[1] = {
nullptr};
6718 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6719 status_check(context::get_status());
6720 return tensor(res[0]);
6723 inline tensor einsum(
const std::vector<tensor>& inputs,
const std::string& equation) {
6725 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6726 TFE_NewOp(context::get_context(),
"Einsum", context::get_status()), &TFE_DeleteOp);
6727 status_check(context::get_status());
6731 std::vector<TFE_TensorHandle*> inputs_handles;
6732 inputs_handles.reserve(inputs.size());
6733 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
6734 [](
const auto& t) { return t.tfe_handle.get(); });
6735 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
6736 status_check(context::get_status());
6739 TFE_OpSetAttrString(op.get(),
"equation", (
void*)equation.c_str(), equation.size());
6740 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
6743 int num_outputs_op = 1;
6744 TFE_TensorHandle* res[1] = {
nullptr};
6745 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6746 status_check(context::get_status());
6747 return tensor(res[0]);
6750 inline tensor elu(
const tensor& features) {
6752 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Elu", context::get_status()),
6754 status_check(context::get_status());
6758 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
6759 status_check(context::get_status());
6764 int num_outputs_op = 1;
6765 TFE_TensorHandle* res[1] = {
nullptr};
6766 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6767 status_check(context::get_status());
6768 return tensor(res[0]);
6771 inline tensor elu_grad(
const tensor& gradients,
const tensor& outputs) {
6773 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6774 TFE_NewOp(context::get_context(),
"EluGrad", context::get_status()), &TFE_DeleteOp);
6775 status_check(context::get_status());
6779 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
6780 status_check(context::get_status());
6782 TFE_OpAddInput(op.get(), outputs.tfe_handle.get(), context::get_status());
6783 status_check(context::get_status());
6788 int num_outputs_op = 1;
6789 TFE_TensorHandle* res[1] = {
nullptr};
6790 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6791 status_check(context::get_status());
6792 return tensor(res[0]);
6795 inline tensor empty(
const tensor& shape, datatype dtype,
bool init =
false) {
6797 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Empty", context::get_status()),
6799 status_check(context::get_status());
6803 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
6804 status_check(context::get_status());
6807 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
6808 TFE_OpSetAttrBool(op.get(),
"init", (
unsigned char)init);
6811 int num_outputs_op = 1;
6812 TFE_TensorHandle* res[1] = {
nullptr};
6813 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6814 status_check(context::get_status());
6815 return tensor(res[0]);
6818 inline tensor empty_tensor_list(
const tensor& element_shape,
const tensor& max_num_elements, datatype element_dtype,
6819 datatype shape_type) {
6821 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6822 TFE_NewOp(context::get_context(),
"EmptyTensorList", context::get_status()), &TFE_DeleteOp);
6823 status_check(context::get_status());
6827 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
6828 status_check(context::get_status());
6830 TFE_OpAddInput(op.get(), max_num_elements.tfe_handle.get(), context::get_status());
6831 status_check(context::get_status());
6834 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
6835 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
6838 int num_outputs_op = 1;
6839 TFE_TensorHandle* res[1] = {
nullptr};
6840 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6841 status_check(context::get_status());
6842 return tensor(res[0]);
6845 inline tensor encode_base64(
const tensor& input,
bool pad =
false) {
6847 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6848 TFE_NewOp(context::get_context(),
"EncodeBase64", context::get_status()), &TFE_DeleteOp);
6849 status_check(context::get_status());
6853 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
6854 status_check(context::get_status());
6857 TFE_OpSetAttrBool(op.get(),
"pad", (
unsigned char)pad);
6860 int num_outputs_op = 1;
6861 TFE_TensorHandle* res[1] = {
nullptr};
6862 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6863 status_check(context::get_status());
6864 return tensor(res[0]);
6867 inline tensor encode_jpeg(
const tensor& image,
const std::string& format =
"", int64_t quality = 95,
6868 bool progressive =
false,
bool optimize_size =
false,
bool chroma_downsampling =
true,
6869 const std::string& density_unit =
"in", int64_t x_density = 300, int64_t y_density = 300,
6870 const std::string& xmp_metadata =
"") {
6872 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6873 TFE_NewOp(context::get_context(),
"EncodeJpeg", context::get_status()), &TFE_DeleteOp);
6874 status_check(context::get_status());
6878 TFE_OpAddInput(op.get(), image.tfe_handle.get(), context::get_status());
6879 status_check(context::get_status());
6882 TFE_OpSetAttrString(op.get(),
"format", (
void*)format.c_str(), format.size());
6883 TFE_OpSetAttrInt(op.get(),
"quality", quality);
6884 TFE_OpSetAttrBool(op.get(),
"progressive", (
unsigned char)progressive);
6885 TFE_OpSetAttrBool(op.get(),
"optimize_size", (
unsigned char)optimize_size);
6886 TFE_OpSetAttrBool(op.get(),
"chroma_downsampling", (
unsigned char)chroma_downsampling);
6887 TFE_OpSetAttrString(op.get(),
"density_unit", (
void*)density_unit.c_str(), density_unit.size());
6888 TFE_OpSetAttrInt(op.get(),
"x_density", x_density);
6889 TFE_OpSetAttrInt(op.get(),
"y_density", y_density);
6890 TFE_OpSetAttrString(op.get(),
"xmp_metadata", (
void*)xmp_metadata.c_str(), xmp_metadata.size());
6893 int num_outputs_op = 1;
6894 TFE_TensorHandle* res[1] = {
nullptr};
6895 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6896 status_check(context::get_status());
6897 return tensor(res[0]);
6900 inline tensor encode_jpeg_variable_quality(
const tensor& images,
const tensor& quality) {
6902 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6903 TFE_NewOp(context::get_context(),
"EncodeJpegVariableQuality", context::get_status()), &TFE_DeleteOp);
6904 status_check(context::get_status());
6908 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
6909 status_check(context::get_status());
6911 TFE_OpAddInput(op.get(), quality.tfe_handle.get(), context::get_status());
6912 status_check(context::get_status());
6917 int num_outputs_op = 1;
6918 TFE_TensorHandle* res[1] = {
nullptr};
6919 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6920 status_check(context::get_status());
6921 return tensor(res[0]);
6924 inline tensor encode_png(
const tensor& image, int64_t compression = -1) {
6926 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6927 TFE_NewOp(context::get_context(),
"EncodePng", context::get_status()), &TFE_DeleteOp);
6928 status_check(context::get_status());
6932 TFE_OpAddInput(op.get(), image.tfe_handle.get(), context::get_status());
6933 status_check(context::get_status());
6936 TFE_OpSetAttrInt(op.get(),
"compression", compression);
6939 int num_outputs_op = 1;
6940 TFE_TensorHandle* res[1] = {
nullptr};
6941 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6942 status_check(context::get_status());
6943 return tensor(res[0]);
6946 inline tensor encode_proto(
const tensor& sizes,
const std::vector<tensor>& values,
6947 const std::vector<std::string>& field_names,
const std::string& message_type,
6948 const std::vector<datatype>& Tinput_types,
6949 const std::string& descriptor_source =
"local://") {
6951 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6952 TFE_NewOp(context::get_context(),
"EncodeProto", context::get_status()), &TFE_DeleteOp);
6953 status_check(context::get_status());
6957 TFE_OpAddInput(op.get(), sizes.tfe_handle.get(), context::get_status());
6958 status_check(context::get_status());
6960 std::vector<TFE_TensorHandle*> values_handles;
6961 values_handles.reserve(values.size());
6962 std::transform(values.begin(), values.end(), std::back_inserter(values_handles),
6963 [](
const auto& t) { return t.tfe_handle.get(); });
6964 TFE_OpAddInputList(op.get(), values_handles.data(),
static_cast<int>(values.size()), context::get_status());
6965 status_check(context::get_status());
6969 std::vector<std::size_t> field_names_sizes;
6970 field_names_sizes.reserve(field_names.size());
6971 std::transform(field_names.begin(), field_names.end(), std::back_inserter(field_names_sizes),
6972 [](
const auto& s) { return s.size(); });
6973 TFE_OpSetAttrStringList(op.get(),
"field_names",
reinterpret_cast<const void* const*
>(field_names.data()),
6974 field_names_sizes.data(),
static_cast<int>(field_names.size()));
6976 TFE_OpSetAttrString(op.get(),
"message_type", (
void*)message_type.c_str(), message_type.size());
6977 TFE_OpSetAttrTypeList(op.get(),
"Tinput_types",
reinterpret_cast<const enum TF_DataType*
>(Tinput_types.data()),
6978 static_cast<int>(Tinput_types.size()));
6979 TFE_OpSetAttrString(op.get(),
"descriptor_source", (
void*)descriptor_source.c_str(), descriptor_source.size());
6982 int num_outputs_op = 1;
6983 TFE_TensorHandle* res[1] = {
nullptr};
6984 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6985 status_check(context::get_status());
6986 return tensor(res[0]);
6989 inline tensor encode_wav(
const tensor& audio,
const tensor& sample_rate) {
6991 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
6992 TFE_NewOp(context::get_context(),
"EncodeWav", context::get_status()), &TFE_DeleteOp);
6993 status_check(context::get_status());
6997 TFE_OpAddInput(op.get(), audio.tfe_handle.get(), context::get_status());
6998 status_check(context::get_status());
7000 TFE_OpAddInput(op.get(), sample_rate.tfe_handle.get(), context::get_status());
7001 status_check(context::get_status());
7006 int num_outputs_op = 1;
7007 TFE_TensorHandle* res[1] = {
nullptr};
7008 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7009 status_check(context::get_status());
7010 return tensor(res[0]);
7013 inline tensor ensure_shape(
const tensor& input,
const std::vector<int64_t>& shape) {
7015 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7016 TFE_NewOp(context::get_context(),
"EnsureShape", context::get_status()), &TFE_DeleteOp);
7017 status_check(context::get_status());
7021 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
7022 status_check(context::get_status());
7026 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
7027 status_check(context::get_status());
7030 int num_outputs_op = 1;
7031 TFE_TensorHandle* res[1] = {
nullptr};
7032 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7033 status_check(context::get_status());
7034 return tensor(res[0]);
7037 inline tensor enter(
const tensor& data,
const std::string& frame_name,
bool is_constant =
false,
7038 int64_t parallel_iterations = 10) {
7040 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Enter", context::get_status()),
7042 status_check(context::get_status());
7046 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
7047 status_check(context::get_status());
7050 TFE_OpSetAttrString(op.get(),
"frame_name", (
void*)frame_name.c_str(), frame_name.size());
7051 TFE_OpSetAttrBool(op.get(),
"is_constant", (
unsigned char)is_constant);
7052 TFE_OpSetAttrInt(op.get(),
"parallel_iterations", parallel_iterations);
7055 int num_outputs_op = 1;
7056 TFE_TensorHandle* res[1] = {
nullptr};
7057 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7058 status_check(context::get_status());
7059 return tensor(res[0]);
7062 inline tensor equal(
const tensor& x,
const tensor& y,
bool incompatible_shape_error =
true) {
7064 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Equal", context::get_status()),
7066 status_check(context::get_status());
7070 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
7071 status_check(context::get_status());
7073 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
7074 status_check(context::get_status());
7077 TFE_OpSetAttrBool(op.get(),
"incompatible_shape_error", (
unsigned char)incompatible_shape_error);
7080 int num_outputs_op = 1;
7081 TFE_TensorHandle* res[1] = {
nullptr};
7082 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7083 status_check(context::get_status());
7084 return tensor(res[0]);
7087 inline tensor erf(
const tensor& x) {
7089 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Erf", context::get_status()),
7091 status_check(context::get_status());
7095 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
7096 status_check(context::get_status());
7101 int num_outputs_op = 1;
7102 TFE_TensorHandle* res[1] = {
nullptr};
7103 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7104 status_check(context::get_status());
7105 return tensor(res[0]);
7108 inline tensor erfc(
const tensor& x) {
7110 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Erfc", context::get_status()),
7112 status_check(context::get_status());
7116 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
7117 status_check(context::get_status());
7122 int num_outputs_op = 1;
7123 TFE_TensorHandle* res[1] = {
nullptr};
7124 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7125 status_check(context::get_status());
7126 return tensor(res[0]);
7129 inline tensor erfinv(
const tensor& x) {
7131 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7132 TFE_NewOp(context::get_context(),
"Erfinv", context::get_status()), &TFE_DeleteOp);
7133 status_check(context::get_status());
7137 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
7138 status_check(context::get_status());
7143 int num_outputs_op = 1;
7144 TFE_TensorHandle* res[1] = {
nullptr};
7145 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7146 status_check(context::get_status());
7147 return tensor(res[0]);
7150 inline tensor euclidean_norm(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
7151 datatype Tidx =
static_cast<datatype
>(3)) {
7153 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7154 TFE_NewOp(context::get_context(),
"EuclideanNorm", context::get_status()), &TFE_DeleteOp);
7155 status_check(context::get_status());
7159 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
7160 status_check(context::get_status());
7162 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
7163 status_check(context::get_status());
7166 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
7167 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
7170 int num_outputs_op = 1;
7171 TFE_TensorHandle* res[1] = {
nullptr};
7172 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7173 status_check(context::get_status());
7174 return tensor(res[0]);
7177 inline tensor exit(
const tensor& data) {
7179 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Exit", context::get_status()),
7181 status_check(context::get_status());
7185 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
7186 status_check(context::get_status());
7191 int num_outputs_op = 1;
7192 TFE_TensorHandle* res[1] = {
nullptr};
7193 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7194 status_check(context::get_status());
7195 return tensor(res[0]);
7198 inline tensor exp(
const tensor& x) {
7200 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Exp", context::get_status()),
7202 status_check(context::get_status());
7206 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
7207 status_check(context::get_status());
7212 int num_outputs_op = 1;
7213 TFE_TensorHandle* res[1] = {
nullptr};
7214 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7215 status_check(context::get_status());
7216 return tensor(res[0]);
7219 inline tensor expand_dims(
const tensor& input,
const tensor& dim, datatype Tdim =
static_cast<datatype
>(3)) {
7221 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7222 TFE_NewOp(context::get_context(),
"ExpandDims", context::get_status()), &TFE_DeleteOp);
7223 status_check(context::get_status());
7227 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
7228 status_check(context::get_status());
7230 TFE_OpAddInput(op.get(), dim.tfe_handle.get(), context::get_status());
7231 status_check(context::get_status());
7234 TFE_OpSetAttrType(op.get(),
"Tdim", Tdim);
7237 int num_outputs_op = 1;
7238 TFE_TensorHandle* res[1] = {
nullptr};
7239 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7240 status_check(context::get_status());
7241 return tensor(res[0]);
7244 inline tensor experimental_assert_next_dataset(
const tensor& input_dataset,
const tensor& transformations,
7245 const std::vector<datatype>& output_types,
7246 const std::vector<std::vector<int64_t>>& output_shapes) {
7248 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7249 TFE_NewOp(context::get_context(),
"ExperimentalAssertNextDataset", context::get_status()), &TFE_DeleteOp);
7250 status_check(context::get_status());
7254 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7255 status_check(context::get_status());
7257 TFE_OpAddInput(op.get(), transformations.tfe_handle.get(), context::get_status());
7258 status_check(context::get_status());
7261 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7262 static_cast<int>(output_types.size()));
7264 std::vector<const int64_t*> output_shapes_values;
7265 output_shapes_values.reserve(output_shapes.size());
7266 std::vector<int> output_shapes_ndims;
7267 output_shapes_ndims.reserve(output_shapes.size());
7268 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7269 [](
const auto& v) { return v.data(); });
7270 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7271 [](
const auto& v) { return static_cast<int>(v.size()); });
7272 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7273 static_cast<int>(output_shapes.size()), context::get_status());
7274 status_check(context::get_status());
7277 int num_outputs_op = 1;
7278 TFE_TensorHandle* res[1] = {
nullptr};
7279 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7280 status_check(context::get_status());
7281 return tensor(res[0]);
7284 inline tensor experimental_auto_shard_dataset(
const tensor& input_dataset,
const tensor& num_workers,
7285 const tensor& index,
const std::vector<datatype>& output_types,
7286 const std::vector<std::vector<int64_t>>& output_shapes,
7287 int64_t auto_shard_policy = 0) {
7289 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7290 TFE_NewOp(context::get_context(),
"ExperimentalAutoShardDataset", context::get_status()), &TFE_DeleteOp);
7291 status_check(context::get_status());
7295 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7296 status_check(context::get_status());
7298 TFE_OpAddInput(op.get(), num_workers.tfe_handle.get(), context::get_status());
7299 status_check(context::get_status());
7301 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
7302 status_check(context::get_status());
7305 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7306 static_cast<int>(output_types.size()));
7308 std::vector<const int64_t*> output_shapes_values;
7309 output_shapes_values.reserve(output_shapes.size());
7310 std::vector<int> output_shapes_ndims;
7311 output_shapes_ndims.reserve(output_shapes.size());
7312 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7313 [](
const auto& v) { return v.data(); });
7314 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7315 [](
const auto& v) { return static_cast<int>(v.size()); });
7316 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7317 static_cast<int>(output_shapes.size()), context::get_status());
7318 status_check(context::get_status());
7320 TFE_OpSetAttrInt(op.get(),
"auto_shard_policy", auto_shard_policy);
7323 int num_outputs_op = 1;
7324 TFE_TensorHandle* res[1] = {
nullptr};
7325 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7326 status_check(context::get_status());
7327 return tensor(res[0]);
7330 inline tensor experimental_bytes_produced_stats_dataset(
const tensor& input_dataset,
const tensor& tag,
7331 const std::vector<datatype>& output_types,
7332 const std::vector<std::vector<int64_t>>& output_shapes) {
7334 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7335 TFE_NewOp(context::get_context(),
"ExperimentalBytesProducedStatsDataset", context::get_status()), &TFE_DeleteOp);
7336 status_check(context::get_status());
7340 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7341 status_check(context::get_status());
7343 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
7344 status_check(context::get_status());
7347 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7348 static_cast<int>(output_types.size()));
7350 std::vector<const int64_t*> output_shapes_values;
7351 output_shapes_values.reserve(output_shapes.size());
7352 std::vector<int> output_shapes_ndims;
7353 output_shapes_ndims.reserve(output_shapes.size());
7354 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7355 [](
const auto& v) { return v.data(); });
7356 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7357 [](
const auto& v) { return static_cast<int>(v.size()); });
7358 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7359 static_cast<int>(output_shapes.size()), context::get_status());
7360 status_check(context::get_status());
7363 int num_outputs_op = 1;
7364 TFE_TensorHandle* res[1] = {
nullptr};
7365 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7366 status_check(context::get_status());
7367 return tensor(res[0]);
7370 inline tensor experimental_c_s_v_dataset(
const tensor& filenames,
const tensor& compression_type,
7371 const tensor& buffer_size,
const tensor& header,
const tensor& field_delim,
7372 const tensor& use_quote_delim,
const tensor& na_value,
7373 const tensor& select_cols,
const std::vector<tensor>& record_defaults,
7374 const std::vector<datatype>& output_types,
7375 const std::vector<std::vector<int64_t>>& output_shapes) {
7377 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7378 TFE_NewOp(context::get_context(),
"ExperimentalCSVDataset", context::get_status()), &TFE_DeleteOp);
7379 status_check(context::get_status());
7383 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
7384 status_check(context::get_status());
7386 TFE_OpAddInput(op.get(), compression_type.tfe_handle.get(), context::get_status());
7387 status_check(context::get_status());
7389 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
7390 status_check(context::get_status());
7392 TFE_OpAddInput(op.get(), header.tfe_handle.get(), context::get_status());
7393 status_check(context::get_status());
7395 TFE_OpAddInput(op.get(), field_delim.tfe_handle.get(), context::get_status());
7396 status_check(context::get_status());
7398 TFE_OpAddInput(op.get(), use_quote_delim.tfe_handle.get(), context::get_status());
7399 status_check(context::get_status());
7401 TFE_OpAddInput(op.get(), na_value.tfe_handle.get(), context::get_status());
7402 status_check(context::get_status());
7404 TFE_OpAddInput(op.get(), select_cols.tfe_handle.get(), context::get_status());
7405 status_check(context::get_status());
7407 std::vector<TFE_TensorHandle*> record_defaults_handles;
7408 record_defaults_handles.reserve(record_defaults.size());
7409 std::transform(record_defaults.begin(), record_defaults.end(), std::back_inserter(record_defaults_handles),
7410 [](
const auto& t) { return t.tfe_handle.get(); });
7411 TFE_OpAddInputList(op.get(), record_defaults_handles.data(),
static_cast<int>(record_defaults.size()),
7412 context::get_status());
7413 status_check(context::get_status());
7416 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7417 static_cast<int>(output_types.size()));
7419 std::vector<const int64_t*> output_shapes_values;
7420 output_shapes_values.reserve(output_shapes.size());
7421 std::vector<int> output_shapes_ndims;
7422 output_shapes_ndims.reserve(output_shapes.size());
7423 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7424 [](
const auto& v) { return v.data(); });
7425 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7426 [](
const auto& v) { return static_cast<int>(v.size()); });
7427 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7428 static_cast<int>(output_shapes.size()), context::get_status());
7429 status_check(context::get_status());
7432 int num_outputs_op = 1;
7433 TFE_TensorHandle* res[1] = {
nullptr};
7434 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7435 status_check(context::get_status());
7436 return tensor(res[0]);
7439 inline tensor experimental_choose_fastest_dataset(
const std::vector<tensor>& input_datasets, int64_t num_experiments,
7440 const std::vector<datatype>& output_types,
7441 const std::vector<std::vector<int64_t>>& output_shapes) {
7443 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7444 TFE_NewOp(context::get_context(),
"ExperimentalChooseFastestDataset", context::get_status()), &TFE_DeleteOp);
7445 status_check(context::get_status());
7449 std::vector<TFE_TensorHandle*> input_datasets_handles;
7450 input_datasets_handles.reserve(input_datasets.size());
7451 std::transform(input_datasets.begin(), input_datasets.end(), std::back_inserter(input_datasets_handles),
7452 [](
const auto& t) { return t.tfe_handle.get(); });
7453 TFE_OpAddInputList(op.get(), input_datasets_handles.data(),
static_cast<int>(input_datasets.size()),
7454 context::get_status());
7455 status_check(context::get_status());
7458 TFE_OpSetAttrInt(op.get(),
"N", input_datasets.size());
7459 TFE_OpSetAttrInt(op.get(),
"num_experiments", num_experiments);
7460 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7461 static_cast<int>(output_types.size()));
7463 std::vector<const int64_t*> output_shapes_values;
7464 output_shapes_values.reserve(output_shapes.size());
7465 std::vector<int> output_shapes_ndims;
7466 output_shapes_ndims.reserve(output_shapes.size());
7467 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7468 [](
const auto& v) { return v.data(); });
7469 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7470 [](
const auto& v) { return static_cast<int>(v.size()); });
7471 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7472 static_cast<int>(output_shapes.size()), context::get_status());
7473 status_check(context::get_status());
7476 int num_outputs_op = 1;
7477 TFE_TensorHandle* res[1] = {
nullptr};
7478 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7479 status_check(context::get_status());
7480 return tensor(res[0]);
7483 inline tensor experimental_dataset_cardinality(
const tensor& input_dataset) {
7485 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7486 TFE_NewOp(context::get_context(),
"ExperimentalDatasetCardinality", context::get_status()), &TFE_DeleteOp);
7487 status_check(context::get_status());
7491 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7492 status_check(context::get_status());
7497 int num_outputs_op = 1;
7498 TFE_TensorHandle* res[1] = {
nullptr};
7499 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7500 status_check(context::get_status());
7501 return tensor(res[0]);
7504 inline tensor experimental_dense_to_sparse_batch_dataset(
const tensor& input_dataset,
const tensor& batch_size,
7505 const tensor& row_shape,
7506 const std::vector<datatype>& output_types,
7507 const std::vector<std::vector<int64_t>>& output_shapes) {
7509 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7510 TFE_NewOp(context::get_context(),
"ExperimentalDenseToSparseBatchDataset", context::get_status()), &TFE_DeleteOp);
7511 status_check(context::get_status());
7515 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7516 status_check(context::get_status());
7518 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
7519 status_check(context::get_status());
7521 TFE_OpAddInput(op.get(), row_shape.tfe_handle.get(), context::get_status());
7522 status_check(context::get_status());
7525 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7526 static_cast<int>(output_types.size()));
7528 std::vector<const int64_t*> output_shapes_values;
7529 output_shapes_values.reserve(output_shapes.size());
7530 std::vector<int> output_shapes_ndims;
7531 output_shapes_ndims.reserve(output_shapes.size());
7532 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7533 [](
const auto& v) { return v.data(); });
7534 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7535 [](
const auto& v) { return static_cast<int>(v.size()); });
7536 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7537 static_cast<int>(output_shapes.size()), context::get_status());
7538 status_check(context::get_status());
7541 int num_outputs_op = 1;
7542 TFE_TensorHandle* res[1] = {
nullptr};
7543 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7544 status_check(context::get_status());
7545 return tensor(res[0]);
7548 inline tensor experimental_directed_interleave_dataset(
const tensor& selector_input_dataset,
7549 const std::vector<tensor>& data_input_datasets,
7550 const std::vector<datatype>& output_types,
7551 const std::vector<std::vector<int64_t>>& output_shapes) {
7553 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7554 TFE_NewOp(context::get_context(),
"ExperimentalDirectedInterleaveDataset", context::get_status()), &TFE_DeleteOp);
7555 status_check(context::get_status());
7559 TFE_OpAddInput(op.get(), selector_input_dataset.tfe_handle.get(), context::get_status());
7560 status_check(context::get_status());
7562 std::vector<TFE_TensorHandle*> data_input_datasets_handles;
7563 data_input_datasets_handles.reserve(data_input_datasets.size());
7564 std::transform(data_input_datasets.begin(), data_input_datasets.end(),
7565 std::back_inserter(data_input_datasets_handles), [](
const auto& t) { return t.tfe_handle.get(); });
7566 TFE_OpAddInputList(op.get(), data_input_datasets_handles.data(),
static_cast<int>(data_input_datasets.size()),
7567 context::get_status());
7568 status_check(context::get_status());
7571 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7572 static_cast<int>(output_types.size()));
7574 std::vector<const int64_t*> output_shapes_values;
7575 output_shapes_values.reserve(output_shapes.size());
7576 std::vector<int> output_shapes_ndims;
7577 output_shapes_ndims.reserve(output_shapes.size());
7578 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7579 [](
const auto& v) { return v.data(); });
7580 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7581 [](
const auto& v) { return static_cast<int>(v.size()); });
7582 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7583 static_cast<int>(output_shapes.size()), context::get_status());
7584 status_check(context::get_status());
7586 TFE_OpSetAttrInt(op.get(),
"N", data_input_datasets.size());
7589 int num_outputs_op = 1;
7590 TFE_TensorHandle* res[1] = {
nullptr};
7591 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7592 status_check(context::get_status());
7593 return tensor(res[0]);
7596 inline tensor experimental_ignore_errors_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
7597 const std::vector<std::vector<int64_t>>& output_shapes) {
7599 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7600 TFE_NewOp(context::get_context(),
"ExperimentalIgnoreErrorsDataset", context::get_status()), &TFE_DeleteOp);
7601 status_check(context::get_status());
7605 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7606 status_check(context::get_status());
7609 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7610 static_cast<int>(output_types.size()));
7612 std::vector<const int64_t*> output_shapes_values;
7613 output_shapes_values.reserve(output_shapes.size());
7614 std::vector<int> output_shapes_ndims;
7615 output_shapes_ndims.reserve(output_shapes.size());
7616 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7617 [](
const auto& v) { return v.data(); });
7618 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7619 [](
const auto& v) { return static_cast<int>(v.size()); });
7620 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7621 static_cast<int>(output_shapes.size()), context::get_status());
7622 status_check(context::get_status());
7625 int num_outputs_op = 1;
7626 TFE_TensorHandle* res[1] = {
nullptr};
7627 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7628 status_check(context::get_status());
7629 return tensor(res[0]);
7632 inline tensor experimental_iterator_get_device(
const tensor& resource) {
7634 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7635 TFE_NewOp(context::get_context(),
"ExperimentalIteratorGetDevice", context::get_status()), &TFE_DeleteOp);
7636 status_check(context::get_status());
7640 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
7641 status_check(context::get_status());
7646 int num_outputs_op = 1;
7647 TFE_TensorHandle* res[1] = {
nullptr};
7648 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7649 status_check(context::get_status());
7650 return tensor(res[0]);
7653 inline tensor experimental_l_m_d_b_dataset(
const tensor& filenames,
const std::vector<datatype>& output_types,
7654 const std::vector<std::vector<int64_t>>& output_shapes) {
7656 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7657 TFE_NewOp(context::get_context(),
"ExperimentalLMDBDataset", context::get_status()), &TFE_DeleteOp);
7658 status_check(context::get_status());
7662 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
7663 status_check(context::get_status());
7666 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7667 static_cast<int>(output_types.size()));
7669 std::vector<const int64_t*> output_shapes_values;
7670 output_shapes_values.reserve(output_shapes.size());
7671 std::vector<int> output_shapes_ndims;
7672 output_shapes_ndims.reserve(output_shapes.size());
7673 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7674 [](
const auto& v) { return v.data(); });
7675 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7676 [](
const auto& v) { return static_cast<int>(v.size()); });
7677 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7678 static_cast<int>(output_shapes.size()), context::get_status());
7679 status_check(context::get_status());
7682 int num_outputs_op = 1;
7683 TFE_TensorHandle* res[1] = {
nullptr};
7684 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7685 status_check(context::get_status());
7686 return tensor(res[0]);
7689 inline tensor experimental_latency_stats_dataset(
const tensor& input_dataset,
const tensor& tag,
7690 const std::vector<datatype>& output_types,
7691 const std::vector<std::vector<int64_t>>& output_shapes) {
7693 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7694 TFE_NewOp(context::get_context(),
"ExperimentalLatencyStatsDataset", context::get_status()), &TFE_DeleteOp);
7695 status_check(context::get_status());
7699 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7700 status_check(context::get_status());
7702 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
7703 status_check(context::get_status());
7706 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7707 static_cast<int>(output_types.size()));
7709 std::vector<const int64_t*> output_shapes_values;
7710 output_shapes_values.reserve(output_shapes.size());
7711 std::vector<int> output_shapes_ndims;
7712 output_shapes_ndims.reserve(output_shapes.size());
7713 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7714 [](
const auto& v) { return v.data(); });
7715 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7716 [](
const auto& v) { return static_cast<int>(v.size()); });
7717 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7718 static_cast<int>(output_shapes.size()), context::get_status());
7719 status_check(context::get_status());
7722 int num_outputs_op = 1;
7723 TFE_TensorHandle* res[1] = {
nullptr};
7724 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7725 status_check(context::get_status());
7726 return tensor(res[0]);
7729 inline tensor experimental_matching_files_dataset(
const tensor& patterns) {
7731 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7732 TFE_NewOp(context::get_context(),
"ExperimentalMatchingFilesDataset", context::get_status()), &TFE_DeleteOp);
7733 status_check(context::get_status());
7737 TFE_OpAddInput(op.get(), patterns.tfe_handle.get(), context::get_status());
7738 status_check(context::get_status());
7743 int num_outputs_op = 1;
7744 TFE_TensorHandle* res[1] = {
nullptr};
7745 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7746 status_check(context::get_status());
7747 return tensor(res[0]);
7750 inline tensor experimental_max_intra_op_parallelism_dataset(
const tensor& input_dataset,
7751 const tensor& max_intra_op_parallelism,
7752 const std::vector<datatype>& output_types,
7753 const std::vector<std::vector<int64_t>>& output_shapes) {
7755 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7756 TFE_NewOp(context::get_context(),
"ExperimentalMaxIntraOpParallelismDataset", context::get_status()),
7758 status_check(context::get_status());
7762 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7763 status_check(context::get_status());
7765 TFE_OpAddInput(op.get(), max_intra_op_parallelism.tfe_handle.get(), context::get_status());
7766 status_check(context::get_status());
7769 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7770 static_cast<int>(output_types.size()));
7772 std::vector<const int64_t*> output_shapes_values;
7773 output_shapes_values.reserve(output_shapes.size());
7774 std::vector<int> output_shapes_ndims;
7775 output_shapes_ndims.reserve(output_shapes.size());
7776 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7777 [](
const auto& v) { return v.data(); });
7778 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7779 [](
const auto& v) { return static_cast<int>(v.size()); });
7780 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7781 static_cast<int>(output_shapes.size()), context::get_status());
7782 status_check(context::get_status());
7785 int num_outputs_op = 1;
7786 TFE_TensorHandle* res[1] = {
nullptr};
7787 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7788 status_check(context::get_status());
7789 return tensor(res[0]);
7792 inline tensor experimental_non_serializable_dataset(
const tensor& input_dataset,
7793 const std::vector<datatype>& output_types,
7794 const std::vector<std::vector<int64_t>>& output_shapes) {
7796 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7797 TFE_NewOp(context::get_context(),
"ExperimentalNonSerializableDataset", context::get_status()), &TFE_DeleteOp);
7798 status_check(context::get_status());
7802 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7803 status_check(context::get_status());
7806 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7807 static_cast<int>(output_types.size()));
7809 std::vector<const int64_t*> output_shapes_values;
7810 output_shapes_values.reserve(output_shapes.size());
7811 std::vector<int> output_shapes_ndims;
7812 output_shapes_ndims.reserve(output_shapes.size());
7813 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7814 [](
const auto& v) { return v.data(); });
7815 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7816 [](
const auto& v) { return static_cast<int>(v.size()); });
7817 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7818 static_cast<int>(output_shapes.size()), context::get_status());
7819 status_check(context::get_status());
7822 int num_outputs_op = 1;
7823 TFE_TensorHandle* res[1] = {
nullptr};
7824 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7825 status_check(context::get_status());
7826 return tensor(res[0]);
7829 inline tensor experimental_parse_example_dataset(
7830 const tensor& input_dataset,
const tensor& num_parallel_calls,
const std::vector<tensor>& dense_defaults,
7831 const std::vector<std::string>& sparse_keys,
const std::vector<std::string>& dense_keys,
7832 const std::vector<datatype>& sparse_types,
const std::vector<datatype>& Tdense,
7833 const std::vector<std::vector<int64_t>>& dense_shapes,
const std::vector<datatype>& output_types,
7834 const std::vector<std::vector<int64_t>>& output_shapes,
bool sloppy =
false) {
7836 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7837 TFE_NewOp(context::get_context(),
"ExperimentalParseExampleDataset", context::get_status()), &TFE_DeleteOp);
7838 status_check(context::get_status());
7842 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7843 status_check(context::get_status());
7845 TFE_OpAddInput(op.get(), num_parallel_calls.tfe_handle.get(), context::get_status());
7846 status_check(context::get_status());
7848 std::vector<TFE_TensorHandle*> dense_defaults_handles;
7849 dense_defaults_handles.reserve(dense_defaults.size());
7850 std::transform(dense_defaults.begin(), dense_defaults.end(), std::back_inserter(dense_defaults_handles),
7851 [](
const auto& t) { return t.tfe_handle.get(); });
7852 TFE_OpAddInputList(op.get(), dense_defaults_handles.data(),
static_cast<int>(dense_defaults.size()),
7853 context::get_status());
7854 status_check(context::get_status());
7858 std::vector<std::size_t> sparse_keys_sizes;
7859 sparse_keys_sizes.reserve(sparse_keys.size());
7860 std::transform(sparse_keys.begin(), sparse_keys.end(), std::back_inserter(sparse_keys_sizes),
7861 [](
const auto& s) { return s.size(); });
7862 TFE_OpSetAttrStringList(op.get(),
"sparse_keys",
reinterpret_cast<const void* const*
>(sparse_keys.data()),
7863 sparse_keys_sizes.data(),
static_cast<int>(sparse_keys.size()));
7865 std::vector<std::size_t> dense_keys_sizes;
7866 dense_keys_sizes.reserve(dense_keys.size());
7867 std::transform(dense_keys.begin(), dense_keys.end(), std::back_inserter(dense_keys_sizes),
7868 [](
const auto& s) { return s.size(); });
7869 TFE_OpSetAttrStringList(op.get(),
"dense_keys",
reinterpret_cast<const void* const*
>(dense_keys.data()),
7870 dense_keys_sizes.data(),
static_cast<int>(dense_keys.size()));
7872 TFE_OpSetAttrTypeList(op.get(),
"sparse_types",
reinterpret_cast<const enum TF_DataType*
>(sparse_types.data()),
7873 static_cast<int>(sparse_types.size()));
7874 TFE_OpSetAttrTypeList(op.get(),
"Tdense",
reinterpret_cast<const enum TF_DataType*
>(Tdense.data()),
7875 static_cast<int>(Tdense.size()));
7877 std::vector<const int64_t*> dense_shapes_values;
7878 dense_shapes_values.reserve(dense_shapes.size());
7879 std::vector<int> dense_shapes_ndims;
7880 dense_shapes_ndims.reserve(dense_shapes.size());
7881 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_values),
7882 [](
const auto& v) { return v.data(); });
7883 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_ndims),
7884 [](
const auto& v) { return static_cast<int>(v.size()); });
7885 TFE_OpSetAttrShapeList(op.get(),
"dense_shapes", dense_shapes_values.data(), dense_shapes_ndims.data(),
7886 static_cast<int>(dense_shapes.size()), context::get_status());
7887 status_check(context::get_status());
7889 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7890 static_cast<int>(output_types.size()));
7892 std::vector<const int64_t*> output_shapes_values;
7893 output_shapes_values.reserve(output_shapes.size());
7894 std::vector<int> output_shapes_ndims;
7895 output_shapes_ndims.reserve(output_shapes.size());
7896 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7897 [](
const auto& v) { return v.data(); });
7898 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7899 [](
const auto& v) { return static_cast<int>(v.size()); });
7900 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7901 static_cast<int>(output_shapes.size()), context::get_status());
7902 status_check(context::get_status());
7904 TFE_OpSetAttrBool(op.get(),
"sloppy", (
unsigned char)sloppy);
7907 int num_outputs_op = 1;
7908 TFE_TensorHandle* res[1] = {
nullptr};
7909 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7910 status_check(context::get_status());
7911 return tensor(res[0]);
7914 inline tensor experimental_private_thread_pool_dataset(
const tensor& input_dataset,
const tensor& num_threads,
7915 const std::vector<datatype>& output_types,
7916 const std::vector<std::vector<int64_t>>& output_shapes) {
7918 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7919 TFE_NewOp(context::get_context(),
"ExperimentalPrivateThreadPoolDataset", context::get_status()), &TFE_DeleteOp);
7920 status_check(context::get_status());
7924 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
7925 status_check(context::get_status());
7927 TFE_OpAddInput(op.get(), num_threads.tfe_handle.get(), context::get_status());
7928 status_check(context::get_status());
7931 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7932 static_cast<int>(output_types.size()));
7934 std::vector<const int64_t*> output_shapes_values;
7935 output_shapes_values.reserve(output_shapes.size());
7936 std::vector<int> output_shapes_ndims;
7937 output_shapes_ndims.reserve(output_shapes.size());
7938 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7939 [](
const auto& v) { return v.data(); });
7940 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7941 [](
const auto& v) { return static_cast<int>(v.size()); });
7942 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7943 static_cast<int>(output_shapes.size()), context::get_status());
7944 status_check(context::get_status());
7947 int num_outputs_op = 1;
7948 TFE_TensorHandle* res[1] = {
nullptr};
7949 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7950 status_check(context::get_status());
7951 return tensor(res[0]);
7954 inline tensor experimental_random_dataset(
const tensor& seed,
const tensor& seed2,
7955 const std::vector<datatype>& output_types,
7956 const std::vector<std::vector<int64_t>>& output_shapes) {
7958 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
7959 TFE_NewOp(context::get_context(),
"ExperimentalRandomDataset", context::get_status()), &TFE_DeleteOp);
7960 status_check(context::get_status());
7964 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
7965 status_check(context::get_status());
7967 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
7968 status_check(context::get_status());
7971 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
7972 static_cast<int>(output_types.size()));
7974 std::vector<const int64_t*> output_shapes_values;
7975 output_shapes_values.reserve(output_shapes.size());
7976 std::vector<int> output_shapes_ndims;
7977 output_shapes_ndims.reserve(output_shapes.size());
7978 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
7979 [](
const auto& v) { return v.data(); });
7980 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
7981 [](
const auto& v) { return static_cast<int>(v.size()); });
7982 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
7983 static_cast<int>(output_shapes.size()), context::get_status());
7984 status_check(context::get_status());
7987 int num_outputs_op = 1;
7988 TFE_TensorHandle* res[1] = {
nullptr};
7989 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
7990 status_check(context::get_status());
7991 return tensor(res[0]);
7994 inline tensor experimental_rebatch_dataset(
const tensor& input_dataset,
const tensor& num_replicas,
7995 const std::vector<datatype>& output_types,
7996 const std::vector<std::vector<int64_t>>& output_shapes,
7997 bool use_fallback =
true) {
7999 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8000 TFE_NewOp(context::get_context(),
"ExperimentalRebatchDataset", context::get_status()), &TFE_DeleteOp);
8001 status_check(context::get_status());
8005 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8006 status_check(context::get_status());
8008 TFE_OpAddInput(op.get(), num_replicas.tfe_handle.get(), context::get_status());
8009 status_check(context::get_status());
8012 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8013 static_cast<int>(output_types.size()));
8015 std::vector<const int64_t*> output_shapes_values;
8016 output_shapes_values.reserve(output_shapes.size());
8017 std::vector<int> output_shapes_ndims;
8018 output_shapes_ndims.reserve(output_shapes.size());
8019 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8020 [](
const auto& v) { return v.data(); });
8021 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8022 [](
const auto& v) { return static_cast<int>(v.size()); });
8023 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8024 static_cast<int>(output_shapes.size()), context::get_status());
8025 status_check(context::get_status());
8027 TFE_OpSetAttrBool(op.get(),
"use_fallback", (
unsigned char)use_fallback);
8030 int num_outputs_op = 1;
8031 TFE_TensorHandle* res[1] = {
nullptr};
8032 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8033 status_check(context::get_status());
8034 return tensor(res[0]);
8037 inline tensor experimental_set_stats_aggregator_dataset(
const tensor& input_dataset,
const tensor& stats_aggregator,
8038 const tensor& tag,
const tensor& counter_prefix,
8039 const std::vector<datatype>& output_types,
8040 const std::vector<std::vector<int64_t>>& output_shapes) {
8042 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8043 TFE_NewOp(context::get_context(),
"ExperimentalSetStatsAggregatorDataset", context::get_status()), &TFE_DeleteOp);
8044 status_check(context::get_status());
8048 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8049 status_check(context::get_status());
8051 TFE_OpAddInput(op.get(), stats_aggregator.tfe_handle.get(), context::get_status());
8052 status_check(context::get_status());
8054 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
8055 status_check(context::get_status());
8057 TFE_OpAddInput(op.get(), counter_prefix.tfe_handle.get(), context::get_status());
8058 status_check(context::get_status());
8061 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8062 static_cast<int>(output_types.size()));
8064 std::vector<const int64_t*> output_shapes_values;
8065 output_shapes_values.reserve(output_shapes.size());
8066 std::vector<int> output_shapes_ndims;
8067 output_shapes_ndims.reserve(output_shapes.size());
8068 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8069 [](
const auto& v) { return v.data(); });
8070 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8071 [](
const auto& v) { return static_cast<int>(v.size()); });
8072 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8073 static_cast<int>(output_shapes.size()), context::get_status());
8074 status_check(context::get_status());
8077 int num_outputs_op = 1;
8078 TFE_TensorHandle* res[1] = {
nullptr};
8079 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8080 status_check(context::get_status());
8081 return tensor(res[0]);
8084 inline tensor experimental_sleep_dataset(
const tensor& input_dataset,
const tensor& sleep_microseconds,
8085 const std::vector<datatype>& output_types,
8086 const std::vector<std::vector<int64_t>>& output_shapes) {
8088 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8089 TFE_NewOp(context::get_context(),
"ExperimentalSleepDataset", context::get_status()), &TFE_DeleteOp);
8090 status_check(context::get_status());
8094 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8095 status_check(context::get_status());
8097 TFE_OpAddInput(op.get(), sleep_microseconds.tfe_handle.get(), context::get_status());
8098 status_check(context::get_status());
8101 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8102 static_cast<int>(output_types.size()));
8104 std::vector<const int64_t*> output_shapes_values;
8105 output_shapes_values.reserve(output_shapes.size());
8106 std::vector<int> output_shapes_ndims;
8107 output_shapes_ndims.reserve(output_shapes.size());
8108 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8109 [](
const auto& v) { return v.data(); });
8110 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8111 [](
const auto& v) { return static_cast<int>(v.size()); });
8112 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8113 static_cast<int>(output_shapes.size()), context::get_status());
8114 status_check(context::get_status());
8117 int num_outputs_op = 1;
8118 TFE_TensorHandle* res[1] = {
nullptr};
8119 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8120 status_check(context::get_status());
8121 return tensor(res[0]);
8124 inline tensor experimental_sliding_window_dataset(
const tensor& input_dataset,
const tensor& window_size,
8125 const tensor& window_shift,
const tensor& window_stride,
8126 const std::vector<datatype>& output_types,
8127 const std::vector<std::vector<int64_t>>& output_shapes) {
8129 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8130 TFE_NewOp(context::get_context(),
"ExperimentalSlidingWindowDataset", context::get_status()), &TFE_DeleteOp);
8131 status_check(context::get_status());
8135 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8136 status_check(context::get_status());
8138 TFE_OpAddInput(op.get(), window_size.tfe_handle.get(), context::get_status());
8139 status_check(context::get_status());
8141 TFE_OpAddInput(op.get(), window_shift.tfe_handle.get(), context::get_status());
8142 status_check(context::get_status());
8144 TFE_OpAddInput(op.get(), window_stride.tfe_handle.get(), context::get_status());
8145 status_check(context::get_status());
8148 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8149 static_cast<int>(output_types.size()));
8151 std::vector<const int64_t*> output_shapes_values;
8152 output_shapes_values.reserve(output_shapes.size());
8153 std::vector<int> output_shapes_ndims;
8154 output_shapes_ndims.reserve(output_shapes.size());
8155 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8156 [](
const auto& v) { return v.data(); });
8157 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8158 [](
const auto& v) { return static_cast<int>(v.size()); });
8159 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8160 static_cast<int>(output_shapes.size()), context::get_status());
8161 status_check(context::get_status());
8164 int num_outputs_op = 1;
8165 TFE_TensorHandle* res[1] = {
nullptr};
8166 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8167 status_check(context::get_status());
8168 return tensor(res[0]);
8171 inline tensor experimental_sql_dataset(
const tensor& driver_name,
const tensor& data_source_name,
const tensor& query,
8172 const std::vector<datatype>& output_types,
8173 const std::vector<std::vector<int64_t>>& output_shapes) {
8175 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8176 TFE_NewOp(context::get_context(),
"ExperimentalSqlDataset", context::get_status()), &TFE_DeleteOp);
8177 status_check(context::get_status());
8181 TFE_OpAddInput(op.get(), driver_name.tfe_handle.get(), context::get_status());
8182 status_check(context::get_status());
8184 TFE_OpAddInput(op.get(), data_source_name.tfe_handle.get(), context::get_status());
8185 status_check(context::get_status());
8187 TFE_OpAddInput(op.get(), query.tfe_handle.get(), context::get_status());
8188 status_check(context::get_status());
8191 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8192 static_cast<int>(output_types.size()));
8194 std::vector<const int64_t*> output_shapes_values;
8195 output_shapes_values.reserve(output_shapes.size());
8196 std::vector<int> output_shapes_ndims;
8197 output_shapes_ndims.reserve(output_shapes.size());
8198 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8199 [](
const auto& v) { return v.data(); });
8200 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8201 [](
const auto& v) { return static_cast<int>(v.size()); });
8202 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8203 static_cast<int>(output_shapes.size()), context::get_status());
8204 status_check(context::get_status());
8207 int num_outputs_op = 1;
8208 TFE_TensorHandle* res[1] = {
nullptr};
8209 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8210 status_check(context::get_status());
8211 return tensor(res[0]);
8214 inline tensor experimental_stats_aggregator_handle(
const std::string& container =
"",
8215 const std::string& shared_name =
"") {
8217 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8218 TFE_NewOp(context::get_context(),
"ExperimentalStatsAggregatorHandle", context::get_status()), &TFE_DeleteOp);
8219 status_check(context::get_status());
8224 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
8225 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
8228 int num_outputs_op = 1;
8229 TFE_TensorHandle* res[1] = {
nullptr};
8230 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8231 status_check(context::get_status());
8232 return tensor(res[0]);
8235 inline tensor experimental_stats_aggregator_summary(
const tensor& iterator) {
8237 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8238 TFE_NewOp(context::get_context(),
"ExperimentalStatsAggregatorSummary", context::get_status()), &TFE_DeleteOp);
8239 status_check(context::get_status());
8243 TFE_OpAddInput(op.get(), iterator.tfe_handle.get(), context::get_status());
8244 status_check(context::get_status());
8249 int num_outputs_op = 1;
8250 TFE_TensorHandle* res[1] = {
nullptr};
8251 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8252 status_check(context::get_status());
8253 return tensor(res[0]);
8256 inline tensor experimental_thread_pool_dataset(
const tensor& input_dataset,
const tensor& thread_pool,
8257 const std::vector<datatype>& output_types,
8258 const std::vector<std::vector<int64_t>>& output_shapes) {
8260 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8261 TFE_NewOp(context::get_context(),
"ExperimentalThreadPoolDataset", context::get_status()), &TFE_DeleteOp);
8262 status_check(context::get_status());
8266 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8267 status_check(context::get_status());
8269 TFE_OpAddInput(op.get(), thread_pool.tfe_handle.get(), context::get_status());
8270 status_check(context::get_status());
8273 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8274 static_cast<int>(output_types.size()));
8276 std::vector<const int64_t*> output_shapes_values;
8277 output_shapes_values.reserve(output_shapes.size());
8278 std::vector<int> output_shapes_ndims;
8279 output_shapes_ndims.reserve(output_shapes.size());
8280 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8281 [](
const auto& v) { return v.data(); });
8282 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8283 [](
const auto& v) { return static_cast<int>(v.size()); });
8284 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8285 static_cast<int>(output_shapes.size()), context::get_status());
8286 status_check(context::get_status());
8289 int num_outputs_op = 1;
8290 TFE_TensorHandle* res[1] = {
nullptr};
8291 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8292 status_check(context::get_status());
8293 return tensor(res[0]);
8296 inline tensor experimental_thread_pool_handle(int64_t num_threads,
const std::string& display_name,
8297 int64_t max_intra_op_parallelism = 1,
const std::string& container =
"",
8298 const std::string& shared_name =
"") {
8300 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8301 TFE_NewOp(context::get_context(),
"ExperimentalThreadPoolHandle", context::get_status()), &TFE_DeleteOp);
8302 status_check(context::get_status());
8307 TFE_OpSetAttrInt(op.get(),
"num_threads", num_threads);
8308 TFE_OpSetAttrString(op.get(),
"display_name", (
void*)display_name.c_str(), display_name.size());
8309 TFE_OpSetAttrInt(op.get(),
"max_intra_op_parallelism", max_intra_op_parallelism);
8310 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
8311 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
8314 int num_outputs_op = 1;
8315 TFE_TensorHandle* res[1] = {
nullptr};
8316 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8317 status_check(context::get_status());
8318 return tensor(res[0]);
8321 inline tensor experimental_unbatch_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
8322 const std::vector<std::vector<int64_t>>& output_shapes) {
8324 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8325 TFE_NewOp(context::get_context(),
"ExperimentalUnbatchDataset", context::get_status()), &TFE_DeleteOp);
8326 status_check(context::get_status());
8330 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8331 status_check(context::get_status());
8334 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8335 static_cast<int>(output_types.size()));
8337 std::vector<const int64_t*> output_shapes_values;
8338 output_shapes_values.reserve(output_shapes.size());
8339 std::vector<int> output_shapes_ndims;
8340 output_shapes_ndims.reserve(output_shapes.size());
8341 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8342 [](
const auto& v) { return v.data(); });
8343 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8344 [](
const auto& v) { return static_cast<int>(v.size()); });
8345 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8346 static_cast<int>(output_shapes.size()), context::get_status());
8347 status_check(context::get_status());
8350 int num_outputs_op = 1;
8351 TFE_TensorHandle* res[1] = {
nullptr};
8352 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8353 status_check(context::get_status());
8354 return tensor(res[0]);
8357 inline tensor experimental_unique_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
8358 const std::vector<std::vector<int64_t>>& output_shapes) {
8360 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8361 TFE_NewOp(context::get_context(),
"ExperimentalUniqueDataset", context::get_status()), &TFE_DeleteOp);
8362 status_check(context::get_status());
8366 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8367 status_check(context::get_status());
8370 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8371 static_cast<int>(output_types.size()));
8373 std::vector<const int64_t*> output_shapes_values;
8374 output_shapes_values.reserve(output_shapes.size());
8375 std::vector<int> output_shapes_ndims;
8376 output_shapes_ndims.reserve(output_shapes.size());
8377 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8378 [](
const auto& v) { return v.data(); });
8379 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8380 [](
const auto& v) { return static_cast<int>(v.size()); });
8381 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8382 static_cast<int>(output_shapes.size()), context::get_status());
8383 status_check(context::get_status());
8386 int num_outputs_op = 1;
8387 TFE_TensorHandle* res[1] = {
nullptr};
8388 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8389 status_check(context::get_status());
8390 return tensor(res[0]);
8393 inline tensor expint(
const tensor& x) {
8395 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8396 TFE_NewOp(context::get_context(),
"Expint", context::get_status()), &TFE_DeleteOp);
8397 status_check(context::get_status());
8401 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
8402 status_check(context::get_status());
8407 int num_outputs_op = 1;
8408 TFE_TensorHandle* res[1] = {
nullptr};
8409 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8410 status_check(context::get_status());
8411 return tensor(res[0]);
8414 inline tensor expm1(
const tensor& x) {
8416 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Expm1", context::get_status()),
8418 status_check(context::get_status());
8422 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
8423 status_check(context::get_status());
8428 int num_outputs_op = 1;
8429 TFE_TensorHandle* res[1] = {
nullptr};
8430 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8431 status_check(context::get_status());
8432 return tensor(res[0]);
8435 inline tensor extract_glimpse(
const tensor& input,
const tensor& size,
const tensor& offsets,
bool centered =
true,
8436 bool normalized =
true,
bool uniform_noise =
true,
const std::string& noise =
"uniform") {
8438 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8439 TFE_NewOp(context::get_context(),
"ExtractGlimpse", context::get_status()), &TFE_DeleteOp);
8440 status_check(context::get_status());
8444 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8445 status_check(context::get_status());
8447 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
8448 status_check(context::get_status());
8450 TFE_OpAddInput(op.get(), offsets.tfe_handle.get(), context::get_status());
8451 status_check(context::get_status());
8454 TFE_OpSetAttrBool(op.get(),
"centered", (
unsigned char)centered);
8455 TFE_OpSetAttrBool(op.get(),
"normalized", (
unsigned char)normalized);
8456 TFE_OpSetAttrBool(op.get(),
"uniform_noise", (
unsigned char)uniform_noise);
8457 TFE_OpSetAttrString(op.get(),
"noise", (
void*)noise.c_str(), noise.size());
8460 int num_outputs_op = 1;
8461 TFE_TensorHandle* res[1] = {
nullptr};
8462 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8463 status_check(context::get_status());
8464 return tensor(res[0]);
8467 inline tensor extract_glimpse_v2(
const tensor& input,
const tensor& size,
const tensor& offsets,
bool centered =
true,
8468 bool normalized =
true,
bool uniform_noise =
true,
8469 const std::string& noise =
"uniform") {
8471 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8472 TFE_NewOp(context::get_context(),
"ExtractGlimpseV2", context::get_status()), &TFE_DeleteOp);
8473 status_check(context::get_status());
8477 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8478 status_check(context::get_status());
8480 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
8481 status_check(context::get_status());
8483 TFE_OpAddInput(op.get(), offsets.tfe_handle.get(), context::get_status());
8484 status_check(context::get_status());
8487 TFE_OpSetAttrBool(op.get(),
"centered", (
unsigned char)centered);
8488 TFE_OpSetAttrBool(op.get(),
"normalized", (
unsigned char)normalized);
8489 TFE_OpSetAttrBool(op.get(),
"uniform_noise", (
unsigned char)uniform_noise);
8490 TFE_OpSetAttrString(op.get(),
"noise", (
void*)noise.c_str(), noise.size());
8493 int num_outputs_op = 1;
8494 TFE_TensorHandle* res[1] = {
nullptr};
8495 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8496 status_check(context::get_status());
8497 return tensor(res[0]);
8500 inline tensor extract_image_patches(
const tensor& images,
const std::vector<int64_t>& ksizes,
8501 const std::vector<int64_t>& strides,
const std::vector<int64_t>& rates,
8502 const std::string& padding) {
8504 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8505 TFE_NewOp(context::get_context(),
"ExtractImagePatches", context::get_status()), &TFE_DeleteOp);
8506 status_check(context::get_status());
8510 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
8511 status_check(context::get_status());
8514 TFE_OpSetAttrIntList(op.get(),
"ksizes", ksizes.data(),
static_cast<int>(ksizes.size()));
8515 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
8516 TFE_OpSetAttrIntList(op.get(),
"rates", rates.data(),
static_cast<int>(rates.size()));
8517 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
8520 int num_outputs_op = 1;
8521 TFE_TensorHandle* res[1] = {
nullptr};
8522 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8523 status_check(context::get_status());
8524 return tensor(res[0]);
8527 inline tensor extract_jpeg_shape(
const tensor& contents, datatype output_type =
static_cast<datatype
>(3)) {
8529 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8530 TFE_NewOp(context::get_context(),
"ExtractJpegShape", context::get_status()), &TFE_DeleteOp);
8531 status_check(context::get_status());
8535 TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
8536 status_check(context::get_status());
8539 TFE_OpSetAttrType(op.get(),
"output_type", output_type);
8542 int num_outputs_op = 1;
8543 TFE_TensorHandle* res[1] = {
nullptr};
8544 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8545 status_check(context::get_status());
8546 return tensor(res[0]);
8549 inline tensor extract_volume_patches(
const tensor& input,
const std::vector<int64_t>& ksizes,
8550 const std::vector<int64_t>& strides,
const std::string& padding) {
8552 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8553 TFE_NewOp(context::get_context(),
"ExtractVolumePatches", context::get_status()), &TFE_DeleteOp);
8554 status_check(context::get_status());
8558 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8559 status_check(context::get_status());
8562 TFE_OpSetAttrIntList(op.get(),
"ksizes", ksizes.data(),
static_cast<int>(ksizes.size()));
8563 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
8564 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
8567 int num_outputs_op = 1;
8568 TFE_TensorHandle* res[1] = {
nullptr};
8569 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8570 status_check(context::get_status());
8571 return tensor(res[0]);
8574 inline tensor f_f_t(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
8576 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"FFT", context::get_status()),
8578 status_check(context::get_status());
8582 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8583 status_check(context::get_status());
8586 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
8589 int num_outputs_op = 1;
8590 TFE_TensorHandle* res[1] = {
nullptr};
8591 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8592 status_check(context::get_status());
8593 return tensor(res[0]);
8596 inline tensor f_f_t2_d(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
8598 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"FFT2D", context::get_status()),
8600 status_check(context::get_status());
8604 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8605 status_check(context::get_status());
8608 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
8611 int num_outputs_op = 1;
8612 TFE_TensorHandle* res[1] = {
nullptr};
8613 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8614 status_check(context::get_status());
8615 return tensor(res[0]);
8618 inline tensor f_f_t3_d(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
8620 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"FFT3D", context::get_status()),
8622 status_check(context::get_status());
8626 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
8627 status_check(context::get_status());
8630 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
8633 int num_outputs_op = 1;
8634 TFE_TensorHandle* res[1] = {
nullptr};
8635 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8636 status_check(context::get_status());
8637 return tensor(res[0]);
8640 inline tensor f_i_f_o_queue(
const std::vector<datatype>& component_types,
8641 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
8642 const std::string& container =
"",
const std::string& shared_name =
"") {
8644 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8645 TFE_NewOp(context::get_context(),
"FIFOQueue", context::get_status()), &TFE_DeleteOp);
8646 status_check(context::get_status());
8651 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
8652 static_cast<int>(component_types.size()));
8654 std::vector<const int64_t*> shapes_values;
8655 shapes_values.reserve(shapes.size());
8656 std::vector<int> shapes_ndims;
8657 shapes_ndims.reserve(shapes.size());
8658 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
8659 [](
const auto& v) { return v.data(); });
8660 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
8661 [](
const auto& v) { return static_cast<int>(v.size()); });
8662 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
8663 context::get_status());
8664 status_check(context::get_status());
8666 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
8667 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
8668 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
8671 int num_outputs_op = 1;
8672 TFE_TensorHandle* res[1] = {
nullptr};
8673 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8674 status_check(context::get_status());
8675 return tensor(res[0]);
8678 inline tensor f_i_f_o_queue_v2(
const std::vector<datatype>& component_types,
8679 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
8680 const std::string& container =
"",
const std::string& shared_name =
"") {
8682 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8683 TFE_NewOp(context::get_context(),
"FIFOQueueV2", context::get_status()), &TFE_DeleteOp);
8684 status_check(context::get_status());
8689 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
8690 static_cast<int>(component_types.size()));
8692 std::vector<const int64_t*> shapes_values;
8693 shapes_values.reserve(shapes.size());
8694 std::vector<int> shapes_ndims;
8695 shapes_ndims.reserve(shapes.size());
8696 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
8697 [](
const auto& v) { return v.data(); });
8698 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
8699 [](
const auto& v) { return static_cast<int>(v.size()); });
8700 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
8701 context::get_status());
8702 status_check(context::get_status());
8704 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
8705 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
8706 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
8709 int num_outputs_op = 1;
8710 TFE_TensorHandle* res[1] = {
nullptr};
8711 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8712 status_check(context::get_status());
8713 return tensor(res[0]);
8716 inline tensor fact() {
8718 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Fact", context::get_status()),
8720 status_check(context::get_status());
8727 int num_outputs_op = 1;
8728 TFE_TensorHandle* res[1] = {
nullptr};
8729 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8730 status_check(context::get_status());
8731 return tensor(res[0]);
8734 inline tensor fake_param(datatype dtype,
const std::vector<int64_t>& shape) {
8736 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8737 TFE_NewOp(context::get_context(),
"FakeParam", context::get_status()), &TFE_DeleteOp);
8738 status_check(context::get_status());
8743 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
8745 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
8746 status_check(context::get_status());
8749 int num_outputs_op = 1;
8750 TFE_TensorHandle* res[1] = {
nullptr};
8751 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8752 status_check(context::get_status());
8753 return tensor(res[0]);
8756 inline tensor fake_quant_with_min_max_args(
const tensor& inputs,
float min = -6.0000e+00,
float max = 6.0000e+00,
8757 int64_t num_bits = 8,
bool narrow_range =
false) {
8759 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8760 TFE_NewOp(context::get_context(),
"FakeQuantWithMinMaxArgs", context::get_status()), &TFE_DeleteOp);
8761 status_check(context::get_status());
8765 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
8766 status_check(context::get_status());
8769 TFE_OpSetAttrFloat(op.get(),
"min", min);
8770 TFE_OpSetAttrFloat(op.get(),
"max", max);
8771 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
8772 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
8775 int num_outputs_op = 1;
8776 TFE_TensorHandle* res[1] = {
nullptr};
8777 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8778 status_check(context::get_status());
8779 return tensor(res[0]);
8782 inline tensor fake_quant_with_min_max_args_gradient(
const tensor& gradients,
const tensor& inputs,
8783 float min = -6.0000e+00,
float max = 6.0000e+00,
8784 int64_t num_bits = 8,
bool narrow_range =
false) {
8786 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8787 TFE_NewOp(context::get_context(),
"FakeQuantWithMinMaxArgsGradient", context::get_status()), &TFE_DeleteOp);
8788 status_check(context::get_status());
8792 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
8793 status_check(context::get_status());
8795 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
8796 status_check(context::get_status());
8799 TFE_OpSetAttrFloat(op.get(),
"min", min);
8800 TFE_OpSetAttrFloat(op.get(),
"max", max);
8801 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
8802 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
8805 int num_outputs_op = 1;
8806 TFE_TensorHandle* res[1] = {
nullptr};
8807 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8808 status_check(context::get_status());
8809 return tensor(res[0]);
8812 inline tensor fake_quant_with_min_max_vars(
const tensor& inputs,
const tensor& min,
const tensor& max,
8813 int64_t num_bits = 8,
bool narrow_range =
false) {
8815 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8816 TFE_NewOp(context::get_context(),
"FakeQuantWithMinMaxVars", context::get_status()), &TFE_DeleteOp);
8817 status_check(context::get_status());
8821 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
8822 status_check(context::get_status());
8824 TFE_OpAddInput(op.get(), min.tfe_handle.get(), context::get_status());
8825 status_check(context::get_status());
8827 TFE_OpAddInput(op.get(), max.tfe_handle.get(), context::get_status());
8828 status_check(context::get_status());
8831 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
8832 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
8835 int num_outputs_op = 1;
8836 TFE_TensorHandle* res[1] = {
nullptr};
8837 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8838 status_check(context::get_status());
8839 return tensor(res[0]);
8842 inline tensor fake_quant_with_min_max_vars_per_channel(
const tensor& inputs,
const tensor& min,
const tensor& max,
8843 int64_t num_bits = 8,
bool narrow_range =
false) {
8845 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8846 TFE_NewOp(context::get_context(),
"FakeQuantWithMinMaxVarsPerChannel", context::get_status()), &TFE_DeleteOp);
8847 status_check(context::get_status());
8851 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
8852 status_check(context::get_status());
8854 TFE_OpAddInput(op.get(), min.tfe_handle.get(), context::get_status());
8855 status_check(context::get_status());
8857 TFE_OpAddInput(op.get(), max.tfe_handle.get(), context::get_status());
8858 status_check(context::get_status());
8861 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
8862 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
8865 int num_outputs_op = 1;
8866 TFE_TensorHandle* res[1] = {
nullptr};
8867 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8868 status_check(context::get_status());
8869 return tensor(res[0]);
8872 inline tensor fake_queue(
const tensor& resource) {
8874 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8875 TFE_NewOp(context::get_context(),
"FakeQueue", context::get_status()), &TFE_DeleteOp);
8876 status_check(context::get_status());
8880 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
8881 status_check(context::get_status());
8886 int num_outputs_op = 1;
8887 TFE_TensorHandle* res[1] = {
nullptr};
8888 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8889 status_check(context::get_status());
8890 return tensor(res[0]);
8893 inline tensor fill(
const tensor& dims,
const tensor& value, datatype index_type =
static_cast<datatype
>(3)) {
8895 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Fill", context::get_status()),
8897 status_check(context::get_status());
8901 TFE_OpAddInput(op.get(), dims.tfe_handle.get(), context::get_status());
8902 status_check(context::get_status());
8904 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
8905 status_check(context::get_status());
8908 TFE_OpSetAttrType(op.get(),
"index_type", index_type);
8911 int num_outputs_op = 1;
8912 TFE_TensorHandle* res[1] = {
nullptr};
8913 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8914 status_check(context::get_status());
8915 return tensor(res[0]);
8918 inline tensor filter_by_last_component_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
8919 const std::vector<std::vector<int64_t>>& output_shapes) {
8921 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8922 TFE_NewOp(context::get_context(),
"FilterByLastComponentDataset", context::get_status()), &TFE_DeleteOp);
8923 status_check(context::get_status());
8927 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
8928 status_check(context::get_status());
8931 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
8932 static_cast<int>(output_types.size()));
8934 std::vector<const int64_t*> output_shapes_values;
8935 output_shapes_values.reserve(output_shapes.size());
8936 std::vector<int> output_shapes_ndims;
8937 output_shapes_ndims.reserve(output_shapes.size());
8938 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
8939 [](
const auto& v) { return v.data(); });
8940 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
8941 [](
const auto& v) { return static_cast<int>(v.size()); });
8942 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
8943 static_cast<int>(output_shapes.size()), context::get_status());
8944 status_check(context::get_status());
8947 int num_outputs_op = 1;
8948 TFE_TensorHandle* res[1] = {
nullptr};
8949 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8950 status_check(context::get_status());
8951 return tensor(res[0]);
8954 inline tensor fingerprint(
const tensor& data,
const tensor& method) {
8956 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8957 TFE_NewOp(context::get_context(),
"Fingerprint", context::get_status()), &TFE_DeleteOp);
8958 status_check(context::get_status());
8962 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
8963 status_check(context::get_status());
8965 TFE_OpAddInput(op.get(), method.tfe_handle.get(), context::get_status());
8966 status_check(context::get_status());
8971 int num_outputs_op = 1;
8972 TFE_TensorHandle* res[1] = {
nullptr};
8973 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
8974 status_check(context::get_status());
8975 return tensor(res[0]);
8978 inline tensor fixed_length_record_dataset(
const tensor& filenames,
const tensor& header_bytes,
8979 const tensor& record_bytes,
const tensor& footer_bytes,
8980 const tensor& buffer_size) {
8982 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
8983 TFE_NewOp(context::get_context(),
"FixedLengthRecordDataset", context::get_status()), &TFE_DeleteOp);
8984 status_check(context::get_status());
8988 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
8989 status_check(context::get_status());
8991 TFE_OpAddInput(op.get(), header_bytes.tfe_handle.get(), context::get_status());
8992 status_check(context::get_status());
8994 TFE_OpAddInput(op.get(), record_bytes.tfe_handle.get(), context::get_status());
8995 status_check(context::get_status());
8997 TFE_OpAddInput(op.get(), footer_bytes.tfe_handle.get(), context::get_status());
8998 status_check(context::get_status());
9000 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
9001 status_check(context::get_status());
9006 int num_outputs_op = 1;
9007 TFE_TensorHandle* res[1] = {
nullptr};
9008 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9009 status_check(context::get_status());
9010 return tensor(res[0]);
9013 inline tensor fixed_length_record_dataset_v2(
const tensor& filenames,
const tensor& header_bytes,
9014 const tensor& record_bytes,
const tensor& footer_bytes,
9015 const tensor& buffer_size,
const tensor& compression_type) {
9017 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9018 TFE_NewOp(context::get_context(),
"FixedLengthRecordDatasetV2", context::get_status()), &TFE_DeleteOp);
9019 status_check(context::get_status());
9023 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
9024 status_check(context::get_status());
9026 TFE_OpAddInput(op.get(), header_bytes.tfe_handle.get(), context::get_status());
9027 status_check(context::get_status());
9029 TFE_OpAddInput(op.get(), record_bytes.tfe_handle.get(), context::get_status());
9030 status_check(context::get_status());
9032 TFE_OpAddInput(op.get(), footer_bytes.tfe_handle.get(), context::get_status());
9033 status_check(context::get_status());
9035 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
9036 status_check(context::get_status());
9038 TFE_OpAddInput(op.get(), compression_type.tfe_handle.get(), context::get_status());
9039 status_check(context::get_status());
9044 int num_outputs_op = 1;
9045 TFE_TensorHandle* res[1] = {
nullptr};
9046 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9047 status_check(context::get_status());
9048 return tensor(res[0]);
9051 inline tensor fixed_length_record_reader(int64_t record_bytes, int64_t header_bytes = 0, int64_t footer_bytes = 0,
9052 int64_t hop_bytes = 0,
const std::string& container =
"",
9053 const std::string& shared_name =
"") {
9055 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9056 TFE_NewOp(context::get_context(),
"FixedLengthRecordReader", context::get_status()), &TFE_DeleteOp);
9057 status_check(context::get_status());
9062 TFE_OpSetAttrInt(op.get(),
"record_bytes", record_bytes);
9063 TFE_OpSetAttrInt(op.get(),
"header_bytes", header_bytes);
9064 TFE_OpSetAttrInt(op.get(),
"footer_bytes", footer_bytes);
9065 TFE_OpSetAttrInt(op.get(),
"hop_bytes", hop_bytes);
9066 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9067 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9070 int num_outputs_op = 1;
9071 TFE_TensorHandle* res[1] = {
nullptr};
9072 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9073 status_check(context::get_status());
9074 return tensor(res[0]);
9077 inline tensor fixed_length_record_reader_v2(int64_t record_bytes, int64_t header_bytes = 0, int64_t footer_bytes = 0,
9078 int64_t hop_bytes = 0,
const std::string& container =
"",
9079 const std::string& shared_name =
"",
const std::string& encoding =
"") {
9081 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9082 TFE_NewOp(context::get_context(),
"FixedLengthRecordReaderV2", context::get_status()), &TFE_DeleteOp);
9083 status_check(context::get_status());
9088 TFE_OpSetAttrInt(op.get(),
"record_bytes", record_bytes);
9089 TFE_OpSetAttrInt(op.get(),
"header_bytes", header_bytes);
9090 TFE_OpSetAttrInt(op.get(),
"footer_bytes", footer_bytes);
9091 TFE_OpSetAttrInt(op.get(),
"hop_bytes", hop_bytes);
9092 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9093 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9094 TFE_OpSetAttrString(op.get(),
"encoding", (
void*)encoding.c_str(), encoding.size());
9097 int num_outputs_op = 1;
9098 TFE_TensorHandle* res[1] = {
nullptr};
9099 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9100 status_check(context::get_status());
9101 return tensor(res[0]);
9104 inline tensor floor(
const tensor& x) {
9106 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Floor", context::get_status()),
9108 status_check(context::get_status());
9112 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9113 status_check(context::get_status());
9118 int num_outputs_op = 1;
9119 TFE_TensorHandle* res[1] = {
nullptr};
9120 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9121 status_check(context::get_status());
9122 return tensor(res[0]);
9125 inline tensor floor_div(
const tensor& x,
const tensor& y) {
9127 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9128 TFE_NewOp(context::get_context(),
"FloorDiv", context::get_status()), &TFE_DeleteOp);
9129 status_check(context::get_status());
9133 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9134 status_check(context::get_status());
9136 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
9137 status_check(context::get_status());
9142 int num_outputs_op = 1;
9143 TFE_TensorHandle* res[1] = {
nullptr};
9144 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9145 status_check(context::get_status());
9146 return tensor(res[0]);
9149 inline tensor floor_mod(
const tensor& x,
const tensor& y) {
9151 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9152 TFE_NewOp(context::get_context(),
"FloorMod", context::get_status()), &TFE_DeleteOp);
9153 status_check(context::get_status());
9157 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9158 status_check(context::get_status());
9160 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
9161 status_check(context::get_status());
9166 int num_outputs_op = 1;
9167 TFE_TensorHandle* res[1] = {
nullptr};
9168 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9169 status_check(context::get_status());
9170 return tensor(res[0]);
9173 inline tensor fractional_avg_pool_grad(
const tensor& orig_input_input_tensor_shape,
const tensor& out_backprop,
9174 const tensor& row_pooling_sequence,
const tensor& col_pooling_sequence,
9175 bool overlapping =
false) {
9177 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9178 TFE_NewOp(context::get_context(),
"FractionalAvgPoolGrad", context::get_status()), &TFE_DeleteOp);
9179 status_check(context::get_status());
9183 TFE_OpAddInput(op.get(), orig_input_input_tensor_shape.tfe_handle.get(), context::get_status());
9184 status_check(context::get_status());
9186 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
9187 status_check(context::get_status());
9189 TFE_OpAddInput(op.get(), row_pooling_sequence.tfe_handle.get(), context::get_status());
9190 status_check(context::get_status());
9192 TFE_OpAddInput(op.get(), col_pooling_sequence.tfe_handle.get(), context::get_status());
9193 status_check(context::get_status());
9196 TFE_OpSetAttrBool(op.get(),
"overlapping", (
unsigned char)overlapping);
9199 int num_outputs_op = 1;
9200 TFE_TensorHandle* res[1] = {
nullptr};
9201 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9202 status_check(context::get_status());
9203 return tensor(res[0]);
9206 inline tensor fractional_max_pool_grad(
const tensor& orig_input,
const tensor& orig_output,
const tensor& out_backprop,
9207 const tensor& row_pooling_sequence,
const tensor& col_pooling_sequence,
9208 bool overlapping =
false) {
9210 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9211 TFE_NewOp(context::get_context(),
"FractionalMaxPoolGrad", context::get_status()), &TFE_DeleteOp);
9212 status_check(context::get_status());
9216 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
9217 status_check(context::get_status());
9219 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
9220 status_check(context::get_status());
9222 TFE_OpAddInput(op.get(), out_backprop.tfe_handle.get(), context::get_status());
9223 status_check(context::get_status());
9225 TFE_OpAddInput(op.get(), row_pooling_sequence.tfe_handle.get(), context::get_status());
9226 status_check(context::get_status());
9228 TFE_OpAddInput(op.get(), col_pooling_sequence.tfe_handle.get(), context::get_status());
9229 status_check(context::get_status());
9232 TFE_OpSetAttrBool(op.get(),
"overlapping", (
unsigned char)overlapping);
9235 int num_outputs_op = 1;
9236 TFE_TensorHandle* res[1] = {
nullptr};
9237 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9238 status_check(context::get_status());
9239 return tensor(res[0]);
9242 inline tensor fresnel_cos(
const tensor& x) {
9244 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9245 TFE_NewOp(context::get_context(),
"FresnelCos", context::get_status()), &TFE_DeleteOp);
9246 status_check(context::get_status());
9250 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9251 status_check(context::get_status());
9256 int num_outputs_op = 1;
9257 TFE_TensorHandle* res[1] = {
nullptr};
9258 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9259 status_check(context::get_status());
9260 return tensor(res[0]);
9263 inline tensor fresnel_sin(
const tensor& x) {
9265 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9266 TFE_NewOp(context::get_context(),
"FresnelSin", context::get_status()), &TFE_DeleteOp);
9267 status_check(context::get_status());
9271 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9272 status_check(context::get_status());
9277 int num_outputs_op = 1;
9278 TFE_TensorHandle* res[1] = {
nullptr};
9279 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9280 status_check(context::get_status());
9281 return tensor(res[0]);
9284 inline tensor fused_pad_conv2_d(
const tensor& input,
const tensor& paddings,
const tensor& filter,
9285 const std::string& mode,
const std::vector<int64_t>& strides,
9286 const std::string& padding) {
9288 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9289 TFE_NewOp(context::get_context(),
"FusedPadConv2D", context::get_status()), &TFE_DeleteOp);
9290 status_check(context::get_status());
9294 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9295 status_check(context::get_status());
9297 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
9298 status_check(context::get_status());
9300 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
9301 status_check(context::get_status());
9304 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
9305 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
9306 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
9309 int num_outputs_op = 1;
9310 TFE_TensorHandle* res[1] = {
nullptr};
9311 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9312 status_check(context::get_status());
9313 return tensor(res[0]);
9316 inline tensor fused_resize_and_pad_conv2_d(
const tensor& input,
const tensor& size,
const tensor& paddings,
9317 const tensor& filter,
const std::string& mode,
9318 const std::vector<int64_t>& strides,
const std::string& padding,
9319 bool resize_align_corners =
false) {
9321 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9322 TFE_NewOp(context::get_context(),
"FusedResizeAndPadConv2D", context::get_status()), &TFE_DeleteOp);
9323 status_check(context::get_status());
9327 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9328 status_check(context::get_status());
9330 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
9331 status_check(context::get_status());
9333 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
9334 status_check(context::get_status());
9336 TFE_OpAddInput(op.get(), filter.tfe_handle.get(), context::get_status());
9337 status_check(context::get_status());
9340 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
9341 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
9342 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
9343 TFE_OpSetAttrBool(op.get(),
"resize_align_corners", (
unsigned char)resize_align_corners);
9346 int num_outputs_op = 1;
9347 TFE_TensorHandle* res[1] = {
nullptr};
9348 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9349 status_check(context::get_status());
9350 return tensor(res[0]);
9353 inline tensor gather(
const tensor& params,
const tensor& indices, datatype Tparams, datatype Tindices,
9354 bool validate_indices =
true) {
9356 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9357 TFE_NewOp(context::get_context(),
"Gather", context::get_status()), &TFE_DeleteOp);
9358 status_check(context::get_status());
9362 TFE_OpAddInput(op.get(), params.tfe_handle.get(), context::get_status());
9363 status_check(context::get_status());
9365 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
9366 status_check(context::get_status());
9369 TFE_OpSetAttrType(op.get(),
"Tparams", Tparams);
9370 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
9371 TFE_OpSetAttrBool(op.get(),
"validate_indices", (
unsigned char)validate_indices);
9374 int num_outputs_op = 1;
9375 TFE_TensorHandle* res[1] = {
nullptr};
9376 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9377 status_check(context::get_status());
9378 return tensor(res[0]);
9381 inline tensor gather_nd(
const tensor& params,
const tensor& indices, datatype Tparams, datatype Tindices) {
9383 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9384 TFE_NewOp(context::get_context(),
"GatherNd", context::get_status()), &TFE_DeleteOp);
9385 status_check(context::get_status());
9389 TFE_OpAddInput(op.get(), params.tfe_handle.get(), context::get_status());
9390 status_check(context::get_status());
9392 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
9393 status_check(context::get_status());
9396 TFE_OpSetAttrType(op.get(),
"Tparams", Tparams);
9397 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
9400 int num_outputs_op = 1;
9401 TFE_TensorHandle* res[1] = {
nullptr};
9402 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9403 status_check(context::get_status());
9404 return tensor(res[0]);
9407 inline tensor gather_v2(
const tensor& params,
const tensor& indices,
const tensor& axis, datatype Tparams,
9408 datatype Tindices, datatype Taxis, int64_t batch_dims = 0) {
9410 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9411 TFE_NewOp(context::get_context(),
"GatherV2", context::get_status()), &TFE_DeleteOp);
9412 status_check(context::get_status());
9416 TFE_OpAddInput(op.get(), params.tfe_handle.get(), context::get_status());
9417 status_check(context::get_status());
9419 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
9420 status_check(context::get_status());
9422 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
9423 status_check(context::get_status());
9426 TFE_OpSetAttrType(op.get(),
"Tparams", Tparams);
9427 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
9428 TFE_OpSetAttrType(op.get(),
"Taxis", Taxis);
9429 TFE_OpSetAttrInt(op.get(),
"batch_dims", batch_dims);
9432 int num_outputs_op = 1;
9433 TFE_TensorHandle* res[1] = {
nullptr};
9434 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9435 status_check(context::get_status());
9436 return tensor(res[0]);
9439 inline tensor get_session_handle(
const tensor& value) {
9441 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9442 TFE_NewOp(context::get_context(),
"GetSessionHandle", context::get_status()), &TFE_DeleteOp);
9443 status_check(context::get_status());
9447 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
9448 status_check(context::get_status());
9453 int num_outputs_op = 1;
9454 TFE_TensorHandle* res[1] = {
nullptr};
9455 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9456 status_check(context::get_status());
9457 return tensor(res[0]);
9460 inline tensor get_session_handle_v2(
const tensor& value) {
9462 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9463 TFE_NewOp(context::get_context(),
"GetSessionHandleV2", context::get_status()), &TFE_DeleteOp);
9464 status_check(context::get_status());
9468 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
9469 status_check(context::get_status());
9474 int num_outputs_op = 1;
9475 TFE_TensorHandle* res[1] = {
nullptr};
9476 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9477 status_check(context::get_status());
9478 return tensor(res[0]);
9481 inline tensor get_session_tensor(
const tensor& handle, datatype dtype) {
9483 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9484 TFE_NewOp(context::get_context(),
"GetSessionTensor", context::get_status()), &TFE_DeleteOp);
9485 status_check(context::get_status());
9489 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
9490 status_check(context::get_status());
9493 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
9496 int num_outputs_op = 1;
9497 TFE_TensorHandle* res[1] = {
nullptr};
9498 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9499 status_check(context::get_status());
9500 return tensor(res[0]);
9503 inline tensor greater(
const tensor& x,
const tensor& y) {
9505 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9506 TFE_NewOp(context::get_context(),
"Greater", context::get_status()), &TFE_DeleteOp);
9507 status_check(context::get_status());
9511 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9512 status_check(context::get_status());
9514 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
9515 status_check(context::get_status());
9520 int num_outputs_op = 1;
9521 TFE_TensorHandle* res[1] = {
nullptr};
9522 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9523 status_check(context::get_status());
9524 return tensor(res[0]);
9527 inline tensor greater_equal(
const tensor& x,
const tensor& y) {
9529 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9530 TFE_NewOp(context::get_context(),
"GreaterEqual", context::get_status()), &TFE_DeleteOp);
9531 status_check(context::get_status());
9535 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9536 status_check(context::get_status());
9538 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
9539 status_check(context::get_status());
9544 int num_outputs_op = 1;
9545 TFE_TensorHandle* res[1] = {
nullptr};
9546 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9547 status_check(context::get_status());
9548 return tensor(res[0]);
9551 inline tensor guarantee_const_tensor(
const tensor& input) {
9553 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9554 TFE_NewOp(context::get_context(),
"GuaranteeConst", context::get_status()), &TFE_DeleteOp);
9555 status_check(context::get_status());
9559 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9560 status_check(context::get_status());
9565 int num_outputs_op = 1;
9566 TFE_TensorHandle* res[1] = {
nullptr};
9567 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9568 status_check(context::get_status());
9569 return tensor(res[0]);
9572 inline tensor h_s_v_to_r_g_b(
const tensor& images) {
9574 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9575 TFE_NewOp(context::get_context(),
"HSVToRGB", context::get_status()), &TFE_DeleteOp);
9576 status_check(context::get_status());
9580 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
9581 status_check(context::get_status());
9586 int num_outputs_op = 1;
9587 TFE_TensorHandle* res[1] = {
nullptr};
9588 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9589 status_check(context::get_status());
9590 return tensor(res[0]);
9593 inline tensor hash_table(datatype key_dtype, datatype value_dtype,
const std::string& container =
"",
9594 const std::string& shared_name =
"",
bool use_node_name_sharing =
false) {
9596 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9597 TFE_NewOp(context::get_context(),
"HashTable", context::get_status()), &TFE_DeleteOp);
9598 status_check(context::get_status());
9603 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
9604 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
9605 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9606 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9607 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
9610 int num_outputs_op = 1;
9611 TFE_TensorHandle* res[1] = {
nullptr};
9612 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9613 status_check(context::get_status());
9614 return tensor(res[0]);
9617 inline tensor hash_table_v2(datatype key_dtype, datatype value_dtype,
const std::string& container =
"",
9618 const std::string& shared_name =
"",
bool use_node_name_sharing =
false) {
9620 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9621 TFE_NewOp(context::get_context(),
"HashTableV2", context::get_status()), &TFE_DeleteOp);
9622 status_check(context::get_status());
9627 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
9628 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
9629 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9630 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9631 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
9634 int num_outputs_op = 1;
9635 TFE_TensorHandle* res[1] = {
nullptr};
9636 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9637 status_check(context::get_status());
9638 return tensor(res[0]);
9641 inline tensor histogram_fixed_width(
const tensor& values,
const tensor& value_range,
const tensor& nbins,
9642 datatype dtype =
static_cast<datatype
>(3)) {
9644 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9645 TFE_NewOp(context::get_context(),
"HistogramFixedWidth", context::get_status()), &TFE_DeleteOp);
9646 status_check(context::get_status());
9650 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
9651 status_check(context::get_status());
9653 TFE_OpAddInput(op.get(), value_range.tfe_handle.get(), context::get_status());
9654 status_check(context::get_status());
9656 TFE_OpAddInput(op.get(), nbins.tfe_handle.get(), context::get_status());
9657 status_check(context::get_status());
9660 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
9663 int num_outputs_op = 1;
9664 TFE_TensorHandle* res[1] = {
nullptr};
9665 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9666 status_check(context::get_status());
9667 return tensor(res[0]);
9670 inline tensor histogram_summary(
const tensor& tag,
const tensor& values) {
9672 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9673 TFE_NewOp(context::get_context(),
"HistogramSummary", context::get_status()), &TFE_DeleteOp);
9674 status_check(context::get_status());
9678 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
9679 status_check(context::get_status());
9681 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
9682 status_check(context::get_status());
9687 int num_outputs_op = 1;
9688 TFE_TensorHandle* res[1] = {
nullptr};
9689 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9690 status_check(context::get_status());
9691 return tensor(res[0]);
9694 inline tensor i_f_f_t(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
9696 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"IFFT", context::get_status()),
9698 status_check(context::get_status());
9702 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9703 status_check(context::get_status());
9706 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9709 int num_outputs_op = 1;
9710 TFE_TensorHandle* res[1] = {
nullptr};
9711 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9712 status_check(context::get_status());
9713 return tensor(res[0]);
9716 inline tensor i_f_f_t2_d(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
9718 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9719 TFE_NewOp(context::get_context(),
"IFFT2D", context::get_status()), &TFE_DeleteOp);
9720 status_check(context::get_status());
9724 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9725 status_check(context::get_status());
9728 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9731 int num_outputs_op = 1;
9732 TFE_TensorHandle* res[1] = {
nullptr};
9733 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9734 status_check(context::get_status());
9735 return tensor(res[0]);
9738 inline tensor i_f_f_t3_d(
const tensor& input, datatype Tcomplex =
static_cast<datatype
>(8)) {
9740 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9741 TFE_NewOp(context::get_context(),
"IFFT3D", context::get_status()), &TFE_DeleteOp);
9742 status_check(context::get_status());
9746 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9747 status_check(context::get_status());
9750 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9753 int num_outputs_op = 1;
9754 TFE_TensorHandle* res[1] = {
nullptr};
9755 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9756 status_check(context::get_status());
9757 return tensor(res[0]);
9760 inline tensor i_r_f_f_t(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
9761 datatype Tcomplex =
static_cast<datatype
>(8)) {
9763 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"IRFFT", context::get_status()),
9765 status_check(context::get_status());
9769 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9770 status_check(context::get_status());
9772 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
9773 status_check(context::get_status());
9776 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
9777 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9780 int num_outputs_op = 1;
9781 TFE_TensorHandle* res[1] = {
nullptr};
9782 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9783 status_check(context::get_status());
9784 return tensor(res[0]);
9787 inline tensor i_r_f_f_t2_d(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
9788 datatype Tcomplex =
static_cast<datatype
>(8)) {
9790 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9791 TFE_NewOp(context::get_context(),
"IRFFT2D", context::get_status()), &TFE_DeleteOp);
9792 status_check(context::get_status());
9796 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9797 status_check(context::get_status());
9799 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
9800 status_check(context::get_status());
9803 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
9804 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9807 int num_outputs_op = 1;
9808 TFE_TensorHandle* res[1] = {
nullptr};
9809 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9810 status_check(context::get_status());
9811 return tensor(res[0]);
9814 inline tensor i_r_f_f_t3_d(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
9815 datatype Tcomplex =
static_cast<datatype
>(8)) {
9817 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9818 TFE_NewOp(context::get_context(),
"IRFFT3D", context::get_status()), &TFE_DeleteOp);
9819 status_check(context::get_status());
9823 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9824 status_check(context::get_status());
9826 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
9827 status_check(context::get_status());
9830 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
9831 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
9834 int num_outputs_op = 1;
9835 TFE_TensorHandle* res[1] = {
nullptr};
9836 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9837 status_check(context::get_status());
9838 return tensor(res[0]);
9841 inline tensor identity(
const tensor& input) {
9843 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9844 TFE_NewOp(context::get_context(),
"Identity", context::get_status()), &TFE_DeleteOp);
9845 status_check(context::get_status());
9849 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
9850 status_check(context::get_status());
9855 int num_outputs_op = 1;
9856 TFE_TensorHandle* res[1] = {
nullptr};
9857 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9858 status_check(context::get_status());
9859 return tensor(res[0]);
9862 inline tensor identity_n(
const std::vector<tensor>& input) {
9864 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9865 TFE_NewOp(context::get_context(),
"IdentityN", context::get_status()), &TFE_DeleteOp);
9866 status_check(context::get_status());
9870 std::vector<TFE_TensorHandle*> input_handles;
9871 input_handles.reserve(input.size());
9872 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
9873 [](
const auto& t) { return t.tfe_handle.get(); });
9874 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
9875 status_check(context::get_status());
9880 int num_outputs_op = 1;
9881 TFE_TensorHandle* res[1] = {
nullptr};
9882 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9883 status_check(context::get_status());
9884 return tensor(res[0]);
9887 inline tensor identity_reader(
const std::string& container =
"",
const std::string& shared_name =
"") {
9889 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9890 TFE_NewOp(context::get_context(),
"IdentityReader", context::get_status()), &TFE_DeleteOp);
9891 status_check(context::get_status());
9896 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9897 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9900 int num_outputs_op = 1;
9901 TFE_TensorHandle* res[1] = {
nullptr};
9902 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9903 status_check(context::get_status());
9904 return tensor(res[0]);
9907 inline tensor identity_reader_v2(
const std::string& container =
"",
const std::string& shared_name =
"") {
9909 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9910 TFE_NewOp(context::get_context(),
"IdentityReaderV2", context::get_status()), &TFE_DeleteOp);
9911 status_check(context::get_status());
9916 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
9917 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
9920 int num_outputs_op = 1;
9921 TFE_TensorHandle* res[1] = {
nullptr};
9922 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9923 status_check(context::get_status());
9924 return tensor(res[0]);
9927 inline tensor igamma(
const tensor& a,
const tensor& x) {
9929 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9930 TFE_NewOp(context::get_context(),
"Igamma", context::get_status()), &TFE_DeleteOp);
9931 status_check(context::get_status());
9935 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
9936 status_check(context::get_status());
9938 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9939 status_check(context::get_status());
9944 int num_outputs_op = 1;
9945 TFE_TensorHandle* res[1] = {
nullptr};
9946 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9947 status_check(context::get_status());
9948 return tensor(res[0]);
9951 inline tensor igamma_grad_a(
const tensor& a,
const tensor& x) {
9953 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9954 TFE_NewOp(context::get_context(),
"IgammaGradA", context::get_status()), &TFE_DeleteOp);
9955 status_check(context::get_status());
9959 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
9960 status_check(context::get_status());
9962 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9963 status_check(context::get_status());
9968 int num_outputs_op = 1;
9969 TFE_TensorHandle* res[1] = {
nullptr};
9970 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9971 status_check(context::get_status());
9972 return tensor(res[0]);
9975 inline tensor igammac(
const tensor& a,
const tensor& x) {
9977 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
9978 TFE_NewOp(context::get_context(),
"Igammac", context::get_status()), &TFE_DeleteOp);
9979 status_check(context::get_status());
9983 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
9984 status_check(context::get_status());
9986 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
9987 status_check(context::get_status());
9992 int num_outputs_op = 1;
9993 TFE_TensorHandle* res[1] = {
nullptr};
9994 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
9995 status_check(context::get_status());
9996 return tensor(res[0]);
9999 inline tensor ignore_errors_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
10000 const std::vector<std::vector<int64_t>>& output_shapes) {
10002 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10003 TFE_NewOp(context::get_context(),
"IgnoreErrorsDataset", context::get_status()), &TFE_DeleteOp);
10004 status_check(context::get_status());
10008 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
10009 status_check(context::get_status());
10012 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10013 static_cast<int>(output_types.size()));
10015 std::vector<const int64_t*> output_shapes_values;
10016 output_shapes_values.reserve(output_shapes.size());
10017 std::vector<int> output_shapes_ndims;
10018 output_shapes_ndims.reserve(output_shapes.size());
10019 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10020 [](
const auto& v) { return v.data(); });
10021 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10022 [](
const auto& v) { return static_cast<int>(v.size()); });
10023 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10024 static_cast<int>(output_shapes.size()), context::get_status());
10025 status_check(context::get_status());
10028 int num_outputs_op = 1;
10029 TFE_TensorHandle* res[1] = {
nullptr};
10030 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10031 status_check(context::get_status());
10032 return tensor(res[0]);
10035 inline tensor imag(
const tensor& input, datatype Tout =
static_cast<datatype
>(1)) {
10037 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Imag", context::get_status()),
10039 status_check(context::get_status());
10043 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
10044 status_check(context::get_status());
10047 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
10050 int num_outputs_op = 1;
10051 TFE_TensorHandle* res[1] = {
nullptr};
10052 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10053 status_check(context::get_status());
10054 return tensor(res[0]);
10057 inline tensor image_projective_transform_v2(
const tensor& images,
const tensor& transforms,
const tensor& output_shape,
10058 datatype dtype,
const std::string& interpolation,
10059 const std::string& fill_mode =
"CONSTANT") {
10061 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10062 TFE_NewOp(context::get_context(),
"ImageProjectiveTransformV2", context::get_status()), &TFE_DeleteOp);
10063 status_check(context::get_status());
10067 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
10068 status_check(context::get_status());
10070 TFE_OpAddInput(op.get(), transforms.tfe_handle.get(), context::get_status());
10071 status_check(context::get_status());
10073 TFE_OpAddInput(op.get(), output_shape.tfe_handle.get(), context::get_status());
10074 status_check(context::get_status());
10077 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
10078 TFE_OpSetAttrString(op.get(),
"interpolation", (
void*)interpolation.c_str(), interpolation.size());
10079 TFE_OpSetAttrString(op.get(),
"fill_mode", (
void*)fill_mode.c_str(), fill_mode.size());
10082 int num_outputs_op = 1;
10083 TFE_TensorHandle* res[1] = {
nullptr};
10084 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10085 status_check(context::get_status());
10086 return tensor(res[0]);
10089 inline tensor image_summary(
const tensor& tag,
const tensor& input_tensor,
const tensor& bad_color,
10090 int64_t max_images = 3) {
10092 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10093 TFE_NewOp(context::get_context(),
"ImageSummary", context::get_status()), &TFE_DeleteOp);
10094 status_check(context::get_status());
10098 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
10099 status_check(context::get_status());
10101 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
10102 status_check(context::get_status());
10106 TFE_OpSetAttrTensor(op.get(),
"bad_color", bad_color.get_tensor().get(), context::get_status());
10107 status_check(context::get_status());
10109 TFE_OpSetAttrInt(op.get(),
"max_images", max_images);
10112 int num_outputs_op = 1;
10113 TFE_TensorHandle* res[1] = {
nullptr};
10114 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10115 status_check(context::get_status());
10116 return tensor(res[0]);
10119 inline tensor immutable_const_tensor(datatype dtype,
const std::vector<int64_t>& shape,
10120 const std::string& memory_region_name) {
10122 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10123 TFE_NewOp(context::get_context(),
"ImmutableConst", context::get_status()), &TFE_DeleteOp);
10124 status_check(context::get_status());
10129 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
10131 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
10132 status_check(context::get_status());
10134 TFE_OpSetAttrString(op.get(),
"memory_region_name", (
void*)memory_region_name.c_str(), memory_region_name.size());
10137 int num_outputs_op = 1;
10138 TFE_TensorHandle* res[1] = {
nullptr};
10139 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10140 status_check(context::get_status());
10141 return tensor(res[0]);
10144 inline tensor in_top_k(
const tensor& predictions,
const tensor& targets, int64_t k) {
10146 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10147 TFE_NewOp(context::get_context(),
"InTopK", context::get_status()), &TFE_DeleteOp);
10148 status_check(context::get_status());
10152 TFE_OpAddInput(op.get(), predictions.tfe_handle.get(), context::get_status());
10153 status_check(context::get_status());
10155 TFE_OpAddInput(op.get(), targets.tfe_handle.get(), context::get_status());
10156 status_check(context::get_status());
10159 TFE_OpSetAttrInt(op.get(),
"k", k);
10162 int num_outputs_op = 1;
10163 TFE_TensorHandle* res[1] = {
nullptr};
10164 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10165 status_check(context::get_status());
10166 return tensor(res[0]);
10169 inline tensor in_top_k_v2(
const tensor& predictions,
const tensor& targets,
const tensor& k) {
10171 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10172 TFE_NewOp(context::get_context(),
"InTopKV2", context::get_status()), &TFE_DeleteOp);
10173 status_check(context::get_status());
10177 TFE_OpAddInput(op.get(), predictions.tfe_handle.get(), context::get_status());
10178 status_check(context::get_status());
10180 TFE_OpAddInput(op.get(), targets.tfe_handle.get(), context::get_status());
10181 status_check(context::get_status());
10183 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
10184 status_check(context::get_status());
10189 int num_outputs_op = 1;
10190 TFE_TensorHandle* res[1] = {
nullptr};
10191 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10192 status_check(context::get_status());
10193 return tensor(res[0]);
10196 inline tensor infeed_dequeue(datatype dtype,
const std::vector<int64_t>& shape) {
10198 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10199 TFE_NewOp(context::get_context(),
"InfeedDequeue", context::get_status()), &TFE_DeleteOp);
10200 status_check(context::get_status());
10205 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
10207 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
10208 status_check(context::get_status());
10211 int num_outputs_op = 1;
10212 TFE_TensorHandle* res[1] = {
nullptr};
10213 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10214 status_check(context::get_status());
10215 return tensor(res[0]);
10218 inline tensor infeed_dequeue_tuple(
const std::vector<datatype>& dtypes,
10219 const std::vector<std::vector<int64_t>>& shapes) {
10221 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10222 TFE_NewOp(context::get_context(),
"InfeedDequeueTuple", context::get_status()), &TFE_DeleteOp);
10223 status_check(context::get_status());
10228 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
10229 static_cast<int>(dtypes.size()));
10231 std::vector<const int64_t*> shapes_values;
10232 shapes_values.reserve(shapes.size());
10233 std::vector<int> shapes_ndims;
10234 shapes_ndims.reserve(shapes.size());
10235 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
10236 [](
const auto& v) { return v.data(); });
10237 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
10238 [](
const auto& v) { return static_cast<int>(v.size()); });
10239 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
10240 context::get_status());
10241 status_check(context::get_status());
10244 int num_outputs_op = 1;
10245 TFE_TensorHandle* res[1] = {
nullptr};
10246 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10247 status_check(context::get_status());
10248 return tensor(res[0]);
10251 inline tensor inplace_add(
const tensor& x,
const tensor& i,
const tensor& v) {
10253 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10254 TFE_NewOp(context::get_context(),
"InplaceAdd", context::get_status()), &TFE_DeleteOp);
10255 status_check(context::get_status());
10259 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10260 status_check(context::get_status());
10262 TFE_OpAddInput(op.get(), i.tfe_handle.get(), context::get_status());
10263 status_check(context::get_status());
10265 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
10266 status_check(context::get_status());
10271 int num_outputs_op = 1;
10272 TFE_TensorHandle* res[1] = {
nullptr};
10273 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10274 status_check(context::get_status());
10275 return tensor(res[0]);
10278 inline tensor inplace_sub(
const tensor& x,
const tensor& i,
const tensor& v) {
10280 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10281 TFE_NewOp(context::get_context(),
"InplaceSub", context::get_status()), &TFE_DeleteOp);
10282 status_check(context::get_status());
10286 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10287 status_check(context::get_status());
10289 TFE_OpAddInput(op.get(), i.tfe_handle.get(), context::get_status());
10290 status_check(context::get_status());
10292 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
10293 status_check(context::get_status());
10298 int num_outputs_op = 1;
10299 TFE_TensorHandle* res[1] = {
nullptr};
10300 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10301 status_check(context::get_status());
10302 return tensor(res[0]);
10305 inline tensor inplace_update(
const tensor& x,
const tensor& i,
const tensor& v) {
10307 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10308 TFE_NewOp(context::get_context(),
"InplaceUpdate", context::get_status()), &TFE_DeleteOp);
10309 status_check(context::get_status());
10313 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10314 status_check(context::get_status());
10316 TFE_OpAddInput(op.get(), i.tfe_handle.get(), context::get_status());
10317 status_check(context::get_status());
10319 TFE_OpAddInput(op.get(), v.tfe_handle.get(), context::get_status());
10320 status_check(context::get_status());
10325 int num_outputs_op = 1;
10326 TFE_TensorHandle* res[1] = {
nullptr};
10327 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10328 status_check(context::get_status());
10329 return tensor(res[0]);
10332 inline tensor inv(
const tensor& x) {
10334 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Inv", context::get_status()),
10336 status_check(context::get_status());
10340 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10341 status_check(context::get_status());
10346 int num_outputs_op = 1;
10347 TFE_TensorHandle* res[1] = {
nullptr};
10348 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10349 status_check(context::get_status());
10350 return tensor(res[0]);
10353 inline tensor inv_grad(
const tensor& y,
const tensor& dy) {
10355 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10356 TFE_NewOp(context::get_context(),
"InvGrad", context::get_status()), &TFE_DeleteOp);
10357 status_check(context::get_status());
10361 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
10362 status_check(context::get_status());
10364 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
10365 status_check(context::get_status());
10370 int num_outputs_op = 1;
10371 TFE_TensorHandle* res[1] = {
nullptr};
10372 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10373 status_check(context::get_status());
10374 return tensor(res[0]);
10377 inline tensor invert(
const tensor& x) {
10379 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10380 TFE_NewOp(context::get_context(),
"Invert", context::get_status()), &TFE_DeleteOp);
10381 status_check(context::get_status());
10385 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10386 status_check(context::get_status());
10391 int num_outputs_op = 1;
10392 TFE_TensorHandle* res[1] = {
nullptr};
10393 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10394 status_check(context::get_status());
10395 return tensor(res[0]);
10398 inline tensor invert_permutation(
const tensor& x) {
10400 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10401 TFE_NewOp(context::get_context(),
"InvertPermutation", context::get_status()), &TFE_DeleteOp);
10402 status_check(context::get_status());
10406 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10407 status_check(context::get_status());
10412 int num_outputs_op = 1;
10413 TFE_TensorHandle* res[1] = {
nullptr};
10414 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10415 status_check(context::get_status());
10416 return tensor(res[0]);
10419 inline tensor is_boosted_trees_ensemble_initialized(
const tensor& tree_ensemble_handle) {
10421 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10422 TFE_NewOp(context::get_context(),
"IsBoostedTreesEnsembleInitialized", context::get_status()), &TFE_DeleteOp);
10423 status_check(context::get_status());
10427 TFE_OpAddInput(op.get(), tree_ensemble_handle.tfe_handle.get(), context::get_status());
10428 status_check(context::get_status());
10433 int num_outputs_op = 1;
10434 TFE_TensorHandle* res[1] = {
nullptr};
10435 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10436 status_check(context::get_status());
10437 return tensor(res[0]);
10440 inline tensor is_boosted_trees_quantile_stream_resource_initialized(
const tensor& quantile_stream_resource_handle) {
10442 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10443 TFE_NewOp(context::get_context(),
"IsBoostedTreesQuantileStreamResourceInitialized", context::get_status()),
10445 status_check(context::get_status());
10449 TFE_OpAddInput(op.get(), quantile_stream_resource_handle.tfe_handle.get(), context::get_status());
10450 status_check(context::get_status());
10455 int num_outputs_op = 1;
10456 TFE_TensorHandle* res[1] = {
nullptr};
10457 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10458 status_check(context::get_status());
10459 return tensor(res[0]);
10462 inline tensor is_finite(
const tensor& x) {
10464 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10465 TFE_NewOp(context::get_context(),
"IsFinite", context::get_status()), &TFE_DeleteOp);
10466 status_check(context::get_status());
10470 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10471 status_check(context::get_status());
10476 int num_outputs_op = 1;
10477 TFE_TensorHandle* res[1] = {
nullptr};
10478 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10479 status_check(context::get_status());
10480 return tensor(res[0]);
10483 inline tensor is_inf(
const tensor& x) {
10485 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"IsInf", context::get_status()),
10487 status_check(context::get_status());
10491 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10492 status_check(context::get_status());
10497 int num_outputs_op = 1;
10498 TFE_TensorHandle* res[1] = {
nullptr};
10499 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10500 status_check(context::get_status());
10501 return tensor(res[0]);
10504 inline tensor is_nan(
const tensor& x) {
10506 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"IsNan", context::get_status()),
10508 status_check(context::get_status());
10512 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
10513 status_check(context::get_status());
10518 int num_outputs_op = 1;
10519 TFE_TensorHandle* res[1] = {
nullptr};
10520 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10521 status_check(context::get_status());
10522 return tensor(res[0]);
10525 inline tensor is_variable_initialized(
const tensor& ref, datatype dtype) {
10527 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10528 TFE_NewOp(context::get_context(),
"IsVariableInitialized", context::get_status()), &TFE_DeleteOp);
10529 status_check(context::get_status());
10533 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
10534 status_check(context::get_status());
10537 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
10540 int num_outputs_op = 1;
10541 TFE_TensorHandle* res[1] = {
nullptr};
10542 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10543 status_check(context::get_status());
10544 return tensor(res[0]);
10547 inline tensor iterator(
const std::string& shared_name,
const std::string& container,
10548 const std::vector<datatype>& output_types,
10549 const std::vector<std::vector<int64_t>>& output_shapes) {
10551 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10552 TFE_NewOp(context::get_context(),
"Iterator", context::get_status()), &TFE_DeleteOp);
10553 status_check(context::get_status());
10558 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
10559 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
10560 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10561 static_cast<int>(output_types.size()));
10563 std::vector<const int64_t*> output_shapes_values;
10564 output_shapes_values.reserve(output_shapes.size());
10565 std::vector<int> output_shapes_ndims;
10566 output_shapes_ndims.reserve(output_shapes.size());
10567 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10568 [](
const auto& v) { return v.data(); });
10569 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10570 [](
const auto& v) { return static_cast<int>(v.size()); });
10571 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10572 static_cast<int>(output_shapes.size()), context::get_status());
10573 status_check(context::get_status());
10576 int num_outputs_op = 1;
10577 TFE_TensorHandle* res[1] = {
nullptr};
10578 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10579 status_check(context::get_status());
10580 return tensor(res[0]);
10583 inline tensor iterator_from_string_handle(
const tensor& string_handle,
const std::vector<datatype>& output_types,
10584 const std::vector<std::vector<int64_t>>& output_shapes) {
10586 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10587 TFE_NewOp(context::get_context(),
"IteratorFromStringHandle", context::get_status()), &TFE_DeleteOp);
10588 status_check(context::get_status());
10592 TFE_OpAddInput(op.get(), string_handle.tfe_handle.get(), context::get_status());
10593 status_check(context::get_status());
10596 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10597 static_cast<int>(output_types.size()));
10599 std::vector<const int64_t*> output_shapes_values;
10600 output_shapes_values.reserve(output_shapes.size());
10601 std::vector<int> output_shapes_ndims;
10602 output_shapes_ndims.reserve(output_shapes.size());
10603 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10604 [](
const auto& v) { return v.data(); });
10605 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10606 [](
const auto& v) { return static_cast<int>(v.size()); });
10607 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10608 static_cast<int>(output_shapes.size()), context::get_status());
10609 status_check(context::get_status());
10612 int num_outputs_op = 1;
10613 TFE_TensorHandle* res[1] = {
nullptr};
10614 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10615 status_check(context::get_status());
10616 return tensor(res[0]);
10619 inline tensor iterator_from_string_handle_v2(
const tensor& string_handle,
const std::vector<datatype>& output_types,
10620 const std::vector<std::vector<int64_t>>& output_shapes) {
10622 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10623 TFE_NewOp(context::get_context(),
"IteratorFromStringHandleV2", context::get_status()), &TFE_DeleteOp);
10624 status_check(context::get_status());
10628 TFE_OpAddInput(op.get(), string_handle.tfe_handle.get(), context::get_status());
10629 status_check(context::get_status());
10632 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10633 static_cast<int>(output_types.size()));
10635 std::vector<const int64_t*> output_shapes_values;
10636 output_shapes_values.reserve(output_shapes.size());
10637 std::vector<int> output_shapes_ndims;
10638 output_shapes_ndims.reserve(output_shapes.size());
10639 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10640 [](
const auto& v) { return v.data(); });
10641 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10642 [](
const auto& v) { return static_cast<int>(v.size()); });
10643 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10644 static_cast<int>(output_shapes.size()), context::get_status());
10645 status_check(context::get_status());
10648 int num_outputs_op = 1;
10649 TFE_TensorHandle* res[1] = {
nullptr};
10650 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10651 status_check(context::get_status());
10652 return tensor(res[0]);
10655 inline tensor iterator_get_device(
const tensor& resource) {
10657 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10658 TFE_NewOp(context::get_context(),
"IteratorGetDevice", context::get_status()), &TFE_DeleteOp);
10659 status_check(context::get_status());
10663 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
10664 status_check(context::get_status());
10669 int num_outputs_op = 1;
10670 TFE_TensorHandle* res[1] = {
nullptr};
10671 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10672 status_check(context::get_status());
10673 return tensor(res[0]);
10676 inline tensor iterator_get_next(
const tensor& iterator,
const std::vector<datatype>& output_types,
10677 const std::vector<std::vector<int64_t>>& output_shapes) {
10679 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10680 TFE_NewOp(context::get_context(),
"IteratorGetNext", context::get_status()), &TFE_DeleteOp);
10681 status_check(context::get_status());
10685 TFE_OpAddInput(op.get(), iterator.tfe_handle.get(), context::get_status());
10686 status_check(context::get_status());
10689 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10690 static_cast<int>(output_types.size()));
10692 std::vector<const int64_t*> output_shapes_values;
10693 output_shapes_values.reserve(output_shapes.size());
10694 std::vector<int> output_shapes_ndims;
10695 output_shapes_ndims.reserve(output_shapes.size());
10696 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10697 [](
const auto& v) { return v.data(); });
10698 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10699 [](
const auto& v) { return static_cast<int>(v.size()); });
10700 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10701 static_cast<int>(output_shapes.size()), context::get_status());
10702 status_check(context::get_status());
10705 int num_outputs_op = 1;
10706 TFE_TensorHandle* res[1] = {
nullptr};
10707 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10708 status_check(context::get_status());
10709 return tensor(res[0]);
10712 inline tensor iterator_get_next_as_optional(
const tensor& iterator,
const std::vector<datatype>& output_types,
10713 const std::vector<std::vector<int64_t>>& output_shapes) {
10715 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10716 TFE_NewOp(context::get_context(),
"IteratorGetNextAsOptional", context::get_status()), &TFE_DeleteOp);
10717 status_check(context::get_status());
10721 TFE_OpAddInput(op.get(), iterator.tfe_handle.get(), context::get_status());
10722 status_check(context::get_status());
10725 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10726 static_cast<int>(output_types.size()));
10728 std::vector<const int64_t*> output_shapes_values;
10729 output_shapes_values.reserve(output_shapes.size());
10730 std::vector<int> output_shapes_ndims;
10731 output_shapes_ndims.reserve(output_shapes.size());
10732 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10733 [](
const auto& v) { return v.data(); });
10734 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10735 [](
const auto& v) { return static_cast<int>(v.size()); });
10736 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10737 static_cast<int>(output_shapes.size()), context::get_status());
10738 status_check(context::get_status());
10741 int num_outputs_op = 1;
10742 TFE_TensorHandle* res[1] = {
nullptr};
10743 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10744 status_check(context::get_status());
10745 return tensor(res[0]);
10748 inline tensor iterator_get_next_sync(
const tensor& iterator,
const std::vector<datatype>& output_types,
10749 const std::vector<std::vector<int64_t>>& output_shapes) {
10751 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10752 TFE_NewOp(context::get_context(),
"IteratorGetNextSync", context::get_status()), &TFE_DeleteOp);
10753 status_check(context::get_status());
10757 TFE_OpAddInput(op.get(), iterator.tfe_handle.get(), context::get_status());
10758 status_check(context::get_status());
10761 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10762 static_cast<int>(output_types.size()));
10764 std::vector<const int64_t*> output_shapes_values;
10765 output_shapes_values.reserve(output_shapes.size());
10766 std::vector<int> output_shapes_ndims;
10767 output_shapes_ndims.reserve(output_shapes.size());
10768 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10769 [](
const auto& v) { return v.data(); });
10770 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10771 [](
const auto& v) { return static_cast<int>(v.size()); });
10772 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10773 static_cast<int>(output_shapes.size()), context::get_status());
10774 status_check(context::get_status());
10777 int num_outputs_op = 1;
10778 TFE_TensorHandle* res[1] = {
nullptr};
10779 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10780 status_check(context::get_status());
10781 return tensor(res[0]);
10784 inline tensor iterator_to_string_handle(
const tensor& resource_handle) {
10786 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10787 TFE_NewOp(context::get_context(),
"IteratorToStringHandle", context::get_status()), &TFE_DeleteOp);
10788 status_check(context::get_status());
10792 TFE_OpAddInput(op.get(), resource_handle.tfe_handle.get(), context::get_status());
10793 status_check(context::get_status());
10798 int num_outputs_op = 1;
10799 TFE_TensorHandle* res[1] = {
nullptr};
10800 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10801 status_check(context::get_status());
10802 return tensor(res[0]);
10805 inline tensor iterator_v2(
const std::string& shared_name,
const std::string& container,
10806 const std::vector<datatype>& output_types,
10807 const std::vector<std::vector<int64_t>>& output_shapes) {
10809 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10810 TFE_NewOp(context::get_context(),
"IteratorV2", context::get_status()), &TFE_DeleteOp);
10811 status_check(context::get_status());
10816 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
10817 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
10818 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10819 static_cast<int>(output_types.size()));
10821 std::vector<const int64_t*> output_shapes_values;
10822 output_shapes_values.reserve(output_shapes.size());
10823 std::vector<int> output_shapes_ndims;
10824 output_shapes_ndims.reserve(output_shapes.size());
10825 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10826 [](
const auto& v) { return v.data(); });
10827 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10828 [](
const auto& v) { return static_cast<int>(v.size()); });
10829 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10830 static_cast<int>(output_shapes.size()), context::get_status());
10831 status_check(context::get_status());
10834 int num_outputs_op = 1;
10835 TFE_TensorHandle* res[1] = {
nullptr};
10836 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10837 status_check(context::get_status());
10838 return tensor(res[0]);
10841 inline tensor l2_loss(
const tensor& t) {
10843 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10844 TFE_NewOp(context::get_context(),
"L2Loss", context::get_status()), &TFE_DeleteOp);
10845 status_check(context::get_status());
10849 TFE_OpAddInput(op.get(), t.tfe_handle.get(), context::get_status());
10850 status_check(context::get_status());
10855 int num_outputs_op = 1;
10856 TFE_TensorHandle* res[1] = {
nullptr};
10857 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10858 status_check(context::get_status());
10859 return tensor(res[0]);
10862 inline tensor l_m_d_b_dataset(
const tensor& filenames,
const std::vector<datatype>& output_types,
10863 const std::vector<std::vector<int64_t>>& output_shapes) {
10865 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10866 TFE_NewOp(context::get_context(),
"LMDBDataset", context::get_status()), &TFE_DeleteOp);
10867 status_check(context::get_status());
10871 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
10872 status_check(context::get_status());
10875 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10876 static_cast<int>(output_types.size()));
10878 std::vector<const int64_t*> output_shapes_values;
10879 output_shapes_values.reserve(output_shapes.size());
10880 std::vector<int> output_shapes_ndims;
10881 output_shapes_ndims.reserve(output_shapes.size());
10882 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
10883 [](
const auto& v) { return v.data(); });
10884 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
10885 [](
const auto& v) { return static_cast<int>(v.size()); });
10886 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
10887 static_cast<int>(output_shapes.size()), context::get_status());
10888 status_check(context::get_status());
10891 int num_outputs_op = 1;
10892 TFE_TensorHandle* res[1] = {
nullptr};
10893 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10894 status_check(context::get_status());
10895 return tensor(res[0]);
10898 inline tensor l_m_d_b_reader(
const std::string& container =
"",
const std::string& shared_name =
"") {
10900 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10901 TFE_NewOp(context::get_context(),
"LMDBReader", context::get_status()), &TFE_DeleteOp);
10902 status_check(context::get_status());
10907 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
10908 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
10911 int num_outputs_op = 1;
10912 TFE_TensorHandle* res[1] = {
nullptr};
10913 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10914 status_check(context::get_status());
10915 return tensor(res[0]);
10918 inline tensor l_r_n(
const tensor& input, int64_t depth_radius = 5,
float bias = 1.0000e+00,
float alpha = 1.0000e+00,
10919 float beta = 5.0000e-01) {
10921 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"LRN", context::get_status()),
10923 status_check(context::get_status());
10927 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
10928 status_check(context::get_status());
10931 TFE_OpSetAttrInt(op.get(),
"depth_radius", depth_radius);
10932 TFE_OpSetAttrFloat(op.get(),
"bias", bias);
10933 TFE_OpSetAttrFloat(op.get(),
"alpha", alpha);
10934 TFE_OpSetAttrFloat(op.get(),
"beta", beta);
10937 int num_outputs_op = 1;
10938 TFE_TensorHandle* res[1] = {
nullptr};
10939 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10940 status_check(context::get_status());
10941 return tensor(res[0]);
10944 inline tensor l_r_n_grad(
const tensor& input_grads,
const tensor& input_image,
const tensor& output_image,
10945 int64_t depth_radius = 5,
float bias = 1.0000e+00,
float alpha = 1.0000e+00,
10946 float beta = 5.0000e-01) {
10948 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10949 TFE_NewOp(context::get_context(),
"LRNGrad", context::get_status()), &TFE_DeleteOp);
10950 status_check(context::get_status());
10954 TFE_OpAddInput(op.get(), input_grads.tfe_handle.get(), context::get_status());
10955 status_check(context::get_status());
10957 TFE_OpAddInput(op.get(), input_image.tfe_handle.get(), context::get_status());
10958 status_check(context::get_status());
10960 TFE_OpAddInput(op.get(), output_image.tfe_handle.get(), context::get_status());
10961 status_check(context::get_status());
10964 TFE_OpSetAttrInt(op.get(),
"depth_radius", depth_radius);
10965 TFE_OpSetAttrFloat(op.get(),
"bias", bias);
10966 TFE_OpSetAttrFloat(op.get(),
"alpha", alpha);
10967 TFE_OpSetAttrFloat(op.get(),
"beta", beta);
10970 int num_outputs_op = 1;
10971 TFE_TensorHandle* res[1] = {
nullptr};
10972 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
10973 status_check(context::get_status());
10974 return tensor(res[0]);
10977 inline tensor latency_stats_dataset(
const tensor& input_dataset,
const tensor& tag,
10978 const std::vector<datatype>& output_types,
10979 const std::vector<std::vector<int64_t>>& output_shapes) {
10981 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
10982 TFE_NewOp(context::get_context(),
"LatencyStatsDataset", context::get_status()), &TFE_DeleteOp);
10983 status_check(context::get_status());
10987 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
10988 status_check(context::get_status());
10990 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
10991 status_check(context::get_status());
10994 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
10995 static_cast<int>(output_types.size()));
10997 std::vector<const int64_t*> output_shapes_values;
10998 output_shapes_values.reserve(output_shapes.size());
10999 std::vector<int> output_shapes_ndims;
11000 output_shapes_ndims.reserve(output_shapes.size());
11001 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
11002 [](
const auto& v) { return v.data(); });
11003 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
11004 [](
const auto& v) { return static_cast<int>(v.size()); });
11005 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
11006 static_cast<int>(output_shapes.size()), context::get_status());
11007 status_check(context::get_status());
11010 int num_outputs_op = 1;
11011 TFE_TensorHandle* res[1] = {
nullptr};
11012 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11013 status_check(context::get_status());
11014 return tensor(res[0]);
11017 inline tensor leaky_relu(
const tensor& features,
float alpha = 2.0000e-01) {
11019 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11020 TFE_NewOp(context::get_context(),
"LeakyRelu", context::get_status()), &TFE_DeleteOp);
11021 status_check(context::get_status());
11025 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
11026 status_check(context::get_status());
11029 TFE_OpSetAttrFloat(op.get(),
"alpha", alpha);
11032 int num_outputs_op = 1;
11033 TFE_TensorHandle* res[1] = {
nullptr};
11034 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11035 status_check(context::get_status());
11036 return tensor(res[0]);
11039 inline tensor leaky_relu_grad(
const tensor& gradients,
const tensor& features,
float alpha = 2.0000e-01) {
11041 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11042 TFE_NewOp(context::get_context(),
"LeakyReluGrad", context::get_status()), &TFE_DeleteOp);
11043 status_check(context::get_status());
11047 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
11048 status_check(context::get_status());
11050 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
11051 status_check(context::get_status());
11054 TFE_OpSetAttrFloat(op.get(),
"alpha", alpha);
11057 int num_outputs_op = 1;
11058 TFE_TensorHandle* res[1] = {
nullptr};
11059 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11060 status_check(context::get_status());
11061 return tensor(res[0]);
11064 inline tensor left_shift(
const tensor& x,
const tensor& y) {
11066 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11067 TFE_NewOp(context::get_context(),
"LeftShift", context::get_status()), &TFE_DeleteOp);
11068 status_check(context::get_status());
11072 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11073 status_check(context::get_status());
11075 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
11076 status_check(context::get_status());
11081 int num_outputs_op = 1;
11082 TFE_TensorHandle* res[1] = {
nullptr};
11083 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11084 status_check(context::get_status());
11085 return tensor(res[0]);
11088 inline tensor less(
const tensor& x,
const tensor& y) {
11090 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Less", context::get_status()),
11092 status_check(context::get_status());
11096 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11097 status_check(context::get_status());
11099 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
11100 status_check(context::get_status());
11105 int num_outputs_op = 1;
11106 TFE_TensorHandle* res[1] = {
nullptr};
11107 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11108 status_check(context::get_status());
11109 return tensor(res[0]);
11112 inline tensor less_equal(
const tensor& x,
const tensor& y) {
11114 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11115 TFE_NewOp(context::get_context(),
"LessEqual", context::get_status()), &TFE_DeleteOp);
11116 status_check(context::get_status());
11120 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11121 status_check(context::get_status());
11123 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
11124 status_check(context::get_status());
11129 int num_outputs_op = 1;
11130 TFE_TensorHandle* res[1] = {
nullptr};
11131 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11132 status_check(context::get_status());
11133 return tensor(res[0]);
11136 inline tensor lgamma(
const tensor& x) {
11138 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11139 TFE_NewOp(context::get_context(),
"Lgamma", context::get_status()), &TFE_DeleteOp);
11140 status_check(context::get_status());
11144 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11145 status_check(context::get_status());
11150 int num_outputs_op = 1;
11151 TFE_TensorHandle* res[1] = {
nullptr};
11152 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11153 status_check(context::get_status());
11154 return tensor(res[0]);
11157 inline tensor lin_space(
const tensor& start,
const tensor& stop,
const tensor& num,
11158 datatype Tidx =
static_cast<datatype
>(3)) {
11160 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11161 TFE_NewOp(context::get_context(),
"LinSpace", context::get_status()), &TFE_DeleteOp);
11162 status_check(context::get_status());
11166 TFE_OpAddInput(op.get(), start.tfe_handle.get(), context::get_status());
11167 status_check(context::get_status());
11169 TFE_OpAddInput(op.get(), stop.tfe_handle.get(), context::get_status());
11170 status_check(context::get_status());
11172 TFE_OpAddInput(op.get(), num.tfe_handle.get(), context::get_status());
11173 status_check(context::get_status());
11176 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
11179 int num_outputs_op = 1;
11180 TFE_TensorHandle* res[1] = {
nullptr};
11181 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11182 status_check(context::get_status());
11183 return tensor(res[0]);
11186 inline tensor load_and_remap_matrix(
const tensor& ckpt_path,
const tensor& old_input_tensor_name,
11187 const tensor& row_remapping,
const tensor& col_remapping,
11188 const tensor& initializing_values, int64_t num_rows, int64_t num_cols,
11189 int64_t max_rows_in_memory = -1) {
11191 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11192 TFE_NewOp(context::get_context(),
"LoadAndRemapMatrix", context::get_status()), &TFE_DeleteOp);
11193 status_check(context::get_status());
11197 TFE_OpAddInput(op.get(), ckpt_path.tfe_handle.get(), context::get_status());
11198 status_check(context::get_status());
11200 TFE_OpAddInput(op.get(), old_input_tensor_name.tfe_handle.get(), context::get_status());
11201 status_check(context::get_status());
11203 TFE_OpAddInput(op.get(), row_remapping.tfe_handle.get(), context::get_status());
11204 status_check(context::get_status());
11206 TFE_OpAddInput(op.get(), col_remapping.tfe_handle.get(), context::get_status());
11207 status_check(context::get_status());
11209 TFE_OpAddInput(op.get(), initializing_values.tfe_handle.get(), context::get_status());
11210 status_check(context::get_status());
11213 TFE_OpSetAttrInt(op.get(),
"num_rows", num_rows);
11214 TFE_OpSetAttrInt(op.get(),
"num_cols", num_cols);
11215 TFE_OpSetAttrInt(op.get(),
"max_rows_in_memory", max_rows_in_memory);
11218 int num_outputs_op = 1;
11219 TFE_TensorHandle* res[1] = {
nullptr};
11220 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11221 status_check(context::get_status());
11222 return tensor(res[0]);
11225 inline tensor log(
const tensor& x) {
11227 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Log", context::get_status()),
11229 status_check(context::get_status());
11233 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11234 status_check(context::get_status());
11239 int num_outputs_op = 1;
11240 TFE_TensorHandle* res[1] = {
nullptr};
11241 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11242 status_check(context::get_status());
11243 return tensor(res[0]);
11246 inline tensor log1p(
const tensor& x) {
11248 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Log1p", context::get_status()),
11250 status_check(context::get_status());
11254 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11255 status_check(context::get_status());
11260 int num_outputs_op = 1;
11261 TFE_TensorHandle* res[1] = {
nullptr};
11262 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11263 status_check(context::get_status());
11264 return tensor(res[0]);
11267 inline tensor log_softmax(
const tensor& logits) {
11269 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11270 TFE_NewOp(context::get_context(),
"LogSoftmax", context::get_status()), &TFE_DeleteOp);
11271 status_check(context::get_status());
11275 TFE_OpAddInput(op.get(), logits.tfe_handle.get(), context::get_status());
11276 status_check(context::get_status());
11281 int num_outputs_op = 1;
11282 TFE_TensorHandle* res[1] = {
nullptr};
11283 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11284 status_check(context::get_status());
11285 return tensor(res[0]);
11288 inline tensor logical_and(
const tensor& x,
const tensor& y) {
11290 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11291 TFE_NewOp(context::get_context(),
"LogicalAnd", context::get_status()), &TFE_DeleteOp);
11292 status_check(context::get_status());
11296 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11297 status_check(context::get_status());
11299 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
11300 status_check(context::get_status());
11305 int num_outputs_op = 1;
11306 TFE_TensorHandle* res[1] = {
nullptr};
11307 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11308 status_check(context::get_status());
11309 return tensor(res[0]);
11312 inline tensor logical_not(
const tensor& x) {
11314 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11315 TFE_NewOp(context::get_context(),
"LogicalNot", context::get_status()), &TFE_DeleteOp);
11316 status_check(context::get_status());
11320 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11321 status_check(context::get_status());
11326 int num_outputs_op = 1;
11327 TFE_TensorHandle* res[1] = {
nullptr};
11328 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11329 status_check(context::get_status());
11330 return tensor(res[0]);
11333 inline tensor logical_or(
const tensor& x,
const tensor& y) {
11335 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11336 TFE_NewOp(context::get_context(),
"LogicalOr", context::get_status()), &TFE_DeleteOp);
11337 status_check(context::get_status());
11341 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
11342 status_check(context::get_status());
11344 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
11345 status_check(context::get_status());
11350 int num_outputs_op = 1;
11351 TFE_TensorHandle* res[1] = {
nullptr};
11352 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11353 status_check(context::get_status());
11354 return tensor(res[0]);
11357 inline tensor lookup_table_find(
const tensor& table_handle,
const tensor& keys,
const tensor& default_value,
11358 datatype Tin, datatype Tout) {
11360 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11361 TFE_NewOp(context::get_context(),
"LookupTableFind", context::get_status()), &TFE_DeleteOp);
11362 status_check(context::get_status());
11366 TFE_OpAddInput(op.get(), table_handle.tfe_handle.get(), context::get_status());
11367 status_check(context::get_status());
11369 TFE_OpAddInput(op.get(), keys.tfe_handle.get(), context::get_status());
11370 status_check(context::get_status());
11372 TFE_OpAddInput(op.get(), default_value.tfe_handle.get(), context::get_status());
11373 status_check(context::get_status());
11376 TFE_OpSetAttrType(op.get(),
"Tin", Tin);
11377 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
11380 int num_outputs_op = 1;
11381 TFE_TensorHandle* res[1] = {
nullptr};
11382 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11383 status_check(context::get_status());
11384 return tensor(res[0]);
11387 inline tensor lookup_table_find_v2(
const tensor& table_handle,
const tensor& keys,
const tensor& default_value,
11388 datatype Tin, datatype Tout) {
11390 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11391 TFE_NewOp(context::get_context(),
"LookupTableFindV2", context::get_status()), &TFE_DeleteOp);
11392 status_check(context::get_status());
11396 TFE_OpAddInput(op.get(), table_handle.tfe_handle.get(), context::get_status());
11397 status_check(context::get_status());
11399 TFE_OpAddInput(op.get(), keys.tfe_handle.get(), context::get_status());
11400 status_check(context::get_status());
11402 TFE_OpAddInput(op.get(), default_value.tfe_handle.get(), context::get_status());
11403 status_check(context::get_status());
11406 TFE_OpSetAttrType(op.get(),
"Tin", Tin);
11407 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
11410 int num_outputs_op = 1;
11411 TFE_TensorHandle* res[1] = {
nullptr};
11412 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11413 status_check(context::get_status());
11414 return tensor(res[0]);
11417 inline tensor lookup_table_size(
const tensor& table_handle) {
11419 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11420 TFE_NewOp(context::get_context(),
"LookupTableSize", context::get_status()), &TFE_DeleteOp);
11421 status_check(context::get_status());
11425 TFE_OpAddInput(op.get(), table_handle.tfe_handle.get(), context::get_status());
11426 status_check(context::get_status());
11431 int num_outputs_op = 1;
11432 TFE_TensorHandle* res[1] = {
nullptr};
11433 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11434 status_check(context::get_status());
11435 return tensor(res[0]);
11438 inline tensor lookup_table_size_v2(
const tensor& table_handle) {
11440 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11441 TFE_NewOp(context::get_context(),
"LookupTableSizeV2", context::get_status()), &TFE_DeleteOp);
11442 status_check(context::get_status());
11446 TFE_OpAddInput(op.get(), table_handle.tfe_handle.get(), context::get_status());
11447 status_check(context::get_status());
11452 int num_outputs_op = 1;
11453 TFE_TensorHandle* res[1] = {
nullptr};
11454 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11455 status_check(context::get_status());
11456 return tensor(res[0]);
11459 inline tensor loop_cond(
const tensor& input) {
11461 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11462 TFE_NewOp(context::get_context(),
"LoopCond", context::get_status()), &TFE_DeleteOp);
11463 status_check(context::get_status());
11467 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11468 status_check(context::get_status());
11473 int num_outputs_op = 1;
11474 TFE_TensorHandle* res[1] = {
nullptr};
11475 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11476 status_check(context::get_status());
11477 return tensor(res[0]);
11480 inline tensor lower_bound(
const tensor& sorted_inputs,
const tensor& values,
11481 datatype out_type =
static_cast<datatype
>(3)) {
11483 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11484 TFE_NewOp(context::get_context(),
"LowerBound", context::get_status()), &TFE_DeleteOp);
11485 status_check(context::get_status());
11489 TFE_OpAddInput(op.get(), sorted_inputs.tfe_handle.get(), context::get_status());
11490 status_check(context::get_status());
11492 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
11493 status_check(context::get_status());
11496 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
11499 int num_outputs_op = 1;
11500 TFE_TensorHandle* res[1] = {
nullptr};
11501 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11502 status_check(context::get_status());
11503 return tensor(res[0]);
11506 inline tensor map_incomplete_size(
const std::vector<datatype>& dtypes, int64_t capacity = 0, int64_t memory_limit = 0,
11507 const std::string& container =
"",
const std::string& shared_name =
"") {
11509 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11510 TFE_NewOp(context::get_context(),
"MapIncompleteSize", context::get_status()), &TFE_DeleteOp);
11511 status_check(context::get_status());
11516 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
11517 static_cast<int>(dtypes.size()));
11518 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
11519 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
11520 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
11521 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
11524 int num_outputs_op = 1;
11525 TFE_TensorHandle* res[1] = {
nullptr};
11526 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11527 status_check(context::get_status());
11528 return tensor(res[0]);
11531 inline tensor map_peek(
const tensor& key,
const tensor& indices,
const std::vector<datatype>& dtypes,
11532 int64_t capacity = 0, int64_t memory_limit = 0,
const std::string& container =
"",
11533 const std::string& shared_name =
"") {
11535 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11536 TFE_NewOp(context::get_context(),
"MapPeek", context::get_status()), &TFE_DeleteOp);
11537 status_check(context::get_status());
11541 TFE_OpAddInput(op.get(), key.tfe_handle.get(), context::get_status());
11542 status_check(context::get_status());
11544 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
11545 status_check(context::get_status());
11548 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
11549 static_cast<int>(dtypes.size()));
11550 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
11551 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
11552 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
11553 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
11556 int num_outputs_op = 1;
11557 TFE_TensorHandle* res[1] = {
nullptr};
11558 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11559 status_check(context::get_status());
11560 return tensor(res[0]);
11563 inline tensor map_size(
const std::vector<datatype>& dtypes, int64_t capacity = 0, int64_t memory_limit = 0,
11564 const std::string& container =
"",
const std::string& shared_name =
"") {
11566 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11567 TFE_NewOp(context::get_context(),
"MapSize", context::get_status()), &TFE_DeleteOp);
11568 status_check(context::get_status());
11573 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
11574 static_cast<int>(dtypes.size()));
11575 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
11576 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
11577 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
11578 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
11581 int num_outputs_op = 1;
11582 TFE_TensorHandle* res[1] = {
nullptr};
11583 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11584 status_check(context::get_status());
11585 return tensor(res[0]);
11588 inline tensor map_unstage(
const tensor& key,
const tensor& indices,
const std::vector<datatype>& dtypes,
11589 int64_t capacity = 0, int64_t memory_limit = 0,
const std::string& container =
"",
11590 const std::string& shared_name =
"") {
11592 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11593 TFE_NewOp(context::get_context(),
"MapUnstage", context::get_status()), &TFE_DeleteOp);
11594 status_check(context::get_status());
11598 TFE_OpAddInput(op.get(), key.tfe_handle.get(), context::get_status());
11599 status_check(context::get_status());
11601 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
11602 status_check(context::get_status());
11605 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
11606 static_cast<int>(dtypes.size()));
11607 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
11608 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
11609 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
11610 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
11613 int num_outputs_op = 1;
11614 TFE_TensorHandle* res[1] = {
nullptr};
11615 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11616 status_check(context::get_status());
11617 return tensor(res[0]);
11620 inline tensor mat_mul(
const tensor& a,
const tensor& b,
bool transpose_a =
false,
bool transpose_b =
false) {
11622 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11623 TFE_NewOp(context::get_context(),
"MatMul", context::get_status()), &TFE_DeleteOp);
11624 status_check(context::get_status());
11628 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
11629 status_check(context::get_status());
11631 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
11632 status_check(context::get_status());
11635 TFE_OpSetAttrBool(op.get(),
"transpose_a", (
unsigned char)transpose_a);
11636 TFE_OpSetAttrBool(op.get(),
"transpose_b", (
unsigned char)transpose_b);
11639 int num_outputs_op = 1;
11640 TFE_TensorHandle* res[1] = {
nullptr};
11641 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11642 status_check(context::get_status());
11643 return tensor(res[0]);
11646 inline tensor matching_files(
const tensor& pattern) {
11648 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11649 TFE_NewOp(context::get_context(),
"MatchingFiles", context::get_status()), &TFE_DeleteOp);
11650 status_check(context::get_status());
11654 TFE_OpAddInput(op.get(), pattern.tfe_handle.get(), context::get_status());
11655 status_check(context::get_status());
11660 int num_outputs_op = 1;
11661 TFE_TensorHandle* res[1] = {
nullptr};
11662 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11663 status_check(context::get_status());
11664 return tensor(res[0]);
11667 inline tensor matching_files_dataset(
const tensor& patterns) {
11669 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11670 TFE_NewOp(context::get_context(),
"MatchingFilesDataset", context::get_status()), &TFE_DeleteOp);
11671 status_check(context::get_status());
11675 TFE_OpAddInput(op.get(), patterns.tfe_handle.get(), context::get_status());
11676 status_check(context::get_status());
11681 int num_outputs_op = 1;
11682 TFE_TensorHandle* res[1] = {
nullptr};
11683 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11684 status_check(context::get_status());
11685 return tensor(res[0]);
11688 inline tensor matrix_band_part(
const tensor& input,
const tensor& num_lower,
const tensor& num_upper,
11689 datatype Tindex =
static_cast<datatype
>(9)) {
11691 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11692 TFE_NewOp(context::get_context(),
"MatrixBandPart", context::get_status()), &TFE_DeleteOp);
11693 status_check(context::get_status());
11697 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11698 status_check(context::get_status());
11700 TFE_OpAddInput(op.get(), num_lower.tfe_handle.get(), context::get_status());
11701 status_check(context::get_status());
11703 TFE_OpAddInput(op.get(), num_upper.tfe_handle.get(), context::get_status());
11704 status_check(context::get_status());
11707 TFE_OpSetAttrType(op.get(),
"Tindex", Tindex);
11710 int num_outputs_op = 1;
11711 TFE_TensorHandle* res[1] = {
nullptr};
11712 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11713 status_check(context::get_status());
11714 return tensor(res[0]);
11717 inline tensor matrix_determinant(
const tensor& input) {
11719 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11720 TFE_NewOp(context::get_context(),
"MatrixDeterminant", context::get_status()), &TFE_DeleteOp);
11721 status_check(context::get_status());
11725 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11726 status_check(context::get_status());
11731 int num_outputs_op = 1;
11732 TFE_TensorHandle* res[1] = {
nullptr};
11733 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11734 status_check(context::get_status());
11735 return tensor(res[0]);
11738 inline tensor matrix_diag(
const tensor& diagonal) {
11740 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11741 TFE_NewOp(context::get_context(),
"MatrixDiag", context::get_status()), &TFE_DeleteOp);
11742 status_check(context::get_status());
11746 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
11747 status_check(context::get_status());
11752 int num_outputs_op = 1;
11753 TFE_TensorHandle* res[1] = {
nullptr};
11754 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11755 status_check(context::get_status());
11756 return tensor(res[0]);
11759 inline tensor matrix_diag_part(
const tensor& input) {
11761 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11762 TFE_NewOp(context::get_context(),
"MatrixDiagPart", context::get_status()), &TFE_DeleteOp);
11763 status_check(context::get_status());
11767 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11768 status_check(context::get_status());
11773 int num_outputs_op = 1;
11774 TFE_TensorHandle* res[1] = {
nullptr};
11775 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11776 status_check(context::get_status());
11777 return tensor(res[0]);
11780 inline tensor matrix_diag_part_v2(
const tensor& input,
const tensor& k,
const tensor& padding_value) {
11782 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11783 TFE_NewOp(context::get_context(),
"MatrixDiagPartV2", context::get_status()), &TFE_DeleteOp);
11784 status_check(context::get_status());
11788 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11789 status_check(context::get_status());
11791 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
11792 status_check(context::get_status());
11794 TFE_OpAddInput(op.get(), padding_value.tfe_handle.get(), context::get_status());
11795 status_check(context::get_status());
11800 int num_outputs_op = 1;
11801 TFE_TensorHandle* res[1] = {
nullptr};
11802 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11803 status_check(context::get_status());
11804 return tensor(res[0]);
11807 inline tensor matrix_diag_part_v3(
const tensor& input,
const tensor& k,
const tensor& padding_value,
11808 const std::string& align =
"RIGHT_LEFT") {
11810 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11811 TFE_NewOp(context::get_context(),
"MatrixDiagPartV3", context::get_status()), &TFE_DeleteOp);
11812 status_check(context::get_status());
11816 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11817 status_check(context::get_status());
11819 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
11820 status_check(context::get_status());
11822 TFE_OpAddInput(op.get(), padding_value.tfe_handle.get(), context::get_status());
11823 status_check(context::get_status());
11826 TFE_OpSetAttrString(op.get(),
"align", (
void*)align.c_str(), align.size());
11829 int num_outputs_op = 1;
11830 TFE_TensorHandle* res[1] = {
nullptr};
11831 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11832 status_check(context::get_status());
11833 return tensor(res[0]);
11836 inline tensor matrix_diag_v2(
const tensor& diagonal,
const tensor& k,
const tensor& num_rows,
const tensor& num_cols,
11837 const tensor& padding_value) {
11839 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11840 TFE_NewOp(context::get_context(),
"MatrixDiagV2", context::get_status()), &TFE_DeleteOp);
11841 status_check(context::get_status());
11845 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
11846 status_check(context::get_status());
11848 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
11849 status_check(context::get_status());
11851 TFE_OpAddInput(op.get(), num_rows.tfe_handle.get(), context::get_status());
11852 status_check(context::get_status());
11854 TFE_OpAddInput(op.get(), num_cols.tfe_handle.get(), context::get_status());
11855 status_check(context::get_status());
11857 TFE_OpAddInput(op.get(), padding_value.tfe_handle.get(), context::get_status());
11858 status_check(context::get_status());
11863 int num_outputs_op = 1;
11864 TFE_TensorHandle* res[1] = {
nullptr};
11865 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11866 status_check(context::get_status());
11867 return tensor(res[0]);
11870 inline tensor matrix_diag_v3(
const tensor& diagonal,
const tensor& k,
const tensor& num_rows,
const tensor& num_cols,
11871 const tensor& padding_value,
const std::string& align =
"RIGHT_LEFT") {
11873 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11874 TFE_NewOp(context::get_context(),
"MatrixDiagV3", context::get_status()), &TFE_DeleteOp);
11875 status_check(context::get_status());
11879 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
11880 status_check(context::get_status());
11882 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
11883 status_check(context::get_status());
11885 TFE_OpAddInput(op.get(), num_rows.tfe_handle.get(), context::get_status());
11886 status_check(context::get_status());
11888 TFE_OpAddInput(op.get(), num_cols.tfe_handle.get(), context::get_status());
11889 status_check(context::get_status());
11891 TFE_OpAddInput(op.get(), padding_value.tfe_handle.get(), context::get_status());
11892 status_check(context::get_status());
11895 TFE_OpSetAttrString(op.get(),
"align", (
void*)align.c_str(), align.size());
11898 int num_outputs_op = 1;
11899 TFE_TensorHandle* res[1] = {
nullptr};
11900 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11901 status_check(context::get_status());
11902 return tensor(res[0]);
11905 inline tensor matrix_exponential(
const tensor& input) {
11907 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11908 TFE_NewOp(context::get_context(),
"MatrixExponential", context::get_status()), &TFE_DeleteOp);
11909 status_check(context::get_status());
11913 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11914 status_check(context::get_status());
11919 int num_outputs_op = 1;
11920 TFE_TensorHandle* res[1] = {
nullptr};
11921 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11922 status_check(context::get_status());
11923 return tensor(res[0]);
11926 inline tensor matrix_inverse(
const tensor& input,
bool adjoint =
false) {
11928 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11929 TFE_NewOp(context::get_context(),
"MatrixInverse", context::get_status()), &TFE_DeleteOp);
11930 status_check(context::get_status());
11934 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11935 status_check(context::get_status());
11938 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
11941 int num_outputs_op = 1;
11942 TFE_TensorHandle* res[1] = {
nullptr};
11943 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11944 status_check(context::get_status());
11945 return tensor(res[0]);
11948 inline tensor matrix_logarithm(
const tensor& input) {
11950 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11951 TFE_NewOp(context::get_context(),
"MatrixLogarithm", context::get_status()), &TFE_DeleteOp);
11952 status_check(context::get_status());
11956 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11957 status_check(context::get_status());
11962 int num_outputs_op = 1;
11963 TFE_TensorHandle* res[1] = {
nullptr};
11964 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11965 status_check(context::get_status());
11966 return tensor(res[0]);
11969 inline tensor matrix_set_diag(
const tensor& input,
const tensor& diagonal) {
11971 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11972 TFE_NewOp(context::get_context(),
"MatrixSetDiag", context::get_status()), &TFE_DeleteOp);
11973 status_check(context::get_status());
11977 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
11978 status_check(context::get_status());
11980 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
11981 status_check(context::get_status());
11986 int num_outputs_op = 1;
11987 TFE_TensorHandle* res[1] = {
nullptr};
11988 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
11989 status_check(context::get_status());
11990 return tensor(res[0]);
11993 inline tensor matrix_set_diag_v2(
const tensor& input,
const tensor& diagonal,
const tensor& k) {
11995 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
11996 TFE_NewOp(context::get_context(),
"MatrixSetDiagV2", context::get_status()), &TFE_DeleteOp);
11997 status_check(context::get_status());
12001 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12002 status_check(context::get_status());
12004 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
12005 status_check(context::get_status());
12007 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
12008 status_check(context::get_status());
12013 int num_outputs_op = 1;
12014 TFE_TensorHandle* res[1] = {
nullptr};
12015 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12016 status_check(context::get_status());
12017 return tensor(res[0]);
12020 inline tensor matrix_set_diag_v3(
const tensor& input,
const tensor& diagonal,
const tensor& k,
12021 const std::string& align =
"RIGHT_LEFT") {
12023 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12024 TFE_NewOp(context::get_context(),
"MatrixSetDiagV3", context::get_status()), &TFE_DeleteOp);
12025 status_check(context::get_status());
12029 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12030 status_check(context::get_status());
12032 TFE_OpAddInput(op.get(), diagonal.tfe_handle.get(), context::get_status());
12033 status_check(context::get_status());
12035 TFE_OpAddInput(op.get(), k.tfe_handle.get(), context::get_status());
12036 status_check(context::get_status());
12039 TFE_OpSetAttrString(op.get(),
"align", (
void*)align.c_str(), align.size());
12042 int num_outputs_op = 1;
12043 TFE_TensorHandle* res[1] = {
nullptr};
12044 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12045 status_check(context::get_status());
12046 return tensor(res[0]);
12049 inline tensor matrix_solve(
const tensor& matrix,
const tensor& rhs,
bool adjoint =
false) {
12051 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12052 TFE_NewOp(context::get_context(),
"MatrixSolve", context::get_status()), &TFE_DeleteOp);
12053 status_check(context::get_status());
12057 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
12058 status_check(context::get_status());
12060 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
12061 status_check(context::get_status());
12064 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
12067 int num_outputs_op = 1;
12068 TFE_TensorHandle* res[1] = {
nullptr};
12069 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12070 status_check(context::get_status());
12071 return tensor(res[0]);
12074 inline tensor matrix_solve_ls(
const tensor& matrix,
const tensor& rhs,
const tensor& l2_regularizer,
bool fast =
true) {
12076 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12077 TFE_NewOp(context::get_context(),
"MatrixSolveLs", context::get_status()), &TFE_DeleteOp);
12078 status_check(context::get_status());
12082 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
12083 status_check(context::get_status());
12085 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
12086 status_check(context::get_status());
12088 TFE_OpAddInput(op.get(), l2_regularizer.tfe_handle.get(), context::get_status());
12089 status_check(context::get_status());
12092 TFE_OpSetAttrBool(op.get(),
"fast", (
unsigned char)fast);
12095 int num_outputs_op = 1;
12096 TFE_TensorHandle* res[1] = {
nullptr};
12097 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12098 status_check(context::get_status());
12099 return tensor(res[0]);
12102 inline tensor matrix_square_root(
const tensor& input) {
12104 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12105 TFE_NewOp(context::get_context(),
"MatrixSquareRoot", context::get_status()), &TFE_DeleteOp);
12106 status_check(context::get_status());
12110 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12111 status_check(context::get_status());
12116 int num_outputs_op = 1;
12117 TFE_TensorHandle* res[1] = {
nullptr};
12118 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12119 status_check(context::get_status());
12120 return tensor(res[0]);
12123 inline tensor matrix_triangular_solve(
const tensor& matrix,
const tensor& rhs,
bool lower =
true,
12124 bool adjoint =
false) {
12126 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12127 TFE_NewOp(context::get_context(),
"MatrixTriangularSolve", context::get_status()), &TFE_DeleteOp);
12128 status_check(context::get_status());
12132 TFE_OpAddInput(op.get(), matrix.tfe_handle.get(), context::get_status());
12133 status_check(context::get_status());
12135 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
12136 status_check(context::get_status());
12139 TFE_OpSetAttrBool(op.get(),
"lower", (
unsigned char)lower);
12140 TFE_OpSetAttrBool(op.get(),
"adjoint", (
unsigned char)adjoint);
12143 int num_outputs_op = 1;
12144 TFE_TensorHandle* res[1] = {
nullptr};
12145 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12146 status_check(context::get_status());
12147 return tensor(res[0]);
12150 inline tensor max(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
12151 datatype Tidx =
static_cast<datatype
>(3)) {
12153 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Max", context::get_status()),
12155 status_check(context::get_status());
12159 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12160 status_check(context::get_status());
12162 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
12163 status_check(context::get_status());
12166 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
12167 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
12170 int num_outputs_op = 1;
12171 TFE_TensorHandle* res[1] = {
nullptr};
12172 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12173 status_check(context::get_status());
12174 return tensor(res[0]);
12177 inline tensor max_intra_op_parallelism_dataset(
const tensor& input_dataset,
const tensor& max_intra_op_parallelism,
12178 const std::vector<datatype>& output_types,
12179 const std::vector<std::vector<int64_t>>& output_shapes) {
12181 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12182 TFE_NewOp(context::get_context(),
"MaxIntraOpParallelismDataset", context::get_status()), &TFE_DeleteOp);
12183 status_check(context::get_status());
12187 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
12188 status_check(context::get_status());
12190 TFE_OpAddInput(op.get(), max_intra_op_parallelism.tfe_handle.get(), context::get_status());
12191 status_check(context::get_status());
12194 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
12195 static_cast<int>(output_types.size()));
12197 std::vector<const int64_t*> output_shapes_values;
12198 output_shapes_values.reserve(output_shapes.size());
12199 std::vector<int> output_shapes_ndims;
12200 output_shapes_ndims.reserve(output_shapes.size());
12201 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
12202 [](
const auto& v) { return v.data(); });
12203 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
12204 [](
const auto& v) { return static_cast<int>(v.size()); });
12205 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
12206 static_cast<int>(output_shapes.size()), context::get_status());
12207 status_check(context::get_status());
12210 int num_outputs_op = 1;
12211 TFE_TensorHandle* res[1] = {
nullptr};
12212 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12213 status_check(context::get_status());
12214 return tensor(res[0]);
12217 inline tensor max_pool(
const tensor& input,
const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12218 const std::string& padding,
const std::string& data_format =
"NHWC") {
12220 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12221 TFE_NewOp(context::get_context(),
"MaxPool", context::get_status()), &TFE_DeleteOp);
12222 status_check(context::get_status());
12226 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12227 status_check(context::get_status());
12230 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12231 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12232 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12233 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12236 int num_outputs_op = 1;
12237 TFE_TensorHandle* res[1] = {
nullptr};
12238 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12239 status_check(context::get_status());
12240 return tensor(res[0]);
12243 inline tensor max_pool3_d(
const tensor& input,
const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12244 const std::string& padding,
const std::string& data_format =
"NDHWC") {
12246 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12247 TFE_NewOp(context::get_context(),
"MaxPool3D", context::get_status()), &TFE_DeleteOp);
12248 status_check(context::get_status());
12252 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12253 status_check(context::get_status());
12256 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12257 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12258 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12259 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12262 int num_outputs_op = 1;
12263 TFE_TensorHandle* res[1] = {
nullptr};
12264 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12265 status_check(context::get_status());
12266 return tensor(res[0]);
12269 inline tensor max_pool3_d_grad(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12270 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12271 const std::string& padding,
const std::string& data_format =
"NDHWC",
12272 datatype TInput =
static_cast<datatype
>(1)) {
12274 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12275 TFE_NewOp(context::get_context(),
"MaxPool3DGrad", context::get_status()), &TFE_DeleteOp);
12276 status_check(context::get_status());
12280 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12281 status_check(context::get_status());
12283 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12284 status_check(context::get_status());
12286 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12287 status_check(context::get_status());
12290 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12291 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12292 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12293 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12294 TFE_OpSetAttrType(op.get(),
"TInput", TInput);
12297 int num_outputs_op = 1;
12298 TFE_TensorHandle* res[1] = {
nullptr};
12299 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12300 status_check(context::get_status());
12301 return tensor(res[0]);
12304 inline tensor max_pool3_d_grad_grad(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12305 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12306 const std::string& padding,
const std::string& data_format =
"NDHWC") {
12308 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12309 TFE_NewOp(context::get_context(),
"MaxPool3DGradGrad", context::get_status()), &TFE_DeleteOp);
12310 status_check(context::get_status());
12314 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12315 status_check(context::get_status());
12317 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12318 status_check(context::get_status());
12320 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12321 status_check(context::get_status());
12324 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12325 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12326 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12327 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12330 int num_outputs_op = 1;
12331 TFE_TensorHandle* res[1] = {
nullptr};
12332 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12333 status_check(context::get_status());
12334 return tensor(res[0]);
12337 inline tensor max_pool_grad(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12338 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12339 const std::string& padding,
const std::string& data_format =
"NHWC") {
12341 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12342 TFE_NewOp(context::get_context(),
"MaxPoolGrad", context::get_status()), &TFE_DeleteOp);
12343 status_check(context::get_status());
12347 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12348 status_check(context::get_status());
12350 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12351 status_check(context::get_status());
12353 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12354 status_check(context::get_status());
12357 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12358 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12359 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12360 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12363 int num_outputs_op = 1;
12364 TFE_TensorHandle* res[1] = {
nullptr};
12365 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12366 status_check(context::get_status());
12367 return tensor(res[0]);
12370 inline tensor max_pool_grad_grad(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12371 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12372 const std::string& padding,
const std::string& data_format =
"NHWC") {
12374 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12375 TFE_NewOp(context::get_context(),
"MaxPoolGradGrad", context::get_status()), &TFE_DeleteOp);
12376 status_check(context::get_status());
12380 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12381 status_check(context::get_status());
12383 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12384 status_check(context::get_status());
12386 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12387 status_check(context::get_status());
12390 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12391 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12392 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12393 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12396 int num_outputs_op = 1;
12397 TFE_TensorHandle* res[1] = {
nullptr};
12398 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12399 status_check(context::get_status());
12400 return tensor(res[0]);
12403 inline tensor max_pool_grad_grad_v2(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12404 const tensor& ksize,
const tensor& strides,
const std::string& padding,
12405 const std::string& data_format =
"NHWC") {
12407 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12408 TFE_NewOp(context::get_context(),
"MaxPoolGradGradV2", context::get_status()), &TFE_DeleteOp);
12409 status_check(context::get_status());
12413 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12414 status_check(context::get_status());
12416 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12417 status_check(context::get_status());
12419 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12420 status_check(context::get_status());
12422 TFE_OpAddInput(op.get(), ksize.tfe_handle.get(), context::get_status());
12423 status_check(context::get_status());
12425 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
12426 status_check(context::get_status());
12429 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12430 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12433 int num_outputs_op = 1;
12434 TFE_TensorHandle* res[1] = {
nullptr};
12435 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12436 status_check(context::get_status());
12437 return tensor(res[0]);
12440 inline tensor max_pool_grad_grad_with_argmax(
const tensor& input,
const tensor& grad,
const tensor& argmax,
12441 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12442 const std::string& padding, datatype Targmax,
12443 bool include_batch_in_index =
false) {
12445 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12446 TFE_NewOp(context::get_context(),
"MaxPoolGradGradWithArgmax", context::get_status()), &TFE_DeleteOp);
12447 status_check(context::get_status());
12451 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12452 status_check(context::get_status());
12454 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12455 status_check(context::get_status());
12457 TFE_OpAddInput(op.get(), argmax.tfe_handle.get(), context::get_status());
12458 status_check(context::get_status());
12461 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12462 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12463 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12464 TFE_OpSetAttrType(op.get(),
"Targmax", Targmax);
12465 TFE_OpSetAttrBool(op.get(),
"include_batch_in_index", (
unsigned char)include_batch_in_index);
12468 int num_outputs_op = 1;
12469 TFE_TensorHandle* res[1] = {
nullptr};
12470 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12471 status_check(context::get_status());
12472 return tensor(res[0]);
12475 inline tensor max_pool_grad_v2(
const tensor& orig_input,
const tensor& orig_output,
const tensor& grad,
12476 const tensor& ksize,
const tensor& strides,
const std::string& padding,
12477 const std::string& data_format =
"NHWC") {
12479 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12480 TFE_NewOp(context::get_context(),
"MaxPoolGradV2", context::get_status()), &TFE_DeleteOp);
12481 status_check(context::get_status());
12485 TFE_OpAddInput(op.get(), orig_input.tfe_handle.get(), context::get_status());
12486 status_check(context::get_status());
12488 TFE_OpAddInput(op.get(), orig_output.tfe_handle.get(), context::get_status());
12489 status_check(context::get_status());
12491 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12492 status_check(context::get_status());
12494 TFE_OpAddInput(op.get(), ksize.tfe_handle.get(), context::get_status());
12495 status_check(context::get_status());
12497 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
12498 status_check(context::get_status());
12501 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12502 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12505 int num_outputs_op = 1;
12506 TFE_TensorHandle* res[1] = {
nullptr};
12507 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12508 status_check(context::get_status());
12509 return tensor(res[0]);
12512 inline tensor max_pool_grad_with_argmax(
const tensor& input,
const tensor& grad,
const tensor& argmax,
12513 const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides,
12514 const std::string& padding, datatype Targmax,
12515 bool include_batch_in_index =
false) {
12517 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12518 TFE_NewOp(context::get_context(),
"MaxPoolGradWithArgmax", context::get_status()), &TFE_DeleteOp);
12519 status_check(context::get_status());
12523 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12524 status_check(context::get_status());
12526 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
12527 status_check(context::get_status());
12529 TFE_OpAddInput(op.get(), argmax.tfe_handle.get(), context::get_status());
12530 status_check(context::get_status());
12533 TFE_OpSetAttrIntList(op.get(),
"ksize", ksize.data(),
static_cast<int>(ksize.size()));
12534 TFE_OpSetAttrIntList(op.get(),
"strides", strides.data(),
static_cast<int>(strides.size()));
12535 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12536 TFE_OpSetAttrType(op.get(),
"Targmax", Targmax);
12537 TFE_OpSetAttrBool(op.get(),
"include_batch_in_index", (
unsigned char)include_batch_in_index);
12540 int num_outputs_op = 1;
12541 TFE_TensorHandle* res[1] = {
nullptr};
12542 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12543 status_check(context::get_status());
12544 return tensor(res[0]);
12547 inline tensor max_pool_v2(
const tensor& input,
const tensor& ksize,
const tensor& strides,
const std::string& padding,
12548 const std::string& data_format =
"NHWC") {
12550 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12551 TFE_NewOp(context::get_context(),
"MaxPoolV2", context::get_status()), &TFE_DeleteOp);
12552 status_check(context::get_status());
12556 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12557 status_check(context::get_status());
12559 TFE_OpAddInput(op.get(), ksize.tfe_handle.get(), context::get_status());
12560 status_check(context::get_status());
12562 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
12563 status_check(context::get_status());
12566 TFE_OpSetAttrString(op.get(),
"padding", (
void*)padding.c_str(), padding.size());
12567 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
12570 int num_outputs_op = 1;
12571 TFE_TensorHandle* res[1] = {
nullptr};
12572 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12573 status_check(context::get_status());
12574 return tensor(res[0]);
12577 inline tensor maximum(
const tensor& x,
const tensor& y) {
12579 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12580 TFE_NewOp(context::get_context(),
"Maximum", context::get_status()), &TFE_DeleteOp);
12581 status_check(context::get_status());
12585 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
12586 status_check(context::get_status());
12588 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
12589 status_check(context::get_status());
12594 int num_outputs_op = 1;
12595 TFE_TensorHandle* res[1] = {
nullptr};
12596 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12597 status_check(context::get_status());
12598 return tensor(res[0]);
12601 inline tensor mean(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
12602 datatype Tidx =
static_cast<datatype
>(3)) {
12604 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Mean", context::get_status()),
12606 status_check(context::get_status());
12610 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12611 status_check(context::get_status());
12613 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
12614 status_check(context::get_status());
12617 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
12618 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
12621 int num_outputs_op = 1;
12622 TFE_TensorHandle* res[1] = {
nullptr};
12623 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12624 status_check(context::get_status());
12625 return tensor(res[0]);
12628 inline tensor merge_summary(
const std::vector<tensor>& inputs) {
12630 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12631 TFE_NewOp(context::get_context(),
"MergeSummary", context::get_status()), &TFE_DeleteOp);
12632 status_check(context::get_status());
12636 std::vector<TFE_TensorHandle*> inputs_handles;
12637 inputs_handles.reserve(inputs.size());
12638 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
12639 [](
const auto& t) { return t.tfe_handle.get(); });
12640 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
12641 status_check(context::get_status());
12644 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
12647 int num_outputs_op = 1;
12648 TFE_TensorHandle* res[1] = {
nullptr};
12649 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12650 status_check(context::get_status());
12651 return tensor(res[0]);
12654 inline tensor mfcc(
const tensor& spectrogram,
const tensor& sample_rate,
float upper_frequency_limit = 4.0000e+03,
12655 float lower_frequency_limit = 2.0000e+01, int64_t filterbank_channel_count = 40,
12656 int64_t dct_coefficient_count = 13) {
12658 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Mfcc", context::get_status()),
12660 status_check(context::get_status());
12664 TFE_OpAddInput(op.get(), spectrogram.tfe_handle.get(), context::get_status());
12665 status_check(context::get_status());
12667 TFE_OpAddInput(op.get(), sample_rate.tfe_handle.get(), context::get_status());
12668 status_check(context::get_status());
12671 TFE_OpSetAttrFloat(op.get(),
"upper_frequency_limit", upper_frequency_limit);
12672 TFE_OpSetAttrFloat(op.get(),
"lower_frequency_limit", lower_frequency_limit);
12673 TFE_OpSetAttrInt(op.get(),
"filterbank_channel_count", filterbank_channel_count);
12674 TFE_OpSetAttrInt(op.get(),
"dct_coefficient_count", dct_coefficient_count);
12677 int num_outputs_op = 1;
12678 TFE_TensorHandle* res[1] = {
nullptr};
12679 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12680 status_check(context::get_status());
12681 return tensor(res[0]);
12684 inline tensor min(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
12685 datatype Tidx =
static_cast<datatype
>(3)) {
12687 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Min", context::get_status()),
12689 status_check(context::get_status());
12693 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12694 status_check(context::get_status());
12696 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
12697 status_check(context::get_status());
12700 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
12701 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
12704 int num_outputs_op = 1;
12705 TFE_TensorHandle* res[1] = {
nullptr};
12706 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12707 status_check(context::get_status());
12708 return tensor(res[0]);
12711 inline tensor minimum(
const tensor& x,
const tensor& y) {
12713 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12714 TFE_NewOp(context::get_context(),
"Minimum", context::get_status()), &TFE_DeleteOp);
12715 status_check(context::get_status());
12719 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
12720 status_check(context::get_status());
12722 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
12723 status_check(context::get_status());
12728 int num_outputs_op = 1;
12729 TFE_TensorHandle* res[1] = {
nullptr};
12730 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12731 status_check(context::get_status());
12732 return tensor(res[0]);
12735 inline tensor mirror_pad(
const tensor& input,
const tensor& paddings,
const std::string& mode,
12736 datatype Tpaddings =
static_cast<datatype
>(3)) {
12738 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12739 TFE_NewOp(context::get_context(),
"MirrorPad", context::get_status()), &TFE_DeleteOp);
12740 status_check(context::get_status());
12744 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12745 status_check(context::get_status());
12747 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
12748 status_check(context::get_status());
12751 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
12752 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
12755 int num_outputs_op = 1;
12756 TFE_TensorHandle* res[1] = {
nullptr};
12757 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12758 status_check(context::get_status());
12759 return tensor(res[0]);
12762 inline tensor mirror_pad_grad(
const tensor& input,
const tensor& paddings,
const std::string& mode,
12763 datatype Tpaddings =
static_cast<datatype
>(3)) {
12765 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12766 TFE_NewOp(context::get_context(),
"MirrorPadGrad", context::get_status()), &TFE_DeleteOp);
12767 status_check(context::get_status());
12771 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
12772 status_check(context::get_status());
12774 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
12775 status_check(context::get_status());
12778 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
12779 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
12782 int num_outputs_op = 1;
12783 TFE_TensorHandle* res[1] = {
nullptr};
12784 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12785 status_check(context::get_status());
12786 return tensor(res[0]);
12789 inline tensor mod(
const tensor& x,
const tensor& y) {
12791 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Mod", context::get_status()),
12793 status_check(context::get_status());
12797 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
12798 status_check(context::get_status());
12800 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
12801 status_check(context::get_status());
12806 int num_outputs_op = 1;
12807 TFE_TensorHandle* res[1] = {
nullptr};
12808 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12809 status_check(context::get_status());
12810 return tensor(res[0]);
12813 inline tensor model_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
12814 const std::vector<std::vector<int64_t>>& output_shapes, int64_t algorithm = 0,
12815 int64_t cpu_budget = 0) {
12817 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12818 TFE_NewOp(context::get_context(),
"ModelDataset", context::get_status()), &TFE_DeleteOp);
12819 status_check(context::get_status());
12823 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
12824 status_check(context::get_status());
12827 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
12828 static_cast<int>(output_types.size()));
12830 std::vector<const int64_t*> output_shapes_values;
12831 output_shapes_values.reserve(output_shapes.size());
12832 std::vector<int> output_shapes_ndims;
12833 output_shapes_ndims.reserve(output_shapes.size());
12834 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
12835 [](
const auto& v) { return v.data(); });
12836 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
12837 [](
const auto& v) { return static_cast<int>(v.size()); });
12838 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
12839 static_cast<int>(output_shapes.size()), context::get_status());
12840 status_check(context::get_status());
12842 TFE_OpSetAttrInt(op.get(),
"algorithm", algorithm);
12843 TFE_OpSetAttrInt(op.get(),
"cpu_budget", cpu_budget);
12846 int num_outputs_op = 1;
12847 TFE_TensorHandle* res[1] = {
nullptr};
12848 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12849 status_check(context::get_status());
12850 return tensor(res[0]);
12853 inline tensor mul(
const tensor& x,
const tensor& y) {
12855 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Mul", context::get_status()),
12857 status_check(context::get_status());
12861 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
12862 status_check(context::get_status());
12864 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
12865 status_check(context::get_status());
12870 int num_outputs_op = 1;
12871 TFE_TensorHandle* res[1] = {
nullptr};
12872 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12873 status_check(context::get_status());
12874 return tensor(res[0]);
12877 inline tensor mul_no_nan(
const tensor& x,
const tensor& y) {
12879 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12880 TFE_NewOp(context::get_context(),
"MulNoNan", context::get_status()), &TFE_DeleteOp);
12881 status_check(context::get_status());
12885 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
12886 status_check(context::get_status());
12888 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
12889 status_check(context::get_status());
12894 int num_outputs_op = 1;
12895 TFE_TensorHandle* res[1] = {
nullptr};
12896 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12897 status_check(context::get_status());
12898 return tensor(res[0]);
12901 inline tensor multi_device_iterator(
const std::vector<std::string>& devices,
const std::string& shared_name,
12902 const std::string& container,
const std::vector<datatype>& output_types,
12903 const std::vector<std::vector<int64_t>>& output_shapes) {
12905 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12906 TFE_NewOp(context::get_context(),
"MultiDeviceIterator", context::get_status()), &TFE_DeleteOp);
12907 status_check(context::get_status());
12913 std::vector<std::size_t> devices_sizes;
12914 devices_sizes.reserve(devices.size());
12915 std::transform(devices.begin(), devices.end(), std::back_inserter(devices_sizes),
12916 [](
const auto& s) { return s.size(); });
12917 TFE_OpSetAttrStringList(op.get(),
"devices",
reinterpret_cast<const void* const*
>(devices.data()),
12918 devices_sizes.data(),
static_cast<int>(devices.size()));
12920 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
12921 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
12922 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
12923 static_cast<int>(output_types.size()));
12925 std::vector<const int64_t*> output_shapes_values;
12926 output_shapes_values.reserve(output_shapes.size());
12927 std::vector<int> output_shapes_ndims;
12928 output_shapes_ndims.reserve(output_shapes.size());
12929 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
12930 [](
const auto& v) { return v.data(); });
12931 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
12932 [](
const auto& v) { return static_cast<int>(v.size()); });
12933 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
12934 static_cast<int>(output_shapes.size()), context::get_status());
12935 status_check(context::get_status());
12938 int num_outputs_op = 1;
12939 TFE_TensorHandle* res[1] = {
nullptr};
12940 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12941 status_check(context::get_status());
12942 return tensor(res[0]);
12945 inline tensor multi_device_iterator_from_string_handle(
const tensor& string_handle,
12946 const std::vector<datatype>& output_types,
12947 const std::vector<std::vector<int64_t>>& output_shapes) {
12949 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12950 TFE_NewOp(context::get_context(),
"MultiDeviceIteratorFromStringHandle", context::get_status()), &TFE_DeleteOp);
12951 status_check(context::get_status());
12955 TFE_OpAddInput(op.get(), string_handle.tfe_handle.get(), context::get_status());
12956 status_check(context::get_status());
12959 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
12960 static_cast<int>(output_types.size()));
12962 std::vector<const int64_t*> output_shapes_values;
12963 output_shapes_values.reserve(output_shapes.size());
12964 std::vector<int> output_shapes_ndims;
12965 output_shapes_ndims.reserve(output_shapes.size());
12966 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
12967 [](
const auto& v) { return v.data(); });
12968 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
12969 [](
const auto& v) { return static_cast<int>(v.size()); });
12970 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
12971 static_cast<int>(output_shapes.size()), context::get_status());
12972 status_check(context::get_status());
12975 int num_outputs_op = 1;
12976 TFE_TensorHandle* res[1] = {
nullptr};
12977 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
12978 status_check(context::get_status());
12979 return tensor(res[0]);
12982 inline tensor multi_device_iterator_get_next_from_shard(
const tensor& multi_device_iterator,
const tensor& shard_num,
12983 const tensor& incarnation_id,
12984 const std::vector<datatype>& output_types,
12985 const std::vector<std::vector<int64_t>>& output_shapes) {
12987 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
12988 TFE_NewOp(context::get_context(),
"MultiDeviceIteratorGetNextFromShard", context::get_status()), &TFE_DeleteOp);
12989 status_check(context::get_status());
12993 TFE_OpAddInput(op.get(), multi_device_iterator.tfe_handle.get(), context::get_status());
12994 status_check(context::get_status());
12996 TFE_OpAddInput(op.get(), shard_num.tfe_handle.get(), context::get_status());
12997 status_check(context::get_status());
12999 TFE_OpAddInput(op.get(), incarnation_id.tfe_handle.get(), context::get_status());
13000 status_check(context::get_status());
13003 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
13004 static_cast<int>(output_types.size()));
13006 std::vector<const int64_t*> output_shapes_values;
13007 output_shapes_values.reserve(output_shapes.size());
13008 std::vector<int> output_shapes_ndims;
13009 output_shapes_ndims.reserve(output_shapes.size());
13010 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
13011 [](
const auto& v) { return v.data(); });
13012 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
13013 [](
const auto& v) { return static_cast<int>(v.size()); });
13014 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
13015 static_cast<int>(output_shapes.size()), context::get_status());
13016 status_check(context::get_status());
13019 int num_outputs_op = 1;
13020 TFE_TensorHandle* res[1] = {
nullptr};
13021 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13022 status_check(context::get_status());
13023 return tensor(res[0]);
13026 inline tensor multi_device_iterator_init(
const tensor& dataset,
const tensor& multi_device_iterator,
13027 const tensor& max_buffer_size) {
13029 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13030 TFE_NewOp(context::get_context(),
"MultiDeviceIteratorInit", context::get_status()), &TFE_DeleteOp);
13031 status_check(context::get_status());
13035 TFE_OpAddInput(op.get(), dataset.tfe_handle.get(), context::get_status());
13036 status_check(context::get_status());
13038 TFE_OpAddInput(op.get(), multi_device_iterator.tfe_handle.get(), context::get_status());
13039 status_check(context::get_status());
13041 TFE_OpAddInput(op.get(), max_buffer_size.tfe_handle.get(), context::get_status());
13042 status_check(context::get_status());
13047 int num_outputs_op = 1;
13048 TFE_TensorHandle* res[1] = {
nullptr};
13049 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13050 status_check(context::get_status());
13051 return tensor(res[0]);
13054 inline tensor multi_device_iterator_to_string_handle(
const tensor& multi_device_iterator) {
13056 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13057 TFE_NewOp(context::get_context(),
"MultiDeviceIteratorToStringHandle", context::get_status()), &TFE_DeleteOp);
13058 status_check(context::get_status());
13062 TFE_OpAddInput(op.get(), multi_device_iterator.tfe_handle.get(), context::get_status());
13063 status_check(context::get_status());
13068 int num_outputs_op = 1;
13069 TFE_TensorHandle* res[1] = {
nullptr};
13070 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13071 status_check(context::get_status());
13072 return tensor(res[0]);
13075 inline tensor multinomial(
const tensor& logits,
const tensor& num_samples, int64_t seed = 0, int64_t seed2 = 0,
13076 datatype output_dtype =
static_cast<datatype
>(9)) {
13078 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13079 TFE_NewOp(context::get_context(),
"Multinomial", context::get_status()), &TFE_DeleteOp);
13080 status_check(context::get_status());
13084 TFE_OpAddInput(op.get(), logits.tfe_handle.get(), context::get_status());
13085 status_check(context::get_status());
13087 TFE_OpAddInput(op.get(), num_samples.tfe_handle.get(), context::get_status());
13088 status_check(context::get_status());
13091 TFE_OpSetAttrInt(op.get(),
"seed", seed);
13092 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
13093 TFE_OpSetAttrType(op.get(),
"output_dtype", output_dtype);
13096 int num_outputs_op = 1;
13097 TFE_TensorHandle* res[1] = {
nullptr};
13098 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13099 status_check(context::get_status());
13100 return tensor(res[0]);
13103 inline tensor mutable_dense_hash_table(
const tensor& empty_key, datatype key_dtype, datatype value_dtype,
13104 const std::vector<int64_t>& value_shape,
const std::string& container =
"",
13105 const std::string& shared_name =
"",
bool use_node_name_sharing =
false,
13106 int64_t initial_num_buckets = 131072,
float max_load_factor = 8.0000e-01) {
13108 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13109 TFE_NewOp(context::get_context(),
"MutableDenseHashTable", context::get_status()), &TFE_DeleteOp);
13110 status_check(context::get_status());
13114 TFE_OpAddInput(op.get(), empty_key.tfe_handle.get(), context::get_status());
13115 status_check(context::get_status());
13118 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13119 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13121 TFE_OpSetAttrShape(op.get(),
"value_shape", value_shape.data(),
static_cast<int>(value_shape.size()),
13122 context::get_status());
13123 status_check(context::get_status());
13125 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13126 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13127 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13128 TFE_OpSetAttrInt(op.get(),
"initial_num_buckets", initial_num_buckets);
13129 TFE_OpSetAttrFloat(op.get(),
"max_load_factor", max_load_factor);
13132 int num_outputs_op = 1;
13133 TFE_TensorHandle* res[1] = {
nullptr};
13134 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13135 status_check(context::get_status());
13136 return tensor(res[0]);
13139 inline tensor mutable_dense_hash_table_v2(
const tensor& empty_key,
const tensor& deleted_key, datatype key_dtype,
13140 datatype value_dtype,
const std::vector<int64_t>& value_shape,
13141 const std::string& container =
"",
const std::string& shared_name =
"",
13142 bool use_node_name_sharing =
false, int64_t initial_num_buckets = 131072,
13143 float max_load_factor = 8.0000e-01) {
13145 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13146 TFE_NewOp(context::get_context(),
"MutableDenseHashTableV2", context::get_status()), &TFE_DeleteOp);
13147 status_check(context::get_status());
13151 TFE_OpAddInput(op.get(), empty_key.tfe_handle.get(), context::get_status());
13152 status_check(context::get_status());
13154 TFE_OpAddInput(op.get(), deleted_key.tfe_handle.get(), context::get_status());
13155 status_check(context::get_status());
13158 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13159 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13161 TFE_OpSetAttrShape(op.get(),
"value_shape", value_shape.data(),
static_cast<int>(value_shape.size()),
13162 context::get_status());
13163 status_check(context::get_status());
13165 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13166 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13167 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13168 TFE_OpSetAttrInt(op.get(),
"initial_num_buckets", initial_num_buckets);
13169 TFE_OpSetAttrFloat(op.get(),
"max_load_factor", max_load_factor);
13172 int num_outputs_op = 1;
13173 TFE_TensorHandle* res[1] = {
nullptr};
13174 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13175 status_check(context::get_status());
13176 return tensor(res[0]);
13179 inline tensor mutable_hash_table(datatype key_dtype, datatype value_dtype,
const std::string& container =
"",
13180 const std::string& shared_name =
"",
bool use_node_name_sharing =
false) {
13182 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13183 TFE_NewOp(context::get_context(),
"MutableHashTable", context::get_status()), &TFE_DeleteOp);
13184 status_check(context::get_status());
13189 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13190 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13191 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13192 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13193 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13196 int num_outputs_op = 1;
13197 TFE_TensorHandle* res[1] = {
nullptr};
13198 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13199 status_check(context::get_status());
13200 return tensor(res[0]);
13203 inline tensor mutable_hash_table_of_tensors(datatype key_dtype, datatype value_dtype,
13204 const std::vector<int64_t>& value_shape,
const std::string& container =
"",
13205 const std::string& shared_name =
"",
bool use_node_name_sharing =
false) {
13207 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13208 TFE_NewOp(context::get_context(),
"MutableHashTableOfTensors", context::get_status()), &TFE_DeleteOp);
13209 status_check(context::get_status());
13214 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13215 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13217 TFE_OpSetAttrShape(op.get(),
"value_shape", value_shape.data(),
static_cast<int>(value_shape.size()),
13218 context::get_status());
13219 status_check(context::get_status());
13221 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13222 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13223 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13226 int num_outputs_op = 1;
13227 TFE_TensorHandle* res[1] = {
nullptr};
13228 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13229 status_check(context::get_status());
13230 return tensor(res[0]);
13233 inline tensor mutable_hash_table_of_tensors_v2(datatype key_dtype, datatype value_dtype,
13234 const std::vector<int64_t>& value_shape,
13235 const std::string& container =
"",
const std::string& shared_name =
"",
13236 bool use_node_name_sharing =
false) {
13238 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13239 TFE_NewOp(context::get_context(),
"MutableHashTableOfTensorsV2", context::get_status()), &TFE_DeleteOp);
13240 status_check(context::get_status());
13245 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13246 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13248 TFE_OpSetAttrShape(op.get(),
"value_shape", value_shape.data(),
static_cast<int>(value_shape.size()),
13249 context::get_status());
13250 status_check(context::get_status());
13252 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13253 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13254 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13257 int num_outputs_op = 1;
13258 TFE_TensorHandle* res[1] = {
nullptr};
13259 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13260 status_check(context::get_status());
13261 return tensor(res[0]);
13264 inline tensor mutable_hash_table_v2(datatype key_dtype, datatype value_dtype,
const std::string& container =
"",
13265 const std::string& shared_name =
"",
bool use_node_name_sharing =
false) {
13267 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13268 TFE_NewOp(context::get_context(),
"MutableHashTableV2", context::get_status()), &TFE_DeleteOp);
13269 status_check(context::get_status());
13274 TFE_OpSetAttrType(op.get(),
"key_dtype", key_dtype);
13275 TFE_OpSetAttrType(op.get(),
"value_dtype", value_dtype);
13276 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13277 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13278 TFE_OpSetAttrBool(op.get(),
"use_node_name_sharing", (
unsigned char)use_node_name_sharing);
13281 int num_outputs_op = 1;
13282 TFE_TensorHandle* res[1] = {
nullptr};
13283 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13284 status_check(context::get_status());
13285 return tensor(res[0]);
13288 inline tensor mutex_lock(
const tensor& mutex) {
13290 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13291 TFE_NewOp(context::get_context(),
"MutexLock", context::get_status()), &TFE_DeleteOp);
13292 status_check(context::get_status());
13296 TFE_OpAddInput(op.get(), mutex.tfe_handle.get(), context::get_status());
13297 status_check(context::get_status());
13302 int num_outputs_op = 1;
13303 TFE_TensorHandle* res[1] = {
nullptr};
13304 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13305 status_check(context::get_status());
13306 return tensor(res[0]);
13309 inline tensor mutex_v2(
const std::string& container =
"",
const std::string& shared_name =
"") {
13311 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13312 TFE_NewOp(context::get_context(),
"MutexV2", context::get_status()), &TFE_DeleteOp);
13313 status_check(context::get_status());
13318 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13319 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13322 int num_outputs_op = 1;
13323 TFE_TensorHandle* res[1] = {
nullptr};
13324 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13325 status_check(context::get_status());
13326 return tensor(res[0]);
13329 inline tensor nccl_all_reduce(
const tensor& input,
const std::string& reduction, int64_t num_devices,
13330 const std::string& shared_name) {
13332 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13333 TFE_NewOp(context::get_context(),
"NcclAllReduce", context::get_status()), &TFE_DeleteOp);
13334 status_check(context::get_status());
13338 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
13339 status_check(context::get_status());
13342 TFE_OpSetAttrString(op.get(),
"reduction", (
void*)reduction.c_str(), reduction.size());
13343 TFE_OpSetAttrInt(op.get(),
"num_devices", num_devices);
13344 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13347 int num_outputs_op = 1;
13348 TFE_TensorHandle* res[1] = {
nullptr};
13349 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13350 status_check(context::get_status());
13351 return tensor(res[0]);
13354 inline tensor nccl_broadcast(
const tensor& input,
const std::vector<int64_t>& shape) {
13356 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13357 TFE_NewOp(context::get_context(),
"NcclBroadcast", context::get_status()), &TFE_DeleteOp);
13358 status_check(context::get_status());
13362 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
13363 status_check(context::get_status());
13367 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
13368 status_check(context::get_status());
13371 int num_outputs_op = 1;
13372 TFE_TensorHandle* res[1] = {
nullptr};
13373 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13374 status_check(context::get_status());
13375 return tensor(res[0]);
13378 inline tensor nccl_reduce(
const std::vector<tensor>& input,
const std::string& reduction) {
13380 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13381 TFE_NewOp(context::get_context(),
"NcclReduce", context::get_status()), &TFE_DeleteOp);
13382 status_check(context::get_status());
13386 std::vector<TFE_TensorHandle*> input_handles;
13387 input_handles.reserve(input.size());
13388 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
13389 [](
const auto& t) { return t.tfe_handle.get(); });
13390 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
13391 status_check(context::get_status());
13394 TFE_OpSetAttrString(op.get(),
"reduction", (
void*)reduction.c_str(), reduction.size());
13395 TFE_OpSetAttrInt(op.get(),
"num_devices", input.size());
13398 int num_outputs_op = 1;
13399 TFE_TensorHandle* res[1] = {
nullptr};
13400 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13401 status_check(context::get_status());
13402 return tensor(res[0]);
13405 inline tensor ndtri(
const tensor& x) {
13407 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Ndtri", context::get_status()),
13409 status_check(context::get_status());
13413 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
13414 status_check(context::get_status());
13419 int num_outputs_op = 1;
13420 TFE_TensorHandle* res[1] = {
nullptr};
13421 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13422 status_check(context::get_status());
13423 return tensor(res[0]);
13426 inline tensor neg(
const tensor& x) {
13428 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Neg", context::get_status()),
13430 status_check(context::get_status());
13434 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
13435 status_check(context::get_status());
13440 int num_outputs_op = 1;
13441 TFE_TensorHandle* res[1] = {
nullptr};
13442 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13443 status_check(context::get_status());
13444 return tensor(res[0]);
13447 inline tensor next_after(
const tensor& x1,
const tensor& x2) {
13449 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13450 TFE_NewOp(context::get_context(),
"NextAfter", context::get_status()), &TFE_DeleteOp);
13451 status_check(context::get_status());
13455 TFE_OpAddInput(op.get(), x1.tfe_handle.get(), context::get_status());
13456 status_check(context::get_status());
13458 TFE_OpAddInput(op.get(), x2.tfe_handle.get(), context::get_status());
13459 status_check(context::get_status());
13464 int num_outputs_op = 1;
13465 TFE_TensorHandle* res[1] = {
nullptr};
13466 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13467 status_check(context::get_status());
13468 return tensor(res[0]);
13471 inline tensor next_iteration(
const tensor& data) {
13473 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13474 TFE_NewOp(context::get_context(),
"NextIteration", context::get_status()), &TFE_DeleteOp);
13475 status_check(context::get_status());
13479 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
13480 status_check(context::get_status());
13485 int num_outputs_op = 1;
13486 TFE_TensorHandle* res[1] = {
nullptr};
13487 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13488 status_check(context::get_status());
13489 return tensor(res[0]);
13492 inline tensor non_deterministic_ints(
const tensor& shape, datatype dtype =
static_cast<datatype
>(9),
13493 datatype shape_dtype =
static_cast<datatype
>(9)) {
13495 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13496 TFE_NewOp(context::get_context(),
"NonDeterministicInts", context::get_status()), &TFE_DeleteOp);
13497 status_check(context::get_status());
13501 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
13502 status_check(context::get_status());
13505 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
13506 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
13509 int num_outputs_op = 1;
13510 TFE_TensorHandle* res[1] = {
nullptr};
13511 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13512 status_check(context::get_status());
13513 return tensor(res[0]);
13516 inline tensor non_max_suppression(
const tensor& boxes,
const tensor& scores,
const tensor& max_output_size,
13517 float iou_threshold = 5.0000e-01) {
13519 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13520 TFE_NewOp(context::get_context(),
"NonMaxSuppression", context::get_status()), &TFE_DeleteOp);
13521 status_check(context::get_status());
13525 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
13526 status_check(context::get_status());
13528 TFE_OpAddInput(op.get(), scores.tfe_handle.get(), context::get_status());
13529 status_check(context::get_status());
13531 TFE_OpAddInput(op.get(), max_output_size.tfe_handle.get(), context::get_status());
13532 status_check(context::get_status());
13535 TFE_OpSetAttrFloat(op.get(),
"iou_threshold", iou_threshold);
13538 int num_outputs_op = 1;
13539 TFE_TensorHandle* res[1] = {
nullptr};
13540 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13541 status_check(context::get_status());
13542 return tensor(res[0]);
13545 inline tensor non_max_suppression_v2(
const tensor& boxes,
const tensor& scores,
const tensor& max_output_size,
13546 const tensor& iou_threshold, datatype T_threshold =
static_cast<datatype
>(1)) {
13548 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13549 TFE_NewOp(context::get_context(),
"NonMaxSuppressionV2", context::get_status()), &TFE_DeleteOp);
13550 status_check(context::get_status());
13554 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
13555 status_check(context::get_status());
13557 TFE_OpAddInput(op.get(), scores.tfe_handle.get(), context::get_status());
13558 status_check(context::get_status());
13560 TFE_OpAddInput(op.get(), max_output_size.tfe_handle.get(), context::get_status());
13561 status_check(context::get_status());
13563 TFE_OpAddInput(op.get(), iou_threshold.tfe_handle.get(), context::get_status());
13564 status_check(context::get_status());
13567 TFE_OpSetAttrType(op.get(),
"T_threshold", T_threshold);
13570 int num_outputs_op = 1;
13571 TFE_TensorHandle* res[1] = {
nullptr};
13572 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13573 status_check(context::get_status());
13574 return tensor(res[0]);
13577 inline tensor non_max_suppression_v3(
const tensor& boxes,
const tensor& scores,
const tensor& max_output_size,
13578 const tensor& iou_threshold,
const tensor& score_threshold,
13579 datatype T_threshold =
static_cast<datatype
>(1)) {
13581 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13582 TFE_NewOp(context::get_context(),
"NonMaxSuppressionV3", context::get_status()), &TFE_DeleteOp);
13583 status_check(context::get_status());
13587 TFE_OpAddInput(op.get(), boxes.tfe_handle.get(), context::get_status());
13588 status_check(context::get_status());
13590 TFE_OpAddInput(op.get(), scores.tfe_handle.get(), context::get_status());
13591 status_check(context::get_status());
13593 TFE_OpAddInput(op.get(), max_output_size.tfe_handle.get(), context::get_status());
13594 status_check(context::get_status());
13596 TFE_OpAddInput(op.get(), iou_threshold.tfe_handle.get(), context::get_status());
13597 status_check(context::get_status());
13599 TFE_OpAddInput(op.get(), score_threshold.tfe_handle.get(), context::get_status());
13600 status_check(context::get_status());
13603 TFE_OpSetAttrType(op.get(),
"T_threshold", T_threshold);
13606 int num_outputs_op = 1;
13607 TFE_TensorHandle* res[1] = {
nullptr};
13608 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13609 status_check(context::get_status());
13610 return tensor(res[0]);
13613 inline tensor non_max_suppression_with_overlaps(
const tensor& overlaps,
const tensor& scores,
13614 const tensor& max_output_size,
const tensor& overlap_threshold,
13615 const tensor& score_threshold) {
13617 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13618 TFE_NewOp(context::get_context(),
"NonMaxSuppressionWithOverlaps", context::get_status()), &TFE_DeleteOp);
13619 status_check(context::get_status());
13623 TFE_OpAddInput(op.get(), overlaps.tfe_handle.get(), context::get_status());
13624 status_check(context::get_status());
13626 TFE_OpAddInput(op.get(), scores.tfe_handle.get(), context::get_status());
13627 status_check(context::get_status());
13629 TFE_OpAddInput(op.get(), max_output_size.tfe_handle.get(), context::get_status());
13630 status_check(context::get_status());
13632 TFE_OpAddInput(op.get(), overlap_threshold.tfe_handle.get(), context::get_status());
13633 status_check(context::get_status());
13635 TFE_OpAddInput(op.get(), score_threshold.tfe_handle.get(), context::get_status());
13636 status_check(context::get_status());
13641 int num_outputs_op = 1;
13642 TFE_TensorHandle* res[1] = {
nullptr};
13643 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13644 status_check(context::get_status());
13645 return tensor(res[0]);
13648 inline tensor non_serializable_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
13649 const std::vector<std::vector<int64_t>>& output_shapes) {
13651 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13652 TFE_NewOp(context::get_context(),
"NonSerializableDataset", context::get_status()), &TFE_DeleteOp);
13653 status_check(context::get_status());
13657 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
13658 status_check(context::get_status());
13661 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
13662 static_cast<int>(output_types.size()));
13664 std::vector<const int64_t*> output_shapes_values;
13665 output_shapes_values.reserve(output_shapes.size());
13666 std::vector<int> output_shapes_ndims;
13667 output_shapes_ndims.reserve(output_shapes.size());
13668 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
13669 [](
const auto& v) { return v.data(); });
13670 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
13671 [](
const auto& v) { return static_cast<int>(v.size()); });
13672 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
13673 static_cast<int>(output_shapes.size()), context::get_status());
13674 status_check(context::get_status());
13677 int num_outputs_op = 1;
13678 TFE_TensorHandle* res[1] = {
nullptr};
13679 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13680 status_check(context::get_status());
13681 return tensor(res[0]);
13684 inline tensor not_equal(
const tensor& x,
const tensor& y,
bool incompatible_shape_error =
true) {
13686 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13687 TFE_NewOp(context::get_context(),
"NotEqual", context::get_status()), &TFE_DeleteOp);
13688 status_check(context::get_status());
13692 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
13693 status_check(context::get_status());
13695 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
13696 status_check(context::get_status());
13699 TFE_OpSetAttrBool(op.get(),
"incompatible_shape_error", (
unsigned char)incompatible_shape_error);
13702 int num_outputs_op = 1;
13703 TFE_TensorHandle* res[1] = {
nullptr};
13704 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13705 status_check(context::get_status());
13706 return tensor(res[0]);
13709 inline tensor nth_element(
const tensor& input,
const tensor& n,
bool reverse =
false) {
13711 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13712 TFE_NewOp(context::get_context(),
"NthElement", context::get_status()), &TFE_DeleteOp);
13713 status_check(context::get_status());
13717 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
13718 status_check(context::get_status());
13720 TFE_OpAddInput(op.get(), n.tfe_handle.get(), context::get_status());
13721 status_check(context::get_status());
13724 TFE_OpSetAttrBool(op.get(),
"reverse", (
unsigned char)reverse);
13727 int num_outputs_op = 1;
13728 TFE_TensorHandle* res[1] = {
nullptr};
13729 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13730 status_check(context::get_status());
13731 return tensor(res[0]);
13734 inline tensor one_hot(
const tensor& indices,
const tensor& depth,
const tensor& on_value,
const tensor& off_value,
13735 int64_t axis = -1, datatype TI =
static_cast<datatype
>(9)) {
13737 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13738 TFE_NewOp(context::get_context(),
"OneHot", context::get_status()), &TFE_DeleteOp);
13739 status_check(context::get_status());
13743 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
13744 status_check(context::get_status());
13746 TFE_OpAddInput(op.get(), depth.tfe_handle.get(), context::get_status());
13747 status_check(context::get_status());
13749 TFE_OpAddInput(op.get(), on_value.tfe_handle.get(), context::get_status());
13750 status_check(context::get_status());
13752 TFE_OpAddInput(op.get(), off_value.tfe_handle.get(), context::get_status());
13753 status_check(context::get_status());
13756 TFE_OpSetAttrInt(op.get(),
"axis", axis);
13757 TFE_OpSetAttrType(op.get(),
"TI", TI);
13760 int num_outputs_op = 1;
13761 TFE_TensorHandle* res[1] = {
nullptr};
13762 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13763 status_check(context::get_status());
13764 return tensor(res[0]);
13767 inline tensor ones_like(
const tensor& x) {
13769 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13770 TFE_NewOp(context::get_context(),
"OnesLike", context::get_status()), &TFE_DeleteOp);
13771 status_check(context::get_status());
13775 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
13776 status_check(context::get_status());
13781 int num_outputs_op = 1;
13782 TFE_TensorHandle* res[1] = {
nullptr};
13783 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13784 status_check(context::get_status());
13785 return tensor(res[0]);
13788 inline tensor optimize_dataset(
const tensor& input_dataset,
const tensor& optimizations,
13789 const std::vector<datatype>& output_types,
13790 const std::vector<std::vector<int64_t>>& output_shapes,
13791 const std::vector<std::string>& optimization_configs) {
13793 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13794 TFE_NewOp(context::get_context(),
"OptimizeDataset", context::get_status()), &TFE_DeleteOp);
13795 status_check(context::get_status());
13799 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
13800 status_check(context::get_status());
13802 TFE_OpAddInput(op.get(), optimizations.tfe_handle.get(), context::get_status());
13803 status_check(context::get_status());
13806 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
13807 static_cast<int>(output_types.size()));
13809 std::vector<const int64_t*> output_shapes_values;
13810 output_shapes_values.reserve(output_shapes.size());
13811 std::vector<int> output_shapes_ndims;
13812 output_shapes_ndims.reserve(output_shapes.size());
13813 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
13814 [](
const auto& v) { return v.data(); });
13815 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
13816 [](
const auto& v) { return static_cast<int>(v.size()); });
13817 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
13818 static_cast<int>(output_shapes.size()), context::get_status());
13819 status_check(context::get_status());
13821 std::vector<std::size_t> optimization_configs_sizes;
13822 optimization_configs_sizes.reserve(optimization_configs.size());
13823 std::transform(optimization_configs.begin(), optimization_configs.end(),
13824 std::back_inserter(optimization_configs_sizes), [](
const auto& s) { return s.size(); });
13825 TFE_OpSetAttrStringList(op.get(),
"optimization_configs",
13826 reinterpret_cast<const void* const*
>(optimization_configs.data()),
13827 optimization_configs_sizes.data(),
static_cast<int>(optimization_configs.size()));
13830 int num_outputs_op = 1;
13831 TFE_TensorHandle* res[1] = {
nullptr};
13832 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13833 status_check(context::get_status());
13834 return tensor(res[0]);
13837 inline tensor optional_from_value(
const std::vector<tensor>& components,
const std::vector<datatype>& Toutput_types) {
13839 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13840 TFE_NewOp(context::get_context(),
"OptionalFromValue", context::get_status()), &TFE_DeleteOp);
13841 status_check(context::get_status());
13845 std::vector<TFE_TensorHandle*> components_handles;
13846 components_handles.reserve(components.size());
13847 std::transform(components.begin(), components.end(), std::back_inserter(components_handles),
13848 [](
const auto& t) { return t.tfe_handle.get(); });
13849 TFE_OpAddInputList(op.get(), components_handles.data(),
static_cast<int>(components.size()), context::get_status());
13850 status_check(context::get_status());
13853 TFE_OpSetAttrTypeList(op.get(),
"Toutput_types",
reinterpret_cast<const enum TF_DataType*
>(Toutput_types.data()),
13854 static_cast<int>(Toutput_types.size()));
13857 int num_outputs_op = 1;
13858 TFE_TensorHandle* res[1] = {
nullptr};
13859 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13860 status_check(context::get_status());
13861 return tensor(res[0]);
13864 inline tensor optional_get_value(
const tensor& optional,
const std::vector<datatype>& output_types,
13865 const std::vector<std::vector<int64_t>>& output_shapes) {
13867 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13868 TFE_NewOp(context::get_context(),
"OptionalGetValue", context::get_status()), &TFE_DeleteOp);
13869 status_check(context::get_status());
13873 TFE_OpAddInput(op.get(), optional.tfe_handle.get(), context::get_status());
13874 status_check(context::get_status());
13877 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
13878 static_cast<int>(output_types.size()));
13880 std::vector<const int64_t*> output_shapes_values;
13881 output_shapes_values.reserve(output_shapes.size());
13882 std::vector<int> output_shapes_ndims;
13883 output_shapes_ndims.reserve(output_shapes.size());
13884 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
13885 [](
const auto& v) { return v.data(); });
13886 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
13887 [](
const auto& v) { return static_cast<int>(v.size()); });
13888 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
13889 static_cast<int>(output_shapes.size()), context::get_status());
13890 status_check(context::get_status());
13893 int num_outputs_op = 1;
13894 TFE_TensorHandle* res[1] = {
nullptr};
13895 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13896 status_check(context::get_status());
13897 return tensor(res[0]);
13900 inline tensor optional_has_value(
const tensor& optional) {
13902 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13903 TFE_NewOp(context::get_context(),
"OptionalHasValue", context::get_status()), &TFE_DeleteOp);
13904 status_check(context::get_status());
13908 TFE_OpAddInput(op.get(), optional.tfe_handle.get(), context::get_status());
13909 status_check(context::get_status());
13914 int num_outputs_op = 1;
13915 TFE_TensorHandle* res[1] = {
nullptr};
13916 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13917 status_check(context::get_status());
13918 return tensor(res[0]);
13921 inline tensor optional_none() {
13923 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13924 TFE_NewOp(context::get_context(),
"OptionalNone", context::get_status()), &TFE_DeleteOp);
13925 status_check(context::get_status());
13932 int num_outputs_op = 1;
13933 TFE_TensorHandle* res[1] = {
nullptr};
13934 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13935 status_check(context::get_status());
13936 return tensor(res[0]);
13939 inline tensor ordered_map_incomplete_size(
const std::vector<datatype>& dtypes, int64_t capacity = 0,
13940 int64_t memory_limit = 0,
const std::string& container =
"",
13941 const std::string& shared_name =
"") {
13943 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13944 TFE_NewOp(context::get_context(),
"OrderedMapIncompleteSize", context::get_status()), &TFE_DeleteOp);
13945 status_check(context::get_status());
13950 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
13951 static_cast<int>(dtypes.size()));
13952 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
13953 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
13954 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13955 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13958 int num_outputs_op = 1;
13959 TFE_TensorHandle* res[1] = {
nullptr};
13960 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13961 status_check(context::get_status());
13962 return tensor(res[0]);
13965 inline tensor ordered_map_peek(
const tensor& key,
const tensor& indices,
const std::vector<datatype>& dtypes,
13966 int64_t capacity = 0, int64_t memory_limit = 0,
const std::string& container =
"",
13967 const std::string& shared_name =
"") {
13969 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
13970 TFE_NewOp(context::get_context(),
"OrderedMapPeek", context::get_status()), &TFE_DeleteOp);
13971 status_check(context::get_status());
13975 TFE_OpAddInput(op.get(), key.tfe_handle.get(), context::get_status());
13976 status_check(context::get_status());
13978 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
13979 status_check(context::get_status());
13982 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
13983 static_cast<int>(dtypes.size()));
13984 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
13985 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
13986 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
13987 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
13990 int num_outputs_op = 1;
13991 TFE_TensorHandle* res[1] = {
nullptr};
13992 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
13993 status_check(context::get_status());
13994 return tensor(res[0]);
13997 inline tensor ordered_map_size(
const std::vector<datatype>& dtypes, int64_t capacity = 0, int64_t memory_limit = 0,
13998 const std::string& container =
"",
const std::string& shared_name =
"") {
14000 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14001 TFE_NewOp(context::get_context(),
"OrderedMapSize", context::get_status()), &TFE_DeleteOp);
14002 status_check(context::get_status());
14007 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
14008 static_cast<int>(dtypes.size()));
14009 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
14010 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
14011 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
14012 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
14015 int num_outputs_op = 1;
14016 TFE_TensorHandle* res[1] = {
nullptr};
14017 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14018 status_check(context::get_status());
14019 return tensor(res[0]);
14022 inline tensor ordered_map_unstage(
const tensor& key,
const tensor& indices,
const std::vector<datatype>& dtypes,
14023 int64_t capacity = 0, int64_t memory_limit = 0,
const std::string& container =
"",
14024 const std::string& shared_name =
"") {
14026 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14027 TFE_NewOp(context::get_context(),
"OrderedMapUnstage", context::get_status()), &TFE_DeleteOp);
14028 status_check(context::get_status());
14032 TFE_OpAddInput(op.get(), key.tfe_handle.get(), context::get_status());
14033 status_check(context::get_status());
14035 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
14036 status_check(context::get_status());
14039 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
14040 static_cast<int>(dtypes.size()));
14041 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
14042 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
14043 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
14044 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
14047 int num_outputs_op = 1;
14048 TFE_TensorHandle* res[1] = {
nullptr};
14049 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14050 status_check(context::get_status());
14051 return tensor(res[0]);
14054 inline tensor outfeed_dequeue(datatype dtype,
const std::vector<int64_t>& shape, int64_t device_ordinal = -1) {
14056 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14057 TFE_NewOp(context::get_context(),
"OutfeedDequeue", context::get_status()), &TFE_DeleteOp);
14058 status_check(context::get_status());
14063 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14065 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14066 status_check(context::get_status());
14068 TFE_OpSetAttrInt(op.get(),
"device_ordinal", device_ordinal);
14071 int num_outputs_op = 1;
14072 TFE_TensorHandle* res[1] = {
nullptr};
14073 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14074 status_check(context::get_status());
14075 return tensor(res[0]);
14078 inline tensor outfeed_dequeue_tuple(
const std::vector<datatype>& dtypes,
14079 const std::vector<std::vector<int64_t>>& shapes, int64_t device_ordinal = -1) {
14081 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14082 TFE_NewOp(context::get_context(),
"OutfeedDequeueTuple", context::get_status()), &TFE_DeleteOp);
14083 status_check(context::get_status());
14088 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
14089 static_cast<int>(dtypes.size()));
14091 std::vector<const int64_t*> shapes_values;
14092 shapes_values.reserve(shapes.size());
14093 std::vector<int> shapes_ndims;
14094 shapes_ndims.reserve(shapes.size());
14095 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
14096 [](
const auto& v) { return v.data(); });
14097 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
14098 [](
const auto& v) { return static_cast<int>(v.size()); });
14099 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
14100 context::get_status());
14101 status_check(context::get_status());
14103 TFE_OpSetAttrInt(op.get(),
"device_ordinal", device_ordinal);
14106 int num_outputs_op = 1;
14107 TFE_TensorHandle* res[1] = {
nullptr};
14108 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14109 status_check(context::get_status());
14110 return tensor(res[0]);
14113 inline tensor pack(
const std::vector<tensor>& values, int64_t axis = 0) {
14115 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Pack", context::get_status()),
14117 status_check(context::get_status());
14121 std::vector<TFE_TensorHandle*> values_handles;
14122 values_handles.reserve(values.size());
14123 std::transform(values.begin(), values.end(), std::back_inserter(values_handles),
14124 [](
const auto& t) { return t.tfe_handle.get(); });
14125 TFE_OpAddInputList(op.get(), values_handles.data(),
static_cast<int>(values.size()), context::get_status());
14126 status_check(context::get_status());
14129 TFE_OpSetAttrInt(op.get(),
"N", values.size());
14130 TFE_OpSetAttrInt(op.get(),
"axis", axis);
14133 int num_outputs_op = 1;
14134 TFE_TensorHandle* res[1] = {
nullptr};
14135 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14136 status_check(context::get_status());
14137 return tensor(res[0]);
14140 inline tensor pad(
const tensor& input,
const tensor& paddings, datatype Tpaddings =
static_cast<datatype
>(3)) {
14142 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Pad", context::get_status()),
14144 status_check(context::get_status());
14148 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
14149 status_check(context::get_status());
14151 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
14152 status_check(context::get_status());
14155 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
14158 int num_outputs_op = 1;
14159 TFE_TensorHandle* res[1] = {
nullptr};
14160 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14161 status_check(context::get_status());
14162 return tensor(res[0]);
14165 inline tensor pad_v2(
const tensor& input,
const tensor& paddings,
const tensor& constant_values,
14166 datatype Tpaddings =
static_cast<datatype
>(3)) {
14168 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"PadV2", context::get_status()),
14170 status_check(context::get_status());
14174 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
14175 status_check(context::get_status());
14177 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
14178 status_check(context::get_status());
14180 TFE_OpAddInput(op.get(), constant_values.tfe_handle.get(), context::get_status());
14181 status_check(context::get_status());
14184 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
14187 int num_outputs_op = 1;
14188 TFE_TensorHandle* res[1] = {
nullptr};
14189 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14190 status_check(context::get_status());
14191 return tensor(res[0]);
14194 inline tensor padded_batch_dataset(
const tensor& input_dataset,
const tensor& batch_size,
14195 const std::vector<tensor>& padded_shapes,
const std::vector<tensor>& padding_values,
14196 const std::vector<datatype>& Toutput_types,
14197 const std::vector<std::vector<int64_t>>& output_shapes) {
14199 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14200 TFE_NewOp(context::get_context(),
"PaddedBatchDataset", context::get_status()), &TFE_DeleteOp);
14201 status_check(context::get_status());
14205 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
14206 status_check(context::get_status());
14208 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
14209 status_check(context::get_status());
14211 std::vector<TFE_TensorHandle*> padded_shapes_handles;
14212 padded_shapes_handles.reserve(padded_shapes.size());
14213 std::transform(padded_shapes.begin(), padded_shapes.end(), std::back_inserter(padded_shapes_handles),
14214 [](
const auto& t) { return t.tfe_handle.get(); });
14215 TFE_OpAddInputList(op.get(), padded_shapes_handles.data(),
static_cast<int>(padded_shapes.size()),
14216 context::get_status());
14217 status_check(context::get_status());
14219 std::vector<TFE_TensorHandle*> padding_values_handles;
14220 padding_values_handles.reserve(padding_values.size());
14221 std::transform(padding_values.begin(), padding_values.end(), std::back_inserter(padding_values_handles),
14222 [](
const auto& t) { return t.tfe_handle.get(); });
14223 TFE_OpAddInputList(op.get(), padding_values_handles.data(),
static_cast<int>(padding_values.size()),
14224 context::get_status());
14225 status_check(context::get_status());
14228 TFE_OpSetAttrTypeList(op.get(),
"Toutput_types",
reinterpret_cast<const enum TF_DataType*
>(Toutput_types.data()),
14229 static_cast<int>(Toutput_types.size()));
14231 std::vector<const int64_t*> output_shapes_values;
14232 output_shapes_values.reserve(output_shapes.size());
14233 std::vector<int> output_shapes_ndims;
14234 output_shapes_ndims.reserve(output_shapes.size());
14235 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
14236 [](
const auto& v) { return v.data(); });
14237 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
14238 [](
const auto& v) { return static_cast<int>(v.size()); });
14239 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
14240 static_cast<int>(output_shapes.size()), context::get_status());
14241 status_check(context::get_status());
14243 TFE_OpSetAttrInt(op.get(),
"N", padded_shapes.size());
14246 int num_outputs_op = 1;
14247 TFE_TensorHandle* res[1] = {
nullptr};
14248 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14249 status_check(context::get_status());
14250 return tensor(res[0]);
14253 inline tensor padded_batch_dataset_v2(
const tensor& input_dataset,
const tensor& batch_size,
14254 const std::vector<tensor>& padded_shapes,
14255 const std::vector<tensor>& padding_values,
const tensor& drop_remainder,
14256 const std::vector<datatype>& Toutput_types,
14257 const std::vector<std::vector<int64_t>>& output_shapes,
14258 bool parallel_copy =
false) {
14260 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14261 TFE_NewOp(context::get_context(),
"PaddedBatchDatasetV2", context::get_status()), &TFE_DeleteOp);
14262 status_check(context::get_status());
14266 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
14267 status_check(context::get_status());
14269 TFE_OpAddInput(op.get(), batch_size.tfe_handle.get(), context::get_status());
14270 status_check(context::get_status());
14272 std::vector<TFE_TensorHandle*> padded_shapes_handles;
14273 padded_shapes_handles.reserve(padded_shapes.size());
14274 std::transform(padded_shapes.begin(), padded_shapes.end(), std::back_inserter(padded_shapes_handles),
14275 [](
const auto& t) { return t.tfe_handle.get(); });
14276 TFE_OpAddInputList(op.get(), padded_shapes_handles.data(),
static_cast<int>(padded_shapes.size()),
14277 context::get_status());
14278 status_check(context::get_status());
14280 std::vector<TFE_TensorHandle*> padding_values_handles;
14281 padding_values_handles.reserve(padding_values.size());
14282 std::transform(padding_values.begin(), padding_values.end(), std::back_inserter(padding_values_handles),
14283 [](
const auto& t) { return t.tfe_handle.get(); });
14284 TFE_OpAddInputList(op.get(), padding_values_handles.data(),
static_cast<int>(padding_values.size()),
14285 context::get_status());
14286 status_check(context::get_status());
14288 TFE_OpAddInput(op.get(), drop_remainder.tfe_handle.get(), context::get_status());
14289 status_check(context::get_status());
14292 TFE_OpSetAttrTypeList(op.get(),
"Toutput_types",
reinterpret_cast<const enum TF_DataType*
>(Toutput_types.data()),
14293 static_cast<int>(Toutput_types.size()));
14295 std::vector<const int64_t*> output_shapes_values;
14296 output_shapes_values.reserve(output_shapes.size());
14297 std::vector<int> output_shapes_ndims;
14298 output_shapes_ndims.reserve(output_shapes.size());
14299 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
14300 [](
const auto& v) { return v.data(); });
14301 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
14302 [](
const auto& v) { return static_cast<int>(v.size()); });
14303 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
14304 static_cast<int>(output_shapes.size()), context::get_status());
14305 status_check(context::get_status());
14307 TFE_OpSetAttrInt(op.get(),
"N", padded_shapes.size());
14308 TFE_OpSetAttrBool(op.get(),
"parallel_copy", (
unsigned char)parallel_copy);
14311 int num_outputs_op = 1;
14312 TFE_TensorHandle* res[1] = {
nullptr};
14313 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14314 status_check(context::get_status());
14315 return tensor(res[0]);
14318 inline tensor padding_f_i_f_o_queue(
const std::vector<datatype>& component_types,
14319 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
14320 const std::string& container =
"",
const std::string& shared_name =
"") {
14322 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14323 TFE_NewOp(context::get_context(),
"PaddingFIFOQueue", context::get_status()), &TFE_DeleteOp);
14324 status_check(context::get_status());
14329 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
14330 static_cast<int>(component_types.size()));
14332 std::vector<const int64_t*> shapes_values;
14333 shapes_values.reserve(shapes.size());
14334 std::vector<int> shapes_ndims;
14335 shapes_ndims.reserve(shapes.size());
14336 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
14337 [](
const auto& v) { return v.data(); });
14338 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
14339 [](
const auto& v) { return static_cast<int>(v.size()); });
14340 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
14341 context::get_status());
14342 status_check(context::get_status());
14344 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
14345 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
14346 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
14349 int num_outputs_op = 1;
14350 TFE_TensorHandle* res[1] = {
nullptr};
14351 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14352 status_check(context::get_status());
14353 return tensor(res[0]);
14356 inline tensor padding_f_i_f_o_queue_v2(
const std::vector<datatype>& component_types,
14357 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
14358 const std::string& container =
"",
const std::string& shared_name =
"") {
14360 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14361 TFE_NewOp(context::get_context(),
"PaddingFIFOQueueV2", context::get_status()), &TFE_DeleteOp);
14362 status_check(context::get_status());
14367 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
14368 static_cast<int>(component_types.size()));
14370 std::vector<const int64_t*> shapes_values;
14371 shapes_values.reserve(shapes.size());
14372 std::vector<int> shapes_ndims;
14373 shapes_ndims.reserve(shapes.size());
14374 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
14375 [](
const auto& v) { return v.data(); });
14376 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
14377 [](
const auto& v) { return static_cast<int>(v.size()); });
14378 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
14379 context::get_status());
14380 status_check(context::get_status());
14382 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
14383 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
14384 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
14387 int num_outputs_op = 1;
14388 TFE_TensorHandle* res[1] = {
nullptr};
14389 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14390 status_check(context::get_status());
14391 return tensor(res[0]);
14394 inline tensor parallel_concat(
const std::vector<tensor>& values,
const std::vector<int64_t>& shape) {
14396 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14397 TFE_NewOp(context::get_context(),
"ParallelConcat", context::get_status()), &TFE_DeleteOp);
14398 status_check(context::get_status());
14402 std::vector<TFE_TensorHandle*> values_handles;
14403 values_handles.reserve(values.size());
14404 std::transform(values.begin(), values.end(), std::back_inserter(values_handles),
14405 [](
const auto& t) { return t.tfe_handle.get(); });
14406 TFE_OpAddInputList(op.get(), values_handles.data(),
static_cast<int>(values.size()), context::get_status());
14407 status_check(context::get_status());
14410 TFE_OpSetAttrInt(op.get(),
"N", values.size());
14412 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14413 status_check(context::get_status());
14416 int num_outputs_op = 1;
14417 TFE_TensorHandle* res[1] = {
nullptr};
14418 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14419 status_check(context::get_status());
14420 return tensor(res[0]);
14423 inline tensor parallel_dynamic_stitch(
const std::vector<tensor>& indices,
const std::vector<tensor>& data) {
14425 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14426 TFE_NewOp(context::get_context(),
"ParallelDynamicStitch", context::get_status()), &TFE_DeleteOp);
14427 status_check(context::get_status());
14431 std::vector<TFE_TensorHandle*> indices_handles;
14432 indices_handles.reserve(indices.size());
14433 std::transform(indices.begin(), indices.end(), std::back_inserter(indices_handles),
14434 [](
const auto& t) { return t.tfe_handle.get(); });
14435 TFE_OpAddInputList(op.get(), indices_handles.data(),
static_cast<int>(indices.size()), context::get_status());
14436 status_check(context::get_status());
14438 std::vector<TFE_TensorHandle*> data_handles;
14439 data_handles.reserve(data.size());
14440 std::transform(data.begin(), data.end(), std::back_inserter(data_handles),
14441 [](
const auto& t) { return t.tfe_handle.get(); });
14442 TFE_OpAddInputList(op.get(), data_handles.data(),
static_cast<int>(data.size()), context::get_status());
14443 status_check(context::get_status());
14446 TFE_OpSetAttrInt(op.get(),
"N", indices.size());
14449 int num_outputs_op = 1;
14450 TFE_TensorHandle* res[1] = {
nullptr};
14451 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14452 status_check(context::get_status());
14453 return tensor(res[0]);
14456 inline tensor parameterized_truncated_normal(
const tensor& shape,
const tensor& means,
const tensor& stdevs,
14457 const tensor& minvals,
const tensor& maxvals, datatype dtype,
14458 int64_t seed = 0, int64_t seed2 = 0) {
14460 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14461 TFE_NewOp(context::get_context(),
"ParameterizedTruncatedNormal", context::get_status()), &TFE_DeleteOp);
14462 status_check(context::get_status());
14466 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
14467 status_check(context::get_status());
14469 TFE_OpAddInput(op.get(), means.tfe_handle.get(), context::get_status());
14470 status_check(context::get_status());
14472 TFE_OpAddInput(op.get(), stdevs.tfe_handle.get(), context::get_status());
14473 status_check(context::get_status());
14475 TFE_OpAddInput(op.get(), minvals.tfe_handle.get(), context::get_status());
14476 status_check(context::get_status());
14478 TFE_OpAddInput(op.get(), maxvals.tfe_handle.get(), context::get_status());
14479 status_check(context::get_status());
14482 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14483 TFE_OpSetAttrInt(op.get(),
"seed", seed);
14484 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
14487 int num_outputs_op = 1;
14488 TFE_TensorHandle* res[1] = {
nullptr};
14489 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14490 status_check(context::get_status());
14491 return tensor(res[0]);
14494 inline tensor parse_example_dataset(
14495 const tensor& input_dataset,
const tensor& num_parallel_calls,
const std::vector<tensor>& dense_defaults,
14496 const std::vector<std::string>& sparse_keys,
const std::vector<std::string>& dense_keys,
14497 const std::vector<datatype>& sparse_types,
const std::vector<datatype>& Tdense,
14498 const std::vector<std::vector<int64_t>>& dense_shapes,
const std::vector<datatype>& output_types,
14499 const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<std::string>& ragged_keys,
14500 const std::vector<datatype>& ragged_value_types,
const std::vector<datatype>& ragged_split_types,
14501 bool sloppy =
false) {
14503 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14504 TFE_NewOp(context::get_context(),
"ParseExampleDataset", context::get_status()), &TFE_DeleteOp);
14505 status_check(context::get_status());
14509 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
14510 status_check(context::get_status());
14512 TFE_OpAddInput(op.get(), num_parallel_calls.tfe_handle.get(), context::get_status());
14513 status_check(context::get_status());
14515 std::vector<TFE_TensorHandle*> dense_defaults_handles;
14516 dense_defaults_handles.reserve(dense_defaults.size());
14517 std::transform(dense_defaults.begin(), dense_defaults.end(), std::back_inserter(dense_defaults_handles),
14518 [](
const auto& t) { return t.tfe_handle.get(); });
14519 TFE_OpAddInputList(op.get(), dense_defaults_handles.data(),
static_cast<int>(dense_defaults.size()),
14520 context::get_status());
14521 status_check(context::get_status());
14525 std::vector<std::size_t> sparse_keys_sizes;
14526 sparse_keys_sizes.reserve(sparse_keys.size());
14527 std::transform(sparse_keys.begin(), sparse_keys.end(), std::back_inserter(sparse_keys_sizes),
14528 [](
const auto& s) { return s.size(); });
14529 TFE_OpSetAttrStringList(op.get(),
"sparse_keys",
reinterpret_cast<const void* const*
>(sparse_keys.data()),
14530 sparse_keys_sizes.data(),
static_cast<int>(sparse_keys.size()));
14532 std::vector<std::size_t> dense_keys_sizes;
14533 dense_keys_sizes.reserve(dense_keys.size());
14534 std::transform(dense_keys.begin(), dense_keys.end(), std::back_inserter(dense_keys_sizes),
14535 [](
const auto& s) { return s.size(); });
14536 TFE_OpSetAttrStringList(op.get(),
"dense_keys",
reinterpret_cast<const void* const*
>(dense_keys.data()),
14537 dense_keys_sizes.data(),
static_cast<int>(dense_keys.size()));
14539 TFE_OpSetAttrTypeList(op.get(),
"sparse_types",
reinterpret_cast<const enum TF_DataType*
>(sparse_types.data()),
14540 static_cast<int>(sparse_types.size()));
14541 TFE_OpSetAttrTypeList(op.get(),
"Tdense",
reinterpret_cast<const enum TF_DataType*
>(Tdense.data()),
14542 static_cast<int>(Tdense.size()));
14544 std::vector<const int64_t*> dense_shapes_values;
14545 dense_shapes_values.reserve(dense_shapes.size());
14546 std::vector<int> dense_shapes_ndims;
14547 dense_shapes_ndims.reserve(dense_shapes.size());
14548 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_values),
14549 [](
const auto& v) { return v.data(); });
14550 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_ndims),
14551 [](
const auto& v) { return static_cast<int>(v.size()); });
14552 TFE_OpSetAttrShapeList(op.get(),
"dense_shapes", dense_shapes_values.data(), dense_shapes_ndims.data(),
14553 static_cast<int>(dense_shapes.size()), context::get_status());
14554 status_check(context::get_status());
14556 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
14557 static_cast<int>(output_types.size()));
14559 std::vector<const int64_t*> output_shapes_values;
14560 output_shapes_values.reserve(output_shapes.size());
14561 std::vector<int> output_shapes_ndims;
14562 output_shapes_ndims.reserve(output_shapes.size());
14563 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
14564 [](
const auto& v) { return v.data(); });
14565 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
14566 [](
const auto& v) { return static_cast<int>(v.size()); });
14567 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
14568 static_cast<int>(output_shapes.size()), context::get_status());
14569 status_check(context::get_status());
14571 std::vector<std::size_t> ragged_keys_sizes;
14572 ragged_keys_sizes.reserve(ragged_keys.size());
14573 std::transform(ragged_keys.begin(), ragged_keys.end(), std::back_inserter(ragged_keys_sizes),
14574 [](
const auto& s) { return s.size(); });
14575 TFE_OpSetAttrStringList(op.get(),
"ragged_keys",
reinterpret_cast<const void* const*
>(ragged_keys.data()),
14576 ragged_keys_sizes.data(),
static_cast<int>(ragged_keys.size()));
14578 TFE_OpSetAttrTypeList(op.get(),
"ragged_value_types",
14579 reinterpret_cast<const enum TF_DataType*
>(ragged_value_types.data()),
14580 static_cast<int>(ragged_value_types.size()));
14581 TFE_OpSetAttrTypeList(op.get(),
"ragged_split_types",
14582 reinterpret_cast<const enum TF_DataType*
>(ragged_split_types.data()),
14583 static_cast<int>(ragged_split_types.size()));
14584 TFE_OpSetAttrBool(op.get(),
"sloppy", (
unsigned char)sloppy);
14587 int num_outputs_op = 1;
14588 TFE_TensorHandle* res[1] = {
nullptr};
14589 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14590 status_check(context::get_status());
14591 return tensor(res[0]);
14594 inline tensor parse_example_dataset_v2(
14595 const tensor& input_dataset,
const tensor& num_parallel_calls,
const std::vector<tensor>& dense_defaults,
14596 const std::vector<std::string>& sparse_keys,
const std::vector<std::string>& dense_keys,
14597 const std::vector<datatype>& sparse_types,
const std::vector<datatype>& Tdense,
14598 const std::vector<std::vector<int64_t>>& dense_shapes,
const std::vector<datatype>& output_types,
14599 const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<std::string>& ragged_keys,
14600 const std::vector<datatype>& ragged_value_types,
const std::vector<datatype>& ragged_split_types,
14601 const std::string& deterministic =
"default") {
14603 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14604 TFE_NewOp(context::get_context(),
"ParseExampleDatasetV2", context::get_status()), &TFE_DeleteOp);
14605 status_check(context::get_status());
14609 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
14610 status_check(context::get_status());
14612 TFE_OpAddInput(op.get(), num_parallel_calls.tfe_handle.get(), context::get_status());
14613 status_check(context::get_status());
14615 std::vector<TFE_TensorHandle*> dense_defaults_handles;
14616 dense_defaults_handles.reserve(dense_defaults.size());
14617 std::transform(dense_defaults.begin(), dense_defaults.end(), std::back_inserter(dense_defaults_handles),
14618 [](
const auto& t) { return t.tfe_handle.get(); });
14619 TFE_OpAddInputList(op.get(), dense_defaults_handles.data(),
static_cast<int>(dense_defaults.size()),
14620 context::get_status());
14621 status_check(context::get_status());
14625 std::vector<std::size_t> sparse_keys_sizes;
14626 sparse_keys_sizes.reserve(sparse_keys.size());
14627 std::transform(sparse_keys.begin(), sparse_keys.end(), std::back_inserter(sparse_keys_sizes),
14628 [](
const auto& s) { return s.size(); });
14629 TFE_OpSetAttrStringList(op.get(),
"sparse_keys",
reinterpret_cast<const void* const*
>(sparse_keys.data()),
14630 sparse_keys_sizes.data(),
static_cast<int>(sparse_keys.size()));
14632 std::vector<std::size_t> dense_keys_sizes;
14633 dense_keys_sizes.reserve(dense_keys.size());
14634 std::transform(dense_keys.begin(), dense_keys.end(), std::back_inserter(dense_keys_sizes),
14635 [](
const auto& s) { return s.size(); });
14636 TFE_OpSetAttrStringList(op.get(),
"dense_keys",
reinterpret_cast<const void* const*
>(dense_keys.data()),
14637 dense_keys_sizes.data(),
static_cast<int>(dense_keys.size()));
14639 TFE_OpSetAttrTypeList(op.get(),
"sparse_types",
reinterpret_cast<const enum TF_DataType*
>(sparse_types.data()),
14640 static_cast<int>(sparse_types.size()));
14641 TFE_OpSetAttrTypeList(op.get(),
"Tdense",
reinterpret_cast<const enum TF_DataType*
>(Tdense.data()),
14642 static_cast<int>(Tdense.size()));
14644 std::vector<const int64_t*> dense_shapes_values;
14645 dense_shapes_values.reserve(dense_shapes.size());
14646 std::vector<int> dense_shapes_ndims;
14647 dense_shapes_ndims.reserve(dense_shapes.size());
14648 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_values),
14649 [](
const auto& v) { return v.data(); });
14650 std::transform(dense_shapes.begin(), dense_shapes.end(), std::back_inserter(dense_shapes_ndims),
14651 [](
const auto& v) { return static_cast<int>(v.size()); });
14652 TFE_OpSetAttrShapeList(op.get(),
"dense_shapes", dense_shapes_values.data(), dense_shapes_ndims.data(),
14653 static_cast<int>(dense_shapes.size()), context::get_status());
14654 status_check(context::get_status());
14656 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
14657 static_cast<int>(output_types.size()));
14659 std::vector<const int64_t*> output_shapes_values;
14660 output_shapes_values.reserve(output_shapes.size());
14661 std::vector<int> output_shapes_ndims;
14662 output_shapes_ndims.reserve(output_shapes.size());
14663 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
14664 [](
const auto& v) { return v.data(); });
14665 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
14666 [](
const auto& v) { return static_cast<int>(v.size()); });
14667 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
14668 static_cast<int>(output_shapes.size()), context::get_status());
14669 status_check(context::get_status());
14671 std::vector<std::size_t> ragged_keys_sizes;
14672 ragged_keys_sizes.reserve(ragged_keys.size());
14673 std::transform(ragged_keys.begin(), ragged_keys.end(), std::back_inserter(ragged_keys_sizes),
14674 [](
const auto& s) { return s.size(); });
14675 TFE_OpSetAttrStringList(op.get(),
"ragged_keys",
reinterpret_cast<const void* const*
>(ragged_keys.data()),
14676 ragged_keys_sizes.data(),
static_cast<int>(ragged_keys.size()));
14678 TFE_OpSetAttrTypeList(op.get(),
"ragged_value_types",
14679 reinterpret_cast<const enum TF_DataType*
>(ragged_value_types.data()),
14680 static_cast<int>(ragged_value_types.size()));
14681 TFE_OpSetAttrTypeList(op.get(),
"ragged_split_types",
14682 reinterpret_cast<const enum TF_DataType*
>(ragged_split_types.data()),
14683 static_cast<int>(ragged_split_types.size()));
14684 TFE_OpSetAttrString(op.get(),
"deterministic", (
void*)deterministic.c_str(), deterministic.size());
14687 int num_outputs_op = 1;
14688 TFE_TensorHandle* res[1] = {
nullptr};
14689 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14690 status_check(context::get_status());
14691 return tensor(res[0]);
14694 inline tensor parse_tensor(
const tensor& serialized, datatype out_type) {
14696 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14697 TFE_NewOp(context::get_context(),
"ParseTensor", context::get_status()), &TFE_DeleteOp);
14698 status_check(context::get_status());
14702 TFE_OpAddInput(op.get(), serialized.tfe_handle.get(), context::get_status());
14703 status_check(context::get_status());
14706 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
14709 int num_outputs_op = 1;
14710 TFE_TensorHandle* res[1] = {
nullptr};
14711 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14712 status_check(context::get_status());
14713 return tensor(res[0]);
14716 inline tensor placeholder(datatype dtype,
const std::vector<int64_t>& shape) {
14718 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14719 TFE_NewOp(context::get_context(),
"Placeholder", context::get_status()), &TFE_DeleteOp);
14720 status_check(context::get_status());
14725 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14727 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14728 status_check(context::get_status());
14731 int num_outputs_op = 1;
14732 TFE_TensorHandle* res[1] = {
nullptr};
14733 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14734 status_check(context::get_status());
14735 return tensor(res[0]);
14738 inline tensor placeholder_v2(datatype dtype,
const std::vector<int64_t>& shape) {
14740 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14741 TFE_NewOp(context::get_context(),
"PlaceholderV2", context::get_status()), &TFE_DeleteOp);
14742 status_check(context::get_status());
14747 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14749 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14750 status_check(context::get_status());
14753 int num_outputs_op = 1;
14754 TFE_TensorHandle* res[1] = {
nullptr};
14755 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14756 status_check(context::get_status());
14757 return tensor(res[0]);
14760 inline tensor placeholder_with_default(
const tensor& input, datatype dtype,
const std::vector<int64_t>& shape) {
14762 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14763 TFE_NewOp(context::get_context(),
"PlaceholderWithDefault", context::get_status()), &TFE_DeleteOp);
14764 status_check(context::get_status());
14768 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
14769 status_check(context::get_status());
14772 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14774 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14775 status_check(context::get_status());
14778 int num_outputs_op = 1;
14779 TFE_TensorHandle* res[1] = {
nullptr};
14780 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14781 status_check(context::get_status());
14782 return tensor(res[0]);
14785 inline tensor polygamma(
const tensor& a,
const tensor& x) {
14787 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14788 TFE_NewOp(context::get_context(),
"Polygamma", context::get_status()), &TFE_DeleteOp);
14789 status_check(context::get_status());
14793 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
14794 status_check(context::get_status());
14796 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
14797 status_check(context::get_status());
14802 int num_outputs_op = 1;
14803 TFE_TensorHandle* res[1] = {
nullptr};
14804 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14805 status_check(context::get_status());
14806 return tensor(res[0]);
14809 inline tensor population_count(
const tensor& x) {
14811 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14812 TFE_NewOp(context::get_context(),
"PopulationCount", context::get_status()), &TFE_DeleteOp);
14813 status_check(context::get_status());
14817 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
14818 status_check(context::get_status());
14823 int num_outputs_op = 1;
14824 TFE_TensorHandle* res[1] = {
nullptr};
14825 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14826 status_check(context::get_status());
14827 return tensor(res[0]);
14830 inline tensor pow(
const tensor& x,
const tensor& y) {
14832 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Pow", context::get_status()),
14834 status_check(context::get_status());
14838 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
14839 status_check(context::get_status());
14841 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
14842 status_check(context::get_status());
14847 int num_outputs_op = 1;
14848 TFE_TensorHandle* res[1] = {
nullptr};
14849 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14850 status_check(context::get_status());
14851 return tensor(res[0]);
14854 inline tensor prefetch_dataset(
const tensor& input_dataset,
const tensor& buffer_size,
14855 const std::vector<datatype>& output_types,
14856 const std::vector<std::vector<int64_t>>& output_shapes, int64_t slack_period = 0,
14857 bool legacy_autotune =
true) {
14859 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14860 TFE_NewOp(context::get_context(),
"PrefetchDataset", context::get_status()), &TFE_DeleteOp);
14861 status_check(context::get_status());
14865 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
14866 status_check(context::get_status());
14868 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
14869 status_check(context::get_status());
14872 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
14873 static_cast<int>(output_types.size()));
14875 std::vector<const int64_t*> output_shapes_values;
14876 output_shapes_values.reserve(output_shapes.size());
14877 std::vector<int> output_shapes_ndims;
14878 output_shapes_ndims.reserve(output_shapes.size());
14879 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
14880 [](
const auto& v) { return v.data(); });
14881 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
14882 [](
const auto& v) { return static_cast<int>(v.size()); });
14883 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
14884 static_cast<int>(output_shapes.size()), context::get_status());
14885 status_check(context::get_status());
14887 TFE_OpSetAttrInt(op.get(),
"slack_period", slack_period);
14888 TFE_OpSetAttrBool(op.get(),
"legacy_autotune", (
unsigned char)legacy_autotune);
14891 int num_outputs_op = 1;
14892 TFE_TensorHandle* res[1] = {
nullptr};
14893 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14894 status_check(context::get_status());
14895 return tensor(res[0]);
14898 inline tensor prelinearize(
const tensor& input, datatype dtype,
const std::vector<int64_t>& shape,
14899 const std::vector<int64_t>& layout) {
14901 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14902 TFE_NewOp(context::get_context(),
"Prelinearize", context::get_status()), &TFE_DeleteOp);
14903 status_check(context::get_status());
14907 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
14908 status_check(context::get_status());
14911 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
14913 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
14914 status_check(context::get_status());
14916 TFE_OpSetAttrIntList(op.get(),
"layout", layout.data(),
static_cast<int>(layout.size()));
14919 int num_outputs_op = 1;
14920 TFE_TensorHandle* res[1] = {
nullptr};
14921 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14922 status_check(context::get_status());
14923 return tensor(res[0]);
14926 inline tensor prelinearize_tuple(
const std::vector<tensor>& inputs,
const std::vector<datatype>& dtypes,
14927 const std::vector<std::vector<int64_t>>& shapes,
const std::vector<int64_t>& layouts) {
14929 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14930 TFE_NewOp(context::get_context(),
"PrelinearizeTuple", context::get_status()), &TFE_DeleteOp);
14931 status_check(context::get_status());
14935 std::vector<TFE_TensorHandle*> inputs_handles;
14936 inputs_handles.reserve(inputs.size());
14937 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
14938 [](
const auto& t) { return t.tfe_handle.get(); });
14939 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
14940 status_check(context::get_status());
14943 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
14944 static_cast<int>(dtypes.size()));
14946 std::vector<const int64_t*> shapes_values;
14947 shapes_values.reserve(shapes.size());
14948 std::vector<int> shapes_ndims;
14949 shapes_ndims.reserve(shapes.size());
14950 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
14951 [](
const auto& v) { return v.data(); });
14952 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
14953 [](
const auto& v) { return static_cast<int>(v.size()); });
14954 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
14955 context::get_status());
14956 status_check(context::get_status());
14958 TFE_OpSetAttrIntList(op.get(),
"layouts", layouts.data(),
static_cast<int>(layouts.size()));
14961 int num_outputs_op = 1;
14962 TFE_TensorHandle* res[1] = {
nullptr};
14963 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14964 status_check(context::get_status());
14965 return tensor(res[0]);
14968 inline tensor prevent_gradient(
const tensor& input,
const std::string& message =
"") {
14970 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
14971 TFE_NewOp(context::get_context(),
"PreventGradient", context::get_status()), &TFE_DeleteOp);
14972 status_check(context::get_status());
14976 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
14977 status_check(context::get_status());
14980 TFE_OpSetAttrString(op.get(),
"message", (
void*)message.c_str(), message.size());
14983 int num_outputs_op = 1;
14984 TFE_TensorHandle* res[1] = {
nullptr};
14985 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
14986 status_check(context::get_status());
14987 return tensor(res[0]);
14990 inline tensor print(
const tensor& input,
const std::vector<tensor>& data,
const std::vector<datatype>& U,
14991 const std::string& message =
"", int64_t first_n = -1, int64_t summarize = 3) {
14993 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Print", context::get_status()),
14995 status_check(context::get_status());
14999 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15000 status_check(context::get_status());
15002 std::vector<TFE_TensorHandle*> data_handles;
15003 data_handles.reserve(data.size());
15004 std::transform(data.begin(), data.end(), std::back_inserter(data_handles),
15005 [](
const auto& t) { return t.tfe_handle.get(); });
15006 TFE_OpAddInputList(op.get(), data_handles.data(),
static_cast<int>(data.size()), context::get_status());
15007 status_check(context::get_status());
15010 TFE_OpSetAttrTypeList(op.get(),
"U",
reinterpret_cast<const enum TF_DataType*
>(U.data()),
static_cast<int>(U.size()));
15011 TFE_OpSetAttrString(op.get(),
"message", (
void*)message.c_str(), message.size());
15012 TFE_OpSetAttrInt(op.get(),
"first_n", first_n);
15013 TFE_OpSetAttrInt(op.get(),
"summarize", summarize);
15016 int num_outputs_op = 1;
15017 TFE_TensorHandle* res[1] = {
nullptr};
15018 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15019 status_check(context::get_status());
15020 return tensor(res[0]);
15023 inline tensor priority_queue(
const std::vector<datatype>& component_types,
15024 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
15025 const std::string& container =
"",
const std::string& shared_name =
"") {
15027 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15028 TFE_NewOp(context::get_context(),
"PriorityQueue", context::get_status()), &TFE_DeleteOp);
15029 status_check(context::get_status());
15034 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15035 static_cast<int>(component_types.size()));
15037 std::vector<const int64_t*> shapes_values;
15038 shapes_values.reserve(shapes.size());
15039 std::vector<int> shapes_ndims;
15040 shapes_ndims.reserve(shapes.size());
15041 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
15042 [](
const auto& v) { return v.data(); });
15043 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
15044 [](
const auto& v) { return static_cast<int>(v.size()); });
15045 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
15046 context::get_status());
15047 status_check(context::get_status());
15049 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
15050 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
15051 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
15054 int num_outputs_op = 1;
15055 TFE_TensorHandle* res[1] = {
nullptr};
15056 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15057 status_check(context::get_status());
15058 return tensor(res[0]);
15061 inline tensor priority_queue_v2(
const std::vector<datatype>& component_types,
15062 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
15063 const std::string& container =
"",
const std::string& shared_name =
"") {
15065 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15066 TFE_NewOp(context::get_context(),
"PriorityQueueV2", context::get_status()), &TFE_DeleteOp);
15067 status_check(context::get_status());
15072 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15073 static_cast<int>(component_types.size()));
15075 std::vector<const int64_t*> shapes_values;
15076 shapes_values.reserve(shapes.size());
15077 std::vector<int> shapes_ndims;
15078 shapes_ndims.reserve(shapes.size());
15079 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
15080 [](
const auto& v) { return v.data(); });
15081 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
15082 [](
const auto& v) { return static_cast<int>(v.size()); });
15083 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
15084 context::get_status());
15085 status_check(context::get_status());
15087 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
15088 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
15089 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
15092 int num_outputs_op = 1;
15093 TFE_TensorHandle* res[1] = {
nullptr};
15094 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15095 status_check(context::get_status());
15096 return tensor(res[0]);
15099 inline tensor private_thread_pool_dataset(
const tensor& input_dataset,
const tensor& num_threads,
15100 const std::vector<datatype>& output_types,
15101 const std::vector<std::vector<int64_t>>& output_shapes) {
15103 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15104 TFE_NewOp(context::get_context(),
"PrivateThreadPoolDataset", context::get_status()), &TFE_DeleteOp);
15105 status_check(context::get_status());
15109 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
15110 status_check(context::get_status());
15112 TFE_OpAddInput(op.get(), num_threads.tfe_handle.get(), context::get_status());
15113 status_check(context::get_status());
15116 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
15117 static_cast<int>(output_types.size()));
15119 std::vector<const int64_t*> output_shapes_values;
15120 output_shapes_values.reserve(output_shapes.size());
15121 std::vector<int> output_shapes_ndims;
15122 output_shapes_ndims.reserve(output_shapes.size());
15123 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
15124 [](
const auto& v) { return v.data(); });
15125 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
15126 [](
const auto& v) { return static_cast<int>(v.size()); });
15127 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
15128 static_cast<int>(output_shapes.size()), context::get_status());
15129 status_check(context::get_status());
15132 int num_outputs_op = 1;
15133 TFE_TensorHandle* res[1] = {
nullptr};
15134 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15135 status_check(context::get_status());
15136 return tensor(res[0]);
15139 inline tensor prod(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
15140 datatype Tidx =
static_cast<datatype
>(3)) {
15142 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Prod", context::get_status()),
15144 status_check(context::get_status());
15148 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15149 status_check(context::get_status());
15151 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
15152 status_check(context::get_status());
15155 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
15156 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
15159 int num_outputs_op = 1;
15160 TFE_TensorHandle* res[1] = {
nullptr};
15161 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15162 status_check(context::get_status());
15163 return tensor(res[0]);
15166 inline tensor py_func(
const std::vector<tensor>& input,
const std::string& token,
const std::vector<datatype>& Tin,
15167 const std::vector<datatype>& Tout) {
15169 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15170 TFE_NewOp(context::get_context(),
"PyFunc", context::get_status()), &TFE_DeleteOp);
15171 status_check(context::get_status());
15175 std::vector<TFE_TensorHandle*> input_handles;
15176 input_handles.reserve(input.size());
15177 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
15178 [](
const auto& t) { return t.tfe_handle.get(); });
15179 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
15180 status_check(context::get_status());
15183 TFE_OpSetAttrString(op.get(),
"token", (
void*)token.c_str(), token.size());
15184 TFE_OpSetAttrTypeList(op.get(),
"Tin",
reinterpret_cast<const enum TF_DataType*
>(Tin.data()),
15185 static_cast<int>(Tin.size()));
15186 TFE_OpSetAttrTypeList(op.get(),
"Tout",
reinterpret_cast<const enum TF_DataType*
>(Tout.data()),
15187 static_cast<int>(Tout.size()));
15190 int num_outputs_op = 1;
15191 TFE_TensorHandle* res[1] = {
nullptr};
15192 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15193 status_check(context::get_status());
15194 return tensor(res[0]);
15197 inline tensor py_func_stateless(
const std::vector<tensor>& input,
const std::string& token,
15198 const std::vector<datatype>& Tin,
const std::vector<datatype>& Tout) {
15200 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15201 TFE_NewOp(context::get_context(),
"PyFuncStateless", context::get_status()), &TFE_DeleteOp);
15202 status_check(context::get_status());
15206 std::vector<TFE_TensorHandle*> input_handles;
15207 input_handles.reserve(input.size());
15208 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
15209 [](
const auto& t) { return t.tfe_handle.get(); });
15210 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
15211 status_check(context::get_status());
15214 TFE_OpSetAttrString(op.get(),
"token", (
void*)token.c_str(), token.size());
15215 TFE_OpSetAttrTypeList(op.get(),
"Tin",
reinterpret_cast<const enum TF_DataType*
>(Tin.data()),
15216 static_cast<int>(Tin.size()));
15217 TFE_OpSetAttrTypeList(op.get(),
"Tout",
reinterpret_cast<const enum TF_DataType*
>(Tout.data()),
15218 static_cast<int>(Tout.size()));
15221 int num_outputs_op = 1;
15222 TFE_TensorHandle* res[1] = {
nullptr};
15223 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15224 status_check(context::get_status());
15225 return tensor(res[0]);
15228 inline tensor quantize_and_dequantize(
const tensor& input,
bool signed_input =
true, int64_t num_bits = 8,
15229 bool range_given =
false,
float input_min = 0.0000e+00,
15230 float input_max = 0.0000e+00) {
15232 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15233 TFE_NewOp(context::get_context(),
"QuantizeAndDequantize", context::get_status()), &TFE_DeleteOp);
15234 status_check(context::get_status());
15238 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15239 status_check(context::get_status());
15242 TFE_OpSetAttrBool(op.get(),
"signed_input", (
unsigned char)signed_input);
15243 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
15244 TFE_OpSetAttrBool(op.get(),
"range_given", (
unsigned char)range_given);
15245 TFE_OpSetAttrFloat(op.get(),
"input_min", input_min);
15246 TFE_OpSetAttrFloat(op.get(),
"input_max", input_max);
15249 int num_outputs_op = 1;
15250 TFE_TensorHandle* res[1] = {
nullptr};
15251 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15252 status_check(context::get_status());
15253 return tensor(res[0]);
15256 inline tensor quantize_and_dequantize_v2(
const tensor& input,
const tensor& input_min,
const tensor& input_max,
15257 bool signed_input =
true, int64_t num_bits = 8,
bool range_given =
false,
15258 const std::string& round_mode =
"HALF_TO_EVEN",
bool narrow_range =
false,
15259 int64_t axis = -1) {
15261 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15262 TFE_NewOp(context::get_context(),
"QuantizeAndDequantizeV2", context::get_status()), &TFE_DeleteOp);
15263 status_check(context::get_status());
15267 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15268 status_check(context::get_status());
15270 TFE_OpAddInput(op.get(), input_min.tfe_handle.get(), context::get_status());
15271 status_check(context::get_status());
15273 TFE_OpAddInput(op.get(), input_max.tfe_handle.get(), context::get_status());
15274 status_check(context::get_status());
15277 TFE_OpSetAttrBool(op.get(),
"signed_input", (
unsigned char)signed_input);
15278 TFE_OpSetAttrInt(op.get(),
"num_bits", num_bits);
15279 TFE_OpSetAttrBool(op.get(),
"range_given", (
unsigned char)range_given);
15280 TFE_OpSetAttrString(op.get(),
"round_mode", (
void*)round_mode.c_str(), round_mode.size());
15281 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
15282 TFE_OpSetAttrInt(op.get(),
"axis", axis);
15285 int num_outputs_op = 1;
15286 TFE_TensorHandle* res[1] = {
nullptr};
15287 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15288 status_check(context::get_status());
15289 return tensor(res[0]);
15292 inline tensor quantize_and_dequantize_v3(
const tensor& input,
const tensor& input_min,
const tensor& input_max,
15293 const tensor& num_bits,
bool signed_input =
true,
bool range_given =
true,
15294 bool narrow_range =
false, int64_t axis = -1) {
15296 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15297 TFE_NewOp(context::get_context(),
"QuantizeAndDequantizeV3", context::get_status()), &TFE_DeleteOp);
15298 status_check(context::get_status());
15302 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15303 status_check(context::get_status());
15305 TFE_OpAddInput(op.get(), input_min.tfe_handle.get(), context::get_status());
15306 status_check(context::get_status());
15308 TFE_OpAddInput(op.get(), input_max.tfe_handle.get(), context::get_status());
15309 status_check(context::get_status());
15311 TFE_OpAddInput(op.get(), num_bits.tfe_handle.get(), context::get_status());
15312 status_check(context::get_status());
15315 TFE_OpSetAttrBool(op.get(),
"signed_input", (
unsigned char)signed_input);
15316 TFE_OpSetAttrBool(op.get(),
"range_given", (
unsigned char)range_given);
15317 TFE_OpSetAttrBool(op.get(),
"narrow_range", (
unsigned char)narrow_range);
15318 TFE_OpSetAttrInt(op.get(),
"axis", axis);
15321 int num_outputs_op = 1;
15322 TFE_TensorHandle* res[1] = {
nullptr};
15323 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15324 status_check(context::get_status());
15325 return tensor(res[0]);
15328 inline tensor quantized_mat_mul_with_bias_and_dequantize(
const tensor& a,
const tensor& b,
const tensor& bias,
15329 const tensor& min_a,
const tensor& max_a,
const tensor& min_b,
15330 const tensor& max_b,
const tensor& min_freezed_output,
15331 const tensor& max_freezed_output, datatype T1, datatype T2,
15332 datatype Tbias, datatype Toutput,
bool transpose_a =
false,
15333 bool transpose_b =
false,
15334 const std::string& input_quant_mode =
"MIN_FIRST") {
15336 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15337 TFE_NewOp(context::get_context(),
"QuantizedMatMulWithBiasAndDequantize", context::get_status()), &TFE_DeleteOp);
15338 status_check(context::get_status());
15342 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
15343 status_check(context::get_status());
15345 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
15346 status_check(context::get_status());
15348 TFE_OpAddInput(op.get(), bias.tfe_handle.get(), context::get_status());
15349 status_check(context::get_status());
15351 TFE_OpAddInput(op.get(), min_a.tfe_handle.get(), context::get_status());
15352 status_check(context::get_status());
15354 TFE_OpAddInput(op.get(), max_a.tfe_handle.get(), context::get_status());
15355 status_check(context::get_status());
15357 TFE_OpAddInput(op.get(), min_b.tfe_handle.get(), context::get_status());
15358 status_check(context::get_status());
15360 TFE_OpAddInput(op.get(), max_b.tfe_handle.get(), context::get_status());
15361 status_check(context::get_status());
15363 TFE_OpAddInput(op.get(), min_freezed_output.tfe_handle.get(), context::get_status());
15364 status_check(context::get_status());
15366 TFE_OpAddInput(op.get(), max_freezed_output.tfe_handle.get(), context::get_status());
15367 status_check(context::get_status());
15370 TFE_OpSetAttrType(op.get(),
"T1", T1);
15371 TFE_OpSetAttrType(op.get(),
"T2", T2);
15372 TFE_OpSetAttrType(op.get(),
"Tbias", Tbias);
15373 TFE_OpSetAttrType(op.get(),
"Toutput", Toutput);
15374 TFE_OpSetAttrBool(op.get(),
"transpose_a", (
unsigned char)transpose_a);
15375 TFE_OpSetAttrBool(op.get(),
"transpose_b", (
unsigned char)transpose_b);
15376 TFE_OpSetAttrString(op.get(),
"input_quant_mode", (
void*)input_quant_mode.c_str(), input_quant_mode.size());
15379 int num_outputs_op = 1;
15380 TFE_TensorHandle* res[1] = {
nullptr};
15381 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15382 status_check(context::get_status());
15383 return tensor(res[0]);
15386 inline tensor queue_dequeue(
const tensor& handle,
const std::vector<datatype>& component_types,
15387 int64_t timeout_ms = -1) {
15389 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15390 TFE_NewOp(context::get_context(),
"QueueDequeue", context::get_status()), &TFE_DeleteOp);
15391 status_check(context::get_status());
15395 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15396 status_check(context::get_status());
15399 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15400 static_cast<int>(component_types.size()));
15401 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15404 int num_outputs_op = 1;
15405 TFE_TensorHandle* res[1] = {
nullptr};
15406 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15407 status_check(context::get_status());
15408 return tensor(res[0]);
15411 inline tensor queue_dequeue_many(
const tensor& handle,
const tensor& n,
const std::vector<datatype>& component_types,
15412 int64_t timeout_ms = -1) {
15414 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15415 TFE_NewOp(context::get_context(),
"QueueDequeueMany", context::get_status()), &TFE_DeleteOp);
15416 status_check(context::get_status());
15420 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15421 status_check(context::get_status());
15423 TFE_OpAddInput(op.get(), n.tfe_handle.get(), context::get_status());
15424 status_check(context::get_status());
15427 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15428 static_cast<int>(component_types.size()));
15429 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15432 int num_outputs_op = 1;
15433 TFE_TensorHandle* res[1] = {
nullptr};
15434 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15435 status_check(context::get_status());
15436 return tensor(res[0]);
15439 inline tensor queue_dequeue_many_v2(
const tensor& handle,
const tensor& n,
const std::vector<datatype>& component_types,
15440 int64_t timeout_ms = -1) {
15442 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15443 TFE_NewOp(context::get_context(),
"QueueDequeueManyV2", context::get_status()), &TFE_DeleteOp);
15444 status_check(context::get_status());
15448 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15449 status_check(context::get_status());
15451 TFE_OpAddInput(op.get(), n.tfe_handle.get(), context::get_status());
15452 status_check(context::get_status());
15455 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15456 static_cast<int>(component_types.size()));
15457 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15460 int num_outputs_op = 1;
15461 TFE_TensorHandle* res[1] = {
nullptr};
15462 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15463 status_check(context::get_status());
15464 return tensor(res[0]);
15467 inline tensor queue_dequeue_up_to(
const tensor& handle,
const tensor& n,
const std::vector<datatype>& component_types,
15468 int64_t timeout_ms = -1) {
15470 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15471 TFE_NewOp(context::get_context(),
"QueueDequeueUpTo", context::get_status()), &TFE_DeleteOp);
15472 status_check(context::get_status());
15476 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15477 status_check(context::get_status());
15479 TFE_OpAddInput(op.get(), n.tfe_handle.get(), context::get_status());
15480 status_check(context::get_status());
15483 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15484 static_cast<int>(component_types.size()));
15485 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15488 int num_outputs_op = 1;
15489 TFE_TensorHandle* res[1] = {
nullptr};
15490 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15491 status_check(context::get_status());
15492 return tensor(res[0]);
15495 inline tensor queue_dequeue_up_to_v2(
const tensor& handle,
const tensor& n,
15496 const std::vector<datatype>& component_types, int64_t timeout_ms = -1) {
15498 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15499 TFE_NewOp(context::get_context(),
"QueueDequeueUpToV2", context::get_status()), &TFE_DeleteOp);
15500 status_check(context::get_status());
15504 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15505 status_check(context::get_status());
15507 TFE_OpAddInput(op.get(), n.tfe_handle.get(), context::get_status());
15508 status_check(context::get_status());
15511 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15512 static_cast<int>(component_types.size()));
15513 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15516 int num_outputs_op = 1;
15517 TFE_TensorHandle* res[1] = {
nullptr};
15518 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15519 status_check(context::get_status());
15520 return tensor(res[0]);
15523 inline tensor queue_dequeue_v2(
const tensor& handle,
const std::vector<datatype>& component_types,
15524 int64_t timeout_ms = -1) {
15526 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15527 TFE_NewOp(context::get_context(),
"QueueDequeueV2", context::get_status()), &TFE_DeleteOp);
15528 status_check(context::get_status());
15532 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15533 status_check(context::get_status());
15536 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
15537 static_cast<int>(component_types.size()));
15538 TFE_OpSetAttrInt(op.get(),
"timeout_ms", timeout_ms);
15541 int num_outputs_op = 1;
15542 TFE_TensorHandle* res[1] = {
nullptr};
15543 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15544 status_check(context::get_status());
15545 return tensor(res[0]);
15548 inline tensor queue_is_closed(
const tensor& handle) {
15550 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15551 TFE_NewOp(context::get_context(),
"QueueIsClosed", context::get_status()), &TFE_DeleteOp);
15552 status_check(context::get_status());
15556 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15557 status_check(context::get_status());
15562 int num_outputs_op = 1;
15563 TFE_TensorHandle* res[1] = {
nullptr};
15564 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15565 status_check(context::get_status());
15566 return tensor(res[0]);
15569 inline tensor queue_is_closed_v2(
const tensor& handle) {
15571 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15572 TFE_NewOp(context::get_context(),
"QueueIsClosedV2", context::get_status()), &TFE_DeleteOp);
15573 status_check(context::get_status());
15577 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15578 status_check(context::get_status());
15583 int num_outputs_op = 1;
15584 TFE_TensorHandle* res[1] = {
nullptr};
15585 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15586 status_check(context::get_status());
15587 return tensor(res[0]);
15590 inline tensor queue_size(
const tensor& handle) {
15592 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15593 TFE_NewOp(context::get_context(),
"QueueSize", context::get_status()), &TFE_DeleteOp);
15594 status_check(context::get_status());
15598 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15599 status_check(context::get_status());
15604 int num_outputs_op = 1;
15605 TFE_TensorHandle* res[1] = {
nullptr};
15606 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15607 status_check(context::get_status());
15608 return tensor(res[0]);
15611 inline tensor queue_size_v2(
const tensor& handle) {
15613 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15614 TFE_NewOp(context::get_context(),
"QueueSizeV2", context::get_status()), &TFE_DeleteOp);
15615 status_check(context::get_status());
15619 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
15620 status_check(context::get_status());
15625 int num_outputs_op = 1;
15626 TFE_TensorHandle* res[1] = {
nullptr};
15627 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15628 status_check(context::get_status());
15629 return tensor(res[0]);
15632 inline tensor r_f_f_t(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
15633 datatype Tcomplex =
static_cast<datatype
>(8)) {
15635 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"RFFT", context::get_status()),
15637 status_check(context::get_status());
15641 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15642 status_check(context::get_status());
15644 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
15645 status_check(context::get_status());
15648 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
15649 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
15652 int num_outputs_op = 1;
15653 TFE_TensorHandle* res[1] = {
nullptr};
15654 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15655 status_check(context::get_status());
15656 return tensor(res[0]);
15659 inline tensor r_f_f_t2_d(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
15660 datatype Tcomplex =
static_cast<datatype
>(8)) {
15662 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15663 TFE_NewOp(context::get_context(),
"RFFT2D", context::get_status()), &TFE_DeleteOp);
15664 status_check(context::get_status());
15668 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15669 status_check(context::get_status());
15671 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
15672 status_check(context::get_status());
15675 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
15676 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
15679 int num_outputs_op = 1;
15680 TFE_TensorHandle* res[1] = {
nullptr};
15681 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15682 status_check(context::get_status());
15683 return tensor(res[0]);
15686 inline tensor r_f_f_t3_d(
const tensor& input,
const tensor& fft_length, datatype Treal =
static_cast<datatype
>(1),
15687 datatype Tcomplex =
static_cast<datatype
>(8)) {
15689 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15690 TFE_NewOp(context::get_context(),
"RFFT3D", context::get_status()), &TFE_DeleteOp);
15691 status_check(context::get_status());
15695 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
15696 status_check(context::get_status());
15698 TFE_OpAddInput(op.get(), fft_length.tfe_handle.get(), context::get_status());
15699 status_check(context::get_status());
15702 TFE_OpSetAttrType(op.get(),
"Treal", Treal);
15703 TFE_OpSetAttrType(op.get(),
"Tcomplex", Tcomplex);
15706 int num_outputs_op = 1;
15707 TFE_TensorHandle* res[1] = {
nullptr};
15708 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15709 status_check(context::get_status());
15710 return tensor(res[0]);
15713 inline tensor r_g_b_to_h_s_v(
const tensor& images) {
15715 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15716 TFE_NewOp(context::get_context(),
"RGBToHSV", context::get_status()), &TFE_DeleteOp);
15717 status_check(context::get_status());
15721 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
15722 status_check(context::get_status());
15727 int num_outputs_op = 1;
15728 TFE_TensorHandle* res[1] = {
nullptr};
15729 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15730 status_check(context::get_status());
15731 return tensor(res[0]);
15734 inline tensor ragged_bincount(
const tensor& splits,
const tensor& values,
const tensor& size,
const tensor& weights,
15735 datatype Tidx,
bool binary_output =
false) {
15737 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15738 TFE_NewOp(context::get_context(),
"RaggedBincount", context::get_status()), &TFE_DeleteOp);
15739 status_check(context::get_status());
15743 TFE_OpAddInput(op.get(), splits.tfe_handle.get(), context::get_status());
15744 status_check(context::get_status());
15746 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
15747 status_check(context::get_status());
15749 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
15750 status_check(context::get_status());
15752 TFE_OpAddInput(op.get(), weights.tfe_handle.get(), context::get_status());
15753 status_check(context::get_status());
15756 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
15757 TFE_OpSetAttrBool(op.get(),
"binary_output", (
unsigned char)binary_output);
15760 int num_outputs_op = 1;
15761 TFE_TensorHandle* res[1] = {
nullptr};
15762 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15763 status_check(context::get_status());
15764 return tensor(res[0]);
15767 inline tensor ragged_tensor_to_tensor(
const tensor& shape,
const tensor& values,
const tensor& default_value,
15768 const std::vector<tensor>& row_partition_tensors, datatype Tindex,
15769 datatype Tshape,
const std::vector<std::string>& row_partition_types) {
15771 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15772 TFE_NewOp(context::get_context(),
"RaggedTensorToTensor", context::get_status()), &TFE_DeleteOp);
15773 status_check(context::get_status());
15777 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
15778 status_check(context::get_status());
15780 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
15781 status_check(context::get_status());
15783 TFE_OpAddInput(op.get(), default_value.tfe_handle.get(), context::get_status());
15784 status_check(context::get_status());
15786 std::vector<TFE_TensorHandle*> row_partition_tensors_handles;
15787 row_partition_tensors_handles.reserve(row_partition_tensors.size());
15788 std::transform(row_partition_tensors.begin(), row_partition_tensors.end(),
15789 std::back_inserter(row_partition_tensors_handles), [](
const auto& t) { return t.tfe_handle.get(); });
15790 TFE_OpAddInputList(op.get(), row_partition_tensors_handles.data(),
static_cast<int>(row_partition_tensors.size()),
15791 context::get_status());
15792 status_check(context::get_status());
15795 TFE_OpSetAttrType(op.get(),
"Tindex", Tindex);
15796 TFE_OpSetAttrType(op.get(),
"Tshape", Tshape);
15797 TFE_OpSetAttrInt(op.get(),
"num_row_partition_tensors", row_partition_tensors.size());
15799 std::vector<std::size_t> row_partition_types_sizes;
15800 row_partition_types_sizes.reserve(row_partition_types.size());
15801 std::transform(row_partition_types.begin(), row_partition_types.end(), std::back_inserter(row_partition_types_sizes),
15802 [](
const auto& s) { return s.size(); });
15803 TFE_OpSetAttrStringList(op.get(),
"row_partition_types",
15804 reinterpret_cast<const void* const*
>(row_partition_types.data()),
15805 row_partition_types_sizes.data(),
static_cast<int>(row_partition_types.size()));
15808 int num_outputs_op = 1;
15809 TFE_TensorHandle* res[1] = {
nullptr};
15810 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15811 status_check(context::get_status());
15812 return tensor(res[0]);
15815 inline tensor ragged_tensor_to_variant(
const std::vector<tensor>& rt_nested_splits,
const tensor& rt_dense_values,
15816 datatype Tvalues,
bool batched_input,
15817 datatype Tsplits =
static_cast<datatype
>(9)) {
15819 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15820 TFE_NewOp(context::get_context(),
"RaggedTensorToVariant", context::get_status()), &TFE_DeleteOp);
15821 status_check(context::get_status());
15825 std::vector<TFE_TensorHandle*> rt_nested_splits_handles;
15826 rt_nested_splits_handles.reserve(rt_nested_splits.size());
15827 std::transform(rt_nested_splits.begin(), rt_nested_splits.end(), std::back_inserter(rt_nested_splits_handles),
15828 [](
const auto& t) { return t.tfe_handle.get(); });
15829 TFE_OpAddInputList(op.get(), rt_nested_splits_handles.data(),
static_cast<int>(rt_nested_splits.size()),
15830 context::get_status());
15831 status_check(context::get_status());
15833 TFE_OpAddInput(op.get(), rt_dense_values.tfe_handle.get(), context::get_status());
15834 status_check(context::get_status());
15837 TFE_OpSetAttrInt(op.get(),
"RAGGED_RANK", rt_nested_splits.size());
15838 TFE_OpSetAttrType(op.get(),
"Tvalues", Tvalues);
15839 TFE_OpSetAttrBool(op.get(),
"batched_input", (
unsigned char)batched_input);
15840 TFE_OpSetAttrType(op.get(),
"Tsplits", Tsplits);
15843 int num_outputs_op = 1;
15844 TFE_TensorHandle* res[1] = {
nullptr};
15845 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15846 status_check(context::get_status());
15847 return tensor(res[0]);
15850 inline tensor random_crop(
const tensor& image,
const tensor& size, int64_t seed = 0, int64_t seed2 = 0) {
15852 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15853 TFE_NewOp(context::get_context(),
"RandomCrop", context::get_status()), &TFE_DeleteOp);
15854 status_check(context::get_status());
15858 TFE_OpAddInput(op.get(), image.tfe_handle.get(), context::get_status());
15859 status_check(context::get_status());
15861 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
15862 status_check(context::get_status());
15865 TFE_OpSetAttrInt(op.get(),
"seed", seed);
15866 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
15869 int num_outputs_op = 1;
15870 TFE_TensorHandle* res[1] = {
nullptr};
15871 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15872 status_check(context::get_status());
15873 return tensor(res[0]);
15876 inline tensor random_dataset(
const tensor& seed,
const tensor& seed2,
const std::vector<datatype>& output_types,
15877 const std::vector<std::vector<int64_t>>& output_shapes) {
15879 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15880 TFE_NewOp(context::get_context(),
"RandomDataset", context::get_status()), &TFE_DeleteOp);
15881 status_check(context::get_status());
15885 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
15886 status_check(context::get_status());
15888 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
15889 status_check(context::get_status());
15892 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
15893 static_cast<int>(output_types.size()));
15895 std::vector<const int64_t*> output_shapes_values;
15896 output_shapes_values.reserve(output_shapes.size());
15897 std::vector<int> output_shapes_ndims;
15898 output_shapes_ndims.reserve(output_shapes.size());
15899 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
15900 [](
const auto& v) { return v.data(); });
15901 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
15902 [](
const auto& v) { return static_cast<int>(v.size()); });
15903 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
15904 static_cast<int>(output_shapes.size()), context::get_status());
15905 status_check(context::get_status());
15908 int num_outputs_op = 1;
15909 TFE_TensorHandle* res[1] = {
nullptr};
15910 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15911 status_check(context::get_status());
15912 return tensor(res[0]);
15915 inline tensor random_gamma(
const tensor& shape,
const tensor& alpha, datatype S, int64_t seed = 0, int64_t seed2 = 0) {
15917 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15918 TFE_NewOp(context::get_context(),
"RandomGamma", context::get_status()), &TFE_DeleteOp);
15919 status_check(context::get_status());
15923 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
15924 status_check(context::get_status());
15926 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
15927 status_check(context::get_status());
15930 TFE_OpSetAttrType(op.get(),
"S", S);
15931 TFE_OpSetAttrInt(op.get(),
"seed", seed);
15932 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
15935 int num_outputs_op = 1;
15936 TFE_TensorHandle* res[1] = {
nullptr};
15937 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15938 status_check(context::get_status());
15939 return tensor(res[0]);
15942 inline tensor random_gamma_grad(
const tensor& alpha,
const tensor& sample) {
15944 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15945 TFE_NewOp(context::get_context(),
"RandomGammaGrad", context::get_status()), &TFE_DeleteOp);
15946 status_check(context::get_status());
15950 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
15951 status_check(context::get_status());
15953 TFE_OpAddInput(op.get(), sample.tfe_handle.get(), context::get_status());
15954 status_check(context::get_status());
15959 int num_outputs_op = 1;
15960 TFE_TensorHandle* res[1] = {
nullptr};
15961 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15962 status_check(context::get_status());
15963 return tensor(res[0]);
15966 inline tensor random_poisson(
const tensor& shape,
const tensor& rate, datatype S, datatype dtype, int64_t seed = 0,
15967 int64_t seed2 = 0) {
15969 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
15970 TFE_NewOp(context::get_context(),
"RandomPoisson", context::get_status()), &TFE_DeleteOp);
15971 status_check(context::get_status());
15975 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
15976 status_check(context::get_status());
15978 TFE_OpAddInput(op.get(), rate.tfe_handle.get(), context::get_status());
15979 status_check(context::get_status());
15982 TFE_OpSetAttrType(op.get(),
"S", S);
15983 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
15984 TFE_OpSetAttrInt(op.get(),
"seed", seed);
15985 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
15988 int num_outputs_op = 1;
15989 TFE_TensorHandle* res[1] = {
nullptr};
15990 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
15991 status_check(context::get_status());
15992 return tensor(res[0]);
15995 inline tensor random_poisson_v2(
const tensor& shape,
const tensor& rate, datatype S, int64_t seed = 0,
15996 int64_t seed2 = 0, datatype R =
static_cast<datatype
>(2),
15997 datatype dtype =
static_cast<datatype
>(9)) {
15999 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16000 TFE_NewOp(context::get_context(),
"RandomPoissonV2", context::get_status()), &TFE_DeleteOp);
16001 status_check(context::get_status());
16005 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
16006 status_check(context::get_status());
16008 TFE_OpAddInput(op.get(), rate.tfe_handle.get(), context::get_status());
16009 status_check(context::get_status());
16012 TFE_OpSetAttrType(op.get(),
"S", S);
16013 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16014 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16015 TFE_OpSetAttrType(op.get(),
"R", R);
16016 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
16019 int num_outputs_op = 1;
16020 TFE_TensorHandle* res[1] = {
nullptr};
16021 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16022 status_check(context::get_status());
16023 return tensor(res[0]);
16026 inline tensor random_shuffle(
const tensor& value, int64_t seed = 0, int64_t seed2 = 0) {
16028 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16029 TFE_NewOp(context::get_context(),
"RandomShuffle", context::get_status()), &TFE_DeleteOp);
16030 status_check(context::get_status());
16034 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
16035 status_check(context::get_status());
16038 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16039 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16042 int num_outputs_op = 1;
16043 TFE_TensorHandle* res[1] = {
nullptr};
16044 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16045 status_check(context::get_status());
16046 return tensor(res[0]);
16049 inline tensor random_shuffle_queue(
const std::vector<datatype>& component_types,
16050 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
16051 int64_t min_after_dequeue = 0, int64_t seed = 0, int64_t seed2 = 0,
16052 const std::string& container =
"",
const std::string& shared_name =
"") {
16054 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16055 TFE_NewOp(context::get_context(),
"RandomShuffleQueue", context::get_status()), &TFE_DeleteOp);
16056 status_check(context::get_status());
16061 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
16062 static_cast<int>(component_types.size()));
16064 std::vector<const int64_t*> shapes_values;
16065 shapes_values.reserve(shapes.size());
16066 std::vector<int> shapes_ndims;
16067 shapes_ndims.reserve(shapes.size());
16068 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
16069 [](
const auto& v) { return v.data(); });
16070 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
16071 [](
const auto& v) { return static_cast<int>(v.size()); });
16072 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
16073 context::get_status());
16074 status_check(context::get_status());
16076 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
16077 TFE_OpSetAttrInt(op.get(),
"min_after_dequeue", min_after_dequeue);
16078 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16079 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16080 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
16081 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
16084 int num_outputs_op = 1;
16085 TFE_TensorHandle* res[1] = {
nullptr};
16086 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16087 status_check(context::get_status());
16088 return tensor(res[0]);
16091 inline tensor random_shuffle_queue_v2(
const std::vector<datatype>& component_types,
16092 const std::vector<std::vector<int64_t>>& shapes, int64_t capacity = -1,
16093 int64_t min_after_dequeue = 0, int64_t seed = 0, int64_t seed2 = 0,
16094 const std::string& container =
"",
const std::string& shared_name =
"") {
16096 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16097 TFE_NewOp(context::get_context(),
"RandomShuffleQueueV2", context::get_status()), &TFE_DeleteOp);
16098 status_check(context::get_status());
16103 TFE_OpSetAttrTypeList(op.get(),
"component_types",
reinterpret_cast<const enum TF_DataType*
>(component_types.data()),
16104 static_cast<int>(component_types.size()));
16106 std::vector<const int64_t*> shapes_values;
16107 shapes_values.reserve(shapes.size());
16108 std::vector<int> shapes_ndims;
16109 shapes_ndims.reserve(shapes.size());
16110 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_values),
16111 [](
const auto& v) { return v.data(); });
16112 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_ndims),
16113 [](
const auto& v) { return static_cast<int>(v.size()); });
16114 TFE_OpSetAttrShapeList(op.get(),
"shapes", shapes_values.data(), shapes_ndims.data(),
static_cast<int>(shapes.size()),
16115 context::get_status());
16116 status_check(context::get_status());
16118 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
16119 TFE_OpSetAttrInt(op.get(),
"min_after_dequeue", min_after_dequeue);
16120 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16121 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16122 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
16123 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
16126 int num_outputs_op = 1;
16127 TFE_TensorHandle* res[1] = {
nullptr};
16128 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16129 status_check(context::get_status());
16130 return tensor(res[0]);
16133 inline tensor random_standard_normal(
const tensor& shape, datatype dtype, int64_t seed = 0, int64_t seed2 = 0) {
16135 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16136 TFE_NewOp(context::get_context(),
"RandomStandardNormal", context::get_status()), &TFE_DeleteOp);
16137 status_check(context::get_status());
16141 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
16142 status_check(context::get_status());
16145 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
16146 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16147 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16150 int num_outputs_op = 1;
16151 TFE_TensorHandle* res[1] = {
nullptr};
16152 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16153 status_check(context::get_status());
16154 return tensor(res[0]);
16157 inline tensor random_uniform(
const tensor& shape, datatype dtype, int64_t seed = 0, int64_t seed2 = 0) {
16159 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16160 TFE_NewOp(context::get_context(),
"RandomUniform", context::get_status()), &TFE_DeleteOp);
16161 status_check(context::get_status());
16165 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
16166 status_check(context::get_status());
16169 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
16170 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16171 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16174 int num_outputs_op = 1;
16175 TFE_TensorHandle* res[1] = {
nullptr};
16176 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16177 status_check(context::get_status());
16178 return tensor(res[0]);
16181 inline tensor random_uniform_int(
const tensor& shape,
const tensor& minval,
const tensor& maxval, datatype Tout,
16182 int64_t seed = 0, int64_t seed2 = 0) {
16184 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16185 TFE_NewOp(context::get_context(),
"RandomUniformInt", context::get_status()), &TFE_DeleteOp);
16186 status_check(context::get_status());
16190 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
16191 status_check(context::get_status());
16193 TFE_OpAddInput(op.get(), minval.tfe_handle.get(), context::get_status());
16194 status_check(context::get_status());
16196 TFE_OpAddInput(op.get(), maxval.tfe_handle.get(), context::get_status());
16197 status_check(context::get_status());
16200 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
16201 TFE_OpSetAttrInt(op.get(),
"seed", seed);
16202 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
16205 int num_outputs_op = 1;
16206 TFE_TensorHandle* res[1] = {
nullptr};
16207 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16208 status_check(context::get_status());
16209 return tensor(res[0]);
16212 inline tensor range(
const tensor& start,
const tensor& limit,
const tensor& delta,
16213 datatype Tidx =
static_cast<datatype
>(3)) {
16215 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Range", context::get_status()),
16217 status_check(context::get_status());
16221 TFE_OpAddInput(op.get(), start.tfe_handle.get(), context::get_status());
16222 status_check(context::get_status());
16224 TFE_OpAddInput(op.get(), limit.tfe_handle.get(), context::get_status());
16225 status_check(context::get_status());
16227 TFE_OpAddInput(op.get(), delta.tfe_handle.get(), context::get_status());
16228 status_check(context::get_status());
16231 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
16234 int num_outputs_op = 1;
16235 TFE_TensorHandle* res[1] = {
nullptr};
16236 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16237 status_check(context::get_status());
16238 return tensor(res[0]);
16241 inline tensor range_dataset(
const tensor& start,
const tensor& stop,
const tensor& step,
16242 const std::vector<datatype>& output_types,
16243 const std::vector<std::vector<int64_t>>& output_shapes) {
16245 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16246 TFE_NewOp(context::get_context(),
"RangeDataset", context::get_status()), &TFE_DeleteOp);
16247 status_check(context::get_status());
16251 TFE_OpAddInput(op.get(), start.tfe_handle.get(), context::get_status());
16252 status_check(context::get_status());
16254 TFE_OpAddInput(op.get(), stop.tfe_handle.get(), context::get_status());
16255 status_check(context::get_status());
16257 TFE_OpAddInput(op.get(), step.tfe_handle.get(), context::get_status());
16258 status_check(context::get_status());
16261 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
16262 static_cast<int>(output_types.size()));
16264 std::vector<const int64_t*> output_shapes_values;
16265 output_shapes_values.reserve(output_shapes.size());
16266 std::vector<int> output_shapes_ndims;
16267 output_shapes_ndims.reserve(output_shapes.size());
16268 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
16269 [](
const auto& v) { return v.data(); });
16270 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
16271 [](
const auto& v) { return static_cast<int>(v.size()); });
16272 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
16273 static_cast<int>(output_shapes.size()), context::get_status());
16274 status_check(context::get_status());
16277 int num_outputs_op = 1;
16278 TFE_TensorHandle* res[1] = {
nullptr};
16279 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16280 status_check(context::get_status());
16281 return tensor(res[0]);
16284 inline tensor rank(
const tensor& input) {
16286 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Rank", context::get_status()),
16288 status_check(context::get_status());
16292 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
16293 status_check(context::get_status());
16298 int num_outputs_op = 1;
16299 TFE_TensorHandle* res[1] = {
nullptr};
16300 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16301 status_check(context::get_status());
16302 return tensor(res[0]);
16305 inline tensor read_file(
const tensor& filename) {
16307 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16308 TFE_NewOp(context::get_context(),
"ReadFile", context::get_status()), &TFE_DeleteOp);
16309 status_check(context::get_status());
16313 TFE_OpAddInput(op.get(), filename.tfe_handle.get(), context::get_status());
16314 status_check(context::get_status());
16319 int num_outputs_op = 1;
16320 TFE_TensorHandle* res[1] = {
nullptr};
16321 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16322 status_check(context::get_status());
16323 return tensor(res[0]);
16326 inline tensor read_variable_op(
const tensor& resource, datatype dtype) {
16328 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16329 TFE_NewOp(context::get_context(),
"ReadVariableOp", context::get_status()), &TFE_DeleteOp);
16330 status_check(context::get_status());
16334 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
16335 status_check(context::get_status());
16338 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
16341 int num_outputs_op = 1;
16342 TFE_TensorHandle* res[1] = {
nullptr};
16343 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16344 status_check(context::get_status());
16345 return tensor(res[0]);
16348 inline tensor reader_num_records_produced(
const tensor& reader_handle) {
16350 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16351 TFE_NewOp(context::get_context(),
"ReaderNumRecordsProduced", context::get_status()), &TFE_DeleteOp);
16352 status_check(context::get_status());
16356 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16357 status_check(context::get_status());
16362 int num_outputs_op = 1;
16363 TFE_TensorHandle* res[1] = {
nullptr};
16364 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16365 status_check(context::get_status());
16366 return tensor(res[0]);
16369 inline tensor reader_num_records_produced_v2(
const tensor& reader_handle) {
16371 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16372 TFE_NewOp(context::get_context(),
"ReaderNumRecordsProducedV2", context::get_status()), &TFE_DeleteOp);
16373 status_check(context::get_status());
16377 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16378 status_check(context::get_status());
16383 int num_outputs_op = 1;
16384 TFE_TensorHandle* res[1] = {
nullptr};
16385 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16386 status_check(context::get_status());
16387 return tensor(res[0]);
16390 inline tensor reader_num_work_units_completed(
const tensor& reader_handle) {
16392 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16393 TFE_NewOp(context::get_context(),
"ReaderNumWorkUnitsCompleted", context::get_status()), &TFE_DeleteOp);
16394 status_check(context::get_status());
16398 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16399 status_check(context::get_status());
16404 int num_outputs_op = 1;
16405 TFE_TensorHandle* res[1] = {
nullptr};
16406 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16407 status_check(context::get_status());
16408 return tensor(res[0]);
16411 inline tensor reader_num_work_units_completed_v2(
const tensor& reader_handle) {
16413 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16414 TFE_NewOp(context::get_context(),
"ReaderNumWorkUnitsCompletedV2", context::get_status()), &TFE_DeleteOp);
16415 status_check(context::get_status());
16419 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16420 status_check(context::get_status());
16425 int num_outputs_op = 1;
16426 TFE_TensorHandle* res[1] = {
nullptr};
16427 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16428 status_check(context::get_status());
16429 return tensor(res[0]);
16432 inline tensor reader_serialize_state(
const tensor& reader_handle) {
16434 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16435 TFE_NewOp(context::get_context(),
"ReaderSerializeState", context::get_status()), &TFE_DeleteOp);
16436 status_check(context::get_status());
16440 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16441 status_check(context::get_status());
16446 int num_outputs_op = 1;
16447 TFE_TensorHandle* res[1] = {
nullptr};
16448 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16449 status_check(context::get_status());
16450 return tensor(res[0]);
16453 inline tensor reader_serialize_state_v2(
const tensor& reader_handle) {
16455 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16456 TFE_NewOp(context::get_context(),
"ReaderSerializeStateV2", context::get_status()), &TFE_DeleteOp);
16457 status_check(context::get_status());
16461 TFE_OpAddInput(op.get(), reader_handle.tfe_handle.get(), context::get_status());
16462 status_check(context::get_status());
16467 int num_outputs_op = 1;
16468 TFE_TensorHandle* res[1] = {
nullptr};
16469 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16470 status_check(context::get_status());
16471 return tensor(res[0]);
16474 inline tensor real(
const tensor& input, datatype Tout =
static_cast<datatype
>(1)) {
16476 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Real", context::get_status()),
16478 status_check(context::get_status());
16482 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
16483 status_check(context::get_status());
16486 TFE_OpSetAttrType(op.get(),
"Tout", Tout);
16489 int num_outputs_op = 1;
16490 TFE_TensorHandle* res[1] = {
nullptr};
16491 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16492 status_check(context::get_status());
16493 return tensor(res[0]);
16496 inline tensor real_div(
const tensor& x,
const tensor& y) {
16498 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16499 TFE_NewOp(context::get_context(),
"RealDiv", context::get_status()), &TFE_DeleteOp);
16500 status_check(context::get_status());
16504 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
16505 status_check(context::get_status());
16507 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
16508 status_check(context::get_status());
16513 int num_outputs_op = 1;
16514 TFE_TensorHandle* res[1] = {
nullptr};
16515 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16516 status_check(context::get_status());
16517 return tensor(res[0]);
16520 inline tensor rebatch_dataset(
const tensor& input_dataset,
const tensor& num_replicas,
16521 const std::vector<datatype>& output_types,
16522 const std::vector<std::vector<int64_t>>& output_shapes,
bool use_fallback =
true) {
16524 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16525 TFE_NewOp(context::get_context(),
"RebatchDataset", context::get_status()), &TFE_DeleteOp);
16526 status_check(context::get_status());
16530 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
16531 status_check(context::get_status());
16533 TFE_OpAddInput(op.get(), num_replicas.tfe_handle.get(), context::get_status());
16534 status_check(context::get_status());
16537 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
16538 static_cast<int>(output_types.size()));
16540 std::vector<const int64_t*> output_shapes_values;
16541 output_shapes_values.reserve(output_shapes.size());
16542 std::vector<int> output_shapes_ndims;
16543 output_shapes_ndims.reserve(output_shapes.size());
16544 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
16545 [](
const auto& v) { return v.data(); });
16546 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
16547 [](
const auto& v) { return static_cast<int>(v.size()); });
16548 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
16549 static_cast<int>(output_shapes.size()), context::get_status());
16550 status_check(context::get_status());
16552 TFE_OpSetAttrBool(op.get(),
"use_fallback", (
unsigned char)use_fallback);
16555 int num_outputs_op = 1;
16556 TFE_TensorHandle* res[1] = {
nullptr};
16557 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16558 status_check(context::get_status());
16559 return tensor(res[0]);
16562 inline tensor reciprocal(
const tensor& x) {
16564 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16565 TFE_NewOp(context::get_context(),
"Reciprocal", context::get_status()), &TFE_DeleteOp);
16566 status_check(context::get_status());
16570 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
16571 status_check(context::get_status());
16576 int num_outputs_op = 1;
16577 TFE_TensorHandle* res[1] = {
nullptr};
16578 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16579 status_check(context::get_status());
16580 return tensor(res[0]);
16583 inline tensor reciprocal_grad(
const tensor& y,
const tensor& dy) {
16585 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16586 TFE_NewOp(context::get_context(),
"ReciprocalGrad", context::get_status()), &TFE_DeleteOp);
16587 status_check(context::get_status());
16591 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
16592 status_check(context::get_status());
16594 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
16595 status_check(context::get_status());
16600 int num_outputs_op = 1;
16601 TFE_TensorHandle* res[1] = {
nullptr};
16602 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16603 status_check(context::get_status());
16604 return tensor(res[0]);
16607 inline tensor record_input(
const std::string& file_pattern, int64_t file_random_seed = 301,
16608 float file_shuffle_shift_ratio = 0.0000e+00, int64_t file_buffer_size = 10000,
16609 int64_t file_parallelism = 16, int64_t batch_size = 32,
16610 const std::string& compression_type =
"") {
16612 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16613 TFE_NewOp(context::get_context(),
"RecordInput", context::get_status()), &TFE_DeleteOp);
16614 status_check(context::get_status());
16619 TFE_OpSetAttrString(op.get(),
"file_pattern", (
void*)file_pattern.c_str(), file_pattern.size());
16620 TFE_OpSetAttrInt(op.get(),
"file_random_seed", file_random_seed);
16621 TFE_OpSetAttrFloat(op.get(),
"file_shuffle_shift_ratio", file_shuffle_shift_ratio);
16622 TFE_OpSetAttrInt(op.get(),
"file_buffer_size", file_buffer_size);
16623 TFE_OpSetAttrInt(op.get(),
"file_parallelism", file_parallelism);
16624 TFE_OpSetAttrInt(op.get(),
"batch_size", batch_size);
16625 TFE_OpSetAttrString(op.get(),
"compression_type", (
void*)compression_type.c_str(), compression_type.size());
16628 int num_outputs_op = 1;
16629 TFE_TensorHandle* res[1] = {
nullptr};
16630 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16631 status_check(context::get_status());
16632 return tensor(res[0]);
16635 inline tensor recv(datatype tensor_type,
const std::string& tensor_name,
const std::string& send_device,
16636 int64_t send_device_incarnation,
const std::string& recv_device,
bool client_terminated =
false) {
16638 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Recv", context::get_status()),
16640 status_check(context::get_status());
16645 TFE_OpSetAttrType(op.get(),
"tensor_type", tensor_type);
16646 TFE_OpSetAttrString(op.get(),
"tensor_name", (
void*)tensor_name.c_str(), tensor_name.size());
16647 TFE_OpSetAttrString(op.get(),
"send_device", (
void*)send_device.c_str(), send_device.size());
16648 TFE_OpSetAttrInt(op.get(),
"send_device_incarnation", send_device_incarnation);
16649 TFE_OpSetAttrString(op.get(),
"recv_device", (
void*)recv_device.c_str(), recv_device.size());
16650 TFE_OpSetAttrBool(op.get(),
"client_terminated", (
unsigned char)client_terminated);
16653 int num_outputs_op = 1;
16654 TFE_TensorHandle* res[1] = {
nullptr};
16655 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16656 status_check(context::get_status());
16657 return tensor(res[0]);
16660 inline tensor recv_t_p_u_embedding_activations(int64_t num_outputs,
const std::string& config) {
16662 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16663 TFE_NewOp(context::get_context(),
"RecvTPUEmbeddingActivations", context::get_status()), &TFE_DeleteOp);
16664 status_check(context::get_status());
16669 TFE_OpSetAttrInt(op.get(),
"num_outputs", num_outputs);
16670 TFE_OpSetAttrString(op.get(),
"config", (
void*)config.c_str(), config.size());
16673 int num_outputs_op = 1;
16674 TFE_TensorHandle* res[1] = {
nullptr};
16675 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16676 status_check(context::get_status());
16677 return tensor(res[0]);
16680 inline tensor reduce_join(
const tensor& inputs,
const tensor& reduction_indices,
bool keep_dims =
false,
16681 const std::string& separator =
"") {
16683 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16684 TFE_NewOp(context::get_context(),
"ReduceJoin", context::get_status()), &TFE_DeleteOp);
16685 status_check(context::get_status());
16689 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
16690 status_check(context::get_status());
16692 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
16693 status_check(context::get_status());
16696 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
16697 TFE_OpSetAttrString(op.get(),
"separator", (
void*)separator.c_str(), separator.size());
16700 int num_outputs_op = 1;
16701 TFE_TensorHandle* res[1] = {
nullptr};
16702 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16703 status_check(context::get_status());
16704 return tensor(res[0]);
16707 inline tensor ref_enter(
const tensor& data,
const std::string& frame_name,
bool is_constant =
false,
16708 int64_t parallel_iterations = 10) {
16710 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16711 TFE_NewOp(context::get_context(),
"RefEnter", context::get_status()), &TFE_DeleteOp);
16712 status_check(context::get_status());
16716 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
16717 status_check(context::get_status());
16720 TFE_OpSetAttrString(op.get(),
"frame_name", (
void*)frame_name.c_str(), frame_name.size());
16721 TFE_OpSetAttrBool(op.get(),
"is_constant", (
unsigned char)is_constant);
16722 TFE_OpSetAttrInt(op.get(),
"parallel_iterations", parallel_iterations);
16725 int num_outputs_op = 1;
16726 TFE_TensorHandle* res[1] = {
nullptr};
16727 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16728 status_check(context::get_status());
16729 return tensor(res[0]);
16732 inline tensor ref_exit(
const tensor& data) {
16734 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16735 TFE_NewOp(context::get_context(),
"RefExit", context::get_status()), &TFE_DeleteOp);
16736 status_check(context::get_status());
16740 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
16741 status_check(context::get_status());
16746 int num_outputs_op = 1;
16747 TFE_TensorHandle* res[1] = {
nullptr};
16748 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16749 status_check(context::get_status());
16750 return tensor(res[0]);
16753 inline tensor ref_identity(
const tensor& input) {
16755 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16756 TFE_NewOp(context::get_context(),
"RefIdentity", context::get_status()), &TFE_DeleteOp);
16757 status_check(context::get_status());
16761 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
16762 status_check(context::get_status());
16767 int num_outputs_op = 1;
16768 TFE_TensorHandle* res[1] = {
nullptr};
16769 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16770 status_check(context::get_status());
16771 return tensor(res[0]);
16774 inline tensor ref_next_iteration(
const tensor& data) {
16776 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16777 TFE_NewOp(context::get_context(),
"RefNextIteration", context::get_status()), &TFE_DeleteOp);
16778 status_check(context::get_status());
16782 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
16783 status_check(context::get_status());
16788 int num_outputs_op = 1;
16789 TFE_TensorHandle* res[1] = {
nullptr};
16790 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16791 status_check(context::get_status());
16792 return tensor(res[0]);
16795 inline tensor ref_select(
const tensor& index,
const std::vector<tensor>& inputs) {
16797 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16798 TFE_NewOp(context::get_context(),
"RefSelect", context::get_status()), &TFE_DeleteOp);
16799 status_check(context::get_status());
16803 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
16804 status_check(context::get_status());
16806 std::vector<TFE_TensorHandle*> inputs_handles;
16807 inputs_handles.reserve(inputs.size());
16808 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
16809 [](
const auto& t) { return t.tfe_handle.get(); });
16810 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
16811 status_check(context::get_status());
16814 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
16817 int num_outputs_op = 1;
16818 TFE_TensorHandle* res[1] = {
nullptr};
16819 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16820 status_check(context::get_status());
16821 return tensor(res[0]);
16824 inline tensor regex_full_match(
const tensor& input,
const tensor& pattern) {
16826 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16827 TFE_NewOp(context::get_context(),
"RegexFullMatch", context::get_status()), &TFE_DeleteOp);
16828 status_check(context::get_status());
16832 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
16833 status_check(context::get_status());
16835 TFE_OpAddInput(op.get(), pattern.tfe_handle.get(), context::get_status());
16836 status_check(context::get_status());
16841 int num_outputs_op = 1;
16842 TFE_TensorHandle* res[1] = {
nullptr};
16843 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16844 status_check(context::get_status());
16845 return tensor(res[0]);
16848 inline tensor regex_replace(
const tensor& input,
const tensor& pattern,
const tensor& rewrite,
16849 bool replace_global =
true) {
16851 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16852 TFE_NewOp(context::get_context(),
"RegexReplace", context::get_status()), &TFE_DeleteOp);
16853 status_check(context::get_status());
16857 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
16858 status_check(context::get_status());
16860 TFE_OpAddInput(op.get(), pattern.tfe_handle.get(), context::get_status());
16861 status_check(context::get_status());
16863 TFE_OpAddInput(op.get(), rewrite.tfe_handle.get(), context::get_status());
16864 status_check(context::get_status());
16867 TFE_OpSetAttrBool(op.get(),
"replace_global", (
unsigned char)replace_global);
16870 int num_outputs_op = 1;
16871 TFE_TensorHandle* res[1] = {
nullptr};
16872 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16873 status_check(context::get_status());
16874 return tensor(res[0]);
16877 inline tensor register_dataset(
const tensor& dataset,
const tensor& address,
const tensor& protocol,
16878 int64_t external_state_policy) {
16880 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16881 TFE_NewOp(context::get_context(),
"RegisterDataset", context::get_status()), &TFE_DeleteOp);
16882 status_check(context::get_status());
16886 TFE_OpAddInput(op.get(), dataset.tfe_handle.get(), context::get_status());
16887 status_check(context::get_status());
16889 TFE_OpAddInput(op.get(), address.tfe_handle.get(), context::get_status());
16890 status_check(context::get_status());
16892 TFE_OpAddInput(op.get(), protocol.tfe_handle.get(), context::get_status());
16893 status_check(context::get_status());
16896 TFE_OpSetAttrInt(op.get(),
"external_state_policy", external_state_policy);
16899 int num_outputs_op = 1;
16900 TFE_TensorHandle* res[1] = {
nullptr};
16901 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16902 status_check(context::get_status());
16903 return tensor(res[0]);
16906 inline tensor relu(
const tensor& features) {
16908 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Relu", context::get_status()),
16910 status_check(context::get_status());
16914 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
16915 status_check(context::get_status());
16920 int num_outputs_op = 1;
16921 TFE_TensorHandle* res[1] = {
nullptr};
16922 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16923 status_check(context::get_status());
16924 return tensor(res[0]);
16927 inline tensor relu6(
const tensor& features) {
16929 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Relu6", context::get_status()),
16931 status_check(context::get_status());
16935 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
16936 status_check(context::get_status());
16941 int num_outputs_op = 1;
16942 TFE_TensorHandle* res[1] = {
nullptr};
16943 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16944 status_check(context::get_status());
16945 return tensor(res[0]);
16948 inline tensor relu6_grad(
const tensor& gradients,
const tensor& features) {
16950 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16951 TFE_NewOp(context::get_context(),
"Relu6Grad", context::get_status()), &TFE_DeleteOp);
16952 status_check(context::get_status());
16956 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
16957 status_check(context::get_status());
16959 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
16960 status_check(context::get_status());
16965 int num_outputs_op = 1;
16966 TFE_TensorHandle* res[1] = {
nullptr};
16967 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16968 status_check(context::get_status());
16969 return tensor(res[0]);
16972 inline tensor relu_grad(
const tensor& gradients,
const tensor& features) {
16974 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
16975 TFE_NewOp(context::get_context(),
"ReluGrad", context::get_status()), &TFE_DeleteOp);
16976 status_check(context::get_status());
16980 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
16981 status_check(context::get_status());
16983 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
16984 status_check(context::get_status());
16989 int num_outputs_op = 1;
16990 TFE_TensorHandle* res[1] = {
nullptr};
16991 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
16992 status_check(context::get_status());
16993 return tensor(res[0]);
16996 inline tensor repeat_dataset(
const tensor& input_dataset,
const tensor& count,
16997 const std::vector<datatype>& output_types,
16998 const std::vector<std::vector<int64_t>>& output_shapes) {
17000 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17001 TFE_NewOp(context::get_context(),
"RepeatDataset", context::get_status()), &TFE_DeleteOp);
17002 status_check(context::get_status());
17006 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
17007 status_check(context::get_status());
17009 TFE_OpAddInput(op.get(), count.tfe_handle.get(), context::get_status());
17010 status_check(context::get_status());
17013 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
17014 static_cast<int>(output_types.size()));
17016 std::vector<const int64_t*> output_shapes_values;
17017 output_shapes_values.reserve(output_shapes.size());
17018 std::vector<int> output_shapes_ndims;
17019 output_shapes_ndims.reserve(output_shapes.size());
17020 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
17021 [](
const auto& v) { return v.data(); });
17022 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
17023 [](
const auto& v) { return static_cast<int>(v.size()); });
17024 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
17025 static_cast<int>(output_shapes.size()), context::get_status());
17026 status_check(context::get_status());
17029 int num_outputs_op = 1;
17030 TFE_TensorHandle* res[1] = {
nullptr};
17031 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17032 status_check(context::get_status());
17033 return tensor(res[0]);
17036 inline tensor reshape(
const tensor& input_tensor,
const tensor& shape, datatype Tshape =
static_cast<datatype
>(3)) {
17038 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17039 TFE_NewOp(context::get_context(),
"Reshape", context::get_status()), &TFE_DeleteOp);
17040 status_check(context::get_status());
17044 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
17045 status_check(context::get_status());
17047 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
17048 status_check(context::get_status());
17051 TFE_OpSetAttrType(op.get(),
"Tshape", Tshape);
17054 int num_outputs_op = 1;
17055 TFE_TensorHandle* res[1] = {
nullptr};
17056 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17057 status_check(context::get_status());
17058 return tensor(res[0]);
17061 inline tensor resize_area(
const tensor& images,
const tensor& size,
bool align_corners =
false) {
17063 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17064 TFE_NewOp(context::get_context(),
"ResizeArea", context::get_status()), &TFE_DeleteOp);
17065 status_check(context::get_status());
17069 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
17070 status_check(context::get_status());
17072 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17073 status_check(context::get_status());
17076 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17079 int num_outputs_op = 1;
17080 TFE_TensorHandle* res[1] = {
nullptr};
17081 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17082 status_check(context::get_status());
17083 return tensor(res[0]);
17086 inline tensor resize_bicubic(
const tensor& images,
const tensor& size,
bool align_corners =
false,
17087 bool half_pixel_centers =
false) {
17089 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17090 TFE_NewOp(context::get_context(),
"ResizeBicubic", context::get_status()), &TFE_DeleteOp);
17091 status_check(context::get_status());
17095 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
17096 status_check(context::get_status());
17098 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17099 status_check(context::get_status());
17102 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17103 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17106 int num_outputs_op = 1;
17107 TFE_TensorHandle* res[1] = {
nullptr};
17108 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17109 status_check(context::get_status());
17110 return tensor(res[0]);
17113 inline tensor resize_bicubic_grad(
const tensor& grads,
const tensor& original_image,
bool align_corners =
false,
17114 bool half_pixel_centers =
false) {
17116 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17117 TFE_NewOp(context::get_context(),
"ResizeBicubicGrad", context::get_status()), &TFE_DeleteOp);
17118 status_check(context::get_status());
17122 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
17123 status_check(context::get_status());
17125 TFE_OpAddInput(op.get(), original_image.tfe_handle.get(), context::get_status());
17126 status_check(context::get_status());
17129 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17130 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17133 int num_outputs_op = 1;
17134 TFE_TensorHandle* res[1] = {
nullptr};
17135 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17136 status_check(context::get_status());
17137 return tensor(res[0]);
17140 inline tensor resize_bilinear(
const tensor& images,
const tensor& size,
bool align_corners =
false,
17141 bool half_pixel_centers =
false) {
17143 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17144 TFE_NewOp(context::get_context(),
"ResizeBilinear", context::get_status()), &TFE_DeleteOp);
17145 status_check(context::get_status());
17149 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
17150 status_check(context::get_status());
17152 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17153 status_check(context::get_status());
17156 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17157 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17160 int num_outputs_op = 1;
17161 TFE_TensorHandle* res[1] = {
nullptr};
17162 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17163 status_check(context::get_status());
17164 return tensor(res[0]);
17167 inline tensor resize_bilinear_grad(
const tensor& grads,
const tensor& original_image,
bool align_corners =
false,
17168 bool half_pixel_centers =
false) {
17170 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17171 TFE_NewOp(context::get_context(),
"ResizeBilinearGrad", context::get_status()), &TFE_DeleteOp);
17172 status_check(context::get_status());
17176 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
17177 status_check(context::get_status());
17179 TFE_OpAddInput(op.get(), original_image.tfe_handle.get(), context::get_status());
17180 status_check(context::get_status());
17183 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17184 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17187 int num_outputs_op = 1;
17188 TFE_TensorHandle* res[1] = {
nullptr};
17189 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17190 status_check(context::get_status());
17191 return tensor(res[0]);
17194 inline tensor resize_nearest_neighbor(
const tensor& images,
const tensor& size,
bool align_corners =
false,
17195 bool half_pixel_centers =
false) {
17197 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17198 TFE_NewOp(context::get_context(),
"ResizeNearestNeighbor", context::get_status()), &TFE_DeleteOp);
17199 status_check(context::get_status());
17203 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
17204 status_check(context::get_status());
17206 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17207 status_check(context::get_status());
17210 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17211 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17214 int num_outputs_op = 1;
17215 TFE_TensorHandle* res[1] = {
nullptr};
17216 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17217 status_check(context::get_status());
17218 return tensor(res[0]);
17221 inline tensor resize_nearest_neighbor_grad(
const tensor& grads,
const tensor& size,
bool align_corners =
false,
17222 bool half_pixel_centers =
false) {
17224 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17225 TFE_NewOp(context::get_context(),
"ResizeNearestNeighborGrad", context::get_status()), &TFE_DeleteOp);
17226 status_check(context::get_status());
17230 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
17231 status_check(context::get_status());
17233 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17234 status_check(context::get_status());
17237 TFE_OpSetAttrBool(op.get(),
"align_corners", (
unsigned char)align_corners);
17238 TFE_OpSetAttrBool(op.get(),
"half_pixel_centers", (
unsigned char)half_pixel_centers);
17241 int num_outputs_op = 1;
17242 TFE_TensorHandle* res[1] = {
nullptr};
17243 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17244 status_check(context::get_status());
17245 return tensor(res[0]);
17248 inline tensor resource_accumulator_num_accumulated(
const tensor& handle) {
17250 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17251 TFE_NewOp(context::get_context(),
"ResourceAccumulatorNumAccumulated", context::get_status()), &TFE_DeleteOp);
17252 status_check(context::get_status());
17256 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
17257 status_check(context::get_status());
17262 int num_outputs_op = 1;
17263 TFE_TensorHandle* res[1] = {
nullptr};
17264 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17265 status_check(context::get_status());
17266 return tensor(res[0]);
17269 inline tensor resource_accumulator_take_gradient(
const tensor& handle,
const tensor& num_required, datatype dtype) {
17271 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17272 TFE_NewOp(context::get_context(),
"ResourceAccumulatorTakeGradient", context::get_status()), &TFE_DeleteOp);
17273 status_check(context::get_status());
17277 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
17278 status_check(context::get_status());
17280 TFE_OpAddInput(op.get(), num_required.tfe_handle.get(), context::get_status());
17281 status_check(context::get_status());
17284 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
17287 int num_outputs_op = 1;
17288 TFE_TensorHandle* res[1] = {
nullptr};
17289 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17290 status_check(context::get_status());
17291 return tensor(res[0]);
17294 inline tensor resource_conditional_accumulator(datatype dtype,
const std::vector<int64_t>& shape,
17295 const std::string& container =
"",
const std::string& shared_name =
"",
17296 const std::string& reduction_type =
"MEAN") {
17298 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17299 TFE_NewOp(context::get_context(),
"ResourceConditionalAccumulator", context::get_status()), &TFE_DeleteOp);
17300 status_check(context::get_status());
17305 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
17307 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
17308 status_check(context::get_status());
17310 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
17311 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
17312 TFE_OpSetAttrString(op.get(),
"reduction_type", (
void*)reduction_type.c_str(), reduction_type.size());
17315 int num_outputs_op = 1;
17316 TFE_TensorHandle* res[1] = {
nullptr};
17317 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17318 status_check(context::get_status());
17319 return tensor(res[0]);
17322 inline tensor resource_count_up_to(
const tensor& resource, int64_t limit) {
17324 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17325 TFE_NewOp(context::get_context(),
"ResourceCountUpTo", context::get_status()), &TFE_DeleteOp);
17326 status_check(context::get_status());
17330 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
17331 status_check(context::get_status());
17334 TFE_OpSetAttrInt(op.get(),
"limit", limit);
17337 int num_outputs_op = 1;
17338 TFE_TensorHandle* res[1] = {
nullptr};
17339 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17340 status_check(context::get_status());
17341 return tensor(res[0]);
17344 inline tensor resource_gather(
const tensor& resource,
const tensor& indices, datatype dtype, datatype Tindices,
17345 int64_t batch_dims = 0,
bool validate_indices =
true) {
17347 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17348 TFE_NewOp(context::get_context(),
"ResourceGather", context::get_status()), &TFE_DeleteOp);
17349 status_check(context::get_status());
17353 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
17354 status_check(context::get_status());
17356 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17357 status_check(context::get_status());
17360 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
17361 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17362 TFE_OpSetAttrInt(op.get(),
"batch_dims", batch_dims);
17363 TFE_OpSetAttrBool(op.get(),
"validate_indices", (
unsigned char)validate_indices);
17366 int num_outputs_op = 1;
17367 TFE_TensorHandle* res[1] = {
nullptr};
17368 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17369 status_check(context::get_status());
17370 return tensor(res[0]);
17373 inline tensor resource_gather_nd(
const tensor& resource,
const tensor& indices, datatype dtype, datatype Tindices) {
17375 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17376 TFE_NewOp(context::get_context(),
"ResourceGatherNd", context::get_status()), &TFE_DeleteOp);
17377 status_check(context::get_status());
17381 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
17382 status_check(context::get_status());
17384 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17385 status_check(context::get_status());
17388 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
17389 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17392 int num_outputs_op = 1;
17393 TFE_TensorHandle* res[1] = {
nullptr};
17394 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17395 status_check(context::get_status());
17396 return tensor(res[0]);
17399 inline tensor restore(
const tensor& file_pattern,
const tensor& input_tensor_name, datatype dt,
17400 int64_t preferred_shard = -1) {
17402 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17403 TFE_NewOp(context::get_context(),
"Restore", context::get_status()), &TFE_DeleteOp);
17404 status_check(context::get_status());
17408 TFE_OpAddInput(op.get(), file_pattern.tfe_handle.get(), context::get_status());
17409 status_check(context::get_status());
17411 TFE_OpAddInput(op.get(), input_tensor_name.tfe_handle.get(), context::get_status());
17412 status_check(context::get_status());
17415 TFE_OpSetAttrType(op.get(),
"dt", dt);
17416 TFE_OpSetAttrInt(op.get(),
"preferred_shard", preferred_shard);
17419 int num_outputs_op = 1;
17420 TFE_TensorHandle* res[1] = {
nullptr};
17421 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17422 status_check(context::get_status());
17423 return tensor(res[0]);
17426 inline tensor restore_slice(
const tensor& file_pattern,
const tensor& input_tensor_name,
const tensor& shape_and_slice,
17427 datatype dt, int64_t preferred_shard = -1) {
17429 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17430 TFE_NewOp(context::get_context(),
"RestoreSlice", context::get_status()), &TFE_DeleteOp);
17431 status_check(context::get_status());
17435 TFE_OpAddInput(op.get(), file_pattern.tfe_handle.get(), context::get_status());
17436 status_check(context::get_status());
17438 TFE_OpAddInput(op.get(), input_tensor_name.tfe_handle.get(), context::get_status());
17439 status_check(context::get_status());
17441 TFE_OpAddInput(op.get(), shape_and_slice.tfe_handle.get(), context::get_status());
17442 status_check(context::get_status());
17445 TFE_OpSetAttrType(op.get(),
"dt", dt);
17446 TFE_OpSetAttrInt(op.get(),
"preferred_shard", preferred_shard);
17449 int num_outputs_op = 1;
17450 TFE_TensorHandle* res[1] = {
nullptr};
17451 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17452 status_check(context::get_status());
17453 return tensor(res[0]);
17456 inline tensor restore_v2(
const tensor& prefix,
const tensor& input_tensor_names,
const tensor& shape_and_slices,
17457 const std::vector<datatype>& dtypes) {
17459 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17460 TFE_NewOp(context::get_context(),
"RestoreV2", context::get_status()), &TFE_DeleteOp);
17461 status_check(context::get_status());
17465 TFE_OpAddInput(op.get(), prefix.tfe_handle.get(), context::get_status());
17466 status_check(context::get_status());
17468 TFE_OpAddInput(op.get(), input_tensor_names.tfe_handle.get(), context::get_status());
17469 status_check(context::get_status());
17471 TFE_OpAddInput(op.get(), shape_and_slices.tfe_handle.get(), context::get_status());
17472 status_check(context::get_status());
17475 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
17476 static_cast<int>(dtypes.size()));
17479 int num_outputs_op = 1;
17480 TFE_TensorHandle* res[1] = {
nullptr};
17481 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17482 status_check(context::get_status());
17483 return tensor(res[0]);
17486 inline tensor retrieve_t_p_u_embedding_stochastic_gradient_descent_parameters(int64_t num_shards, int64_t shard_id,
17487 int64_t table_id = -1,
17488 const std::string& table_name =
"",
17489 const std::string& config =
"") {
17491 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17492 TFE_NewOp(context::get_context(),
"RetrieveTPUEmbeddingStochasticGradientDescentParameters",
17493 context::get_status()),
17495 status_check(context::get_status());
17500 TFE_OpSetAttrInt(op.get(),
"num_shards", num_shards);
17501 TFE_OpSetAttrInt(op.get(),
"shard_id", shard_id);
17502 TFE_OpSetAttrInt(op.get(),
"table_id", table_id);
17503 TFE_OpSetAttrString(op.get(),
"table_name", (
void*)table_name.c_str(), table_name.size());
17504 TFE_OpSetAttrString(op.get(),
"config", (
void*)config.c_str(), config.size());
17507 int num_outputs_op = 1;
17508 TFE_TensorHandle* res[1] = {
nullptr};
17509 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17510 status_check(context::get_status());
17511 return tensor(res[0]);
17514 inline tensor reverse(
const tensor& input_tensor,
const tensor& dims) {
17516 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17517 TFE_NewOp(context::get_context(),
"Reverse", context::get_status()), &TFE_DeleteOp);
17518 status_check(context::get_status());
17522 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
17523 status_check(context::get_status());
17525 TFE_OpAddInput(op.get(), dims.tfe_handle.get(), context::get_status());
17526 status_check(context::get_status());
17531 int num_outputs_op = 1;
17532 TFE_TensorHandle* res[1] = {
nullptr};
17533 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17534 status_check(context::get_status());
17535 return tensor(res[0]);
17538 inline tensor reverse_sequence(
const tensor& input,
const tensor& seq_lengths, int64_t seq_dim, int64_t batch_dim = 0,
17539 datatype Tlen =
static_cast<datatype
>(9)) {
17541 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17542 TFE_NewOp(context::get_context(),
"ReverseSequence", context::get_status()), &TFE_DeleteOp);
17543 status_check(context::get_status());
17547 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
17548 status_check(context::get_status());
17550 TFE_OpAddInput(op.get(), seq_lengths.tfe_handle.get(), context::get_status());
17551 status_check(context::get_status());
17554 TFE_OpSetAttrInt(op.get(),
"seq_dim", seq_dim);
17555 TFE_OpSetAttrInt(op.get(),
"batch_dim", batch_dim);
17556 TFE_OpSetAttrType(op.get(),
"Tlen", Tlen);
17559 int num_outputs_op = 1;
17560 TFE_TensorHandle* res[1] = {
nullptr};
17561 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17562 status_check(context::get_status());
17563 return tensor(res[0]);
17566 inline tensor reverse_v2(
const tensor& input_tensor,
const tensor& axis, datatype Tidx =
static_cast<datatype
>(3)) {
17568 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17569 TFE_NewOp(context::get_context(),
"ReverseV2", context::get_status()), &TFE_DeleteOp);
17570 status_check(context::get_status());
17574 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
17575 status_check(context::get_status());
17577 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
17578 status_check(context::get_status());
17581 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
17584 int num_outputs_op = 1;
17585 TFE_TensorHandle* res[1] = {
nullptr};
17586 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17587 status_check(context::get_status());
17588 return tensor(res[0]);
17591 inline tensor right_shift(
const tensor& x,
const tensor& y) {
17593 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17594 TFE_NewOp(context::get_context(),
"RightShift", context::get_status()), &TFE_DeleteOp);
17595 status_check(context::get_status());
17599 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
17600 status_check(context::get_status());
17602 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
17603 status_check(context::get_status());
17608 int num_outputs_op = 1;
17609 TFE_TensorHandle* res[1] = {
nullptr};
17610 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17611 status_check(context::get_status());
17612 return tensor(res[0]);
17615 inline tensor rint(
const tensor& x) {
17617 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Rint", context::get_status()),
17619 status_check(context::get_status());
17623 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
17624 status_check(context::get_status());
17629 int num_outputs_op = 1;
17630 TFE_TensorHandle* res[1] = {
nullptr};
17631 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17632 status_check(context::get_status());
17633 return tensor(res[0]);
17636 inline tensor roll(
const tensor& input,
const tensor& shift,
const tensor& axis, datatype Tshift, datatype Taxis) {
17638 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Roll", context::get_status()),
17640 status_check(context::get_status());
17644 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
17645 status_check(context::get_status());
17647 TFE_OpAddInput(op.get(), shift.tfe_handle.get(), context::get_status());
17648 status_check(context::get_status());
17650 TFE_OpAddInput(op.get(), axis.tfe_handle.get(), context::get_status());
17651 status_check(context::get_status());
17654 TFE_OpSetAttrType(op.get(),
"Tshift", Tshift);
17655 TFE_OpSetAttrType(op.get(),
"Taxis", Taxis);
17658 int num_outputs_op = 1;
17659 TFE_TensorHandle* res[1] = {
nullptr};
17660 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17661 status_check(context::get_status());
17662 return tensor(res[0]);
17665 inline tensor round(
const tensor& x) {
17667 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Round", context::get_status()),
17669 status_check(context::get_status());
17673 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
17674 status_check(context::get_status());
17679 int num_outputs_op = 1;
17680 TFE_TensorHandle* res[1] = {
nullptr};
17681 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17682 status_check(context::get_status());
17683 return tensor(res[0]);
17686 inline tensor rsqrt(
const tensor& x) {
17688 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Rsqrt", context::get_status()),
17690 status_check(context::get_status());
17694 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
17695 status_check(context::get_status());
17700 int num_outputs_op = 1;
17701 TFE_TensorHandle* res[1] = {
nullptr};
17702 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17703 status_check(context::get_status());
17704 return tensor(res[0]);
17707 inline tensor rsqrt_grad(
const tensor& y,
const tensor& dy) {
17709 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17710 TFE_NewOp(context::get_context(),
"RsqrtGrad", context::get_status()), &TFE_DeleteOp);
17711 status_check(context::get_status());
17715 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
17716 status_check(context::get_status());
17718 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
17719 status_check(context::get_status());
17724 int num_outputs_op = 1;
17725 TFE_TensorHandle* res[1] = {
nullptr};
17726 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17727 status_check(context::get_status());
17728 return tensor(res[0]);
17731 inline tensor sampling_dataset(
const tensor& input_dataset,
const tensor& rate,
const tensor& seed,
const tensor& seed2,
17732 const std::vector<datatype>& output_types,
17733 const std::vector<std::vector<int64_t>>& output_shapes) {
17735 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17736 TFE_NewOp(context::get_context(),
"SamplingDataset", context::get_status()), &TFE_DeleteOp);
17737 status_check(context::get_status());
17741 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
17742 status_check(context::get_status());
17744 TFE_OpAddInput(op.get(), rate.tfe_handle.get(), context::get_status());
17745 status_check(context::get_status());
17747 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
17748 status_check(context::get_status());
17750 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
17751 status_check(context::get_status());
17754 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
17755 static_cast<int>(output_types.size()));
17757 std::vector<const int64_t*> output_shapes_values;
17758 output_shapes_values.reserve(output_shapes.size());
17759 std::vector<int> output_shapes_ndims;
17760 output_shapes_ndims.reserve(output_shapes.size());
17761 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
17762 [](
const auto& v) { return v.data(); });
17763 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
17764 [](
const auto& v) { return static_cast<int>(v.size()); });
17765 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
17766 static_cast<int>(output_shapes.size()), context::get_status());
17767 status_check(context::get_status());
17770 int num_outputs_op = 1;
17771 TFE_TensorHandle* res[1] = {
nullptr};
17772 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17773 status_check(context::get_status());
17774 return tensor(res[0]);
17777 inline tensor scalar_summary(
const tensor& tags,
const tensor& values) {
17779 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17780 TFE_NewOp(context::get_context(),
"ScalarSummary", context::get_status()), &TFE_DeleteOp);
17781 status_check(context::get_status());
17785 TFE_OpAddInput(op.get(), tags.tfe_handle.get(), context::get_status());
17786 status_check(context::get_status());
17788 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
17789 status_check(context::get_status());
17794 int num_outputs_op = 1;
17795 TFE_TensorHandle* res[1] = {
nullptr};
17796 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17797 status_check(context::get_status());
17798 return tensor(res[0]);
17801 inline tensor scale_and_translate(
const tensor& images,
const tensor& size,
const tensor& scale,
17802 const tensor& translation,
const std::string& kernel_type =
"lanczos3",
17803 bool antialias =
true) {
17805 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17806 TFE_NewOp(context::get_context(),
"ScaleAndTranslate", context::get_status()), &TFE_DeleteOp);
17807 status_check(context::get_status());
17811 TFE_OpAddInput(op.get(), images.tfe_handle.get(), context::get_status());
17812 status_check(context::get_status());
17814 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
17815 status_check(context::get_status());
17817 TFE_OpAddInput(op.get(), scale.tfe_handle.get(), context::get_status());
17818 status_check(context::get_status());
17820 TFE_OpAddInput(op.get(), translation.tfe_handle.get(), context::get_status());
17821 status_check(context::get_status());
17824 TFE_OpSetAttrString(op.get(),
"kernel_type", (
void*)kernel_type.c_str(), kernel_type.size());
17825 TFE_OpSetAttrBool(op.get(),
"antialias", (
unsigned char)antialias);
17828 int num_outputs_op = 1;
17829 TFE_TensorHandle* res[1] = {
nullptr};
17830 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17831 status_check(context::get_status());
17832 return tensor(res[0]);
17835 inline tensor scale_and_translate_grad(
const tensor& grads,
const tensor& original_image,
const tensor& scale,
17836 const tensor& translation,
const std::string& kernel_type =
"lanczos3",
17837 bool antialias =
true) {
17839 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17840 TFE_NewOp(context::get_context(),
"ScaleAndTranslateGrad", context::get_status()), &TFE_DeleteOp);
17841 status_check(context::get_status());
17845 TFE_OpAddInput(op.get(), grads.tfe_handle.get(), context::get_status());
17846 status_check(context::get_status());
17848 TFE_OpAddInput(op.get(), original_image.tfe_handle.get(), context::get_status());
17849 status_check(context::get_status());
17851 TFE_OpAddInput(op.get(), scale.tfe_handle.get(), context::get_status());
17852 status_check(context::get_status());
17854 TFE_OpAddInput(op.get(), translation.tfe_handle.get(), context::get_status());
17855 status_check(context::get_status());
17858 TFE_OpSetAttrString(op.get(),
"kernel_type", (
void*)kernel_type.c_str(), kernel_type.size());
17859 TFE_OpSetAttrBool(op.get(),
"antialias", (
unsigned char)antialias);
17862 int num_outputs_op = 1;
17863 TFE_TensorHandle* res[1] = {
nullptr};
17864 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17865 status_check(context::get_status());
17866 return tensor(res[0]);
17869 inline tensor scatter_add(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
17870 bool use_locking =
false) {
17872 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17873 TFE_NewOp(context::get_context(),
"ScatterAdd", context::get_status()), &TFE_DeleteOp);
17874 status_check(context::get_status());
17878 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
17879 status_check(context::get_status());
17881 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17882 status_check(context::get_status());
17884 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
17885 status_check(context::get_status());
17888 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17889 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
17892 int num_outputs_op = 1;
17893 TFE_TensorHandle* res[1] = {
nullptr};
17894 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17895 status_check(context::get_status());
17896 return tensor(res[0]);
17899 inline tensor scatter_div(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
17900 bool use_locking =
false) {
17902 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17903 TFE_NewOp(context::get_context(),
"ScatterDiv", context::get_status()), &TFE_DeleteOp);
17904 status_check(context::get_status());
17908 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
17909 status_check(context::get_status());
17911 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17912 status_check(context::get_status());
17914 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
17915 status_check(context::get_status());
17918 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17919 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
17922 int num_outputs_op = 1;
17923 TFE_TensorHandle* res[1] = {
nullptr};
17924 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17925 status_check(context::get_status());
17926 return tensor(res[0]);
17929 inline tensor scatter_max(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
17930 bool use_locking =
false) {
17932 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17933 TFE_NewOp(context::get_context(),
"ScatterMax", context::get_status()), &TFE_DeleteOp);
17934 status_check(context::get_status());
17938 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
17939 status_check(context::get_status());
17941 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17942 status_check(context::get_status());
17944 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
17945 status_check(context::get_status());
17948 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17949 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
17952 int num_outputs_op = 1;
17953 TFE_TensorHandle* res[1] = {
nullptr};
17954 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17955 status_check(context::get_status());
17956 return tensor(res[0]);
17959 inline tensor scatter_min(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
17960 bool use_locking =
false) {
17962 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17963 TFE_NewOp(context::get_context(),
"ScatterMin", context::get_status()), &TFE_DeleteOp);
17964 status_check(context::get_status());
17968 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
17969 status_check(context::get_status());
17971 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
17972 status_check(context::get_status());
17974 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
17975 status_check(context::get_status());
17978 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
17979 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
17982 int num_outputs_op = 1;
17983 TFE_TensorHandle* res[1] = {
nullptr};
17984 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
17985 status_check(context::get_status());
17986 return tensor(res[0]);
17989 inline tensor scatter_mul(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
17990 bool use_locking =
false) {
17992 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
17993 TFE_NewOp(context::get_context(),
"ScatterMul", context::get_status()), &TFE_DeleteOp);
17994 status_check(context::get_status());
17998 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
17999 status_check(context::get_status());
18001 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18002 status_check(context::get_status());
18004 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18005 status_check(context::get_status());
18008 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18009 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18012 int num_outputs_op = 1;
18013 TFE_TensorHandle* res[1] = {
nullptr};
18014 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18015 status_check(context::get_status());
18016 return tensor(res[0]);
18019 inline tensor scatter_nd(
const tensor& indices,
const tensor& updates,
const tensor& shape, datatype Tindices) {
18021 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18022 TFE_NewOp(context::get_context(),
"ScatterNd", context::get_status()), &TFE_DeleteOp);
18023 status_check(context::get_status());
18027 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18028 status_check(context::get_status());
18030 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18031 status_check(context::get_status());
18033 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
18034 status_check(context::get_status());
18037 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18040 int num_outputs_op = 1;
18041 TFE_TensorHandle* res[1] = {
nullptr};
18042 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18043 status_check(context::get_status());
18044 return tensor(res[0]);
18047 inline tensor scatter_nd_add(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18048 bool use_locking =
false) {
18050 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18051 TFE_NewOp(context::get_context(),
"ScatterNdAdd", context::get_status()), &TFE_DeleteOp);
18052 status_check(context::get_status());
18056 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18057 status_check(context::get_status());
18059 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18060 status_check(context::get_status());
18062 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18063 status_check(context::get_status());
18066 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18067 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18070 int num_outputs_op = 1;
18071 TFE_TensorHandle* res[1] = {
nullptr};
18072 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18073 status_check(context::get_status());
18074 return tensor(res[0]);
18077 inline tensor scatter_nd_max(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18078 bool use_locking =
false) {
18080 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18081 TFE_NewOp(context::get_context(),
"ScatterNdMax", context::get_status()), &TFE_DeleteOp);
18082 status_check(context::get_status());
18086 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18087 status_check(context::get_status());
18089 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18090 status_check(context::get_status());
18092 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18093 status_check(context::get_status());
18096 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18097 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18100 int num_outputs_op = 1;
18101 TFE_TensorHandle* res[1] = {
nullptr};
18102 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18103 status_check(context::get_status());
18104 return tensor(res[0]);
18107 inline tensor scatter_nd_min(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18108 bool use_locking =
false) {
18110 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18111 TFE_NewOp(context::get_context(),
"ScatterNdMin", context::get_status()), &TFE_DeleteOp);
18112 status_check(context::get_status());
18116 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18117 status_check(context::get_status());
18119 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18120 status_check(context::get_status());
18122 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18123 status_check(context::get_status());
18126 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18127 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18130 int num_outputs_op = 1;
18131 TFE_TensorHandle* res[1] = {
nullptr};
18132 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18133 status_check(context::get_status());
18134 return tensor(res[0]);
18137 inline tensor scatter_nd_non_aliasing_add(
const tensor& input,
const tensor& indices,
const tensor& updates,
18138 datatype Tindices) {
18140 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18141 TFE_NewOp(context::get_context(),
"ScatterNdNonAliasingAdd", context::get_status()), &TFE_DeleteOp);
18142 status_check(context::get_status());
18146 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
18147 status_check(context::get_status());
18149 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18150 status_check(context::get_status());
18152 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18153 status_check(context::get_status());
18156 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18159 int num_outputs_op = 1;
18160 TFE_TensorHandle* res[1] = {
nullptr};
18161 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18162 status_check(context::get_status());
18163 return tensor(res[0]);
18166 inline tensor scatter_nd_sub(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18167 bool use_locking =
false) {
18169 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18170 TFE_NewOp(context::get_context(),
"ScatterNdSub", context::get_status()), &TFE_DeleteOp);
18171 status_check(context::get_status());
18175 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18176 status_check(context::get_status());
18178 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18179 status_check(context::get_status());
18181 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18182 status_check(context::get_status());
18185 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18186 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18189 int num_outputs_op = 1;
18190 TFE_TensorHandle* res[1] = {
nullptr};
18191 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18192 status_check(context::get_status());
18193 return tensor(res[0]);
18196 inline tensor scatter_nd_update(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18197 bool use_locking =
true) {
18199 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18200 TFE_NewOp(context::get_context(),
"ScatterNdUpdate", context::get_status()), &TFE_DeleteOp);
18201 status_check(context::get_status());
18205 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18206 status_check(context::get_status());
18208 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18209 status_check(context::get_status());
18211 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18212 status_check(context::get_status());
18215 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18216 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18219 int num_outputs_op = 1;
18220 TFE_TensorHandle* res[1] = {
nullptr};
18221 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18222 status_check(context::get_status());
18223 return tensor(res[0]);
18226 inline tensor scatter_sub(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18227 bool use_locking =
false) {
18229 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18230 TFE_NewOp(context::get_context(),
"ScatterSub", context::get_status()), &TFE_DeleteOp);
18231 status_check(context::get_status());
18235 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18236 status_check(context::get_status());
18238 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18239 status_check(context::get_status());
18241 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18242 status_check(context::get_status());
18245 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18246 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18249 int num_outputs_op = 1;
18250 TFE_TensorHandle* res[1] = {
nullptr};
18251 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18252 status_check(context::get_status());
18253 return tensor(res[0]);
18256 inline tensor scatter_update(
const tensor& ref,
const tensor& indices,
const tensor& updates, datatype Tindices,
18257 bool use_locking =
true) {
18259 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18260 TFE_NewOp(context::get_context(),
"ScatterUpdate", context::get_status()), &TFE_DeleteOp);
18261 status_check(context::get_status());
18265 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
18266 status_check(context::get_status());
18268 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
18269 status_check(context::get_status());
18271 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
18272 status_check(context::get_status());
18275 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18276 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
18279 int num_outputs_op = 1;
18280 TFE_TensorHandle* res[1] = {
nullptr};
18281 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18282 status_check(context::get_status());
18283 return tensor(res[0]);
18286 inline tensor sdca_fprint(
const tensor& input) {
18288 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18289 TFE_NewOp(context::get_context(),
"SdcaFprint", context::get_status()), &TFE_DeleteOp);
18290 status_check(context::get_status());
18294 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
18295 status_check(context::get_status());
18300 int num_outputs_op = 1;
18301 TFE_TensorHandle* res[1] = {
nullptr};
18302 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18303 status_check(context::get_status());
18304 return tensor(res[0]);
18307 inline tensor segment_max(
const tensor& data,
const tensor& segment_ids, datatype Tindices) {
18309 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18310 TFE_NewOp(context::get_context(),
"SegmentMax", context::get_status()), &TFE_DeleteOp);
18311 status_check(context::get_status());
18315 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
18316 status_check(context::get_status());
18318 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
18319 status_check(context::get_status());
18322 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18325 int num_outputs_op = 1;
18326 TFE_TensorHandle* res[1] = {
nullptr};
18327 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18328 status_check(context::get_status());
18329 return tensor(res[0]);
18332 inline tensor segment_mean(
const tensor& data,
const tensor& segment_ids, datatype Tindices) {
18334 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18335 TFE_NewOp(context::get_context(),
"SegmentMean", context::get_status()), &TFE_DeleteOp);
18336 status_check(context::get_status());
18340 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
18341 status_check(context::get_status());
18343 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
18344 status_check(context::get_status());
18347 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18350 int num_outputs_op = 1;
18351 TFE_TensorHandle* res[1] = {
nullptr};
18352 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18353 status_check(context::get_status());
18354 return tensor(res[0]);
18357 inline tensor segment_min(
const tensor& data,
const tensor& segment_ids, datatype Tindices) {
18359 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18360 TFE_NewOp(context::get_context(),
"SegmentMin", context::get_status()), &TFE_DeleteOp);
18361 status_check(context::get_status());
18365 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
18366 status_check(context::get_status());
18368 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
18369 status_check(context::get_status());
18372 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18375 int num_outputs_op = 1;
18376 TFE_TensorHandle* res[1] = {
nullptr};
18377 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18378 status_check(context::get_status());
18379 return tensor(res[0]);
18382 inline tensor segment_prod(
const tensor& data,
const tensor& segment_ids, datatype Tindices) {
18384 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18385 TFE_NewOp(context::get_context(),
"SegmentProd", context::get_status()), &TFE_DeleteOp);
18386 status_check(context::get_status());
18390 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
18391 status_check(context::get_status());
18393 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
18394 status_check(context::get_status());
18397 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18400 int num_outputs_op = 1;
18401 TFE_TensorHandle* res[1] = {
nullptr};
18402 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18403 status_check(context::get_status());
18404 return tensor(res[0]);
18407 inline tensor segment_sum(
const tensor& data,
const tensor& segment_ids, datatype Tindices) {
18409 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18410 TFE_NewOp(context::get_context(),
"SegmentSum", context::get_status()), &TFE_DeleteOp);
18411 status_check(context::get_status());
18415 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
18416 status_check(context::get_status());
18418 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
18419 status_check(context::get_status());
18422 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
18425 int num_outputs_op = 1;
18426 TFE_TensorHandle* res[1] = {
nullptr};
18427 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18428 status_check(context::get_status());
18429 return tensor(res[0]);
18432 inline tensor select(
const tensor& condition,
const tensor& t,
const tensor& e) {
18434 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18435 TFE_NewOp(context::get_context(),
"Select", context::get_status()), &TFE_DeleteOp);
18436 status_check(context::get_status());
18440 TFE_OpAddInput(op.get(), condition.tfe_handle.get(), context::get_status());
18441 status_check(context::get_status());
18443 TFE_OpAddInput(op.get(), t.tfe_handle.get(), context::get_status());
18444 status_check(context::get_status());
18446 TFE_OpAddInput(op.get(), e.tfe_handle.get(), context::get_status());
18447 status_check(context::get_status());
18452 int num_outputs_op = 1;
18453 TFE_TensorHandle* res[1] = {
nullptr};
18454 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18455 status_check(context::get_status());
18456 return tensor(res[0]);
18459 inline tensor select_v2(
const tensor& condition,
const tensor& t,
const tensor& e) {
18461 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18462 TFE_NewOp(context::get_context(),
"SelectV2", context::get_status()), &TFE_DeleteOp);
18463 status_check(context::get_status());
18467 TFE_OpAddInput(op.get(), condition.tfe_handle.get(), context::get_status());
18468 status_check(context::get_status());
18470 TFE_OpAddInput(op.get(), t.tfe_handle.get(), context::get_status());
18471 status_check(context::get_status());
18473 TFE_OpAddInput(op.get(), e.tfe_handle.get(), context::get_status());
18474 status_check(context::get_status());
18479 int num_outputs_op = 1;
18480 TFE_TensorHandle* res[1] = {
nullptr};
18481 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18482 status_check(context::get_status());
18483 return tensor(res[0]);
18486 inline tensor self_adjoint_eig(
const tensor& input) {
18488 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18489 TFE_NewOp(context::get_context(),
"SelfAdjointEig", context::get_status()), &TFE_DeleteOp);
18490 status_check(context::get_status());
18494 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
18495 status_check(context::get_status());
18500 int num_outputs_op = 1;
18501 TFE_TensorHandle* res[1] = {
nullptr};
18502 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18503 status_check(context::get_status());
18504 return tensor(res[0]);
18507 inline tensor selu(
const tensor& features) {
18509 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Selu", context::get_status()),
18511 status_check(context::get_status());
18515 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
18516 status_check(context::get_status());
18521 int num_outputs_op = 1;
18522 TFE_TensorHandle* res[1] = {
nullptr};
18523 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18524 status_check(context::get_status());
18525 return tensor(res[0]);
18528 inline tensor selu_grad(
const tensor& gradients,
const tensor& outputs) {
18530 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18531 TFE_NewOp(context::get_context(),
"SeluGrad", context::get_status()), &TFE_DeleteOp);
18532 status_check(context::get_status());
18536 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
18537 status_check(context::get_status());
18539 TFE_OpAddInput(op.get(), outputs.tfe_handle.get(), context::get_status());
18540 status_check(context::get_status());
18545 int num_outputs_op = 1;
18546 TFE_TensorHandle* res[1] = {
nullptr};
18547 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18548 status_check(context::get_status());
18549 return tensor(res[0]);
18552 inline tensor serialize_iterator(
const tensor& resource_handle, int64_t external_state_policy = 0) {
18554 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18555 TFE_NewOp(context::get_context(),
"SerializeIterator", context::get_status()), &TFE_DeleteOp);
18556 status_check(context::get_status());
18560 TFE_OpAddInput(op.get(), resource_handle.tfe_handle.get(), context::get_status());
18561 status_check(context::get_status());
18564 TFE_OpSetAttrInt(op.get(),
"external_state_policy", external_state_policy);
18567 int num_outputs_op = 1;
18568 TFE_TensorHandle* res[1] = {
nullptr};
18569 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18570 status_check(context::get_status());
18571 return tensor(res[0]);
18574 inline tensor serialize_many_sparse(
const tensor& sparse_indices,
const tensor& sparse_values,
18575 const tensor& sparse_shape, datatype out_type =
static_cast<datatype
>(7)) {
18577 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18578 TFE_NewOp(context::get_context(),
"SerializeManySparse", context::get_status()), &TFE_DeleteOp);
18579 status_check(context::get_status());
18583 TFE_OpAddInput(op.get(), sparse_indices.tfe_handle.get(), context::get_status());
18584 status_check(context::get_status());
18586 TFE_OpAddInput(op.get(), sparse_values.tfe_handle.get(), context::get_status());
18587 status_check(context::get_status());
18589 TFE_OpAddInput(op.get(), sparse_shape.tfe_handle.get(), context::get_status());
18590 status_check(context::get_status());
18593 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
18596 int num_outputs_op = 1;
18597 TFE_TensorHandle* res[1] = {
nullptr};
18598 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18599 status_check(context::get_status());
18600 return tensor(res[0]);
18603 inline tensor serialize_sparse(
const tensor& sparse_indices,
const tensor& sparse_values,
const tensor& sparse_shape,
18604 datatype out_type =
static_cast<datatype
>(7)) {
18606 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18607 TFE_NewOp(context::get_context(),
"SerializeSparse", context::get_status()), &TFE_DeleteOp);
18608 status_check(context::get_status());
18612 TFE_OpAddInput(op.get(), sparse_indices.tfe_handle.get(), context::get_status());
18613 status_check(context::get_status());
18615 TFE_OpAddInput(op.get(), sparse_values.tfe_handle.get(), context::get_status());
18616 status_check(context::get_status());
18618 TFE_OpAddInput(op.get(), sparse_shape.tfe_handle.get(), context::get_status());
18619 status_check(context::get_status());
18622 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
18625 int num_outputs_op = 1;
18626 TFE_TensorHandle* res[1] = {
nullptr};
18627 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18628 status_check(context::get_status());
18629 return tensor(res[0]);
18632 inline tensor serialize_tensor(
const tensor& input_tensor) {
18634 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18635 TFE_NewOp(context::get_context(),
"SerializeTensor", context::get_status()), &TFE_DeleteOp);
18636 status_check(context::get_status());
18640 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
18641 status_check(context::get_status());
18646 int num_outputs_op = 1;
18647 TFE_TensorHandle* res[1] = {
nullptr};
18648 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18649 status_check(context::get_status());
18650 return tensor(res[0]);
18653 inline tensor set_size(
const tensor& set_indices,
const tensor& set_values,
const tensor& set_shape,
18654 bool validate_indices =
true) {
18656 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18657 TFE_NewOp(context::get_context(),
"SetSize", context::get_status()), &TFE_DeleteOp);
18658 status_check(context::get_status());
18662 TFE_OpAddInput(op.get(), set_indices.tfe_handle.get(), context::get_status());
18663 status_check(context::get_status());
18665 TFE_OpAddInput(op.get(), set_values.tfe_handle.get(), context::get_status());
18666 status_check(context::get_status());
18668 TFE_OpAddInput(op.get(), set_shape.tfe_handle.get(), context::get_status());
18669 status_check(context::get_status());
18672 TFE_OpSetAttrBool(op.get(),
"validate_indices", (
unsigned char)validate_indices);
18675 int num_outputs_op = 1;
18676 TFE_TensorHandle* res[1] = {
nullptr};
18677 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18678 status_check(context::get_status());
18679 return tensor(res[0]);
18682 inline tensor set_stats_aggregator_dataset(
const tensor& input_dataset,
const tensor& stats_aggregator,
18683 const tensor& tag,
const tensor& counter_prefix,
18684 const std::vector<datatype>& output_types,
18685 const std::vector<std::vector<int64_t>>& output_shapes) {
18687 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18688 TFE_NewOp(context::get_context(),
"SetStatsAggregatorDataset", context::get_status()), &TFE_DeleteOp);
18689 status_check(context::get_status());
18693 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
18694 status_check(context::get_status());
18696 TFE_OpAddInput(op.get(), stats_aggregator.tfe_handle.get(), context::get_status());
18697 status_check(context::get_status());
18699 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
18700 status_check(context::get_status());
18702 TFE_OpAddInput(op.get(), counter_prefix.tfe_handle.get(), context::get_status());
18703 status_check(context::get_status());
18706 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
18707 static_cast<int>(output_types.size()));
18709 std::vector<const int64_t*> output_shapes_values;
18710 output_shapes_values.reserve(output_shapes.size());
18711 std::vector<int> output_shapes_ndims;
18712 output_shapes_ndims.reserve(output_shapes.size());
18713 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
18714 [](
const auto& v) { return v.data(); });
18715 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
18716 [](
const auto& v) { return static_cast<int>(v.size()); });
18717 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
18718 static_cast<int>(output_shapes.size()), context::get_status());
18719 status_check(context::get_status());
18722 int num_outputs_op = 1;
18723 TFE_TensorHandle* res[1] = {
nullptr};
18724 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18725 status_check(context::get_status());
18726 return tensor(res[0]);
18729 inline tensor shape(
const tensor& input, datatype out_type =
static_cast<datatype
>(3)) {
18731 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Shape", context::get_status()),
18733 status_check(context::get_status());
18737 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
18738 status_check(context::get_status());
18741 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
18744 int num_outputs_op = 1;
18745 TFE_TensorHandle* res[1] = {
nullptr};
18746 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18747 status_check(context::get_status());
18748 return tensor(res[0]);
18751 inline tensor shape_n(
const std::vector<tensor>& input, datatype out_type =
static_cast<datatype
>(3)) {
18753 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18754 TFE_NewOp(context::get_context(),
"ShapeN", context::get_status()), &TFE_DeleteOp);
18755 status_check(context::get_status());
18759 std::vector<TFE_TensorHandle*> input_handles;
18760 input_handles.reserve(input.size());
18761 std::transform(input.begin(), input.end(), std::back_inserter(input_handles),
18762 [](
const auto& t) { return t.tfe_handle.get(); });
18763 TFE_OpAddInputList(op.get(), input_handles.data(),
static_cast<int>(input.size()), context::get_status());
18764 status_check(context::get_status());
18767 TFE_OpSetAttrInt(op.get(),
"N", input.size());
18768 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
18771 int num_outputs_op = 1;
18772 TFE_TensorHandle* res[1] = {
nullptr};
18773 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18774 status_check(context::get_status());
18775 return tensor(res[0]);
18778 inline tensor shard_dataset(
const tensor& input_dataset,
const tensor& num_shards,
const tensor& index,
18779 const std::vector<datatype>& output_types,
18780 const std::vector<std::vector<int64_t>>& output_shapes,
bool require_non_empty =
false) {
18782 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18783 TFE_NewOp(context::get_context(),
"ShardDataset", context::get_status()), &TFE_DeleteOp);
18784 status_check(context::get_status());
18788 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
18789 status_check(context::get_status());
18791 TFE_OpAddInput(op.get(), num_shards.tfe_handle.get(), context::get_status());
18792 status_check(context::get_status());
18794 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
18795 status_check(context::get_status());
18798 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
18799 static_cast<int>(output_types.size()));
18801 std::vector<const int64_t*> output_shapes_values;
18802 output_shapes_values.reserve(output_shapes.size());
18803 std::vector<int> output_shapes_ndims;
18804 output_shapes_ndims.reserve(output_shapes.size());
18805 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
18806 [](
const auto& v) { return v.data(); });
18807 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
18808 [](
const auto& v) { return static_cast<int>(v.size()); });
18809 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
18810 static_cast<int>(output_shapes.size()), context::get_status());
18811 status_check(context::get_status());
18813 TFE_OpSetAttrBool(op.get(),
"require_non_empty", (
unsigned char)require_non_empty);
18816 int num_outputs_op = 1;
18817 TFE_TensorHandle* res[1] = {
nullptr};
18818 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18819 status_check(context::get_status());
18820 return tensor(res[0]);
18823 inline tensor sharded_filename(
const tensor& basename,
const tensor& shard,
const tensor& num_shards) {
18825 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18826 TFE_NewOp(context::get_context(),
"ShardedFilename", context::get_status()), &TFE_DeleteOp);
18827 status_check(context::get_status());
18831 TFE_OpAddInput(op.get(), basename.tfe_handle.get(), context::get_status());
18832 status_check(context::get_status());
18834 TFE_OpAddInput(op.get(), shard.tfe_handle.get(), context::get_status());
18835 status_check(context::get_status());
18837 TFE_OpAddInput(op.get(), num_shards.tfe_handle.get(), context::get_status());
18838 status_check(context::get_status());
18843 int num_outputs_op = 1;
18844 TFE_TensorHandle* res[1] = {
nullptr};
18845 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18846 status_check(context::get_status());
18847 return tensor(res[0]);
18850 inline tensor sharded_filespec(
const tensor& basename,
const tensor& num_shards) {
18852 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18853 TFE_NewOp(context::get_context(),
"ShardedFilespec", context::get_status()), &TFE_DeleteOp);
18854 status_check(context::get_status());
18858 TFE_OpAddInput(op.get(), basename.tfe_handle.get(), context::get_status());
18859 status_check(context::get_status());
18861 TFE_OpAddInput(op.get(), num_shards.tfe_handle.get(), context::get_status());
18862 status_check(context::get_status());
18867 int num_outputs_op = 1;
18868 TFE_TensorHandle* res[1] = {
nullptr};
18869 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18870 status_check(context::get_status());
18871 return tensor(res[0]);
18874 inline tensor shuffle_and_repeat_dataset(
const tensor& input_dataset,
const tensor& buffer_size,
const tensor& seed,
18875 const tensor& seed2,
const tensor& count,
18876 const std::vector<datatype>& output_types,
18877 const std::vector<std::vector<int64_t>>& output_shapes,
18878 bool reshuffle_each_iteration =
true) {
18880 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18881 TFE_NewOp(context::get_context(),
"ShuffleAndRepeatDataset", context::get_status()), &TFE_DeleteOp);
18882 status_check(context::get_status());
18886 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
18887 status_check(context::get_status());
18889 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
18890 status_check(context::get_status());
18892 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
18893 status_check(context::get_status());
18895 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
18896 status_check(context::get_status());
18898 TFE_OpAddInput(op.get(), count.tfe_handle.get(), context::get_status());
18899 status_check(context::get_status());
18902 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
18903 static_cast<int>(output_types.size()));
18905 std::vector<const int64_t*> output_shapes_values;
18906 output_shapes_values.reserve(output_shapes.size());
18907 std::vector<int> output_shapes_ndims;
18908 output_shapes_ndims.reserve(output_shapes.size());
18909 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
18910 [](
const auto& v) { return v.data(); });
18911 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
18912 [](
const auto& v) { return static_cast<int>(v.size()); });
18913 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
18914 static_cast<int>(output_shapes.size()), context::get_status());
18915 status_check(context::get_status());
18917 TFE_OpSetAttrBool(op.get(),
"reshuffle_each_iteration", (
unsigned char)reshuffle_each_iteration);
18920 int num_outputs_op = 1;
18921 TFE_TensorHandle* res[1] = {
nullptr};
18922 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18923 status_check(context::get_status());
18924 return tensor(res[0]);
18927 inline tensor shuffle_and_repeat_dataset_v2(
const tensor& input_dataset,
const tensor& buffer_size,
const tensor& seed,
18928 const tensor& seed2,
const tensor& count,
const tensor& seed_generator,
18929 const std::vector<datatype>& output_types,
18930 const std::vector<std::vector<int64_t>>& output_shapes,
18931 bool reshuffle_each_iteration =
true) {
18933 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18934 TFE_NewOp(context::get_context(),
"ShuffleAndRepeatDatasetV2", context::get_status()), &TFE_DeleteOp);
18935 status_check(context::get_status());
18939 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
18940 status_check(context::get_status());
18942 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
18943 status_check(context::get_status());
18945 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
18946 status_check(context::get_status());
18948 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
18949 status_check(context::get_status());
18951 TFE_OpAddInput(op.get(), count.tfe_handle.get(), context::get_status());
18952 status_check(context::get_status());
18954 TFE_OpAddInput(op.get(), seed_generator.tfe_handle.get(), context::get_status());
18955 status_check(context::get_status());
18958 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
18959 static_cast<int>(output_types.size()));
18961 std::vector<const int64_t*> output_shapes_values;
18962 output_shapes_values.reserve(output_shapes.size());
18963 std::vector<int> output_shapes_ndims;
18964 output_shapes_ndims.reserve(output_shapes.size());
18965 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
18966 [](
const auto& v) { return v.data(); });
18967 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
18968 [](
const auto& v) { return static_cast<int>(v.size()); });
18969 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
18970 static_cast<int>(output_shapes.size()), context::get_status());
18971 status_check(context::get_status());
18973 TFE_OpSetAttrBool(op.get(),
"reshuffle_each_iteration", (
unsigned char)reshuffle_each_iteration);
18976 int num_outputs_op = 1;
18977 TFE_TensorHandle* res[1] = {
nullptr};
18978 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
18979 status_check(context::get_status());
18980 return tensor(res[0]);
18983 inline tensor shuffle_dataset(
const tensor& input_dataset,
const tensor& buffer_size,
const tensor& seed,
18984 const tensor& seed2,
const std::vector<datatype>& output_types,
18985 const std::vector<std::vector<int64_t>>& output_shapes,
18986 bool reshuffle_each_iteration =
true) {
18988 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
18989 TFE_NewOp(context::get_context(),
"ShuffleDataset", context::get_status()), &TFE_DeleteOp);
18990 status_check(context::get_status());
18994 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
18995 status_check(context::get_status());
18997 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
18998 status_check(context::get_status());
19000 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
19001 status_check(context::get_status());
19003 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
19004 status_check(context::get_status());
19007 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19008 static_cast<int>(output_types.size()));
19010 std::vector<const int64_t*> output_shapes_values;
19011 output_shapes_values.reserve(output_shapes.size());
19012 std::vector<int> output_shapes_ndims;
19013 output_shapes_ndims.reserve(output_shapes.size());
19014 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19015 [](
const auto& v) { return v.data(); });
19016 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19017 [](
const auto& v) { return static_cast<int>(v.size()); });
19018 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19019 static_cast<int>(output_shapes.size()), context::get_status());
19020 status_check(context::get_status());
19022 TFE_OpSetAttrBool(op.get(),
"reshuffle_each_iteration", (
unsigned char)reshuffle_each_iteration);
19025 int num_outputs_op = 1;
19026 TFE_TensorHandle* res[1] = {
nullptr};
19027 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19028 status_check(context::get_status());
19029 return tensor(res[0]);
19032 inline tensor shuffle_dataset_v2(
const tensor& input_dataset,
const tensor& buffer_size,
const tensor& seed_generator,
19033 const std::vector<datatype>& output_types,
19034 const std::vector<std::vector<int64_t>>& output_shapes) {
19036 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19037 TFE_NewOp(context::get_context(),
"ShuffleDatasetV2", context::get_status()), &TFE_DeleteOp);
19038 status_check(context::get_status());
19042 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19043 status_check(context::get_status());
19045 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
19046 status_check(context::get_status());
19048 TFE_OpAddInput(op.get(), seed_generator.tfe_handle.get(), context::get_status());
19049 status_check(context::get_status());
19052 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19053 static_cast<int>(output_types.size()));
19055 std::vector<const int64_t*> output_shapes_values;
19056 output_shapes_values.reserve(output_shapes.size());
19057 std::vector<int> output_shapes_ndims;
19058 output_shapes_ndims.reserve(output_shapes.size());
19059 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19060 [](
const auto& v) { return v.data(); });
19061 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19062 [](
const auto& v) { return static_cast<int>(v.size()); });
19063 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19064 static_cast<int>(output_shapes.size()), context::get_status());
19065 status_check(context::get_status());
19068 int num_outputs_op = 1;
19069 TFE_TensorHandle* res[1] = {
nullptr};
19070 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19071 status_check(context::get_status());
19072 return tensor(res[0]);
19075 inline tensor shuffle_dataset_v3(
const tensor& input_dataset,
const tensor& buffer_size,
const tensor& seed,
19076 const tensor& seed2,
const tensor& seed_generator,
19077 const std::vector<datatype>& output_types,
19078 const std::vector<std::vector<int64_t>>& output_shapes,
19079 bool reshuffle_each_iteration =
true) {
19081 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19082 TFE_NewOp(context::get_context(),
"ShuffleDatasetV3", context::get_status()), &TFE_DeleteOp);
19083 status_check(context::get_status());
19087 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19088 status_check(context::get_status());
19090 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
19091 status_check(context::get_status());
19093 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
19094 status_check(context::get_status());
19096 TFE_OpAddInput(op.get(), seed2.tfe_handle.get(), context::get_status());
19097 status_check(context::get_status());
19099 TFE_OpAddInput(op.get(), seed_generator.tfe_handle.get(), context::get_status());
19100 status_check(context::get_status());
19103 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19104 static_cast<int>(output_types.size()));
19106 std::vector<const int64_t*> output_shapes_values;
19107 output_shapes_values.reserve(output_shapes.size());
19108 std::vector<int> output_shapes_ndims;
19109 output_shapes_ndims.reserve(output_shapes.size());
19110 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19111 [](
const auto& v) { return v.data(); });
19112 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19113 [](
const auto& v) { return static_cast<int>(v.size()); });
19114 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19115 static_cast<int>(output_shapes.size()), context::get_status());
19116 status_check(context::get_status());
19118 TFE_OpSetAttrBool(op.get(),
"reshuffle_each_iteration", (
unsigned char)reshuffle_each_iteration);
19121 int num_outputs_op = 1;
19122 TFE_TensorHandle* res[1] = {
nullptr};
19123 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19124 status_check(context::get_status());
19125 return tensor(res[0]);
19128 inline tensor sigmoid(
const tensor& x) {
19130 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19131 TFE_NewOp(context::get_context(),
"Sigmoid", context::get_status()), &TFE_DeleteOp);
19132 status_check(context::get_status());
19136 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
19137 status_check(context::get_status());
19142 int num_outputs_op = 1;
19143 TFE_TensorHandle* res[1] = {
nullptr};
19144 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19145 status_check(context::get_status());
19146 return tensor(res[0]);
19149 inline tensor sigmoid_grad(
const tensor& y,
const tensor& dy) {
19151 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19152 TFE_NewOp(context::get_context(),
"SigmoidGrad", context::get_status()), &TFE_DeleteOp);
19153 status_check(context::get_status());
19157 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
19158 status_check(context::get_status());
19160 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
19161 status_check(context::get_status());
19166 int num_outputs_op = 1;
19167 TFE_TensorHandle* res[1] = {
nullptr};
19168 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19169 status_check(context::get_status());
19170 return tensor(res[0]);
19173 inline tensor sign(
const tensor& x) {
19175 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sign", context::get_status()),
19177 status_check(context::get_status());
19181 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
19182 status_check(context::get_status());
19187 int num_outputs_op = 1;
19188 TFE_TensorHandle* res[1] = {
nullptr};
19189 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19190 status_check(context::get_status());
19191 return tensor(res[0]);
19194 inline tensor sin(
const tensor& x) {
19196 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sin", context::get_status()),
19198 status_check(context::get_status());
19202 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
19203 status_check(context::get_status());
19208 int num_outputs_op = 1;
19209 TFE_TensorHandle* res[1] = {
nullptr};
19210 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19211 status_check(context::get_status());
19212 return tensor(res[0]);
19215 inline tensor sinh(
const tensor& x) {
19217 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sinh", context::get_status()),
19219 status_check(context::get_status());
19223 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
19224 status_check(context::get_status());
19229 int num_outputs_op = 1;
19230 TFE_TensorHandle* res[1] = {
nullptr};
19231 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19232 status_check(context::get_status());
19233 return tensor(res[0]);
19236 inline tensor size(
const tensor& input, datatype out_type =
static_cast<datatype
>(3)) {
19238 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Size", context::get_status()),
19240 status_check(context::get_status());
19244 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19245 status_check(context::get_status());
19248 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
19251 int num_outputs_op = 1;
19252 TFE_TensorHandle* res[1] = {
nullptr};
19253 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19254 status_check(context::get_status());
19255 return tensor(res[0]);
19258 inline tensor skip_dataset(
const tensor& input_dataset,
const tensor& count,
const std::vector<datatype>& output_types,
19259 const std::vector<std::vector<int64_t>>& output_shapes) {
19261 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19262 TFE_NewOp(context::get_context(),
"SkipDataset", context::get_status()), &TFE_DeleteOp);
19263 status_check(context::get_status());
19267 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19268 status_check(context::get_status());
19270 TFE_OpAddInput(op.get(), count.tfe_handle.get(), context::get_status());
19271 status_check(context::get_status());
19274 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19275 static_cast<int>(output_types.size()));
19277 std::vector<const int64_t*> output_shapes_values;
19278 output_shapes_values.reserve(output_shapes.size());
19279 std::vector<int> output_shapes_ndims;
19280 output_shapes_ndims.reserve(output_shapes.size());
19281 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19282 [](
const auto& v) { return v.data(); });
19283 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19284 [](
const auto& v) { return static_cast<int>(v.size()); });
19285 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19286 static_cast<int>(output_shapes.size()), context::get_status());
19287 status_check(context::get_status());
19290 int num_outputs_op = 1;
19291 TFE_TensorHandle* res[1] = {
nullptr};
19292 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19293 status_check(context::get_status());
19294 return tensor(res[0]);
19297 inline tensor sleep_dataset(
const tensor& input_dataset,
const tensor& sleep_microseconds,
19298 const std::vector<datatype>& output_types,
19299 const std::vector<std::vector<int64_t>>& output_shapes) {
19301 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19302 TFE_NewOp(context::get_context(),
"SleepDataset", context::get_status()), &TFE_DeleteOp);
19303 status_check(context::get_status());
19307 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19308 status_check(context::get_status());
19310 TFE_OpAddInput(op.get(), sleep_microseconds.tfe_handle.get(), context::get_status());
19311 status_check(context::get_status());
19314 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19315 static_cast<int>(output_types.size()));
19317 std::vector<const int64_t*> output_shapes_values;
19318 output_shapes_values.reserve(output_shapes.size());
19319 std::vector<int> output_shapes_ndims;
19320 output_shapes_ndims.reserve(output_shapes.size());
19321 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19322 [](
const auto& v) { return v.data(); });
19323 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19324 [](
const auto& v) { return static_cast<int>(v.size()); });
19325 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19326 static_cast<int>(output_shapes.size()), context::get_status());
19327 status_check(context::get_status());
19330 int num_outputs_op = 1;
19331 TFE_TensorHandle* res[1] = {
nullptr};
19332 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19333 status_check(context::get_status());
19334 return tensor(res[0]);
19337 inline tensor slice(
const tensor& input,
const tensor& begin,
const tensor& size, datatype Index) {
19339 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Slice", context::get_status()),
19341 status_check(context::get_status());
19345 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19346 status_check(context::get_status());
19348 TFE_OpAddInput(op.get(), begin.tfe_handle.get(), context::get_status());
19349 status_check(context::get_status());
19351 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
19352 status_check(context::get_status());
19355 TFE_OpSetAttrType(op.get(),
"Index", Index);
19358 int num_outputs_op = 1;
19359 TFE_TensorHandle* res[1] = {
nullptr};
19360 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19361 status_check(context::get_status());
19362 return tensor(res[0]);
19365 inline tensor sliding_window_dataset(
const tensor& input_dataset,
const tensor& window_size,
const tensor& window_shift,
19366 const tensor& window_stride,
const std::vector<datatype>& output_types,
19367 const std::vector<std::vector<int64_t>>& output_shapes) {
19369 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19370 TFE_NewOp(context::get_context(),
"SlidingWindowDataset", context::get_status()), &TFE_DeleteOp);
19371 status_check(context::get_status());
19375 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19376 status_check(context::get_status());
19378 TFE_OpAddInput(op.get(), window_size.tfe_handle.get(), context::get_status());
19379 status_check(context::get_status());
19381 TFE_OpAddInput(op.get(), window_shift.tfe_handle.get(), context::get_status());
19382 status_check(context::get_status());
19384 TFE_OpAddInput(op.get(), window_stride.tfe_handle.get(), context::get_status());
19385 status_check(context::get_status());
19388 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19389 static_cast<int>(output_types.size()));
19391 std::vector<const int64_t*> output_shapes_values;
19392 output_shapes_values.reserve(output_shapes.size());
19393 std::vector<int> output_shapes_ndims;
19394 output_shapes_ndims.reserve(output_shapes.size());
19395 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19396 [](
const auto& v) { return v.data(); });
19397 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19398 [](
const auto& v) { return static_cast<int>(v.size()); });
19399 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19400 static_cast<int>(output_shapes.size()), context::get_status());
19401 status_check(context::get_status());
19404 int num_outputs_op = 1;
19405 TFE_TensorHandle* res[1] = {
nullptr};
19406 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19407 status_check(context::get_status());
19408 return tensor(res[0]);
19411 inline tensor snapshot(
const tensor& input) {
19413 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19414 TFE_NewOp(context::get_context(),
"Snapshot", context::get_status()), &TFE_DeleteOp);
19415 status_check(context::get_status());
19419 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19420 status_check(context::get_status());
19425 int num_outputs_op = 1;
19426 TFE_TensorHandle* res[1] = {
nullptr};
19427 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19428 status_check(context::get_status());
19429 return tensor(res[0]);
19432 inline tensor snapshot_dataset(
const tensor& input_dataset,
const tensor& path,
19433 const std::vector<datatype>& output_types,
19434 const std::vector<std::vector<int64_t>>& output_shapes,
19435 const std::string& compression =
"",
const std::string& reader_path_prefix =
"",
19436 const std::string& writer_path_prefix =
"", int64_t shard_size_bytes = 10737418240,
19437 int64_t pending_snapshot_expiry_seconds = 86400, int64_t num_reader_threads = 1,
19438 int64_t reader_buffer_size = 1, int64_t num_writer_threads = 1,
19439 int64_t writer_buffer_size = 1,
bool shuffle_on_read =
false, int64_t seed = 0,
19440 int64_t seed2 = 0,
const std::string& mode =
"auto",
19441 const std::string& snapshot_name =
"") {
19443 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19444 TFE_NewOp(context::get_context(),
"SnapshotDataset", context::get_status()), &TFE_DeleteOp);
19445 status_check(context::get_status());
19449 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
19450 status_check(context::get_status());
19452 TFE_OpAddInput(op.get(), path.tfe_handle.get(), context::get_status());
19453 status_check(context::get_status());
19456 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
19457 static_cast<int>(output_types.size()));
19459 std::vector<const int64_t*> output_shapes_values;
19460 output_shapes_values.reserve(output_shapes.size());
19461 std::vector<int> output_shapes_ndims;
19462 output_shapes_ndims.reserve(output_shapes.size());
19463 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
19464 [](
const auto& v) { return v.data(); });
19465 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
19466 [](
const auto& v) { return static_cast<int>(v.size()); });
19467 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
19468 static_cast<int>(output_shapes.size()), context::get_status());
19469 status_check(context::get_status());
19471 TFE_OpSetAttrString(op.get(),
"compression", (
void*)compression.c_str(), compression.size());
19472 TFE_OpSetAttrString(op.get(),
"reader_path_prefix", (
void*)reader_path_prefix.c_str(), reader_path_prefix.size());
19473 TFE_OpSetAttrString(op.get(),
"writer_path_prefix", (
void*)writer_path_prefix.c_str(), writer_path_prefix.size());
19474 TFE_OpSetAttrInt(op.get(),
"shard_size_bytes", shard_size_bytes);
19475 TFE_OpSetAttrInt(op.get(),
"pending_snapshot_expiry_seconds", pending_snapshot_expiry_seconds);
19476 TFE_OpSetAttrInt(op.get(),
"num_reader_threads", num_reader_threads);
19477 TFE_OpSetAttrInt(op.get(),
"reader_buffer_size", reader_buffer_size);
19478 TFE_OpSetAttrInt(op.get(),
"num_writer_threads", num_writer_threads);
19479 TFE_OpSetAttrInt(op.get(),
"writer_buffer_size", writer_buffer_size);
19480 TFE_OpSetAttrBool(op.get(),
"shuffle_on_read", (
unsigned char)shuffle_on_read);
19481 TFE_OpSetAttrInt(op.get(),
"seed", seed);
19482 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
19483 TFE_OpSetAttrString(op.get(),
"mode", (
void*)mode.c_str(), mode.size());
19484 TFE_OpSetAttrString(op.get(),
"snapshot_name", (
void*)snapshot_name.c_str(), snapshot_name.size());
19487 int num_outputs_op = 1;
19488 TFE_TensorHandle* res[1] = {
nullptr};
19489 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19490 status_check(context::get_status());
19491 return tensor(res[0]);
19494 inline tensor sobol_sample(
const tensor& dim,
const tensor& num_results,
const tensor& skip,
19495 datatype dtype =
static_cast<datatype
>(1)) {
19497 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19498 TFE_NewOp(context::get_context(),
"SobolSample", context::get_status()), &TFE_DeleteOp);
19499 status_check(context::get_status());
19503 TFE_OpAddInput(op.get(), dim.tfe_handle.get(), context::get_status());
19504 status_check(context::get_status());
19506 TFE_OpAddInput(op.get(), num_results.tfe_handle.get(), context::get_status());
19507 status_check(context::get_status());
19509 TFE_OpAddInput(op.get(), skip.tfe_handle.get(), context::get_status());
19510 status_check(context::get_status());
19513 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
19516 int num_outputs_op = 1;
19517 TFE_TensorHandle* res[1] = {
nullptr};
19518 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19519 status_check(context::get_status());
19520 return tensor(res[0]);
19523 inline tensor softmax(
const tensor& logits) {
19525 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19526 TFE_NewOp(context::get_context(),
"Softmax", context::get_status()), &TFE_DeleteOp);
19527 status_check(context::get_status());
19531 TFE_OpAddInput(op.get(), logits.tfe_handle.get(), context::get_status());
19532 status_check(context::get_status());
19537 int num_outputs_op = 1;
19538 TFE_TensorHandle* res[1] = {
nullptr};
19539 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19540 status_check(context::get_status());
19541 return tensor(res[0]);
19544 inline tensor softplus(
const tensor& features) {
19546 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19547 TFE_NewOp(context::get_context(),
"Softplus", context::get_status()), &TFE_DeleteOp);
19548 status_check(context::get_status());
19552 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
19553 status_check(context::get_status());
19558 int num_outputs_op = 1;
19559 TFE_TensorHandle* res[1] = {
nullptr};
19560 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19561 status_check(context::get_status());
19562 return tensor(res[0]);
19565 inline tensor softplus_grad(
const tensor& gradients,
const tensor& features) {
19567 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19568 TFE_NewOp(context::get_context(),
"SoftplusGrad", context::get_status()), &TFE_DeleteOp);
19569 status_check(context::get_status());
19573 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
19574 status_check(context::get_status());
19576 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
19577 status_check(context::get_status());
19582 int num_outputs_op = 1;
19583 TFE_TensorHandle* res[1] = {
nullptr};
19584 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19585 status_check(context::get_status());
19586 return tensor(res[0]);
19589 inline tensor softsign(
const tensor& features) {
19591 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19592 TFE_NewOp(context::get_context(),
"Softsign", context::get_status()), &TFE_DeleteOp);
19593 status_check(context::get_status());
19597 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
19598 status_check(context::get_status());
19603 int num_outputs_op = 1;
19604 TFE_TensorHandle* res[1] = {
nullptr};
19605 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19606 status_check(context::get_status());
19607 return tensor(res[0]);
19610 inline tensor softsign_grad(
const tensor& gradients,
const tensor& features) {
19612 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19613 TFE_NewOp(context::get_context(),
"SoftsignGrad", context::get_status()), &TFE_DeleteOp);
19614 status_check(context::get_status());
19618 TFE_OpAddInput(op.get(), gradients.tfe_handle.get(), context::get_status());
19619 status_check(context::get_status());
19621 TFE_OpAddInput(op.get(), features.tfe_handle.get(), context::get_status());
19622 status_check(context::get_status());
19627 int num_outputs_op = 1;
19628 TFE_TensorHandle* res[1] = {
nullptr};
19629 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19630 status_check(context::get_status());
19631 return tensor(res[0]);
19634 inline tensor space_to_batch(
const tensor& input,
const tensor& paddings, int64_t block_size,
19635 datatype Tpaddings =
static_cast<datatype
>(3)) {
19637 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19638 TFE_NewOp(context::get_context(),
"SpaceToBatch", context::get_status()), &TFE_DeleteOp);
19639 status_check(context::get_status());
19643 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19644 status_check(context::get_status());
19646 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
19647 status_check(context::get_status());
19650 TFE_OpSetAttrInt(op.get(),
"block_size", block_size);
19651 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
19654 int num_outputs_op = 1;
19655 TFE_TensorHandle* res[1] = {
nullptr};
19656 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19657 status_check(context::get_status());
19658 return tensor(res[0]);
19661 inline tensor space_to_batch_n_d(
const tensor& input,
const tensor& block_shape,
const tensor& paddings,
19662 datatype Tblock_shape =
static_cast<datatype
>(3),
19663 datatype Tpaddings =
static_cast<datatype
>(3)) {
19665 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19666 TFE_NewOp(context::get_context(),
"SpaceToBatchND", context::get_status()), &TFE_DeleteOp);
19667 status_check(context::get_status());
19671 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19672 status_check(context::get_status());
19674 TFE_OpAddInput(op.get(), block_shape.tfe_handle.get(), context::get_status());
19675 status_check(context::get_status());
19677 TFE_OpAddInput(op.get(), paddings.tfe_handle.get(), context::get_status());
19678 status_check(context::get_status());
19681 TFE_OpSetAttrType(op.get(),
"Tblock_shape", Tblock_shape);
19682 TFE_OpSetAttrType(op.get(),
"Tpaddings", Tpaddings);
19685 int num_outputs_op = 1;
19686 TFE_TensorHandle* res[1] = {
nullptr};
19687 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19688 status_check(context::get_status());
19689 return tensor(res[0]);
19692 inline tensor space_to_depth(
const tensor& input, int64_t block_size,
const std::string& data_format =
"NHWC") {
19694 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19695 TFE_NewOp(context::get_context(),
"SpaceToDepth", context::get_status()), &TFE_DeleteOp);
19696 status_check(context::get_status());
19700 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
19701 status_check(context::get_status());
19704 TFE_OpSetAttrInt(op.get(),
"block_size", block_size);
19705 TFE_OpSetAttrString(op.get(),
"data_format", (
void*)data_format.c_str(), data_format.size());
19708 int num_outputs_op = 1;
19709 TFE_TensorHandle* res[1] = {
nullptr};
19710 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19711 status_check(context::get_status());
19712 return tensor(res[0]);
19715 inline tensor sparse_apply_adadelta(
const tensor& var,
const tensor& accum,
const tensor& accum_update,
19716 const tensor& lr,
const tensor& rho,
const tensor& epsilon,
const tensor& grad,
19717 const tensor& indices, datatype Tindices,
bool use_locking =
false) {
19719 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19720 TFE_NewOp(context::get_context(),
"SparseApplyAdadelta", context::get_status()), &TFE_DeleteOp);
19721 status_check(context::get_status());
19725 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19726 status_check(context::get_status());
19728 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
19729 status_check(context::get_status());
19731 TFE_OpAddInput(op.get(), accum_update.tfe_handle.get(), context::get_status());
19732 status_check(context::get_status());
19734 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19735 status_check(context::get_status());
19737 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
19738 status_check(context::get_status());
19740 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
19741 status_check(context::get_status());
19743 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19744 status_check(context::get_status());
19746 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19747 status_check(context::get_status());
19750 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19751 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19754 int num_outputs_op = 1;
19755 TFE_TensorHandle* res[1] = {
nullptr};
19756 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19757 status_check(context::get_status());
19758 return tensor(res[0]);
19761 inline tensor sparse_apply_adagrad(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& grad,
19762 const tensor& indices, datatype Tindices,
bool use_locking =
false,
19763 bool update_slots =
true) {
19765 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19766 TFE_NewOp(context::get_context(),
"SparseApplyAdagrad", context::get_status()), &TFE_DeleteOp);
19767 status_check(context::get_status());
19771 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19772 status_check(context::get_status());
19774 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
19775 status_check(context::get_status());
19777 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19778 status_check(context::get_status());
19780 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19781 status_check(context::get_status());
19783 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19784 status_check(context::get_status());
19787 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19788 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19789 TFE_OpSetAttrBool(op.get(),
"update_slots", (
unsigned char)update_slots);
19792 int num_outputs_op = 1;
19793 TFE_TensorHandle* res[1] = {
nullptr};
19794 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19795 status_check(context::get_status());
19796 return tensor(res[0]);
19799 inline tensor sparse_apply_adagrad_d_a(
const tensor& var,
const tensor& gradient_accumulator,
19800 const tensor& gradient_squared_accumulator,
const tensor& grad,
19801 const tensor& indices,
const tensor& lr,
const tensor& l1,
const tensor& l2,
19802 const tensor& global_step, datatype Tindices,
bool use_locking =
false) {
19804 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19805 TFE_NewOp(context::get_context(),
"SparseApplyAdagradDA", context::get_status()), &TFE_DeleteOp);
19806 status_check(context::get_status());
19810 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19811 status_check(context::get_status());
19813 TFE_OpAddInput(op.get(), gradient_accumulator.tfe_handle.get(), context::get_status());
19814 status_check(context::get_status());
19816 TFE_OpAddInput(op.get(), gradient_squared_accumulator.tfe_handle.get(), context::get_status());
19817 status_check(context::get_status());
19819 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19820 status_check(context::get_status());
19822 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19823 status_check(context::get_status());
19825 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19826 status_check(context::get_status());
19828 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
19829 status_check(context::get_status());
19831 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
19832 status_check(context::get_status());
19834 TFE_OpAddInput(op.get(), global_step.tfe_handle.get(), context::get_status());
19835 status_check(context::get_status());
19838 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19839 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19842 int num_outputs_op = 1;
19843 TFE_TensorHandle* res[1] = {
nullptr};
19844 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19845 status_check(context::get_status());
19846 return tensor(res[0]);
19849 inline tensor sparse_apply_adagrad_v2(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& epsilon,
19850 const tensor& grad,
const tensor& indices, datatype Tindices,
19851 bool use_locking =
false,
bool update_slots =
true) {
19853 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19854 TFE_NewOp(context::get_context(),
"SparseApplyAdagradV2", context::get_status()), &TFE_DeleteOp);
19855 status_check(context::get_status());
19859 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19860 status_check(context::get_status());
19862 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
19863 status_check(context::get_status());
19865 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19866 status_check(context::get_status());
19868 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
19869 status_check(context::get_status());
19871 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19872 status_check(context::get_status());
19874 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19875 status_check(context::get_status());
19878 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19879 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19880 TFE_OpSetAttrBool(op.get(),
"update_slots", (
unsigned char)update_slots);
19883 int num_outputs_op = 1;
19884 TFE_TensorHandle* res[1] = {
nullptr};
19885 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19886 status_check(context::get_status());
19887 return tensor(res[0]);
19890 inline tensor sparse_apply_centered_r_m_s_prop(
const tensor& var,
const tensor& mg,
const tensor& ms,
const tensor& mom,
19891 const tensor& lr,
const tensor& rho,
const tensor& momentum,
19892 const tensor& epsilon,
const tensor& grad,
const tensor& indices,
19893 datatype Tindices,
bool use_locking =
false) {
19895 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19896 TFE_NewOp(context::get_context(),
"SparseApplyCenteredRMSProp", context::get_status()), &TFE_DeleteOp);
19897 status_check(context::get_status());
19901 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19902 status_check(context::get_status());
19904 TFE_OpAddInput(op.get(), mg.tfe_handle.get(), context::get_status());
19905 status_check(context::get_status());
19907 TFE_OpAddInput(op.get(), ms.tfe_handle.get(), context::get_status());
19908 status_check(context::get_status());
19910 TFE_OpAddInput(op.get(), mom.tfe_handle.get(), context::get_status());
19911 status_check(context::get_status());
19913 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19914 status_check(context::get_status());
19916 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
19917 status_check(context::get_status());
19919 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
19920 status_check(context::get_status());
19922 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
19923 status_check(context::get_status());
19925 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19926 status_check(context::get_status());
19928 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19929 status_check(context::get_status());
19932 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19933 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19936 int num_outputs_op = 1;
19937 TFE_TensorHandle* res[1] = {
nullptr};
19938 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19939 status_check(context::get_status());
19940 return tensor(res[0]);
19943 inline tensor sparse_apply_ftrl(
const tensor& var,
const tensor& accum,
const tensor& linear,
const tensor& grad,
19944 const tensor& indices,
const tensor& lr,
const tensor& l1,
const tensor& l2,
19945 const tensor& lr_power, datatype Tindices,
bool use_locking =
false,
19946 bool multiply_linear_by_lr =
false) {
19948 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
19949 TFE_NewOp(context::get_context(),
"SparseApplyFtrl", context::get_status()), &TFE_DeleteOp);
19950 status_check(context::get_status());
19954 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
19955 status_check(context::get_status());
19957 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
19958 status_check(context::get_status());
19960 TFE_OpAddInput(op.get(), linear.tfe_handle.get(), context::get_status());
19961 status_check(context::get_status());
19963 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
19964 status_check(context::get_status());
19966 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
19967 status_check(context::get_status());
19969 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
19970 status_check(context::get_status());
19972 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
19973 status_check(context::get_status());
19975 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
19976 status_check(context::get_status());
19978 TFE_OpAddInput(op.get(), lr_power.tfe_handle.get(), context::get_status());
19979 status_check(context::get_status());
19982 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
19983 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
19984 TFE_OpSetAttrBool(op.get(),
"multiply_linear_by_lr", (
unsigned char)multiply_linear_by_lr);
19987 int num_outputs_op = 1;
19988 TFE_TensorHandle* res[1] = {
nullptr};
19989 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
19990 status_check(context::get_status());
19991 return tensor(res[0]);
19994 inline tensor sparse_apply_ftrl_v2(
const tensor& var,
const tensor& accum,
const tensor& linear,
const tensor& grad,
19995 const tensor& indices,
const tensor& lr,
const tensor& l1,
const tensor& l2,
19996 const tensor& l2_shrinkage,
const tensor& lr_power, datatype Tindices,
19997 bool use_locking =
false,
bool multiply_linear_by_lr =
false) {
19999 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20000 TFE_NewOp(context::get_context(),
"SparseApplyFtrlV2", context::get_status()), &TFE_DeleteOp);
20001 status_check(context::get_status());
20005 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
20006 status_check(context::get_status());
20008 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
20009 status_check(context::get_status());
20011 TFE_OpAddInput(op.get(), linear.tfe_handle.get(), context::get_status());
20012 status_check(context::get_status());
20014 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20015 status_check(context::get_status());
20017 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20018 status_check(context::get_status());
20020 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
20021 status_check(context::get_status());
20023 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
20024 status_check(context::get_status());
20026 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
20027 status_check(context::get_status());
20029 TFE_OpAddInput(op.get(), l2_shrinkage.tfe_handle.get(), context::get_status());
20030 status_check(context::get_status());
20032 TFE_OpAddInput(op.get(), lr_power.tfe_handle.get(), context::get_status());
20033 status_check(context::get_status());
20036 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
20037 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
20038 TFE_OpSetAttrBool(op.get(),
"multiply_linear_by_lr", (
unsigned char)multiply_linear_by_lr);
20041 int num_outputs_op = 1;
20042 TFE_TensorHandle* res[1] = {
nullptr};
20043 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20044 status_check(context::get_status());
20045 return tensor(res[0]);
20048 inline tensor sparse_apply_momentum(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& grad,
20049 const tensor& indices,
const tensor& momentum, datatype Tindices,
20050 bool use_locking =
false,
bool use_nesterov =
false) {
20052 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20053 TFE_NewOp(context::get_context(),
"SparseApplyMomentum", context::get_status()), &TFE_DeleteOp);
20054 status_check(context::get_status());
20058 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
20059 status_check(context::get_status());
20061 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
20062 status_check(context::get_status());
20064 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
20065 status_check(context::get_status());
20067 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20068 status_check(context::get_status());
20070 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20071 status_check(context::get_status());
20073 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
20074 status_check(context::get_status());
20077 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
20078 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
20079 TFE_OpSetAttrBool(op.get(),
"use_nesterov", (
unsigned char)use_nesterov);
20082 int num_outputs_op = 1;
20083 TFE_TensorHandle* res[1] = {
nullptr};
20084 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20085 status_check(context::get_status());
20086 return tensor(res[0]);
20089 inline tensor sparse_apply_proximal_adagrad(
const tensor& var,
const tensor& accum,
const tensor& lr,
const tensor& l1,
20090 const tensor& l2,
const tensor& grad,
const tensor& indices,
20091 datatype Tindices,
bool use_locking =
false) {
20093 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20094 TFE_NewOp(context::get_context(),
"SparseApplyProximalAdagrad", context::get_status()), &TFE_DeleteOp);
20095 status_check(context::get_status());
20099 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
20100 status_check(context::get_status());
20102 TFE_OpAddInput(op.get(), accum.tfe_handle.get(), context::get_status());
20103 status_check(context::get_status());
20105 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
20106 status_check(context::get_status());
20108 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
20109 status_check(context::get_status());
20111 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
20112 status_check(context::get_status());
20114 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20115 status_check(context::get_status());
20117 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20118 status_check(context::get_status());
20121 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
20122 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
20125 int num_outputs_op = 1;
20126 TFE_TensorHandle* res[1] = {
nullptr};
20127 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20128 status_check(context::get_status());
20129 return tensor(res[0]);
20132 inline tensor sparse_apply_proximal_gradient_descent(
const tensor& var,
const tensor& alpha,
const tensor& l1,
20133 const tensor& l2,
const tensor& grad,
const tensor& indices,
20134 datatype Tindices,
bool use_locking =
false) {
20136 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20137 TFE_NewOp(context::get_context(),
"SparseApplyProximalGradientDescent", context::get_status()), &TFE_DeleteOp);
20138 status_check(context::get_status());
20142 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
20143 status_check(context::get_status());
20145 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
20146 status_check(context::get_status());
20148 TFE_OpAddInput(op.get(), l1.tfe_handle.get(), context::get_status());
20149 status_check(context::get_status());
20151 TFE_OpAddInput(op.get(), l2.tfe_handle.get(), context::get_status());
20152 status_check(context::get_status());
20154 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20155 status_check(context::get_status());
20157 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20158 status_check(context::get_status());
20161 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
20162 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
20165 int num_outputs_op = 1;
20166 TFE_TensorHandle* res[1] = {
nullptr};
20167 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20168 status_check(context::get_status());
20169 return tensor(res[0]);
20172 inline tensor sparse_apply_r_m_s_prop(
const tensor& var,
const tensor& ms,
const tensor& mom,
const tensor& lr,
20173 const tensor& rho,
const tensor& momentum,
const tensor& epsilon,
20174 const tensor& grad,
const tensor& indices, datatype Tindices,
20175 bool use_locking =
false) {
20177 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20178 TFE_NewOp(context::get_context(),
"SparseApplyRMSProp", context::get_status()), &TFE_DeleteOp);
20179 status_check(context::get_status());
20183 TFE_OpAddInput(op.get(), var.tfe_handle.get(), context::get_status());
20184 status_check(context::get_status());
20186 TFE_OpAddInput(op.get(), ms.tfe_handle.get(), context::get_status());
20187 status_check(context::get_status());
20189 TFE_OpAddInput(op.get(), mom.tfe_handle.get(), context::get_status());
20190 status_check(context::get_status());
20192 TFE_OpAddInput(op.get(), lr.tfe_handle.get(), context::get_status());
20193 status_check(context::get_status());
20195 TFE_OpAddInput(op.get(), rho.tfe_handle.get(), context::get_status());
20196 status_check(context::get_status());
20198 TFE_OpAddInput(op.get(), momentum.tfe_handle.get(), context::get_status());
20199 status_check(context::get_status());
20201 TFE_OpAddInput(op.get(), epsilon.tfe_handle.get(), context::get_status());
20202 status_check(context::get_status());
20204 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20205 status_check(context::get_status());
20207 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20208 status_check(context::get_status());
20211 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
20212 TFE_OpSetAttrBool(op.get(),
"use_locking", (
unsigned char)use_locking);
20215 int num_outputs_op = 1;
20216 TFE_TensorHandle* res[1] = {
nullptr};
20217 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20218 status_check(context::get_status());
20219 return tensor(res[0]);
20222 inline tensor sparse_bincount(
const tensor& indices,
const tensor& values,
const tensor& dense_shape,
20223 const tensor& size,
const tensor& weights, datatype Tidx,
bool binary_output =
false) {
20225 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20226 TFE_NewOp(context::get_context(),
"SparseBincount", context::get_status()), &TFE_DeleteOp);
20227 status_check(context::get_status());
20231 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20232 status_check(context::get_status());
20234 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
20235 status_check(context::get_status());
20237 TFE_OpAddInput(op.get(), dense_shape.tfe_handle.get(), context::get_status());
20238 status_check(context::get_status());
20240 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
20241 status_check(context::get_status());
20243 TFE_OpAddInput(op.get(), weights.tfe_handle.get(), context::get_status());
20244 status_check(context::get_status());
20247 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20248 TFE_OpSetAttrBool(op.get(),
"binary_output", (
unsigned char)binary_output);
20251 int num_outputs_op = 1;
20252 TFE_TensorHandle* res[1] = {
nullptr};
20253 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20254 status_check(context::get_status());
20255 return tensor(res[0]);
20258 inline tensor sparse_conditional_accumulator(datatype dtype,
const std::vector<int64_t>& shape,
20259 const std::string& container =
"",
const std::string& shared_name =
"",
20260 const std::string& reduction_type =
"MEAN") {
20262 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20263 TFE_NewOp(context::get_context(),
"SparseConditionalAccumulator", context::get_status()), &TFE_DeleteOp);
20264 status_check(context::get_status());
20269 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
20271 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
20272 status_check(context::get_status());
20274 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
20275 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
20276 TFE_OpSetAttrString(op.get(),
"reduction_type", (
void*)reduction_type.c_str(), reduction_type.size());
20279 int num_outputs_op = 1;
20280 TFE_TensorHandle* res[1] = {
nullptr};
20281 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20282 status_check(context::get_status());
20283 return tensor(res[0]);
20286 inline tensor sparse_dense_cwise_add(
const tensor& sp_indices,
const tensor& sp_values,
const tensor& sp_shape,
20287 const tensor& dense) {
20289 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20290 TFE_NewOp(context::get_context(),
"SparseDenseCwiseAdd", context::get_status()), &TFE_DeleteOp);
20291 status_check(context::get_status());
20295 TFE_OpAddInput(op.get(), sp_indices.tfe_handle.get(), context::get_status());
20296 status_check(context::get_status());
20298 TFE_OpAddInput(op.get(), sp_values.tfe_handle.get(), context::get_status());
20299 status_check(context::get_status());
20301 TFE_OpAddInput(op.get(), sp_shape.tfe_handle.get(), context::get_status());
20302 status_check(context::get_status());
20304 TFE_OpAddInput(op.get(), dense.tfe_handle.get(), context::get_status());
20305 status_check(context::get_status());
20310 int num_outputs_op = 1;
20311 TFE_TensorHandle* res[1] = {
nullptr};
20312 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20313 status_check(context::get_status());
20314 return tensor(res[0]);
20317 inline tensor sparse_dense_cwise_div(
const tensor& sp_indices,
const tensor& sp_values,
const tensor& sp_shape,
20318 const tensor& dense) {
20320 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20321 TFE_NewOp(context::get_context(),
"SparseDenseCwiseDiv", context::get_status()), &TFE_DeleteOp);
20322 status_check(context::get_status());
20326 TFE_OpAddInput(op.get(), sp_indices.tfe_handle.get(), context::get_status());
20327 status_check(context::get_status());
20329 TFE_OpAddInput(op.get(), sp_values.tfe_handle.get(), context::get_status());
20330 status_check(context::get_status());
20332 TFE_OpAddInput(op.get(), sp_shape.tfe_handle.get(), context::get_status());
20333 status_check(context::get_status());
20335 TFE_OpAddInput(op.get(), dense.tfe_handle.get(), context::get_status());
20336 status_check(context::get_status());
20341 int num_outputs_op = 1;
20342 TFE_TensorHandle* res[1] = {
nullptr};
20343 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20344 status_check(context::get_status());
20345 return tensor(res[0]);
20348 inline tensor sparse_dense_cwise_mul(
const tensor& sp_indices,
const tensor& sp_values,
const tensor& sp_shape,
20349 const tensor& dense) {
20351 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20352 TFE_NewOp(context::get_context(),
"SparseDenseCwiseMul", context::get_status()), &TFE_DeleteOp);
20353 status_check(context::get_status());
20357 TFE_OpAddInput(op.get(), sp_indices.tfe_handle.get(), context::get_status());
20358 status_check(context::get_status());
20360 TFE_OpAddInput(op.get(), sp_values.tfe_handle.get(), context::get_status());
20361 status_check(context::get_status());
20363 TFE_OpAddInput(op.get(), sp_shape.tfe_handle.get(), context::get_status());
20364 status_check(context::get_status());
20366 TFE_OpAddInput(op.get(), dense.tfe_handle.get(), context::get_status());
20367 status_check(context::get_status());
20372 int num_outputs_op = 1;
20373 TFE_TensorHandle* res[1] = {
nullptr};
20374 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20375 status_check(context::get_status());
20376 return tensor(res[0]);
20379 inline tensor sparse_mat_mul(
const tensor& a,
const tensor& b,
bool transpose_a =
false,
bool transpose_b =
false,
20380 bool a_is_sparse =
false,
bool b_is_sparse =
false, datatype Ta =
static_cast<datatype
>(1),
20381 datatype Tb =
static_cast<datatype
>(1)) {
20383 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20384 TFE_NewOp(context::get_context(),
"SparseMatMul", context::get_status()), &TFE_DeleteOp);
20385 status_check(context::get_status());
20389 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
20390 status_check(context::get_status());
20392 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
20393 status_check(context::get_status());
20396 TFE_OpSetAttrBool(op.get(),
"transpose_a", (
unsigned char)transpose_a);
20397 TFE_OpSetAttrBool(op.get(),
"transpose_b", (
unsigned char)transpose_b);
20398 TFE_OpSetAttrBool(op.get(),
"a_is_sparse", (
unsigned char)a_is_sparse);
20399 TFE_OpSetAttrBool(op.get(),
"b_is_sparse", (
unsigned char)b_is_sparse);
20400 TFE_OpSetAttrType(op.get(),
"Ta", Ta);
20401 TFE_OpSetAttrType(op.get(),
"Tb", Tb);
20404 int num_outputs_op = 1;
20405 TFE_TensorHandle* res[1] = {
nullptr};
20406 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20407 status_check(context::get_status());
20408 return tensor(res[0]);
20411 inline tensor sparse_matrix_add(
const tensor& a,
const tensor& b,
const tensor& alpha,
const tensor& beta) {
20413 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20414 TFE_NewOp(context::get_context(),
"SparseMatrixAdd", context::get_status()), &TFE_DeleteOp);
20415 status_check(context::get_status());
20419 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
20420 status_check(context::get_status());
20422 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
20423 status_check(context::get_status());
20425 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
20426 status_check(context::get_status());
20428 TFE_OpAddInput(op.get(), beta.tfe_handle.get(), context::get_status());
20429 status_check(context::get_status());
20434 int num_outputs_op = 1;
20435 TFE_TensorHandle* res[1] = {
nullptr};
20436 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20437 status_check(context::get_status());
20438 return tensor(res[0]);
20441 inline tensor sparse_matrix_mat_mul(
const tensor& a,
const tensor& b,
bool transpose_a =
false,
20442 bool transpose_b =
false,
bool adjoint_a =
false,
bool adjoint_b =
false,
20443 bool transpose_output =
false,
bool conjugate_output =
false) {
20445 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20446 TFE_NewOp(context::get_context(),
"SparseMatrixMatMul", context::get_status()), &TFE_DeleteOp);
20447 status_check(context::get_status());
20451 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
20452 status_check(context::get_status());
20454 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
20455 status_check(context::get_status());
20458 TFE_OpSetAttrBool(op.get(),
"transpose_a", (
unsigned char)transpose_a);
20459 TFE_OpSetAttrBool(op.get(),
"transpose_b", (
unsigned char)transpose_b);
20460 TFE_OpSetAttrBool(op.get(),
"adjoint_a", (
unsigned char)adjoint_a);
20461 TFE_OpSetAttrBool(op.get(),
"adjoint_b", (
unsigned char)adjoint_b);
20462 TFE_OpSetAttrBool(op.get(),
"transpose_output", (
unsigned char)transpose_output);
20463 TFE_OpSetAttrBool(op.get(),
"conjugate_output", (
unsigned char)conjugate_output);
20466 int num_outputs_op = 1;
20467 TFE_TensorHandle* res[1] = {
nullptr};
20468 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20469 status_check(context::get_status());
20470 return tensor(res[0]);
20473 inline tensor sparse_matrix_mul(
const tensor& a,
const tensor& b) {
20475 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20476 TFE_NewOp(context::get_context(),
"SparseMatrixMul", context::get_status()), &TFE_DeleteOp);
20477 status_check(context::get_status());
20481 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
20482 status_check(context::get_status());
20484 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
20485 status_check(context::get_status());
20490 int num_outputs_op = 1;
20491 TFE_TensorHandle* res[1] = {
nullptr};
20492 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20493 status_check(context::get_status());
20494 return tensor(res[0]);
20497 inline tensor sparse_matrix_n_n_z(
const tensor& sparse_matrix) {
20499 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20500 TFE_NewOp(context::get_context(),
"SparseMatrixNNZ", context::get_status()), &TFE_DeleteOp);
20501 status_check(context::get_status());
20505 TFE_OpAddInput(op.get(), sparse_matrix.tfe_handle.get(), context::get_status());
20506 status_check(context::get_status());
20511 int num_outputs_op = 1;
20512 TFE_TensorHandle* res[1] = {
nullptr};
20513 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20514 status_check(context::get_status());
20515 return tensor(res[0]);
20518 inline tensor sparse_matrix_ordering_a_m_d(
const tensor& input) {
20520 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20521 TFE_NewOp(context::get_context(),
"SparseMatrixOrderingAMD", context::get_status()), &TFE_DeleteOp);
20522 status_check(context::get_status());
20526 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
20527 status_check(context::get_status());
20532 int num_outputs_op = 1;
20533 TFE_TensorHandle* res[1] = {
nullptr};
20534 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20535 status_check(context::get_status());
20536 return tensor(res[0]);
20539 inline tensor sparse_matrix_softmax(
const tensor& logits, datatype type) {
20541 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20542 TFE_NewOp(context::get_context(),
"SparseMatrixSoftmax", context::get_status()), &TFE_DeleteOp);
20543 status_check(context::get_status());
20547 TFE_OpAddInput(op.get(), logits.tfe_handle.get(), context::get_status());
20548 status_check(context::get_status());
20551 TFE_OpSetAttrType(op.get(),
"type", type);
20554 int num_outputs_op = 1;
20555 TFE_TensorHandle* res[1] = {
nullptr};
20556 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20557 status_check(context::get_status());
20558 return tensor(res[0]);
20561 inline tensor sparse_matrix_softmax_grad(
const tensor& softmax,
const tensor& grad_softmax, datatype type) {
20563 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20564 TFE_NewOp(context::get_context(),
"SparseMatrixSoftmaxGrad", context::get_status()), &TFE_DeleteOp);
20565 status_check(context::get_status());
20569 TFE_OpAddInput(op.get(), softmax.tfe_handle.get(), context::get_status());
20570 status_check(context::get_status());
20572 TFE_OpAddInput(op.get(), grad_softmax.tfe_handle.get(), context::get_status());
20573 status_check(context::get_status());
20576 TFE_OpSetAttrType(op.get(),
"type", type);
20579 int num_outputs_op = 1;
20580 TFE_TensorHandle* res[1] = {
nullptr};
20581 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20582 status_check(context::get_status());
20583 return tensor(res[0]);
20586 inline tensor sparse_matrix_sparse_cholesky(
const tensor& input,
const tensor& permutation, datatype type) {
20588 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20589 TFE_NewOp(context::get_context(),
"SparseMatrixSparseCholesky", context::get_status()), &TFE_DeleteOp);
20590 status_check(context::get_status());
20594 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
20595 status_check(context::get_status());
20597 TFE_OpAddInput(op.get(), permutation.tfe_handle.get(), context::get_status());
20598 status_check(context::get_status());
20601 TFE_OpSetAttrType(op.get(),
"type", type);
20604 int num_outputs_op = 1;
20605 TFE_TensorHandle* res[1] = {
nullptr};
20606 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20607 status_check(context::get_status());
20608 return tensor(res[0]);
20611 inline tensor sparse_matrix_sparse_mat_mul(
const tensor& a,
const tensor& b, datatype type,
bool transpose_a =
false,
20612 bool transpose_b =
false,
bool adjoint_a =
false,
bool adjoint_b =
false) {
20614 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20615 TFE_NewOp(context::get_context(),
"SparseMatrixSparseMatMul", context::get_status()), &TFE_DeleteOp);
20616 status_check(context::get_status());
20620 TFE_OpAddInput(op.get(), a.tfe_handle.get(), context::get_status());
20621 status_check(context::get_status());
20623 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
20624 status_check(context::get_status());
20627 TFE_OpSetAttrType(op.get(),
"type", type);
20628 TFE_OpSetAttrBool(op.get(),
"transpose_a", (
unsigned char)transpose_a);
20629 TFE_OpSetAttrBool(op.get(),
"transpose_b", (
unsigned char)transpose_b);
20630 TFE_OpSetAttrBool(op.get(),
"adjoint_a", (
unsigned char)adjoint_a);
20631 TFE_OpSetAttrBool(op.get(),
"adjoint_b", (
unsigned char)adjoint_b);
20634 int num_outputs_op = 1;
20635 TFE_TensorHandle* res[1] = {
nullptr};
20636 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20637 status_check(context::get_status());
20638 return tensor(res[0]);
20641 inline tensor sparse_matrix_transpose(
const tensor& input, datatype type,
bool conjugate =
false) {
20643 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20644 TFE_NewOp(context::get_context(),
"SparseMatrixTranspose", context::get_status()), &TFE_DeleteOp);
20645 status_check(context::get_status());
20649 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
20650 status_check(context::get_status());
20653 TFE_OpSetAttrType(op.get(),
"type", type);
20654 TFE_OpSetAttrBool(op.get(),
"conjugate", (
unsigned char)conjugate);
20657 int num_outputs_op = 1;
20658 TFE_TensorHandle* res[1] = {
nullptr};
20659 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20660 status_check(context::get_status());
20661 return tensor(res[0]);
20664 inline tensor sparse_matrix_zeros(
const tensor& dense_shape, datatype type) {
20666 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20667 TFE_NewOp(context::get_context(),
"SparseMatrixZeros", context::get_status()), &TFE_DeleteOp);
20668 status_check(context::get_status());
20672 TFE_OpAddInput(op.get(), dense_shape.tfe_handle.get(), context::get_status());
20673 status_check(context::get_status());
20676 TFE_OpSetAttrType(op.get(),
"type", type);
20679 int num_outputs_op = 1;
20680 TFE_TensorHandle* res[1] = {
nullptr};
20681 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20682 status_check(context::get_status());
20683 return tensor(res[0]);
20686 inline tensor sparse_reduce_max(
const tensor& input_indices,
const tensor& input_values,
const tensor& input_shape,
20687 const tensor& reduction_axes,
bool keep_dims =
false) {
20689 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20690 TFE_NewOp(context::get_context(),
"SparseReduceMax", context::get_status()), &TFE_DeleteOp);
20691 status_check(context::get_status());
20695 TFE_OpAddInput(op.get(), input_indices.tfe_handle.get(), context::get_status());
20696 status_check(context::get_status());
20698 TFE_OpAddInput(op.get(), input_values.tfe_handle.get(), context::get_status());
20699 status_check(context::get_status());
20701 TFE_OpAddInput(op.get(), input_shape.tfe_handle.get(), context::get_status());
20702 status_check(context::get_status());
20704 TFE_OpAddInput(op.get(), reduction_axes.tfe_handle.get(), context::get_status());
20705 status_check(context::get_status());
20708 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
20711 int num_outputs_op = 1;
20712 TFE_TensorHandle* res[1] = {
nullptr};
20713 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20714 status_check(context::get_status());
20715 return tensor(res[0]);
20718 inline tensor sparse_reduce_sum(
const tensor& input_indices,
const tensor& input_values,
const tensor& input_shape,
20719 const tensor& reduction_axes,
bool keep_dims =
false) {
20721 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20722 TFE_NewOp(context::get_context(),
"SparseReduceSum", context::get_status()), &TFE_DeleteOp);
20723 status_check(context::get_status());
20727 TFE_OpAddInput(op.get(), input_indices.tfe_handle.get(), context::get_status());
20728 status_check(context::get_status());
20730 TFE_OpAddInput(op.get(), input_values.tfe_handle.get(), context::get_status());
20731 status_check(context::get_status());
20733 TFE_OpAddInput(op.get(), input_shape.tfe_handle.get(), context::get_status());
20734 status_check(context::get_status());
20736 TFE_OpAddInput(op.get(), reduction_axes.tfe_handle.get(), context::get_status());
20737 status_check(context::get_status());
20740 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
20743 int num_outputs_op = 1;
20744 TFE_TensorHandle* res[1] = {
nullptr};
20745 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20746 status_check(context::get_status());
20747 return tensor(res[0]);
20750 inline tensor sparse_segment_mean(
const tensor& data,
const tensor& indices,
const tensor& segment_ids,
20751 datatype Tidx =
static_cast<datatype
>(3),
20752 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20754 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20755 TFE_NewOp(context::get_context(),
"SparseSegmentMean", context::get_status()), &TFE_DeleteOp);
20756 status_check(context::get_status());
20760 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20761 status_check(context::get_status());
20763 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20764 status_check(context::get_status());
20766 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20767 status_check(context::get_status());
20770 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20771 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20774 int num_outputs_op = 1;
20775 TFE_TensorHandle* res[1] = {
nullptr};
20776 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20777 status_check(context::get_status());
20778 return tensor(res[0]);
20781 inline tensor sparse_segment_mean_grad(
const tensor& grad,
const tensor& indices,
const tensor& segment_ids,
20782 const tensor& output_dim0, datatype Tidx =
static_cast<datatype
>(3),
20783 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20785 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20786 TFE_NewOp(context::get_context(),
"SparseSegmentMeanGrad", context::get_status()), &TFE_DeleteOp);
20787 status_check(context::get_status());
20791 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20792 status_check(context::get_status());
20794 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20795 status_check(context::get_status());
20797 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20798 status_check(context::get_status());
20800 TFE_OpAddInput(op.get(), output_dim0.tfe_handle.get(), context::get_status());
20801 status_check(context::get_status());
20804 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20805 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20808 int num_outputs_op = 1;
20809 TFE_TensorHandle* res[1] = {
nullptr};
20810 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20811 status_check(context::get_status());
20812 return tensor(res[0]);
20815 inline tensor sparse_segment_mean_with_num_segments(
const tensor& data,
const tensor& indices,
20816 const tensor& segment_ids,
const tensor& num_segments,
20817 datatype Tidx =
static_cast<datatype
>(3),
20818 datatype Tnumsegments =
static_cast<datatype
>(3),
20819 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20821 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20822 TFE_NewOp(context::get_context(),
"SparseSegmentMeanWithNumSegments", context::get_status()), &TFE_DeleteOp);
20823 status_check(context::get_status());
20827 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20828 status_check(context::get_status());
20830 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20831 status_check(context::get_status());
20833 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20834 status_check(context::get_status());
20836 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
20837 status_check(context::get_status());
20840 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20841 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
20842 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20845 int num_outputs_op = 1;
20846 TFE_TensorHandle* res[1] = {
nullptr};
20847 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20848 status_check(context::get_status());
20849 return tensor(res[0]);
20852 inline tensor sparse_segment_sqrt_n(
const tensor& data,
const tensor& indices,
const tensor& segment_ids,
20853 datatype Tidx =
static_cast<datatype
>(3),
20854 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20856 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20857 TFE_NewOp(context::get_context(),
"SparseSegmentSqrtN", context::get_status()), &TFE_DeleteOp);
20858 status_check(context::get_status());
20862 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20863 status_check(context::get_status());
20865 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20866 status_check(context::get_status());
20868 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20869 status_check(context::get_status());
20872 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20873 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20876 int num_outputs_op = 1;
20877 TFE_TensorHandle* res[1] = {
nullptr};
20878 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20879 status_check(context::get_status());
20880 return tensor(res[0]);
20883 inline tensor sparse_segment_sqrt_n_grad(
const tensor& grad,
const tensor& indices,
const tensor& segment_ids,
20884 const tensor& output_dim0, datatype Tidx =
static_cast<datatype
>(3),
20885 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20887 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20888 TFE_NewOp(context::get_context(),
"SparseSegmentSqrtNGrad", context::get_status()), &TFE_DeleteOp);
20889 status_check(context::get_status());
20893 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
20894 status_check(context::get_status());
20896 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20897 status_check(context::get_status());
20899 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20900 status_check(context::get_status());
20902 TFE_OpAddInput(op.get(), output_dim0.tfe_handle.get(), context::get_status());
20903 status_check(context::get_status());
20906 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20907 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20910 int num_outputs_op = 1;
20911 TFE_TensorHandle* res[1] = {
nullptr};
20912 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20913 status_check(context::get_status());
20914 return tensor(res[0]);
20917 inline tensor sparse_segment_sqrt_n_with_num_segments(
const tensor& data,
const tensor& indices,
20918 const tensor& segment_ids,
const tensor& num_segments,
20919 datatype Tidx =
static_cast<datatype
>(3),
20920 datatype Tnumsegments =
static_cast<datatype
>(3),
20921 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20923 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20924 TFE_NewOp(context::get_context(),
"SparseSegmentSqrtNWithNumSegments", context::get_status()), &TFE_DeleteOp);
20925 status_check(context::get_status());
20929 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20930 status_check(context::get_status());
20932 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20933 status_check(context::get_status());
20935 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20936 status_check(context::get_status());
20938 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
20939 status_check(context::get_status());
20942 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20943 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
20944 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20947 int num_outputs_op = 1;
20948 TFE_TensorHandle* res[1] = {
nullptr};
20949 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20950 status_check(context::get_status());
20951 return tensor(res[0]);
20954 inline tensor sparse_segment_sum(
const tensor& data,
const tensor& indices,
const tensor& segment_ids,
20955 datatype Tidx =
static_cast<datatype
>(3),
20956 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20958 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20959 TFE_NewOp(context::get_context(),
"SparseSegmentSum", context::get_status()), &TFE_DeleteOp);
20960 status_check(context::get_status());
20964 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20965 status_check(context::get_status());
20967 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
20968 status_check(context::get_status());
20970 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
20971 status_check(context::get_status());
20974 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
20975 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
20978 int num_outputs_op = 1;
20979 TFE_TensorHandle* res[1] = {
nullptr};
20980 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
20981 status_check(context::get_status());
20982 return tensor(res[0]);
20985 inline tensor sparse_segment_sum_with_num_segments(
const tensor& data,
const tensor& indices,
const tensor& segment_ids,
20986 const tensor& num_segments, datatype Tidx =
static_cast<datatype
>(3),
20987 datatype Tnumsegments =
static_cast<datatype
>(3),
20988 datatype Tsegmentids =
static_cast<datatype
>(3)) {
20990 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
20991 TFE_NewOp(context::get_context(),
"SparseSegmentSumWithNumSegments", context::get_status()), &TFE_DeleteOp);
20992 status_check(context::get_status());
20996 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
20997 status_check(context::get_status());
20999 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
21000 status_check(context::get_status());
21002 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
21003 status_check(context::get_status());
21005 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
21006 status_check(context::get_status());
21009 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
21010 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
21011 TFE_OpSetAttrType(op.get(),
"Tsegmentids", Tsegmentids);
21014 int num_outputs_op = 1;
21015 TFE_TensorHandle* res[1] = {
nullptr};
21016 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21017 status_check(context::get_status());
21018 return tensor(res[0]);
21021 inline tensor sparse_slice_grad(
const tensor& backprop_val_grad,
const tensor& input_indices,
const tensor& input_start,
21022 const tensor& output_indices) {
21024 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21025 TFE_NewOp(context::get_context(),
"SparseSliceGrad", context::get_status()), &TFE_DeleteOp);
21026 status_check(context::get_status());
21030 TFE_OpAddInput(op.get(), backprop_val_grad.tfe_handle.get(), context::get_status());
21031 status_check(context::get_status());
21033 TFE_OpAddInput(op.get(), input_indices.tfe_handle.get(), context::get_status());
21034 status_check(context::get_status());
21036 TFE_OpAddInput(op.get(), input_start.tfe_handle.get(), context::get_status());
21037 status_check(context::get_status());
21039 TFE_OpAddInput(op.get(), output_indices.tfe_handle.get(), context::get_status());
21040 status_check(context::get_status());
21045 int num_outputs_op = 1;
21046 TFE_TensorHandle* res[1] = {
nullptr};
21047 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21048 status_check(context::get_status());
21049 return tensor(res[0]);
21052 inline tensor sparse_softmax(
const tensor& sp_indices,
const tensor& sp_values,
const tensor& sp_shape) {
21054 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21055 TFE_NewOp(context::get_context(),
"SparseSoftmax", context::get_status()), &TFE_DeleteOp);
21056 status_check(context::get_status());
21060 TFE_OpAddInput(op.get(), sp_indices.tfe_handle.get(), context::get_status());
21061 status_check(context::get_status());
21063 TFE_OpAddInput(op.get(), sp_values.tfe_handle.get(), context::get_status());
21064 status_check(context::get_status());
21066 TFE_OpAddInput(op.get(), sp_shape.tfe_handle.get(), context::get_status());
21067 status_check(context::get_status());
21072 int num_outputs_op = 1;
21073 TFE_TensorHandle* res[1] = {
nullptr};
21074 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21075 status_check(context::get_status());
21076 return tensor(res[0]);
21079 inline tensor sparse_tensor_dense_add(
const tensor& a_indices,
const tensor& a_values,
const tensor& a_shape,
21080 const tensor& b, datatype Tindices) {
21082 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21083 TFE_NewOp(context::get_context(),
"SparseTensorDenseAdd", context::get_status()), &TFE_DeleteOp);
21084 status_check(context::get_status());
21088 TFE_OpAddInput(op.get(), a_indices.tfe_handle.get(), context::get_status());
21089 status_check(context::get_status());
21091 TFE_OpAddInput(op.get(), a_values.tfe_handle.get(), context::get_status());
21092 status_check(context::get_status());
21094 TFE_OpAddInput(op.get(), a_shape.tfe_handle.get(), context::get_status());
21095 status_check(context::get_status());
21097 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
21098 status_check(context::get_status());
21101 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
21104 int num_outputs_op = 1;
21105 TFE_TensorHandle* res[1] = {
nullptr};
21106 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21107 status_check(context::get_status());
21108 return tensor(res[0]);
21111 inline tensor sparse_tensor_dense_mat_mul(
const tensor& a_indices,
const tensor& a_values,
const tensor& a_shape,
21112 const tensor& b, datatype Tindices =
static_cast<datatype
>(9),
21113 bool adjoint_a =
false,
bool adjoint_b =
false) {
21115 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21116 TFE_NewOp(context::get_context(),
"SparseTensorDenseMatMul", context::get_status()), &TFE_DeleteOp);
21117 status_check(context::get_status());
21121 TFE_OpAddInput(op.get(), a_indices.tfe_handle.get(), context::get_status());
21122 status_check(context::get_status());
21124 TFE_OpAddInput(op.get(), a_values.tfe_handle.get(), context::get_status());
21125 status_check(context::get_status());
21127 TFE_OpAddInput(op.get(), a_shape.tfe_handle.get(), context::get_status());
21128 status_check(context::get_status());
21130 TFE_OpAddInput(op.get(), b.tfe_handle.get(), context::get_status());
21131 status_check(context::get_status());
21134 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
21135 TFE_OpSetAttrBool(op.get(),
"adjoint_a", (
unsigned char)adjoint_a);
21136 TFE_OpSetAttrBool(op.get(),
"adjoint_b", (
unsigned char)adjoint_b);
21139 int num_outputs_op = 1;
21140 TFE_TensorHandle* res[1] = {
nullptr};
21141 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21142 status_check(context::get_status());
21143 return tensor(res[0]);
21146 inline tensor sparse_tensor_slice_dataset(
const tensor& indices,
const tensor& values,
const tensor& dense_shape,
21147 datatype Tvalues) {
21149 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21150 TFE_NewOp(context::get_context(),
"SparseTensorSliceDataset", context::get_status()), &TFE_DeleteOp);
21151 status_check(context::get_status());
21155 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
21156 status_check(context::get_status());
21158 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
21159 status_check(context::get_status());
21161 TFE_OpAddInput(op.get(), dense_shape.tfe_handle.get(), context::get_status());
21162 status_check(context::get_status());
21165 TFE_OpSetAttrType(op.get(),
"Tvalues", Tvalues);
21168 int num_outputs_op = 1;
21169 TFE_TensorHandle* res[1] = {
nullptr};
21170 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21171 status_check(context::get_status());
21172 return tensor(res[0]);
21175 inline tensor sparse_tensor_to_c_s_r_sparse_matrix(
const tensor& indices,
const tensor& values,
21176 const tensor& dense_shape) {
21178 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21179 TFE_NewOp(context::get_context(),
"SparseTensorToCSRSparseMatrix", context::get_status()), &TFE_DeleteOp);
21180 status_check(context::get_status());
21184 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
21185 status_check(context::get_status());
21187 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
21188 status_check(context::get_status());
21190 TFE_OpAddInput(op.get(), dense_shape.tfe_handle.get(), context::get_status());
21191 status_check(context::get_status());
21196 int num_outputs_op = 1;
21197 TFE_TensorHandle* res[1] = {
nullptr};
21198 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21199 status_check(context::get_status());
21200 return tensor(res[0]);
21203 inline tensor sparse_to_dense(
const tensor& sparse_indices,
const tensor& output_shape,
const tensor& sparse_values,
21204 const tensor& default_value, datatype Tindices,
bool validate_indices =
true) {
21206 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21207 TFE_NewOp(context::get_context(),
"SparseToDense", context::get_status()), &TFE_DeleteOp);
21208 status_check(context::get_status());
21212 TFE_OpAddInput(op.get(), sparse_indices.tfe_handle.get(), context::get_status());
21213 status_check(context::get_status());
21215 TFE_OpAddInput(op.get(), output_shape.tfe_handle.get(), context::get_status());
21216 status_check(context::get_status());
21218 TFE_OpAddInput(op.get(), sparse_values.tfe_handle.get(), context::get_status());
21219 status_check(context::get_status());
21221 TFE_OpAddInput(op.get(), default_value.tfe_handle.get(), context::get_status());
21222 status_check(context::get_status());
21225 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
21226 TFE_OpSetAttrBool(op.get(),
"validate_indices", (
unsigned char)validate_indices);
21229 int num_outputs_op = 1;
21230 TFE_TensorHandle* res[1] = {
nullptr};
21231 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21232 status_check(context::get_status());
21233 return tensor(res[0]);
21236 inline tensor spence(
const tensor& x) {
21238 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21239 TFE_NewOp(context::get_context(),
"Spence", context::get_status()), &TFE_DeleteOp);
21240 status_check(context::get_status());
21244 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
21245 status_check(context::get_status());
21250 int num_outputs_op = 1;
21251 TFE_TensorHandle* res[1] = {
nullptr};
21252 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21253 status_check(context::get_status());
21254 return tensor(res[0]);
21257 inline tensor split(
const tensor& split_dim,
const tensor& value, int64_t num_split) {
21259 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Split", context::get_status()),
21261 status_check(context::get_status());
21265 TFE_OpAddInput(op.get(), split_dim.tfe_handle.get(), context::get_status());
21266 status_check(context::get_status());
21268 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
21269 status_check(context::get_status());
21272 TFE_OpSetAttrInt(op.get(),
"num_split", num_split);
21275 int num_outputs_op = 1;
21276 TFE_TensorHandle* res[1] = {
nullptr};
21277 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21278 status_check(context::get_status());
21279 return tensor(res[0]);
21282 inline tensor split_v(
const tensor& value,
const tensor& size_splits,
const tensor& split_dim, int64_t num_split,
21283 datatype Tlen =
static_cast<datatype
>(9)) {
21285 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21286 TFE_NewOp(context::get_context(),
"SplitV", context::get_status()), &TFE_DeleteOp);
21287 status_check(context::get_status());
21291 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
21292 status_check(context::get_status());
21294 TFE_OpAddInput(op.get(), size_splits.tfe_handle.get(), context::get_status());
21295 status_check(context::get_status());
21297 TFE_OpAddInput(op.get(), split_dim.tfe_handle.get(), context::get_status());
21298 status_check(context::get_status());
21301 TFE_OpSetAttrInt(op.get(),
"num_split", num_split);
21302 TFE_OpSetAttrType(op.get(),
"Tlen", Tlen);
21305 int num_outputs_op = 1;
21306 TFE_TensorHandle* res[1] = {
nullptr};
21307 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21308 status_check(context::get_status());
21309 return tensor(res[0]);
21312 inline tensor sql_dataset(
const tensor& driver_name,
const tensor& data_source_name,
const tensor& query,
21313 const std::vector<datatype>& output_types,
21314 const std::vector<std::vector<int64_t>>& output_shapes) {
21316 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21317 TFE_NewOp(context::get_context(),
"SqlDataset", context::get_status()), &TFE_DeleteOp);
21318 status_check(context::get_status());
21322 TFE_OpAddInput(op.get(), driver_name.tfe_handle.get(), context::get_status());
21323 status_check(context::get_status());
21325 TFE_OpAddInput(op.get(), data_source_name.tfe_handle.get(), context::get_status());
21326 status_check(context::get_status());
21328 TFE_OpAddInput(op.get(), query.tfe_handle.get(), context::get_status());
21329 status_check(context::get_status());
21332 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
21333 static_cast<int>(output_types.size()));
21335 std::vector<const int64_t*> output_shapes_values;
21336 output_shapes_values.reserve(output_shapes.size());
21337 std::vector<int> output_shapes_ndims;
21338 output_shapes_ndims.reserve(output_shapes.size());
21339 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
21340 [](
const auto& v) { return v.data(); });
21341 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
21342 [](
const auto& v) { return static_cast<int>(v.size()); });
21343 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
21344 static_cast<int>(output_shapes.size()), context::get_status());
21345 status_check(context::get_status());
21348 int num_outputs_op = 1;
21349 TFE_TensorHandle* res[1] = {
nullptr};
21350 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21351 status_check(context::get_status());
21352 return tensor(res[0]);
21355 inline tensor sqrt(
const tensor& x) {
21357 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sqrt", context::get_status()),
21359 status_check(context::get_status());
21363 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
21364 status_check(context::get_status());
21369 int num_outputs_op = 1;
21370 TFE_TensorHandle* res[1] = {
nullptr};
21371 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21372 status_check(context::get_status());
21373 return tensor(res[0]);
21376 inline tensor sqrt_grad(
const tensor& y,
const tensor& dy) {
21378 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21379 TFE_NewOp(context::get_context(),
"SqrtGrad", context::get_status()), &TFE_DeleteOp);
21380 status_check(context::get_status());
21384 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
21385 status_check(context::get_status());
21387 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
21388 status_check(context::get_status());
21393 int num_outputs_op = 1;
21394 TFE_TensorHandle* res[1] = {
nullptr};
21395 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21396 status_check(context::get_status());
21397 return tensor(res[0]);
21400 inline tensor square(
const tensor& x) {
21402 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21403 TFE_NewOp(context::get_context(),
"Square", context::get_status()), &TFE_DeleteOp);
21404 status_check(context::get_status());
21408 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
21409 status_check(context::get_status());
21414 int num_outputs_op = 1;
21415 TFE_TensorHandle* res[1] = {
nullptr};
21416 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21417 status_check(context::get_status());
21418 return tensor(res[0]);
21421 inline tensor squared_difference(
const tensor& x,
const tensor& y) {
21423 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21424 TFE_NewOp(context::get_context(),
"SquaredDifference", context::get_status()), &TFE_DeleteOp);
21425 status_check(context::get_status());
21429 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
21430 status_check(context::get_status());
21432 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
21433 status_check(context::get_status());
21438 int num_outputs_op = 1;
21439 TFE_TensorHandle* res[1] = {
nullptr};
21440 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21441 status_check(context::get_status());
21442 return tensor(res[0]);
21445 inline tensor squeeze(
const tensor& input,
const std::vector<int64_t>& squeeze_dims) {
21447 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21448 TFE_NewOp(context::get_context(),
"Squeeze", context::get_status()), &TFE_DeleteOp);
21449 status_check(context::get_status());
21453 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
21454 status_check(context::get_status());
21457 TFE_OpSetAttrIntList(op.get(),
"squeeze_dims", squeeze_dims.data(),
static_cast<int>(squeeze_dims.size()));
21460 int num_outputs_op = 1;
21461 TFE_TensorHandle* res[1] = {
nullptr};
21462 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21463 status_check(context::get_status());
21464 return tensor(res[0]);
21467 inline tensor stack(datatype elem_type,
const std::string& stack_name =
"") {
21469 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Stack", context::get_status()),
21471 status_check(context::get_status());
21476 TFE_OpSetAttrType(op.get(),
"elem_type", elem_type);
21477 TFE_OpSetAttrString(op.get(),
"stack_name", (
void*)stack_name.c_str(), stack_name.size());
21480 int num_outputs_op = 1;
21481 TFE_TensorHandle* res[1] = {
nullptr};
21482 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21483 status_check(context::get_status());
21484 return tensor(res[0]);
21487 inline tensor stack_pop(
const tensor& handle, datatype elem_type) {
21489 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21490 TFE_NewOp(context::get_context(),
"StackPop", context::get_status()), &TFE_DeleteOp);
21491 status_check(context::get_status());
21495 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
21496 status_check(context::get_status());
21499 TFE_OpSetAttrType(op.get(),
"elem_type", elem_type);
21502 int num_outputs_op = 1;
21503 TFE_TensorHandle* res[1] = {
nullptr};
21504 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21505 status_check(context::get_status());
21506 return tensor(res[0]);
21509 inline tensor stack_pop_v2(
const tensor& handle, datatype elem_type) {
21511 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21512 TFE_NewOp(context::get_context(),
"StackPopV2", context::get_status()), &TFE_DeleteOp);
21513 status_check(context::get_status());
21517 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
21518 status_check(context::get_status());
21521 TFE_OpSetAttrType(op.get(),
"elem_type", elem_type);
21524 int num_outputs_op = 1;
21525 TFE_TensorHandle* res[1] = {
nullptr};
21526 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21527 status_check(context::get_status());
21528 return tensor(res[0]);
21531 inline tensor stack_push(
const tensor& handle,
const tensor& elem,
bool swap_memory =
false) {
21533 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21534 TFE_NewOp(context::get_context(),
"StackPush", context::get_status()), &TFE_DeleteOp);
21535 status_check(context::get_status());
21539 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
21540 status_check(context::get_status());
21542 TFE_OpAddInput(op.get(), elem.tfe_handle.get(), context::get_status());
21543 status_check(context::get_status());
21546 TFE_OpSetAttrBool(op.get(),
"swap_memory", (
unsigned char)swap_memory);
21549 int num_outputs_op = 1;
21550 TFE_TensorHandle* res[1] = {
nullptr};
21551 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21552 status_check(context::get_status());
21553 return tensor(res[0]);
21556 inline tensor stack_push_v2(
const tensor& handle,
const tensor& elem,
bool swap_memory =
false) {
21558 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21559 TFE_NewOp(context::get_context(),
"StackPushV2", context::get_status()), &TFE_DeleteOp);
21560 status_check(context::get_status());
21564 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
21565 status_check(context::get_status());
21567 TFE_OpAddInput(op.get(), elem.tfe_handle.get(), context::get_status());
21568 status_check(context::get_status());
21571 TFE_OpSetAttrBool(op.get(),
"swap_memory", (
unsigned char)swap_memory);
21574 int num_outputs_op = 1;
21575 TFE_TensorHandle* res[1] = {
nullptr};
21576 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21577 status_check(context::get_status());
21578 return tensor(res[0]);
21581 inline tensor stack_v2(
const tensor& max_size, datatype elem_type,
const std::string& stack_name =
"") {
21583 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21584 TFE_NewOp(context::get_context(),
"StackV2", context::get_status()), &TFE_DeleteOp);
21585 status_check(context::get_status());
21589 TFE_OpAddInput(op.get(), max_size.tfe_handle.get(), context::get_status());
21590 status_check(context::get_status());
21593 TFE_OpSetAttrType(op.get(),
"elem_type", elem_type);
21594 TFE_OpSetAttrString(op.get(),
"stack_name", (
void*)stack_name.c_str(), stack_name.size());
21597 int num_outputs_op = 1;
21598 TFE_TensorHandle* res[1] = {
nullptr};
21599 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21600 status_check(context::get_status());
21601 return tensor(res[0]);
21604 inline tensor stage_peek(
const tensor& index,
const std::vector<datatype>& dtypes, int64_t capacity = 0,
21605 int64_t memory_limit = 0,
const std::string& container =
"",
21606 const std::string& shared_name =
"") {
21608 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21609 TFE_NewOp(context::get_context(),
"StagePeek", context::get_status()), &TFE_DeleteOp);
21610 status_check(context::get_status());
21614 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
21615 status_check(context::get_status());
21618 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
21619 static_cast<int>(dtypes.size()));
21620 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
21621 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
21622 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
21623 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
21626 int num_outputs_op = 1;
21627 TFE_TensorHandle* res[1] = {
nullptr};
21628 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21629 status_check(context::get_status());
21630 return tensor(res[0]);
21633 inline tensor stage_size(
const std::vector<datatype>& dtypes, int64_t capacity = 0, int64_t memory_limit = 0,
21634 const std::string& container =
"",
const std::string& shared_name =
"") {
21636 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21637 TFE_NewOp(context::get_context(),
"StageSize", context::get_status()), &TFE_DeleteOp);
21638 status_check(context::get_status());
21643 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
21644 static_cast<int>(dtypes.size()));
21645 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
21646 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
21647 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
21648 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
21651 int num_outputs_op = 1;
21652 TFE_TensorHandle* res[1] = {
nullptr};
21653 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21654 status_check(context::get_status());
21655 return tensor(res[0]);
21658 inline tensor stateful_random_binomial(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21659 const tensor& counts,
const tensor& probs, datatype S,
21660 datatype dtype =
static_cast<datatype
>(9)) {
21662 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21663 TFE_NewOp(context::get_context(),
"StatefulRandomBinomial", context::get_status()), &TFE_DeleteOp);
21664 status_check(context::get_status());
21668 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21669 status_check(context::get_status());
21671 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21672 status_check(context::get_status());
21674 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21675 status_check(context::get_status());
21677 TFE_OpAddInput(op.get(), counts.tfe_handle.get(), context::get_status());
21678 status_check(context::get_status());
21680 TFE_OpAddInput(op.get(), probs.tfe_handle.get(), context::get_status());
21681 status_check(context::get_status());
21684 TFE_OpSetAttrType(op.get(),
"S", S);
21685 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21688 int num_outputs_op = 1;
21689 TFE_TensorHandle* res[1] = {
nullptr};
21690 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21691 status_check(context::get_status());
21692 return tensor(res[0]);
21695 inline tensor stateful_standard_normal(
const tensor& resource,
const tensor& shape,
21696 datatype dtype =
static_cast<datatype
>(1),
21697 datatype shape_dtype =
static_cast<datatype
>(9)) {
21699 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21700 TFE_NewOp(context::get_context(),
"StatefulStandardNormal", context::get_status()), &TFE_DeleteOp);
21701 status_check(context::get_status());
21705 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21706 status_check(context::get_status());
21708 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21709 status_check(context::get_status());
21712 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21713 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21716 int num_outputs_op = 1;
21717 TFE_TensorHandle* res[1] = {
nullptr};
21718 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21719 status_check(context::get_status());
21720 return tensor(res[0]);
21723 inline tensor stateful_standard_normal_v2(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21724 datatype dtype =
static_cast<datatype
>(1),
21725 datatype shape_dtype =
static_cast<datatype
>(9)) {
21727 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21728 TFE_NewOp(context::get_context(),
"StatefulStandardNormalV2", context::get_status()), &TFE_DeleteOp);
21729 status_check(context::get_status());
21733 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21734 status_check(context::get_status());
21736 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21737 status_check(context::get_status());
21739 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21740 status_check(context::get_status());
21743 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21744 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21747 int num_outputs_op = 1;
21748 TFE_TensorHandle* res[1] = {
nullptr};
21749 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21750 status_check(context::get_status());
21751 return tensor(res[0]);
21754 inline tensor stateful_truncated_normal(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21755 datatype dtype =
static_cast<datatype
>(1),
21756 datatype shape_dtype =
static_cast<datatype
>(9)) {
21758 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21759 TFE_NewOp(context::get_context(),
"StatefulTruncatedNormal", context::get_status()), &TFE_DeleteOp);
21760 status_check(context::get_status());
21764 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21765 status_check(context::get_status());
21767 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21768 status_check(context::get_status());
21770 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21771 status_check(context::get_status());
21774 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21775 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21778 int num_outputs_op = 1;
21779 TFE_TensorHandle* res[1] = {
nullptr};
21780 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21781 status_check(context::get_status());
21782 return tensor(res[0]);
21785 inline tensor stateful_uniform(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21786 datatype dtype =
static_cast<datatype
>(1),
21787 datatype shape_dtype =
static_cast<datatype
>(9)) {
21789 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21790 TFE_NewOp(context::get_context(),
"StatefulUniform", context::get_status()), &TFE_DeleteOp);
21791 status_check(context::get_status());
21795 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21796 status_check(context::get_status());
21798 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21799 status_check(context::get_status());
21801 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21802 status_check(context::get_status());
21805 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21806 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21809 int num_outputs_op = 1;
21810 TFE_TensorHandle* res[1] = {
nullptr};
21811 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21812 status_check(context::get_status());
21813 return tensor(res[0]);
21816 inline tensor stateful_uniform_full_int(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21817 datatype dtype =
static_cast<datatype
>(23),
21818 datatype shape_dtype =
static_cast<datatype
>(9)) {
21820 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21821 TFE_NewOp(context::get_context(),
"StatefulUniformFullInt", context::get_status()), &TFE_DeleteOp);
21822 status_check(context::get_status());
21826 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21827 status_check(context::get_status());
21829 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21830 status_check(context::get_status());
21832 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21833 status_check(context::get_status());
21836 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21837 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21840 int num_outputs_op = 1;
21841 TFE_TensorHandle* res[1] = {
nullptr};
21842 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21843 status_check(context::get_status());
21844 return tensor(res[0]);
21847 inline tensor stateful_uniform_int(
const tensor& resource,
const tensor& algorithm,
const tensor& shape,
21848 const tensor& minval,
const tensor& maxval,
21849 datatype dtype =
static_cast<datatype
>(9),
21850 datatype shape_dtype =
static_cast<datatype
>(9)) {
21852 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21853 TFE_NewOp(context::get_context(),
"StatefulUniformInt", context::get_status()), &TFE_DeleteOp);
21854 status_check(context::get_status());
21858 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
21859 status_check(context::get_status());
21861 TFE_OpAddInput(op.get(), algorithm.tfe_handle.get(), context::get_status());
21862 status_check(context::get_status());
21864 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21865 status_check(context::get_status());
21867 TFE_OpAddInput(op.get(), minval.tfe_handle.get(), context::get_status());
21868 status_check(context::get_status());
21870 TFE_OpAddInput(op.get(), maxval.tfe_handle.get(), context::get_status());
21871 status_check(context::get_status());
21874 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21875 TFE_OpSetAttrType(op.get(),
"shape_dtype", shape_dtype);
21878 int num_outputs_op = 1;
21879 TFE_TensorHandle* res[1] = {
nullptr};
21880 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21881 status_check(context::get_status());
21882 return tensor(res[0]);
21885 inline tensor stateless_multinomial(
const tensor& logits,
const tensor& num_samples,
const tensor& seed,
21886 datatype Tseed =
static_cast<datatype
>(9),
21887 datatype output_dtype =
static_cast<datatype
>(9)) {
21889 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21890 TFE_NewOp(context::get_context(),
"StatelessMultinomial", context::get_status()), &TFE_DeleteOp);
21891 status_check(context::get_status());
21895 TFE_OpAddInput(op.get(), logits.tfe_handle.get(), context::get_status());
21896 status_check(context::get_status());
21898 TFE_OpAddInput(op.get(), num_samples.tfe_handle.get(), context::get_status());
21899 status_check(context::get_status());
21901 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
21902 status_check(context::get_status());
21905 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
21906 TFE_OpSetAttrType(op.get(),
"output_dtype", output_dtype);
21909 int num_outputs_op = 1;
21910 TFE_TensorHandle* res[1] = {
nullptr};
21911 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21912 status_check(context::get_status());
21913 return tensor(res[0]);
21916 inline tensor stateless_parameterized_truncated_normal(
const tensor& shape,
const tensor& seed,
const tensor& means,
21917 const tensor& stddevs,
const tensor& minvals,
21918 const tensor& maxvals, datatype S, datatype dtype,
21919 datatype Tseed =
static_cast<datatype
>(9)) {
21921 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21922 TFE_NewOp(context::get_context(),
"StatelessParameterizedTruncatedNormal", context::get_status()), &TFE_DeleteOp);
21923 status_check(context::get_status());
21927 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21928 status_check(context::get_status());
21930 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
21931 status_check(context::get_status());
21933 TFE_OpAddInput(op.get(), means.tfe_handle.get(), context::get_status());
21934 status_check(context::get_status());
21936 TFE_OpAddInput(op.get(), stddevs.tfe_handle.get(), context::get_status());
21937 status_check(context::get_status());
21939 TFE_OpAddInput(op.get(), minvals.tfe_handle.get(), context::get_status());
21940 status_check(context::get_status());
21942 TFE_OpAddInput(op.get(), maxvals.tfe_handle.get(), context::get_status());
21943 status_check(context::get_status());
21946 TFE_OpSetAttrType(op.get(),
"S", S);
21947 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21948 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
21951 int num_outputs_op = 1;
21952 TFE_TensorHandle* res[1] = {
nullptr};
21953 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21954 status_check(context::get_status());
21955 return tensor(res[0]);
21958 inline tensor stateless_random_binomial(
const tensor& shape,
const tensor& seed,
const tensor& counts,
21959 const tensor& probs, datatype S, datatype Tseed =
static_cast<datatype
>(9),
21960 datatype dtype =
static_cast<datatype
>(9)) {
21962 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21963 TFE_NewOp(context::get_context(),
"StatelessRandomBinomial", context::get_status()), &TFE_DeleteOp);
21964 status_check(context::get_status());
21968 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
21969 status_check(context::get_status());
21971 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
21972 status_check(context::get_status());
21974 TFE_OpAddInput(op.get(), counts.tfe_handle.get(), context::get_status());
21975 status_check(context::get_status());
21977 TFE_OpAddInput(op.get(), probs.tfe_handle.get(), context::get_status());
21978 status_check(context::get_status());
21981 TFE_OpSetAttrType(op.get(),
"S", S);
21982 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
21983 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
21986 int num_outputs_op = 1;
21987 TFE_TensorHandle* res[1] = {
nullptr};
21988 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
21989 status_check(context::get_status());
21990 return tensor(res[0]);
21993 inline tensor stateless_random_gamma_v2(
const tensor& shape,
const tensor& seed,
const tensor& alpha, datatype dtype,
21994 datatype Tseed =
static_cast<datatype
>(9)) {
21996 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
21997 TFE_NewOp(context::get_context(),
"StatelessRandomGammaV2", context::get_status()), &TFE_DeleteOp);
21998 status_check(context::get_status());
22002 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22003 status_check(context::get_status());
22005 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22006 status_check(context::get_status());
22008 TFE_OpAddInput(op.get(), alpha.tfe_handle.get(), context::get_status());
22009 status_check(context::get_status());
22012 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22013 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22016 int num_outputs_op = 1;
22017 TFE_TensorHandle* res[1] = {
nullptr};
22018 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22019 status_check(context::get_status());
22020 return tensor(res[0]);
22023 inline tensor stateless_random_normal(
const tensor& shape,
const tensor& seed,
22024 datatype dtype =
static_cast<datatype
>(1),
22025 datatype Tseed =
static_cast<datatype
>(9)) {
22027 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22028 TFE_NewOp(context::get_context(),
"StatelessRandomNormal", context::get_status()), &TFE_DeleteOp);
22029 status_check(context::get_status());
22033 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22034 status_check(context::get_status());
22036 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22037 status_check(context::get_status());
22040 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22041 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22044 int num_outputs_op = 1;
22045 TFE_TensorHandle* res[1] = {
nullptr};
22046 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22047 status_check(context::get_status());
22048 return tensor(res[0]);
22051 inline tensor stateless_random_poisson(
const tensor& shape,
const tensor& seed,
const tensor& lam, datatype Rtype,
22052 datatype dtype, datatype Tseed =
static_cast<datatype
>(9)) {
22054 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22055 TFE_NewOp(context::get_context(),
"StatelessRandomPoisson", context::get_status()), &TFE_DeleteOp);
22056 status_check(context::get_status());
22060 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22061 status_check(context::get_status());
22063 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22064 status_check(context::get_status());
22066 TFE_OpAddInput(op.get(), lam.tfe_handle.get(), context::get_status());
22067 status_check(context::get_status());
22070 TFE_OpSetAttrType(op.get(),
"Rtype", Rtype);
22071 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22072 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22075 int num_outputs_op = 1;
22076 TFE_TensorHandle* res[1] = {
nullptr};
22077 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22078 status_check(context::get_status());
22079 return tensor(res[0]);
22082 inline tensor stateless_random_uniform(
const tensor& shape,
const tensor& seed,
22083 datatype dtype =
static_cast<datatype
>(1),
22084 datatype Tseed =
static_cast<datatype
>(9)) {
22086 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22087 TFE_NewOp(context::get_context(),
"StatelessRandomUniform", context::get_status()), &TFE_DeleteOp);
22088 status_check(context::get_status());
22092 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22093 status_check(context::get_status());
22095 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22096 status_check(context::get_status());
22099 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22100 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22103 int num_outputs_op = 1;
22104 TFE_TensorHandle* res[1] = {
nullptr};
22105 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22106 status_check(context::get_status());
22107 return tensor(res[0]);
22110 inline tensor stateless_random_uniform_full_int(
const tensor& shape,
const tensor& seed,
22111 datatype dtype =
static_cast<datatype
>(23),
22112 datatype Tseed =
static_cast<datatype
>(9)) {
22114 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22115 TFE_NewOp(context::get_context(),
"StatelessRandomUniformFullInt", context::get_status()), &TFE_DeleteOp);
22116 status_check(context::get_status());
22120 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22121 status_check(context::get_status());
22123 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22124 status_check(context::get_status());
22127 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22128 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22131 int num_outputs_op = 1;
22132 TFE_TensorHandle* res[1] = {
nullptr};
22133 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22134 status_check(context::get_status());
22135 return tensor(res[0]);
22138 inline tensor stateless_random_uniform_int(
const tensor& shape,
const tensor& seed,
const tensor& minval,
22139 const tensor& maxval, datatype dtype,
22140 datatype Tseed =
static_cast<datatype
>(9)) {
22142 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22143 TFE_NewOp(context::get_context(),
"StatelessRandomUniformInt", context::get_status()), &TFE_DeleteOp);
22144 status_check(context::get_status());
22148 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22149 status_check(context::get_status());
22151 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22152 status_check(context::get_status());
22154 TFE_OpAddInput(op.get(), minval.tfe_handle.get(), context::get_status());
22155 status_check(context::get_status());
22157 TFE_OpAddInput(op.get(), maxval.tfe_handle.get(), context::get_status());
22158 status_check(context::get_status());
22161 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22162 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22165 int num_outputs_op = 1;
22166 TFE_TensorHandle* res[1] = {
nullptr};
22167 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22168 status_check(context::get_status());
22169 return tensor(res[0]);
22172 inline tensor stateless_truncated_normal(
const tensor& shape,
const tensor& seed,
22173 datatype dtype =
static_cast<datatype
>(1),
22174 datatype Tseed =
static_cast<datatype
>(9)) {
22176 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22177 TFE_NewOp(context::get_context(),
"StatelessTruncatedNormal", context::get_status()), &TFE_DeleteOp);
22178 status_check(context::get_status());
22182 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22183 status_check(context::get_status());
22185 TFE_OpAddInput(op.get(), seed.tfe_handle.get(), context::get_status());
22186 status_check(context::get_status());
22189 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
22190 TFE_OpSetAttrType(op.get(),
"Tseed", Tseed);
22193 int num_outputs_op = 1;
22194 TFE_TensorHandle* res[1] = {
nullptr};
22195 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22196 status_check(context::get_status());
22197 return tensor(res[0]);
22200 inline tensor static_regex_full_match(
const tensor& input,
const std::string& pattern) {
22202 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22203 TFE_NewOp(context::get_context(),
"StaticRegexFullMatch", context::get_status()), &TFE_DeleteOp);
22204 status_check(context::get_status());
22208 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22209 status_check(context::get_status());
22212 TFE_OpSetAttrString(op.get(),
"pattern", (
void*)pattern.c_str(), pattern.size());
22215 int num_outputs_op = 1;
22216 TFE_TensorHandle* res[1] = {
nullptr};
22217 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22218 status_check(context::get_status());
22219 return tensor(res[0]);
22222 inline tensor static_regex_replace(
const tensor& input,
const std::string& pattern,
const std::string& rewrite,
22223 bool replace_global =
true) {
22225 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22226 TFE_NewOp(context::get_context(),
"StaticRegexReplace", context::get_status()), &TFE_DeleteOp);
22227 status_check(context::get_status());
22231 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22232 status_check(context::get_status());
22235 TFE_OpSetAttrString(op.get(),
"pattern", (
void*)pattern.c_str(), pattern.size());
22236 TFE_OpSetAttrString(op.get(),
"rewrite", (
void*)rewrite.c_str(), rewrite.size());
22237 TFE_OpSetAttrBool(op.get(),
"replace_global", (
unsigned char)replace_global);
22240 int num_outputs_op = 1;
22241 TFE_TensorHandle* res[1] = {
nullptr};
22242 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22243 status_check(context::get_status());
22244 return tensor(res[0]);
22247 inline tensor stats_aggregator_handle(
const std::string& container =
"",
const std::string& shared_name =
"") {
22249 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22250 TFE_NewOp(context::get_context(),
"StatsAggregatorHandle", context::get_status()), &TFE_DeleteOp);
22251 status_check(context::get_status());
22256 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
22257 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
22260 int num_outputs_op = 1;
22261 TFE_TensorHandle* res[1] = {
nullptr};
22262 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22263 status_check(context::get_status());
22264 return tensor(res[0]);
22267 inline tensor stats_aggregator_handle_v2(
const std::string& container =
"",
const std::string& shared_name =
"") {
22269 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22270 TFE_NewOp(context::get_context(),
"StatsAggregatorHandleV2", context::get_status()), &TFE_DeleteOp);
22271 status_check(context::get_status());
22276 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
22277 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
22280 int num_outputs_op = 1;
22281 TFE_TensorHandle* res[1] = {
nullptr};
22282 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22283 status_check(context::get_status());
22284 return tensor(res[0]);
22287 inline tensor stats_aggregator_summary(
const tensor& iterator) {
22289 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22290 TFE_NewOp(context::get_context(),
"StatsAggregatorSummary", context::get_status()), &TFE_DeleteOp);
22291 status_check(context::get_status());
22295 TFE_OpAddInput(op.get(), iterator.tfe_handle.get(), context::get_status());
22296 status_check(context::get_status());
22301 int num_outputs_op = 1;
22302 TFE_TensorHandle* res[1] = {
nullptr};
22303 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22304 status_check(context::get_status());
22305 return tensor(res[0]);
22308 inline tensor stop_gradient(
const tensor& input) {
22310 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22311 TFE_NewOp(context::get_context(),
"StopGradient", context::get_status()), &TFE_DeleteOp);
22312 status_check(context::get_status());
22316 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22317 status_check(context::get_status());
22322 int num_outputs_op = 1;
22323 TFE_TensorHandle* res[1] = {
nullptr};
22324 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22325 status_check(context::get_status());
22326 return tensor(res[0]);
22329 inline tensor strided_slice(
const tensor& input,
const tensor& begin,
const tensor& end,
const tensor& strides,
22330 datatype Index, int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0,
22331 int64_t new_axis_mask = 0, int64_t shrink_axis_mask = 0) {
22333 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22334 TFE_NewOp(context::get_context(),
"StridedSlice", context::get_status()), &TFE_DeleteOp);
22335 status_check(context::get_status());
22339 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22340 status_check(context::get_status());
22342 TFE_OpAddInput(op.get(), begin.tfe_handle.get(), context::get_status());
22343 status_check(context::get_status());
22345 TFE_OpAddInput(op.get(), end.tfe_handle.get(), context::get_status());
22346 status_check(context::get_status());
22348 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
22349 status_check(context::get_status());
22352 TFE_OpSetAttrType(op.get(),
"Index", Index);
22353 TFE_OpSetAttrInt(op.get(),
"begin_mask", begin_mask);
22354 TFE_OpSetAttrInt(op.get(),
"end_mask", end_mask);
22355 TFE_OpSetAttrInt(op.get(),
"ellipsis_mask", ellipsis_mask);
22356 TFE_OpSetAttrInt(op.get(),
"new_axis_mask", new_axis_mask);
22357 TFE_OpSetAttrInt(op.get(),
"shrink_axis_mask", shrink_axis_mask);
22360 int num_outputs_op = 1;
22361 TFE_TensorHandle* res[1] = {
nullptr};
22362 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22363 status_check(context::get_status());
22364 return tensor(res[0]);
22367 inline tensor strided_slice_assign(
const tensor& ref,
const tensor& begin,
const tensor& end,
const tensor& strides,
22368 const tensor& value, datatype Index, int64_t begin_mask = 0, int64_t end_mask = 0,
22369 int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0, int64_t shrink_axis_mask = 0) {
22371 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22372 TFE_NewOp(context::get_context(),
"StridedSliceAssign", context::get_status()), &TFE_DeleteOp);
22373 status_check(context::get_status());
22377 TFE_OpAddInput(op.get(), ref.tfe_handle.get(), context::get_status());
22378 status_check(context::get_status());
22380 TFE_OpAddInput(op.get(), begin.tfe_handle.get(), context::get_status());
22381 status_check(context::get_status());
22383 TFE_OpAddInput(op.get(), end.tfe_handle.get(), context::get_status());
22384 status_check(context::get_status());
22386 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
22387 status_check(context::get_status());
22389 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
22390 status_check(context::get_status());
22393 TFE_OpSetAttrType(op.get(),
"Index", Index);
22394 TFE_OpSetAttrInt(op.get(),
"begin_mask", begin_mask);
22395 TFE_OpSetAttrInt(op.get(),
"end_mask", end_mask);
22396 TFE_OpSetAttrInt(op.get(),
"ellipsis_mask", ellipsis_mask);
22397 TFE_OpSetAttrInt(op.get(),
"new_axis_mask", new_axis_mask);
22398 TFE_OpSetAttrInt(op.get(),
"shrink_axis_mask", shrink_axis_mask);
22401 int num_outputs_op = 1;
22402 TFE_TensorHandle* res[1] = {
nullptr};
22403 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22404 status_check(context::get_status());
22405 return tensor(res[0]);
22408 inline tensor strided_slice_grad(
const tensor& shape,
const tensor& begin,
const tensor& end,
const tensor& strides,
22409 const tensor& dy, datatype Index, int64_t begin_mask = 0, int64_t end_mask = 0,
22410 int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0, int64_t shrink_axis_mask = 0) {
22412 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22413 TFE_NewOp(context::get_context(),
"StridedSliceGrad", context::get_status()), &TFE_DeleteOp);
22414 status_check(context::get_status());
22418 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
22419 status_check(context::get_status());
22421 TFE_OpAddInput(op.get(), begin.tfe_handle.get(), context::get_status());
22422 status_check(context::get_status());
22424 TFE_OpAddInput(op.get(), end.tfe_handle.get(), context::get_status());
22425 status_check(context::get_status());
22427 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
22428 status_check(context::get_status());
22430 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
22431 status_check(context::get_status());
22434 TFE_OpSetAttrType(op.get(),
"Index", Index);
22435 TFE_OpSetAttrInt(op.get(),
"begin_mask", begin_mask);
22436 TFE_OpSetAttrInt(op.get(),
"end_mask", end_mask);
22437 TFE_OpSetAttrInt(op.get(),
"ellipsis_mask", ellipsis_mask);
22438 TFE_OpSetAttrInt(op.get(),
"new_axis_mask", new_axis_mask);
22439 TFE_OpSetAttrInt(op.get(),
"shrink_axis_mask", shrink_axis_mask);
22442 int num_outputs_op = 1;
22443 TFE_TensorHandle* res[1] = {
nullptr};
22444 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22445 status_check(context::get_status());
22446 return tensor(res[0]);
22449 inline tensor string_format(
const std::vector<tensor>& inputs,
const std::string& template_arg =
"%s",
22450 const std::string& placeholder =
"%s", int64_t summarize = 3) {
22452 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22453 TFE_NewOp(context::get_context(),
"StringFormat", context::get_status()), &TFE_DeleteOp);
22454 status_check(context::get_status());
22458 std::vector<TFE_TensorHandle*> inputs_handles;
22459 inputs_handles.reserve(inputs.size());
22460 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
22461 [](
const auto& t) { return t.tfe_handle.get(); });
22462 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
22463 status_check(context::get_status());
22466 TFE_OpSetAttrString(op.get(),
"template", (
void*)template_arg.c_str(), template_arg.size());
22467 TFE_OpSetAttrString(op.get(),
"placeholder", (
void*)placeholder.c_str(), placeholder.size());
22468 TFE_OpSetAttrInt(op.get(),
"summarize", summarize);
22471 int num_outputs_op = 1;
22472 TFE_TensorHandle* res[1] = {
nullptr};
22473 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22474 status_check(context::get_status());
22475 return tensor(res[0]);
22478 inline tensor string_join(
const std::vector<tensor>& inputs,
const std::string& separator =
"") {
22480 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22481 TFE_NewOp(context::get_context(),
"StringJoin", context::get_status()), &TFE_DeleteOp);
22482 status_check(context::get_status());
22486 std::vector<TFE_TensorHandle*> inputs_handles;
22487 inputs_handles.reserve(inputs.size());
22488 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
22489 [](
const auto& t) { return t.tfe_handle.get(); });
22490 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
22491 status_check(context::get_status());
22494 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
22495 TFE_OpSetAttrString(op.get(),
"separator", (
void*)separator.c_str(), separator.size());
22498 int num_outputs_op = 1;
22499 TFE_TensorHandle* res[1] = {
nullptr};
22500 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22501 status_check(context::get_status());
22502 return tensor(res[0]);
22505 inline tensor string_length(
const tensor& input,
const std::string& unit =
"BYTE") {
22507 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22508 TFE_NewOp(context::get_context(),
"StringLength", context::get_status()), &TFE_DeleteOp);
22509 status_check(context::get_status());
22513 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22514 status_check(context::get_status());
22517 TFE_OpSetAttrString(op.get(),
"unit", (
void*)unit.c_str(), unit.size());
22520 int num_outputs_op = 1;
22521 TFE_TensorHandle* res[1] = {
nullptr};
22522 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22523 status_check(context::get_status());
22524 return tensor(res[0]);
22527 inline tensor string_lower(
const tensor& input,
const std::string& encoding =
"") {
22529 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22530 TFE_NewOp(context::get_context(),
"StringLower", context::get_status()), &TFE_DeleteOp);
22531 status_check(context::get_status());
22535 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22536 status_check(context::get_status());
22539 TFE_OpSetAttrString(op.get(),
"encoding", (
void*)encoding.c_str(), encoding.size());
22542 int num_outputs_op = 1;
22543 TFE_TensorHandle* res[1] = {
nullptr};
22544 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22545 status_check(context::get_status());
22546 return tensor(res[0]);
22549 inline tensor string_strip(
const tensor& input) {
22551 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22552 TFE_NewOp(context::get_context(),
"StringStrip", context::get_status()), &TFE_DeleteOp);
22553 status_check(context::get_status());
22557 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22558 status_check(context::get_status());
22563 int num_outputs_op = 1;
22564 TFE_TensorHandle* res[1] = {
nullptr};
22565 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22566 status_check(context::get_status());
22567 return tensor(res[0]);
22570 inline tensor string_to_hash_bucket(
const tensor& string_input_tensor, int64_t num_buckets) {
22572 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22573 TFE_NewOp(context::get_context(),
"StringToHashBucket", context::get_status()), &TFE_DeleteOp);
22574 status_check(context::get_status());
22578 TFE_OpAddInput(op.get(), string_input_tensor.tfe_handle.get(), context::get_status());
22579 status_check(context::get_status());
22582 TFE_OpSetAttrInt(op.get(),
"num_buckets", num_buckets);
22585 int num_outputs_op = 1;
22586 TFE_TensorHandle* res[1] = {
nullptr};
22587 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22588 status_check(context::get_status());
22589 return tensor(res[0]);
22592 inline tensor string_to_hash_bucket_fast(
const tensor& input, int64_t num_buckets) {
22594 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22595 TFE_NewOp(context::get_context(),
"StringToHashBucketFast", context::get_status()), &TFE_DeleteOp);
22596 status_check(context::get_status());
22600 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22601 status_check(context::get_status());
22604 TFE_OpSetAttrInt(op.get(),
"num_buckets", num_buckets);
22607 int num_outputs_op = 1;
22608 TFE_TensorHandle* res[1] = {
nullptr};
22609 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22610 status_check(context::get_status());
22611 return tensor(res[0]);
22614 inline tensor string_to_hash_bucket_strong(
const tensor& input, int64_t num_buckets,
const std::vector<int64_t>& key) {
22616 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22617 TFE_NewOp(context::get_context(),
"StringToHashBucketStrong", context::get_status()), &TFE_DeleteOp);
22618 status_check(context::get_status());
22622 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22623 status_check(context::get_status());
22626 TFE_OpSetAttrInt(op.get(),
"num_buckets", num_buckets);
22627 TFE_OpSetAttrIntList(op.get(),
"key", key.data(),
static_cast<int>(key.size()));
22630 int num_outputs_op = 1;
22631 TFE_TensorHandle* res[1] = {
nullptr};
22632 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22633 status_check(context::get_status());
22634 return tensor(res[0]);
22637 inline tensor string_to_number(
const tensor& string_input_tensor, datatype out_type =
static_cast<datatype
>(1)) {
22639 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22640 TFE_NewOp(context::get_context(),
"StringToNumber", context::get_status()), &TFE_DeleteOp);
22641 status_check(context::get_status());
22645 TFE_OpAddInput(op.get(), string_input_tensor.tfe_handle.get(), context::get_status());
22646 status_check(context::get_status());
22649 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
22652 int num_outputs_op = 1;
22653 TFE_TensorHandle* res[1] = {
nullptr};
22654 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22655 status_check(context::get_status());
22656 return tensor(res[0]);
22659 inline tensor string_upper(
const tensor& input,
const std::string& encoding =
"") {
22661 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22662 TFE_NewOp(context::get_context(),
"StringUpper", context::get_status()), &TFE_DeleteOp);
22663 status_check(context::get_status());
22667 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22668 status_check(context::get_status());
22671 TFE_OpSetAttrString(op.get(),
"encoding", (
void*)encoding.c_str(), encoding.size());
22674 int num_outputs_op = 1;
22675 TFE_TensorHandle* res[1] = {
nullptr};
22676 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22677 status_check(context::get_status());
22678 return tensor(res[0]);
22681 inline tensor sub(
const tensor& x,
const tensor& y) {
22683 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sub", context::get_status()),
22685 status_check(context::get_status());
22689 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
22690 status_check(context::get_status());
22692 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
22693 status_check(context::get_status());
22698 int num_outputs_op = 1;
22699 TFE_TensorHandle* res[1] = {
nullptr};
22700 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22701 status_check(context::get_status());
22702 return tensor(res[0]);
22705 inline tensor substr(
const tensor& input,
const tensor& pos,
const tensor& len,
const std::string& unit =
"BYTE") {
22707 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22708 TFE_NewOp(context::get_context(),
"Substr", context::get_status()), &TFE_DeleteOp);
22709 status_check(context::get_status());
22713 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22714 status_check(context::get_status());
22716 TFE_OpAddInput(op.get(), pos.tfe_handle.get(), context::get_status());
22717 status_check(context::get_status());
22719 TFE_OpAddInput(op.get(), len.tfe_handle.get(), context::get_status());
22720 status_check(context::get_status());
22723 TFE_OpSetAttrString(op.get(),
"unit", (
void*)unit.c_str(), unit.size());
22726 int num_outputs_op = 1;
22727 TFE_TensorHandle* res[1] = {
nullptr};
22728 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22729 status_check(context::get_status());
22730 return tensor(res[0]);
22733 inline tensor sum(
const tensor& input,
const tensor& reduction_indices,
bool keep_dims =
false,
22734 datatype Tidx =
static_cast<datatype
>(3)) {
22736 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Sum", context::get_status()),
22738 status_check(context::get_status());
22742 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22743 status_check(context::get_status());
22745 TFE_OpAddInput(op.get(), reduction_indices.tfe_handle.get(), context::get_status());
22746 status_check(context::get_status());
22749 TFE_OpSetAttrBool(op.get(),
"keep_dims", (
unsigned char)keep_dims);
22750 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
22753 int num_outputs_op = 1;
22754 TFE_TensorHandle* res[1] = {
nullptr};
22755 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22756 status_check(context::get_status());
22757 return tensor(res[0]);
22760 inline tensor summary_writer(
const std::string& shared_name =
"",
const std::string& container =
"") {
22762 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22763 TFE_NewOp(context::get_context(),
"SummaryWriter", context::get_status()), &TFE_DeleteOp);
22764 status_check(context::get_status());
22769 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
22770 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
22773 int num_outputs_op = 1;
22774 TFE_TensorHandle* res[1] = {
nullptr};
22775 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22776 status_check(context::get_status());
22777 return tensor(res[0]);
22780 inline tensor t_f_record_dataset(
const tensor& filenames,
const tensor& compression_type,
const tensor& buffer_size) {
22782 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22783 TFE_NewOp(context::get_context(),
"TFRecordDataset", context::get_status()), &TFE_DeleteOp);
22784 status_check(context::get_status());
22788 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
22789 status_check(context::get_status());
22791 TFE_OpAddInput(op.get(), compression_type.tfe_handle.get(), context::get_status());
22792 status_check(context::get_status());
22794 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
22795 status_check(context::get_status());
22800 int num_outputs_op = 1;
22801 TFE_TensorHandle* res[1] = {
nullptr};
22802 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22803 status_check(context::get_status());
22804 return tensor(res[0]);
22807 inline tensor t_f_record_reader(
const std::string& container =
"",
const std::string& shared_name =
"",
22808 const std::string& compression_type =
"") {
22810 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22811 TFE_NewOp(context::get_context(),
"TFRecordReader", context::get_status()), &TFE_DeleteOp);
22812 status_check(context::get_status());
22817 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
22818 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
22819 TFE_OpSetAttrString(op.get(),
"compression_type", (
void*)compression_type.c_str(), compression_type.size());
22822 int num_outputs_op = 1;
22823 TFE_TensorHandle* res[1] = {
nullptr};
22824 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22825 status_check(context::get_status());
22826 return tensor(res[0]);
22829 inline tensor t_f_record_reader_v2(
const std::string& container =
"",
const std::string& shared_name =
"",
22830 const std::string& compression_type =
"") {
22832 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22833 TFE_NewOp(context::get_context(),
"TFRecordReaderV2", context::get_status()), &TFE_DeleteOp);
22834 status_check(context::get_status());
22839 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
22840 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
22841 TFE_OpSetAttrString(op.get(),
"compression_type", (
void*)compression_type.c_str(), compression_type.size());
22844 int num_outputs_op = 1;
22845 TFE_TensorHandle* res[1] = {
nullptr};
22846 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22847 status_check(context::get_status());
22848 return tensor(res[0]);
22851 inline tensor t_p_u_compilation_result() {
22853 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22854 TFE_NewOp(context::get_context(),
"TPUCompilationResult", context::get_status()), &TFE_DeleteOp);
22855 status_check(context::get_status());
22862 int num_outputs_op = 1;
22863 TFE_TensorHandle* res[1] = {
nullptr};
22864 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22865 status_check(context::get_status());
22866 return tensor(res[0]);
22869 inline tensor t_p_u_embedding_activations(
const tensor& embedding_variable,
const tensor& sliced_activations,
22870 int64_t table_id, int64_t lookup_id) {
22872 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22873 TFE_NewOp(context::get_context(),
"TPUEmbeddingActivations", context::get_status()), &TFE_DeleteOp);
22874 status_check(context::get_status());
22878 TFE_OpAddInput(op.get(), embedding_variable.tfe_handle.get(), context::get_status());
22879 status_check(context::get_status());
22881 TFE_OpAddInput(op.get(), sliced_activations.tfe_handle.get(), context::get_status());
22882 status_check(context::get_status());
22885 TFE_OpSetAttrInt(op.get(),
"table_id", table_id);
22886 TFE_OpSetAttrInt(op.get(),
"lookup_id", lookup_id);
22889 int num_outputs_op = 1;
22890 TFE_TensorHandle* res[1] = {
nullptr};
22891 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22892 status_check(context::get_status());
22893 return tensor(res[0]);
22896 inline tensor t_p_u_ordinal_selector() {
22898 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22899 TFE_NewOp(context::get_context(),
"TPUOrdinalSelector", context::get_status()), &TFE_DeleteOp);
22900 status_check(context::get_status());
22907 int num_outputs_op = 1;
22908 TFE_TensorHandle* res[1] = {
nullptr};
22909 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22910 status_check(context::get_status());
22911 return tensor(res[0]);
22914 inline tensor t_p_u_replicated_input(
const std::vector<tensor>& inputs,
bool is_mirrored_variable =
false,
22915 int64_t index = -1,
bool is_packed =
false) {
22917 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22918 TFE_NewOp(context::get_context(),
"TPUReplicatedInput", context::get_status()), &TFE_DeleteOp);
22919 status_check(context::get_status());
22923 std::vector<TFE_TensorHandle*> inputs_handles;
22924 inputs_handles.reserve(inputs.size());
22925 std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_handles),
22926 [](
const auto& t) { return t.tfe_handle.get(); });
22927 TFE_OpAddInputList(op.get(), inputs_handles.data(),
static_cast<int>(inputs.size()), context::get_status());
22928 status_check(context::get_status());
22931 TFE_OpSetAttrInt(op.get(),
"N", inputs.size());
22932 TFE_OpSetAttrBool(op.get(),
"is_mirrored_variable", (
unsigned char)is_mirrored_variable);
22933 TFE_OpSetAttrInt(op.get(),
"index", index);
22934 TFE_OpSetAttrBool(op.get(),
"is_packed", (
unsigned char)is_packed);
22937 int num_outputs_op = 1;
22938 TFE_TensorHandle* res[1] = {
nullptr};
22939 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22940 status_check(context::get_status());
22941 return tensor(res[0]);
22944 inline tensor t_p_u_replicated_output(
const tensor& input, int64_t num_replicas) {
22946 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22947 TFE_NewOp(context::get_context(),
"TPUReplicatedOutput", context::get_status()), &TFE_DeleteOp);
22948 status_check(context::get_status());
22952 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
22953 status_check(context::get_status());
22956 TFE_OpSetAttrInt(op.get(),
"num_replicas", num_replicas);
22959 int num_outputs_op = 1;
22960 TFE_TensorHandle* res[1] = {
nullptr};
22961 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
22962 status_check(context::get_status());
22963 return tensor(res[0]);
22966 inline tensor take_dataset(
const tensor& input_dataset,
const tensor& count,
const std::vector<datatype>& output_types,
22967 const std::vector<std::vector<int64_t>>& output_shapes) {
22969 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
22970 TFE_NewOp(context::get_context(),
"TakeDataset", context::get_status()), &TFE_DeleteOp);
22971 status_check(context::get_status());
22975 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
22976 status_check(context::get_status());
22978 TFE_OpAddInput(op.get(), count.tfe_handle.get(), context::get_status());
22979 status_check(context::get_status());
22982 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
22983 static_cast<int>(output_types.size()));
22985 std::vector<const int64_t*> output_shapes_values;
22986 output_shapes_values.reserve(output_shapes.size());
22987 std::vector<int> output_shapes_ndims;
22988 output_shapes_ndims.reserve(output_shapes.size());
22989 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
22990 [](
const auto& v) { return v.data(); });
22991 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
22992 [](
const auto& v) { return static_cast<int>(v.size()); });
22993 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
22994 static_cast<int>(output_shapes.size()), context::get_status());
22995 status_check(context::get_status());
22998 int num_outputs_op = 1;
22999 TFE_TensorHandle* res[1] = {
nullptr};
23000 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23001 status_check(context::get_status());
23002 return tensor(res[0]);
23005 inline tensor tan(
const tensor& x) {
23007 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Tan", context::get_status()),
23009 status_check(context::get_status());
23013 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
23014 status_check(context::get_status());
23019 int num_outputs_op = 1;
23020 TFE_TensorHandle* res[1] = {
nullptr};
23021 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23022 status_check(context::get_status());
23023 return tensor(res[0]);
23026 inline tensor tanh(
const tensor& x) {
23028 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Tanh", context::get_status()),
23030 status_check(context::get_status());
23034 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
23035 status_check(context::get_status());
23040 int num_outputs_op = 1;
23041 TFE_TensorHandle* res[1] = {
nullptr};
23042 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23043 status_check(context::get_status());
23044 return tensor(res[0]);
23047 inline tensor tanh_grad(
const tensor& y,
const tensor& dy) {
23049 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23050 TFE_NewOp(context::get_context(),
"TanhGrad", context::get_status()), &TFE_DeleteOp);
23051 status_check(context::get_status());
23055 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
23056 status_check(context::get_status());
23058 TFE_OpAddInput(op.get(), dy.tfe_handle.get(), context::get_status());
23059 status_check(context::get_status());
23064 int num_outputs_op = 1;
23065 TFE_TensorHandle* res[1] = {
nullptr};
23066 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23067 status_check(context::get_status());
23068 return tensor(res[0]);
23071 inline tensor temporary_variable(
const std::vector<int64_t>& shape, datatype dtype,
const std::string& var_name =
"") {
23073 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23074 TFE_NewOp(context::get_context(),
"TemporaryVariable", context::get_status()), &TFE_DeleteOp);
23075 status_check(context::get_status());
23081 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
23082 status_check(context::get_status());
23084 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23085 TFE_OpSetAttrString(op.get(),
"var_name", (
void*)var_name.c_str(), var_name.size());
23088 int num_outputs_op = 1;
23089 TFE_TensorHandle* res[1] = {
nullptr};
23090 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23091 status_check(context::get_status());
23092 return tensor(res[0]);
23095 inline tensor tensor_array(
const tensor& size, datatype dtype,
const std::vector<int64_t>& element_shape,
23096 bool dynamic_size =
false,
bool clear_after_read =
true,
23097 const std::string& tensor_array_name =
"") {
23099 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23100 TFE_NewOp(context::get_context(),
"TensorArray", context::get_status()), &TFE_DeleteOp);
23101 status_check(context::get_status());
23105 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
23106 status_check(context::get_status());
23109 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23111 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23112 context::get_status());
23113 status_check(context::get_status());
23115 TFE_OpSetAttrBool(op.get(),
"dynamic_size", (
unsigned char)dynamic_size);
23116 TFE_OpSetAttrBool(op.get(),
"clear_after_read", (
unsigned char)clear_after_read);
23117 TFE_OpSetAttrString(op.get(),
"tensor_array_name", (
void*)tensor_array_name.c_str(), tensor_array_name.size());
23120 int num_outputs_op = 1;
23121 TFE_TensorHandle* res[1] = {
nullptr};
23122 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23123 status_check(context::get_status());
23124 return tensor(res[0]);
23127 inline tensor tensor_array_gather(
const tensor& handle,
const tensor& indices,
const tensor& flow_in, datatype dtype,
23128 const std::vector<int64_t>& element_shape) {
23130 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23131 TFE_NewOp(context::get_context(),
"TensorArrayGather", context::get_status()), &TFE_DeleteOp);
23132 status_check(context::get_status());
23136 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23137 status_check(context::get_status());
23139 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23140 status_check(context::get_status());
23142 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23143 status_check(context::get_status());
23146 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23148 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23149 context::get_status());
23150 status_check(context::get_status());
23153 int num_outputs_op = 1;
23154 TFE_TensorHandle* res[1] = {
nullptr};
23155 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23156 status_check(context::get_status());
23157 return tensor(res[0]);
23160 inline tensor tensor_array_gather_v2(
const tensor& handle,
const tensor& indices,
const tensor& flow_in, datatype dtype,
23161 const std::vector<int64_t>& element_shape) {
23163 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23164 TFE_NewOp(context::get_context(),
"TensorArrayGatherV2", context::get_status()), &TFE_DeleteOp);
23165 status_check(context::get_status());
23169 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23170 status_check(context::get_status());
23172 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23173 status_check(context::get_status());
23175 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23176 status_check(context::get_status());
23179 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23181 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23182 context::get_status());
23183 status_check(context::get_status());
23186 int num_outputs_op = 1;
23187 TFE_TensorHandle* res[1] = {
nullptr};
23188 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23189 status_check(context::get_status());
23190 return tensor(res[0]);
23193 inline tensor tensor_array_gather_v3(
const tensor& handle,
const tensor& indices,
const tensor& flow_in, datatype dtype,
23194 const std::vector<int64_t>& element_shape) {
23196 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23197 TFE_NewOp(context::get_context(),
"TensorArrayGatherV3", context::get_status()), &TFE_DeleteOp);
23198 status_check(context::get_status());
23202 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23203 status_check(context::get_status());
23205 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23206 status_check(context::get_status());
23208 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23209 status_check(context::get_status());
23212 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23214 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23215 context::get_status());
23216 status_check(context::get_status());
23219 int num_outputs_op = 1;
23220 TFE_TensorHandle* res[1] = {
nullptr};
23221 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23222 status_check(context::get_status());
23223 return tensor(res[0]);
23226 inline tensor tensor_array_grad(
const tensor& handle,
const tensor& flow_in,
const std::string& source) {
23228 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23229 TFE_NewOp(context::get_context(),
"TensorArrayGrad", context::get_status()), &TFE_DeleteOp);
23230 status_check(context::get_status());
23234 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23235 status_check(context::get_status());
23237 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23238 status_check(context::get_status());
23241 TFE_OpSetAttrString(op.get(),
"source", (
void*)source.c_str(), source.size());
23244 int num_outputs_op = 1;
23245 TFE_TensorHandle* res[1] = {
nullptr};
23246 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23247 status_check(context::get_status());
23248 return tensor(res[0]);
23251 inline tensor tensor_array_grad_v2(
const tensor& handle,
const tensor& flow_in,
const std::string& source) {
23253 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23254 TFE_NewOp(context::get_context(),
"TensorArrayGradV2", context::get_status()), &TFE_DeleteOp);
23255 status_check(context::get_status());
23259 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23260 status_check(context::get_status());
23262 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23263 status_check(context::get_status());
23266 TFE_OpSetAttrString(op.get(),
"source", (
void*)source.c_str(), source.size());
23269 int num_outputs_op = 1;
23270 TFE_TensorHandle* res[1] = {
nullptr};
23271 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23272 status_check(context::get_status());
23273 return tensor(res[0]);
23276 inline tensor tensor_array_pack(
const tensor& handle,
const tensor& flow_in, datatype dtype,
23277 const std::vector<int64_t>& element_shape) {
23279 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23280 TFE_NewOp(context::get_context(),
"TensorArrayPack", context::get_status()), &TFE_DeleteOp);
23281 status_check(context::get_status());
23285 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23286 status_check(context::get_status());
23288 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23289 status_check(context::get_status());
23292 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23294 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23295 context::get_status());
23296 status_check(context::get_status());
23299 int num_outputs_op = 1;
23300 TFE_TensorHandle* res[1] = {
nullptr};
23301 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23302 status_check(context::get_status());
23303 return tensor(res[0]);
23306 inline tensor tensor_array_read(
const tensor& handle,
const tensor& index,
const tensor& flow_in, datatype dtype) {
23308 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23309 TFE_NewOp(context::get_context(),
"TensorArrayRead", context::get_status()), &TFE_DeleteOp);
23310 status_check(context::get_status());
23314 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23315 status_check(context::get_status());
23317 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23318 status_check(context::get_status());
23320 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23321 status_check(context::get_status());
23324 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23327 int num_outputs_op = 1;
23328 TFE_TensorHandle* res[1] = {
nullptr};
23329 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23330 status_check(context::get_status());
23331 return tensor(res[0]);
23334 inline tensor tensor_array_read_v2(
const tensor& handle,
const tensor& index,
const tensor& flow_in, datatype dtype) {
23336 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23337 TFE_NewOp(context::get_context(),
"TensorArrayReadV2", context::get_status()), &TFE_DeleteOp);
23338 status_check(context::get_status());
23342 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23343 status_check(context::get_status());
23345 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23346 status_check(context::get_status());
23348 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23349 status_check(context::get_status());
23352 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23355 int num_outputs_op = 1;
23356 TFE_TensorHandle* res[1] = {
nullptr};
23357 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23358 status_check(context::get_status());
23359 return tensor(res[0]);
23362 inline tensor tensor_array_read_v3(
const tensor& handle,
const tensor& index,
const tensor& flow_in, datatype dtype) {
23364 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23365 TFE_NewOp(context::get_context(),
"TensorArrayReadV3", context::get_status()), &TFE_DeleteOp);
23366 status_check(context::get_status());
23370 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23371 status_check(context::get_status());
23373 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23374 status_check(context::get_status());
23376 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23377 status_check(context::get_status());
23380 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23383 int num_outputs_op = 1;
23384 TFE_TensorHandle* res[1] = {
nullptr};
23385 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23386 status_check(context::get_status());
23387 return tensor(res[0]);
23390 inline tensor tensor_array_scatter(
const tensor& handle,
const tensor& indices,
const tensor& value,
23391 const tensor& flow_in) {
23393 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23394 TFE_NewOp(context::get_context(),
"TensorArrayScatter", context::get_status()), &TFE_DeleteOp);
23395 status_check(context::get_status());
23399 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23400 status_check(context::get_status());
23402 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23403 status_check(context::get_status());
23405 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23406 status_check(context::get_status());
23408 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23409 status_check(context::get_status());
23414 int num_outputs_op = 1;
23415 TFE_TensorHandle* res[1] = {
nullptr};
23416 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23417 status_check(context::get_status());
23418 return tensor(res[0]);
23421 inline tensor tensor_array_scatter_v2(
const tensor& handle,
const tensor& indices,
const tensor& value,
23422 const tensor& flow_in) {
23424 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23425 TFE_NewOp(context::get_context(),
"TensorArrayScatterV2", context::get_status()), &TFE_DeleteOp);
23426 status_check(context::get_status());
23430 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23431 status_check(context::get_status());
23433 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23434 status_check(context::get_status());
23436 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23437 status_check(context::get_status());
23439 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23440 status_check(context::get_status());
23445 int num_outputs_op = 1;
23446 TFE_TensorHandle* res[1] = {
nullptr};
23447 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23448 status_check(context::get_status());
23449 return tensor(res[0]);
23452 inline tensor tensor_array_scatter_v3(
const tensor& handle,
const tensor& indices,
const tensor& value,
23453 const tensor& flow_in) {
23455 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23456 TFE_NewOp(context::get_context(),
"TensorArrayScatterV3", context::get_status()), &TFE_DeleteOp);
23457 status_check(context::get_status());
23461 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23462 status_check(context::get_status());
23464 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23465 status_check(context::get_status());
23467 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23468 status_check(context::get_status());
23470 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23471 status_check(context::get_status());
23476 int num_outputs_op = 1;
23477 TFE_TensorHandle* res[1] = {
nullptr};
23478 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23479 status_check(context::get_status());
23480 return tensor(res[0]);
23483 inline tensor tensor_array_size(
const tensor& handle,
const tensor& flow_in) {
23485 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23486 TFE_NewOp(context::get_context(),
"TensorArraySize", context::get_status()), &TFE_DeleteOp);
23487 status_check(context::get_status());
23491 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23492 status_check(context::get_status());
23494 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23495 status_check(context::get_status());
23500 int num_outputs_op = 1;
23501 TFE_TensorHandle* res[1] = {
nullptr};
23502 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23503 status_check(context::get_status());
23504 return tensor(res[0]);
23507 inline tensor tensor_array_size_v2(
const tensor& handle,
const tensor& flow_in) {
23509 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23510 TFE_NewOp(context::get_context(),
"TensorArraySizeV2", context::get_status()), &TFE_DeleteOp);
23511 status_check(context::get_status());
23515 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23516 status_check(context::get_status());
23518 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23519 status_check(context::get_status());
23524 int num_outputs_op = 1;
23525 TFE_TensorHandle* res[1] = {
nullptr};
23526 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23527 status_check(context::get_status());
23528 return tensor(res[0]);
23531 inline tensor tensor_array_size_v3(
const tensor& handle,
const tensor& flow_in) {
23533 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23534 TFE_NewOp(context::get_context(),
"TensorArraySizeV3", context::get_status()), &TFE_DeleteOp);
23535 status_check(context::get_status());
23539 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23540 status_check(context::get_status());
23542 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23543 status_check(context::get_status());
23548 int num_outputs_op = 1;
23549 TFE_TensorHandle* res[1] = {
nullptr};
23550 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23551 status_check(context::get_status());
23552 return tensor(res[0]);
23555 inline tensor tensor_array_split(
const tensor& handle,
const tensor& value,
const tensor& lengths,
23556 const tensor& flow_in) {
23558 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23559 TFE_NewOp(context::get_context(),
"TensorArraySplit", context::get_status()), &TFE_DeleteOp);
23560 status_check(context::get_status());
23564 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23565 status_check(context::get_status());
23567 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23568 status_check(context::get_status());
23570 TFE_OpAddInput(op.get(), lengths.tfe_handle.get(), context::get_status());
23571 status_check(context::get_status());
23573 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23574 status_check(context::get_status());
23579 int num_outputs_op = 1;
23580 TFE_TensorHandle* res[1] = {
nullptr};
23581 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23582 status_check(context::get_status());
23583 return tensor(res[0]);
23586 inline tensor tensor_array_split_v2(
const tensor& handle,
const tensor& value,
const tensor& lengths,
23587 const tensor& flow_in) {
23589 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23590 TFE_NewOp(context::get_context(),
"TensorArraySplitV2", context::get_status()), &TFE_DeleteOp);
23591 status_check(context::get_status());
23595 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23596 status_check(context::get_status());
23598 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23599 status_check(context::get_status());
23601 TFE_OpAddInput(op.get(), lengths.tfe_handle.get(), context::get_status());
23602 status_check(context::get_status());
23604 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23605 status_check(context::get_status());
23610 int num_outputs_op = 1;
23611 TFE_TensorHandle* res[1] = {
nullptr};
23612 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23613 status_check(context::get_status());
23614 return tensor(res[0]);
23617 inline tensor tensor_array_split_v3(
const tensor& handle,
const tensor& value,
const tensor& lengths,
23618 const tensor& flow_in) {
23620 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23621 TFE_NewOp(context::get_context(),
"TensorArraySplitV3", context::get_status()), &TFE_DeleteOp);
23622 status_check(context::get_status());
23626 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23627 status_check(context::get_status());
23629 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23630 status_check(context::get_status());
23632 TFE_OpAddInput(op.get(), lengths.tfe_handle.get(), context::get_status());
23633 status_check(context::get_status());
23635 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23636 status_check(context::get_status());
23641 int num_outputs_op = 1;
23642 TFE_TensorHandle* res[1] = {
nullptr};
23643 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23644 status_check(context::get_status());
23645 return tensor(res[0]);
23648 inline tensor tensor_array_unpack(
const tensor& handle,
const tensor& value,
const tensor& flow_in) {
23650 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23651 TFE_NewOp(context::get_context(),
"TensorArrayUnpack", context::get_status()), &TFE_DeleteOp);
23652 status_check(context::get_status());
23656 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23657 status_check(context::get_status());
23659 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23660 status_check(context::get_status());
23662 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23663 status_check(context::get_status());
23668 int num_outputs_op = 1;
23669 TFE_TensorHandle* res[1] = {
nullptr};
23670 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23671 status_check(context::get_status());
23672 return tensor(res[0]);
23675 inline tensor tensor_array_v2(
const tensor& size, datatype dtype,
const std::vector<int64_t>& element_shape,
23676 bool dynamic_size =
false,
bool clear_after_read =
true,
23677 const std::string& tensor_array_name =
"") {
23679 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23680 TFE_NewOp(context::get_context(),
"TensorArrayV2", context::get_status()), &TFE_DeleteOp);
23681 status_check(context::get_status());
23685 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
23686 status_check(context::get_status());
23689 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
23691 TFE_OpSetAttrShape(op.get(),
"element_shape", element_shape.data(),
static_cast<int>(element_shape.size()),
23692 context::get_status());
23693 status_check(context::get_status());
23695 TFE_OpSetAttrBool(op.get(),
"dynamic_size", (
unsigned char)dynamic_size);
23696 TFE_OpSetAttrBool(op.get(),
"clear_after_read", (
unsigned char)clear_after_read);
23697 TFE_OpSetAttrString(op.get(),
"tensor_array_name", (
void*)tensor_array_name.c_str(), tensor_array_name.size());
23700 int num_outputs_op = 1;
23701 TFE_TensorHandle* res[1] = {
nullptr};
23702 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23703 status_check(context::get_status());
23704 return tensor(res[0]);
23707 inline tensor tensor_array_write(
const tensor& handle,
const tensor& index,
const tensor& value,
23708 const tensor& flow_in) {
23710 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23711 TFE_NewOp(context::get_context(),
"TensorArrayWrite", context::get_status()), &TFE_DeleteOp);
23712 status_check(context::get_status());
23716 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23717 status_check(context::get_status());
23719 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23720 status_check(context::get_status());
23722 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23723 status_check(context::get_status());
23725 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23726 status_check(context::get_status());
23731 int num_outputs_op = 1;
23732 TFE_TensorHandle* res[1] = {
nullptr};
23733 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23734 status_check(context::get_status());
23735 return tensor(res[0]);
23738 inline tensor tensor_array_write_v2(
const tensor& handle,
const tensor& index,
const tensor& value,
23739 const tensor& flow_in) {
23741 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23742 TFE_NewOp(context::get_context(),
"TensorArrayWriteV2", context::get_status()), &TFE_DeleteOp);
23743 status_check(context::get_status());
23747 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23748 status_check(context::get_status());
23750 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23751 status_check(context::get_status());
23753 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23754 status_check(context::get_status());
23756 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23757 status_check(context::get_status());
23762 int num_outputs_op = 1;
23763 TFE_TensorHandle* res[1] = {
nullptr};
23764 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23765 status_check(context::get_status());
23766 return tensor(res[0]);
23769 inline tensor tensor_array_write_v3(
const tensor& handle,
const tensor& index,
const tensor& value,
23770 const tensor& flow_in) {
23772 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23773 TFE_NewOp(context::get_context(),
"TensorArrayWriteV3", context::get_status()), &TFE_DeleteOp);
23774 status_check(context::get_status());
23778 TFE_OpAddInput(op.get(), handle.tfe_handle.get(), context::get_status());
23779 status_check(context::get_status());
23781 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23782 status_check(context::get_status());
23784 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
23785 status_check(context::get_status());
23787 TFE_OpAddInput(op.get(), flow_in.tfe_handle.get(), context::get_status());
23788 status_check(context::get_status());
23793 int num_outputs_op = 1;
23794 TFE_TensorHandle* res[1] = {
nullptr};
23795 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23796 status_check(context::get_status());
23797 return tensor(res[0]);
23800 inline tensor tensor_dataset(
const std::vector<tensor>& components,
const std::vector<datatype>& Toutput_types,
23801 const std::vector<std::vector<int64_t>>& output_shapes) {
23803 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23804 TFE_NewOp(context::get_context(),
"TensorDataset", context::get_status()), &TFE_DeleteOp);
23805 status_check(context::get_status());
23809 std::vector<TFE_TensorHandle*> components_handles;
23810 components_handles.reserve(components.size());
23811 std::transform(components.begin(), components.end(), std::back_inserter(components_handles),
23812 [](
const auto& t) { return t.tfe_handle.get(); });
23813 TFE_OpAddInputList(op.get(), components_handles.data(),
static_cast<int>(components.size()), context::get_status());
23814 status_check(context::get_status());
23817 TFE_OpSetAttrTypeList(op.get(),
"Toutput_types",
reinterpret_cast<const enum TF_DataType*
>(Toutput_types.data()),
23818 static_cast<int>(Toutput_types.size()));
23820 std::vector<const int64_t*> output_shapes_values;
23821 output_shapes_values.reserve(output_shapes.size());
23822 std::vector<int> output_shapes_ndims;
23823 output_shapes_ndims.reserve(output_shapes.size());
23824 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
23825 [](
const auto& v) { return v.data(); });
23826 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
23827 [](
const auto& v) { return static_cast<int>(v.size()); });
23828 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
23829 static_cast<int>(output_shapes.size()), context::get_status());
23830 status_check(context::get_status());
23833 int num_outputs_op = 1;
23834 TFE_TensorHandle* res[1] = {
nullptr};
23835 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23836 status_check(context::get_status());
23837 return tensor(res[0]);
23840 inline tensor tensor_list_concat_lists(
const tensor& input_a,
const tensor& input_b, datatype element_dtype) {
23842 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23843 TFE_NewOp(context::get_context(),
"TensorListConcatLists", context::get_status()), &TFE_DeleteOp);
23844 status_check(context::get_status());
23848 TFE_OpAddInput(op.get(), input_a.tfe_handle.get(), context::get_status());
23849 status_check(context::get_status());
23851 TFE_OpAddInput(op.get(), input_b.tfe_handle.get(), context::get_status());
23852 status_check(context::get_status());
23855 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
23858 int num_outputs_op = 1;
23859 TFE_TensorHandle* res[1] = {
nullptr};
23860 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23861 status_check(context::get_status());
23862 return tensor(res[0]);
23865 inline tensor tensor_list_element_shape(
const tensor& input_handle, datatype shape_type) {
23867 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23868 TFE_NewOp(context::get_context(),
"TensorListElementShape", context::get_status()), &TFE_DeleteOp);
23869 status_check(context::get_status());
23873 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
23874 status_check(context::get_status());
23877 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
23880 int num_outputs_op = 1;
23881 TFE_TensorHandle* res[1] = {
nullptr};
23882 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23883 status_check(context::get_status());
23884 return tensor(res[0]);
23887 inline tensor tensor_list_from_tensor(
const tensor& input_tensor,
const tensor& element_shape, datatype element_dtype,
23888 datatype shape_type) {
23890 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23891 TFE_NewOp(context::get_context(),
"TensorListFromTensor", context::get_status()), &TFE_DeleteOp);
23892 status_check(context::get_status());
23896 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
23897 status_check(context::get_status());
23899 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
23900 status_check(context::get_status());
23903 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
23904 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
23907 int num_outputs_op = 1;
23908 TFE_TensorHandle* res[1] = {
nullptr};
23909 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23910 status_check(context::get_status());
23911 return tensor(res[0]);
23914 inline tensor tensor_list_gather(
const tensor& input_handle,
const tensor& indices,
const tensor& element_shape,
23915 datatype element_dtype) {
23917 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23918 TFE_NewOp(context::get_context(),
"TensorListGather", context::get_status()), &TFE_DeleteOp);
23919 status_check(context::get_status());
23923 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
23924 status_check(context::get_status());
23926 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
23927 status_check(context::get_status());
23929 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
23930 status_check(context::get_status());
23933 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
23936 int num_outputs_op = 1;
23937 TFE_TensorHandle* res[1] = {
nullptr};
23938 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23939 status_check(context::get_status());
23940 return tensor(res[0]);
23943 inline tensor tensor_list_get_item(
const tensor& input_handle,
const tensor& index,
const tensor& element_shape,
23944 datatype element_dtype) {
23946 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23947 TFE_NewOp(context::get_context(),
"TensorListGetItem", context::get_status()), &TFE_DeleteOp);
23948 status_check(context::get_status());
23952 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
23953 status_check(context::get_status());
23955 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
23956 status_check(context::get_status());
23958 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
23959 status_check(context::get_status());
23962 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
23965 int num_outputs_op = 1;
23966 TFE_TensorHandle* res[1] = {
nullptr};
23967 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23968 status_check(context::get_status());
23969 return tensor(res[0]);
23972 inline tensor tensor_list_length(
const tensor& input_handle) {
23974 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23975 TFE_NewOp(context::get_context(),
"TensorListLength", context::get_status()), &TFE_DeleteOp);
23976 status_check(context::get_status());
23980 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
23981 status_check(context::get_status());
23986 int num_outputs_op = 1;
23987 TFE_TensorHandle* res[1] = {
nullptr};
23988 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
23989 status_check(context::get_status());
23990 return tensor(res[0]);
23993 inline tensor tensor_list_push_back(
const tensor& input_handle,
const tensor& input_tensor, datatype element_dtype) {
23995 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
23996 TFE_NewOp(context::get_context(),
"TensorListPushBack", context::get_status()), &TFE_DeleteOp);
23997 status_check(context::get_status());
24001 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
24002 status_check(context::get_status());
24004 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24005 status_check(context::get_status());
24008 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24011 int num_outputs_op = 1;
24012 TFE_TensorHandle* res[1] = {
nullptr};
24013 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24014 status_check(context::get_status());
24015 return tensor(res[0]);
24018 inline tensor tensor_list_push_back_batch(
const tensor& input_handles,
const tensor& input_tensor,
24019 datatype element_dtype) {
24021 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24022 TFE_NewOp(context::get_context(),
"TensorListPushBackBatch", context::get_status()), &TFE_DeleteOp);
24023 status_check(context::get_status());
24027 TFE_OpAddInput(op.get(), input_handles.tfe_handle.get(), context::get_status());
24028 status_check(context::get_status());
24030 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24031 status_check(context::get_status());
24034 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24037 int num_outputs_op = 1;
24038 TFE_TensorHandle* res[1] = {
nullptr};
24039 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24040 status_check(context::get_status());
24041 return tensor(res[0]);
24044 inline tensor tensor_list_reserve(
const tensor& element_shape,
const tensor& num_elements, datatype element_dtype,
24045 datatype shape_type) {
24047 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24048 TFE_NewOp(context::get_context(),
"TensorListReserve", context::get_status()), &TFE_DeleteOp);
24049 status_check(context::get_status());
24053 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
24054 status_check(context::get_status());
24056 TFE_OpAddInput(op.get(), num_elements.tfe_handle.get(), context::get_status());
24057 status_check(context::get_status());
24060 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24061 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
24064 int num_outputs_op = 1;
24065 TFE_TensorHandle* res[1] = {
nullptr};
24066 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24067 status_check(context::get_status());
24068 return tensor(res[0]);
24071 inline tensor tensor_list_resize(
const tensor& input_handle,
const tensor& size) {
24073 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24074 TFE_NewOp(context::get_context(),
"TensorListResize", context::get_status()), &TFE_DeleteOp);
24075 status_check(context::get_status());
24079 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
24080 status_check(context::get_status());
24082 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
24083 status_check(context::get_status());
24088 int num_outputs_op = 1;
24089 TFE_TensorHandle* res[1] = {
nullptr};
24090 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24091 status_check(context::get_status());
24092 return tensor(res[0]);
24095 inline tensor tensor_list_scatter(
const tensor& input_tensor,
const tensor& indices,
const tensor& element_shape,
24096 datatype element_dtype, datatype shape_type) {
24098 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24099 TFE_NewOp(context::get_context(),
"TensorListScatter", context::get_status()), &TFE_DeleteOp);
24100 status_check(context::get_status());
24104 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24105 status_check(context::get_status());
24107 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24108 status_check(context::get_status());
24110 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
24111 status_check(context::get_status());
24114 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24115 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
24118 int num_outputs_op = 1;
24119 TFE_TensorHandle* res[1] = {
nullptr};
24120 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24121 status_check(context::get_status());
24122 return tensor(res[0]);
24125 inline tensor tensor_list_scatter_into_existing_list(
const tensor& input_handle,
const tensor& input_tensor,
24126 const tensor& indices, datatype element_dtype) {
24128 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24129 TFE_NewOp(context::get_context(),
"TensorListScatterIntoExistingList", context::get_status()), &TFE_DeleteOp);
24130 status_check(context::get_status());
24134 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
24135 status_check(context::get_status());
24137 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24138 status_check(context::get_status());
24140 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24141 status_check(context::get_status());
24144 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24147 int num_outputs_op = 1;
24148 TFE_TensorHandle* res[1] = {
nullptr};
24149 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24150 status_check(context::get_status());
24151 return tensor(res[0]);
24154 inline tensor tensor_list_scatter_v2(
const tensor& input_tensor,
const tensor& indices,
const tensor& element_shape,
24155 const tensor& num_elements, datatype element_dtype, datatype shape_type) {
24157 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24158 TFE_NewOp(context::get_context(),
"TensorListScatterV2", context::get_status()), &TFE_DeleteOp);
24159 status_check(context::get_status());
24163 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24164 status_check(context::get_status());
24166 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24167 status_check(context::get_status());
24169 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
24170 status_check(context::get_status());
24172 TFE_OpAddInput(op.get(), num_elements.tfe_handle.get(), context::get_status());
24173 status_check(context::get_status());
24176 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24177 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
24180 int num_outputs_op = 1;
24181 TFE_TensorHandle* res[1] = {
nullptr};
24182 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24183 status_check(context::get_status());
24184 return tensor(res[0]);
24187 inline tensor tensor_list_set_item(
const tensor& input_handle,
const tensor& index,
const tensor& item,
24188 datatype element_dtype) {
24190 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24191 TFE_NewOp(context::get_context(),
"TensorListSetItem", context::get_status()), &TFE_DeleteOp);
24192 status_check(context::get_status());
24196 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
24197 status_check(context::get_status());
24199 TFE_OpAddInput(op.get(), index.tfe_handle.get(), context::get_status());
24200 status_check(context::get_status());
24202 TFE_OpAddInput(op.get(), item.tfe_handle.get(), context::get_status());
24203 status_check(context::get_status());
24206 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24209 int num_outputs_op = 1;
24210 TFE_TensorHandle* res[1] = {
nullptr};
24211 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24212 status_check(context::get_status());
24213 return tensor(res[0]);
24216 inline tensor tensor_list_split(
const tensor& input_tensor,
const tensor& element_shape,
const tensor& lengths,
24217 datatype element_dtype, datatype shape_type) {
24219 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24220 TFE_NewOp(context::get_context(),
"TensorListSplit", context::get_status()), &TFE_DeleteOp);
24221 status_check(context::get_status());
24225 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24226 status_check(context::get_status());
24228 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
24229 status_check(context::get_status());
24231 TFE_OpAddInput(op.get(), lengths.tfe_handle.get(), context::get_status());
24232 status_check(context::get_status());
24235 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24236 TFE_OpSetAttrType(op.get(),
"shape_type", shape_type);
24239 int num_outputs_op = 1;
24240 TFE_TensorHandle* res[1] = {
nullptr};
24241 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24242 status_check(context::get_status());
24243 return tensor(res[0]);
24246 inline tensor tensor_list_stack(
const tensor& input_handle,
const tensor& element_shape, datatype element_dtype,
24247 int64_t num_elements = -1) {
24249 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24250 TFE_NewOp(context::get_context(),
"TensorListStack", context::get_status()), &TFE_DeleteOp);
24251 status_check(context::get_status());
24255 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
24256 status_check(context::get_status());
24258 TFE_OpAddInput(op.get(), element_shape.tfe_handle.get(), context::get_status());
24259 status_check(context::get_status());
24262 TFE_OpSetAttrType(op.get(),
"element_dtype", element_dtype);
24263 TFE_OpSetAttrInt(op.get(),
"num_elements", num_elements);
24266 int num_outputs_op = 1;
24267 TFE_TensorHandle* res[1] = {
nullptr};
24268 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24269 status_check(context::get_status());
24270 return tensor(res[0]);
24273 inline tensor tensor_scatter_add(
const tensor& input_tensor,
const tensor& indices,
const tensor& updates,
24274 datatype Tindices) {
24276 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24277 TFE_NewOp(context::get_context(),
"TensorScatterAdd", context::get_status()), &TFE_DeleteOp);
24278 status_check(context::get_status());
24282 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24283 status_check(context::get_status());
24285 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24286 status_check(context::get_status());
24288 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
24289 status_check(context::get_status());
24292 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
24295 int num_outputs_op = 1;
24296 TFE_TensorHandle* res[1] = {
nullptr};
24297 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24298 status_check(context::get_status());
24299 return tensor(res[0]);
24302 inline tensor tensor_scatter_max(
const tensor& input_tensor,
const tensor& indices,
const tensor& updates,
24303 datatype Tindices) {
24305 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24306 TFE_NewOp(context::get_context(),
"TensorScatterMax", context::get_status()), &TFE_DeleteOp);
24307 status_check(context::get_status());
24311 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24312 status_check(context::get_status());
24314 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24315 status_check(context::get_status());
24317 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
24318 status_check(context::get_status());
24321 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
24324 int num_outputs_op = 1;
24325 TFE_TensorHandle* res[1] = {
nullptr};
24326 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24327 status_check(context::get_status());
24328 return tensor(res[0]);
24331 inline tensor tensor_scatter_min(
const tensor& input_tensor,
const tensor& indices,
const tensor& updates,
24332 datatype Tindices) {
24334 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24335 TFE_NewOp(context::get_context(),
"TensorScatterMin", context::get_status()), &TFE_DeleteOp);
24336 status_check(context::get_status());
24340 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24341 status_check(context::get_status());
24343 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24344 status_check(context::get_status());
24346 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
24347 status_check(context::get_status());
24350 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
24353 int num_outputs_op = 1;
24354 TFE_TensorHandle* res[1] = {
nullptr};
24355 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24356 status_check(context::get_status());
24357 return tensor(res[0]);
24360 inline tensor tensor_scatter_sub(
const tensor& input_tensor,
const tensor& indices,
const tensor& updates,
24361 datatype Tindices) {
24363 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24364 TFE_NewOp(context::get_context(),
"TensorScatterSub", context::get_status()), &TFE_DeleteOp);
24365 status_check(context::get_status());
24369 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24370 status_check(context::get_status());
24372 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24373 status_check(context::get_status());
24375 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
24376 status_check(context::get_status());
24379 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
24382 int num_outputs_op = 1;
24383 TFE_TensorHandle* res[1] = {
nullptr};
24384 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24385 status_check(context::get_status());
24386 return tensor(res[0]);
24389 inline tensor tensor_scatter_update(
const tensor& input_tensor,
const tensor& indices,
const tensor& updates,
24390 datatype Tindices) {
24392 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24393 TFE_NewOp(context::get_context(),
"TensorScatterUpdate", context::get_status()), &TFE_DeleteOp);
24394 status_check(context::get_status());
24398 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24399 status_check(context::get_status());
24401 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
24402 status_check(context::get_status());
24404 TFE_OpAddInput(op.get(), updates.tfe_handle.get(), context::get_status());
24405 status_check(context::get_status());
24408 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
24411 int num_outputs_op = 1;
24412 TFE_TensorHandle* res[1] = {
nullptr};
24413 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24414 status_check(context::get_status());
24415 return tensor(res[0]);
24418 inline tensor tensor_slice_dataset(
const std::vector<tensor>& components,
const std::vector<datatype>& Toutput_types,
24419 const std::vector<std::vector<int64_t>>& output_shapes) {
24421 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24422 TFE_NewOp(context::get_context(),
"TensorSliceDataset", context::get_status()), &TFE_DeleteOp);
24423 status_check(context::get_status());
24427 std::vector<TFE_TensorHandle*> components_handles;
24428 components_handles.reserve(components.size());
24429 std::transform(components.begin(), components.end(), std::back_inserter(components_handles),
24430 [](
const auto& t) { return t.tfe_handle.get(); });
24431 TFE_OpAddInputList(op.get(), components_handles.data(),
static_cast<int>(components.size()), context::get_status());
24432 status_check(context::get_status());
24435 TFE_OpSetAttrTypeList(op.get(),
"Toutput_types",
reinterpret_cast<const enum TF_DataType*
>(Toutput_types.data()),
24436 static_cast<int>(Toutput_types.size()));
24438 std::vector<const int64_t*> output_shapes_values;
24439 output_shapes_values.reserve(output_shapes.size());
24440 std::vector<int> output_shapes_ndims;
24441 output_shapes_ndims.reserve(output_shapes.size());
24442 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
24443 [](
const auto& v) { return v.data(); });
24444 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
24445 [](
const auto& v) { return static_cast<int>(v.size()); });
24446 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
24447 static_cast<int>(output_shapes.size()), context::get_status());
24448 status_check(context::get_status());
24451 int num_outputs_op = 1;
24452 TFE_TensorHandle* res[1] = {
nullptr};
24453 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24454 status_check(context::get_status());
24455 return tensor(res[0]);
24458 inline tensor tensor_strided_slice_update(
const tensor& input,
const tensor& begin,
const tensor& end,
24459 const tensor& strides,
const tensor& value, datatype Index,
24460 int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0,
24461 int64_t new_axis_mask = 0, int64_t shrink_axis_mask = 0) {
24463 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24464 TFE_NewOp(context::get_context(),
"TensorStridedSliceUpdate", context::get_status()), &TFE_DeleteOp);
24465 status_check(context::get_status());
24469 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
24470 status_check(context::get_status());
24472 TFE_OpAddInput(op.get(), begin.tfe_handle.get(), context::get_status());
24473 status_check(context::get_status());
24475 TFE_OpAddInput(op.get(), end.tfe_handle.get(), context::get_status());
24476 status_check(context::get_status());
24478 TFE_OpAddInput(op.get(), strides.tfe_handle.get(), context::get_status());
24479 status_check(context::get_status());
24481 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
24482 status_check(context::get_status());
24485 TFE_OpSetAttrType(op.get(),
"Index", Index);
24486 TFE_OpSetAttrInt(op.get(),
"begin_mask", begin_mask);
24487 TFE_OpSetAttrInt(op.get(),
"end_mask", end_mask);
24488 TFE_OpSetAttrInt(op.get(),
"ellipsis_mask", ellipsis_mask);
24489 TFE_OpSetAttrInt(op.get(),
"new_axis_mask", new_axis_mask);
24490 TFE_OpSetAttrInt(op.get(),
"shrink_axis_mask", shrink_axis_mask);
24493 int num_outputs_op = 1;
24494 TFE_TensorHandle* res[1] = {
nullptr};
24495 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24496 status_check(context::get_status());
24497 return tensor(res[0]);
24500 inline tensor tensor_summary(
const tensor& input_tensor,
const std::vector<std::string>& labels,
24501 const std::string& description =
"",
const std::string& display_name =
"") {
24503 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24504 TFE_NewOp(context::get_context(),
"TensorSummary", context::get_status()), &TFE_DeleteOp);
24505 status_check(context::get_status());
24509 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24510 status_check(context::get_status());
24514 std::vector<std::size_t> labels_sizes;
24515 labels_sizes.reserve(labels.size());
24516 std::transform(labels.begin(), labels.end(), std::back_inserter(labels_sizes),
24517 [](
const auto& s) { return s.size(); });
24518 TFE_OpSetAttrStringList(op.get(),
"labels",
reinterpret_cast<const void* const*
>(labels.data()), labels_sizes.data(),
24519 static_cast<int>(labels.size()));
24521 TFE_OpSetAttrString(op.get(),
"description", (
void*)description.c_str(), description.size());
24522 TFE_OpSetAttrString(op.get(),
"display_name", (
void*)display_name.c_str(), display_name.size());
24525 int num_outputs_op = 1;
24526 TFE_TensorHandle* res[1] = {
nullptr};
24527 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24528 status_check(context::get_status());
24529 return tensor(res[0]);
24532 inline tensor tensor_summary_v2(
const tensor& tag,
const tensor& input_tensor,
24533 const tensor& serialized_summary_metadata) {
24535 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24536 TFE_NewOp(context::get_context(),
"TensorSummaryV2", context::get_status()), &TFE_DeleteOp);
24537 status_check(context::get_status());
24541 TFE_OpAddInput(op.get(), tag.tfe_handle.get(), context::get_status());
24542 status_check(context::get_status());
24544 TFE_OpAddInput(op.get(), input_tensor.tfe_handle.get(), context::get_status());
24545 status_check(context::get_status());
24547 TFE_OpAddInput(op.get(), serialized_summary_metadata.tfe_handle.get(), context::get_status());
24548 status_check(context::get_status());
24553 int num_outputs_op = 1;
24554 TFE_TensorHandle* res[1] = {
nullptr};
24555 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24556 status_check(context::get_status());
24557 return tensor(res[0]);
24560 inline tensor text_line_dataset(
const tensor& filenames,
const tensor& compression_type,
const tensor& buffer_size) {
24562 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24563 TFE_NewOp(context::get_context(),
"TextLineDataset", context::get_status()), &TFE_DeleteOp);
24564 status_check(context::get_status());
24568 TFE_OpAddInput(op.get(), filenames.tfe_handle.get(), context::get_status());
24569 status_check(context::get_status());
24571 TFE_OpAddInput(op.get(), compression_type.tfe_handle.get(), context::get_status());
24572 status_check(context::get_status());
24574 TFE_OpAddInput(op.get(), buffer_size.tfe_handle.get(), context::get_status());
24575 status_check(context::get_status());
24580 int num_outputs_op = 1;
24581 TFE_TensorHandle* res[1] = {
nullptr};
24582 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24583 status_check(context::get_status());
24584 return tensor(res[0]);
24587 inline tensor text_line_reader(int64_t skip_header_lines = 0,
const std::string& container =
"",
24588 const std::string& shared_name =
"") {
24590 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24591 TFE_NewOp(context::get_context(),
"TextLineReader", context::get_status()), &TFE_DeleteOp);
24592 status_check(context::get_status());
24597 TFE_OpSetAttrInt(op.get(),
"skip_header_lines", skip_header_lines);
24598 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
24599 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
24602 int num_outputs_op = 1;
24603 TFE_TensorHandle* res[1] = {
nullptr};
24604 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24605 status_check(context::get_status());
24606 return tensor(res[0]);
24609 inline tensor text_line_reader_v2(int64_t skip_header_lines = 0,
const std::string& container =
"",
24610 const std::string& shared_name =
"") {
24612 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24613 TFE_NewOp(context::get_context(),
"TextLineReaderV2", context::get_status()), &TFE_DeleteOp);
24614 status_check(context::get_status());
24619 TFE_OpSetAttrInt(op.get(),
"skip_header_lines", skip_header_lines);
24620 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
24621 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
24624 int num_outputs_op = 1;
24625 TFE_TensorHandle* res[1] = {
nullptr};
24626 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24627 status_check(context::get_status());
24628 return tensor(res[0]);
24631 inline tensor thread_pool_dataset(
const tensor& input_dataset,
const tensor& thread_pool,
24632 const std::vector<datatype>& output_types,
24633 const std::vector<std::vector<int64_t>>& output_shapes) {
24635 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24636 TFE_NewOp(context::get_context(),
"ThreadPoolDataset", context::get_status()), &TFE_DeleteOp);
24637 status_check(context::get_status());
24641 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
24642 status_check(context::get_status());
24644 TFE_OpAddInput(op.get(), thread_pool.tfe_handle.get(), context::get_status());
24645 status_check(context::get_status());
24648 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
24649 static_cast<int>(output_types.size()));
24651 std::vector<const int64_t*> output_shapes_values;
24652 output_shapes_values.reserve(output_shapes.size());
24653 std::vector<int> output_shapes_ndims;
24654 output_shapes_ndims.reserve(output_shapes.size());
24655 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
24656 [](
const auto& v) { return v.data(); });
24657 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
24658 [](
const auto& v) { return static_cast<int>(v.size()); });
24659 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
24660 static_cast<int>(output_shapes.size()), context::get_status());
24661 status_check(context::get_status());
24664 int num_outputs_op = 1;
24665 TFE_TensorHandle* res[1] = {
nullptr};
24666 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24667 status_check(context::get_status());
24668 return tensor(res[0]);
24671 inline tensor thread_pool_handle(int64_t num_threads,
const std::string& display_name,
24672 int64_t max_intra_op_parallelism = 1,
const std::string& container =
"",
24673 const std::string& shared_name =
"") {
24675 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24676 TFE_NewOp(context::get_context(),
"ThreadPoolHandle", context::get_status()), &TFE_DeleteOp);
24677 status_check(context::get_status());
24682 TFE_OpSetAttrInt(op.get(),
"num_threads", num_threads);
24683 TFE_OpSetAttrString(op.get(),
"display_name", (
void*)display_name.c_str(), display_name.size());
24684 TFE_OpSetAttrInt(op.get(),
"max_intra_op_parallelism", max_intra_op_parallelism);
24685 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
24686 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
24689 int num_outputs_op = 1;
24690 TFE_TensorHandle* res[1] = {
nullptr};
24691 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24692 status_check(context::get_status());
24693 return tensor(res[0]);
24696 inline tensor tile(
const tensor& input,
const tensor& multiples, datatype Tmultiples =
static_cast<datatype
>(3)) {
24698 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Tile", context::get_status()),
24700 status_check(context::get_status());
24704 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
24705 status_check(context::get_status());
24707 TFE_OpAddInput(op.get(), multiples.tfe_handle.get(), context::get_status());
24708 status_check(context::get_status());
24711 TFE_OpSetAttrType(op.get(),
"Tmultiples", Tmultiples);
24714 int num_outputs_op = 1;
24715 TFE_TensorHandle* res[1] = {
nullptr};
24716 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24717 status_check(context::get_status());
24718 return tensor(res[0]);
24721 inline tensor tile_grad(
const tensor& input,
const tensor& multiples) {
24723 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24724 TFE_NewOp(context::get_context(),
"TileGrad", context::get_status()), &TFE_DeleteOp);
24725 status_check(context::get_status());
24729 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
24730 status_check(context::get_status());
24732 TFE_OpAddInput(op.get(), multiples.tfe_handle.get(), context::get_status());
24733 status_check(context::get_status());
24738 int num_outputs_op = 1;
24739 TFE_TensorHandle* res[1] = {
nullptr};
24740 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24741 status_check(context::get_status());
24742 return tensor(res[0]);
24745 inline tensor timestamp() {
24747 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24748 TFE_NewOp(context::get_context(),
"Timestamp", context::get_status()), &TFE_DeleteOp);
24749 status_check(context::get_status());
24756 int num_outputs_op = 1;
24757 TFE_TensorHandle* res[1] = {
nullptr};
24758 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24759 status_check(context::get_status());
24760 return tensor(res[0]);
24763 inline tensor to_bool(
const tensor& input) {
24765 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24766 TFE_NewOp(context::get_context(),
"ToBool", context::get_status()), &TFE_DeleteOp);
24767 status_check(context::get_status());
24771 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
24772 status_check(context::get_status());
24777 int num_outputs_op = 1;
24778 TFE_TensorHandle* res[1] = {
nullptr};
24779 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24780 status_check(context::get_status());
24781 return tensor(res[0]);
24784 inline tensor transpose(
const tensor& x,
const tensor& perm, datatype Tperm =
static_cast<datatype
>(3)) {
24786 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24787 TFE_NewOp(context::get_context(),
"Transpose", context::get_status()), &TFE_DeleteOp);
24788 status_check(context::get_status());
24792 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
24793 status_check(context::get_status());
24795 TFE_OpAddInput(op.get(), perm.tfe_handle.get(), context::get_status());
24796 status_check(context::get_status());
24799 TFE_OpSetAttrType(op.get(),
"Tperm", Tperm);
24802 int num_outputs_op = 1;
24803 TFE_TensorHandle* res[1] = {
nullptr};
24804 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24805 status_check(context::get_status());
24806 return tensor(res[0]);
24809 inline tensor tridiagonal_mat_mul(
const tensor& superdiag,
const tensor& maindiag,
const tensor& subdiag,
24810 const tensor& rhs) {
24812 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24813 TFE_NewOp(context::get_context(),
"TridiagonalMatMul", context::get_status()), &TFE_DeleteOp);
24814 status_check(context::get_status());
24818 TFE_OpAddInput(op.get(), superdiag.tfe_handle.get(), context::get_status());
24819 status_check(context::get_status());
24821 TFE_OpAddInput(op.get(), maindiag.tfe_handle.get(), context::get_status());
24822 status_check(context::get_status());
24824 TFE_OpAddInput(op.get(), subdiag.tfe_handle.get(), context::get_status());
24825 status_check(context::get_status());
24827 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
24828 status_check(context::get_status());
24833 int num_outputs_op = 1;
24834 TFE_TensorHandle* res[1] = {
nullptr};
24835 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24836 status_check(context::get_status());
24837 return tensor(res[0]);
24840 inline tensor tridiagonal_solve(
const tensor& diagonals,
const tensor& rhs,
bool partial_pivoting =
true) {
24842 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24843 TFE_NewOp(context::get_context(),
"TridiagonalSolve", context::get_status()), &TFE_DeleteOp);
24844 status_check(context::get_status());
24848 TFE_OpAddInput(op.get(), diagonals.tfe_handle.get(), context::get_status());
24849 status_check(context::get_status());
24851 TFE_OpAddInput(op.get(), rhs.tfe_handle.get(), context::get_status());
24852 status_check(context::get_status());
24855 TFE_OpSetAttrBool(op.get(),
"partial_pivoting", (
unsigned char)partial_pivoting);
24858 int num_outputs_op = 1;
24859 TFE_TensorHandle* res[1] = {
nullptr};
24860 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24861 status_check(context::get_status());
24862 return tensor(res[0]);
24865 inline tensor truncate_div(
const tensor& x,
const tensor& y) {
24867 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24868 TFE_NewOp(context::get_context(),
"TruncateDiv", context::get_status()), &TFE_DeleteOp);
24869 status_check(context::get_status());
24873 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
24874 status_check(context::get_status());
24876 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
24877 status_check(context::get_status());
24882 int num_outputs_op = 1;
24883 TFE_TensorHandle* res[1] = {
nullptr};
24884 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24885 status_check(context::get_status());
24886 return tensor(res[0]);
24889 inline tensor truncate_mod(
const tensor& x,
const tensor& y) {
24891 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24892 TFE_NewOp(context::get_context(),
"TruncateMod", context::get_status()), &TFE_DeleteOp);
24893 status_check(context::get_status());
24897 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
24898 status_check(context::get_status());
24900 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
24901 status_check(context::get_status());
24906 int num_outputs_op = 1;
24907 TFE_TensorHandle* res[1] = {
nullptr};
24908 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24909 status_check(context::get_status());
24910 return tensor(res[0]);
24913 inline tensor truncated_normal(
const tensor& shape, datatype dtype, int64_t seed = 0, int64_t seed2 = 0) {
24915 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24916 TFE_NewOp(context::get_context(),
"TruncatedNormal", context::get_status()), &TFE_DeleteOp);
24917 status_check(context::get_status());
24921 TFE_OpAddInput(op.get(), shape.tfe_handle.get(), context::get_status());
24922 status_check(context::get_status());
24925 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
24926 TFE_OpSetAttrInt(op.get(),
"seed", seed);
24927 TFE_OpSetAttrInt(op.get(),
"seed2", seed2);
24930 int num_outputs_op = 1;
24931 TFE_TensorHandle* res[1] = {
nullptr};
24932 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24933 status_check(context::get_status());
24934 return tensor(res[0]);
24937 inline tensor unbatch(
const tensor& batched_input_tensor,
const tensor& batch_index,
const tensor&
id,
24938 int64_t timeout_micros,
const std::string& container =
"",
const std::string& shared_name =
"") {
24940 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24941 TFE_NewOp(context::get_context(),
"Unbatch", context::get_status()), &TFE_DeleteOp);
24942 status_check(context::get_status());
24946 TFE_OpAddInput(op.get(), batched_input_tensor.tfe_handle.get(), context::get_status());
24947 status_check(context::get_status());
24949 TFE_OpAddInput(op.get(), batch_index.tfe_handle.get(), context::get_status());
24950 status_check(context::get_status());
24952 TFE_OpAddInput(op.get(),
id.tfe_handle.get(), context::get_status());
24953 status_check(context::get_status());
24956 TFE_OpSetAttrInt(op.get(),
"timeout_micros", timeout_micros);
24957 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
24958 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
24961 int num_outputs_op = 1;
24962 TFE_TensorHandle* res[1] = {
nullptr};
24963 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
24964 status_check(context::get_status());
24965 return tensor(res[0]);
24968 inline tensor unbatch_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
24969 const std::vector<std::vector<int64_t>>& output_shapes) {
24971 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
24972 TFE_NewOp(context::get_context(),
"UnbatchDataset", context::get_status()), &TFE_DeleteOp);
24973 status_check(context::get_status());
24977 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
24978 status_check(context::get_status());
24981 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
24982 static_cast<int>(output_types.size()));
24984 std::vector<const int64_t*> output_shapes_values;
24985 output_shapes_values.reserve(output_shapes.size());
24986 std::vector<int> output_shapes_ndims;
24987 output_shapes_ndims.reserve(output_shapes.size());
24988 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
24989 [](
const auto& v) { return v.data(); });
24990 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
24991 [](
const auto& v) { return static_cast<int>(v.size()); });
24992 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
24993 static_cast<int>(output_shapes.size()), context::get_status());
24994 status_check(context::get_status());
24997 int num_outputs_op = 1;
24998 TFE_TensorHandle* res[1] = {
nullptr};
24999 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25000 status_check(context::get_status());
25001 return tensor(res[0]);
25004 inline tensor unbatch_grad(
const tensor& original_input,
const tensor& batch_index,
const tensor& grad,
25005 const tensor&
id,
const std::string& container =
"",
const std::string& shared_name =
"") {
25007 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25008 TFE_NewOp(context::get_context(),
"UnbatchGrad", context::get_status()), &TFE_DeleteOp);
25009 status_check(context::get_status());
25013 TFE_OpAddInput(op.get(), original_input.tfe_handle.get(), context::get_status());
25014 status_check(context::get_status());
25016 TFE_OpAddInput(op.get(), batch_index.tfe_handle.get(), context::get_status());
25017 status_check(context::get_status());
25019 TFE_OpAddInput(op.get(), grad.tfe_handle.get(), context::get_status());
25020 status_check(context::get_status());
25022 TFE_OpAddInput(op.get(),
id.tfe_handle.get(), context::get_status());
25023 status_check(context::get_status());
25026 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25027 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25030 int num_outputs_op = 1;
25031 TFE_TensorHandle* res[1] = {
nullptr};
25032 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25033 status_check(context::get_status());
25034 return tensor(res[0]);
25037 inline tensor uncompress_element(
const tensor& compressed,
const std::vector<datatype>& output_types,
25038 const std::vector<std::vector<int64_t>>& output_shapes) {
25040 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25041 TFE_NewOp(context::get_context(),
"UncompressElement", context::get_status()), &TFE_DeleteOp);
25042 status_check(context::get_status());
25046 TFE_OpAddInput(op.get(), compressed.tfe_handle.get(), context::get_status());
25047 status_check(context::get_status());
25050 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
25051 static_cast<int>(output_types.size()));
25053 std::vector<const int64_t*> output_shapes_values;
25054 output_shapes_values.reserve(output_shapes.size());
25055 std::vector<int> output_shapes_ndims;
25056 output_shapes_ndims.reserve(output_shapes.size());
25057 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
25058 [](
const auto& v) { return v.data(); });
25059 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
25060 [](
const auto& v) { return static_cast<int>(v.size()); });
25061 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
25062 static_cast<int>(output_shapes.size()), context::get_status());
25063 status_check(context::get_status());
25066 int num_outputs_op = 1;
25067 TFE_TensorHandle* res[1] = {
nullptr};
25068 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25069 status_check(context::get_status());
25070 return tensor(res[0]);
25073 inline tensor unicode_encode(
const tensor& input_values,
const tensor& input_splits,
const std::string& output_encoding,
25074 const std::string& errors =
"replace", int64_t replacement_char = 65533,
25075 datatype Tsplits =
static_cast<datatype
>(9)) {
25077 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25078 TFE_NewOp(context::get_context(),
"UnicodeEncode", context::get_status()), &TFE_DeleteOp);
25079 status_check(context::get_status());
25083 TFE_OpAddInput(op.get(), input_values.tfe_handle.get(), context::get_status());
25084 status_check(context::get_status());
25086 TFE_OpAddInput(op.get(), input_splits.tfe_handle.get(), context::get_status());
25087 status_check(context::get_status());
25090 TFE_OpSetAttrString(op.get(),
"output_encoding", (
void*)output_encoding.c_str(), output_encoding.size());
25091 TFE_OpSetAttrString(op.get(),
"errors", (
void*)errors.c_str(), errors.size());
25092 TFE_OpSetAttrInt(op.get(),
"replacement_char", replacement_char);
25093 TFE_OpSetAttrType(op.get(),
"Tsplits", Tsplits);
25096 int num_outputs_op = 1;
25097 TFE_TensorHandle* res[1] = {
nullptr};
25098 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25099 status_check(context::get_status());
25100 return tensor(res[0]);
25103 inline tensor unicode_script(
const tensor& input) {
25105 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25106 TFE_NewOp(context::get_context(),
"UnicodeScript", context::get_status()), &TFE_DeleteOp);
25107 status_check(context::get_status());
25111 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
25112 status_check(context::get_status());
25117 int num_outputs_op = 1;
25118 TFE_TensorHandle* res[1] = {
nullptr};
25119 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25120 status_check(context::get_status());
25121 return tensor(res[0]);
25124 inline tensor unicode_transcode(
const tensor& input,
const std::string& input_encoding,
25125 const std::string& output_encoding,
const std::string& errors =
"replace",
25126 int64_t replacement_char = 65533,
bool replace_control_characters =
false) {
25128 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25129 TFE_NewOp(context::get_context(),
"UnicodeTranscode", context::get_status()), &TFE_DeleteOp);
25130 status_check(context::get_status());
25134 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
25135 status_check(context::get_status());
25138 TFE_OpSetAttrString(op.get(),
"input_encoding", (
void*)input_encoding.c_str(), input_encoding.size());
25139 TFE_OpSetAttrString(op.get(),
"output_encoding", (
void*)output_encoding.c_str(), output_encoding.size());
25140 TFE_OpSetAttrString(op.get(),
"errors", (
void*)errors.c_str(), errors.size());
25141 TFE_OpSetAttrInt(op.get(),
"replacement_char", replacement_char);
25142 TFE_OpSetAttrBool(op.get(),
"replace_control_characters", (
unsigned char)replace_control_characters);
25145 int num_outputs_op = 1;
25146 TFE_TensorHandle* res[1] = {
nullptr};
25147 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25148 status_check(context::get_status());
25149 return tensor(res[0]);
25152 inline tensor unique_dataset(
const tensor& input_dataset,
const std::vector<datatype>& output_types,
25153 const std::vector<std::vector<int64_t>>& output_shapes) {
25155 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25156 TFE_NewOp(context::get_context(),
"UniqueDataset", context::get_status()), &TFE_DeleteOp);
25157 status_check(context::get_status());
25161 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
25162 status_check(context::get_status());
25165 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
25166 static_cast<int>(output_types.size()));
25168 std::vector<const int64_t*> output_shapes_values;
25169 output_shapes_values.reserve(output_shapes.size());
25170 std::vector<int> output_shapes_ndims;
25171 output_shapes_ndims.reserve(output_shapes.size());
25172 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
25173 [](
const auto& v) { return v.data(); });
25174 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
25175 [](
const auto& v) { return static_cast<int>(v.size()); });
25176 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
25177 static_cast<int>(output_shapes.size()), context::get_status());
25178 status_check(context::get_status());
25181 int num_outputs_op = 1;
25182 TFE_TensorHandle* res[1] = {
nullptr};
25183 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25184 status_check(context::get_status());
25185 return tensor(res[0]);
25188 inline tensor unpack(
const tensor& value, int64_t num, int64_t axis = 0) {
25190 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25191 TFE_NewOp(context::get_context(),
"Unpack", context::get_status()), &TFE_DeleteOp);
25192 status_check(context::get_status());
25196 TFE_OpAddInput(op.get(), value.tfe_handle.get(), context::get_status());
25197 status_check(context::get_status());
25200 TFE_OpSetAttrInt(op.get(),
"num", num);
25201 TFE_OpSetAttrInt(op.get(),
"axis", axis);
25204 int num_outputs_op = 1;
25205 TFE_TensorHandle* res[1] = {
nullptr};
25206 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25207 status_check(context::get_status());
25208 return tensor(res[0]);
25211 inline tensor unravel_index(
const tensor& indices,
const tensor& dims, datatype Tidx =
static_cast<datatype
>(3)) {
25213 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25214 TFE_NewOp(context::get_context(),
"UnravelIndex", context::get_status()), &TFE_DeleteOp);
25215 status_check(context::get_status());
25219 TFE_OpAddInput(op.get(), indices.tfe_handle.get(), context::get_status());
25220 status_check(context::get_status());
25222 TFE_OpAddInput(op.get(), dims.tfe_handle.get(), context::get_status());
25223 status_check(context::get_status());
25226 TFE_OpSetAttrType(op.get(),
"Tidx", Tidx);
25229 int num_outputs_op = 1;
25230 TFE_TensorHandle* res[1] = {
nullptr};
25231 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25232 status_check(context::get_status());
25233 return tensor(res[0]);
25236 inline tensor unsorted_segment_join(
const tensor& inputs,
const tensor& segment_ids,
const tensor& num_segments,
25237 datatype Tindices,
const std::string& separator =
"",
25238 datatype Tnumsegments =
static_cast<datatype
>(3)) {
25240 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25241 TFE_NewOp(context::get_context(),
"UnsortedSegmentJoin", context::get_status()), &TFE_DeleteOp);
25242 status_check(context::get_status());
25246 TFE_OpAddInput(op.get(), inputs.tfe_handle.get(), context::get_status());
25247 status_check(context::get_status());
25249 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
25250 status_check(context::get_status());
25252 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
25253 status_check(context::get_status());
25256 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
25257 TFE_OpSetAttrString(op.get(),
"separator", (
void*)separator.c_str(), separator.size());
25258 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
25261 int num_outputs_op = 1;
25262 TFE_TensorHandle* res[1] = {
nullptr};
25263 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25264 status_check(context::get_status());
25265 return tensor(res[0]);
25268 inline tensor unsorted_segment_max(
const tensor& data,
const tensor& segment_ids,
const tensor& num_segments,
25269 datatype Tindices, datatype Tnumsegments =
static_cast<datatype
>(3)) {
25271 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25272 TFE_NewOp(context::get_context(),
"UnsortedSegmentMax", context::get_status()), &TFE_DeleteOp);
25273 status_check(context::get_status());
25277 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
25278 status_check(context::get_status());
25280 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
25281 status_check(context::get_status());
25283 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
25284 status_check(context::get_status());
25287 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
25288 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
25291 int num_outputs_op = 1;
25292 TFE_TensorHandle* res[1] = {
nullptr};
25293 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25294 status_check(context::get_status());
25295 return tensor(res[0]);
25298 inline tensor unsorted_segment_min(
const tensor& data,
const tensor& segment_ids,
const tensor& num_segments,
25299 datatype Tindices, datatype Tnumsegments =
static_cast<datatype
>(3)) {
25301 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25302 TFE_NewOp(context::get_context(),
"UnsortedSegmentMin", context::get_status()), &TFE_DeleteOp);
25303 status_check(context::get_status());
25307 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
25308 status_check(context::get_status());
25310 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
25311 status_check(context::get_status());
25313 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
25314 status_check(context::get_status());
25317 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
25318 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
25321 int num_outputs_op = 1;
25322 TFE_TensorHandle* res[1] = {
nullptr};
25323 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25324 status_check(context::get_status());
25325 return tensor(res[0]);
25328 inline tensor unsorted_segment_prod(
const tensor& data,
const tensor& segment_ids,
const tensor& num_segments,
25329 datatype Tindices, datatype Tnumsegments =
static_cast<datatype
>(3)) {
25331 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25332 TFE_NewOp(context::get_context(),
"UnsortedSegmentProd", context::get_status()), &TFE_DeleteOp);
25333 status_check(context::get_status());
25337 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
25338 status_check(context::get_status());
25340 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
25341 status_check(context::get_status());
25343 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
25344 status_check(context::get_status());
25347 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
25348 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
25351 int num_outputs_op = 1;
25352 TFE_TensorHandle* res[1] = {
nullptr};
25353 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25354 status_check(context::get_status());
25355 return tensor(res[0]);
25358 inline tensor unsorted_segment_sum(
const tensor& data,
const tensor& segment_ids,
const tensor& num_segments,
25359 datatype Tindices, datatype Tnumsegments =
static_cast<datatype
>(3)) {
25361 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25362 TFE_NewOp(context::get_context(),
"UnsortedSegmentSum", context::get_status()), &TFE_DeleteOp);
25363 status_check(context::get_status());
25367 TFE_OpAddInput(op.get(), data.tfe_handle.get(), context::get_status());
25368 status_check(context::get_status());
25370 TFE_OpAddInput(op.get(), segment_ids.tfe_handle.get(), context::get_status());
25371 status_check(context::get_status());
25373 TFE_OpAddInput(op.get(), num_segments.tfe_handle.get(), context::get_status());
25374 status_check(context::get_status());
25377 TFE_OpSetAttrType(op.get(),
"Tindices", Tindices);
25378 TFE_OpSetAttrType(op.get(),
"Tnumsegments", Tnumsegments);
25381 int num_outputs_op = 1;
25382 TFE_TensorHandle* res[1] = {
nullptr};
25383 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25384 status_check(context::get_status());
25385 return tensor(res[0]);
25388 inline tensor unstage(
const std::vector<datatype>& dtypes, int64_t capacity = 0, int64_t memory_limit = 0,
25389 const std::string& container =
"",
const std::string& shared_name =
"") {
25391 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25392 TFE_NewOp(context::get_context(),
"Unstage", context::get_status()), &TFE_DeleteOp);
25393 status_check(context::get_status());
25398 TFE_OpSetAttrTypeList(op.get(),
"dtypes",
reinterpret_cast<const enum TF_DataType*
>(dtypes.data()),
25399 static_cast<int>(dtypes.size()));
25400 TFE_OpSetAttrInt(op.get(),
"capacity", capacity);
25401 TFE_OpSetAttrInt(op.get(),
"memory_limit", memory_limit);
25402 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25403 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25406 int num_outputs_op = 1;
25407 TFE_TensorHandle* res[1] = {
nullptr};
25408 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25409 status_check(context::get_status());
25410 return tensor(res[0]);
25413 inline tensor unwrap_dataset_variant(
const tensor& input_handle) {
25415 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25416 TFE_NewOp(context::get_context(),
"UnwrapDatasetVariant", context::get_status()), &TFE_DeleteOp);
25417 status_check(context::get_status());
25421 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
25422 status_check(context::get_status());
25427 int num_outputs_op = 1;
25428 TFE_TensorHandle* res[1] = {
nullptr};
25429 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25430 status_check(context::get_status());
25431 return tensor(res[0]);
25434 inline tensor upper_bound(
const tensor& sorted_inputs,
const tensor& values,
25435 datatype out_type =
static_cast<datatype
>(3)) {
25437 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25438 TFE_NewOp(context::get_context(),
"UpperBound", context::get_status()), &TFE_DeleteOp);
25439 status_check(context::get_status());
25443 TFE_OpAddInput(op.get(), sorted_inputs.tfe_handle.get(), context::get_status());
25444 status_check(context::get_status());
25446 TFE_OpAddInput(op.get(), values.tfe_handle.get(), context::get_status());
25447 status_check(context::get_status());
25450 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
25453 int num_outputs_op = 1;
25454 TFE_TensorHandle* res[1] = {
nullptr};
25455 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25456 status_check(context::get_status());
25457 return tensor(res[0]);
25460 inline tensor var_handle_op(datatype dtype,
const std::vector<int64_t>& shape,
25461 const std::vector<std::string>& allowed_devices,
const std::string& container =
"",
25462 const std::string& shared_name =
"") {
25464 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25465 TFE_NewOp(context::get_context(),
"VarHandleOp", context::get_status()), &TFE_DeleteOp);
25466 status_check(context::get_status());
25471 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
25473 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
25474 status_check(context::get_status());
25476 std::vector<std::size_t> allowed_devices_sizes;
25477 allowed_devices_sizes.reserve(allowed_devices.size());
25478 std::transform(allowed_devices.begin(), allowed_devices.end(), std::back_inserter(allowed_devices_sizes),
25479 [](
const auto& s) { return s.size(); });
25480 TFE_OpSetAttrStringList(op.get(),
"allowed_devices",
reinterpret_cast<const void* const*
>(allowed_devices.data()),
25481 allowed_devices_sizes.data(),
static_cast<int>(allowed_devices.size()));
25483 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25484 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25487 int num_outputs_op = 1;
25488 TFE_TensorHandle* res[1] = {
nullptr};
25489 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25490 status_check(context::get_status());
25491 return tensor(res[0]);
25494 inline tensor var_is_initialized_op(
const tensor& resource) {
25496 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25497 TFE_NewOp(context::get_context(),
"VarIsInitializedOp", context::get_status()), &TFE_DeleteOp);
25498 status_check(context::get_status());
25502 TFE_OpAddInput(op.get(), resource.tfe_handle.get(), context::get_status());
25503 status_check(context::get_status());
25508 int num_outputs_op = 1;
25509 TFE_TensorHandle* res[1] = {
nullptr};
25510 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25511 status_check(context::get_status());
25512 return tensor(res[0]);
25515 inline tensor variable(
const std::vector<int64_t>& shape, datatype dtype,
const std::string& container =
"",
25516 const std::string& shared_name =
"") {
25518 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25519 TFE_NewOp(context::get_context(),
"Variable", context::get_status()), &TFE_DeleteOp);
25520 status_check(context::get_status());
25526 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
25527 status_check(context::get_status());
25529 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
25530 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25531 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25534 int num_outputs_op = 1;
25535 TFE_TensorHandle* res[1] = {
nullptr};
25536 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25537 status_check(context::get_status());
25538 return tensor(res[0]);
25541 inline tensor variable_shape(
const tensor& input, datatype out_type =
static_cast<datatype
>(3)) {
25543 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25544 TFE_NewOp(context::get_context(),
"VariableShape", context::get_status()), &TFE_DeleteOp);
25545 status_check(context::get_status());
25549 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
25550 status_check(context::get_status());
25553 TFE_OpSetAttrType(op.get(),
"out_type", out_type);
25556 int num_outputs_op = 1;
25557 TFE_TensorHandle* res[1] = {
nullptr};
25558 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25559 status_check(context::get_status());
25560 return tensor(res[0]);
25563 inline tensor variable_v2(
const std::vector<int64_t>& shape, datatype dtype,
const std::string& container =
"",
25564 const std::string& shared_name =
"") {
25566 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25567 TFE_NewOp(context::get_context(),
"VariableV2", context::get_status()), &TFE_DeleteOp);
25568 status_check(context::get_status());
25574 TFE_OpSetAttrShape(op.get(),
"shape", shape.data(),
static_cast<int>(shape.size()), context::get_status());
25575 status_check(context::get_status());
25577 TFE_OpSetAttrType(op.get(),
"dtype", dtype);
25578 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25579 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25582 int num_outputs_op = 1;
25583 TFE_TensorHandle* res[1] = {
nullptr};
25584 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25585 status_check(context::get_status());
25586 return tensor(res[0]);
25589 inline tensor where(
const tensor& input) {
25591 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Where", context::get_status()),
25593 status_check(context::get_status());
25597 TFE_OpAddInput(op.get(), input.tfe_handle.get(), context::get_status());
25598 status_check(context::get_status());
25603 int num_outputs_op = 1;
25604 TFE_TensorHandle* res[1] = {
nullptr};
25605 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25606 status_check(context::get_status());
25607 return tensor(res[0]);
25610 inline tensor whole_file_reader(
const std::string& container =
"",
const std::string& shared_name =
"") {
25612 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25613 TFE_NewOp(context::get_context(),
"WholeFileReader", context::get_status()), &TFE_DeleteOp);
25614 status_check(context::get_status());
25619 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25620 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25623 int num_outputs_op = 1;
25624 TFE_TensorHandle* res[1] = {
nullptr};
25625 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25626 status_check(context::get_status());
25627 return tensor(res[0]);
25630 inline tensor whole_file_reader_v2(
const std::string& container =
"",
const std::string& shared_name =
"") {
25632 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25633 TFE_NewOp(context::get_context(),
"WholeFileReaderV2", context::get_status()), &TFE_DeleteOp);
25634 status_check(context::get_status());
25639 TFE_OpSetAttrString(op.get(),
"container", (
void*)container.c_str(), container.size());
25640 TFE_OpSetAttrString(op.get(),
"shared_name", (
void*)shared_name.c_str(), shared_name.size());
25643 int num_outputs_op = 1;
25644 TFE_TensorHandle* res[1] = {
nullptr};
25645 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25646 status_check(context::get_status());
25647 return tensor(res[0]);
25650 inline tensor window_dataset(
const tensor& input_dataset,
const tensor& size,
const tensor& shift,
const tensor& stride,
25651 const tensor& drop_remainder,
const std::vector<datatype>& output_types,
25652 const std::vector<std::vector<int64_t>>& output_shapes) {
25654 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25655 TFE_NewOp(context::get_context(),
"WindowDataset", context::get_status()), &TFE_DeleteOp);
25656 status_check(context::get_status());
25660 TFE_OpAddInput(op.get(), input_dataset.tfe_handle.get(), context::get_status());
25661 status_check(context::get_status());
25663 TFE_OpAddInput(op.get(), size.tfe_handle.get(), context::get_status());
25664 status_check(context::get_status());
25666 TFE_OpAddInput(op.get(), shift.tfe_handle.get(), context::get_status());
25667 status_check(context::get_status());
25669 TFE_OpAddInput(op.get(), stride.tfe_handle.get(), context::get_status());
25670 status_check(context::get_status());
25672 TFE_OpAddInput(op.get(), drop_remainder.tfe_handle.get(), context::get_status());
25673 status_check(context::get_status());
25676 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
25677 static_cast<int>(output_types.size()));
25679 std::vector<const int64_t*> output_shapes_values;
25680 output_shapes_values.reserve(output_shapes.size());
25681 std::vector<int> output_shapes_ndims;
25682 output_shapes_ndims.reserve(output_shapes.size());
25683 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
25684 [](
const auto& v) { return v.data(); });
25685 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
25686 [](
const auto& v) { return static_cast<int>(v.size()); });
25687 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
25688 static_cast<int>(output_shapes.size()), context::get_status());
25689 status_check(context::get_status());
25692 int num_outputs_op = 1;
25693 TFE_TensorHandle* res[1] = {
nullptr};
25694 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25695 status_check(context::get_status());
25696 return tensor(res[0]);
25699 inline tensor worker_heartbeat(
const tensor& request) {
25701 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25702 TFE_NewOp(context::get_context(),
"WorkerHeartbeat", context::get_status()), &TFE_DeleteOp);
25703 status_check(context::get_status());
25707 TFE_OpAddInput(op.get(), request.tfe_handle.get(), context::get_status());
25708 status_check(context::get_status());
25713 int num_outputs_op = 1;
25714 TFE_TensorHandle* res[1] = {
nullptr};
25715 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25716 status_check(context::get_status());
25717 return tensor(res[0]);
25720 inline tensor wrap_dataset_variant(
const tensor& input_handle) {
25722 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25723 TFE_NewOp(context::get_context(),
"WrapDatasetVariant", context::get_status()), &TFE_DeleteOp);
25724 status_check(context::get_status());
25728 TFE_OpAddInput(op.get(), input_handle.tfe_handle.get(), context::get_status());
25729 status_check(context::get_status());
25734 int num_outputs_op = 1;
25735 TFE_TensorHandle* res[1] = {
nullptr};
25736 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25737 status_check(context::get_status());
25738 return tensor(res[0]);
25741 inline tensor xdivy(
const tensor& x,
const tensor& y) {
25743 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Xdivy", context::get_status()),
25745 status_check(context::get_status());
25749 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
25750 status_check(context::get_status());
25752 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
25753 status_check(context::get_status());
25758 int num_outputs_op = 1;
25759 TFE_TensorHandle* res[1] = {
nullptr};
25760 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25761 status_check(context::get_status());
25762 return tensor(res[0]);
25765 inline tensor xlog1py(
const tensor& x,
const tensor& y) {
25767 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25768 TFE_NewOp(context::get_context(),
"Xlog1py", context::get_status()), &TFE_DeleteOp);
25769 status_check(context::get_status());
25773 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
25774 status_check(context::get_status());
25776 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
25777 status_check(context::get_status());
25782 int num_outputs_op = 1;
25783 TFE_TensorHandle* res[1] = {
nullptr};
25784 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25785 status_check(context::get_status());
25786 return tensor(res[0]);
25789 inline tensor xlogy(
const tensor& x,
const tensor& y) {
25791 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Xlogy", context::get_status()),
25793 status_check(context::get_status());
25797 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
25798 status_check(context::get_status());
25800 TFE_OpAddInput(op.get(), y.tfe_handle.get(), context::get_status());
25801 status_check(context::get_status());
25806 int num_outputs_op = 1;
25807 TFE_TensorHandle* res[1] = {
nullptr};
25808 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25809 status_check(context::get_status());
25810 return tensor(res[0]);
25813 inline tensor zeros_like(
const tensor& x) {
25815 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25816 TFE_NewOp(context::get_context(),
"ZerosLike", context::get_status()), &TFE_DeleteOp);
25817 status_check(context::get_status());
25821 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
25822 status_check(context::get_status());
25827 int num_outputs_op = 1;
25828 TFE_TensorHandle* res[1] = {
nullptr};
25829 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25830 status_check(context::get_status());
25831 return tensor(res[0]);
25834 inline tensor zeta(
const tensor& x,
const tensor& q) {
25836 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(),
"Zeta", context::get_status()),
25838 status_check(context::get_status());
25842 TFE_OpAddInput(op.get(), x.tfe_handle.get(), context::get_status());
25843 status_check(context::get_status());
25845 TFE_OpAddInput(op.get(), q.tfe_handle.get(), context::get_status());
25846 status_check(context::get_status());
25851 int num_outputs_op = 1;
25852 TFE_TensorHandle* res[1] = {
nullptr};
25853 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25854 status_check(context::get_status());
25855 return tensor(res[0]);
25858 inline tensor zip_dataset(
const std::vector<tensor>& input_datasets,
const std::vector<datatype>& output_types,
25859 const std::vector<std::vector<int64_t>>& output_shapes) {
25861 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
25862 TFE_NewOp(context::get_context(),
"ZipDataset", context::get_status()), &TFE_DeleteOp);
25863 status_check(context::get_status());
25867 std::vector<TFE_TensorHandle*> input_datasets_handles;
25868 input_datasets_handles.reserve(input_datasets.size());
25869 std::transform(input_datasets.begin(), input_datasets.end(), std::back_inserter(input_datasets_handles),
25870 [](
const auto& t) { return t.tfe_handle.get(); });
25871 TFE_OpAddInputList(op.get(), input_datasets_handles.data(),
static_cast<int>(input_datasets.size()),
25872 context::get_status());
25873 status_check(context::get_status());
25876 TFE_OpSetAttrTypeList(op.get(),
"output_types",
reinterpret_cast<const enum TF_DataType*
>(output_types.data()),
25877 static_cast<int>(output_types.size()));
25879 std::vector<const int64_t*> output_shapes_values;
25880 output_shapes_values.reserve(output_shapes.size());
25881 std::vector<int> output_shapes_ndims;
25882 output_shapes_ndims.reserve(output_shapes.size());
25883 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_values),
25884 [](
const auto& v) { return v.data(); });
25885 std::transform(output_shapes.begin(), output_shapes.end(), std::back_inserter(output_shapes_ndims),
25886 [](
const auto& v) { return static_cast<int>(v.size()); });
25887 TFE_OpSetAttrShapeList(op.get(),
"output_shapes", output_shapes_values.data(), output_shapes_ndims.data(),
25888 static_cast<int>(output_shapes.size()), context::get_status());
25889 status_check(context::get_status());
25891 TFE_OpSetAttrInt(op.get(),
"N", input_datasets.size());
25894 int num_outputs_op = 1;
25895 TFE_TensorHandle* res[1] = {
nullptr};
25896 TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
25897 status_check(context::get_status());
25898 return tensor(res[0]);