Skip to content

Commit 57e5c5f

Browse files
abrichrclaude
andcommitted
feat(modal): add inference serving with call_inference API
- Add _build_inference_app() for Modal GPU inference with PEFT adapter - Add upload_adapter_to_volume() for uploading adapters to Modal volume - Add call_inference() as the primary API for remote inference - Add 'serve' CLI command for interactive model serving - Container caches model in memory across calls (container_idle_timeout=600) - Support --no-adapter for zero-shot base model serving Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8185d25 commit 57e5c5f

2 files changed

Lines changed: 456 additions & 1 deletion

File tree

openadapt_ml/cloud/modal_cloud.py

Lines changed: 344 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Modal cloud GPU integration for training.
1+
"""Modal cloud GPU integration for training and inference.
22
33
Modal is a Python-native serverless cloud platform:
44
- No SSH, no instances to manage
@@ -25,6 +25,11 @@
2525
# Download results
2626
python -m openadapt_ml.cloud.modal_cloud download --output ./results
2727
28+
# Serve fine-tuned model for inference
29+
python -m openadapt_ml.cloud.modal_cloud serve \
30+
--adapter /path/to/adapter \
31+
--base-model Qwen/Qwen3-VL-2B-Instruct
32+
2833
# List volumes
2934
python -m openadapt_ml.cloud.modal_cloud list-volumes
3035
"""
@@ -255,6 +260,224 @@ def train_model(
255260
return train_model
256261

257262

263+
# ---------------------------------------------------------------------------
264+
# Inference serving
265+
# ---------------------------------------------------------------------------
266+
267+
INFERENCE_APP_NAME = "openadapt-inference"
268+
269+
270+
def _build_inference_app(
271+
adapter_path: str | None = None,
272+
base_model: str = "Qwen/Qwen3-VL-2B-Instruct",
273+
gpu: str = "A10G",
274+
):
275+
"""Build Modal app for model inference.
276+
277+
Args:
278+
adapter_path: Path to PEFT adapter in the volume (e.g., /training/results/final).
279+
base_model: HuggingFace model ID for the base model.
280+
gpu: GPU type.
281+
282+
Returns:
283+
(app, infer_fn) - the app and the inference function handle.
284+
"""
285+
modal = _get_modal()
286+
287+
app = modal.App(INFERENCE_APP_NAME)
288+
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
289+
290+
inference_image = modal.Image.debian_slim(python_version="3.12").pip_install(
291+
"torch",
292+
"transformers",
293+
"peft",
294+
"accelerate",
295+
"pillow",
296+
"qwen-vl-utils",
297+
)
298+
299+
vol = volume
300+
_adapter = adapter_path
301+
_base = base_model
302+
303+
@app.function(
304+
gpu=gpu,
305+
image=inference_image,
306+
volumes={VOLUME_MOUNT: vol},
307+
timeout=300,
308+
serialized=True,
309+
container_idle_timeout=600,
310+
)
311+
def infer(
312+
messages_json: str,
313+
image_base64: str | None = None,
314+
max_new_tokens: int = 512,
315+
) -> str:
316+
"""Run inference on the fine-tuned model.
317+
318+
Args:
319+
messages_json: JSON-encoded list of messages (OpenAI chat format).
320+
image_base64: Base64-encoded screenshot image (optional).
321+
max_new_tokens: Maximum tokens to generate.
322+
323+
Returns:
324+
JSON string with 'response' key containing model output.
325+
"""
326+
import base64 as _base64
327+
import json as _json
328+
from io import BytesIO as _BytesIO
329+
330+
import torch
331+
from PIL import Image as _Image
332+
from transformers import AutoModelForVision2Seq, AutoProcessor
333+
334+
# Load model (cached in container memory across calls)
335+
if not hasattr(infer, "_model"):
336+
print(f"Loading base model: {_base}")
337+
infer._model = AutoModelForVision2Seq.from_pretrained(
338+
_base,
339+
torch_dtype=torch.bfloat16,
340+
device_map="auto",
341+
)
342+
343+
if _adapter:
344+
from peft import PeftModel
345+
346+
print(f"Loading PEFT adapter: {_adapter}")
347+
vol.reload()
348+
infer._model = PeftModel.from_pretrained(infer._model, _adapter)
349+
350+
infer._processor = AutoProcessor.from_pretrained(_base)
351+
print("Model ready for inference")
352+
353+
messages = _json.loads(messages_json)
354+
355+
# If image_base64 is provided, decode it
356+
image = None
357+
if image_base64:
358+
img_bytes = _base64.b64decode(image_base64)
359+
image = _Image.open(_BytesIO(img_bytes)).convert("RGB")
360+
361+
# Build inputs using the processor's chat template
362+
text = infer._processor.apply_chat_template(
363+
messages, tokenize=False, add_generation_prompt=True
364+
)
365+
366+
if image is not None:
367+
inputs = infer._processor(
368+
text=[text], images=[image], return_tensors="pt", padding=True
369+
)
370+
else:
371+
inputs = infer._processor(
372+
text=[text], return_tensors="pt", padding=True
373+
)
374+
375+
inputs = inputs.to(infer._model.device)
376+
377+
with torch.no_grad():
378+
output_ids = infer._model.generate(
379+
**inputs,
380+
max_new_tokens=max_new_tokens,
381+
do_sample=False,
382+
)
383+
384+
# Decode only the generated tokens (skip the input)
385+
generated_ids = output_ids[:, inputs["input_ids"].shape[1] :]
386+
response_text = infer._processor.batch_decode(
387+
generated_ids, skip_special_tokens=True
388+
)[0]
389+
390+
return _json.dumps({"response": response_text.strip()})
391+
392+
return app, infer
393+
394+
395+
def upload_adapter_to_volume(adapter_dir: str | Path) -> str:
396+
"""Upload a local PEFT adapter to the Modal volume.
397+
398+
Args:
399+
adapter_dir: Path to local adapter directory.
400+
401+
Returns:
402+
Remote path to the adapter in the volume.
403+
"""
404+
adapter_dir = Path(adapter_dir)
405+
if not adapter_dir.exists():
406+
raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
407+
if not (adapter_dir / "adapter_config.json").exists():
408+
raise FileNotFoundError(f"No adapter_config.json in: {adapter_dir}")
409+
410+
remote_path = "/adapter"
411+
412+
# Create volume if needed
413+
create_cmd = ["modal", "volume", "create", VOLUME_NAME]
414+
subprocess.run(create_cmd, capture_output=True, text=True)
415+
416+
cmd = [
417+
"modal",
418+
"volume",
419+
"put",
420+
VOLUME_NAME,
421+
str(adapter_dir),
422+
remote_path,
423+
"--force",
424+
]
425+
result = subprocess.run(cmd, capture_output=True, text=True)
426+
if result.returncode != 0:
427+
raise RuntimeError(f"Adapter upload failed: {result.stderr or result.stdout}")
428+
429+
full_remote = f"{VOLUME_MOUNT}{remote_path}"
430+
print(f"Adapter uploaded to volume at: {full_remote}")
431+
return full_remote
432+
433+
434+
def call_inference(
435+
messages: list[dict],
436+
image_base64: str | None = None,
437+
max_new_tokens: int = 512,
438+
adapter_path: str | None = None,
439+
base_model: str = "Qwen/Qwen3-VL-2B-Instruct",
440+
gpu: str = "A10G",
441+
) -> str:
442+
"""Call the Modal inference function remotely.
443+
444+
This is the primary API for external callers (e.g., Qwen3VLAgent).
445+
Builds and runs the Modal app, sends a single inference request,
446+
and returns the model output.
447+
448+
Args:
449+
messages: Chat messages in OpenAI format.
450+
image_base64: Base64-encoded image string.
451+
max_new_tokens: Maximum tokens to generate.
452+
adapter_path: Remote adapter path in the volume.
453+
base_model: HuggingFace model ID for the base model.
454+
gpu: GPU type.
455+
456+
Returns:
457+
Model response text.
458+
"""
459+
modal = _get_modal()
460+
modal.enable_output()
461+
462+
app, infer_fn = _build_inference_app(
463+
adapter_path=adapter_path,
464+
base_model=base_model,
465+
gpu=gpu,
466+
)
467+
468+
messages_json = json.dumps(messages)
469+
470+
with app.run():
471+
result_json = infer_fn.remote(
472+
messages_json=messages_json,
473+
image_base64=image_base64,
474+
max_new_tokens=max_new_tokens,
475+
)
476+
477+
result = json.loads(result_json)
478+
return result.get("response", "")
479+
480+
258481
# ---------------------------------------------------------------------------
259482
# Local helpers for CLI commands
260483
# ---------------------------------------------------------------------------
@@ -462,6 +685,34 @@ def cli_main(argv: list[str] | None = None) -> int:
462685
help="Local output directory (default: training_output/modal)",
463686
)
464687

