Skip to content

Commit 10e1df5

Browse files
authored
feat: add vit classification models to model exporter (#564)
* add vit models * add vit models to converter
1 parent f07b91b commit 10e1df5

2 files changed

Lines changed: 54 additions & 6 deletions

File tree

tools/model_converter/config.json

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,45 @@
579579
"license": "apache-2.0",
580580
"license_link": "https://spdx.org/licenses/Apache-2.0.html",
581581
"labels": "IMAGENET1K_V1"
582+
},
583+
{
584+
"model_short_name": "vit_tiny_patch16_224_augreg_in21k",
585+
"huggingface_repo": "timm/vit_tiny_patch16_224.augreg_in21k",
586+
"huggingface_revision": "3d5f75e2fe58abe541d5651356278a1df3fd3ab3",
587+
"model_library": "timm",
588+
"model_full_name": "ViT-Tiny Patch16 224 AugReg ImageNet-21k",
589+
"description": "Vision Transformer Tiny with 16x16 patches trained on ImageNet-21k with augmentation and regularization",
590+
"docs": "https://huggingface.co/timm/vit_tiny_patch16_224.augreg_in21k",
591+
"input_shape": [1, 3, 224, 224],
592+
"input_names": ["image"],
593+
"output_names": ["logits"],
594+
"model_params": null,
595+
"model_type": "Classification",
596+
"reverse_input_channels": true,
597+
"mean_values": "123.675 116.28 103.53",
598+
"scale_values": "58.395 57.12 57.375",
599+
"license": "apache-2.0",
600+
"license_link": "https://spdx.org/licenses/Apache-2.0.html",
601+
"labels": "IMAGENET21K"
602+
},
603+
{
604+
"model_short_name": "vit_small_patch14_dinov2.lvd142m",
605+
"huggingface_repo": "timm/vit_small_patch14_dinov2.lvd142m",
606+
"huggingface_revision": "4610ca143709d58a633b6397a74412c2c3842454",
607+
"model_library": "timm",
608+
"model_full_name": "DINOv2-Small Patch14 518 LVD-142M",
609+
"description": "DINOv2 Small ViT backbone for image feature extraction with 384-dimensional features",
610+
"docs": "https://huggingface.co/timm/vit_small_patch14_dinov2.lvd142m",
611+
"input_shape": [1, 3, 518, 518],
612+
"input_names": ["image"],
613+
"output_names": ["output"],
614+
"model_params": null,
615+
"model_type": "Classification",
616+
"reverse_input_channels": true,
617+
"mean_values": "123.675 116.28 103.53",
618+
"scale_values": "58.395 57.12 57.375",
619+
"license": "apache-2.0",
620+
"license_link": "https://spdx.org/licenses/Apache-2.0.html"
582621
}
583622
]
584623
}

tools/model_converter/model_converter.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def get_labels(self, label_set: str) -> str | None:
7979
categories = [label.replace(" ", "_") for label in categories]
8080
return " ".join(categories)
8181

82+
if label_set == "IMAGENET21K":
83+
from timm.data import ImageNetInfo
84+
85+
info = ImageNetInfo("imagenet21k")
86+
categories = info.label_descriptions()
87+
categories = [desc.split(",")[0].strip().replace(" ", "_") for desc in categories]
88+
return " ".join(categories)
89+
8290
return None
8391

8492
def download_from_huggingface(
@@ -459,7 +467,7 @@ def create_calibration_dataset(
459467
return_labels: Whether to return labels along with images
460468
461469
Returns:
462-
List of preprocessed image arrays, or tuple of (images, labels) if return_labels=True
470+
List of preprocessed image arrays, or tuple of (images, labels)
463471
"""
464472
if not self.dataset_path or not self.dataset_path.exists():
465473
self.logger.warning("Dataset path not provided or doesn't exist. Skipping quantization.")
@@ -476,12 +484,12 @@ def create_calibration_dataset(
476484
image_dir = self.dataset_path
477485
if not image_dir.exists():
478486
self.logger.error(f"Image directory not found: {image_dir}")
479-
return ([], []) if return_labels else []
487+
return ([], [])
480488

481489
image_entries = self._collect_dataset_entries(image_dir)
482490
if not image_entries:
483491
self.logger.error("No images found in dataset")
484-
return ([], []) if return_labels else []
492+
return ([], [])
485493

486494
self.logger.info(f"Found {len(image_entries)} images in dataset")
487495
self.logger.info(f"Using {min(subset_size, len(image_entries))} images for calibration")
@@ -537,7 +545,7 @@ def create_calibration_dataset(
537545
continue
538546

539547
self.logger.info(f"✓ Created calibration dataset with {len(calibration_data)} images")
540-
return calibration_data
548+
return calibration_data, []
541549

542550
def validate_model(
543551
self,
@@ -938,6 +946,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
938946
# Quantize the model if dataset is available
939947
if self.dataset_path:
940948
self.logger.info("Creating calibration dataset for INT8 quantization")
949+
has_labels = bool(config.get("labels"))
941950

942951
self.logger.info("Creating validation dataset for accuracy measurement")
943952
validation_data, validation_labels = self.create_calibration_dataset(
@@ -946,7 +955,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
946955
scale_values=scale_values,
947956
reverse_input_channels=reverse_input_channels,
948957
subset_size=300,
949-
return_labels=True,
958+
return_labels=has_labels,
950959
)
951960

952961
if validation_data:
@@ -957,7 +966,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
957966
model_config=config,
958967
preset="mixed",
959968
validation_data=validation_data if validation_labels else None,
960-
validation_labels=validation_labels,
969+
validation_labels=validation_labels or None,
961970
)
962971

963972
# Clean up temporary FP32 model after quantization

0 commit comments

Comments
 (0)