99from transformers import AutoModelForCausalLM , AutoTokenizer
1010import numpy as np
1111
12+
1213@pytest .fixture
1314def temp_dir ():
1415 """Create a temporary directory for test files."""
@@ -274,7 +275,7 @@ def test_activation_cache_with_normalizer(temp_dir):
274275def test_sequence_ranges_no_bos_token (temp_dir ):
275276 """Test that sequence ranges are stored when model has no BOS token."""
276277 # Set flag to handle meta tensors properly
277- if hasattr (th .fx , ' experimental' ):
278+ if hasattr (th .fx , " experimental" ):
278279 th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
279280
280281 # Skip test if CUDA not available
@@ -296,12 +297,18 @@ def test_sequence_ranges_no_bos_token(temp_dir):
296297 )
297298 model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
298299 model .tokenizer .pad_token = model .tokenizer .eos_token
299-
300+
300301 # Simulate model without BOS token
301302 original_bos_token_id = model .tokenizer .bos_token_id
302303 model .tokenizer .bos_token_id = None
303304
304- tokens = model .tokenizer (test_strings , add_special_tokens = True , return_tensors = "pt" , padding = True , truncation = True )
305+ tokens = model .tokenizer (
306+ test_strings ,
307+ add_special_tokens = True ,
308+ return_tensors = "pt" ,
309+ padding = True ,
310+ truncation = True ,
311+ )
305312 lengths = tokens ["attention_mask" ].sum (dim = 1 ).tolist ()
306313 ranges = np .cumsum ([0 ] + lengths )
307314 try :
@@ -335,28 +342,40 @@ def test_sequence_ranges_no_bos_token(temp_dir):
335342
336343 # Verify sequence ranges were stored
337344 sequence_ranges = cache .sequence_ranges
338- assert sequence_ranges is not None , "sequence ranges should be stored for model without BOS token"
339-
345+ assert (
346+ sequence_ranges is not None
347+ ), "sequence ranges should be stored for model without BOS token"
348+
340349 # Should have one sequence start per input string plus one for the last sequence
341- assert len (sequence_ranges ) == len (test_strings ) + 1 , f"Expected { len (test_strings )} sequence ranges, got { len (sequence_ranges )} "
342-
350+ assert (
351+ len (sequence_ranges ) == len (test_strings ) + 1
352+ ), f"Expected { len (test_strings )} sequence ranges, got { len (sequence_ranges )} "
353+
343354 # First sequence should start at position 0
344- assert sequence_ranges [0 ].item () == 0 , "First sequence should start at position 0"
355+ assert (
356+ sequence_ranges [0 ].item () == 0
357+ ), "First sequence should start at position 0"
345358
346359 # sequence ranges should be the same as the ranges computed from the tokens
347- assert np .allclose (sequence_ranges , ranges ), "sequence ranges should be the same as the ranges computed from the tokens"
348-
360+ assert np .allclose (
361+ sequence_ranges , ranges
362+ ), "sequence ranges should be the same as the ranges computed from the tokens"
363+
349364 # sequence ranges should be in ascending order
350365 for i in range (1 , len (sequence_ranges )):
351- assert sequence_ranges [i ] > sequence_ranges [i - 1 ], f"sequence ranges should be ascending: { sequence_ranges } "
366+ assert (
367+ sequence_ranges [i ] > sequence_ranges [i - 1 ]
368+ ), f"sequence ranges should be ascending: { sequence_ranges } "
352369
353370 # Verify sequence ranges align with token boundaries
354371 tokens = cache .tokens
355372 total_tokens = len (tokens )
356-
373+
357374 # All sequence ranges should be valid indices
358375 for start_idx in sequence_ranges :
359- assert 0 <= start_idx <= total_tokens , f"Invalid sequence start index: { start_idx } "
376+ assert (
377+ 0 <= start_idx <= total_tokens
378+ ), f"Invalid sequence start index: { start_idx } "
360379
361380 finally :
362381 # Restore original BOS token
@@ -366,7 +385,7 @@ def test_sequence_ranges_no_bos_token(temp_dir):
366385def test_sequence_ranges_with_bos_token (temp_dir ):
367386 """Test that sequence ranges are NOT stored when model has BOS token."""
368387 # Set flag to handle meta tensors properly
369- if hasattr (th .fx , ' experimental' ):
388+ if hasattr (th .fx , " experimental" ):
370389 th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
371390
372391 # Skip test if CUDA not available
@@ -382,7 +401,7 @@ def test_sequence_ranges_with_bos_token(temp_dir):
382401 )
383402 model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
384403 model .tokenizer .pad_token = model .tokenizer .eos_token
385-
404+
386405 # Ensure model has BOS token (set it explicitly)
387406 model .tokenizer .bos_token_id = model .tokenizer .eos_token_id
388407
@@ -411,7 +430,9 @@ def test_sequence_ranges_with_bos_token(temp_dir):
411430
412431 # Verify sequence ranges were NOT stored
413432 sequence_ranges = cache .sequence_ranges
414- assert sequence_ranges is None , "sequence ranges should not be stored for model with BOS token"
433+ assert (
434+ sequence_ranges is None
435+ ), "sequence ranges should not be stored for model with BOS token"
415436
416437
417438def test_activation_cache_slice_indexing_cross_shard (temp_dir ):
@@ -469,39 +490,45 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
469490
470491 # Load the cached activations
471492 cache = ActivationCache (temp_dir , submodule_name + "_out" )
472-
493+
473494 # Verify we have multiple shards
474- assert len (cache .shards ) >= 2 , f"Expected at least 2 shards, got { len (cache .shards )} "
475-
495+ assert (
496+ len (cache .shards ) >= 2
497+ ), f"Expected at least 2 shards, got { len (cache .shards )} "
498+
476499 total_size = len (cache )
477500 print (f"Cache has { len (cache .shards )} shards with total size { total_size } " )
478-
501+
479502 # Print shard boundaries for debugging
480503 shard_boundaries = cache ._range_to_shard_idx
481504 print (f"Shard boundaries: { shard_boundaries } " )
482-
505+
483506 # Test 1: Slice that crosses exactly one shard boundary
484507 if len (cache .shards ) >= 2 :
485508 # Find a slice that starts in first shard and ends in second shard
486509 first_shard_end = shard_boundaries [1 ]
487510 start_idx = max (0 , first_shard_end - 10 )
488511 end_idx = min (total_size , first_shard_end + 10 )
489-
512+
490513 # Get slice result
491514 slice_result = cache [start_idx :end_idx ]
492-
515+
493516 # Get individual results for comparison
494- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
495-
517+ individual_results = th .stack (
518+ [cache [i ] for i in range (start_idx , end_idx )], dim = 0
519+ )
520+
496521 # Verify they match
497- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
498- f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
499-
522+ assert th .allclose (
523+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
524+ ), f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
525+
500526 # Verify correct shape
501527 expected_length = end_idx - start_idx
502- assert slice_result .shape [0 ] == expected_length , \
503- f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
504-
528+ assert (
529+ slice_result .shape [0 ] == expected_length
530+ ), f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
531+
505532 print (f"✓ Cross-shard slice test 1 passed: indices { start_idx } :{ end_idx } " )
506533
507534 # Test 2: Slice that spans multiple shards
@@ -510,54 +537,70 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
510537 second_shard_end = shard_boundaries [2 ]
511538 start_idx = max (0 , shard_boundaries [1 ] - 5 ) # Start near end of first shard
512539 end_idx = min (total_size , second_shard_end + 5 ) # End in third shard
513-
540+
514541 slice_result = cache [start_idx :end_idx ]
515- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
516-
517- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
518- f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
519-
542+ individual_results = th .stack (
543+ [cache [i ] for i in range (start_idx , end_idx )], dim = 0
544+ )
545+
546+ assert th .allclose (
547+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
548+ ), f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
549+
520550 expected_length = end_idx - start_idx
521- assert slice_result .shape [0 ] == expected_length , \
522- f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
523-
551+ assert (
552+ slice_result .shape [0 ] == expected_length
553+ ), f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
554+
524555 print (f"✓ Multi-shard slice test passed: indices { start_idx } :{ end_idx } " )
525556
526557 # Test 3: Slice with step parameter across shards
527558 if total_size >= 50 :
528559 start_idx = 5
529560 end_idx = min (total_size , 45 )
530561 step = 3
531-
562+
532563 slice_result = cache [start_idx :end_idx :step ]
533- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx , step )], dim = 0 )
534-
535- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
536- f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
537-
564+ individual_results = th .stack (
565+ [cache [i ] for i in range (start_idx , end_idx , step )], dim = 0
566+ )
567+
568+ assert th .allclose (
569+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
570+ ), f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
571+
538572 expected_length = len (range (start_idx , end_idx , step ))
539- assert slice_result .shape [0 ] == expected_length , \
540- f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
541-
573+ assert (
574+ slice_result .shape [0 ] == expected_length
575+ ), f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
576+
542577 print (f"✓ Stepped slice test passed: indices { start_idx } :{ end_idx } :{ step } " )
543578
544579 # Test 4: Edge cases - slice at boundaries
545580 if len (cache .shards ) >= 2 :
546581 # Test slice starting exactly at shard boundary
547582 boundary_idx = shard_boundaries [1 ]
548583 if boundary_idx < total_size - 5 :
549- slice_result = cache [boundary_idx :boundary_idx + 5 ]
550- individual_results = th .stack ([cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0 )
551-
552- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
553- f"Boundary slice result doesn't match individual indexing"
554-
555- print (f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } " )
584+ slice_result = cache [boundary_idx : boundary_idx + 5 ]
585+ individual_results = th .stack (
586+ [cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0
587+ )
588+
589+ assert th .allclose (
590+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
591+ ), f"Boundary slice result doesn't match individual indexing"
592+
593+ print (
594+ f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } "
595+ )
556596
557597 # Test 5: Empty slice
558598 empty_slice = cache [10 :10 ]
559- assert empty_slice .shape [0 ] == 0 , f"Expected empty slice, got shape { empty_slice .shape } "
599+ assert (
600+ empty_slice .shape [0 ] == 0
601+ ), f"Expected empty slice, got shape { empty_slice .shape } "
560602 print ("✓ Empty slice test passed" )
561-
562603
563- print (f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards" )
604+ print (
605+ f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards"
606+ )
0 commit comments