|
17 | 17 |
|
18 | 18 | from sagemaker.modules.train import ModelTrainer |
19 | 19 | from sagemaker.modules.configs import SourceCode, Compute |
20 | | -from sagemaker.modules.distributed import MPI, Torchrun |
| 20 | +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig |
21 | 21 |
|
22 | 22 | EXPECTED_HYPERPARAMETERS = { |
23 | 23 | "integer": 1, |
@@ -106,3 +106,35 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session): |
106 | 106 | ) |
107 | 107 |
|
108 | 108 | model_trainer.train() |
| 109 | + |
| 110 | + |
| 111 | +def test_custom_distributed_driver(modules_sagemaker_session): |
| 112 | + class CustomDriver(DistributedConfig): |
| 113 | + process_count_per_node: int = None |
| 114 | + |
| 115 | + @property |
| 116 | + def driver_dir(self) -> str: |
| 117 | + return f"{DATA_DIR}/modules/custom_drivers" |
| 118 | + |
| 119 | + @property |
| 120 | + def driver_script(self) -> str: |
| 121 | + return "driver.py" |
| 122 | + |
| 123 | + source_code = SourceCode( |
| 124 | + source_dir=f"{DATA_DIR}/modules/scripts", |
| 125 | + entry_script="entry_script.py", |
| 126 | + ) |
| 127 | + |
| 128 | + hyperparameters = {"epochs": 10} |
| 129 | + |
| 130 | + custom_driver = CustomDriver(process_count_per_node=2) |
| 131 | + |
| 132 | + model_trainer = ModelTrainer( |
| 133 | + sagemaker_session=modules_sagemaker_session, |
| 134 | + training_image=DEFAULT_CPU_IMAGE, |
| 135 | + hyperparameters=hyperparameters, |
| 136 | + source_code=source_code, |
| 137 | + distributed=custom_driver, |
| 138 | + base_job_name="custom-distributed-driver", |
| 139 | + ) |
| 140 | + model_trainer.train() |
0 commit comments