@@ -373,15 +373,21 @@ def cross_compile_for_windows(
373373 gm = exported_program .module ()
374374 logger .debug ("Input graph: " + str (gm .graph ))
375375
376+ # Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
377+ if offload_module_to_cpu :
378+ deallocate_module (exported_program .module ())
376379 # Apply lowering on the graph module
377380 gm = post_lowering (gm , settings )
381+ logger .debug (f"CPU memory usage after post_lowering: { get_cpu_memory_usage ()} MB" )
378382 logger .debug ("Lowered Input graph: " + str (gm .graph ))
383+
379384 # Move the weights in the state_dict to CPU
380385 if offload_module_to_cpu :
381- deallocate_module (exported_program . module (), delete_module = False )
386+ deallocate_module (gm )
382387 logger .info (
383388 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
384389 )
390+ logger .debug (f"CPU memory usage after CPU offload: { get_cpu_memory_usage ()} MB" )
385391 else :
386392 remaining_memory , total_memory = torch .cuda .mem_get_info ()
387393 if remaining_memory < total_memory // 2 :
@@ -766,15 +772,17 @@ def compile(
766772 # Move the weights in the state_dict to CPU
767773 logger .debug ("Input graph: " + str (gm .graph ))
768774
775+ # Move the weights in the state_dict to CPU. We should do this before post_lowering for KV cache support.
776+ if offload_module_to_cpu :
777+ deallocate_module (exported_program .module ())
769778 # Apply lowering on the graph module
770779 gm = post_lowering (gm , settings )
771780 logger .debug (f"CPU memory usage after post_lowering: { get_cpu_memory_usage ()} MB" )
772781 logger .debug ("Lowered Input graph: " + str (gm .graph ))
773782
774783 # Move the weights in the state_dict to CPU
775784 if offload_module_to_cpu :
776- deallocate_module (gm , delete_module = False )
777- deallocate_module (exported_program .module (), delete_module = False )
785+ deallocate_module (gm )
778786 logger .info (
779787 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
780788 )
@@ -1419,7 +1427,7 @@ def convert_exported_program_to_serialized_trt_engine(
14191427
14201428 # Move the weights in the state_dict to CPU
14211429 if offload_module_to_cpu :
1422- deallocate_module (exported_program .module (), delete_module = False )
1430+ deallocate_module (exported_program .module ())
14231431 logger .info (
14241432 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
14251433 )
0 commit comments