@@ -176,33 +176,39 @@ def forward(
176176 raise ValueError (
177177 "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
178178 )
179+ assert hidden_states .dim () == 3 and hidden_states .size (0 ) == 1 , (
180+ "THD expects embeddings shaped [1, total_tokens, hidden_size]."
181+ )
182+ hidden_states = hidden_states .squeeze (0 )
179183
180184 elif self .config .attn_input_format == "bshd" :
181185 if any (x is not None for x in [cu_seq_lens_q , cu_seq_lens_k , max_length_q , max_length_k ]):
182186 raise ValueError (
183187 "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
184188 )
185189
186- if self .config .attn_input_format == "bshd" and self .te_rope_emb is not None :
187- te_rope_emb = self .te_rope_emb .to (
188- device = hidden_states .device , dtype = hidden_states .dtype , non_blocking = True
189- )
190- seq_len = hidden_states .shape [1 ]
191- if te_rope_emb .size (0 ) < seq_len :
192- raise RuntimeError (
193- f"ROPE length { te_rope_emb .size (0 )} < input seq length { seq_len } . "
194- f"Increase max_position_embeddings."
190+ te_rope_emb = None
191+ if self .config .position_embedding_type == "rotary" :
192+ if self .config .attn_input_format == "bshd" :
193+ te_rope_emb = self .te_rope_emb .to (
194+ device = hidden_states .device , dtype = hidden_states .dtype , non_blocking = True
195+ )
196+ seq_len = hidden_states .shape [1 ]
197+ if te_rope_emb .size (0 ) < seq_len :
198+ raise RuntimeError (
199+ f"ROPE length { te_rope_emb .size (0 )} < input seq length { seq_len } . "
200+ f"Increase max_position_embeddings."
201+ )
202+ te_rope_emb = te_rope_emb [:seq_len ]
203+
204+ elif self .config .attn_input_format == "thd" :
205+ assert cu_seq_lens_q is not None
206+ te_rope_emb = self .rotary_embeddings (max_seq_len = cu_seq_lens_q [- 1 ]).to (
207+ device = hidden_states .device , dtype = hidden_states .dtype , non_blocking = True
195208 )
196- te_rope_emb = te_rope_emb [:seq_len ]
197209
198- elif self .config .attn_input_format == "thd" :
199- assert cu_seq_lens_q is not None
200- te_rope_emb = self .rotary_embeddings (max_seq_len = cu_seq_lens_q [- 1 ]).to (
201- device = hidden_states .device , dtype = hidden_states .dtype , non_blocking = True
202- )
203- hidden_states = hidden_states .squeeze (0 )
204- else :
205- te_rope_emb = None
210+ else :
211+ raise ValueError (f"Unsupported attention input format: { self .config .attn_input_format } " )
206212
207213 for layer_module in self .layers :
208214 if output_hidden_states :
0 commit comments