1414from __future__ import absolute_import
1515
1616import time
17+ import random
1718import boto3
1819from sagemaker .core .helper .session_helper import Session
1920from sagemaker .train .rlaif_trainer import RLAIFTrainer
2021from sagemaker .train .common import TrainingType
2122import pytest
2223
24+ pytestmark = pytest .mark .gpu_intensive
25+
2326
24- @pytest .mark .skip (reason = "Skipping GPU resource intensive test" )
2527def test_rlaif_trainer_lora_complete_workflow (sagemaker_session ):
2628 """Test complete RLAIF training workflow with LORA."""
29+ unique_id = f"{ int (time .time ())} -{ random .randint (1000 , 9999 )} "
2730
2831 rlaif_trainer = RLAIFTrainer (
2932 model = "meta-textgeneration-llama-3-2-1b-instruct" ,
@@ -33,9 +36,10 @@ def test_rlaif_trainer_lora_complete_workflow(sagemaker_session):
3336 reward_prompt = 'Builtin.Summarize' ,
3437 mlflow_experiment_name = "test-rlaif-finetuned-models-exp" ,
3538 mlflow_run_name = "test-rlaif-finetuned-models-run" ,
36- training_dataset = "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/ rlvr-rlaif-oss- test-data/0.0.1 " ,
39+ training_dataset = "s3://mc-flows-sdk-testing/input_data/ rlvr-rlaif-test-data/train_285.jsonl " ,
3740 s3_output_path = "s3://mc-flows-sdk-testing/output/" ,
38- accept_eula = True
41+ accept_eula = True ,
42+ base_job_name = f"rlaif-lora-integ-{ unique_id } " ,
3943 )
4044
4145 # Create training job
@@ -61,9 +65,9 @@ def test_rlaif_trainer_lora_complete_workflow(sagemaker_session):
6165 assert training_job .output_model_package_arn is not None
6266
6367
64- @pytest .mark .skip (reason = "Skipping GPU resource intensive test" )
6568def test_rlaif_trainer_with_custom_reward_settings (sagemaker_session ):
6669 """Test RLAIF trainer with different reward model and prompt."""
70+ unique_id = f"{ int (time .time ())} -{ random .randint (1000 , 9999 )} "
6771
6872 rlaif_trainer = RLAIFTrainer (
6973 model = "meta-textgeneration-llama-3-2-1b-instruct" ,
@@ -73,9 +77,10 @@ def test_rlaif_trainer_with_custom_reward_settings(sagemaker_session):
7377 reward_prompt = "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/rlaif-test-prompt/0.0.1" ,
7478 mlflow_experiment_name = "test-rlaif-finetuned-models-exp" ,
7579 mlflow_run_name = "test-rlaif-finetuned-models-run" ,
76- training_dataset = "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/ rlvr-rlaif-oss- test-data/0.0.1 " ,
80+ training_dataset = "s3://mc-flows-sdk-testing/input_data/ rlvr-rlaif-test-data/train_285.jsonl " ,
7781 s3_output_path = "s3://mc-flows-sdk-testing/output/" ,
78- accept_eula = True
82+ accept_eula = True ,
83+ base_job_name = f"rlaif-rwd-integ-{ unique_id } " ,
7984 )
8085
8186 training_job = rlaif_trainer .train (wait = False )
@@ -100,9 +105,9 @@ def test_rlaif_trainer_with_custom_reward_settings(sagemaker_session):
100105 assert training_job .output_model_package_arn is not None
101106
102107
103- @pytest .mark .skip (reason = "Skipping GPU resource intensive test" )
104108def test_rlaif_trainer_continued_finetuning (sagemaker_session ):
105109 """Test complete RLAIF training workflow with LORA."""
110+ unique_id = f"{ int (time .time ())} -{ random .randint (1000 , 9999 )} "
106111
107112 rlaif_trainer = RLAIFTrainer (
108113 model = "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1" ,
@@ -112,9 +117,10 @@ def test_rlaif_trainer_continued_finetuning(sagemaker_session):
112117 reward_prompt = 'Builtin.Summarize' ,
113118 mlflow_experiment_name = "test-rlaif-finetuned-models-exp" ,
114119 mlflow_run_name = "test-rlaif-finetuned-models-run" ,
115- training_dataset = "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/ rlvr-rlaif-oss- test-data/0.0.1 " ,
120+ training_dataset = "s3://mc-flows-sdk-testing/input_data/ rlvr-rlaif-test-data/train_285.jsonl " ,
116121 s3_output_path = "s3://mc-flows-sdk-testing/output/" ,
117- accept_eula = True
122+ accept_eula = True ,
123+ base_job_name = f"rlaif-cont-integ-{ unique_id } " ,
118124 )
119125
120126 # Create training job
0 commit comments