|
| 1 | +import os, json, urllib, zipfile |
| 2 | +import urllib.request |
| 3 | +from PIL import Image |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pydynet as pdn |
| 7 | +import pydynet.nn.functional as F |
| 8 | + |
| 9 | +from .tokenizer import SimpleTokenizer |
| 10 | +from .model import CLIP |
| 11 | + |
| 12 | + |
| 13 | +def download(url: str, filename: str, chunk_size: int = 10**6) -> None: |
| 14 | + # Create directories if they don't exist yet |
| 15 | + directories = os.path.dirname(filename) |
| 16 | + if directories: |
| 17 | + os.makedirs(directories, exist_ok=True) |
| 18 | + |
| 19 | + # Download the file |
| 20 | + with urllib.request.urlopen(url) as response: |
| 21 | + total = int(response.info()["Content-Length"]) |
| 22 | + |
| 23 | + buf = b"" |
| 24 | + while True: |
| 25 | + data = response.read(chunk_size) |
| 26 | + if not data: |
| 27 | + break |
| 28 | + buf += data |
| 29 | + print(f"Downloading {filename} {len(buf) / total * 100:.2f} %") |
| 30 | + |
| 31 | + # Write the downloaded data to the file |
| 32 | + with open(filename, "wb") as f: |
| 33 | + f.write(buf) |
| 34 | + |
| 35 | + |
| 36 | +def load_zip(path: str): |
| 37 | + files = {} |
| 38 | + |
| 39 | + with zipfile.ZipFile(path) as z: |
| 40 | + for file_info in z.infolist(): |
| 41 | + with z.open(file_info) as f: |
| 42 | + path = file_info.filename |
| 43 | + files[path] = f.read() |
| 44 | + |
| 45 | + return files |
| 46 | + |
| 47 | + |
| 48 | +class Params: |
| 49 | + |
| 50 | + def __init__(self, name: str, download_root: str = None) -> None: |
| 51 | + assert name == "ViT-B/32", f"Model {name} not supported yet. Only ViT-B-32 currently supported." |
| 52 | + |
| 53 | + model_urls = { |
| 54 | + "RN50": |
| 55 | + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", |
| 56 | + "RN101": |
| 57 | + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", |
| 58 | + "RN50x4": |
| 59 | + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", |
| 60 | + "RN50x16": |
| 61 | + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", |
| 62 | + "RN50x64": |
| 63 | + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", |
| 64 | + "ViT-B/32": |
| 65 | + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", |
| 66 | + "ViT-B/16": |
| 67 | + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", |
| 68 | + "ViT-L/14": |
| 69 | + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", |
| 70 | + "ViT-L/14@336px": |
| 71 | + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", |
| 72 | + } |
| 73 | + |
| 74 | + model_url = model_urls[name] |
| 75 | + |
| 76 | + name = name.replace("/", "-") |
| 77 | + |
| 78 | + if download_root is None: |
| 79 | + download_root = os.path.expanduser(f"~/.cache/clip") |
| 80 | + download_root = os.environ.get("CLIP_DIR", download_root) |
| 81 | + |
| 82 | + model_path = os.path.join(download_root, f"{name}.pt") |
| 83 | + |
| 84 | + if not os.path.isfile(model_path): |
| 85 | + print(f"Downloading {model_path} from {model_url}") |
| 86 | + download(model_url, model_path) |
| 87 | + |
| 88 | + self.files = load_zip(model_path) |
| 89 | + |
| 90 | + with open(f"{download_root}/{name}.json") as f: |
| 91 | + self.info = json.load(f) |
| 92 | + |
| 93 | + def get_int(self, name: str) -> int: |
| 94 | + info = self.info[name] |
| 95 | + |
| 96 | + value: int = info["value"] |
| 97 | + |
| 98 | + return value |
| 99 | + |
| 100 | + def __getitem__(self, name: str): |
| 101 | + info = self.info[name] |
| 102 | + |
| 103 | + path = info["path"] |
| 104 | + dtype = info["dtype"] |
| 105 | + shape = info["shape"] |
| 106 | + start = info["start"] |
| 107 | + end = info["end"] |
| 108 | + |
| 109 | + assert dtype in ["float16", "float32"] |
| 110 | + |
| 111 | + data = self.files[path][start:end] |
| 112 | + |
| 113 | + arr = np.frombuffer(data, dtype=dtype).reshape(shape) |
| 114 | + arr = arr.astype(np.float32) |
| 115 | + |
| 116 | + return arr |
| 117 | + |
| 118 | + |
| 119 | +def tokenize(texts: list[str], context_length: int = 77): |
| 120 | + tokenizer = SimpleTokenizer() |
| 121 | + |
| 122 | + sot_token = tokenizer.encoder["<|startoftext|>"] |
| 123 | + eot_token = tokenizer.encoder["<|endoftext|>"] |
| 124 | + |
| 125 | + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] |
| 126 | + for text in texts] |
| 127 | + |
| 128 | + result = np.zeros((len(all_tokens), context_length), dtype=np.int64) |
| 129 | + |
| 130 | + for i, tokens in enumerate(all_tokens): |
| 131 | + if len(tokens) > context_length: |
| 132 | + raise RuntimeError( |
| 133 | + f"Input {texts[i]} is too long for context length {context_length}" |
| 134 | + ) |
| 135 | + |
| 136 | + result[i, :len(tokens)] = tokens |
| 137 | + |
| 138 | + return result |
| 139 | + |
| 140 | + |
| 141 | +def preprocess(image: Image.Image, image_size: int = 224): |
| 142 | + # Scale image such that length of smaller side is 224 |
| 143 | + width, height = image.size |
| 144 | + scale = image_size / min(width, height) |
| 145 | + width = int(scale * width) |
| 146 | + height = int(scale * height) |
| 147 | + # Some Pillow versions have different interface |
| 148 | + if hasattr(Image, "Resampling"): |
| 149 | + image = image.resize((width, height), Image.Resampling.BICUBIC) |
| 150 | + else: |
| 151 | + image = image.resize((width, height), Image.BICUBIC) |
| 152 | + |
| 153 | + # Crop center |
| 154 | + x0 = round((width - image_size) / 2) |
| 155 | + y0 = round((height - image_size) / 2) |
| 156 | + x1 = x0 + image_size |
| 157 | + y1 = y0 + image_size |
| 158 | + image = image.crop((x0, y0, x1, y1)) |
| 159 | + |
| 160 | + image = image.convert("RGB") |
| 161 | + |
| 162 | + # Normalize |
| 163 | + x = np.array(image, dtype=np.float32) / 255.0 |
| 164 | + mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) |
| 165 | + std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) |
| 166 | + x = (x - mean) / std |
| 167 | + |
| 168 | + x = x.transpose(2, 0, 1) |
| 169 | + |
| 170 | + return pdn.Tensor(x, copy=None) |
| 171 | + |
| 172 | + |
| 173 | +@pdn.no_grad() |
| 174 | +def load_model(model: CLIP, param: Params): |
| 175 | + |
| 176 | + # with pdn.no_grad(): |
| 177 | + model.scale = pdn.exp(param["logit_scale"].astype(np.float32)) |
| 178 | + model.class_embed.data[0, 0] = param["visual.class_embedding"] |
| 179 | + model.v_pos_emb.data[...] = param["visual.positional_embedding"] |
| 180 | + model.t_pos_emb.data[...] = param["positional_embedding"] |
| 181 | + |
| 182 | + model.image_encoder.kernel.data[...] = param["visual.conv1.weight"] |
| 183 | + model.image_encoder.pre_norm.scale[...] = param["visual.ln_pre.weight"] |
| 184 | + model.image_encoder.pre_norm.shift[...] = param["visual.ln_pre.bias"] |
| 185 | + model.image_encoder.post_norm.scale[...] = param["visual.ln_post.weight"] |
| 186 | + model.image_encoder.post_norm.shift[...] = param["visual.ln_post.bias"] |
| 187 | + |
| 188 | + model.image_encoder.proj.weight[...] = param["visual.proj"] |
| 189 | + |
| 190 | + model.text_encoder.token_embed.weight[ |
| 191 | + ...] = param["token_embedding.weight"] |
| 192 | + model.text_encoder.post_norm.scale[...] = param["ln_final.weight"] |
| 193 | + model.text_encoder.post_norm.shift[...] = param["ln_final.bias"] |
| 194 | + model.text_encoder.proj.weight[...] = param["text_projection"] |
| 195 | + |
| 196 | + prefix = "transformer.resblocks." |
| 197 | + for i in range(12): |
| 198 | + ( |
| 199 | + model.image_encoder.transformers[i].mha.QKV.weight.data[...], |
| 200 | + model.image_encoder.transformers[i].mha.QKV.bias.data[...], |
| 201 | + model.image_encoder.transformers[i].mha.O.weight.data[...], |
| 202 | + model.image_encoder.transformers[i].mha.O.bias.data[...], |
| 203 | + model.image_encoder.transformers[i].layer_norm1.scale.data[...], |
| 204 | + model.image_encoder.transformers[i].layer_norm1.shift.data[...], |
| 205 | + model.image_encoder.transformers[i].layer_norm2.scale.data[...], |
| 206 | + model.image_encoder.transformers[i].layer_norm2.shift.data[...], |
| 207 | + model.image_encoder.transformers[i].mlp.fc1.weight.data[...], |
| 208 | + model.image_encoder.transformers[i].mlp.fc1.bias.data[...], |
| 209 | + model.image_encoder.transformers[i].mlp.fc2.weight.data[...], |
| 210 | + model.image_encoder.transformers[i].mlp.fc2.bias.data[...], |
| 211 | + model.text_encoder.transformers[i].mha.QKV.weight.data[...], |
| 212 | + model.text_encoder.transformers[i].mha.QKV.bias.data[...], |
| 213 | + model.text_encoder.transformers[i].mha.O.weight.data[...], |
| 214 | + model.text_encoder.transformers[i].mha.O.bias.data[...], |
| 215 | + model.text_encoder.transformers[i].layer_norm1.scale.data[...], |
| 216 | + model.text_encoder.transformers[i].layer_norm1.shift.data[...], |
| 217 | + model.text_encoder.transformers[i].layer_norm2.scale.data[...], |
| 218 | + model.text_encoder.transformers[i].layer_norm2.shift.data[...], |
| 219 | + model.text_encoder.transformers[i].mlp.fc1.weight.data[...], |
| 220 | + model.text_encoder.transformers[i].mlp.fc1.bias.data[...], |
| 221 | + model.text_encoder.transformers[i].mlp.fc2.weight.data[...], |
| 222 | + model.text_encoder.transformers[i].mlp.fc2.bias.data[...], |
| 223 | + ) = ( |
| 224 | + param["visual." + prefix + f"{i}.attn.in_proj_weight"].T, |
| 225 | + param["visual." + prefix + f"{i}.attn.in_proj_bias"], |
| 226 | + param["visual." + prefix + f"{i}.attn.out_proj.weight"].T, |
| 227 | + param["visual." + prefix + f"{i}.attn.out_proj.bias"], |
| 228 | + param["visual." + prefix + f"{i}.ln_1.weight"], |
| 229 | + param["visual." + prefix + f"{i}.ln_1.bias"], |
| 230 | + param["visual." + prefix + f"{i}.ln_2.weight"], |
| 231 | + param["visual." + prefix + f"{i}.ln_2.bias"], |
| 232 | + param["visual." + prefix + f"{i}.mlp.c_fc.weight"].T, |
| 233 | + param["visual." + prefix + f"{i}.mlp.c_fc.bias"], |
| 234 | + param["visual." + prefix + f"{i}.mlp.c_proj.weight"].T, |
| 235 | + param["visual." + prefix + f"{i}.mlp.c_proj.bias"], |
| 236 | + param[prefix + f"{i}.attn.in_proj_weight"].T, |
| 237 | + param[prefix + f"{i}.attn.in_proj_bias"], |
| 238 | + param[prefix + f"{i}.attn.out_proj.weight"].T, |
| 239 | + param[prefix + f"{i}.attn.out_proj.bias"], |
| 240 | + param[prefix + f"{i}.ln_1.weight"], |
| 241 | + param[prefix + f"{i}.ln_1.bias"], |
| 242 | + param[prefix + f"{i}.ln_2.weight"], |
| 243 | + param[prefix + f"{i}.ln_2.bias"], |
| 244 | + param[prefix + f"{i}.mlp.c_fc.weight"].T, |
| 245 | + param[prefix + f"{i}.mlp.c_fc.bias"], |
| 246 | + param[prefix + f"{i}.mlp.c_proj.weight"].T, |
| 247 | + param[prefix + f"{i}.mlp.c_proj.bias"], |
| 248 | + ) |
| 249 | + return model |
| 250 | + |
| 251 | + |
| 252 | +image = preprocess(Image.open("llm/clip/picture.png"))[np.newaxis, :, :, :] |
| 253 | +text = tokenize(["a fish", "a dog", "a cat"]) |
| 254 | +clip = load_model(CLIP(), Params("ViT-B/32", download_root='llm/clip/data')) |
| 255 | + |
| 256 | +with pdn.no_grad(): |
| 257 | + logits_per_image = clip(image, text) |
| 258 | + probs = F.softmax(logits_per_image, axis=-1) |
| 259 | + print("Label probs:", probs.numpy()[0]) |
0 commit comments