Skip to content
This repository was archived by the owner on Apr 19, 2026. It is now read-only.

Commit 9978d53

Browse files
author
Shrestha Malik
authored
Merge pull request #412 from tensorflow/shrestha/upgrade_r22_to_master
Shrestha/upgrade r22 to master
2 parents 27d8a19 + f4b0a03 commit 9978d53

28 files changed

Lines changed: 321 additions & 83 deletions

bazel/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ cc_library(
4848
"ngraph_bridge/ngraph_tensor_manager.h",
4949
"ngraph_bridge/ngraph_timer.h",
5050
"ngraph_bridge/ngraph_utils.h",
51+
"ngraph_bridge/ngraph_var.h",
5152
"ngraph_bridge/ngraph_version_utils.h",
5253
"ngraph_bridge/tf_deadness_analysis.h",
5354
"ngraph_bridge/tf_graphcycles.h",
@@ -92,6 +93,7 @@ cc_library(
9293
"ngraph_bridge/ngraph_tensor_manager.cc",
9394
"ngraph_bridge/ngraph_tracked_variable.cc",
9495
"ngraph_bridge/ngraph_utils.cc",
96+
"ngraph_bridge/ngraph_var.cc",
9597
"ngraph_bridge/tf_deadness_analysis.cc",
9698
"ngraph_bridge/tf_graphcycles.cc",
9799
"ngraph_bridge/ops/ngraph_ops.cc",

ngraph_bridge/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(SRC
5757
ngraph_rewrite_pass.cc
5858
ngraph_tensor_manager.cc
5959
ngraph_tracked_variable.cc
60+
ngraph_var.cc
6061
ngraph_utils.cc
6162
tf_graphcycles.cc
6263
tf_deadness_analysis.cc
@@ -86,7 +87,6 @@ if(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
8687
list(APPEND SRC enable_variable_ops/ngraph_tracked_variable.cc)
8788

8889
# new files
89-
list(APPEND SRC enable_variable_ops/ngraph_var.cc)
9090
list(APPEND SRC enable_variable_ops/ngraph_assign_op.cc)
9191
list(APPEND SRC enable_variable_ops/ngraph_enter_in_catalog.cc)
9292
list(APPEND SRC enable_variable_ops/ngraph_remove_ngraphassigns.cc)

ngraph_bridge/enable_variable_ops/ngraph_assign_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
#include "ngraph/event_tracing.hpp"
2626
#include "ngraph/runtime/backend.hpp"
2727

28-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
2928
#include "ngraph_bridge/ngraph_catalog.h"
3029
#include "ngraph_bridge/ngraph_freshness_tracker.h"
3130
#include "ngraph_bridge/ngraph_timer.h"
3231
#include "ngraph_bridge/ngraph_utils.h"
32+
#include "ngraph_bridge/ngraph_var.h"
3333

3434
using namespace std;
3535
namespace ng = ngraph;
@@ -83,7 +83,7 @@ class NGraphAssignOp : public OpKernel {
8383

8484
void Compute(OpKernelContext* context) override {
8585
std::ostringstream oss;
86-
oss << "Execute: Assign_" << my_instance_id << ": " << name();
86+
oss << "NGAssign::Compute::" << name();
8787
ngraph::Event event_compute(oss.str(), name(), "");
8888

8989
NGRAPH_VLOG(4) << "NGraphAssign:: Compute called for: " << def().name()

ngraph_bridge/enable_variable_ops/ngraph_enter_in_catalog.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,12 @@ Status EnterInCatalog(Graph* graph, int graph_id) {
160160
}
161161
}
162162

163-
// are there indexes that need copy
164-
if (op_index_to_copy.size() > 0) {
165-
try {
166-
NGraphCatalog::AddToEncapOutputCopyIndexesMap(graph_id, node->name(),
167-
op_index_to_copy);
168-
} catch (const std::exception& exp) {
169-
return errors::Internal(
170-
"Caught exception while entering in catalog: ", exp.what(), "\n");
171-
}
163+
try {
164+
NGraphCatalog::AddToEncapOutputCopyIndexesMap(graph_id, node->name(),
165+
op_index_to_copy);
166+
} catch (const std::exception& exp) {
167+
return errors::Internal("Caught exception while entering in catalog: ",
168+
exp.what(), "\n");
172169
}
173170

174171
} // end of node is type NGraphEncapsulate

ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "ngraph_bridge/ngraph_cluster_manager.h"
3131
#include "ngraph_bridge/ngraph_deassign_clusters.h"
3232
#include "ngraph_bridge/ngraph_encapsulate_clusters.h"
33+
#include "ngraph_bridge/ngraph_enter_prefetch_in_catalog.h"
3334
#include "ngraph_bridge/ngraph_mark_for_clustering.h"
3435
#include "ngraph_bridge/ngraph_rewrite_for_tracking.h"
3536
#include "ngraph_bridge/ngraph_utils.h"
@@ -255,6 +256,13 @@ class NGraphEncapsulationPass : public NGraphRewritePass {
255256
"Graph with NGraphAssigns Optimized/Removed");
256257
}
257258

259+
// 8. Enter Prefetch in catalog then.
260+
TF_RETURN_IF_ERROR(EnterPrefetchInCatalog(options.graph->get(), idx));
261+
if (DumpCatalogedGraphs()) {
262+
DumpGraphs(options, idx, "prefetch-cataloged",
263+
"Graph with Prefetched Inputs Entered in Catalog");
264+
}
265+
258266
return Status::OK();
259267
}
260268

ngraph_bridge/enable_variable_ops/ngraph_tracked_variable.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
#include "ngraph/event_tracing.hpp"
2424
#include "ngraph/runtime/backend.hpp"
2525

26-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
2726
#include "ngraph_bridge/ngraph_backend_manager.h"
2827
#include "ngraph_bridge/ngraph_catalog.h"
2928
#include "ngraph_bridge/ngraph_freshness_tracker.h"
3029
#include "ngraph_bridge/ngraph_utils.h"
30+
#include "ngraph_bridge/ngraph_var.h"
3131

3232
using namespace std;
3333
namespace ng = ngraph;
@@ -119,7 +119,7 @@ void NGraphVariableOp::Compute(OpKernelContext* ctx) {
119119
<< " ,backend_name " << ng_backend_name_;
120120

121121
std::ostringstream oss;
122-
oss << "NGraphVariable: " << my_instance_id << ": " << name();
122+
oss << "NGVariable::Compute::" << name();
123123
ngraph::Event event_compute(oss.str(), name(), "");
124124

125125
bool log_copies = false;
@@ -250,6 +250,7 @@ void NGraphVariableOp::Compute(OpKernelContext* ctx) {
250250
ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
251251
}
252252
var->Unref();
253+
event_compute.Stop();
253254
ngraph::Event::write_trace(event_compute);
254255
}
255256

ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626

2727
#include "ngraph/runtime/backend.hpp"
2828

29-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
3029
#include "ngraph_bridge/ngraph_backend_manager.h"
3130
#include "ngraph_bridge/ngraph_catalog.h"
3231
#include "ngraph_bridge/ngraph_freshness_tracker.h"
3332
#include "ngraph_bridge/ngraph_timer.h"
3433
#include "ngraph_bridge/ngraph_utils.h"
34+
#include "ngraph_bridge/ngraph_var.h"
3535

3636
using namespace std;
3737
namespace ng = ngraph;

ngraph_bridge/enable_variable_ops/ngraph_variable_update_ng_tensor_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424

2525
#include "ngraph/event_tracing.hpp"
2626

27-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
2827
#include "ngraph_bridge/enable_variable_ops/ngraph_variable_update_ng_tensor_op.h"
2928
#include "ngraph_bridge/ngraph_timer.h"
3029
#include "ngraph_bridge/ngraph_utils.h"
30+
#include "ngraph_bridge/ngraph_var.h"
3131

3232
using namespace std;
3333
namespace ng = ngraph;
@@ -67,6 +67,7 @@ NGraphVariableUpdateNGTensorOp::~NGraphVariableUpdateNGTensorOp() {
6767
void NGraphVariableUpdateNGTensorOp::Compute(OpKernelContext* context) {
6868
std::ostringstream oss;
6969
// Start event tracing
70+
oss << "NGVariableUpdateNGTensor::Compute::" << name();
7071
ngraph::Event event_compute(oss.str(), name(), "");
7172
bool log_copies = false;
7273
OP_REQUIRES_OK(context,

ngraph_bridge/ngraph_encapsulate_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
#include "ngraph_bridge/ngraph_timer.h"
4646
#include "ngraph_bridge/ngraph_utils.h"
4747

48+
#include "ngraph_bridge/ngraph_var.h"
4849
#if defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
49-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
5050
#include "ngraph_bridge/ngraph_catalog.h"
5151
#endif
5252

ngraph_bridge/ngraph_encapsulate_op.cc

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
#include "ngraph_bridge/ngraph_prefetch_shared_data.h"
5050
#include "ngraph_bridge/ngraph_timer.h"
5151
#include "ngraph_bridge/ngraph_utils.h"
52+
#include "ngraph_bridge/ngraph_var.h"
5253

5354
#if defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
54-
#include "ngraph_bridge/enable_variable_ops/ngraph_var.h"
5555
#include "ngraph_bridge/ngraph_catalog.h"
5656
#endif
5757

@@ -88,13 +88,8 @@ NGraphEncapsulateOp::NGraphEncapsulateOp(OpKernelConstruction* ctx)
8888
ctx, backend != nullptr,
8989
errors::Internal("Cannot get the backend object for BE: ", be_name));
9090

91-
// If we have the VARIABLE capture on then we can't use the
92-
// parallel executor until that support is added.
93-
#if !defined(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
91+
// If backend executable can create tensors we use parallel executor
9492
m_use_parallel_executor = backend->executable_can_create_tensors();
95-
#else
96-
m_use_parallel_executor = false;
97-
#endif
9893

9994
// Override the switch for debugging/testing
10095
if (std::getenv("NGRAPH_TF_USE_LEGACY_EXECUTOR") != nullptr) {
@@ -402,7 +397,7 @@ NGraphEncapsulateOp::~NGraphEncapsulateOp() {
402397
// OpKernel::Compute
403398
//---------------------------------------------------------------------------
404399
void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
405-
ngraph::Event event_compute("Compute", "", "");
400+
ngraph::Event event_compute("NGEncap::Compute::" + name(), name(), "");
406401

407402
if (m_use_parallel_executor) {
408403
NGRAPH_VLOG(1) << "NGraphEncapsulateOp::Compute: Using Parallel Executor";
@@ -459,6 +454,7 @@ void NGraphEncapsulateOp::ComputeUsingParallelExecutor(OpKernelContext* ctx) {
459454
m_parallel_executor->GetTensorPipelineDepth()));
460455

461456
// Get Tensor Manager and some error checking
457+
ngraph::Event event_prepare_ng_tensors("Prepare NG In/Out Tensors", "", "");
462458
auto tensor_manager = m_parallel_executor->GetTensorManager();
463459
int num_of_inputs = tensor_manager->GetNumberOfInputs();
464460
int num_of_outputs = tensor_manager->GetNumberOfOutputs();
@@ -499,14 +495,18 @@ void NGraphEncapsulateOp::ComputeUsingParallelExecutor(OpKernelContext* ctx) {
499495
vector<shared_ptr<ng::runtime::Tensor>> ng_inputs(num_of_inputs);
500496
vector<shared_ptr<ng::runtime::Tensor>> ng_outputs(num_of_outputs);
501497

502-
// All inputs and outputs are pipelined.
503-
// Of all these pipelined inputs some are prefetched
504-
// TODO: Fit in variables
505-
ng_inputs = get<1>(pipelined_io_tensors);
506-
ng_outputs = get<2>(pipelined_io_tensors);
498+
// Prepare NG Input Output Tensors
499+
// Assemble Variable tensors and pipelined tensors to ng_input and ng_outputs
500+
OP_REQUIRES_OK(ctx, GetIOTensorsReadyForExecution(
501+
ctx, tensor_manager, get<1>(pipelined_io_tensors),
502+
get<2>(pipelined_io_tensors), ng_inputs, ng_outputs));
503+
event_prepare_ng_tensors.Stop();
504+
ngraph::Event::write_trace(event_prepare_ng_tensors);
507505

508506
// And execute
509-
ngraph::Event event_execute_graph("Execute Graph", "", "");
507+
ngraph::Event event_execute_graph(
508+
"Execute Graph Pipeline Indx" + to_string(current_iter_pipeline_depth),
509+
"", "");
510510

511511
BackendManager::LockBackend(m_parallel_executor->GetOpBackendName());
512512
NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Compute call starting for cluster "
@@ -540,12 +540,14 @@ void NGraphEncapsulateOp::ComputeUsingParallelExecutor(OpKernelContext* ctx) {
540540
ngraph::Event::write_trace(event_execute_graph);
541541

542542
// Now prepare the output
543-
ngraph::Event event_copy_output_tensor("Copy Output Tensor", "", "");
543+
// Allocate TF Tensors
544+
NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Compute Allocating TF Output Tensors "
545+
<< m_parallel_executor->GetNgraphClusterId();
544546

545-
std::vector<std::unique_ptr<ngraph::Event>> output_copy_events;
547+
ngraph::Event event_prepare_tf_output_tensors("Prepare TF Output Tensor", "",
548+
"");
549+
vector<Tensor*> tf_output_tensors;
546550
for (auto i = 0; i < ng_exec->get_results().size(); i++) {
547-
std::unique_ptr<ngraph::Event> event_copy_prep(
548-
new ngraph::Event("Copy Prep", "", ""));
549551
auto ng_element = ng_exec->get_results()[i];
550552
auto ng_shape = ng_element->get_shape();
551553
auto ng_element_type = ng_element->get_element_type();
@@ -558,7 +560,7 @@ void NGraphEncapsulateOp::ComputeUsingParallelExecutor(OpKernelContext* ctx) {
558560
TensorShape tf_shape(dims);
559561
Tensor* tf_output_tensor = nullptr;
560562
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, tf_shape, &tf_output_tensor));
561-
563+
tf_output_tensors.push_back(tf_output_tensor);
562564
// Make sure the nGraph-inferred element type agrees with what TensorFlow
563565
// expected.
564566
ng::element::Type expected_elem_type;
@@ -569,28 +571,45 @@ void NGraphEncapsulateOp::ComputeUsingParallelExecutor(OpKernelContext* ctx) {
569571
ctx, ng_element_type == expected_elem_type,
570572
errors::Internal("Element type inferred by nGraph does not match "
571573
"the element type expected by TensorFlow"));
572-
event_copy_prep->Stop();
573-
output_copy_events.push_back(std::move(event_copy_prep));
574+
}
574575

575-
// Now copy the nGraph Tensor to Host Tensor
576-
std::unique_ptr<ngraph::Event> event_copy_d2h(
577-
new ngraph::Event("Device to Host Copy", "", ""));
578-
void* dst_ptr = DMAHelper::base(tf_output_tensor);
576+
// Copy Tensors that are required
577+
NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Compute Read NG Output Tensors "
578+
<< m_parallel_executor->GetNgraphClusterId();
579579

580-
ng_outputs[i]->read(
581-
dst_ptr, ng_outputs[i]->get_element_count() * ng_element_type.size());
580+
std::vector<std::unique_ptr<ngraph::Event>> output_copy_events;
581+
582+
auto output_indexes_to_be_copied =
583+
tensor_manager->GetOutputIndexesThatNeedCopy();
584+
for (auto output_index : output_indexes_to_be_copied) {
585+
// Copy the nGraph Tensor to Host Tensor
586+
std::unique_ptr<ngraph::Event> event_copy_d2h(new ngraph::Event(
587+
"D2H_Output_" + std::to_string(output_index), "", ""));
588+
void* dst_ptr = (void*)DMAHelper::base(tf_output_tensors[output_index]);
589+
ng_outputs[output_index]->read(
590+
dst_ptr, ng_outputs[output_index]->get_element_count() *
591+
ng_outputs[output_index]->get_element_type().size());
582592
event_copy_d2h->Stop();
583593
output_copy_events.push_back(std::move(event_copy_d2h));
584594
}
585-
586595
for (auto& next : output_copy_events) {
587596
ngraph::Event::write_trace(*next.get());
588597
}
598+
event_prepare_tf_output_tensors.Stop();
599+
ngraph::Event::write_trace(event_prepare_tf_output_tensors);
589600

590-
event_copy_output_tensor.Stop();
591-
ngraph::Event::write_trace(event_copy_output_tensor);
601+
// Synch Var Output Tensors as required
602+
NGRAPH_VLOG(4)
603+
<< "NGraphEncapsulateOp::Compute Sync NG Output Variable Tensors "
604+
<< m_parallel_executor->GetNgraphClusterId();
605+
ngraph::Event event_update_ngvar_tensors("Update NGVar Tensors", "", "");
606+
OP_REQUIRES_OK(ctx, SyncOutputVarTensors(ctx, tensor_manager));
607+
event_update_ngvar_tensors.Stop();
608+
ngraph::Event::write_trace(event_update_ngvar_tensors);
592609

593610
// Now return them to the cache
611+
NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Returning Tensors "
612+
<< m_parallel_executor->GetNgraphClusterId();
594613
ngraph::Event event_return_tensor("Return Tensor", "", "");
595614
pipelined_tensor_store->return_tensors(current_iter_pipeline_depth);
596615

0 commit comments

Comments
 (0)