Skip to content

Commit 21489b5

Browse files
committed
fix: write ModelTrainer train script with LF line endings
1 parent a3a20c7 commit 21489b5

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def _prepare_train_script(
10611061
execute_driver=execute_driver,
10621062
)
10631063

1064-
with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f:
1064+
with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w", newline="\n") as f:
10651065
f.write(train_script)
10661066

10671067
@classmethod

sagemaker-train/tests/unit/train/test_model_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,9 +559,10 @@ def test_train_with_distributed_config(
559559
)
560560

561561
assert os.path.exists(expected_train_script_path)
562-
with open(expected_train_script_path, "r") as f:
562+
with open(expected_train_script_path, "rb") as f:
563563
train_script_content = f.read()
564-
assert test_case["expected_template"] in train_script_content
564+
assert test_case["expected_template"] in train_script_content.decode("utf-8")
565+
assert b"\r\n" not in train_script_content
565566

566567
assert os.path.exists(expected_runner_json_path)
567568
with open(expected_runner_json_path, "r") as f:
@@ -1529,4 +1530,4 @@ def test_llmft_recipe_missing_training_image_error(modules_session):
15291530
)
15301531

15311532
# Clean up the temporary file
1532-
os.unlink(recipe.name)
1533+
os.unlink(recipe.name)

0 commit comments

Comments
 (0)