1313"""Distributed module."""
1414from __future__ import absolute_import
1515
16+ import os
17+
18+ from abc import ABC , abstractmethod
1619from typing import Optional , Dict , Any , List
17- from pydantic import PrivateAttr
20+ from pydantic import BaseModel
1821from sagemaker .modules .utils import safe_serialize
19- from sagemaker .modules .configs import BaseConfig
22+ from sagemaker .modules .constants import SM_DRIVERS_LOCAL_PATH
2023
2124
2225class SMP (BaseConfig ):
@@ -73,16 +76,39 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7376 return hyperparameters
7477
7578
76- class DistributedConfig (BaseConfig ):
77- """Base class for distributed training configurations."""
79+ class DistributedConfig (BaseModel , ABC ):
80+ """Abstract base class for distributed training configurations.
81+
82+ This class defines the interface that all distributed training configurations
83+ must implement. It provides a standardized way to specify driver scripts and
84+ their locations for distributed training jobs.
85+ """
86+
87+ @property
88+ @abstractmethod
89+ def driver_dir (self ) -> str :
90+ """Directory containing the driver script.
91+
92+ This property should return the path to the directory containing
93+ the driver script, relative to the container's working directory.
94+
95+ Returns:
96+ str: Path to directory containing the driver script
97+ """
98+ pass
99+
100+ @property
101+ @abstractmethod
102+ def driver_script (self ) -> str :
103+ """Name of the driver script.
78104
79- _type : str = PrivateAttr ()
105+ This property should return the name of the Python script that implements
106+ the distributed training driver logic.
80107
81- def model_dump (self , * args , ** kwargs ):
82- """Dump the model to a dictionary."""
83- result = super ().model_dump (* args , ** kwargs )
84- result ["_type" ] = self ._type
85- return result
108+ Returns:
109+ str: Name of the driver script file
110+ """
111+ pass
86112
87113
88114class Torchrun (DistributedConfig ):
@@ -99,11 +125,17 @@ class Torchrun(DistributedConfig):
99125 The SageMaker Model Parallelism v2 parameters.
100126 """
101127
102- _type : str = PrivateAttr (default = "torchrun" )
103-
104128 process_count_per_node : Optional [int ] = None
105129 smp : Optional ["SMP" ] = None
106130
131+ @property
132+ def driver_dir (self ) -> str :
133+ return os .path .join (SM_DRIVERS_LOCAL_PATH , "drivers" )
134+
135+ @property
136+ def driver_script (self ) -> str :
137+ return "torchrun_driver.py"
138+
107139
108140class MPI (DistributedConfig ):
109141 """MPI.
@@ -119,7 +151,13 @@ class MPI(DistributedConfig):
119151 The custom MPI options to use for the training job.
120152 """
121153
122- _type : str = PrivateAttr (default = "mpi" )
123-
124154 process_count_per_node : Optional [int ] = None
125155 mpi_additional_options : Optional [List [str ]] = None
156+
157+ @property
158+ def driver_dir (self ) -> str :
159+ return os .path .join (SM_DRIVERS_LOCAL_PATH , "drivers" )
160+
161+ @property
162+ def driver_script (self ) -> str :
163+ return "mpi_driver.py"
0 commit comments