@@ -22,7 +22,7 @@ def __getstate__(self):
2222 return self .obj .items ()
2323
2424 def __setstate__ (self , items ):
25- if not hasattr (self , ' obj' ):
25+ if not hasattr (self , " obj" ):
2626 self .obj = {}
2727 for key , val in items :
2828 self .obj [key ] = val
@@ -43,11 +43,7 @@ def keys(self):
4343def download_from_huggingface (repo , filename , ** kwargs ):
4444 while True :
4545 try :
46- return huggingface_hub .hf_hub_download (
47- repo ,
48- filename = filename ,
49- ** kwargs
50- )
46+ return huggingface_hub .hf_hub_download (repo , filename = filename , ** kwargs )
5147 except HTTPError as e :
5248 if e .response .status_code == 401 :
5349 # Need to log into huggingface api
@@ -76,13 +72,17 @@ def download_from_huggingface(repo, filename, **kwargs):
7672def load_config_from_pretrained (pretrained_model_name_or_path ):
7773 if os .path .exists (pretrained_model_name_or_path ):
7874 if "config.json" not in pretrained_model_name_or_path :
79- pretrained_model_name_or_path = os .path .join (pretrained_model_name_or_path , "config.json" )
75+ pretrained_model_name_or_path = os .path .join (
76+ pretrained_model_name_or_path , "config.json"
77+ )
8078 else :
81- assert pretrained_model_name_or_path in MODELS , f"Choose from { list (MODELS .keys ())} "
79+ assert (
80+ pretrained_model_name_or_path in MODELS
81+ ), f"Choose from { list (MODELS .keys ())} "
8282 pretrained_model_name_or_path = download_from_huggingface (
8383 repo = MODELS [pretrained_model_name_or_path ]["repo" ],
8484 filename = "config.json" ,
85- subfolder = MODELS [pretrained_model_name_or_path ]["subfolder" ]
85+ subfolder = MODELS [pretrained_model_name_or_path ]["subfolder" ],
8686 )
8787 with open (pretrained_model_name_or_path , "r" , encoding = "utf-8" ) as f :
8888 pretrained_args = AttributeDict (json .load (f ))
@@ -91,9 +91,12 @@ def load_config_from_pretrained(pretrained_model_name_or_path):
9191
9292def load_e4t_unet (pretrained_model_name_or_path = None , ckpt_path = None , ** kwargs ):
9393 assert pretrained_model_name_or_path is not None or ckpt_path is not None
94- if pretrained_model_name_or_path is None or not os . path . exists ( ckpt_path ) :
94+ if pretrained_model_name_or_path is None :
9595 if os .path .exists (ckpt_path ):
96- assert os .path .basename (ckpt_path ) == "unet.pt" or os .path .basename (ckpt_path ) == "weight_offsets.pt" , "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)"
96+ assert (
97+ os .path .basename (ckpt_path ) == "unet.pt"
98+ or os .path .basename (ckpt_path ) == "weight_offsets.pt"
99+ ), "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)"
97100 config = load_config_from_pretrained (os .path .dirname (ckpt_path ))
98101 else :
99102 assert ckpt_path in MODELS , f"Choose from { list (MODELS .keys ())} "
@@ -102,16 +105,22 @@ def load_e4t_unet(pretrained_model_name_or_path=None, ckpt_path=None, **kwargs):
102105 ckpt_path = download_from_huggingface (
103106 repo = MODELS [ckpt_path ]["repo" ],
104107 filename = "weight_offsets.pt" ,
105- subfolder = MODELS [ckpt_path ]["subfolder" ]
108+ subfolder = MODELS [ckpt_path ]["subfolder" ],
106109 )
107110 except EntryNotFoundError :
108111 ckpt_path = download_from_huggingface (
109112 repo = MODELS [ckpt_path ]["repo" ],
110113 filename = "unet.pt" ,
111- subfolder = MODELS [ckpt_path ]["subfolder" ]
114+ subfolder = MODELS [ckpt_path ]["subfolder" ],
112115 )
113- pretrained_model_name_or_path = config .pretrained_model_name_or_path if config .pretrained_args is None else config .pretrained_args ["pretrained_model_name_or_path" ]
114- unet = OriginalUNet2DConditionModel .from_pretrained (pretrained_model_name_or_path , subfolder = "unet" , ** kwargs )
116+ pretrained_model_name_or_path = (
117+ config .pretrained_model_name_or_path
118+ if config .pretrained_args is None
119+ else config .pretrained_args ["pretrained_model_name_or_path" ]
120+ )
121+ unet = OriginalUNet2DConditionModel .from_pretrained (
122+ pretrained_model_name_or_path , subfolder = "unet" , ** kwargs
123+ )
115124 state_dict = dict (unet .state_dict ())
116125 if ckpt_path :
117126 ckpt_sd = torch .load (ckpt_path , map_location = "cpu" )
@@ -142,7 +151,7 @@ def load_e4t_encoder(ckpt_path=None, **kwargs):
142151 ckpt_path = download_from_huggingface (
143152 repo = MODELS [ckpt_path ]["repo" ],
144153 filename = "encoder.pt" ,
145- subfolder = MODELS [ckpt_path ]["subfolder" ]
154+ subfolder = MODELS [ckpt_path ]["subfolder" ],
146155 )
147156 state_dict = torch .load (ckpt_path , map_location = "cpu" )
148157 print (f"Resuming from { ckpt_path } " )
@@ -182,7 +191,7 @@ def image_grid(imgs, rows, cols):
182191 assert len (imgs ) == rows * cols
183192
184193 w , h = imgs [0 ].size
185- grid = Image .new (' RGB' , size = (cols * w , rows * h ))
194+ grid = Image .new (" RGB" , size = (cols * w , rows * h ))
186195 grid_w , grid_h = grid .size
187196
188197 for i , img in enumerate (imgs ):
0 commit comments