Skip to content

Commit 27fe6dc

Browse files
committed
fix: address PR Ascend#39 review comments (round 1 & 2)
Round 1 (Copilot Ascend#5, Ascend#9, Ascend#10, Ascend#17, Ascend#20): - fix(serial_utils): preserve scalar shape in round-trip serialization - fix(zmq_utils): add strict=True to zip calls - fix(metadata): use warnings.warn for parse_dtype fallback - refactor(metadata): remove dead _convert_legacy_sample_meta code - fix(simple_backend): downgrade storage unit log to DEBUG Round 2 (Copilot Ascend#24, 0oshowero0 Ascend#27, Ascend#28): - perf(simple_backend): track active keys with _active_keys set for O(1) capacity check, replacing O(K×F) existing_keys scan in put_data - docs(tutorial): merge demonstrate_batch_meta_construction() and demonstrate_batch_meta() into demonstrate_batch_meta_operations(), eliminating 3 duplicate demos and fixing misleading function name Signed-off-by: 看我72遍 <m.pb@msn.com>
1 parent e25c386 commit 27fe6dc

7 files changed

Lines changed: 193 additions & 156 deletions

File tree

tests/test_metadata.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,44 @@ def test_chunk_concat_roundtrip_preserves_extra_info(self):
368368
assert len(restored) == 6
369369
assert restored.global_indexes == list(range(6))
370370

371+
def test_to_dict_from_dict_scalar_shape_roundtrip(self):
372+
"""to_dict/from_dict must preserve scalar tensor shape () (empty tuple).
373+
374+
() is falsy in Python; `if meta.get('shape')` would incorrectly
375+
serialize it as None. The fix uses `is not None`.
376+
"""
377+
batch = BatchMeta(
378+
global_indexes=[0],
379+
partition_ids=["p0"],
380+
field_schema={
381+
"scalar_field": {
382+
"dtype": torch.float32,
383+
"shape": (), # scalar shape — falsy!
384+
"is_nested": False,
385+
"is_non_tensor": False,
386+
}
387+
},
388+
production_status=np.ones(1, dtype=np.int8),
389+
)
390+
d = batch.to_dict()
391+
# shape must be serialized as [] (empty list), NOT None
392+
assert d["field_schema"]["scalar_field"]["shape"] == []
393+
394+
restored = BatchMeta.from_dict(d)
395+
# shape must be restored as () (empty tuple), NOT None
396+
assert restored.field_schema["scalar_field"]["shape"] == ()
397+
398+
def test_parse_dtype_unknown_string_logs_warning(self, caplog):
399+
"""_parse_dtype must log a warning when returning a raw string fallback."""
400+
import logging
401+
402+
from transfer_queue.metadata import _parse_dtype
403+
404+
with caplog.at_level(logging.WARNING, logger="transfer_queue.metadata"):
405+
result = _parse_dtype("<class 'int'>")
406+
assert result == "<class 'int'>"
407+
assert "Unknown dtype string" in caplog.text
408+
371409

372410
# ==============================================================================
373411
# KVBatchMeta Tests (all migrated from main with no modification)
@@ -674,3 +712,21 @@ def test_kv_batch_meta_concat_extra_info_conflict_raises(self):
674712
)
675713
with pytest.raises(ValueError, match="conflicting"):
676714
KVBatchMeta.concat([kv1, kv2])
715+
716+
717+
# ==============================================================================
718+
# StorageUnitData Tests
719+
# ==============================================================================
720+
721+
722+
class TestStorageUnitDataStrict:
723+
"""Tests for StorageUnitData length validation."""
724+
725+
def test_put_data_length_mismatch_raises(self):
726+
"""put_data must raise when global_indexes and field values have different lengths."""
727+
from transfer_queue.storage.simple_backend import StorageUnitData
728+
729+
sud = StorageUnitData(storage_size=10)
730+
# 3 indexes but only 2 values — must raise, not silently drop
731+
with pytest.raises(ValueError, match="length mismatch"):
732+
sud.put_data({"field_a": [1, 2]}, global_indexes=[0, 1, 2])

tests/test_simple_storage_unit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,54 @@ def test_storage_unit_data_partial_consume_safety():
467467
storage.put_data({"f": [torch.tensor([9.0])]}, global_indexes=[1])
468468
torch.testing.assert_close(storage.field_data["f"][0], torch.tensor([0.0]))
469469
torch.testing.assert_close(storage.field_data["f"][1], torch.tensor([9.0]))
470+
471+
472+
def test_storage_unit_data_active_keys_tracking():
473+
"""_active_keys must track put/clear lifecycle without scanning field_data."""
474+
from transfer_queue.storage.simple_backend import StorageUnitData
475+
476+
storage = StorageUnitData(storage_size=10)
477+
478+
# Initially empty
479+
assert hasattr(storage, "_active_keys"), "StorageUnitData must have _active_keys attribute"
480+
assert storage._active_keys == set()
481+
482+
# After put, _active_keys must include the new keys
483+
storage.put_data(
484+
{"f1": [1, 2, 3], "f2": ["a", "b", "c"]},
485+
global_indexes=[10, 20, 30],
486+
)
487+
assert storage._active_keys == {10, 20, 30}
488+
489+
# After clearing some keys, _active_keys must shrink
490+
storage.clear(keys=[20])
491+
assert storage._active_keys == {10, 30}
492+
493+
# Re-put overlapping key should not duplicate
494+
storage.put_data({"f1": [99], "f2": ["z"]}, global_indexes=[10])
495+
assert storage._active_keys == {10, 30}
496+
497+
# After clearing all, _active_keys must be empty
498+
storage.clear(keys=[10, 30])
499+
assert storage._active_keys == set()
500+
501+
502+
def test_storage_unit_data_capacity_uses_active_keys():
503+
"""Capacity check must use _active_keys, not scan field_data."""
504+
from transfer_queue.storage.simple_backend import StorageUnitData
505+
506+
storage = StorageUnitData(storage_size=3)
507+
508+
# Fill to capacity
509+
storage.put_data({"f": [1, 2, 3]}, global_indexes=[0, 1, 2])
510+
assert len(storage._active_keys) == 3
511+
512+
# Exceeding capacity must raise
513+
with pytest.raises(ValueError, match="Storage capacity exceeded"):
514+
storage.put_data({"f": [4]}, global_indexes=[3])
515+
516+
# After clearing one key, adding one more should succeed
517+
storage.clear(keys=[2])
518+
assert len(storage._active_keys) == 2
519+
storage.put_data({"f": [4]}, global_indexes=[3])
520+
assert storage._active_keys == {0, 1, 3}

