Skip to content

Commit 7efd364

Browse files
authored
Set cache directory for Hugging Face models in ResNet pipelines.
1 parent 12c8cb8 commit 7efd364

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

experiments/datascope/experiments/pipelines/pipelines.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import joblib
44
import numpy as np
5+
import os
56
import re
67
import sklearn.pipeline
78
import transformers
@@ -38,6 +39,8 @@
3839

3940
transformers.utils.logging.set_verbosity_error()
4041

42+
HUGGINGFACE_CACHE_DIR = os.path.join("var", "hf-cache")
43+
4144

4245
class ProvenancePipeline(sklearn.pipeline.Pipeline):
4346

@@ -402,7 +405,7 @@ class ResNet18EmbeddingPipeline(
402405

403406
@classmethod
404407
def get_model(cls: Type["ResNet18EmbeddingPipeline"]) -> PreTrainedModel:
405-
model = ResNetModel.from_pretrained("microsoft/resnet-18")
408+
model = ResNetModel.from_pretrained("microsoft/resnet-18", cache_dir=HUGGINGFACE_CACHE_DIR)
406409
return model
407410

408411

@@ -415,7 +418,7 @@ class ResNet50EmbeddingPipeline(
415418

416419
@classmethod
417420
def get_model(cls: Type["ResNet50EmbeddingPipeline"]) -> PreTrainedModel:
418-
model = ResNetModel.from_pretrained("microsoft/resnet-50")
421+
model = ResNetModel.from_pretrained("microsoft/resnet-50", cache_dir=HUGGINGFACE_CACHE_DIR)
419422
return model
420423

421424

0 commit comments

Comments
 (0)