diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..1764d3e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ + +src +*.egg-info +__pycache__ +*/**/__pycache__ +outputs +train.bat +logs +gen.bat +gen_ref.bat diff --git a/README.md b/README.md index 0b6a8424..5d2b3dca 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is an implementtaion of Google's [Dreambooth](https://arxiv.org/abs/2208.12242) with [Stable Diffusion](https://github.com/CompVis/stable-diffusion). The original Dreambooth is based on [Imagen](https://imagen.research.google/) text-to-image model. However, neither the model nor the pre-trained weights of Imagen is available. To enable people to fine-tune a text-to-image model with a few examples, I implemented the idea of Dreambooth on Stable diffusion. -This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word ebedding, while dreambooth fine-tunes the whole diffusion model. +This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word embedding, while dreambooth fine-tunes the whole diffusion model. The implementation makes minimum changes over the official codebase of Textual Inversion. In fact, due to lazyness, some components in Textual Inversion, such as the embedding manager, are not deleted, although they will never be used here. ## Update diff --git a/configs/stable-diffusion/v1-finetune_unfrozen.yaml b/configs/stable-diffusion/v1-finetune_unfrozen.yaml index 780282f3..e327367e 100644 --- a/configs/stable-diffusion/v1-finetune_unfrozen.yaml +++ b/configs/stable-diffusion/v1-finetune_unfrozen.yaml @@ -77,7 +77,7 @@ data: target: main.DataModuleFromConfig params: batch_size: 1 - num_workers: 2 + num_workers: 1 wrap: false train: target: ldm.data.personalized.PersonalizedBase @@ -111,10 +111,11 @@ lightning: image_logger: target: main.ImageLogger params: - batch_frequency: 500 + batch_frequency: 200 max_images: 8 increase_log_steps: False trainer: benchmark: True max_steps: 800 +# precision: 'bf16' diff --git a/evaluation/__pycache__/clip_eval.cpython-36.pyc b/evaluation/__pycache__/clip_eval.cpython-36.pyc deleted file mode 100644 index d8f156dd..00000000 Binary files a/evaluation/__pycache__/clip_eval.cpython-36.pyc and /dev/null differ diff --git a/evaluation/__pycache__/clip_eval.cpython-38.pyc b/evaluation/__pycache__/clip_eval.cpython-38.pyc deleted file mode 100644 index 890bb4ef..00000000 Binary files a/evaluation/__pycache__/clip_eval.cpython-38.pyc and /dev/null differ diff --git a/ldm/__pycache__/util.cpython-36.pyc b/ldm/__pycache__/util.cpython-36.pyc deleted file mode 100644 index cd5ac82b..00000000 Binary files a/ldm/__pycache__/util.cpython-36.pyc and /dev/null differ diff --git a/ldm/__pycache__/util.cpython-38.pyc b/ldm/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 874c370f..00000000 Binary files a/ldm/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/__init__.cpython-36.pyc b/ldm/data/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index fa8037e3..00000000 Binary files a/ldm/data/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/__init__.cpython-38.pyc b/ldm/data/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index c576f9d6..00000000 Binary files a/ldm/data/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/base.cpython-36.pyc b/ldm/data/__pycache__/base.cpython-36.pyc deleted file mode 100644 index 62b31c28..00000000 Binary files a/ldm/data/__pycache__/base.cpython-36.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/base.cpython-38.pyc b/ldm/data/__pycache__/base.cpython-38.pyc deleted file mode 100644 index 47ff11cd..00000000 Binary files a/ldm/data/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized.cpython-36.pyc b/ldm/data/__pycache__/personalized.cpython-36.pyc deleted file mode 100644 index dbc38564..00000000 Binary files a/ldm/data/__pycache__/personalized.cpython-36.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized.cpython-38.pyc b/ldm/data/__pycache__/personalized.cpython-38.pyc deleted file mode 100644 index 4dac4480..00000000 Binary files a/ldm/data/__pycache__/personalized.cpython-38.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized_compose.cpython-38.pyc b/ldm/data/__pycache__/personalized_compose.cpython-38.pyc deleted file mode 100644 index 41adc506..00000000 Binary files a/ldm/data/__pycache__/personalized_compose.cpython-38.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc b/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc deleted file mode 100644 index 40586e19..00000000 Binary files a/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized_style.cpython-36.pyc b/ldm/data/__pycache__/personalized_style.cpython-36.pyc deleted file mode 100644 index 58daf2a7..00000000 Binary files a/ldm/data/__pycache__/personalized_style.cpython-36.pyc and /dev/null differ diff --git a/ldm/data/__pycache__/personalized_style.cpython-38.pyc b/ldm/data/__pycache__/personalized_style.cpython-38.pyc deleted file mode 100644 index 919d42fe..00000000 Binary files a/ldm/data/__pycache__/personalized_style.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/__pycache__/autoencoder.cpython-36.pyc b/ldm/models/__pycache__/autoencoder.cpython-36.pyc deleted file mode 100644 index 3ee1d08f..00000000 Binary files a/ldm/models/__pycache__/autoencoder.cpython-36.pyc and /dev/null differ diff --git a/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/ldm/models/__pycache__/autoencoder.cpython-38.pyc deleted file mode 100644 index 7a2b263f..00000000 Binary files a/ldm/models/__pycache__/autoencoder.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc b/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 5fa5d9e4..00000000 Binary files a/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index fac69f73..00000000 Binary files a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc b/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc deleted file mode 100644 index ae85930e..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc deleted file mode 100644 index 3581d481..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc deleted file mode 100644 index 7134101e..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc b/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc deleted file mode 100644 index 8d7e6895..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc deleted file mode 100644 index 6bb4defb..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc deleted file mode 100644 index 57b7785e..00000000 Binary files a/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc b/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc deleted file mode 100644 index f04c88a7..00000000 Binary files a/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc and /dev/null differ diff --git a/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc b/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc deleted file mode 100644 index 66f45472..00000000 Binary files a/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/attention.cpython-36.pyc b/ldm/modules/__pycache__/attention.cpython-36.pyc deleted file mode 100644 index ba18a7e2..00000000 Binary files a/ldm/modules/__pycache__/attention.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/attention.cpython-38.pyc b/ldm/modules/__pycache__/attention.cpython-38.pyc deleted file mode 100644 index dddaa85f..00000000 Binary files a/ldm/modules/__pycache__/attention.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/ema.cpython-36.pyc b/ldm/modules/__pycache__/ema.cpython-36.pyc deleted file mode 100644 index e6a17e24..00000000 Binary files a/ldm/modules/__pycache__/ema.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/ema.cpython-38.pyc b/ldm/modules/__pycache__/ema.cpython-38.pyc deleted file mode 100644 index ad63cba3..00000000 Binary files a/ldm/modules/__pycache__/ema.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc b/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc deleted file mode 100644 index 14a1688a..00000000 Binary files a/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc b/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc deleted file mode 100644 index 3517a502..00000000 Binary files a/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/x_transformer.cpython-36.pyc b/ldm/modules/__pycache__/x_transformer.cpython-36.pyc deleted file mode 100644 index 48f9d94d..00000000 Binary files a/ldm/modules/__pycache__/x_transformer.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/__pycache__/x_transformer.cpython-38.pyc b/ldm/modules/__pycache__/x_transformer.cpython-38.pyc deleted file mode 100644 index 799863b7..00000000 Binary files a/ldm/modules/__pycache__/x_transformer.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39c..58d6a084 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,8 +1,8 @@ from inspect import isfunction import math -import torch +import torch, gc import torch.nn.functional as F -from torch import nn, einsum +from torch import nn, einsum, autocast from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint @@ -170,28 +170,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 6aad9041..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index d1e31225..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc deleted file mode 100644 index 3e8b5530..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc deleted file mode 100644 index ad09f37a..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc deleted file mode 100644 index b399666c..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc deleted file mode 100644 index 4e18bc60..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc deleted file mode 100644 index 4c636b05..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 20254a96..00000000 Binary files a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a..96e8fa5b 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,6 +1,7 @@ # pytorch_diffusion + derived encoder decoder import math -import torch +from re import T +import torch, gc import torch.nn as nn import numpy as np from einops import rearrange @@ -32,7 +33,10 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + return x def Normalize(in_channels, num_groups=32): @@ -119,18 +123,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +154,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -178,28 +194,65 @@ def __init__(self, in_channels): def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_ = self.proj_out(h_) + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 - return x+h_ + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +593,52 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: + t = h h = torch.tanh(h) + del t return h diff --git a/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc b/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index e6de2112..00000000 Binary files a/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 4679c917..00000000 Binary files a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc b/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc deleted file mode 100644 index b730ba09..00000000 Binary files a/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc deleted file mode 100644 index 3967ec46..00000000 Binary files a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 5535ae3b..00000000 Binary files a/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 3a7163a4..00000000 Binary files a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc deleted file mode 100644 index 069706b5..00000000 Binary files a/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc and /dev/null differ diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index 18edf47e..00000000 Binary files a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/ldm/util.py b/ldm/util.py index aa0963a0..39885c35 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -22,7 +22,8 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + #font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) diff --git a/main.py b/main.py index 506b44b9..c605a99e 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import argparse, os, sys, datetime, glob, importlib, csv +import argparse, os, sys, datetime, glob, importlib, csv, gc import numpy as np import time import torch @@ -11,6 +11,7 @@ from torch.utils.data import random_split, DataLoader, Dataset, Subset from functools import partial from PIL import Image +from torch import autocast from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer @@ -272,6 +273,7 @@ def _train_dataloader(self): reg_set = self.datasets["reg"] concat_dataset = ConcatDataset(train_set, reg_set) return DataLoader(concat_dataset, batch_size=self.batch_size, + pin_memory=False, num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, worker_init_fn=init_fn) @@ -427,9 +429,9 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): - images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) + images[k] = images[k].detach().cpu() self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx) @@ -452,6 +454,7 @@ def check_frequency(self, check_idx): return False def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + pass if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") @@ -472,6 +475,8 @@ def on_train_epoch_start(self, trainer, pl_module): self.start_time = time.time() def on_train_epoch_end(self, trainer, pl_module): + gc.collect() + torch.cuda.empty_cache() torch.cuda.synchronize(trainer.root_gpu) max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 epoch_time = time.time() - self.start_time @@ -608,7 +613,7 @@ def on_train_epoch_start(self, trainer, pl_module): # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp - trainer_config["accelerator"] = "ddp" + #trainer_config["accelerator"] = "ddp" for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not "gpus" in trainer_config: @@ -713,7 +718,7 @@ def on_train_epoch_start(self, trainer, pl_module): "params": { "batch_frequency": 750, "max_images": 4, - "clamp": True + "clamp": True, } }, "learning_rate_logger": { @@ -821,18 +826,20 @@ def divein(*args, **kwargs): import signal - signal.signal(signal.SIGUSR1, melk) - signal.signal(signal.SIGUSR2, divein) + signal.signal(signal.SIGTERM, melk) + signal.signal(signal.SIGTERM, divein) # run if opt.train: try: - trainer.fit(model, data) +# with autocast('cuda'): + trainer.fit(model, data) except Exception: melk() raise if not opt.no_test and not trainer.interrupted: - trainer.test(model, data) + pass + #trainer.test(model, data) except Exception: if opt.debug and trainer.global_rank == 0: try: