|
22 | 22 | import pytest |
23 | 23 |
|
24 | 24 |
|
25 | | -@pytest.mark.skip(reason="Skipping GPU resource intensive test") |
26 | 25 | def test_dpo_trainer_lora_complete_workflow(sagemaker_session): |
27 | 26 | """Test complete DPO training workflow with LORA.""" |
28 | 27 | # Create DPOTrainer instance with comprehensive configuration |
29 | 28 | trainer = DPOTrainer( |
30 | 29 | model="meta-textgeneration-llama-3-2-1b-instruct", |
31 | 30 | training_type=TrainingType.LORA, |
32 | 31 | model_package_group="sdk-test-finetuned-models", |
33 | | - training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", |
| 32 | + training_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", |
34 | 33 | s3_output_path="s3://mc-flows-sdk-testing/output/", |
35 | 34 | accept_eula=True |
36 | 35 | ) |
@@ -61,16 +60,15 @@ def test_dpo_trainer_lora_complete_workflow(sagemaker_session): |
61 | 60 | assert training_job.output_model_package_arn is not None |
62 | 61 |
|
63 | 62 |
|
64 | | -@pytest.mark.skip(reason="Skipping GPU resource intensive test") |
65 | 63 | def test_dpo_trainer_with_validation_dataset(sagemaker_session): |
66 | 64 | """Test DPO trainer with both training and validation datasets.""" |
67 | 65 |
|
68 | 66 | dpo_trainer = DPOTrainer( |
69 | 67 | model="meta-textgeneration-llama-3-2-1b-instruct", |
70 | 68 | training_type=TrainingType.LORA, |
71 | 69 | model_package_group="sdk-test-finetuned-models", |
72 | | - training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", |
73 | | - validation_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1", |
| 70 | + training_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", |
| 71 | + validation_dataset="s3://mc-flows-sdk-testing/input_data/dpo/preference_dataset_train_256.jsonl", |
74 | 72 | s3_output_path="s3://mc-flows-sdk-testing/output/", |
75 | 73 | accept_eula=True |
76 | 74 | ) |
|
0 commit comments