7878from tensorflow_transform .tf_metadata import metadata_io
7979from tensorflow_transform .tf_metadata import schema_utils
8080from tfx_bsl .telemetry import collection as telemetry
81+ from tfx_bsl .telemetry import util as telemetry_util
8182from tfx_bsl .tfxio import tensor_representation_util
8283from tfx_bsl .tfxio import tensor_to_arrow
8384from tfx_bsl .tfxio import tf_example_record
@@ -1078,6 +1079,15 @@ def expand(self, dataset):
10781079 >> telemetry .TrackRecordBatchBytes (beam_common .METRICS_NAMESPACE ,
10791080 'analysis_input_bytes' ))
10801081
1082+ # Gather telemetry on types of input features.
1083+ _ = (
1084+ self .pipeline | 'CreateAnalyzeInputTensorRepresentations' >>
1085+ beam .Create ([input_tensor_adapter_config .tensor_representations ])
1086+ |
1087+ 'InstrumentAnalyzeInputTensors' >> telemetry .TrackTensorRepresentations (
1088+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1089+ ['analyze_input_tensors' ])))
1090+
10811091 asset_map = annotators .get_asset_annotations (graph )
10821092 # TF.HUB can error when unapproved collections are present. So we explicitly
10831093 # clear out the collections in the graph.
@@ -1351,6 +1361,20 @@ def _remove_columns_from_metadata(metadata, excluded_columns):
13511361 new_feature_spec , new_domains )
13521362
13531363
1364+ class _MaybeInferTensorRepresentationsDoFn (beam .DoFn ):
1365+ """Tries to infer TensorRepresentations from a Schema."""
1366+
1367+ def process (
1368+ self , schema : schema_pb2 .Schema
1369+ ) -> Iterable [Dict [str , schema_pb2 .TensorRepresentation ]]:
1370+ try :
1371+ yield (tensor_representation_util
1372+ .InferTensorRepresentationsFromMixedSchema (schema ))
1373+ except ValueError :
1374+ # Ignore any inference errors since the output is only used for metrics.
1375+ yield {}
1376+
1377+
13541378@beam .typehints .with_input_types (Union [_DatasetElementType , pa .RecordBatch ],
13551379 Union [dataset_metadata .DatasetMetadata ,
13561380 TensorAdapterConfig ,
@@ -1446,11 +1470,20 @@ def expand(self, dataset_and_transform_fn):
14461470 self .pipeline
14471471 | 'CreateDeferredSchema' >> beam .Create ([output_metadata .schema ]))
14481472
1473+ # Increment input metrics.
14491474 _ = (
14501475 input_values
14511476 | 'InstrumentInputBytes[Transform]' >> telemetry .TrackRecordBatchBytes (
14521477 beam_common .METRICS_NAMESPACE , 'transform_input_bytes' ))
14531478
1479+ _ = (
1480+ self .pipeline | 'CreateTransformInputTensorRepresentations' >>
1481+ beam .Create ([input_tensor_adapter_config .tensor_representations ])
1482+ | 'InstrumentTransformInputTensors' >>
1483+ telemetry .TrackTensorRepresentations (
1484+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1485+ ['transform_input_tensors' ])))
1486+
14541487 tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE .get (
14551488 type (self .pipeline .runner ))
14561489 output_batches = (
@@ -1471,20 +1504,38 @@ def expand(self, dataset_and_transform_fn):
14711504 converter_pcol = (
14721505 deferred_schema | 'MakeTensorToArrowConverter' >> beam .Map (
14731506 impl_helper .make_tensor_to_arrow_converter ))
1507+
1508+ output_tensor_representations = (
1509+ converter_pcol
1510+ | 'MapToTensorRepresentations' >>
1511+ beam .Map (lambda converter : converter .tensor_representations ()))
1512+
14741513 output_data = (
14751514 output_batches | 'ConvertToRecordBatch' >> beam .Map (
14761515 _convert_to_record_batch ,
14771516 schema = beam .pvalue .AsSingleton (deferred_schema ),
14781517 converter = beam .pvalue .AsSingleton (converter_pcol ),
14791518 passthrough_keys = Context .get_passthrough_keys (),
14801519 input_metadata = input_metadata ))
1520+
14811521 else :
1522+
1523+ output_tensor_representations = (
1524+ deferred_schema | 'MaybeInferTensorRepresentations' >> beam .ParDo (
1525+ _MaybeInferTensorRepresentationsDoFn ()))
14821526 output_data = (
14831527 output_batches | 'ConvertAndUnbatchToInstanceDicts' >> beam .FlatMap (
14841528 _convert_and_unbatch_to_instance_dicts ,
14851529 schema = beam .pvalue .AsSingleton (deferred_schema ),
14861530 passthrough_keys = Context .get_passthrough_keys ()))
14871531
1532+ # Increment output data metrics.
1533+ _ = (
1534+ output_tensor_representations
1535+ | 'InstrumentTransformOutputTensors' >>
1536+ telemetry .TrackTensorRepresentations (
1537+ telemetry_util .AppendToNamespace (beam_common .METRICS_NAMESPACE ,
1538+ ['transform_output_tensors' ])))
14881539 _clear_shared_state_after_barrier (self .pipeline , output_data )
14891540
14901541 return (output_data , output_metadata )
0 commit comments