|
175 | 175 | fi |
176 | 176 |
|
177 | 177 | printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" |
178 | | - printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" |
179 | | - $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" |
| 178 | + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n" |
| 179 | + $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@" |
180 | 180 | else |
181 | 181 | printf "INFO: No conda env provided. Invoking remote function\\n" |
182 | | - printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" |
183 | | - python -m sagemaker.train.remote_function.invoke_function "$@" |
| 182 | + printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n" |
| 183 | + python -m sagemaker.core.remote_function.invoke_function "$@" |
184 | 184 | fi |
185 | 185 | """ |
186 | 186 |
|
|
234 | 234 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
235 | 235 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
236 | 236 |
|
237 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
| 237 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
238 | 238 | $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
239 | 239 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
240 | 240 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
241 | 241 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
242 | 242 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
243 | 243 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
244 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
| 244 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
245 | 245 |
|
246 | 246 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
247 | 247 | else |
|
259 | 259 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
260 | 260 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
261 | 261 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
262 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
| 262 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
263 | 263 |
|
264 | 264 | mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
265 | 265 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
266 | 266 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
267 | 267 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
268 | 268 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
269 | 269 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
270 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
| 270 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
271 | 271 |
|
272 | 272 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
273 | 273 | else |
|
320 | 320 | printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" |
321 | 321 | printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
322 | 322 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
323 | | - -m sagemaker.train.remote_function.invoke_function \\n" |
| 323 | + -m sagemaker.core.remote_function.invoke_function \\n" |
324 | 324 |
|
325 | 325 | $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
326 | 326 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
327 | | - -m sagemaker.train.remote_function.invoke_function "$@" |
| 327 | + -m sagemaker.core.remote_function.invoke_function "$@" |
328 | 328 | else |
329 | 329 | printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" |
330 | 330 | printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
331 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" |
| 331 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n" |
332 | 332 |
|
333 | 333 | torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
334 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" |
| 334 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@" |
335 | 335 | fi |
336 | 336 | """ |
337 | 337 |
|
@@ -1259,7 +1259,215 @@ def _prepare_and_upload_runtime_scripts( |
1259 | 1259 | return upload_path |
1260 | 1260 |
|
1261 | 1261 |
|
1262 | | -def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): |
| 1262 | +def _decrement_version(version_str: str) -> str: |
| 1263 | + """Decrement a version string by one minor or patch version. |
| 1264 | + |
| 1265 | + Rules: |
| 1266 | + - If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0 |
| 1267 | + - If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1 |
| 1268 | + |
| 1269 | + Args: |
| 1270 | + version_str: Version string (e.g., "3.2.0") |
| 1271 | + |
| 1272 | + Returns: |
| 1273 | + Decremented version string |
| 1274 | + """ |
| 1275 | + from packaging import version as pkg_version |
| 1276 | + |
| 1277 | + try: |
| 1278 | + parsed = pkg_version.parse(version_str) |
| 1279 | + major = parsed.major |
| 1280 | + minor = parsed.minor |
| 1281 | + patch = parsed.micro |
| 1282 | + |
| 1283 | + if patch == 0: |
| 1284 | + # Decrement minor version |
| 1285 | + minor = max(0, minor - 1) |
| 1286 | + else: |
| 1287 | + # Decrement patch version |
| 1288 | + patch = max(0, patch - 1) |
| 1289 | + |
| 1290 | + return f"{major}.{minor}.{patch}" |
| 1291 | + except Exception: |
| 1292 | + return version_str |
| 1293 | + |
| 1294 | + |
| 1295 | +def _resolve_version_from_specifier(specifier_str: str) -> str: |
| 1296 | + """Resolve the version to check based on upper bounds. |
| 1297 | + |
| 1298 | + Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only). |
| 1299 | + If no upper bound exists, it's safe (unbounded). |
| 1300 | + If the decremented upper bound is less than a lower bound, use the lower bound. |
| 1301 | + |
| 1302 | + Args: |
| 1303 | + specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0") |
| 1304 | + |
| 1305 | + Returns: |
| 1306 | + The resolved version string to check, or None if safe |
| 1307 | + """ |
| 1308 | + import re |
| 1309 | + from packaging import version as pkg_version |
| 1310 | + |
| 1311 | + # Handle exact version pinning (==) |
| 1312 | + match = re.search(r'==\s*([\d.]+)', specifier_str) |
| 1313 | + if match: |
| 1314 | + return match.group(1) |
| 1315 | + |
| 1316 | + # Extract lower bounds for comparison |
| 1317 | + lower_bounds = [] |
| 1318 | + for match in re.finditer(r'>=\s*([\d.]+)', specifier_str): |
| 1319 | + lower_bounds.append(match.group(1)) |
| 1320 | + |
| 1321 | + # Handle upper bounds - find the most restrictive one |
| 1322 | + upper_bounds = [] |
| 1323 | + |
| 1324 | + # Find all <= bounds |
| 1325 | + for match in re.finditer(r'<=\s*([\d.]+)', specifier_str): |
| 1326 | + upper_bounds.append(('<=', match.group(1))) |
| 1327 | + |
| 1328 | + # Find all < bounds |
| 1329 | + for match in re.finditer(r'<\s*([\d.]+)', specifier_str): |
| 1330 | + upper_bounds.append(('<', match.group(1))) |
| 1331 | + |
| 1332 | + if upper_bounds: |
| 1333 | + # Sort by version to find the most restrictive (lowest) upper bound |
| 1334 | + upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) |
| 1335 | + operator, version = upper_bounds[0] |
| 1336 | + |
| 1337 | + # Special case: if upper bound is <4.0.0, it's safe (V3 only) |
| 1338 | + try: |
| 1339 | + parsed_upper = pkg_version.parse(version) |
| 1340 | + if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0: |
| 1341 | + # <4.0.0 means V3 only, which is safe |
| 1342 | + return None |
| 1343 | + except Exception: |
| 1344 | + pass |
| 1345 | + |
| 1346 | + resolved_version = version |
| 1347 | + if operator == '<': |
| 1348 | + resolved_version = _decrement_version(version) |
| 1349 | + |
| 1350 | + # If we have a lower bound and the resolved version is less than it, use the lower bound |
| 1351 | + if lower_bounds: |
| 1352 | + try: |
| 1353 | + resolved_parsed = pkg_version.parse(resolved_version) |
| 1354 | + for lower_bound_str in lower_bounds: |
| 1355 | + lower_parsed = pkg_version.parse(lower_bound_str) |
| 1356 | + if resolved_parsed < lower_parsed: |
| 1357 | + resolved_version = lower_bound_str |
| 1358 | + except Exception: |
| 1359 | + pass |
| 1360 | + |
| 1361 | + return resolved_version |
| 1362 | + |
| 1363 | + # For lower bounds only (>=, >), we don't check |
| 1364 | + return None |
| 1365 | + |
| 1366 | + |
| 1367 | +def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: |
| 1368 | + """Check if the sagemaker version requirement uses incompatible hashing. |
| 1369 | +
|
| 1370 | + Raises ValueError if the requirement would install a version that uses HMAC hashing |
| 1371 | + (which is incompatible with the current SHA256-based integrity checks). |
| 1372 | +
|
| 1373 | + Args: |
| 1374 | + sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0") |
| 1375 | +
|
| 1376 | + Raises: |
| 1377 | + ValueError: If the requirement would install a version using HMAC hashing |
| 1378 | + """ |
| 1379 | + import re |
| 1380 | + from packaging import version as pkg_version |
| 1381 | + |
| 1382 | + match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) |
| 1383 | + if not match: |
| 1384 | + return |
| 1385 | + |
| 1386 | + specifier_str = match.group(1).strip() |
| 1387 | + |
| 1388 | + # Resolve the version that would be installed |
| 1389 | + resolved_version_str = _resolve_version_from_specifier(specifier_str) |
| 1390 | + if not resolved_version_str: |
| 1391 | + # No upper bound or exact version, so we can't determine if it's bad |
| 1392 | + return |
| 1393 | + |
| 1394 | + try: |
| 1395 | + resolved_version = pkg_version.parse(resolved_version_str) |
| 1396 | + except Exception: |
| 1397 | + return |
| 1398 | + |
| 1399 | + # Define HMAC thresholds for each major version |
| 1400 | + v2_hmac_threshold = pkg_version.parse("2.256.0") |
| 1401 | + v3_hmac_threshold = pkg_version.parse("3.2.0") |
| 1402 | + |
| 1403 | + # Check if the resolved version uses HMAC hashing |
| 1404 | + uses_hmac = False |
| 1405 | + if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: |
| 1406 | + uses_hmac = True |
| 1407 | + elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: |
| 1408 | + uses_hmac = True |
| 1409 | + |
| 1410 | + if uses_hmac: |
| 1411 | + raise ValueError( |
| 1412 | + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " |
| 1413 | + f"could install a version using HMAC-based integrity checks which are incompatible " |
| 1414 | + f"with the current SHA256-based integrity checks. Please update to " |
| 1415 | + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." |
| 1416 | + ) |
| 1417 | + |
| 1418 | + |
| 1419 | +def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: |
| 1420 | + """Ensure sagemaker>=3.2.0 is in the dependencies. |
| 1421 | +
|
| 1422 | + This function ensures that the remote environment has a compatible version of sagemaker |
| 1423 | + that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based |
| 1424 | + integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable. |
| 1425 | + Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require |
| 1426 | + the secret key. |
| 1427 | +
|
| 1428 | + If no dependencies are provided, creates a temporary requirements.txt with sagemaker. |
| 1429 | + If dependencies are provided, appends sagemaker if not already present. |
| 1430 | +
|
| 1431 | + Args: |
| 1432 | + local_dependencies_path: Path to user's dependencies file or None |
| 1433 | +
|
| 1434 | + Returns: |
| 1435 | + Path to the dependencies file (created or modified) |
| 1436 | +
|
| 1437 | + Raises: |
| 1438 | + ValueError: If user has pinned sagemaker to a version using HMAC hashing |
| 1439 | + """ |
| 1440 | + import tempfile |
| 1441 | + |
| 1442 | + SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0" |
| 1443 | + |
| 1444 | + if local_dependencies_path is None: |
| 1445 | + fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") |
| 1446 | + os.close(fd) |
| 1447 | + |
| 1448 | + with open(req_file, "w") as f: |
| 1449 | + f.write(f"{SAGEMAKER_MIN_VERSION}\n") |
| 1450 | + logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION) |
| 1451 | + return req_file |
| 1452 | + |
| 1453 | + if local_dependencies_path.endswith(".txt"): |
| 1454 | + with open(local_dependencies_path, "r") as f: |
| 1455 | + content = f.read() |
| 1456 | + |
| 1457 | + if "sagemaker" in content.lower(): |
| 1458 | + for line in content.split('\n'): |
| 1459 | + if 'sagemaker' in line.lower(): |
| 1460 | + _check_sagemaker_version_compatibility(line.strip()) |
| 1461 | + break |
| 1462 | + else: |
| 1463 | + with open(local_dependencies_path, "a") as f: |
| 1464 | + f.write(f"\n{SAGEMAKER_MIN_VERSION}\n") |
| 1465 | + logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION) |
| 1466 | + |
| 1467 | + return local_dependencies_path |
| 1468 | + |
| 1469 | + |
| 1470 | +def _generate_input_data_config(job_settings, s3_base_uri): |
1263 | 1471 | """Generates input data config""" |
1264 | 1472 | from sagemaker.core.workflow.utilities import load_step_compilation_context |
1265 | 1473 |
|
@@ -1288,6 +1496,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): |
1288 | 1496 |
|
1289 | 1497 | local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) |
1290 | 1498 |
|
| 1499 | + # Ensure sagemaker dependency is included to prevent version mismatch issues |
| 1500 | + # Resolves issue where computing hash for integrity check changed in 3.2.0 |
| 1501 | + local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path) |
| 1502 | + job_settings.dependencies = local_dependencies_path |
| 1503 | + |
1291 | 1504 | if step_compilation_context: |
1292 | 1505 | with _tmpdir() as tmp_dir: |
1293 | 1506 | script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( |
|
0 commit comments