Skip to content

Commit dbded70

Browse files
committed
add test
Signed-off-by: oliver könig <okoenig@nvidia.com>
1 parent 7342913 commit dbded70

1 file changed

Lines changed: 321 additions & 1 deletion

File tree

test/core/execution/test_dgxcloud.py

Lines changed: 321 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,4 +1143,324 @@ def test_default_headers_with_token(self):
11431143
assert headers["Content-Type"] == "application/json"
11441144
assert "Authorization" in headers
11451145
assert headers["Authorization"] == "Bearer test_token"
1146-
assert headers["Authorization"] == "Bearer test_token"
1146+
1147+
1148+
class TestDGXCloudRequest:
1149+
"""Test DGXCloudRequest dataclass and its methods."""
1150+
1151+
@pytest.fixture
1152+
def basic_executor(self):
1153+
"""Create a basic DGXCloudExecutor for testing."""
1154+
return DGXCloudExecutor(
1155+
base_url="https://dgxapi.example.com",
1156+
kube_apiserver_url="https://127.0.0.1:443",
1157+
app_id="test_app_id",
1158+
app_secret="test_app_secret",
1159+
project_name="test_project",
1160+
container_image="nvcr.io/nvidia/test:latest",
1161+
pvc_nemo_run_dir="/workspace/nemo_run",
1162+
)
1163+
1164+
@pytest.fixture
1165+
def executor_with_env_vars(self):
1166+
"""Create a DGXCloudExecutor with environment variables."""
1167+
return DGXCloudExecutor(
1168+
base_url="https://dgxapi.example.com",
1169+
kube_apiserver_url="https://127.0.0.1:443",
1170+
app_id="test_app_id",
1171+
app_secret="test_app_secret",
1172+
project_name="test_project",
1173+
container_image="nvcr.io/nvidia/test:latest",
1174+
pvc_nemo_run_dir="/workspace/nemo_run",
1175+
env_vars={"EXECUTOR_VAR": "executor_value", "SHARED_VAR": "from_executor"},
1176+
)
1177+
1178+
def test_dgxcloud_request_init(self, basic_executor):
1179+
"""Test basic initialization of DGXCloudRequest."""
1180+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1181+
1182+
request = DGXCloudRequest(
1183+
launch_cmd=["python", "train.py"],
1184+
jobs=["job1", "job2"],
1185+
executor=basic_executor,
1186+
max_retries=3,
1187+
extra_env={"EXTRA_VAR": "extra_value"},
1188+
)
1189+
1190+
assert request.launch_cmd == ["python", "train.py"]
1191+
assert request.jobs == ["job1", "job2"]
1192+
assert request.executor == basic_executor
1193+
assert request.max_retries == 3
1194+
assert request.extra_env == {"EXTRA_VAR": "extra_value"}
1195+
assert request.launcher is None
1196+
1197+
def test_dgxcloud_request_with_launcher(self, basic_executor):
1198+
"""Test DGXCloudRequest with a launcher."""
1199+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1200+
from nemo_run.core.execution.launcher import Torchrun
1201+
1202+
launcher = Torchrun()
1203+
request = DGXCloudRequest(
1204+
launch_cmd=["python", "train.py"],
1205+
jobs=["job1"],
1206+
executor=basic_executor,
1207+
max_retries=5,
1208+
extra_env={},
1209+
launcher=launcher,
1210+
)
1211+
1212+
assert request.launcher == launcher
1213+
assert isinstance(request.launcher, Torchrun)
1214+
1215+
def test_materialize_basic(self, basic_executor):
1216+
"""Test materialization of a basic request without fault tolerance."""
1217+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1218+
1219+
request = DGXCloudRequest(
1220+
launch_cmd=["python", "train.py", "--epochs", "10"],
1221+
jobs=["job1"],
1222+
executor=basic_executor,
1223+
max_retries=3,
1224+
extra_env={"MY_VAR": "my_value"},
1225+
)
1226+
1227+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1228+
mock_fill.return_value = "#!/bin/bash\necho 'test script'"
1229+
script = request.materialize()
1230+
1231+
# Verify fill_template was called
1232+
mock_fill.assert_called_once()
1233+
args, kwargs = mock_fill.call_args
1234+
assert args[0] == "dgxc.sh.j2"
1235+
1236+
template_vars = args[1]
1237+
assert template_vars["max_retries"] == 3
1238+
assert template_vars["training_command"] == "python train.py --epochs 10"
1239+
assert template_vars["ft_enabled"] is False
1240+
assert "export MY_VAR=my_value" in template_vars["env_vars"]
1241+
1242+
assert script == "#!/bin/bash\necho 'test script'"
1243+
1244+
def test_materialize_with_env_vars(self, executor_with_env_vars):
1245+
"""Test that environment variables from executor and extra_env are merged."""
1246+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1247+
1248+
request = DGXCloudRequest(
1249+
launch_cmd=["python", "train.py"],
1250+
jobs=["job1"],
1251+
executor=executor_with_env_vars,
1252+
max_retries=1,
1253+
extra_env={"EXTRA_VAR": "extra_value", "SHARED_VAR": "from_extra"},
1254+
)
1255+
1256+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1257+
mock_fill.return_value = "mock_script"
1258+
request.materialize()
1259+
1260+
template_vars = mock_fill.call_args[0][1]
1261+
env_vars = template_vars["env_vars"]
1262+
1263+
# Check that variables are present (order may vary due to dict merge)
1264+
assert "export EXECUTOR_VAR=executor_value" in env_vars
1265+
assert "export EXTRA_VAR=extra_value" in env_vars
1266+
# extra_env should override executor.env_vars for SHARED_VAR
1267+
assert "export SHARED_VAR=from_extra" in env_vars
1268+
assert "export SHARED_VAR=from_executor" not in env_vars
1269+
1270+
def test_materialize_with_fault_tolerance(self, basic_executor):
1271+
"""Test materialization with fault tolerance enabled."""
1272+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1273+
from nemo_run.core.execution.launcher import FaultTolerance
1274+
1275+
ft_launcher = FaultTolerance(
1276+
cfg_path="/workspace/ft_config.yaml",
1277+
finished_flag_file="/workspace/.ft_finished",
1278+
job_results_file="/workspace/ft_results.json",
1279+
)
1280+
1281+
request = DGXCloudRequest(
1282+
launch_cmd=["python", "train.py"],
1283+
jobs=["job1"],
1284+
executor=basic_executor,
1285+
max_retries=5,
1286+
extra_env={},
1287+
launcher=ft_launcher,
1288+
)
1289+
1290+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1291+
mock_fill.return_value = "ft_script"
1292+
_ = request.materialize()
1293+
1294+
template_vars = mock_fill.call_args[0][1]
1295+
assert template_vars["ft_enabled"] is True
1296+
assert template_vars["fault_tol_cfg_path"] == "/workspace/ft_config.yaml"
1297+
assert template_vars["fault_tol_finished_flag_file"] == "/workspace/.ft_finished"
1298+
assert template_vars["fault_tol_job_results_file"] == "/workspace/ft_results.json"
1299+
1300+
def test_materialize_fault_tolerance_missing_fields(self, basic_executor):
1301+
"""Test that fault tolerance with missing required fields raises an error."""
1302+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1303+
from nemo_run.core.execution.launcher import FaultTolerance
1304+
1305+
# Create FaultTolerance with missing required fields
1306+
ft_launcher = FaultTolerance(
1307+
cfg_path="/workspace/ft_config.yaml",
1308+
# Missing finished_flag_file and job_results_file
1309+
)
1310+
1311+
request = DGXCloudRequest(
1312+
launch_cmd=["python", "train.py"],
1313+
jobs=["job1"],
1314+
executor=basic_executor,
1315+
max_retries=5,
1316+
extra_env={},
1317+
launcher=ft_launcher,
1318+
)
1319+
1320+
with pytest.raises(AssertionError) as exc_info:
1321+
with patch("nemo_run.core.execution.dgxcloud.fill_template"):
1322+
request.materialize()
1323+
1324+
assert "Fault Tolerance requires" in str(exc_info.value)
1325+
1326+
def test_materialize_with_non_fault_tolerance_launcher(self, basic_executor):
1327+
"""Test materialization with a non-FaultTolerance launcher (e.g., Torchrun)."""
1328+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1329+
from nemo_run.core.execution.launcher import Torchrun
1330+
1331+
launcher = Torchrun()
1332+
request = DGXCloudRequest(
1333+
launch_cmd=["python", "train.py"],
1334+
jobs=["job1"],
1335+
executor=basic_executor,
1336+
max_retries=2,
1337+
extra_env={},
1338+
launcher=launcher,
1339+
)
1340+
1341+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1342+
mock_fill.return_value = "torchrun_script"
1343+
_ = request.materialize()
1344+
1345+
template_vars = mock_fill.call_args[0][1]
1346+
# FT should be disabled for non-FaultTolerance launchers
1347+
assert template_vars["ft_enabled"] is False
1348+
# FT-specific fields should not be in template_vars
1349+
assert "fault_tol_cfg_path" not in template_vars
1350+
assert "fault_tol_finished_flag_file" not in template_vars
1351+
assert "fault_tol_job_results_file" not in template_vars
1352+
1353+
def test_materialize_empty_extra_env(self, basic_executor):
1354+
"""Test materialization with empty extra_env."""
1355+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1356+
1357+
request = DGXCloudRequest(
1358+
launch_cmd=["python", "train.py"],
1359+
jobs=["job1"],
1360+
executor=basic_executor,
1361+
max_retries=1,
1362+
extra_env={},
1363+
)
1364+
1365+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1366+
mock_fill.return_value = "script"
1367+
request.materialize()
1368+
1369+
template_vars = mock_fill.call_args[0][1]
1370+
assert template_vars["env_vars"] == []
1371+
1372+
def test_materialize_uppercase_env_vars(self, basic_executor):
1373+
"""Test that environment variable keys are uppercased."""
1374+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1375+
1376+
request = DGXCloudRequest(
1377+
launch_cmd=["python", "train.py"],
1378+
jobs=["job1"],
1379+
executor=basic_executor,
1380+
max_retries=1,
1381+
extra_env={"lowercase_var": "value", "MixedCase": "value2"},
1382+
)
1383+
1384+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1385+
mock_fill.return_value = "script"
1386+
request.materialize()
1387+
1388+
template_vars = mock_fill.call_args[0][1]
1389+
env_vars = template_vars["env_vars"]
1390+
1391+
# Keys should be uppercased
1392+
assert "export LOWERCASE_VAR=value" in env_vars
1393+
assert "export MIXEDCASE=value2" in env_vars
1394+
1395+
def test_repr(self, basic_executor):
1396+
"""Test the __repr__ method."""
1397+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1398+
1399+
request = DGXCloudRequest(
1400+
launch_cmd=["python", "train.py"],
1401+
jobs=["job1", "job2"],
1402+
executor=basic_executor,
1403+
max_retries=3,
1404+
extra_env={},
1405+
)
1406+
1407+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1408+
mock_fill.return_value = "#!/bin/bash\necho 'script content'"
1409+
repr_str = repr(request)
1410+
1411+
assert "# DGXC Entrypoint Script Request" in repr_str
1412+
assert "# Executor: DGXCloudExecutor" in repr_str
1413+
assert "# Jobs: ['job1', 'job2']" in repr_str
1414+
assert "#!/bin/bash" in repr_str
1415+
assert "echo 'script content'" in repr_str
1416+
1417+
def test_complex_launch_command(self, basic_executor):
1418+
"""Test materialization with a complex multi-argument launch command."""
1419+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1420+
1421+
request = DGXCloudRequest(
1422+
launch_cmd=[
1423+
"torchrun",
1424+
"--nproc_per_node=8",
1425+
"--nnodes=2",
1426+
"train.py",
1427+
"--batch-size",
1428+
"32",
1429+
"--lr",
1430+
"0.001",
1431+
],
1432+
jobs=["job1"],
1433+
executor=basic_executor,
1434+
max_retries=1,
1435+
extra_env={},
1436+
)
1437+
1438+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1439+
mock_fill.return_value = "script"
1440+
request.materialize()
1441+
1442+
template_vars = mock_fill.call_args[0][1]
1443+
expected_cmd = (
1444+
"torchrun --nproc_per_node=8 --nnodes=2 train.py --batch-size 32 --lr 0.001"
1445+
)
1446+
assert template_vars["training_command"] == expected_cmd
1447+
1448+
def test_max_retries_values(self, basic_executor):
1449+
"""Test different max_retries values."""
1450+
from nemo_run.core.execution.dgxcloud import DGXCloudRequest
1451+
1452+
for retries in [0, 1, 10, 100]:
1453+
request = DGXCloudRequest(
1454+
launch_cmd=["python", "train.py"],
1455+
jobs=["job1"],
1456+
executor=basic_executor,
1457+
max_retries=retries,
1458+
extra_env={},
1459+
)
1460+
1461+
with patch("nemo_run.core.execution.dgxcloud.fill_template") as mock_fill:
1462+
mock_fill.return_value = "script"
1463+
request.materialize()
1464+
1465+
template_vars = mock_fill.call_args[0][1]
1466+
assert template_vars["max_retries"] == retries

0 commit comments

Comments
 (0)