Skip to content

Commit 019057a

Browse files
NengXu001NENGXU003
authored andcommitted
feat: add ViT activation_offload for InternS1
1 parent 87e50ab commit 019057a

3 files changed

Lines changed: 43 additions & 23 deletions

File tree

xtuner/v1/model/compose/intern_s1/modeling_vision.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from xtuner.v1.ops.act_fn import get_act_fn
3737
from xtuner.v1.utils import get_logger
3838
from xtuner.v1.module import AttnOutputs
39+
import os
40+
from xtuner.v1.utils.activation_offload import async_save_on_cpu
3941

4042
DEVICE = get_device()
4143
DEVICE_MODULE = get_torch_device_module()
@@ -230,6 +232,7 @@ def __init__(self, config: InternS1VisionConfig) -> None:
230232
dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers))
231233
self.layer = nn.ModuleList([
232234
InternS1VisionLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
235+
self.offload_stream = torch.cuda.Stream()
233236

234237
def forward(
235238
self,
@@ -241,8 +244,17 @@ def forward(
241244
for i, layer_module in enumerate(self.layer):
242245
if output_hidden_states:
243246
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore
244-
245-
hidden_states = layer_module(hidden_states)
247+
if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1:
248+
with async_save_on_cpu(
249+
h2d_stream=self.offload_stream,
250+
d2h_stream=self.offload_stream,
251+
block_idx=int(i),
252+
group="vision",
253+
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
254+
):
255+
hidden_states = layer_module(hidden_states)
256+
else:
257+
hidden_states = layer_module(hidden_states)
246258

247259
if output_hidden_states:
248260
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore

xtuner/v1/model/moe/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def _micro_batch_forward(
432432
h2d_stream=self.offload_stream,
433433
d2h_stream=self.offload_stream,
434434
block_idx=layer_idx - self.config.first_k_dense_replace,
435-
depth=len(self.layers) - self.config.first_k_dense_replace,
435+
group="text",
436436
custom_check_fn=lambda x: x.data_ptr()
437437
in [hidden_states.data_ptr() for hidden_states in hidden_states_list],
438438
prefetch=True,
@@ -579,7 +579,7 @@ def _forward(
579579
h2d_stream=self.offload_stream,
580580
d2h_stream=self.offload_stream,
581581
block_idx=int(idx),
582-
depth=len(self.layers),
582+
group="text",
583583
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
584584
):
585585
layer_results = decoder_layer(

xtuner/v1/utils/activation_offload.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def __init__(self):
2929
self._block_tensor_nums = {} # offload tensors per block
3030

3131
def get_cnt(self, block_idx):
32+
prev_block_idx = None if self._block_idx == -1 else self._block_idx
3233
after_block = False
34+
3335
if block_idx > self._block_idx:
3436
self._block_tensor_nums[block_idx] = 1
3537
if block_idx != 0:
@@ -41,9 +43,9 @@ def get_cnt(self, block_idx):
4143
# one step end
4244
self._block_idx = block_idx
4345
self._block_tensor_nums = {block_idx: 1}
44-
46+
4547
offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
46-
return offload_tensor_key, after_block
48+
return offload_tensor_key, after_block, prev_block_idx
4749

4850
def get_prefetch_keys(self, block_idx, tensor_idx):
4951
prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)
@@ -193,11 +195,13 @@ def __init__(self, check=False):
193195
self.items = {}
194196
self.check = check
195197
self.device_item = []
196-
self.getcnt = GetCnt()
198+
self.getcnt = {}
197199
self.may_npu_tensors = {}
198200

199-
def get_cnt(self, block_idx):
200-
return self.getcnt.get_cnt(block_idx)
201+
def get_cnt(self, block_idx, group="default"):
202+
if group not in self.getcnt:
203+
self.getcnt[group] = GetCnt()
204+
return self.getcnt[group].get_cnt(block_idx)
201205

202206
def assert_exist(self, key):
203207
if key not in self.items:
@@ -249,16 +253,17 @@ def get(self, key):
249253
self.may_npu_tensors.update({key: self.items.pop(key)})
250254
return act
251255

252-
def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
253-
prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
256+
def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream, group="default"):
257+
if group not in self.getcnt:
258+
return
259+
prefetch_keys = self.getcnt[group].get_prefetch_keys(block_idx, tensor_idx)
254260
for prefetch_key in prefetch_keys:
255-
if self.exist(prefetch_key):
256-
prefetch_swap_tensor = self.get(prefetch_key)
261+
full_key = f"{group}_{prefetch_key}"
262+
if self.exist(full_key):
263+
prefetch_swap_tensor = self.get(full_key)
257264
h2d_stream.wait_stream(d2h_stream)
258265
prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)
259266
# prefetch_swap_tensor.tensor.record_stream(h2d_stream)
260-
else:
261-
torch.distributed.breakpoint()
262267

263268
def empty(self):
264269
return len(self.items) == 0
@@ -291,7 +296,8 @@ def __init__(
291296
h2d_stream: torch.cuda.Stream,
292297
d2h_stream: torch.cuda.Stream,
293298
block_idx: int,
294-
depth: int,
299+
depth: int | None = None,
300+
group: str = "default",
295301
custom_check_fn=None,
296302
prefetch=True,
297303
) -> None:
@@ -302,19 +308,21 @@ def _pack_to_cpu(tensor):
302308
if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
303309
return tensor
304310

305-
key, after_block = OffloadManager().get_cnt(block_idx)
311+
key, after_block, prev_block_idx = OffloadManager().get_cnt(block_idx, group=group)
306312

307-
if after_block:
308-
OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)
313+
if after_block and (prev_block_idx is not None):
314+
OffloadManager().del_npu_tensor(f"{group}_{prev_block_idx}_", d2h_stream)
309315

310316
swap_tensor = SwapTensor(tensor, key)
317+
full_key = f"{group}_{key}"
311318

312-
if block_idx <= depth - 1:
319+
should_offload = depth is None or block_idx <= depth - 1
320+
if should_offload:
313321
working_stream = torch.cuda.current_stream()
314322
d2h_stream.wait_stream(working_stream)
315323
swap_tensor.launch_d2h(d2h_stream)
316324

317-
OffloadManager().put(key, swap_tensor)
325+
OffloadManager().put(full_key, swap_tensor)
318326
return swap_tensor
319327

320328
def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
@@ -328,14 +336,14 @@ def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
328336

329337
block_idx, tensor_idx = swap_tensor.key.split("_")
330338

331-
OffloadManager().del_may_npu_tensor(f"{int(block_idx) + 1}_", h2d_stream)
339+
OffloadManager().del_may_npu_tensor(f"{group}_{int(block_idx) + 1}_", h2d_stream)
332340
swap_tensor.launch_h2d(h2d_stream, True, working_stream)
333341
# if block_idx in ["0", "2", "3"]:
334342
# if block_idx in ["0"]:
335343
# torch.cuda.synchronize()
336344

337345
if prefetch and block_idx != 0:
338-
OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream)
346+
OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream, group=group)
339347

340348
# if block_idx in ["0"] and tensor_idx == "1":
341349
# swap_tensor.load()

0 commit comments

Comments
 (0)