@@ -114,6 +114,66 @@ def _make_model(config: DeepseekV4Config) -> DeepseekV4ForCausalLM:
114114
115115
116116class TestDeepseekV4ModelSmoke :
117+ def test_hc_comb_transpose_used_at_attn_and_mlp_sites (self ):
118+ """Both HC expand sites mix residual streams as ``comb.T @ x``."""
119+
120+ class _FixedHC (torch .nn .Module ):
121+ def __init__ (self , comb ):
122+ super ().__init__ ()
123+ self .register_buffer ("comb" , comb )
124+
125+ def compute_weights (self , hidden_streams ):
126+ bsz , seq , hc_mult = hidden_streams .shape [:3 ]
127+ pre = torch .zeros (bsz , seq , hc_mult , dtype = torch .float32 , device = hidden_streams .device )
128+ post = torch .zeros_like (pre )
129+ return pre , post , self .comb .expand (bsz , seq , - 1 , - 1 )
130+
131+ class _ZeroAttention (torch .nn .Module ):
132+ def forward (self , hidden_states , ** kwargs ):
133+ return torch .zeros_like (hidden_states ), None
134+
135+ class _ZeroMLP (torch .nn .Module ):
136+ def forward (self , hidden_states , padding_mask = None ):
137+ return torch .zeros_like (hidden_states )
138+
139+ cfg = _tiny_config (num_hidden_layers = 1 , num_hash_layers = 0 , compress_ratios = [0 ])
140+ model = _make_model (cfg )
141+ block = model .model .layers ["0" ]
142+ block .attn_hc = _FixedHC (
143+ torch .tensor (
144+ [
145+ [1.0 , 2.0 , 0.0 , 1.0 ],
146+ [0.0 , 1.0 , 3.0 , 0.0 ],
147+ [4.0 , 0.0 , 1.0 , 2.0 ],
148+ [0.0 , 5.0 , 0.0 , 1.0 ],
149+ ]
150+ )
151+ )
152+ block .ffn_hc = _FixedHC (
153+ torch .tensor (
154+ [
155+ [2.0 , 0.0 , 1.0 , 0.0 ],
156+ [1.0 , 3.0 , 0.0 , 2.0 ],
157+ [0.0 , 1.0 , 4.0 , 0.0 ],
158+ [5.0 , 0.0 , 1.0 , 1.0 ],
159+ ]
160+ )
161+ )
162+ block .self_attn = _ZeroAttention ()
163+ block .mlp = _ZeroMLP ()
164+ block .input_layernorm = torch .nn .Identity ()
165+ block .post_attention_layernorm = torch .nn .Identity ()
166+
167+ x = torch .arange (cfg .hc_mult * cfg .hidden_size , dtype = torch .float32 ).view (1 , 1 , cfg .hc_mult , cfg .hidden_size )
168+ expected = torch .matmul (block .attn_hc .comb .transpose (- 1 , - 2 ), x )
169+ expected = torch .matmul (block .ffn_hc .comb .transpose (- 1 , - 2 ), expected )
170+ wrong_orientation = torch .matmul (block .ffn_hc .comb , torch .matmul (block .attn_hc .comb , x ))
171+
172+ actual = block (x , position_embeddings = (torch .empty (0 ), torch .empty (0 )))
173+
174+ torch .testing .assert_close (actual , expected )
175+ assert not torch .allclose (expected , wrong_orientation )
176+
117177 @_REQUIRES_CUDA
118178 def test_forward_shape (self ):
119179 """Forward pass produces logits of the right shape."""
0 commit comments