@@ -412,3 +412,152 @@ def test_sequence_ranges_with_bos_token(temp_dir):
412412 # Verify sequence ranges were NOT stored
413413 sequence_ranges = cache .sequence_ranges
414414 assert sequence_ranges is None , "sequence ranges should not be stored for model with BOS token"
415+
416+
417+ def test_activation_cache_slice_indexing_cross_shard (temp_dir ):
418+ """Test ActivationCache slice indexing that crosses shard boundaries."""
419+ # Set flag to handle meta tensors properly
420+ th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
421+
422+ # Skip test if CUDA not available to avoid device mapping issues
423+ if not th .cuda .is_available ():
424+ pytest .skip ("CUDA not available, skipping test to avoid device mapping issues" )
425+
426+ # Create test strings with sufficient data to span multiple shards
427+ test_strings = [
428+ f"This is test sentence number { i } with some content to fill up the cache."
429+ for i in range (20 ) # Create more samples to ensure multiple shards
430+ ]
431+
432+ # Use the list directly
433+ dataset = test_strings
434+
435+ # Load GPT-2 model
436+ tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
437+ model = AutoModelForCausalLM .from_pretrained (
438+ "gpt2" , device_map = "auto" , torch_dtype = th .float32
439+ )
440+ model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
441+ model .tokenizer .pad_token = model .tokenizer .eos_token
442+
443+ # Get a transformer block to extract activations from
444+ target_layer = model .transformer .h [6 ] # Middle layer of GPT-2
445+ submodule_name = "transformer_h_6"
446+
447+ # Parameters for activation collection - use small shard size to ensure multiple shards
448+ batch_size = 3
449+ context_len = 32
450+ d_model = 768 # GPT-2 hidden size
451+ shard_size = 50 # Small shard size to force multiple shards
452+
453+ # Collect activations using ActivationCache
454+ ActivationCache .collect (
455+ data = dataset ,
456+ submodules = (target_layer ,),
457+ submodule_names = (submodule_name ,),
458+ model = model ,
459+ store_dir = temp_dir ,
460+ batch_size = batch_size ,
461+ context_len = context_len ,
462+ shard_size = shard_size , # Small shard size for testing cross-shard slicing
463+ d_model = d_model ,
464+ io = "out" ,
465+ max_total_tokens = 5000 ,
466+ store_tokens = True ,
467+ shuffle_shards = False , # Important: don't shuffle so we can predict shard boundaries
468+ )
469+
470+ # Load the cached activations
471+ cache = ActivationCache (temp_dir , submodule_name + "_out" )
472+
473+ # Verify we have multiple shards
474+ assert len (cache .shards ) >= 2 , f"Expected at least 2 shards, got { len (cache .shards )} "
475+
476+ total_size = len (cache )
477+ print (f"Cache has { len (cache .shards )} shards with total size { total_size } " )
478+
479+ # Print shard boundaries for debugging
480+ shard_boundaries = cache ._range_to_shard_idx
481+ print (f"Shard boundaries: { shard_boundaries } " )
482+
483+ # Test 1: Slice that crosses exactly one shard boundary
484+ if len (cache .shards ) >= 2 :
485+ # Find a slice that starts in first shard and ends in second shard
486+ first_shard_end = shard_boundaries [1 ]
487+ start_idx = max (0 , first_shard_end - 10 )
488+ end_idx = min (total_size , first_shard_end + 10 )
489+
490+ # Get slice result
491+ slice_result = cache [start_idx :end_idx ]
492+
493+ # Get individual results for comparison
494+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
495+
496+ # Verify they match
497+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
498+ f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
499+
500+ # Verify correct shape
501+ expected_length = end_idx - start_idx
502+ assert slice_result .shape [0 ] == expected_length , \
503+ f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
504+
505+ print (f"✓ Cross-shard slice test 1 passed: indices { start_idx } :{ end_idx } " )
506+
507+ # Test 2: Slice that spans multiple shards
508+ if len (cache .shards ) >= 3 :
509+ # Find a slice that starts in first shard and ends in third shard
510+ second_shard_end = shard_boundaries [2 ]
511+ start_idx = max (0 , shard_boundaries [1 ] - 5 ) # Start near end of first shard
512+ end_idx = min (total_size , second_shard_end + 5 ) # End in third shard
513+
514+ slice_result = cache [start_idx :end_idx ]
515+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
516+
517+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
518+ f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
519+
520+ expected_length = end_idx - start_idx
521+ assert slice_result .shape [0 ] == expected_length , \
522+ f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
523+
524+ print (f"✓ Multi-shard slice test passed: indices { start_idx } :{ end_idx } " )
525+
526+ # Test 3: Slice with step parameter across shards
527+ if total_size >= 50 :
528+ start_idx = 5
529+ end_idx = min (total_size , 45 )
530+ step = 3
531+
532+ slice_result = cache [start_idx :end_idx :step ]
533+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx , step )], dim = 0 )
534+
535+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
536+ f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
537+
538+ expected_length = len (range (start_idx , end_idx , step ))
539+ assert slice_result .shape [0 ] == expected_length , \
540+ f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
541+
542+ print (f"✓ Stepped slice test passed: indices { start_idx } :{ end_idx } :{ step } " )
543+
544+ # Test 4: Edge cases - slice at boundaries
545+ if len (cache .shards ) >= 2 :
546+ # Test slice starting exactly at shard boundary
547+ boundary_idx = shard_boundaries [1 ]
548+ if boundary_idx < total_size - 5 :
549+ slice_result = cache [boundary_idx :boundary_idx + 5 ]
550+ individual_results = th .stack ([cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0 )
551+
552+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
553+ f"Boundary slice result doesn't match individual indexing"
554+
555+ print (f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } " )
556+
557+ # Test 5: Empty slice
558+ empty_slice = cache [10 :10 ]
559+ assert empty_slice .shape [0 ] == 0 , f"Expected empty slice, got shape { empty_slice .shape } "
560+ print ("✓ Empty slice test passed" )
561+
562+
563+ print (f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards" )
0 commit comments