Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/chronos/chronos2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def get_train_dataloader(self) -> DataLoader:

return DataLoader(train_dataset, **dataloader_params) # type: ignore

def _move_model_to_device(self, model, device):
"""
Keep the model on its existing CUDA device when fine-tuning a single-device model.

`Trainer` may otherwise move a model loaded on e.g. `cuda:5` to `args.device` (often `cuda:0`).
"""
model_device = getattr(model, "device", None)
model_device_type = getattr(model_device, "type", None)
target_device_type = getattr(device, "type", None)
has_hf_device_map = getattr(model, "hf_device_map", None) is not None

if (
not has_hf_device_map
and model_device is not None
and model_device_type == "cuda"
and target_device_type == "cuda"
and model_device != device
):
device = model_device

super()._move_model_to_device(model, device)

def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
if self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
Expand Down
64 changes: 64 additions & 0 deletions test/test_chronos2_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import torch
from transformers import Trainer

from chronos.chronos2.trainer import Chronos2Trainer


class _DummyModel:
def __init__(self, device: torch.device, hf_device_map=None):
self.device = device
self.hf_device_map = hf_device_map


def test_move_model_to_device_preserves_loaded_cuda_device(monkeypatch):
"""When model is on a single CUDA device, keep that device instead of forcing cuda:0."""
captured = {}

def fake_move(self, model, device):
captured["device"] = device

monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)

trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cuda:5"))

Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))

assert captured["device"] == torch.device("cuda:5")


def test_move_model_to_device_keeps_requested_cpu_device(monkeypatch):
"""CPU fine-tuning should preserve existing Trainer behavior."""
captured = {}

def fake_move(self, model, device):
captured["device"] = device

monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)

trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cpu"))

Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cpu"))

assert captured["device"] == torch.device("cpu")


def test_move_model_to_device_keeps_requested_device_for_hf_device_map(monkeypatch):
"""Do not override device movement for models managed via hf_device_map."""
captured = {}

def fake_move(self, model, device):
captured["device"] = device

monkeypatch.setattr(Trainer, "_move_model_to_device", fake_move)

trainer = object.__new__(Chronos2Trainer)
model = _DummyModel(torch.device("cuda:5"), hf_device_map={"": "cuda:5"})

Chronos2Trainer._move_model_to_device(trainer, model, torch.device("cuda:0"))

assert captured["device"] == torch.device("cuda:0")