From bbbdfac392e647afd85c57c62c472503796f64b6 Mon Sep 17 00:00:00 2001 From: dario-fumarola Date: Thu, 26 Feb 2026 10:05:11 -0500 Subject: [PATCH] Fix Chronos2 fine-tuning to preserve loaded CUDA device --- src/chronos/chronos2/trainer.py | 22 ++++++++++++ test/test_chronos2_trainer.py | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 test/test_chronos2_trainer.py diff --git a/src/chronos/chronos2/trainer.py b/src/chronos/chronos2/trainer.py index bd4e4647..6ce7c910 100644 --- a/src/chronos/chronos2/trainer.py +++ b/src/chronos/chronos2/trainer.py @@ -74,6 +74,28 @@ def get_train_dataloader(self) -> DataLoader: return DataLoader(train_dataset, **dataloader_params) # type: ignore + def _move_model_to_device(self, model, device): + """ + Keep the model on its existing CUDA device when fine-tuning a single-device model. + + `Trainer` may otherwise move a model loaded on e.g. `cuda:5` to `args.device` (often `cuda:0`). + """ + model_device = getattr(model, "device", None) + model_device_type = getattr(model_device, "type", None) + target_device_type = getattr(device, "type", None) + has_hf_device_map = getattr(model, "hf_device_map", None) is not None + + if ( + not has_hf_device_map + and model_device is not None + and model_device_type == "cuda" + and target_device_type == "cuda" + and model_device != device + ): + device = model_device + + super()._move_model_to_device(model, device) + def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader: if self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") diff --git a/test/test_chronos2_trainer.py b/test/test_chronos2_trainer.py new file mode 100644 index 00000000..5957df88 --- /dev/null +++ b/test/test_chronos2_trainer.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from transformers import Trainer + +from chronos.chronos2.trainer import Chronos2Trainer + + +class _DummyModel: + def __init__(self, device: torch.device, hf_device_map=None): + self.device = device + self.hf_device_map = hf_device_map + + +def test_move_model_to_device_preserves_loaded_cuda_device(monkeypatch): + """When model is on a single CUDA device, keep that device instead of forcing cuda:0.""" + captured = {} + + def fake_move(self, model, device): + captured["device"] = device + + monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move) + + trainer = object.__new__(Chronos2Trainer) + model = _DummyModel(torch.device("cuda:5")) + + Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0")) + + assert captured["device"] == torch.device("cuda:5") + + +def test_move_model_to_device_keeps_requested_cpu_device(monkeypatch): + """CPU fine-tuning should preserve existing Trainer behavior.""" + captured = {} + + def fake_move(self, model, device): + captured["device"] = device + + monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move) + + trainer = object.__new__(Chronos2Trainer) + model = _DummyModel(torch.device("cpu")) + + Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cpu")) + + assert captured["device"] == torch.device("cpu") + + +def test_move_model_to_device_keeps_requested_device_for_hf_device_map(monkeypatch): + """Do not override device movement for models managed via hf_device_map.""" + captured = {} + + def fake_move(self, model, device): + captured["device"] = device + + monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move) + + trainer = object.__new__(Chronos2Trainer) + model = _DummyModel(torch.device("cuda:5"), hf_device_map={"": "cuda:5"}) + + Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0")) + + assert captured["device"] == torch.device("cuda:0")