Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/grpo_trl_rewrite_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def train_grpo(waa_config: WAATrainingConfig | None = None) -> str:
# to verify this works with Unsloth's patched model loading.
#
# If incompatible, we can:
# (a) Use standard HF model loading (AutoModelForVision2Seq)
# (a) Use standard HF model loading (AutoModelForImageTextToText)
# (b) Load with Unsloth, then pass the model to TRL
# (c) Use Unsloth's GRPOTrainer fork (if available)
#
Expand Down
7 changes: 5 additions & 2 deletions openadapt_ml/cloud/modal_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,12 @@ def infer(
if not hasattr(infer, "_model"):
print(f"Loading base model: {_base}")
try:
from transformers import AutoModelForVision2Seq
try:
from transformers import AutoModelForImageTextToText as AutoVLM
except ImportError:
from transformers import AutoModelForVision2Seq as AutoVLM

infer._model = AutoModelForVision2Seq.from_pretrained(
infer._model = AutoVLM.from_pretrained(
_base,
torch_dtype=torch.bfloat16,
device_map="auto",
Expand Down
15 changes: 10 additions & 5 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- beta=0.0 (no KL penalty) per DAPO/Open-Reasoner-Zero. Simpler, saves
memory (no reference model needed).
- Per-step backward to avoid OOM on long trajectories.
- Standard HF model loading: AutoModelForVision2Seq + AutoProcessor + PEFT.
- Standard HF model loading: AutoModelForImageTextToText + AutoProcessor + PEFT.
- Standard PEFT checkpointing: model.save_pretrained().
"""

Expand Down Expand Up @@ -222,7 +222,12 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]:
(model, processor) tuple. Model has LoRA adapters attached.
"""
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import AutoProcessor

try:
from transformers import AutoModelForImageTextToText as AutoVLM
except ImportError:
from transformers import AutoModelForVision2Seq as AutoVLM

processor = AutoProcessor.from_pretrained(config.model_name)

Expand All @@ -239,7 +244,7 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]:
bnb_4bit_quant_type="nf4",
)

model = AutoModelForVision2Seq.from_pretrained(config.model_name, **load_kwargs)
model = AutoVLM.from_pretrained(config.model_name, **load_kwargs)

if config.lora_checkpoint:
logger.info("Loading existing LoRA from %s", config.lora_checkpoint)
Expand Down Expand Up @@ -322,7 +327,7 @@ def agent_fn(obs: Any) -> BenchmarkAction:
else:
text_input = messages[-1]["content"]

inputs = processor(text_input, images=[image], return_tensors="pt")
inputs = processor(text=[text_input], images=[image], return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.no_grad():
Expand Down Expand Up @@ -519,7 +524,7 @@ def _compute_rollout_loss(
text_input = messages[-1]["content"]

prompt_inputs = self._processor(
text_input, images=[image], return_tensors="pt"
text=[text_input], images=[image], return_tensors="pt"
)
prompt_ids = prompt_inputs["input_ids"]
prompt_len = prompt_ids.shape[1]
Expand Down
Loading