@@ -79,6 +79,14 @@ def get_labels(self, label_set: str) -> str | None:
7979 categories = [label .replace (" " , "_" ) for label in categories ]
8080 return " " .join (categories )
8181
82+ if label_set == "IMAGENET21K" :
83+ from timm .data import ImageNetInfo
84+
85+ info = ImageNetInfo ("imagenet21k" )
86+ categories = info .label_descriptions ()
87+ categories = [desc .split ("," )[0 ].strip ().replace (" " , "_" ) for desc in categories ]
88+ return " " .join (categories )
89+
8290 return None
8391
8492 def download_from_huggingface (
@@ -459,7 +467,7 @@ def create_calibration_dataset(
459467 return_labels: Whether to return labels along with images
460468
461469 Returns:
462- List of preprocessed image arrays, or tuple of (images, labels) if return_labels=True
470+ List of preprocessed image arrays, or tuple of (images, labels)
463471 """
464472 if not self .dataset_path or not self .dataset_path .exists ():
465473 self .logger .warning ("Dataset path not provided or doesn't exist. Skipping quantization." )
@@ -476,12 +484,12 @@ def create_calibration_dataset(
476484 image_dir = self .dataset_path
477485 if not image_dir .exists ():
478486 self .logger .error (f"Image directory not found: { image_dir } " )
479- return ([], []) if return_labels else []
487+ return ([], [])
480488
481489 image_entries = self ._collect_dataset_entries (image_dir )
482490 if not image_entries :
483491 self .logger .error ("No images found in dataset" )
484- return ([], []) if return_labels else []
492+ return ([], [])
485493
486494 self .logger .info (f"Found { len (image_entries )} images in dataset" )
487495 self .logger .info (f"Using { min (subset_size , len (image_entries ))} images for calibration" )
@@ -537,7 +545,7 @@ def create_calibration_dataset(
537545 continue
538546
539547 self .logger .info (f"✓ Created calibration dataset with { len (calibration_data )} images" )
540- return calibration_data
548+ return calibration_data , []
541549
542550 def validate_model (
543551 self ,
@@ -938,6 +946,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
938946 # Quantize the model if dataset is available
939947 if self .dataset_path :
940948 self .logger .info ("Creating calibration dataset for INT8 quantization" )
949+ has_labels = bool (config .get ("labels" ))
941950
942951 self .logger .info ("Creating validation dataset for accuracy measurement" )
943952 validation_data , validation_labels = self .create_calibration_dataset (
@@ -946,7 +955,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
946955 scale_values = scale_values ,
947956 reverse_input_channels = reverse_input_channels ,
948957 subset_size = 300 ,
949- return_labels = True ,
958+ return_labels = has_labels ,
950959 )
951960
952961 if validation_data :
@@ -957,7 +966,7 @@ def process_model_config(self, config: dict[str, Any]) -> bool:
957966 model_config = config ,
958967 preset = "mixed" ,
959968 validation_data = validation_data if validation_labels else None ,
960- validation_labels = validation_labels ,
969+ validation_labels = validation_labels or None ,
961970 )
962971
963972 # Clean up temporary FP32 model after quantization
0 commit comments