Skip to content

Commit aa954ba

Browse files
abrichrclaude
andauthored
fix: make heavy ML dependencies optional for lightweight installs (#57)
* fix: make heavy ML dependencies optional for lightweight installs Move torch, torchvision, bitsandbytes, peft, and transformers from required dependencies to [project.optional-dependencies.training]. Wrap all top-level imports of these packages in try/except ImportError so the package can be imported without them installed. This unblocks lightweight consumers (e.g. Wright worker installing openadapt-evals) that don't need local model training/inference. Users who need training can install with: pip install openadapt-ml[training] Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * style: fix ruff formatting in qwen_vl.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 942a2f3 commit aa954ba

File tree

8 files changed

+146
-49
lines changed

8 files changed

+146
-49
lines changed

openadapt_ml/datasets/next_action.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from dataclasses import dataclass
44
from typing import Any, Dict, List
55

6-
from torch.utils.data import Dataset
6+
try:
7+
from torch.utils.data import Dataset as _TorchDataset
8+
except ImportError:
9+
_TorchDataset = None # type: ignore[assignment,misc]
710

811
from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
912

@@ -523,10 +526,18 @@ class NextActionSample:
523526
messages: List[Dict[str, str]]
524527

525528

526-
class NextActionDataset(Dataset):
527-
"""Thin PyTorch Dataset wrapper around pre-built SFT samples."""
529+
class NextActionDataset(_TorchDataset if _TorchDataset is not None else object): # type: ignore[misc]
530+
"""Thin PyTorch Dataset wrapper around pre-built SFT samples.
531+
532+
Requires torch to be installed (pip install openadapt-ml[training]).
533+
"""
528534

529535
def __init__(self, samples: List[Dict[str, Any]]):
536+
if _TorchDataset is None:
537+
raise ImportError(
538+
"torch is required for NextActionDataset. "
539+
"Install with: pip install openadapt-ml[training]"
540+
)
530541
self._samples = samples
531542

532543
def __len__(self) -> int: # type: ignore[override]

openadapt_ml/models/api_adapter.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import base64
66
import os
77

8-
import torch
8+
try:
9+
import torch
10+
except ImportError:
11+
torch = None # type: ignore[assignment]
912

1013
from openadapt_ml.config import settings
1114
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
@@ -21,7 +24,7 @@ class ApiVLMAdapter(BaseVLMAdapter):
2124
def __init__(
2225
self,
2326
provider: str,
24-
device: Optional[torch.device] = None,
27+
device: Optional[Any] = None,
2528
api_key: Optional[str] = None,
2629
model_name: Optional[str] = None,
2730
) -> None:
@@ -84,22 +87,27 @@ def __init__(
8487
else:
8588
raise ValueError(f"Unsupported provider: {provider}")
8689

87-
if device is None:
88-
device = get_default_device()
89-
9090
# Store client separately; BaseVLMAdapter expects a model + processor, so
9191
# we pass a tiny dummy module and the client as the "processor".
9292
self._client = client
93-
model = torch.nn.Identity()
94-
processor: Any = client
95-
super().__init__(model=model, processor=processor, device=device)
93+
if torch is not None:
94+
if device is None:
95+
device = get_default_device()
96+
model = torch.nn.Identity()
97+
processor: Any = client
98+
super().__init__(model=model, processor=processor, device=device)
99+
else:
100+
# Lightweight mode: skip torch-based init for API-only usage
101+
self.model = None
102+
self.processor = client
103+
self.device = None
96104

97105
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type: ignore[override]
98106
raise NotImplementedError(
99107
"ApiVLMAdapter does not support training (prepare_inputs)"
100108
)
101109

102-
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
110+
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
103111
raise NotImplementedError(
104112
"ApiVLMAdapter does not support training (compute_loss)"
105113
)

openadapt_ml/models/base_adapter.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, List, Optional
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
55

6-
import torch
6+
try:
7+
import torch
8+
except ImportError:
9+
torch = None # type: ignore[assignment]
710

11+
if TYPE_CHECKING:
12+
import torch
813

9-
def get_default_device() -> torch.device:
14+
15+
def get_default_device() -> "torch.device":
1016
"""Select cuda, then mps, then cpu.
1117
1218
This is used as a fallback when no explicit device is provided.
19+
Requires torch to be installed.
1320
"""
21+
if torch is None:
22+
raise ImportError(
23+
"torch is required for model operations. "
24+
"Install with: pip install openadapt-ml[training]"
25+
)
1426

1527
if torch.cuda.is_available():
1628
return torch.device("cuda")
@@ -29,14 +41,21 @@ class BaseVLMAdapter(ABC):
2941
- converting SFT-style samples into model inputs (tokenization, image processing)
3042
- computing supervised training loss
3143
- generating assistant text given a single sample at inference time
44+
45+
Requires torch to be installed (pip install openadapt-ml[training]).
3246
"""
3347

3448
def __init__(
3549
self,
36-
model: torch.nn.Module,
50+
model: Any,
3751
processor: Any,
38-
device: Optional[torch.device] = None,
52+
device: Optional[Any] = None,
3953
) -> None:
54+
if torch is None:
55+
raise ImportError(
56+
"torch is required for BaseVLMAdapter. "
57+
"Install with: pip install openadapt-ml[training]"
58+
)
4059
self.model = model
4160
self.processor = processor
4261
self.device = device or get_default_device()
@@ -53,7 +72,7 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
5372
"""
5473

5574
@abstractmethod
56-
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor:
75+
def compute_loss(self, inputs: Dict[str, Any]) -> Any:
5776
"""Run the model forward and return a scalar loss tensor."""
5877

5978
@abstractmethod

openadapt_ml/models/dummy_adapter.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
from typing import Any, Dict, List, Optional
44

5-
import torch
6-
from torch import nn
5+
try:
6+
import torch
7+
from torch import nn
8+
except ImportError:
9+
torch = None # type: ignore[assignment]
10+
nn = None # type: ignore[assignment]
711

812
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
913

@@ -14,9 +18,16 @@ class DummyAdapter(BaseVLMAdapter):
1418
- Ignores images/messages content.
1519
- Uses a tiny linear model and returns a simple MSE loss.
1620
- generate() returns a fixed string.
21+
22+
Requires torch to be installed (pip install openadapt-ml[training]).
1723
"""
1824

19-
def __init__(self, device: Optional[torch.device] = None) -> None:
25+
def __init__(self, device: Optional[Any] = None) -> None:
26+
if torch is None:
27+
raise ImportError(
28+
"torch is required for DummyAdapter. "
29+
"Install with: pip install openadapt-ml[training]"
30+
)
2031
if device is None:
2132
device = get_default_device()
2233
# Tiny dummy model with a few parameters
@@ -32,7 +43,7 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type
3243
y = torch.zeros(batch_size, 1, device=self.device)
3344
return {"inputs": x, "targets": y}
3445

35-
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
46+
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
3647
x = inputs["inputs"]
3748
y = inputs["targets"]
3849
preds = self.model(x)

openadapt_ml/models/qwen_vl.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,28 @@
22

33
from typing import Any, Dict, List, Optional
44

5-
import torch
6-
from peft import LoraConfig, PeftModel, get_peft_model
7-
from transformers import (
8-
AutoProcessor,
9-
Qwen3VLForConditionalGeneration,
10-
Qwen2_5_VLForConditionalGeneration,
11-
)
5+
try:
6+
import torch
7+
except ImportError:
8+
torch = None # type: ignore[assignment]
9+
10+
try:
11+
from peft import LoraConfig, PeftModel, get_peft_model
12+
except ImportError:
13+
LoraConfig = None # type: ignore[assignment,misc]
14+
PeftModel = None # type: ignore[assignment,misc]
15+
get_peft_model = None # type: ignore[assignment]
16+
17+
try:
18+
from transformers import (
19+
AutoProcessor,
20+
Qwen3VLForConditionalGeneration,
21+
Qwen2_5_VLForConditionalGeneration,
22+
)
23+
except ImportError:
24+
AutoProcessor = None # type: ignore[assignment,misc]
25+
Qwen3VLForConditionalGeneration = None # type: ignore[assignment,misc]
26+
Qwen2_5_VLForConditionalGeneration = None # type: ignore[assignment,misc]
1227

1328
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
1429

@@ -67,9 +82,9 @@ class QwenVLAdapter(BaseVLMAdapter):
6782

6883
def __init__(
6984
self,
70-
model: torch.nn.Module,
85+
model: Any,
7186
processor: Any,
72-
device: Optional[torch.device] = None,
87+
device: Optional[Any] = None,
7388
version: str = "qwen3",
7489
) -> None:
7590
super().__init__(model=model, processor=processor, device=device)
@@ -79,9 +94,9 @@ def __init__(
7994
def from_pretrained(
8095
cls,
8196
model_name: str,
82-
lora_config: Optional[LoraConfig | Dict[str, Any]] = None,
97+
lora_config: Optional[Any] = None,
8398
load_in_4bit: bool = False,
84-
device: Optional[torch.device] = None,
99+
device: Optional[Any] = None,
85100
max_pixels: Optional[int] = None,
86101
min_pixels: Optional[int] = None,
87102
) -> "QwenVLAdapter":
@@ -91,7 +106,22 @@ def from_pretrained(
91106
max_pixels: Maximum image size in pixels (e.g., 512*512=262144 for faster training).
92107
If None, uses model default (very large).
93108
min_pixels: Minimum image size in pixels. If None, uses model default.
109+
110+
Requires torch, transformers, and peft to be installed
111+
(pip install openadapt-ml[training]).
94112
"""
113+
_missing = []
114+
if torch is None:
115+
_missing.append("torch")
116+
if AutoProcessor is None:
117+
_missing.append("transformers")
118+
if LoraConfig is None:
119+
_missing.append("peft")
120+
if _missing:
121+
raise ImportError(
122+
f"{', '.join(_missing)} required for QwenVLAdapter. "
123+
"Install with: pip install openadapt-ml[training]"
124+
)
95125

96126
if "Qwen3-VL" in model_name or "Qwen3VL" in model_name:
97127
version = "qwen3"
@@ -368,9 +398,11 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type
368398
inputs["labels"] = labels
369399
return inputs
370400

371-
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
401+
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
372402
inputs = {
373-
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
403+
k: v.to(self.device)
404+
if torch is not None and isinstance(v, torch.Tensor)
405+
else v
374406
for k, v in inputs.items()
375407
}
376408
outputs = self.model(**inputs)

openadapt_ml/training/grpo/__init__.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
from __future__ import annotations
4444

45+
# Lightweight imports (no torch required)
4546
from openadapt_ml.training.grpo.config import GRPOConfig
4647
from openadapt_ml.training.grpo.reward import (
4748
binary_task_success,
@@ -51,13 +52,6 @@
5152
GRPORolloutCollector,
5253
Rollout,
5354
)
54-
from openadapt_ml.training.grpo.trainer import (
55-
GRPOTrainer,
56-
policy_gradient_loss,
57-
grpo_loss,
58-
parse_vlm_output_to_action,
59-
format_action_as_text,
60-
)
6155
from openadapt_ml.training.grpo.cot_warmup import (
6256
build_cot_sft_samples,
6357
generate_cot_annotations,
@@ -67,6 +61,24 @@
6761
train_with_verl,
6862
)
6963

64+
# Lazy imports for torch-dependent modules
65+
_TRAINER_NAMES = {
66+
"GRPOTrainer",
67+
"policy_gradient_loss",
68+
"grpo_loss",
69+
"parse_vlm_output_to_action",
70+
"format_action_as_text",
71+
}
72+
73+
74+
def __getattr__(name: str):
75+
if name in _TRAINER_NAMES:
76+
from openadapt_ml.training.grpo import trainer as _trainer
77+
78+
return getattr(_trainer, name)
79+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
80+
81+
7082
__all__ = [
7183
"GRPOConfig",
7284
"GRPOTrainer",

openadapt_ml/training/grpo/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
from pathlib import Path
3737
from typing import Any, Callable
3838

39-
import torch
39+
try:
40+
import torch
41+
except ImportError:
42+
torch = None # type: ignore[assignment]
43+
4044
from PIL import Image
4145

4246
from openadapt_ml.datasets.next_action import SYSTEM_PROMPT

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,16 @@ classifiers = [
2323

2424
dependencies = [
2525
"anthropic>=0.75.0",
26-
"bitsandbytes>=0.41.0", # For 4-bit quantization
2726
"click>=8.1.0", # CLI framework
2827
"google-generativeai>=0.8.5",
2928
"matplotlib>=3.10.7",
3029
"modal>=1.3.4",
3130
"openadapt-capture>=0.3.0",
32-
"peft>=0.18.0",
3331
"pillow>=12.0.0",
3432
"pyautogui>=0.9.54",
3533
"pydantic-settings>=2.0.0",
3634
"pytest>=9.0.2",
3735
"pyyaml>=6.0.3",
38-
"torch>=2.8.0",
39-
"torchvision>=0.24.1",
40-
"transformers>=4.57.3",
4136
]
4237

4338
[project.optional-dependencies]
@@ -57,8 +52,13 @@ lambda-labs = [
5752
parquet = [
5853
"pyarrow>=14.0.0",
5954
]
60-
# TRL + Unsloth for optimized training (2x faster, 50% less VRAM)
55+
# Heavy ML dependencies for local model training and inference
6156
training = [
57+
"bitsandbytes>=0.41.0", # For 4-bit quantization
58+
"peft>=0.18.0",
59+
"torch>=2.8.0",
60+
"torchvision>=0.24.1",
61+
"transformers>=4.57.3",
6262
"trl>=0.12.0",
6363
"datasets>=2.18.0",
6464
]

0 commit comments

Comments
 (0)