Skip to content

Commit bb133dd

Browse files
committed
feat: add ViT activation_offload for InternS1
1 parent b0fdc8d commit bb133dd

3 files changed

Lines changed: 25 additions & 6 deletions

File tree

xtuner/v1/model/compose/intern_s1/intern_s1_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class InternS1VisionConfig(XTunerBaseModelConfig):
4949
use_mask_token: bool = False
5050
use_mean_pooling: bool = True
5151
attn_impl: Literal["flash_attention", "flex_attention", "eager_attention"] = "flash_attention"
52+
text_hidden_layers: int = 0
5253

5354
def model_post_init(self, _):
5455
if self.attn_impl == "flash_attention" and get_device() == "cuda":
@@ -143,6 +144,11 @@ class InternS1Config(InternS1BaseConfig):
143144
vocab_size=153216, hf_key_mapping={r"^model.": "model.language_model."}
144145
)
145146

147+
# FOR ACTIVATION_OFFLOAD PURPOSE, vision and text model need to exchange num_hidden_layers with each other.
148+
def model_post_init(self, __context) -> None:
149+
self.vision_config.text_hidden_layers = self.text_config.num_hidden_layers
150+
self.text_config.vision_hidden_layers = self.vision_config.num_hidden_layers
151+
146152

147153
class InternS1MiniConfig(InternS1BaseConfig):
148154
vision_config: InternS1VisionConfig = InternS1VisionConfig()

xtuner/v1/model/compose/intern_s1/modeling_vision.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from xtuner.v1.ops.act_fn import get_act_fn
3737
from xtuner.v1.utils import get_logger
3838
from xtuner.v1.module import AttnOutputs
39+
import os
40+
from xtuner.v1.utils.activation_offload import async_save_on_cpu
3941

4042
DEVICE = get_device()
4143
DEVICE_MODULE = get_torch_device_module()
@@ -230,6 +232,7 @@ def __init__(self, config: InternS1VisionConfig) -> None:
230232
dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers))
231233
self.layer = nn.ModuleList([
232234
InternS1VisionLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
235+
self.offload_stream = torch.cuda.Stream()
233236

234237
def forward(
235238
self,
@@ -241,8 +244,17 @@ def forward(
241244
for i, layer_module in enumerate(self.layer):
242245
if output_hidden_states:
243246
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore
244-
245-
hidden_states = layer_module(hidden_states)
247+
if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1:
248+
with async_save_on_cpu(
249+
h2d_stream=self.offload_stream,
250+
d2h_stream=self.offload_stream,
251+
block_idx=int(i),
252+
depth=len(self.layer) + self.config.text_hidden_layers,
253+
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
254+
):
255+
hidden_states = layer_module(hidden_states)
256+
else:
257+
hidden_states = layer_module(hidden_states)
246258

247259
if output_hidden_states:
248260
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore

xtuner/v1/model/moe/moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class MoEConfig(TransformerConfig):
154154
moe_bias: bool = False
155155
moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig()
156156
freeze_routers: bool = False
157+
vision_hidden_layers: int = 0
157158

158159
def build(self) -> "MoE":
159160
from xtuner.v1.model.moe.moe import MoE
@@ -430,8 +431,8 @@ def _micro_batch_forward(
430431
with async_save_on_cpu(
431432
h2d_stream=self.offload_stream,
432433
d2h_stream=self.offload_stream,
433-
block_idx=layer_idx - self.config.first_k_dense_replace,
434-
depth=len(self.layers) - self.config.first_k_dense_replace,
434+
block_idx=layer_idx - self.config.first_k_dense_replace + self.config.vision_hidden_layers,
435+
depth=len(self.layers) - self.config.first_k_dense_replace + self.config.vision_hidden_layers,
435436
custom_check_fn=lambda x: x.data_ptr()
436437
in [hidden_states.data_ptr() for hidden_states in hidden_states_list],
437438
prefetch=True,
@@ -577,8 +578,8 @@ def _forward(
577578
with async_save_on_cpu(
578579
h2d_stream=self.offload_stream,
579580
d2h_stream=self.offload_stream,
580-
block_idx=int(idx),
581-
depth=len(self.layers),
581+
block_idx=int(idx) + self.config.vision_hidden_layers,
582+
depth=len(self.layers) + self.config.vision_hidden_layers,
582583
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
583584
):
584585
layer_results = decoder_layer(

0 commit comments

Comments
 (0)