Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3340,6 +3340,8 @@ def split_embedding_weights(self) -> list[Tensor]:
# pyre-fixme[29]: `(self: TensorBase) -> int | Module | Tensor` is
# not a function.
if weights.dim() == 2:
# pyre-fixme[29]: `Union[(self: TensorBase, start_dim: int = ...,
# end_dim: int = ...) -> Tensor, Module, Tensor]` is not a function.
weights = weights.flatten()
splits.append(
weights.detach()[offset : offset + rows * dim].view(rows, dim)
Expand Down Expand Up @@ -4009,6 +4011,8 @@ def _update_cache_counter_and_locations(
context=self.step,
stream=torch.cuda.current_stream(),
):
# pyre-fixme[6]: For 1st argument expected `Union[Stream, Stream]`
# but got `Optional[Stream]`.
torch.cuda.current_stream().wait_stream(self.prefetch_stream)

torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
Expand Down
28 changes: 28 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,13 @@ def __init__(
# The max number of rows to be evicted is limited by the number of
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
# be the same shape as the L1 cache (lxu_cache_weights)
# pyre-fixme[4]: Attribute must be annotated.
self.lxu_cache_evicted_weights_list = []
# pyre-fixme[4]: Attribute must be annotated.
self.lxu_cache_evicted_indices_list = []
# pyre-fixme[4]: Attribute must be annotated.
self.lxu_cache_evicted_slots_list = []
# pyre-fixme[4]: Attribute must be annotated.
self.lxu_cache_evicted_count_list = []
for buf_idx in range(2):
evicted_weights = torch.ops.fbgemm.new_unified_tensor(
Expand Down Expand Up @@ -1933,6 +1937,8 @@ def _update_cache_counter_and_pointers(
"""
if self.prefetch_stream:
# Ensure that prefetch is done
# pyre-fixme[6]: For 1st argument expected `Union[Stream, Stream]` but
# got `Optional[Stream]`.
torch.cuda.current_stream().wait_stream(self.prefetch_stream)

assert self.current_iter_data is not None, "current_iter_data must be set"
Expand Down Expand Up @@ -3179,6 +3185,7 @@ def _split_optimizer_states_kv_zch_whole_row(
# _fetch_offloaded_optimizer_states can still handle the
# local_weight_counts > 0 case (filling with zeros).
if sorted_ids is None:
# pyre-fixme[9]: sorted_ids has type `Tensor`; used as `List[Tensor]`.
sorted_ids = [
torch.empty(0, 1, device=torch.device("cpu"), dtype=torch.int64)
for _ in dims_
Expand Down Expand Up @@ -3570,6 +3577,7 @@ def split_embedding_weights(
row_offset = table_offset
metaheader_dim = 0
if self.kv_zch_params:
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
# pyre-ignore
bucket_size = self.kv_zch_params.bucket_sizes[i]
Expand Down Expand Up @@ -4377,6 +4385,7 @@ def _report_l2_cache_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="l2_cache.hit_rate_pct",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=hit_rate,
)

Expand Down Expand Up @@ -4452,6 +4461,7 @@ def _report_eviction_stats(self) -> None:
return

# skip metrics reporting when evicting disabled
# pyre-fixme[16]: Optional type has no attribute `eviction_policy`.
if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
return

Expand Down Expand Up @@ -4619,6 +4629,7 @@ def _report_dram_kv_perf_stats(self) -> None:
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_missing_load",
enable_tb_metrics=True,
# pyre-fixme[6]: For 4th argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.READ_MISSING_LOAD],
)
stats_reporter.report_duration(
Expand Down Expand Up @@ -4671,6 +4682,7 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.FWD_L1_EVICTION_WRITE_MISSING_LOAD],
enable_tb_metrics=True,
)
Expand Down Expand Up @@ -4721,32 +4733,37 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.BWD_L1_CNFLCT_MISS_WRITE_MISSING_LOAD],
enable_tb_metrics=True,
)

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_kv_read_counts",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.KV_READ_COUNTS],
enable_tb_metrics=True,
)

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_allocated_bytes_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.KV_ALLOCATED_BYTES],
enable_tb_metrics=True,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.KV_ACTUAL_USED_CHUNK_BYTES],
enable_tb_metrics=True,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_mem_num_rows_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.KV_NUM_ROWS],
enable_tb_metrics=True,
)
Expand Down Expand Up @@ -4788,6 +4805,7 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.METADATA_WRITE_CACHE_MISS_AVG_COUNT],
enable_tb_metrics=True,
)
Expand Down Expand Up @@ -4831,6 +4849,7 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=stats[DramKvPerfStat.READ_METADATA_LOAD_SIZE],
enable_tb_metrics=True,
)
Expand All @@ -4847,12 +4866,14 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_hit_count_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=dram_read_hit_count,
enable_tb_metrics=True,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_miss_count_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=dram_read_miss_count,
enable_tb_metrics=True,
)
Expand All @@ -4862,13 +4883,15 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_hit_rate_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=hit_rate_pct,
enable_tb_metrics=True,
)
# Aggregate hit rate (kept for backward compat)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.hit_rate_pct",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=hit_rate_pct,
enable_tb_metrics=True,
)
Expand Down Expand Up @@ -4898,12 +4921,14 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.overall_hit_rate_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=overall_hit_rate_pct,
enable_tb_metrics=True,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd_tbe.overall_hit_rate_pct",
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=overall_hit_rate_pct,
enable_tb_metrics=True,
)
Expand All @@ -4918,12 +4943,14 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.enrichment_query_count_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=enrichment_query_count,
enable_tb_metrics=True,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.enrichment_empty_count_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=enrichment_empty_count,
enable_tb_metrics=True,
)
Expand All @@ -4936,6 +4963,7 @@ def _report_dram_kv_perf_stats(self) -> None:
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.enrichment_success_rate_stats_name,
# pyre-fixme[6]: For 3rd argument expected `int` but got `float`.
data_bytes=enrichment_success_rate,
enable_tb_metrics=True,
)
Expand Down
Loading