diff --git a/xtuner/v1/model/compose/intern_s1/modeling_vision.py b/xtuner/v1/model/compose/intern_s1/modeling_vision.py index 9aef186df..c9ea76c51 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_vision.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_vision.py @@ -36,6 +36,8 @@ from xtuner.v1.ops.act_fn import get_act_fn from xtuner.v1.utils import get_logger from xtuner.v1.module import AttnOutputs +import os +from xtuner.v1.utils.activation_offload import async_save_on_cpu DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() @@ -230,6 +232,7 @@ def __init__(self, config: InternS1VisionConfig) -> None: dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers)) self.layer = nn.ModuleList([ InternS1VisionLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) + self.offload_stream = torch.cuda.Stream() def forward( self, @@ -241,8 +244,17 @@ def forward( for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore - - hidden_states = layer_module(hidden_states) + if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1: + with async_save_on_cpu( + h2d_stream=self.offload_stream, + d2h_stream=self.offload_stream, + block_idx=int(i), + group="vision", + custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(), + ): + hidden_states = layer_module(hidden_states) + else: + hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 29f5bc4a5..887fec942 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -432,7 +432,7 @@ def _micro_batch_forward( h2d_stream=self.offload_stream, d2h_stream=self.offload_stream, block_idx=layer_idx - self.config.first_k_dense_replace, - depth=len(self.layers) - self.config.first_k_dense_replace, + group="text", custom_check_fn=lambda x: x.data_ptr() in [hidden_states.data_ptr() for hidden_states in hidden_states_list], prefetch=True, @@ -579,7 +579,7 @@ def _forward( h2d_stream=self.offload_stream, d2h_stream=self.offload_stream, block_idx=int(idx), - depth=len(self.layers), + group="text", custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(), ): layer_results = decoder_layer( diff --git a/xtuner/v1/utils/activation_offload.py b/xtuner/v1/utils/activation_offload.py index d5a5055df..1fac7acdb 100644 --- a/xtuner/v1/utils/activation_offload.py +++ b/xtuner/v1/utils/activation_offload.py @@ -29,7 +29,9 @@ def __init__(self): self._block_tensor_nums = {} # offload tensors per block def get_cnt(self, block_idx): + prev_block_idx = None if self._block_idx == -1 else self._block_idx after_block = False + if block_idx > self._block_idx: self._block_tensor_nums[block_idx] = 1 if block_idx != 0: @@ -43,7 +45,7 @@ def get_cnt(self, block_idx): self._block_tensor_nums = {block_idx: 1} offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}" - return offload_tensor_key, after_block + return offload_tensor_key, after_block, prev_block_idx def get_prefetch_keys(self, block_idx, tensor_idx): prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None) @@ -193,11 +195,13 @@ def __init__(self, check=False): self.items = {} self.check = check self.device_item = [] - self.getcnt = GetCnt() + self.getcnt = {} self.may_npu_tensors = {} - def get_cnt(self, block_idx): - return self.getcnt.get_cnt(block_idx) + def get_cnt(self, block_idx, group="default"): + if group not in self.getcnt: + self.getcnt[group] = GetCnt() + return self.getcnt[group].get_cnt(block_idx) def assert_exist(self, key): if key not in self.items: @@ -249,16 +253,17 @@ def get(self, key): self.may_npu_tensors.update({key: self.items.pop(key)}) return act - def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream): - prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx) + def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream, group="default"): + if group not in self.getcnt: + return + prefetch_keys = self.getcnt[group].get_prefetch_keys(block_idx, tensor_idx) for prefetch_key in prefetch_keys: - if self.exist(prefetch_key): - prefetch_swap_tensor = self.get(prefetch_key) + full_key = f"{group}_{prefetch_key}" + if self.exist(full_key): + prefetch_swap_tensor = self.get(full_key) h2d_stream.wait_stream(d2h_stream) prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True) # prefetch_swap_tensor.tensor.record_stream(h2d_stream) - else: - torch.distributed.breakpoint() def empty(self): return len(self.items) == 0 @@ -291,7 +296,8 @@ def __init__( h2d_stream: torch.cuda.Stream, d2h_stream: torch.cuda.Stream, block_idx: int, - depth: int, + depth: int | None = None, + group: str = "default", custom_check_fn=None, prefetch=True, ) -> None: @@ -302,19 +308,21 @@ def _pack_to_cpu(tensor): if (custom_check_fn is not None) and (not custom_check_fn(tensor)): return tensor - key, after_block = OffloadManager().get_cnt(block_idx) + key, after_block, prev_block_idx = OffloadManager().get_cnt(block_idx, group=group) - if after_block: - OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream) + if after_block and (prev_block_idx is not None): + OffloadManager().del_npu_tensor(f"{group}_{prev_block_idx}_", d2h_stream) swap_tensor = SwapTensor(tensor, key) + full_key = f"{group}_{key}" - if block_idx <= depth - 1: + should_offload = depth is None or block_idx <= depth - 1 + if should_offload: working_stream = torch.cuda.current_stream() d2h_stream.wait_stream(working_stream) swap_tensor.launch_d2h(d2h_stream) - OffloadManager().put(key, swap_tensor) + OffloadManager().put(full_key, swap_tensor) return swap_tensor def _unpack_from_cpu(swap_tensor) -> torch.Tensor: @@ -328,14 +336,14 @@ def _unpack_from_cpu(swap_tensor) -> torch.Tensor: block_idx, tensor_idx = swap_tensor.key.split("_") - OffloadManager().del_may_npu_tensor(f"{int(block_idx) + 1}_", h2d_stream) + OffloadManager().del_may_npu_tensor(f"{group}_{int(block_idx) + 1}_", h2d_stream) swap_tensor.launch_h2d(h2d_stream, True, working_stream) # if block_idx in ["0", "2", "3"]: # if block_idx in ["0"]: # torch.cuda.synchronize() if prefetch and block_idx != 0: - OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream) + OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream, group=group) # if block_idx in ["0"] and tensor_idx == "1": # swap_tensor.load()