Skip to content

Commit 59ea1ab

Browse files
authored
Fix imports for Model Customization interfaces (#5832)
1 parent ab5f478 commit 59ea1ab

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

sagemaker-train/src/sagemaker/train/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ def __getattr__(name):
2828
elif name == "ModelTrainer":
2929
from sagemaker.train.model_trainer import ModelTrainer
3030
return ModelTrainer
31+
elif name == "SFTTrainer":
32+
from sagemaker.train.sft_trainer import SFTTrainer
33+
return SFTTrainer
34+
elif name == "DPOTrainer":
35+
from sagemaker.train.dpo_trainer import DPOTrainer
36+
return DPOTrainer
37+
elif name == "RLVRTrainer":
38+
from sagemaker.train.rlvr_trainer import RLVRTrainer
39+
return RLVRTrainer
40+
elif name == "RLAIFTrainer":
41+
from sagemaker.train.rlaif_trainer import RLAIFTrainer
42+
return RLAIFTrainer
43+
elif name == "TrainingType":
44+
from sagemaker.train.common import TrainingType
45+
return TrainingType
46+
elif name == "CustomizationTechnique":
47+
from sagemaker.train.common import CustomizationTechnique
48+
return CustomizationTechnique
3149
elif name == "logger":
3250
from sagemaker.core.utils.utils import logger
3351
return logger

0 commit comments

Comments
 (0)