transfer_queue/metadata.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _parse_dtype(dtype_str: str) -> Any:
6262
except TypeError:
6363
pass
6464
# Fallback: return as-is (e.g. plain Python type repr like "<class 'int'>")
65+
logger.warning("Unknown dtype string '%s', returning as-is", dtype_str)
6566
return dtype_str
6667

6768

@@ -744,7 +745,7 @@ def to_dict(self) -> dict:
744745
dtype = meta.get("dtype")
745746
serialized_schema[field_name] = {
746747
"dtype": str(dtype) if dtype is not None else None,
747-
"shape": list(meta["shape"]) if meta.get("shape") else None,
748+
"shape": list(meta["shape"]) if meta.get("shape") is not None else None,
748749
"is_nested": meta.get("is_nested", False),
749750
"is_non_tensor": meta.get("is_non_tensor", False),
750751
}
@@ -773,7 +774,7 @@ def from_dict(cls, data: dict) -> "BatchMeta":
773774
dtype = _parse_dtype(dtype_str) if dtype_str is not None else None
774775
field_schema[field_name] = {
775776
"dtype": dtype,
776-
"shape": tuple(meta["shape"]) if meta.get("shape") else None,
777+
"shape": tuple(meta["shape"]) if meta.get("shape") is not None else None,
777778
"is_nested": meta.get("is_nested", False),
778779
"is_non_tensor": meta.get("is_non_tensor", False),
779780
}

transfer_queue/storage/simple_backend.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, storage_size: int):
6161
self.field_data: dict[str, dict] = {}
6262
# Capacity upper bound (not pre-allocated list length)
6363
self.storage_size = storage_size
64+
# Track active global_index keys for O(1) capacity checks
65+
self._active_keys: set = set()
6466

6567
def get_data(self, fields: list[str], global_indexes: list) -> dict[str, list]:
6668
"""Get data by global index keys.
@@ -92,20 +94,23 @@ def put_data(self, field_data: dict[str, Any], global_indexes: list) -> None:
9294
global_indexes: Global indexes to use as dict keys.
9395
"""
9496
# Capacity is enforced per unique sample key, not counted per-field
95-
existing_keys: set = set()
96-
for fd in self.field_data.values():
97-
existing_keys.update(fd.keys())
98-
new_global_keys = [k for k in global_indexes if k not in existing_keys]
99-
if len(existing_keys) + len(new_global_keys) > self.storage_size:
97+
new_global_keys = [k for k in global_indexes if k not in self._active_keys]
98+
if len(self._active_keys) + len(new_global_keys) > self.storage_size:
10099
raise ValueError(
101-
f"Storage capacity exceeded: {len(existing_keys)} existing + "
100+
f"Storage capacity exceeded: {len(self._active_keys)} existing + "
102101
f"{len(new_global_keys)} new > {self.storage_size}"
103102
)
104103
for f, values in field_data.items():
104+
if len(values) != len(global_indexes):
105+
raise ValueError(
106+
f"StorageUnitData put_data: field '{f}' values length {len(values)} "
107+
f"!= global_indexes length {len(global_indexes)}, length mismatch"
108+
)
105109
if f not in self.field_data:
106110
self.field_data[f] = {}
107-
for key, val in zip(global_indexes, values, strict=False):
111+
for key, val in zip(global_indexes, values, strict=True):
108112
self.field_data[f][key] = val
113+
self._active_keys.update(global_indexes)
109114

110115
def clear(self, keys: list[int]) -> None:
111116
"""Remove data at given global index keys, immediately freeing memory.
@@ -116,6 +121,7 @@ def clear(self, keys: list[int]) -> None:
116121
for f in self.field_data:
117122
for key in keys:
118123
self.field_data[f].pop(key, None)
124+
self._active_keys -= set(keys)
119125

120126

121127
@ray.remote(num_cpus=1)

transfer_queue/utils/serial_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def encode_with_fallback(obj: Any) -> list[bytestr]:
411411
try:
412412
return list(_encoder.encode(obj))
413413
except (TypeError, ValueError) as e:
414-
logger.info(
414+
logger.debug(
415415
"encode_with_fallback: msgpack failed (%s), falling back to pickle.",
416416
type(e).__name__,
417417
)

transfer_queue/utils/zmq_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,6 @@ def deserialize(cls, frames: list) -> "ZMQMessage":
182182
raise ValueError("Empty frames received")
183183

184184
result = decode_with_fallback(frames)
185-
# Pickle fallback path: serialize() pickled the ZMQMessage directly.
186-
if isinstance(result, cls):
187-
return result
188185
return cls(
189186
request_type=ZMQRequestType(result["request_type"]),
190187
sender_id=result["sender_id"],

0 commit comments

Comments
 (0)