Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions openadapt_ml/datasets/next_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List

from torch.utils.data import Dataset
try:
from torch.utils.data import Dataset as _TorchDataset
except ImportError:
_TorchDataset = None # type: ignore[assignment,misc]

from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement

Expand Down Expand Up @@ -523,10 +526,18 @@ class NextActionSample:
messages: List[Dict[str, str]]


class NextActionDataset(Dataset):
"""Thin PyTorch Dataset wrapper around pre-built SFT samples."""
class NextActionDataset(_TorchDataset if _TorchDataset is not None else object): # type: ignore[misc]
"""Thin PyTorch Dataset wrapper around pre-built SFT samples.

Requires torch to be installed (pip install openadapt-ml[training]).
"""

def __init__(self, samples: List[Dict[str, Any]]):
if _TorchDataset is None:
raise ImportError(
"torch is required for NextActionDataset. "
"Install with: pip install openadapt-ml[training]"
)
self._samples = samples

def __len__(self) -> int: # type: ignore[override]
Expand Down
26 changes: 17 additions & 9 deletions openadapt_ml/models/api_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import base64
import os

import torch
try:
import torch
except ImportError:
torch = None # type: ignore[assignment]

from openadapt_ml.config import settings
from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device
Expand All @@ -21,7 +24,7 @@ class ApiVLMAdapter(BaseVLMAdapter):
def __init__(
self,
provider: str,
device: Optional[torch.device] = None,
device: Optional[Any] = None,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -84,22 +87,27 @@ def __init__(
else:
raise ValueError(f"Unsupported provider: {provider}")

if device is None:
device = get_default_device()

# Store client separately; BaseVLMAdapter expects a model + processor, so
# we pass a tiny dummy module and the client as the "processor".
self._client = client
model = torch.nn.Identity()
processor: Any = client
super().__init__(model=model, processor=processor, device=device)
if torch is not None:
if device is None:
device = get_default_device()
model = torch.nn.Identity()
processor: Any = client
super().__init__(model=model, processor=processor, device=device)
else:
# Lightweight mode: skip torch-based init for API-only usage
self.model = None
self.processor = client
self.device = None

def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type: ignore[override]
raise NotImplementedError(
"ApiVLMAdapter does not support training (prepare_inputs)"
)

def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
raise NotImplementedError(
"ApiVLMAdapter does not support training (compute_loss)"
)
Expand Down
31 changes: 25 additions & 6 deletions openadapt_ml/models/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import torch
try:
import torch
except ImportError:
torch = None # type: ignore[assignment]

if TYPE_CHECKING:
import torch

def get_default_device() -> torch.device:

def get_default_device() -> "torch.device":
"""Select cuda, then mps, then cpu.

This is used as a fallback when no explicit device is provided.
Requires torch to be installed.
"""
if torch is None:
raise ImportError(
"torch is required for model operations. "
"Install with: pip install openadapt-ml[training]"
)

if torch.cuda.is_available():
return torch.device("cuda")
Expand All @@ -29,14 +41,21 @@ class BaseVLMAdapter(ABC):
- converting SFT-style samples into model inputs (tokenization, image processing)
- computing supervised training loss
- generating assistant text given a single sample at inference time

Requires torch to be installed (pip install openadapt-ml[training]).
"""

def __init__(
self,
model: torch.nn.Module,
model: Any,
processor: Any,
device: Optional[torch.device] = None,
device: Optional[Any] = None,
) -> None:
if torch is None:
raise ImportError(
"torch is required for BaseVLMAdapter. "
"Install with: pip install openadapt-ml[training]"
)
self.model = model
self.processor = processor
self.device = device or get_default_device()
Expand All @@ -53,7 +72,7 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""

@abstractmethod
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor:
def compute_loss(self, inputs: Dict[str, Any]) -> Any:
"""Run the model forward and return a scalar loss tensor."""

@abstractmethod
Expand Down
19 changes: 15 additions & 4 deletions openadapt_ml/models/dummy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from typing import Any, Dict, List, Optional

import torch
from torch import nn
try:
import torch
from torch import nn
except ImportError:
torch = None # type: ignore[assignment]
nn = None # type: ignore[assignment]

from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device

Expand All @@ -14,9 +18,16 @@ class DummyAdapter(BaseVLMAdapter):
- Ignores images/messages content.
- Uses a tiny linear model and returns a simple MSE loss.
- generate() returns a fixed string.

Requires torch to be installed (pip install openadapt-ml[training]).
"""

def __init__(self, device: Optional[torch.device] = None) -> None:
def __init__(self, device: Optional[Any] = None) -> None:
if torch is None:
raise ImportError(
"torch is required for DummyAdapter. "
"Install with: pip install openadapt-ml[training]"
)
if device is None:
device = get_default_device()
# Tiny dummy model with a few parameters
Expand All @@ -32,7 +43,7 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type
y = torch.zeros(batch_size, 1, device=self.device)
return {"inputs": x, "targets": y}

def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
x = inputs["inputs"]
y = inputs["targets"]
preds = self.model(x)
Expand Down
58 changes: 45 additions & 13 deletions openadapt_ml/models/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@

from typing import Any, Dict, List, Optional

import torch
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import (
AutoProcessor,
Qwen3VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
try:
import torch
except ImportError:
torch = None # type: ignore[assignment]

try:
from peft import LoraConfig, PeftModel, get_peft_model
except ImportError:
LoraConfig = None # type: ignore[assignment,misc]
PeftModel = None # type: ignore[assignment,misc]
get_peft_model = None # type: ignore[assignment]

try:
from transformers import (
AutoProcessor,
Qwen3VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
except ImportError:
AutoProcessor = None # type: ignore[assignment,misc]
Qwen3VLForConditionalGeneration = None # type: ignore[assignment,misc]
Qwen2_5_VLForConditionalGeneration = None # type: ignore[assignment,misc]

from openadapt_ml.models.base_adapter import BaseVLMAdapter, get_default_device

Expand Down Expand Up @@ -67,9 +82,9 @@ class QwenVLAdapter(BaseVLMAdapter):

def __init__(
self,
model: torch.nn.Module,
model: Any,
processor: Any,
device: Optional[torch.device] = None,
device: Optional[Any] = None,
version: str = "qwen3",
) -> None:
super().__init__(model=model, processor=processor, device=device)
Expand All @@ -79,9 +94,9 @@ def __init__(
def from_pretrained(
cls,
model_name: str,
lora_config: Optional[LoraConfig | Dict[str, Any]] = None,
lora_config: Optional[Any] = None,
load_in_4bit: bool = False,
device: Optional[torch.device] = None,
device: Optional[Any] = None,
max_pixels: Optional[int] = None,
min_pixels: Optional[int] = None,
) -> "QwenVLAdapter":
Expand All @@ -91,7 +106,22 @@ def from_pretrained(
max_pixels: Maximum image size in pixels (e.g., 512*512=262144 for faster training).
If None, uses model default (very large).
min_pixels: Minimum image size in pixels. If None, uses model default.

Requires torch, transformers, and peft to be installed
(pip install openadapt-ml[training]).
"""
_missing = []
if torch is None:
_missing.append("torch")
if AutoProcessor is None:
_missing.append("transformers")
if LoraConfig is None:
_missing.append("peft")
if _missing:
raise ImportError(
f"{', '.join(_missing)} required for QwenVLAdapter. "
"Install with: pip install openadapt-ml[training]"
)

if "Qwen3-VL" in model_name or "Qwen3VL" in model_name:
version = "qwen3"
Expand Down Expand Up @@ -368,9 +398,11 @@ def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type
inputs["labels"] = labels
return inputs

def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
def compute_loss(self, inputs: Dict[str, Any]) -> Any: # type: ignore[override]
inputs = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
k: v.to(self.device)
if torch is not None and isinstance(v, torch.Tensor)
else v
for k, v in inputs.items()
}
outputs = self.model(**inputs)
Expand Down
26 changes: 19 additions & 7 deletions openadapt_ml/training/grpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from __future__ import annotations

# Lightweight imports (no torch required)
from openadapt_ml.training.grpo.config import GRPOConfig
from openadapt_ml.training.grpo.reward import (
binary_task_success,
Expand All @@ -51,13 +52,6 @@
GRPORolloutCollector,
Rollout,
)
from openadapt_ml.training.grpo.trainer import (
GRPOTrainer,
policy_gradient_loss,
grpo_loss,
parse_vlm_output_to_action,
format_action_as_text,
)
from openadapt_ml.training.grpo.cot_warmup import (
build_cot_sft_samples,
generate_cot_annotations,
Expand All @@ -67,6 +61,24 @@
train_with_verl,
)

# Lazy imports for torch-dependent modules
_TRAINER_NAMES = {
"GRPOTrainer",
"policy_gradient_loss",
"grpo_loss",
"parse_vlm_output_to_action",
"format_action_as_text",
}


def __getattr__(name: str):
if name in _TRAINER_NAMES:
from openadapt_ml.training.grpo import trainer as _trainer

return getattr(_trainer, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = [
"GRPOConfig",
"GRPOTrainer",
Expand Down
6 changes: 5 additions & 1 deletion openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from pathlib import Path
from typing import Any, Callable

import torch
try:
import torch
except ImportError:
torch = None # type: ignore[assignment]

from PIL import Image

from openadapt_ml.datasets.next_action import SYSTEM_PROMPT
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@ classifiers = [

dependencies = [
"anthropic>=0.75.0",
"bitsandbytes>=0.41.0", # For 4-bit quantization
"click>=8.1.0", # CLI framework
"google-generativeai>=0.8.5",
"matplotlib>=3.10.7",
"modal>=1.3.4",
"openadapt-capture>=0.3.0",
"peft>=0.18.0",
"pillow>=12.0.0",
"pyautogui>=0.9.54",
"pydantic-settings>=2.0.0",
"pytest>=9.0.2",
"pyyaml>=6.0.3",
"torch>=2.8.0",
"torchvision>=0.24.1",
"transformers>=4.57.3",
]

[project.optional-dependencies]
Expand All @@ -57,8 +52,13 @@ lambda-labs = [
parquet = [
"pyarrow>=14.0.0",
]
# TRL + Unsloth for optimized training (2x faster, 50% less VRAM)
# Heavy ML dependencies for local model training and inference
training = [
"bitsandbytes>=0.41.0", # For 4-bit quantization
"peft>=0.18.0",
"torch>=2.8.0",
"torchvision>=0.24.1",
"transformers>=4.57.3",
"trl>=0.12.0",
"datasets>=2.18.0",
]
Expand Down
Loading