Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
462 changes: 462 additions & 0 deletions tests/test_controller.py

Large diffs are not rendered by default.

185 changes: 180 additions & 5 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_data_partition_status():
1: {"input_ids": (512,), "attention_mask": (512,)},
2: {"input_ids": (512,), "attention_mask": (512,)},
},
custom_meta=None,
)

assert success
Expand Down Expand Up @@ -172,6 +173,7 @@ def test_dynamic_expansion_scenarios():
5: {"field_1": (32,)},
10: {"field_1": (32,)},
},
custom_meta=None,
)
assert partition.total_samples_num == 3
assert partition.allocated_samples_num >= 11 # Should accommodate index 10
Expand All @@ -180,7 +182,7 @@ def test_dynamic_expansion_scenarios():
# Scenario 2: Adding many fields dynamically
for i in range(15):
partition.update_production_status(
[0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}
[0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}, None
)

assert partition.total_fields_num == 16 # Original + 15 new fields
Expand Down Expand Up @@ -222,7 +224,7 @@ def test_data_partition_status_advanced():
# Add data to trigger expansion
dtypes = {i: {f"dynamic_field_{s}": "torch.bool" for s in ["a", "b", "c"]} for i in range(5)}
shapes = {i: {f"dynamic_field_{s}": (32,) for s in ["a", "b", "c"]} for i in range(5)}
partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes)
partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes, None)

# Properties should reflect current state
assert partition.total_samples_num >= 5 # At least 5 samples
Expand Down Expand Up @@ -253,7 +255,7 @@ def test_data_partition_status_advanced():
11: {"field_d": (32,)},
12: {"field_d": (32,)},
}
partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes) # Triggers sample expansion
partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion
expanded_consumption = partition.get_consumption_status(task_name)
assert expanded_consumption[0] == 1 # Preserved
assert expanded_consumption[1] == 1 # Preserved
Expand All @@ -265,13 +267,13 @@ def test_data_partition_status_advanced():
# Start with some fields
dtypes = {0: {"initial_field": "torch.bool"}}
shapes = {0: {"field_d": (32,)}}
partition.update_production_status([0], ["initial_field"], dtypes, shapes)
partition.update_production_status([0], ["initial_field"], dtypes, shapes, None)

# Add many fields to trigger column expansion
new_fields = [f"dynamic_field_{i}" for i in range(20)]
dtypes = {1: {f"dynamic_field_{i}": "torch.bool" for i in range(20)}}
shapes = {1: {f"dynamic_field_{i}": (32,) for i in range(20)}}
partition.update_production_status([1], new_fields, dtypes, shapes)
partition.update_production_status([1], new_fields, dtypes, shapes, None)

# Verify all fields are registered and accessible
assert "initial_field" in partition.field_name_mapping
Expand Down Expand Up @@ -441,3 +443,176 @@ def test_performance_characteristics():
print("✓ Memory usage patterns reasonable")

print("Performance characteristics tests passed!\n")


def test_custom_meta_in_data_partition_status():
"""Test custom_meta functionality in DataPartitionStatus."""
print("Testing custom_meta in DataPartitionStatus...")

from transfer_queue.controller import DataPartitionStatus

partition = DataPartitionStatus(partition_id="custom_meta_test")

