forked from maybleMyers/chromaforge
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchromadct_optimization_integration.py
More file actions
83 lines (67 loc) · 2.41 KB
/
chromadct_optimization_integration.py
File metadata and controls
83 lines (67 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
ChromaDCT Memory Optimization Integration
Automatically applies memory optimizations to ChromaDCT models during loading
"""
import torch
from typing import Optional
# Import the optimization components
try:
from myerflow.src.models.chroma.memory_optimization_helper import (
apply_memory_optimization,
should_use_optimized_offloading,
print_memory_optimization_info
)
OPTIMIZATION_AVAILABLE = True
except ImportError:
print("Warning: ChromaDCT optimization components not found")
OPTIMIZATION_AVAILABLE = False
def patch_chromadct_model_loading():
"""
Patch the model loading system to automatically apply ChromaDCT optimizations
"""
if not OPTIMIZATION_AVAILABLE:
return
# This would be called during model initialization
# For now, this serves as a reference implementation
pass
def optimize_model_if_needed(model, device: torch.device, strategy: Optional[str] = None):
"""
Apply optimization to model if it's a ChromaDCT model
Args:
model: Model to potentially optimize
device: Target device
strategy: Offloading strategy (auto-detect if None)
Returns:
Optimized model or original model
"""
if not OPTIMIZATION_AVAILABLE:
return model
return apply_memory_optimization(model, device, strategy)
# Example usage for manual optimization:
def example_optimize_chromadct():
"""
Example of how to manually apply ChromaDCT optimization
"""
print("ChromaDCT Memory Optimization Example")
print("=====================================")
# This would typically be called during model loading:
#
# # Load your ChromaDCT model
# model = load_chromadct_model()
# device = torch.device('cuda')
#
# # Apply optimization
# optimized_model = optimize_model_if_needed(model, device, strategy='balanced')
#
# # Print optimization status
# print_memory_optimization_info(optimized_model)
#
# # Use optimized model for inference
# result = optimized_model(img, img_ids, txt, txt_ids, txt_mask, timesteps, guidance)
#
# # Print performance stats
# if hasattr(optimized_model, 'print_performance_summary'):
# optimized_model.print_performance_summary()
print("See chromadct_optimization_integration.py for implementation details")
if __name__ == "__main__":
example_optimize_chromadct()