688+
# --- serve ---
689+
serve_parser = subparsers.add_parser(
690+
"serve", help="Serve fine-tuned model for inference on Modal GPU"
691+
)
692+
serve_parser.add_argument(
693+
"--adapter",
694+
help="Local adapter directory to upload and serve",
695+
)
696+
serve_parser.add_argument(
697+
"--adapter-remote",
698+
help="Remote adapter path already in the volume (e.g., /training/results/final)",
699+
)
700+
serve_parser.add_argument(
701+
"--base-model",
702+
default="Qwen/Qwen3-VL-2B-Instruct",
703+
help="Base model HuggingFace ID (default: Qwen/Qwen3-VL-2B-Instruct)",
704+
)
705+
serve_parser.add_argument(
706+
"--gpu",
707+
default="A10G",
708+
help="GPU type (default: A10G)",
709+
)
710+
serve_parser.add_argument(
711+
"--no-adapter",
712+
action="store_true",
713+
help="Serve base model without adapter (zero-shot)",
714+
)
715+
465716
# --- list-volumes ---
466717
subparsers.add_parser("list-volumes", help="List Modal volumes")
467718

@@ -477,6 +728,8 @@ def cli_main(argv: list[str] | None = None) -> int:
477728
return _cmd_status(args)
478729
elif args.command == "download":
479730
return _cmd_download(args)
731+
elif args.command == "serve":
732+
return _cmd_serve(args)
480733
elif args.command == "list-volumes":
481734
return _cmd_list_volumes(args)
482735
else:
@@ -626,6 +879,96 @@ def _cmd_download(args: argparse.Namespace) -> int:
626879
return 1
627880

