Skip to content

Commit db84fb3

Browse files
paynecltensorflower-gardener
authored andcommitted
Update graph analysis (via GetMlirBridgeRolloutPolicy) to take in a flag indicating which version of the Phase 1 pass was called and add missing logging for Session API/V1 Compat Pass.
PiperOrigin-RevId: 467271406
1 parent e9ce467 commit db84fb3

4 files changed

Lines changed: 64 additions & 13 deletions

File tree

tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
2323
const tensorflow::Graph& graph,
2424
const FunctionLibraryDefinition* function_library,
2525
std::optional<ConfigProto> config_proto,
26-
bool uses_uninitialized_resource_args, bool record_stats) {
26+
bool uses_uninitialized_resource_args, bool is_v1_compat,
27+
bool record_stats) {
2728
switch (GetMlirBridgeRolloutState(config_proto)) {
2829
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
2930
return MlirBridgeRolloutPolicy::kEnabledByUser;
@@ -36,4 +37,10 @@ MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
3637
}
3738
}
3839

40+
void LogGraphFeatures(const Graph& graph,
41+
const FunctionLibraryDefinition* function_library,
42+
std::optional<ConfigProto> config_proto,
43+
bool uses_uninitialized_resource_args,
44+
bool is_v1_compat) {}
45+
3946
} // namespace tensorflow

tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,25 @@ MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
5656
const tensorflow::Graph& graph,
5757
const FunctionLibraryDefinition* function_library,
5858
std::optional<tensorflow::ConfigProto> config_proto,
59-
bool uses_uninitialized_resource_args, bool record_stats = false);
59+
bool uses_uninitialized_resource_args, bool is_v1_compat,
60+
bool record_stats);
6061

6162
static inline MlirBridgeRolloutPolicy GetMlirBridge2ndPhaseRolloutPolicy(
6263
mlir::ModuleOp module) {
6364
return MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis;
6465
}
6566

67+
// Explicit Interface for when we want to log features vs test the validity of
68+
// the graph for MLIR bridge processing. Note that right now the logging
69+
// which is done in the logic used by GraphHasFeaturesUnsupportedByMlirBridge
70+
// has diverged and logs supported features as well. Parameters are the same
71+
// as for GetMlirBridgeRolloutPolicy with the exception of
72+
// record_stats, which isn't needed because this interface will always record.
73+
void LogGraphFeatures(const Graph& graph,
74+
const FunctionLibraryDefinition* function_library,
75+
std::optional<ConfigProto> config_proto,
76+
bool uses_uninitialized_resource_args, bool is_v1_compat);
77+
6678
} // namespace tensorflow
6779

6880
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_

tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <memory>
1919
#include <string>
2020

