diff --git a/mimir/models.py b/mimir/models.py index 11b4c2e..a782d4c 100644 --- a/mimir/models.py +++ b/mimir/models.py @@ -179,8 +179,7 @@ def load_base_model_and_tokenizer(self, model_kwargs): elif "llama" in self.name or "alpaca" in self.name: # TODO: This should be smth specified in config in case user has # llama is too big, gotta use device map - model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir) - self.device = 'cuda:1' + model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir) elif "stablelm" in self.name.lower(): # models requiring custom code model = transformers.AutoModelForCausalLM.from_pretrained( self.name, **model_kwargs, trust_remote_code=True, device_map=device_map, cache_dir=self.cache_dir)