628881

882+
def _cmd_serve(args: argparse.Namespace) -> int:
883+
"""Serve a fine-tuned model on Modal GPU for inference.
884+
885+
Uploads the adapter (if local path provided), then starts the
886+
inference function that clients can call via Modal's .remote() API.
887+
Alternatively, clients can use the HTTP wrapper in Qwen3VLAgent.
888+
"""
889+
modal = _get_modal()
890+
891+
adapter_remote = None
892+
893+
if args.no_adapter:
894+
print(f"Serving base model: {args.base_model} (no adapter)")
895+
elif args.adapter:
896+
# Upload local adapter to volume
897+
print("Uploading adapter to Modal volume...")
898+
try:
899+
adapter_remote = upload_adapter_to_volume(args.adapter)
900+
except (FileNotFoundError, RuntimeError) as e:
901+
print(f"Error: {e}")
902+
return 1
903+
elif args.adapter_remote:
904+
adapter_remote = args.adapter_remote
905+
print(f"Using remote adapter: {adapter_remote}")
906+
else:
907+
# Default: use the latest training results
908+
adapter_remote = f"{RESULTS_REMOTE_PATH}/final"
909+
print(f"Using default adapter: {adapter_remote}")
910+
911+
print(f"Base model: {args.base_model}")
912+
print(f"GPU: {args.gpu}")
913+
print()
914+
915+
try:
916+
modal.enable_output()
917+
918+
app, infer_fn = _build_inference_app(
919+
adapter_path=adapter_remote,
920+
base_model=args.base_model,
921+
gpu=args.gpu,
922+
)
923+
924+
print("Starting inference server on Modal...")
925+
print("Press Ctrl+C to stop.\n")
926+
927+
with app.run():
928+
# Test with a simple warmup call
929+
test_messages = json.dumps(
930+
[
931+
{
932+
"role": "system",
933+
"content": "You are a GUI automation agent.",
934+
},
935+
{
936+
"role": "user",
937+
"content": "Respond with: ready",
938+
},
939+
]
940+
)
941+
result = infer_fn.remote(messages_json=test_messages)
942+
result_data = json.loads(result)
943+
print(f"Model ready. Test response: {result_data.get('response', '')}")
944+
print()
945+
print("=" * 50)
946+
print("INFERENCE SERVER RUNNING")
947+
print("=" * 50)
948+
print()
949+
print(
950+
"To run inference from another process, use:\n"
951+
" from openadapt_ml.cloud.modal_cloud import call_inference\n"
952+
" result = call_inference(messages, image_base64)\n"
953+
)
954+
print("Or use Qwen3VLAgent with --model-endpoint modal\n")
955+
956+
# Keep the app running until Ctrl+C
957+
import time as _time
958+
959+
try:
960+
while True:
961+
_time.sleep(1)
962+
except KeyboardInterrupt:
963+
print("\nShutting down inference server...")
964+
965+
except Exception as e:
966+
print(f"Serve failed: {e}")
967+
return 1
968+
969+
return 0
970+
971+
629972
def _cmd_list_volumes(args: argparse.Namespace) -> int:
630973
"""List Modal volumes."""
631974
list_volumes()

0 commit comments

Comments
 (0)