Skip to content

Commit e809336

Browse files
committed
feat: recipe validation integ test
1 parent 4c184d4 commit e809336

2 files changed

Lines changed: 119 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Recipe validation integ test for the HP-ModelCustomization-RecipeValidator pipeline.
14+
15+
Iterates through every model in the private hub referenced by the ``HYPERPOD_HUB_NAME``
16+
env var and validates that each fine-tuning recipe can be used to instantiate the
17+
appropriate ``sagemaker.train`` Trainer class (SFT/DPO/RLAIF/RLVR).
18+
"""
19+
from __future__ import absolute_import
20+
21+
import json
22+
import os
23+
24+
import boto3
25+
26+
from sagemaker.train.common import TrainingType
27+
from sagemaker.train.dpo_trainer import DPOTrainer
28+
from sagemaker.train.rlaif_trainer import RLAIFTrainer
29+
from sagemaker.train.rlvr_trainer import RLVRTrainer
30+
from sagemaker.train.sft_trainer import SFTTrainer
31+
32+
TRAINER_MAPPING = {
33+
"sft": SFTTrainer,
34+
"dpo": DPOTrainer,
35+
"rlaif": RLAIFTrainer,
36+
"rlvr": RLVRTrainer,
37+
}
38+
39+
DUMMY_DATASET = "s3://placeholder/validation-data"
40+
DUMMY_MODEL_PACKAGE_GROUP = "recipe-validation-test"
41+
42+
43+
def detect_training_type(recipe_path: str) -> str:
44+
"""Detect SFT/DPO/RLAIF/RLVR from the recipe name; default to SFT."""
45+
if not recipe_path:
46+
return "sft"
47+
lower = recipe_path.lower()
48+
if "rlvr" in lower:
49+
return "rlvr"
50+
if "rlaif" in lower:
51+
return "rlaif"
52+
if "dpo" in lower:
53+
return "dpo"
54+
return "sft"
55+
56+
57+
def detect_lora_or_full(recipe_path: str) -> TrainingType:
58+
"""Detect LoRA vs full fine-tuning from the recipe name; default to LoRA."""
59+
if not recipe_path:
60+
return TrainingType.LORA
61+
lower = recipe_path.lower()
62+
if "_fft" in lower or "full_fine_tuning" in lower:
63+
return TrainingType.FULL
64+
return TrainingType.LORA
65+
66+
67+
def test_new_recipes_create_valid_trainers():
68+
"""Validate every new/modified recipe in the private hub yields a valid Trainer."""
69+
hub_name = os.environ.get("HYPERPOD_HUB_NAME")
70+
assert hub_name, "HYPERPOD_HUB_NAME environment variable must be set"
71+
72+
sm = boto3.client("sagemaker", region_name="us-west-2")
73+
74+
models = []
75+
kwargs = {"HubName": hub_name, "HubContentType": "Model"}
76+
while True:
77+
response = sm.list_hub_contents(**kwargs)
78+
models.extend([item["HubContentName"] for item in response["HubContentSummaries"]])
79+
next_token = response.get("NextToken")
80+
if not next_token:
81+
break
82+
kwargs["NextToken"] = next_token
83+
84+
if not models:
85+
return
86+
87+
errors = []
88+
for model_name in models:
89+
try:
90+
response = sm.describe_hub_content(
91+
HubName=hub_name,
92+
HubContentType="Model",
93+
HubContentName=model_name,
94+
)
95+
doc = json.loads(response.get("HubContentDocument", "{}"))
96+
recipes = doc.get("RecipeCollection", [])
97+
98+
ft_recipes = [r for r in recipes if r.get("Type") == "FineTuning"]
99+
for i, recipe in enumerate(ft_recipes):
100+
recipe_name = recipe.get("Name", f"recipe-{i}")
101+
training_type = detect_training_type(recipe_name)
102+
training_type_enum = detect_lora_or_full(recipe_name)
103+
trainer_class = TRAINER_MAPPING[training_type]
104+
105+
trainer = trainer_class(
106+
model=model_name,
107+
training_type=training_type_enum,
108+
training_dataset=DUMMY_DATASET,
109+
model_package_group=DUMMY_MODEL_PACKAGE_GROUP,
110+
accept_eula=True,
111+
)
112+
assert trainer is not None, (
113+
f"{model_name}: {trainer_class.__name__} returned None"
114+
)
115+
except Exception as e: # noqa: BLE001 - collect all errors across all models
116+
errors.append(f"{model_name}: {e}")
117+
118+
assert not errors, "Recipe validation failures:\n" + "\n".join(errors)

0 commit comments

Comments
 (0)