# Test 1: Basic custom_meta storage via update_production_status
global_indices = [0, 1, 2]
field_names = ["input_ids", "attention_mask"]
dtypes = {
0: {"input_ids": "torch.int32", "attention_mask": "torch.bool"},
1: {"input_ids": "torch.int32", "attention_mask": "torch.bool"},
2: {"input_ids": "torch.int32", "attention_mask": "torch.bool"},
}
shapes = {
0: {"input_ids": (512,), "attention_mask": (512,)},
1: {"input_ids": (512,), "attention_mask": (512,)},
2: {"input_ids": (512,), "attention_mask": (512,)},
}
custom_meta = {
0: {"input_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}},
1: {"input_ids": {"token_count": 200}, "attention_mask": {"mask_ratio": 0.2}},
2: {"input_ids": {"token_count": 300}, "attention_mask": {"mask_ratio": 0.3}},
}

success = partition.update_production_status(
global_indices=global_indices,
field_names=field_names,
dtypes=dtypes,
shapes=shapes,
custom_meta=custom_meta,
)

assert success
assert len(partition.field_custom_metas) == 3

# Verify custom_meta is stored correctly
assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100
assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2
assert partition.field_custom_metas[2]["input_ids"]["token_count"] == 300

print("✓ Basic custom_meta storage works")

# Test 2: get_field_custom_meta retrieval
retrieved_meta = partition.get_field_custom_meta([0, 1, 2], ["input_ids", "attention_mask"])

assert 0 in retrieved_meta
assert 1 in retrieved_meta
assert 2 in retrieved_meta
assert retrieved_meta[0]["input_ids"]["token_count"] == 100
assert retrieved_meta[1]["attention_mask"]["mask_ratio"] == 0.2

print("✓ get_field_custom_meta retrieval works")

# Test 3: get_field_custom_meta with partial field filter
partial_meta = partition.get_field_custom_meta([0, 1], ["input_ids"])

assert 0 in partial_meta
assert 1 in partial_meta
assert "input_ids" in partial_meta[0]
assert "attention_mask" not in partial_meta[0] # Should not include non-requested fields

print("✓ get_field_custom_meta with partial fields works")

# Test 4: get_field_custom_meta with non-existent global_index
empty_meta = partition.get_field_custom_meta([999], ["input_ids"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this will happen since we strictly requires the shape of custom_meta to be exactly with number of samples (commented earlier)?

assert 999 not in empty_meta # Should not include non-existent indices

print("✓ get_field_custom_meta handles non-existent indices correctly")

# Test 5: custom_meta update (merge) on same global_index
additional_custom_meta = {
0: {"new_field": {"new_key": "new_value"}},
}
success = partition.update_production_status(
global_indices=[0],
field_names=["new_field"],
dtypes={0: {"new_field": "torch.float32"}},
shapes={0: {"new_field": (64,)}},
custom_meta=additional_custom_meta,
)

assert success
# Original custom_meta should be preserved
assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100
# New custom_meta should be merged
assert partition.field_custom_metas[0]["new_field"]["new_key"] == "new_value"

print("✓ Custom_meta merge on update works")

# Test 6: custom_meta cleared on clear_data
partition.clear_data([0], clear_consumption=True)

assert 0 not in partition.field_custom_metas
assert 1 in partition.field_custom_metas # Other samples should remain
assert 2 in partition.field_custom_metas

print("✓ Custom_meta cleared on clear_data works")

# Test 7: custom_meta None does not create entries
partition2 = DataPartitionStatus(partition_id="custom_meta_test_2")
success = partition2.update_production_status(
global_indices=[0, 1],
field_names=["field1"],
dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}},
shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}},
custom_meta=None,
)

assert success
assert len(partition2.field_custom_metas) == 0

print("✓ Custom_meta None handling works")

# Test 8: custom_meta length mismatch raises ValueError
partition3 = DataPartitionStatus(partition_id="custom_meta_test_3")
mismatched_custom_meta = {
0: {"field1": {"key": "value"}},
# Missing entries for 1 and 2
}
success = partition3.update_production_status(
global_indices=[0, 1, 2],
field_names=["field1"],
dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}, 2: {"field1": "torch.int32"}},
shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}, 2: {"field1": (32,)}},
custom_meta=mismatched_custom_meta,
)

# Should return False due to length mismatch (caught by exception handler)
assert success is False

print("✓ Custom_meta length mismatch error handling works")

# Test 9: Complex nested custom_meta
partition4 = DataPartitionStatus(partition_id="custom_meta_test_4")
complex_custom_meta = {
0: {
"field1": {
"nested": {"level1": {"level2": {"value": 42}}},
"list_data": [1, 2, 3],
"mixed": {"str": "test", "int": 100, "float": 3.14, "bool": True},
}
},
}
success = partition4.update_production_status(
global_indices=[0],
field_names=["field1"],
dtypes={0: {"field1": "torch.int32"}},
shapes={0: {"field1": (32,)}},
custom_meta=complex_custom_meta,
)

assert success
stored_meta = partition4.field_custom_metas[0]["field1"]
assert stored_meta["nested"]["level1"]["level2"]["value"] == 42
assert stored_meta["list_data"] == [1, 2, 3]
assert stored_meta["mixed"]["str"] == "test"
assert stored_meta["mixed"]["bool"] is True

print("✓ Complex nested custom_meta storage works")

# Test 10: custom_meta preserved in snapshot
snapshot = partition4.to_snapshot()
assert 0 in snapshot.field_custom_metas
assert snapshot.field_custom_metas[0]["field1"]["nested"]["level1"]["level2"]["value"] == 42

print("✓ Custom_meta preserved in snapshot")

print("Custom_meta in DataPartitionStatus tests passed!\n")
Loading
Loading