diff --git a/docs/grpo_trl_rewrite_draft.py b/docs/grpo_trl_rewrite_draft.py index 0b2f69c..06c144a 100644 --- a/docs/grpo_trl_rewrite_draft.py +++ b/docs/grpo_trl_rewrite_draft.py @@ -952,7 +952,7 @@ def train_grpo(waa_config: WAATrainingConfig | None = None) -> str: # to verify this works with Unsloth's patched model loading. # # If incompatible, we can: -# (a) Use standard HF model loading (AutoModelForVision2Seq) +# (a) Use standard HF model loading (AutoModelForImageTextToText) # (b) Load with Unsloth, then pass the model to TRL # (c) Use Unsloth's GRPOTrainer fork (if available) # diff --git a/openadapt_ml/cloud/modal_cloud.py b/openadapt_ml/cloud/modal_cloud.py index 3ce8bc1..7fedee4 100644 --- a/openadapt_ml/cloud/modal_cloud.py +++ b/openadapt_ml/cloud/modal_cloud.py @@ -336,9 +336,12 @@ def infer( if not hasattr(infer, "_model"): print(f"Loading base model: {_base}") try: - from transformers import AutoModelForVision2Seq + try: + from transformers import AutoModelForImageTextToText as AutoVLM + except ImportError: + from transformers import AutoModelForVision2Seq as AutoVLM - infer._model = AutoModelForVision2Seq.from_pretrained( + infer._model = AutoVLM.from_pretrained( _base, torch_dtype=torch.bfloat16, device_map="auto", diff --git a/openadapt_ml/training/grpo/trainer.py b/openadapt_ml/training/grpo/trainer.py index ec5966f..cb0290a 100644 --- a/openadapt_ml/training/grpo/trainer.py +++ b/openadapt_ml/training/grpo/trainer.py @@ -23,7 +23,7 @@ - beta=0.0 (no KL penalty) per DAPO/Open-Reasoner-Zero. Simpler, saves memory (no reference model needed). - Per-step backward to avoid OOM on long trajectories. - - Standard HF model loading: AutoModelForVision2Seq + AutoProcessor + PEFT. + - Standard HF model loading: AutoModelForImageTextToText + AutoProcessor + PEFT. - Standard PEFT checkpointing: model.save_pretrained(). """ @@ -222,7 +222,12 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]: (model, processor) tuple. Model has LoRA adapters attached. """ from peft import LoraConfig, PeftModel, get_peft_model - from transformers import AutoModelForVision2Seq, AutoProcessor + from transformers import AutoProcessor + + try: + from transformers import AutoModelForImageTextToText as AutoVLM + except ImportError: + from transformers import AutoModelForVision2Seq as AutoVLM processor = AutoProcessor.from_pretrained(config.model_name) @@ -239,7 +244,7 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]: bnb_4bit_quant_type="nf4", ) - model = AutoModelForVision2Seq.from_pretrained(config.model_name, **load_kwargs) + model = AutoVLM.from_pretrained(config.model_name, **load_kwargs) if config.lora_checkpoint: logger.info("Loading existing LoRA from %s", config.lora_checkpoint) @@ -322,7 +327,7 @@ def agent_fn(obs: Any) -> BenchmarkAction: else: text_input = messages[-1]["content"] - inputs = processor(text_input, images=[image], return_tensors="pt") + inputs = processor(text=[text_input], images=[image], return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): @@ -519,7 +524,7 @@ def _compute_rollout_loss( text_input = messages[-1]["content"] prompt_inputs = self._processor( - text_input, images=[image], return_tensors="pt" + text=[text_input], images=[image], return_tensors="pt" ) prompt_ids = prompt_inputs["input_ids"] prompt_len = prompt_ids.shape[1]