|
176 | 176 | fi |
177 | 177 |
|
178 | 178 | printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" |
179 | | - printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n" |
180 | | - $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@" |
| 179 | + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" |
| 180 | + $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" |
181 | 181 | else |
182 | 182 | printf "INFO: No conda env provided. Invoking remote function\\n" |
183 | | - printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n" |
184 | | - python -m sagemaker.core.remote_function.invoke_function "$@" |
| 183 | + printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" |
| 184 | + python -m sagemaker.train.remote_function.invoke_function "$@" |
185 | 185 | fi |
186 | 186 | """ |
187 | 187 |
|
|
235 | 235 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
236 | 236 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
237 | 237 |
|
238 | | - python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
| 238 | + python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
239 | 239 | $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
240 | 240 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
241 | 241 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
242 | 242 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
243 | 243 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
244 | 244 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
245 | | - python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
| 245 | + python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
246 | 246 |
|
247 | 247 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
248 | 248 | else |
|
260 | 260 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
261 | 261 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
262 | 262 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
263 | | - python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
| 263 | + python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
264 | 264 |
|
265 | 265 | mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
266 | 266 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
267 | 267 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
268 | 268 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
269 | 269 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
270 | 270 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
271 | | - python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
| 271 | + python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
272 | 272 |
|
273 | 273 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
274 | 274 | else |
|
321 | 321 | printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" |
322 | 322 | printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
323 | 323 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
324 | | - -m sagemaker.core.remote_function.invoke_function \\n" |
| 324 | + -m sagemaker.train.remote_function.invoke_function \\n" |
325 | 325 |
|
326 | 326 | $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
327 | 327 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
328 | | - -m sagemaker.core.remote_function.invoke_function "$@" |
| 328 | + -m sagemaker.train.remote_function.invoke_function "$@" |
329 | 329 | else |
330 | 330 | printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" |
331 | 331 | printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
332 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n" |
| 332 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" |
333 | 333 |
|
334 | 334 | torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
335 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@" |
| 335 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" |
336 | 336 | fi |
337 | 337 | """ |
338 | 338 |
|
@@ -1278,215 +1278,7 @@ def _prepare_and_upload_runtime_scripts( |
1278 | 1278 | return upload_path |
1279 | 1279 |
|
1280 | 1280 |
|
1281 | | -def _decrement_version(version_str: str) -> str: |
1282 | | - """Decrement a version string by one minor or patch version. |
1283 | | - |
1284 | | - Rules: |
1285 | | - - If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0 |
1286 | | - - If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1 |
1287 | | - |
1288 | | - Args: |
1289 | | - version_str: Version string (e.g., "3.2.0") |
1290 | | - |
1291 | | - Returns: |
1292 | | - Decremented version string |
1293 | | - """ |
1294 | | - from packaging import version as pkg_version |
1295 | | - |
1296 | | - try: |
1297 | | - parsed = pkg_version.parse(version_str) |
1298 | | - major = parsed.major |
1299 | | - minor = parsed.minor |
1300 | | - patch = parsed.micro |
1301 | | - |
1302 | | - if patch == 0: |
1303 | | - # Decrement minor version |
1304 | | - minor = max(0, minor - 1) |
1305 | | - else: |
1306 | | - # Decrement patch version |
1307 | | - patch = max(0, patch - 1) |
1308 | | - |
1309 | | - return f"{major}.{minor}.{patch}" |
1310 | | - except Exception: |
1311 | | - return version_str |
1312 | | - |
1313 | | - |
1314 | | -def _resolve_version_from_specifier(specifier_str: str) -> str: |
1315 | | - """Resolve the version to check based on upper bounds. |
1316 | | - |
1317 | | - Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only). |
1318 | | - If no upper bound exists, it's safe (unbounded). |
1319 | | - If the decremented upper bound is less than a lower bound, use the lower bound. |
1320 | | - |
1321 | | - Args: |
1322 | | - specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0") |
1323 | | - |
1324 | | - Returns: |
1325 | | - The resolved version string to check, or None if safe |
1326 | | - """ |
1327 | | - import re |
1328 | | - from packaging import version as pkg_version |
1329 | | - |
1330 | | - # Handle exact version pinning (==) |
1331 | | - match = re.search(r'==\s*([\d.]+)', specifier_str) |
1332 | | - if match: |
1333 | | - return match.group(1) |
1334 | | - |
1335 | | - # Extract lower bounds for comparison |
1336 | | - lower_bounds = [] |
1337 | | - for match in re.finditer(r'>=\s*([\d.]+)', specifier_str): |
1338 | | - lower_bounds.append(match.group(1)) |
1339 | | - |
1340 | | - # Handle upper bounds - find the most restrictive one |
1341 | | - upper_bounds = [] |
1342 | | - |
1343 | | - # Find all <= bounds |
1344 | | - for match in re.finditer(r'<=\s*([\d.]+)', specifier_str): |
1345 | | - upper_bounds.append(('<=', match.group(1))) |
1346 | | - |
1347 | | - # Find all < bounds |
1348 | | - for match in re.finditer(r'<\s*([\d.]+)', specifier_str): |
1349 | | - upper_bounds.append(('<', match.group(1))) |
1350 | | - |
1351 | | - if upper_bounds: |
1352 | | - # Sort by version to find the most restrictive (lowest) upper bound |
1353 | | - upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) |
1354 | | - operator, version = upper_bounds[0] |
1355 | | - |
1356 | | - # Special case: if upper bound is <4.0.0, it's safe (V3 only) |
1357 | | - try: |
1358 | | - parsed_upper = pkg_version.parse(version) |
1359 | | - if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0: |
1360 | | - # <4.0.0 means V3 only, which is safe |
1361 | | - return None |
1362 | | - except Exception: |
1363 | | - pass |
1364 | | - |
1365 | | - resolved_version = version |
1366 | | - if operator == '<': |
1367 | | - resolved_version = _decrement_version(version) |
1368 | | - |
1369 | | - # If we have a lower bound and the resolved version is less than it, use the lower bound |
1370 | | - if lower_bounds: |
1371 | | - try: |
1372 | | - resolved_parsed = pkg_version.parse(resolved_version) |
1373 | | - for lower_bound_str in lower_bounds: |
1374 | | - lower_parsed = pkg_version.parse(lower_bound_str) |
1375 | | - if resolved_parsed < lower_parsed: |
1376 | | - resolved_version = lower_bound_str |
1377 | | - except Exception: |
1378 | | - pass |
1379 | | - |
1380 | | - return resolved_version |
1381 | | - |
1382 | | - # For lower bounds only (>=, >), we don't check |
1383 | | - return None |
1384 | | - |
1385 | | - |
1386 | | -def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: |
1387 | | - """Check if the sagemaker version requirement uses incompatible hashing. |
1388 | | -
|
1389 | | - Raises ValueError if the requirement would install a version that uses HMAC hashing |
1390 | | - (which is incompatible with the current SHA256-based integrity checks). |
1391 | | -
|
1392 | | - Args: |
1393 | | - sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0") |
1394 | | -
|
1395 | | - Raises: |
1396 | | - ValueError: If the requirement would install a version using HMAC hashing |
1397 | | - """ |
1398 | | - import re |
1399 | | - from packaging import version as pkg_version |
1400 | | - |
1401 | | - match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) |
1402 | | - if not match: |
1403 | | - return |
1404 | | - |
1405 | | - specifier_str = match.group(1).strip() |
1406 | | - |
1407 | | - # Resolve the version that would be installed |
1408 | | - resolved_version_str = _resolve_version_from_specifier(specifier_str) |
1409 | | - if not resolved_version_str: |
1410 | | - # No upper bound or exact version, so we can't determine if it's bad |
1411 | | - return |
1412 | | - |
1413 | | - try: |
1414 | | - resolved_version = pkg_version.parse(resolved_version_str) |
1415 | | - except Exception: |
1416 | | - return |
1417 | | - |
1418 | | - # Define HMAC thresholds for each major version |
1419 | | - v2_hmac_threshold = pkg_version.parse("2.256.0") |
1420 | | - v3_hmac_threshold = pkg_version.parse("3.2.0") |
1421 | | - |
1422 | | - # Check if the resolved version uses HMAC hashing |
1423 | | - uses_hmac = False |
1424 | | - if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: |
1425 | | - uses_hmac = True |
1426 | | - elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: |
1427 | | - uses_hmac = True |
1428 | | - |
1429 | | - if uses_hmac: |
1430 | | - raise ValueError( |
1431 | | - f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " |
1432 | | - f"could install a version using HMAC-based integrity checks which are incompatible " |
1433 | | - f"with the current SHA256-based integrity checks. Please update to " |
1434 | | - f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." |
1435 | | - ) |
1436 | | - |
1437 | | - |
1438 | | -def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: |
1439 | | - """Ensure sagemaker>=3.2.0 is in the dependencies. |
1440 | | -
|
1441 | | - This function ensures that the remote environment has a compatible version of sagemaker |
1442 | | - that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based |
1443 | | - integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable. |
1444 | | - Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require |
1445 | | - the secret key. |
1446 | | -
|
1447 | | - If no dependencies are provided, creates a temporary requirements.txt with sagemaker. |
1448 | | - If dependencies are provided, appends sagemaker if not already present. |
1449 | | -
|
1450 | | - Args: |
1451 | | - local_dependencies_path: Path to user's dependencies file or None |
1452 | | -
|
1453 | | - Returns: |
1454 | | - Path to the dependencies file (created or modified) |
1455 | | -
|
1456 | | - Raises: |
1457 | | - ValueError: If user has pinned sagemaker to a version using HMAC hashing |
1458 | | - """ |
1459 | | - import tempfile |
1460 | | - |
1461 | | - SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0" |
1462 | | - |
1463 | | - if local_dependencies_path is None: |
1464 | | - fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") |
1465 | | - os.close(fd) |
1466 | | - |
1467 | | - with open(req_file, "w") as f: |
1468 | | - f.write(f"{SAGEMAKER_MIN_VERSION}\n") |
1469 | | - logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION) |
1470 | | - return req_file |
1471 | | - |
1472 | | - if local_dependencies_path.endswith(".txt"): |
1473 | | - with open(local_dependencies_path, "r") as f: |
1474 | | - content = f.read() |
1475 | | - |
1476 | | - if "sagemaker" in content.lower(): |
1477 | | - for line in content.split('\n'): |
1478 | | - if 'sagemaker' in line.lower(): |
1479 | | - _check_sagemaker_version_compatibility(line.strip()) |
1480 | | - break |
1481 | | - else: |
1482 | | - with open(local_dependencies_path, "a") as f: |
1483 | | - f.write(f"\n{SAGEMAKER_MIN_VERSION}\n") |
1484 | | - logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION) |
1485 | | - |
1486 | | - return local_dependencies_path |
1487 | | - |
1488 | | - |
1489 | | -def _generate_input_data_config(job_settings, s3_base_uri): |
| 1281 | +def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): |
1490 | 1282 | """Generates input data config""" |
1491 | 1283 | from sagemaker.core.workflow.utilities import load_step_compilation_context |
1492 | 1284 |
|
@@ -1515,11 +1307,6 @@ def _generate_input_data_config(job_settings, s3_base_uri): |
1515 | 1307 |
|
1516 | 1308 | local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) |
1517 | 1309 |
|
1518 | | - # Ensure sagemaker dependency is included to prevent version mismatch issues |
1519 | | - # Resolves issue where computing hash for integrity check changed in 3.2.0 |
1520 | | - local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path) |
1521 | | - job_settings.dependencies = local_dependencies_path |
1522 | | - |
1523 | 1310 | if step_compilation_context: |
1524 | 1311 | with _tmpdir() as tmp_dir: |
1525 | 1312 | script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( |
|
0 commit comments