-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: Add recipe validation integ test for HP-ModelCustomization-RecipeValidator pipeline #5779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Recipe validation integ test for the HP-ModelCustomization-RecipeValidator pipeline. | ||
|
|
||
| Iterates through every model in the private hub referenced by the ``SAGEMAKER_HUB_NAME`` | ||
| env var and validates that each fine-tuning recipe can be used to instantiate the | ||
| appropriate ``sagemaker.train`` Trainer class (SFT/DPO/RLAIF/RLVR). | ||
| """ | ||
| from __future__ import absolute_import | ||
|
|
||
| import json | ||
| import os | ||
|
|
||
| import boto3 | ||
|
|
||
| from sagemaker.train.common import TrainingType | ||
| from sagemaker.train.dpo_trainer import DPOTrainer | ||
| from sagemaker.train.rlaif_trainer import RLAIFTrainer | ||
| from sagemaker.train.rlvr_trainer import RLVRTrainer | ||
| from sagemaker.train.sft_trainer import SFTTrainer | ||
|
|
||
| TRAINER_MAPPING = { | ||
| "sft": SFTTrainer, | ||
| "dpo": DPOTrainer, | ||
| "rlaif": RLAIFTrainer, | ||
| "rlvr": RLVRTrainer, | ||
| } | ||
|
|
||
| DUMMY_DATASET = "s3://placeholder/validation-data" | ||
| DUMMY_MODEL_PACKAGE_GROUP = "recipe-validation-test" | ||
|
|
||
|
|
||
| def detect_training_type(recipe_path: str) -> str: | ||
| """Detect SFT/DPO/RLAIF/RLVR from the recipe name; default to SFT.""" | ||
| if not recipe_path: | ||
| return "sft" | ||
| lower = recipe_path.lower() | ||
| if "rlvr" in lower: | ||
| return "rlvr" | ||
| if "rlaif" in lower: | ||
| return "rlaif" | ||
| if "dpo" in lower: | ||
| return "dpo" | ||
| return "sft" | ||
|
|
||
|
|
||
| def detect_lora_or_full(recipe_path: str) -> TrainingType: | ||
| """Detect LoRA vs full fine-tuning from the recipe name; default to LoRA.""" | ||
| if not recipe_path: | ||
| return TrainingType.LORA | ||
| lower = recipe_path.lower() | ||
| if "_fft" in lower or "full_fine_tuning" in lower: | ||
| return TrainingType.FULL | ||
| return TrainingType.LORA | ||
|
|
||
|
|
||
| def test_new_recipes_create_valid_trainers(): | ||
| """Validate every new/modified recipe in the private hub yields a valid Trainer.""" | ||
| hub_name = os.environ.get("SAGEMAKER_HUB_NAME") | ||
| assert hub_name, "SAGEMAKER_HUB_NAME environment variable must be set" | ||
|
|
||
| sm = boto3.client("sagemaker", region_name="us-west-2") | ||
|
|
||
| models = [] | ||
| kwargs = {"HubName": hub_name, "HubContentType": "Model"} | ||
| while True: | ||
| response = sm.list_hub_contents(**kwargs) | ||
| models.extend([item["HubContentName"] for item in response["HubContentSummaries"]]) | ||
| next_token = response.get("NextToken") | ||
| if not next_token: | ||
| break | ||
| kwargs["NextToken"] = next_token | ||
|
|
||
| if not models: | ||
| return | ||
|
|
||
| errors = [] | ||
| for model_name in models: | ||
| try: | ||
| response = sm.describe_hub_content( | ||
| HubName=hub_name, | ||
| HubContentType="Model", | ||
| HubContentName=model_name, | ||
| ) | ||
| doc = json.loads(response.get("HubContentDocument", "{}")) | ||
| recipes = doc.get("RecipeCollection", []) | ||
|
|
||
| ft_recipes = [r for r in recipes if r.get("Type") == "FineTuning"] | ||
| for i, recipe in enumerate(ft_recipes): | ||
| recipe_name = recipe.get("Name", f"recipe-{i}") | ||
| training_type = detect_training_type(recipe_name) | ||
| training_type_enum = detect_lora_or_full(recipe_name) | ||
| trainer_class = TRAINER_MAPPING[training_type] | ||
|
|
||
| trainer = trainer_class( | ||
| model=model_name, | ||
| training_type=training_type_enum, | ||
| training_dataset=DUMMY_DATASET, | ||
| model_package_group=DUMMY_MODEL_PACKAGE_GROUP, | ||
| accept_eula=True, | ||
| ) | ||
| assert trainer is not None, ( | ||
| f"{model_name}: {trainer_class.__name__} returned None" | ||
| ) | ||
| except Exception as e: # noqa: BLE001 - collect all errors across all models | ||
| errors.append(f"{model_name}: {e}") | ||
|
|
||
| assert not errors, "Recipe validation failures:\n" + "\n".join(errors) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it sufficient to check here that we can instantiate a trainer class? Could we also submit a test job and verify that interaction with smjobs/k8s will work?
We can potentially use a small/dummy dataset so that the job doesn't run for long but still verify that the end to customer interaction via PySDK will work for new recipes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instantiation-only is the right scope for this validation step — it catches the most likely breakages: schema mismatches, missing fields, and unsupported training types in the hub content fetch → recipe parsing → Trainer construction path.
Running real training jobs would require significant infrastructure changes to the validation account — GPU instance quotas, CreateTrainingJob permissions, per-technique dummy datasets, and cleanup logic, none of which exist today. We do already have e2e integ tests in the PySDK repo that submit real training jobs for a subset of recipes, so the full job path is partially covered. If we want broader e2e coverage for all new recipes, I'd suggest scoping that as a follow-up with its own infrastructure workstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do want to be able to test that the job is able to start/run to verify the customer workflow before launch. Could you please add a Note here as a follow up task?