|
50 | 50 | lazy rope pattern needed for MLA models. |
51 | 51 | """ |
52 | 52 |
|
53 | | -import contextlib |
54 | 53 | import logging |
55 | 54 |
|
56 | 55 | import torch |
57 | 56 | import torch.nn.functional as F |
58 | | -from torch.nn import CrossEntropyLoss |
59 | 57 | from transformers import PreTrainedModel |
60 | 58 | from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config |
61 | 59 | from transformers.trainer_pt_utils import LabelSmoother |
62 | 60 | from transformers.utils import ModelOutput |
63 | 61 |
|
64 | 62 | from ..dflash.conversion import DFlashDMRegistry |
65 | 63 | from ..dflash.dflash_model import DFlashModel |
66 | | -from .modeling_dflash import ( # noqa: F401 |
67 | | - DFlashAttention, |
68 | | - DFlashBaseModelOutput, |
69 | | - DFlashModule, |
70 | | - build_target_layer_ids, |
71 | | -) |
| 64 | +from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401 |
72 | 65 | from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS |
73 | 66 |
|
74 | 67 | logger = logging.getLogger(__name__) |
@@ -121,25 +114,6 @@ def _find_base_model_parts(self): |
121 | 114 | else: |
122 | 115 | raise ValueError(f"Part {name} not found in model") |
123 | 116 |
|
124 | | - def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs): |
125 | | - """Run the base model forward pass with optional freeze and base-model loss.""" |
126 | | - ctx = torch.no_grad() if freeze else contextlib.nullcontext() |
127 | | - with ctx: |
128 | | - outputs = super().forward( |
129 | | - input_ids=input_ids, |
130 | | - attention_mask=attention_mask, |
131 | | - output_hidden_states=True, |
132 | | - **kwargs, |
133 | | - ) |
134 | | - base_loss = None |
135 | | - if not freeze and labels is not None: |
136 | | - loss_fct = CrossEntropyLoss() |
137 | | - base_loss = loss_fct( |
138 | | - outputs.logits.view(-1, outputs.logits.shape[-1]), |
139 | | - labels.view(-1), |
140 | | - ) |
141 | | - return outputs, base_loss |
142 | | - |
143 | 117 | def modify(self, config): |
144 | 118 | """Initialize DFlash draft module.""" |
145 | 119 | super().modify(config) |
@@ -406,16 +380,6 @@ def _compute_loss( |
406 | 380 |
|
407 | 381 | return loss, accuracy |
408 | 382 |
|
409 | | - def _dflash_base_model_forward( |
410 | | - self, input_ids, attention_mask, freeze=True |
411 | | - ) -> DFlashBaseModelOutput: |
412 | | - """Run base model and extract target hidden states for DFlash.""" |
413 | | - outputs, _ = self._base_model_forward(input_ids, attention_mask, freeze=freeze) |
414 | | - # hidden_states[0] is the embedding output; layer i output is at index i+1 |
415 | | - selected = [outputs.hidden_states[lid + 1] for lid in self.target_layer_ids] |
416 | | - target_hidden = torch.cat(selected, dim=-1) |
417 | | - return DFlashBaseModelOutput(target_hidden=target_hidden, logits=outputs.logits) |
418 | | - |
419 | 383 | def forward( |
420 | 384 | self, |
421 | 385 | input_ids=None, |
@@ -464,10 +428,18 @@ def forward( |
464 | 428 | f"Adjust training_seq_len or use padding." |
465 | 429 | ) |
466 | 430 |
|
467 | | - # 1. Run base model → extract target hidden states |
468 | | - base_outputs = self._dflash_base_model_forward( |
469 | | - input_ids, attention_mask, freeze=self.dflash_freeze_base_model |
470 | | - ) |
| 431 | + # 1. Run base model → hidden states |
| 432 | + # TODO: For co-training the base model, remove no_grad and eval() switch. |
| 433 | + with torch.no_grad(): |
| 434 | + base_outputs = super().forward( |
| 435 | + input_ids=input_ids, |
| 436 | + attention_mask=attention_mask, |
| 437 | + output_hidden_states=True, |
| 438 | + ) |
| 439 | + |
| 440 | + offset = 1 |
| 441 | + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] |
| 442 | + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] |
471 | 443 |
|
472 | 444 | # 2. Build loss mask. |
473 | 445 | # When labels are provided (answer_only_loss), they already encode both |
@@ -497,18 +469,13 @@ def forward( |
497 | 469 | ) |
498 | 470 | full_pos = self._build_position_ids(seq_len, anchor_positions, device) |
499 | 471 | attn_mask = self._build_draft_attention_mask( |
500 | | - seq_len, |
501 | | - anchor_positions, |
502 | | - block_keep_mask, |
503 | | - n_blocks, |
504 | | - base_outputs.target_hidden.dtype, |
505 | | - device, |
| 472 | + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device |
506 | 473 | ) |
507 | 474 |
|
508 | 475 | # 5. Draft forward |
509 | 476 | hidden = self.dflash_module( |
510 | 477 | noise_embedding=noise_embedding, |
511 | | - target_hidden=base_outputs.target_hidden, |
| 478 | + target_hidden=target_hidden, |
512 | 479 | position_ids=full_pos, |
513 | 480 | attention_mask=attn_mask, |
514 | 481 | ) |
@@ -582,14 +549,29 @@ def pseudo_speculative_generate(self, input_ids, steps=1): |
582 | 549 | base_token: Next token from base model [B, 1]. |
583 | 550 | draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1. |
584 | 551 | """ |
585 | | - base_outputs = self._dflash_base_model_forward(input_ids, attention_mask=None, freeze=True) |
586 | | - assert base_outputs.logits is not None |
587 | | - base_token = base_outputs.logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) |
| 552 | + # Call the base model's inner model directly (avoids DynamicModule dispatch) |
| 553 | + model_output = self._base_model( |
| 554 | + input_ids=input_ids, |
| 555 | + output_hidden_states=True, |
| 556 | + ) |
| 557 | + # Compute logits via lm_head |
| 558 | + base_logits = self._base_model_lm_head(model_output.last_hidden_state) |
| 559 | + # Build output with hidden_states |
| 560 | + base_outputs = ModelOutput( |
| 561 | + logits=base_logits, |
| 562 | + hidden_states=model_output.hidden_states, |
| 563 | + ) |
| 564 | + base_logits = base_outputs.logits |
| 565 | + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) |
588 | 566 |
|
589 | 567 | if steps < 1: |
590 | 568 | return base_token, None |
591 | 569 |
|
592 | | - target_hidden = base_outputs.target_hidden |
| 570 | + # Extract target hidden states (raw, before FC projection) |
| 571 | + hid_offset = 1 |
| 572 | + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] |
| 573 | + target_hidden = torch.cat(selected, dim=-1) |
| 574 | + |
593 | 575 | block_size = self.dflash_block_size |
594 | 576 | bsz = input_ids.shape[0] |
595 | 577 | device = input_ids.device |
|
0 commit comments