@@ -135,3 +135,39 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1):
135135 assert training_job .training_job_status == "Completed"
136136 assert hasattr (training_job , 'output_model_package_arn' )
137137 assert training_job .output_model_package_arn is not None
138+
139+
140+ @pytest .mark .gpu_intensive
141+ def test_sft_trainer_lora_with_sequence_length (sagemaker_session ):
142+ """Test SFT training workflow with LORA and sequence_length specified."""
143+ unique_id = f"{ int (time .time ())} -{ random .randint (1000 , 9999 )} "
144+
145+ sft_trainer = SFTTrainer (
146+ model = "meta-textgeneration-llama-3-2-1b-instruct" ,
147+ training_type = TrainingType .LORA ,
148+ model_package_group = "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models" ,
149+ training_dataset = "s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl" ,
150+ s3_output_path = "s3://mc-flows-sdk-testing/output/" ,
151+ accept_eula = True ,
152+ sequence_length = "8K" ,
153+ base_job_name = f"sft-seqlen-integ-{ unique_id } " ,
154+ )
155+
156+ training_job = sft_trainer .train (wait = False )
157+
158+ max_wait_time = 3600
159+ poll_interval = 30
160+ start_time = time .time ()
161+
162+ while time .time () - start_time < max_wait_time :
163+ training_job .refresh ()
164+ status = training_job .training_job_status
165+
166+ if status in ["Completed" , "Failed" , "Stopped" ]:
167+ break
168+
169+ time .sleep (poll_interval )
170+
171+ assert training_job .training_job_status == "Completed"
172+ assert hasattr (training_job , 'output_model_package_arn' )
173+ assert training_job .output_model_package_arn is not None
0 commit comments