2727from unittest .mock import patch
2828
2929import pandas as pd
30+ from inference_endpoint .dataset_manager .dataset import Dataset
3031from inference_endpoint .dataset_manager .predefined .shopify_product_catalogue import (
3132 BaseShopifyProductCatalogue ,
3233 ShopifyProductCatalogue ,
3334 ShopifyProductCatalogue8k ,
3435)
35- from inference_endpoint .dataset_manager .dataset import Dataset
3636from inference_endpoint .dataset_manager .predefined .shopify_product_catalogue .presets import (
3737 ShopifyMultimodalFormatter ,
3838 q3vl ,
@@ -372,7 +372,9 @@ def test_class_inherits_from_base(self) -> None:
372372
373373 def test_has_correct_repo_id (self ) -> None :
374374 """REPO_ID points to nvidia/Shopify-product-catalogue-8k."""
375- assert ShopifyProductCatalogue8k .REPO_ID == "nvidia/Shopify-product-catalogue-8k"
375+ assert (
376+ ShopifyProductCatalogue8k .REPO_ID == "nvidia/Shopify-product-catalogue-8k"
377+ )
376378
377379 def test_has_correct_dataset_id (self ) -> None :
378380 """DATASET_ID is shopify_product_catalogue_8k."""
@@ -381,11 +383,17 @@ def test_has_correct_dataset_id(self) -> None:
381383 def test_registered_in_dataset_predefined (self ) -> None :
382384 """Class is auto-registered in Dataset.PREDEFINED."""
383385 assert "shopify_product_catalogue_8k" in Dataset .PREDEFINED
384- assert Dataset .PREDEFINED ["shopify_product_catalogue_8k" ] is ShopifyProductCatalogue8k
386+ assert (
387+ Dataset .PREDEFINED ["shopify_product_catalogue_8k" ]
388+ is ShopifyProductCatalogue8k
389+ )
385390
386391 def test_shares_column_names_with_base (self ) -> None :
387392 """Column names are identical to ShopifyProductCatalogue."""
388- assert ShopifyProductCatalogue8k .COLUMN_NAMES == ShopifyProductCatalogue .COLUMN_NAMES
393+ assert (
394+ ShopifyProductCatalogue8k .COLUMN_NAMES
395+ == ShopifyProductCatalogue .COLUMN_NAMES
396+ )
389397
390398 def test_shares_presets_with_base (self ) -> None :
391399 """Presets are shared with base class (q3vl works)."""
@@ -430,7 +438,12 @@ def test_generate_uses_correct_dataset_id_for_paths(
430438 force = True ,
431439 )
432440 # Verify cache path uses shopify_product_catalogue_8k
433- expected_path = tmp_path / "shopify_product_catalogue_8k" / "train" / "shopify_product_catalogue_8k_train.parquet"
441+ expected_path = (
442+ tmp_path
443+ / "shopify_product_catalogue_8k"
444+ / "train"
445+ / "shopify_product_catalogue_8k_train.parquet"
446+ )
434447 assert expected_path .exists ()
435448
436449 def test_get_dataloader_with_q3vl_preset (self , tmp_path : Path ) -> None :
0 commit comments