21+
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
2122
#include "absl/container/flat_hash_set.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/Support/FormatVariadic.h"
@@ -178,12 +179,7 @@ Status MlirFunctionOptimizationPass::Run(
178179
tensorflow::metrics::ScopedCounter<2> timings(
179180
tensorflow::metrics::GetGraphOptimizationCounter(),
180181
{kTfMlirCategory, "graph_analysis"});
181-
// Capture stats on graph properties analyzed before running the MLIR bridge.
182-
// We set `uses_uninitialized_resource_args` to false here because function
183-
// optimization is not affected by uninitialized resource args.
184-
GetMlirBridgeRolloutPolicy(**graph, flib_def, config_proto,
185-
/*uses_uninitialized_resource_args=*/false,
186-
/*record_stats=*/true);
182+
187183
timings.ReportAndStop();
188184

189185
if (overall_state == MlirOptimizationPassState::Disabled) {
@@ -192,6 +188,13 @@ Status MlirFunctionOptimizationPass::Run(
192188
<< "None of the MLIR Optimization Passes are enabled "
193189
<< "(registered " << registry_->passes().size() << ")";
194190
}
191+
// Capture stats on graph properties analyzed before running the MLIR
192+
// bridge. We set `uses_uninitialized_resource_args` to false here because
193+
// function optimization is not affected by uninitialized resource args.
194+
// TODO(b/241853328): Remove LogGraphFeatures when fixed
195+
LogGraphFeatures(**graph, flib_def, config_proto,
196+
/*uses_uninitialized_resource_args=*/false,
197+
/*is_v1_compat=*/false);
195198
return OkStatus();
196199
}
197200

@@ -244,6 +247,14 @@ Status MlirFunctionOptimizationPass::Run(
244247
std::move(module_ref_status.ValueOrDie());
245248
AddDevicesToOp(*module_ref, &device_set);
246249

250+
// Capture stats on graph properties analyzed before running the MLIR
251+
// bridge. We set `uses_uninitialized_resource_args` to false here because
252+
// function optimization is not affected by uninitialized resource args.
253+
// TODO (b/241853328) Remove LogGraphFeatures when fixed
254+
LogGraphFeatures(**graph, flib_def, config_proto,
255+
/*uses_uninitialized_resource_args=*/false,
256+
/*is_v1_compat=*/false);
257+
247258
int per_pass_state_index = 0;
248259
for (auto& pass_registration : registry_->passes()) {
249260
llvm::StringRef name = pass_registration.pass->name();
@@ -381,11 +392,18 @@ Status MlirV1CompatGraphOptimizationPass::Run(
381392

382393
llvm::StringRef name = pass->name();
383394
VLOG(2) << "Run MLIR V1 graph optimization pass: " << StringRefToView(name);
395+
// If we ever have more than one MlirV1CompatOptimization pass we need to
396+
// ensure the logging only happens once per graph to avoid redundant logging
397+
// (see how it is used in the MLIRFunctionOptimizationPass as an example)
398+
// TODO(b/241853328): Remove LogGraphFeatures when fixed
399+
LogGraphFeatures(**options.graph, options.flib_def,
400+
options.session_options->config,
401+
/*uses_uninitialized_resource_args=*/false,
402+
/*is_v1_compat=*/true);
384403

385404
if (VLOG_IS_ON(1)) {
386405
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
387406
}
388-
389407
Status pass_status = pass->Run(options, *module_ref);
390408

391409
if (!pass_status.ok()) {

tensorflow/compiler/tf2xla/mlir_bridge_pass.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <string>
1919

20+
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
2021
#include "absl/base/call_once.h"
2122
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
2223
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
@@ -175,9 +176,15 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
175176

176177
// We set `uses_uninitialized_resource_args` to false here because the first
177178
// phase of the bridge is not affected by uninitialized resource args.
178-
MlirBridgeRolloutPolicy policy =
179-
GetMlirBridgeRolloutPolicy(graph, &function_library, config_proto,
180-
/*uses_uninitialized_resource_args=*/false);
179+
// Note we are recording the stats using LogGraphFeatures in the pass
180+
// that calls this one to avoid duplicate logging due to
181+
// GetMlirBridgeRolloutPolicy being called multiple times for the same graph.
182+
// TODO(b/241853328): Add caching of pass state and call logging/metrics
183+
// related to graph analysis from here.
184+
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
185+
graph, &function_library, config_proto,
186+
/*uses_uninitialized_resource_args=*/false,
187+
/*is_v1_compat=*/false, /*record_stats=*/false);
181188
switch (policy) {
182189
case MlirBridgeRolloutPolicy::kEnabledByUser:
183190
return MlirOptimizationPassState::Enabled;
@@ -230,6 +237,8 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
230237

231238
// Set device_set to nullptr here as the device specific checks are performed
232239
// based on the devices in the module.
240+
// TODO(b/241853328): Add caching of pass state and call logging/metrics
241+
// related to graph analysis from here.
233242
auto pass_state = GetPassState(/*device_set=*/nullptr, config_proto, graph,
234243
function_library);
235244

@@ -265,9 +274,14 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState(
265274
// only run if it's enabled by the user explicitly.
266275
// We set `uses_uninitialized_resource_args` to false here because the first
267276
// phase of the bridge is not affected by uninitialized resource args.
277+
// Note we are recording the stats using LogGraphFeatures in the pass
278+
// that calls this one.
279+
// TODO(b/241853328): Add caching of pass state and call logging/metrics
280+
// related to graph analysis from here.
268281
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
269282
graph, /*function_library=*/&function_library, config_proto,
270-
/*uses_uninitialized_resource_args=*/false);
283+
/*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true,
284+
/*record_stats=*/false);
271285
switch (policy) {
272286
case MlirBridgeRolloutPolicy::kEnabledByUser:
273287
return MlirOptimizationPassState::Enabled;

0 commit comments

Comments
 (0)