Skip to content

Commit 95497ab

Browse files
authored
Fixed the bug caused by cpu offloading (#4063)
1 parent 1b547e3 commit 95497ab

3 files changed

Lines changed: 17 additions & 11 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,21 @@ def cross_compile_for_windows(
373373
gm = exported_program.module()
374374
logger.debug("Input graph: " + str(gm.graph))
375375

376+
# Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
377+
if offload_module_to_cpu:
378+
deallocate_module(exported_program.module())
376379
# Apply lowering on the graph module
377380
gm = post_lowering(gm, settings)
381+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
378382
logger.debug("Lowered Input graph: " + str(gm.graph))
383+
379384
# Move the weights in the state_dict to CPU
380385
if offload_module_to_cpu:
381-
deallocate_module(exported_program.module(), delete_module=False)
386+
deallocate_module(gm)
382387
logger.info(
383388
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
384389
)
390+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
385391
else:
386392
remaining_memory, total_memory = torch.cuda.mem_get_info()
387393
if remaining_memory < total_memory // 2:
@@ -766,15 +772,17 @@ def compile(
766772
# Move the weights in the state_dict to CPU
767773
logger.debug("Input graph: " + str(gm.graph))
768774

775+
# Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
776+
if offload_module_to_cpu:
777+
deallocate_module(exported_program.module())
769778
# Apply lowering on the graph module
770779
gm = post_lowering(gm, settings)
771780
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
772781
logger.debug("Lowered Input graph: " + str(gm.graph))
773782

774783
# Move the weights in the state_dict to CPU
775784
if offload_module_to_cpu:
776-
deallocate_module(gm, delete_module=False)
777-
deallocate_module(exported_program.module(), delete_module=False)
785+
deallocate_module(gm)
778786
logger.info(
779787
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
780788
)
@@ -1419,7 +1427,7 @@ def convert_exported_program_to_serialized_trt_engine(
14191427

14201428
# Move the weights in the state_dict to CPU
14211429
if offload_module_to_cpu:
1422-
deallocate_module(exported_program.module(), delete_module=False)
1430+
deallocate_module(exported_program.module())
14231431
logger.info(
14241432
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
14251433
)

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def update_refit_condition(self) -> None:
262262
args, kwargs, result = self.run_info
263263
self.original_model.to(to_torch_device(self.trt_device))
264264
new_result = self.original_model(*args, **kwargs)
265-
deallocate_module(self.original_model, delete_module=False)
265+
deallocate_module(self.original_model)
266266
if check_output_equal(result, new_result):
267267
self.refit_state.set_state(RefitFlag.LIVE)
268268
return
@@ -311,7 +311,7 @@ def refit_gm(self) -> None:
311311
in_place=True,
312312
)
313313

314-
deallocate_module(self.original_model, delete_module=False)
314+
deallocate_module(self.original_model)
315315

316316
def get_exported_program(self) -> torch.export.ExportedProgram:
317317

@@ -372,7 +372,7 @@ def compile(self) -> None:
372372
**self.additional_settings,
373373
)
374374
if self.additional_settings.get("offload_module_to_cpu", False):
375-
deallocate_module(self.original_model, delete_module=False)
375+
deallocate_module(self.original_model)
376376
if self.enable_weight_streaming:
377377
self.set_weight_streaming_ctx(self.weight_streaming_budget)
378378

@@ -738,7 +738,7 @@ def load(path: str) -> Any:
738738
module.exp_program = torch.export.export(
739739
module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs
740740
)
741-
deallocate_module(module.original_model, delete_module=False)
741+
deallocate_module(module.original_model)
742742
cls = module.__class__
743743
module.__class__ = type(
744744
module.original_model.__class__.__name__,

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,12 @@ def unified_dtype_converter(
127127
raise TypeError("%s is not a supported dtype" % dtype)
128128

129129

130-
def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
130+
def deallocate_module(module: torch.fx.GraphModule) -> None:
131131
"""
132132
This is a helper function to delete the instance of module. We first move it to CPU and then
133133
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
134134
"""
135135
module.to(CPU_DEVICE)
136-
if delete_module:
137-
del module
138136
torch.cuda.empty_cache()
139137
gc.collect()
140138

0 commit comments

Comments
 (0)