From bb0f956a502180a4faf607557586abd7d0762eb5 Mon Sep 17 00:00:00 2001 From: GuangyaoDou Date: Tue, 22 Oct 2024 21:16:14 -0400 Subject: [PATCH] fix device override in model.py --- mimir/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mimir/models.py b/mimir/models.py index 11b4c2e..a782d4c 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -179,8 +179,7 @@ def load_base_model_and_tokenizer(self, model_kwargs): elif "llama" in self.name or "alpaca" in self.name: # TODO: This should be smth specified in config in case user has # llama is too big, gotta use device map - model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir) - self.device = 'cuda:1' + model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir) elif "stablelm" in self.name.lower(): # models requiring custom code model = transformers.AutoModelForCausalLM.from_pretrained( self.name, **model_kwargs, trust_remote_code=True, device_map=device_map, cache_dir=self.cache_dir)