@@ -400,30 +400,38 @@ def __init__(
400400 device = device ,
401401 dtype = dtype ,
402402 )
403- # The HF cache API pinned in CI expects scalar num_heads/head_dim in
404- # early_initialization(). Gemma 4-style hybrid layouts need per-layer
405- # shapes, so initialize each cache layer directly using the resolved
406- # backing-cache geometry instead of relying on the helper.
407- for layer , layer_num_heads , layer_head_dim in zip (
408- self .layers , num_heads , head_dims
409- ):
410- fake_keys_tensor = torch .zeros (
411- (max_batch_size , layer_num_heads , 0 , layer_head_dim ),
403+ # Newer HF cache implementations already support per-layer layouts in
404+ # early_initialization(). Keep that path for Gemma 4, and only fall
405+ # back to manual layer initialization for the older CI-pinned API.
406+ try :
407+ self .early_initialization (
408+ batch_size = max_batch_size ,
409+ num_heads = num_heads ,
410+ head_dim = head_dims ,
412411 dtype = dtype ,
413412 device = device ,
414413 )
415- lazy_init_sig = inspect .signature (layer .lazy_initialization )
416- # Older pinned HF caches take a single fake tensor, while newer
417- # versions expect both key_states and value_states separately.
418- if len (lazy_init_sig .parameters ) == 1 :
419- layer .lazy_initialization (fake_keys_tensor )
420- else :
421- fake_values_tensor = torch .zeros (
414+ except TypeError :
415+ for layer , layer_num_heads , layer_head_dim in zip (
416+ self .layers , num_heads , head_dims
417+ ):
418+ fake_keys_tensor = torch .zeros (
422419 (max_batch_size , layer_num_heads , 0 , layer_head_dim ),
423420 dtype = dtype ,
424421 device = device ,
425422 )
426- layer .lazy_initialization (fake_keys_tensor , fake_values_tensor )
423+ lazy_init_sig = inspect .signature (layer .lazy_initialization )
424+ # Older pinned HF caches take a single fake tensor, while newer
425+ # versions expect both key_states and value_states separately.
426+ if len (lazy_init_sig .parameters ) == 1 :
427+ layer .lazy_initialization (fake_keys_tensor )
428+ else :
429+ fake_values_tensor = torch .zeros (
430+ (max_batch_size , layer_num_heads , 0 , layer_head_dim ),
431+ dtype = dtype ,
432+ device = device ,
433+ )
434+ layer .lazy_initialization (fake_keys_tensor , fake_values_tensor )
427435
428436 # Some models (for example Gemma 4) only allocate cache entries for the
429437 # non-shared KV layers. Mirror the parent StaticCache layout exactly so
0 commit comments