From 5e547465f0a0986912e788a1072d237e6da43349 Mon Sep 17 00:00:00 2001 From: "Pyre Bot Jr." Date: Wed, 10 Jun 2026 13:48:07 -0700 Subject: [PATCH] Suppress type errors for Pyre upgrade Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2792 This diff was automatically generated by the Pyre per-target upgrade tool. It adds `# pyre-fixme` or `pyrefly: ignore` comments to suppress type errors that will be introduced by an upcoming Pyre or Pyrefly release. These suppressions allow the upgrade to proceed without breaking existing code. wed - upgrade new suppression fix #pyreupgrade Differential Revision: D108191557 --- ...t_table_batched_embeddings_ops_training.py | 4 +++ fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 28 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 137259cf1b..dd1903d2d7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -3340,6 +3340,8 @@ def split_embedding_weights(self) -> list[Tensor]: # pyre-fixme[29]: `(self: TensorBase) -> int | Module | Tensor` is # not a function. if weights.dim() == 2: + # pyre-fixme[29]: `Union[(self: TensorBase, start_dim: int = ..., + # end_dim: int = ...) -> Tensor, Module, Tensor]` is not a function. weights = weights.flatten() splits.append( weights.detach()[offset : offset + rows * dim].view(rows, dim) @@ -4009,6 +4011,8 @@ def _update_cache_counter_and_locations( context=self.step, stream=torch.cuda.current_stream(), ): + # pyre-fixme[6]: For 1st argument expected `Union[Stream, Stream]` + # but got `Optional[Stream]`. torch.cuda.current_stream().wait_stream(self.prefetch_stream) torch.ops.fbgemm.lxu_cache_locking_counter_decrement( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index b6bbb88c9f..0c38e18ef3 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -645,9 +645,13 @@ def __init__( # The max number of rows to be evicted is limited by the number of # slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to # be the same shape as the L1 cache (lxu_cache_weights) + # pyre-fixme[4]: Attribute must be annotated. self.lxu_cache_evicted_weights_list = [] + # pyre-fixme[4]: Attribute must be annotated. self.lxu_cache_evicted_indices_list = [] + # pyre-fixme[4]: Attribute must be annotated. self.lxu_cache_evicted_slots_list = [] + # pyre-fixme[4]: Attribute must be annotated. self.lxu_cache_evicted_count_list = [] for buf_idx in range(2): evicted_weights = torch.ops.fbgemm.new_unified_tensor( @@ -1933,6 +1937,8 @@ def _update_cache_counter_and_pointers( """ if self.prefetch_stream: # Ensure that prefetch is done + # pyre-fixme[6]: For 1st argument expected `Union[Stream, Stream]` but + # got `Optional[Stream]`. torch.cuda.current_stream().wait_stream(self.prefetch_stream) assert self.current_iter_data is not None, "current_iter_data must be set" @@ -3179,6 +3185,7 @@ def _split_optimizer_states_kv_zch_whole_row( # _fetch_offloaded_optimizer_states can still handle the # local_weight_counts > 0 case (filling with zeros). if sorted_ids is None: + # pyre-fixme[9]: sorted_ids has type `Tensor`; used as `List[Tensor]`. sorted_ids = [ torch.empty(0, 1, device=torch.device("cpu"), dtype=torch.int64) for _ in dims_ @@ -3570,6 +3577,7 @@ def split_embedding_weights( row_offset = table_offset metaheader_dim = 0 if self.kv_zch_params: + # pyre-fixme[16]: Optional type has no attribute `__getitem__`. bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] # pyre-ignore bucket_size = self.kv_zch_params.bucket_sizes[i] @@ -4377,6 +4385,7 @@ def _report_l2_cache_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="l2_cache.hit_rate_pct", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=hit_rate, ) @@ -4452,6 +4461,7 @@ def _report_eviction_stats(self) -> None: return # skip metrics reporting when evicting disabled + # pyre-fixme[16]: Optional type has no attribute `eviction_policy`. if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0: return @@ -4619,6 +4629,7 @@ def _report_dram_kv_perf_stats(self) -> None: iteration_step=self.step, event_name="dram_kv.perf.get.dram_read_missing_load", enable_tb_metrics=True, + # pyre-fixme[6]: For 4th argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.READ_MISSING_LOAD], ) stats_reporter.report_duration( @@ -4671,6 +4682,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.FWD_L1_EVICTION_WRITE_MISSING_LOAD], enable_tb_metrics=True, ) @@ -4721,6 +4733,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.BWD_L1_CNFLCT_MISS_WRITE_MISSING_LOAD], enable_tb_metrics=True, ) @@ -4728,6 +4741,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.perf.get.dram_kv_read_counts", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.KV_READ_COUNTS], enable_tb_metrics=True, ) @@ -4735,18 +4749,21 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_allocated_bytes_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.KV_ALLOCATED_BYTES], enable_tb_metrics=True, ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_actual_used_chunk_bytes_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.KV_ACTUAL_USED_CHUNK_BYTES], enable_tb_metrics=True, ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_mem_num_rows_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.KV_NUM_ROWS], enable_tb_metrics=True, ) @@ -4788,6 +4805,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.METADATA_WRITE_CACHE_MISS_AVG_COUNT], enable_tb_metrics=True, ) @@ -4831,6 +4849,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.perf.get.dram_eviction_score_read_load_size", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=stats[DramKvPerfStat.READ_METADATA_LOAD_SIZE], enable_tb_metrics=True, ) @@ -4847,12 +4866,14 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_hit_count_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=dram_read_hit_count, enable_tb_metrics=True, ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_miss_count_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=dram_read_miss_count, enable_tb_metrics=True, ) @@ -4862,6 +4883,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.dram_kv_hit_rate_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=hit_rate_pct, enable_tb_metrics=True, ) @@ -4869,6 +4891,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name="dram_kv.hit_rate_pct", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=hit_rate_pct, enable_tb_metrics=True, ) @@ -4898,12 +4921,14 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.overall_hit_rate_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=overall_hit_rate_pct, enable_tb_metrics=True, ) stats_reporter.report_data_amount( iteration_step=self.step, event_name="ssd_tbe.overall_hit_rate_pct", + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=overall_hit_rate_pct, enable_tb_metrics=True, ) @@ -4918,12 +4943,14 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.enrichment_query_count_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=enrichment_query_count, enable_tb_metrics=True, ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.enrichment_empty_count_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=enrichment_empty_count, enable_tb_metrics=True, ) @@ -4936,6 +4963,7 @@ def _report_dram_kv_perf_stats(self) -> None: stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.enrichment_success_rate_stats_name, + # pyre-fixme[6]: For 3rd argument expected `int` but got `float`. data_bytes=enrichment_success_rate, enable_tb_metrics=True, )