From bbfa35b5a199c75a93daf583a8222b72e91015eb Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:58:52 -0400 Subject: [PATCH] fix: Pipeline TypeError: can only concatenate list (not "NoneType") to list Using Sou (#5518) --- .../src/sagemaker/core/workflow/utilities.py | 523 ------------------ .../tests/unit/workflow/test_utilities.py | 489 ---------------- 2 files changed, 1012 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/workflow/utilities.py b/sagemaker-core/src/sagemaker/core/workflow/utilities.py index c07a31c51e..e69de29bb2 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/utilities.py +++ b/sagemaker-core/src/sagemaker/core/workflow/utilities.py @@ -1,523 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Utilities to support workflow.""" -from __future__ import absolute_import - -import inspect -import logging -from functools import wraps -from pathlib import Path -from typing import List, Sequence, Union, Set, TYPE_CHECKING, Optional -import hashlib -from urllib.parse import unquote, urlparse -from contextlib import contextmanager - -try: - # _hashlib is an internal python module, and is not present in - # statically linked interpreters. - from _hashlib import HASH as Hash -except ImportError: - import typing - - Hash = typing.Any - -from sagemaker.core.common_utils import base_from_name -from sagemaker.core.workflow.parameters import Parameter -from sagemaker.core.workflow.pipeline_context import _StepArguments, _PipelineConfig -from sagemaker.core.workflow.entities import Entity -from sagemaker.core.helper.pipeline_variable import RequestType -from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig - -logger = logging.getLogger(__name__) - -DEF_CONFIG_WARN_MSG_TEMPLATE = ( - "Popping out '%s' from the pipeline definition by default " - "since it will be overridden at pipeline execution time. Please utilize " - "the PipelineDefinitionConfig to persist this field in the pipeline definition " - "if desired." -) - -JOB_KEY_NONE_WARN_MSG_TEMPLATE = ( - "Invalid input: use_custom_job_prefix flag is set but the name field [{}] has not been " - "specified. Please refer to the AWS Docs to identify which field should be set to enable the " - "custom-prefixing feature for jobs created via a pipeline execution. " - "https://docs.aws.amazon.com/sagemaker/latest/dg/" - "build-and-manage-access.html#build-and-manage-step-permissions-prefix" -) - -if TYPE_CHECKING: - from sagemaker.mlops.workflow.step_collections import StepCollection - -BUF_SIZE = 65536 # 64KiB -_pipeline_config: _PipelineConfig = None - - -def list_to_request(entities: Sequence[Union[Entity, "StepCollection"]]) -> List[RequestType]: - """Get the request structure for list of entities. - - Args: - entities (Sequence[Entity]): A list of entities. - Returns: - list: A request structure for a workflow service call. - """ - from sagemaker.mlops.workflow.step_collections import StepCollection - - request_dicts = [] - for entity in entities: - if isinstance(entity, Entity): - request_dicts.append(entity.to_request()) - elif isinstance(entity, StepCollection): - request_dicts.extend(entity.request_dicts()) - return request_dicts - - -@contextmanager -def step_compilation_context_manager( - pipeline_name: str, - step_name: str, - sagemaker_session, - code_hash: str, - config_hash: str, - pipeline_definition_config: PipelineDefinitionConfig, - upload_runtime_scripts: bool, - upload_workspace: bool, - pipeline_build_time: str, - function_step_secret_token: Optional[str] = None, -): - """Expose static _pipeline_config variable to other modules - - Args: - pipeline_name (str): pipeline name - step_name (str): step name - sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session - code_hash (str): a hash of the code artifact for the particular step - config_hash (str): a hash of the config artifact for the particular step (Processing) - pipeline_definition_config (PipelineDefinitionConfig): a configuration used to toggle - feature flags persistent in a pipeline definition - upload_runtime_scripts (bool): flag used to manage upload of runtime scripts to s3 for - a _FunctionStep in pipeline - upload_workspace (bool): flag used to manage the upload of workspace to s3 for a - _FunctionStep in pipeline - pipeline_build_time (str): timestamp when the pipeline is being built - function_step_secret_token (str): secret token used for the function step checksum - """ - - # pylint: disable=W0603 - global _pipeline_config - _pipeline_config = _PipelineConfig( - pipeline_name=pipeline_name, - step_name=step_name, - sagemaker_session=sagemaker_session, - code_hash=code_hash, - config_hash=config_hash, - pipeline_definition_config=pipeline_definition_config, - upload_runtime_scripts=upload_runtime_scripts, - upload_workspace=upload_workspace, - pipeline_build_time=pipeline_build_time, - function_step_secret_token=function_step_secret_token, - ) - try: - yield _pipeline_config - finally: - _pipeline_config = None - - -def load_step_compilation_context(): - """Load the step compilation context from the static _pipeline_config variable - - Returns: - _PipelineConfig: a context object containing information about the current step - """ - return _pipeline_config - - -def get_code_hash(step: Entity) -> str: - """Get the hash of the code artifact(s) for the given step - - Args: - step (Entity): A pipeline step object (Entity type because Step causes circular import) - Returns: - str: A hash string representing the unique code artifact(s) for the step - """ - from sagemaker.mlops.workflow.steps import ProcessingStep, TrainingStep - - if isinstance(step, ProcessingStep) and step.step_args: - kwargs = step.step_args.func_kwargs - source_dir = kwargs.get("source_dir") - submit_class = kwargs.get("submit_class") - dependencies = get_processing_dependencies( - [ - kwargs.get("dependencies"), - kwargs.get("submit_py_files"), - [submit_class] if submit_class else None, - kwargs.get("submit_jars"), - kwargs.get("submit_files"), - ] - ) - code = kwargs.get("submit_app") or kwargs.get("code") - - return get_processing_code_hash(code, source_dir, dependencies) - - if isinstance(step, TrainingStep) and step.step_args: - model_trainer = step.step_args.func_args[0] - source_code = model_trainer.source_code - if source_code: - source_dir = source_code.source_dir - requirements = source_code.requirements - entry_point = source_code.entry_script - return get_training_code_hash(entry_point, source_dir, requirements) - return None - - -def get_processing_dependencies(dependency_args: List[List[str]]) -> List[str]: - """Get the Processing job dependencies from the processor run kwargs - - Args: - dependency_args: A list of dependency args from processor.run() - Returns: - List[str]: A list of code dependencies for the job - """ - - dependencies = [] - for arg in dependency_args: - if arg: - dependencies += arg - - return dependencies - - -def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]) -> str: - """Get the hash of a processing step's code artifact(s). - - Args: - code (str): Path to a file with the processing script to run - source_dir (str): Path to a directory with any other processing - source code dependencies aside from the entry point file - dependencies (str): A list of paths to directories (absolute - or relative) with any additional libraries that will be exported - to the container - Returns: - str: A hash string representing the unique code artifact(s) for the step - """ - - # FrameworkProcessor - if source_dir: - source_dir_url = urlparse(source_dir) - if source_dir_url.scheme == "" or source_dir_url.scheme == "file": - # Include code in the hash when possible - if code: - code_url = urlparse(code) - if code_url.scheme == "" or code_url.scheme == "file": - return hash_files_or_dirs([code, source_dir] + dependencies) - return hash_files_or_dirs([source_dir] + dependencies) - # Other Processors - Spark, Script, Base, etc. - if code: - code_url = urlparse(code) - if code_url.scheme == "" or code_url.scheme == "file": - return hash_files_or_dirs([code] + dependencies) - return None - - -def get_training_code_hash( - entry_point: str, source_dir: str, dependencies: Optional[str] = None -) -> str: - """Get the hash of a training step's code artifact(s). - - Args: - entry_point (str): The absolute or relative path to the local Python - source file that should be executed as the entry point to - training - source_dir (str): Path to a directory with any other training source - code dependencies aside from the entry point file - dependencies Optional[str]: The relative path within ``source_dir`` to a - ``requirements.txt`` file with any additional libraries that - will be exported to the container - Returns: - str: A hash string representing the unique code artifact(s) for the step - """ - from sagemaker.core.workflow import is_pipeline_variable - - if not is_pipeline_variable(source_dir) and not is_pipeline_variable(entry_point): - if source_dir: - source_dir_url = urlparse(source_dir) - if source_dir_url.scheme == "" or source_dir_url.scheme == "file": - if dependencies: - return hash_files_or_dirs([source_dir] + [dependencies]) - else: - return hash_files_or_dirs([source_dir]) - elif entry_point: - entry_point_url = urlparse(entry_point) - if entry_point_url.scheme == "" or entry_point_url.scheme == "file": - if dependencies: - return hash_files_or_dirs([entry_point] + [dependencies]) - else: - return hash_files_or_dirs([entry_point]) - return None - - -def get_config_hash(step: Entity): - """Get the hash of the config artifact(s) for the given step - - Args: - step (Entity): A pipeline step object (Entity type because Step causes circular import) - Returns: - str: A hash string representing the unique config artifact(s) for the step - """ - from sagemaker.mlops.workflow.steps import ProcessingStep - - if isinstance(step, ProcessingStep) and step.step_args: - config = step.step_args.func_kwargs.get("configuration") - if config: - return hash_object(config) - return None - - -def hash_object(obj) -> str: - """Get the SHA256 hash of an object. - - Args: - obj (dict): The object - Returns: - str: The SHA256 hash of the object - """ - return hashlib.sha256(str(obj).encode()).hexdigest() - - -def hash_file(path: str) -> str: - """Get the SHA256 hash of a file. - - Args: - path (str): The local path for the file. - Returns: - str: The SHA256 hash of the file. - """ - return _hash_file(path, hashlib.sha256()).hexdigest() - - -def hash_files_or_dirs(paths: List[str]) -> str: - """Get the SHA256 hash of the contents of a list of files or directories. - - Hash is changed if: - * input list is changed - * new nested directories/files are added to any directory in the input list - * nested directory/file names are changed for any of the inputted directories - * content of files is edited - - Args: - paths: List of file or directory paths - Returns: - str: The SHA256 hash of the list of files or directories. - """ - sha256 = hashlib.sha256() - for path in sorted(paths): - sha256 = _hash_file_or_dir(path, sha256) - return sha256.hexdigest() - - -def _hash_file_or_dir(path: str, sha256: Hash) -> Hash: - """Updates the inputted Hash with the contents of the current path. - - Args: - path: path of file or directory - Returns: - str: The SHA256 hash of the file or directory - """ - if isinstance(path, str) and path.lower().startswith("file://"): - path = unquote(urlparse(path).path) - sha256.update(path.encode()) - if Path(path).is_dir(): - sha256 = _hash_dir(path, sha256) - elif Path(path).is_file(): - sha256 = _hash_file(path, sha256) - return sha256 - - -def _hash_dir(directory: Union[str, Path], sha256: Hash) -> Hash: - """Updates the inputted Hash with the contents of the current path. - - Args: - directory: path of the directory - Returns: - str: The SHA256 hash of the directory - """ - if not Path(directory).is_dir(): - raise ValueError(str(directory) + " is not a valid directory") - for path in sorted(Path(directory).iterdir()): - sha256.update(path.name.encode()) - if path.is_file(): - sha256 = _hash_file(path, sha256) - elif path.is_dir(): - sha256 = _hash_dir(path, sha256) - return sha256 - - -def _hash_file(file: Union[str, Path], sha256: Hash) -> Hash: - """Updates the inputted Hash with the contents of the current path. - - Args: - file: path of the file - Returns: - str: The SHA256 hash of the file - """ - if isinstance(file, str) and file.lower().startswith("file://"): - file = unquote(urlparse(file).path) - if not Path(file).is_file(): - raise ValueError(str(file) + " is not a valid file") - with open(file, "rb") as f: - while True: - data = f.read(BUF_SIZE) - if not data: - break - sha256.update(data) - return sha256 - - -def validate_step_args_input( - step_args: _StepArguments, expected_caller: Set[str], error_message: str -): - """Validate the `_StepArguments` object which is passed into a pipeline step - - Args: - step_args (_StepArguments): A `_StepArguments` object to be used for composing - a pipeline step. - expected_caller (Set[str]): The expected name of the caller function which is - intercepted by the PipelineSession to get the step arguments. - error_message (str): The error message to be thrown if the validation fails. - """ - if not isinstance(step_args, _StepArguments): - raise TypeError(error_message) - if step_args.caller_name not in expected_caller: - raise ValueError(error_message) - - -def override_pipeline_parameter_var(func): - """A decorator to override pipeline Parameters passed into a function - - This is a temporary decorator to override pipeline Parameter objects with their default value - and display warning information to instruct users to update their code. - - This decorator can help to give a grace period for users to update their code when - we make changes to explicitly prevent passing any pipeline variables to a function. - - We should remove this decorator after the grace period. - """ - warning_msg_template = ( - "The input argument %s of function (%s) is a pipeline variable (%s), " - "which is interpreted in pipeline execution time only. " - "As the function needs to evaluate the argument value in SDK compile time, " - "the default_value of this Parameter object will be used to override it. " - "Please make sure the default_value is valid." - ) - - @wraps(func) - def wrapper(*args, **kwargs): - func_name = "{}.{}".format(func.__module__, func.__name__) - params = inspect.signature(func).parameters - args = list(args) - for i, (arg_name, _) in enumerate(params.items()): - if i >= len(args): - break - if isinstance(args[i], Parameter): - logger.warning(warning_msg_template, arg_name, func_name, type(args[i])) - args[i] = args[i].default_value - args = tuple(args) - - for arg_name, value in kwargs.items(): - if isinstance(value, Parameter): - logger.warning(warning_msg_template, arg_name, func_name, type(value)) - kwargs[arg_name] = value.default_value - return func(*args, **kwargs) - - return wrapper - - -def execute_job_functions(step_args: _StepArguments): - """Execute the job class functions during pipeline definition construction - - Executes the job functions such as run(), fit(), or transform() that have been - delayed until the pipeline gets built, for steps built with a PipelineSession. - - Handles multiple functions in instances where job functions are chained - together from the inheritance of different job classes (e.g. PySparkProcessor, - ScriptProcessor, and Processor). - - Args: - step_args (_StepArguments): A `_StepArguments` object to be used for composing - a pipeline step, contains the necessary function information - """ - - chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs) - if isinstance(chained_args, _StepArguments): - execute_job_functions(chained_args) - - -def trim_request_dict(request_dict, job_key, config): - """Trim request_dict for unwanted fields to not persist them in step arguments - - Trim the job_name field off request_dict in cases where we do not want to include it - in the pipeline definition. - - Args: - request_dict (dict): A dictionary used to build the arguments for a pipeline step, - containing fields that will be passed to job client during orchestration. - job_key (str): The key in a step's arguments to look up the base_job_name if it - exists - config (_pipeline_config) The config intercepted and set for a pipeline via the - context manager - """ - - if not config or not config.pipeline_definition_config.use_custom_job_prefix: - logger.warning(DEF_CONFIG_WARN_MSG_TEMPLATE, job_key) - request_dict.pop(job_key, None) # safely return null in case of KeyError - else: - if job_key in request_dict: - if request_dict[job_key] is None or len(request_dict[job_key]) == 0: - raise ValueError(JOB_KEY_NONE_WARN_MSG_TEMPLATE.format(job_key)) - request_dict[job_key] = base_from_name(request_dict[job_key]) # trim timestamp - - return request_dict - - -def _collect_parameters(func): - """The decorator function is to collect all the params passed into an invoked function of class. - - These parameters are set as properties of the class instance. The use case is to simplify - parameter collecting when they are passed into the step __init__ method. - - Usage: - class A: - @collect_parameters - def __init__(a, b='', c=None) - pass - - In above case, the A instance would have a, b, c set as instance properties. - None value will be set as well. If the property exists, it will be overridden. - """ - - @wraps(func) - def wrapper(self, *args, **kwargs): - # Get the parameters and values - signature = inspect.signature(func) - bound_args = signature.bind(self, *args, **kwargs) - bound_args.apply_defaults() - - # Create a dictionary of parameters and their values - parameters_and_values = dict(bound_args.arguments) - - for param, value in parameters_and_values.items(): - if param not in ("self", "depends_on"): - setattr(self, param, value) - - func(self, *args, **kwargs) - - return wrapper diff --git a/sagemaker-core/tests/unit/workflow/test_utilities.py b/sagemaker-core/tests/unit/workflow/test_utilities.py index 5e9ed7bbbd..e69de29bb2 100644 --- a/sagemaker-core/tests/unit/workflow/test_utilities.py +++ b/sagemaker-core/tests/unit/workflow/test_utilities.py @@ -1,489 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import pytest -import tempfile -import os -from pathlib import Path -from unittest.mock import Mock, patch, MagicMock -from sagemaker.core.workflow.utilities import ( - list_to_request, - hash_file, - hash_files_or_dirs, - hash_object, - get_processing_dependencies, - get_processing_code_hash, - get_training_code_hash, - validate_step_args_input, - override_pipeline_parameter_var, - trim_request_dict, - _collect_parameters, -) -from sagemaker.core.workflow.entities import Entity -from sagemaker.core.workflow.parameters import Parameter -from sagemaker.core.workflow.pipeline_context import _StepArguments - - -class MockEntity(Entity): - """Mock entity for testing""" - - def to_request(self): - return {"Type": "MockEntity"} - - -class TestWorkflowUtilities: - """Test cases for workflow utility functions""" - - @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") - def test_list_to_request_with_entities(self): - """Test list_to_request with Entity objects""" - entities = [MockEntity(), MockEntity()] - - result = list_to_request(entities) - - assert len(result) == 2 - assert all(item["Type"] == "MockEntity" for item in result) - - @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") - def test_list_to_request_with_step_collection(self): - """Test list_to_request with StepCollection""" - from sagemaker.mlops.workflow.step_collections import StepCollection - - mock_collection = Mock(spec=StepCollection) - mock_collection.request_dicts.return_value = [{"Type": "Step1"}, {"Type": "Step2"}] - - result = list_to_request([mock_collection]) - - assert len(result) == 2 - - @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") - def test_list_to_request_mixed(self): - """Test list_to_request with mixed entities and collections""" - from sagemaker.mlops.workflow.step_collections import StepCollection - - mock_collection = Mock(spec=StepCollection) - mock_collection.request_dicts.return_value = [{"Type": "Step1"}] - - entities = [MockEntity(), mock_collection] - - result = list_to_request(entities) - - assert len(result) == 2 - - def test_hash_object(self): - """Test hash_object produces consistent hash""" - obj = {"key": "value", "number": 123} - - hash1 = hash_object(obj) - hash2 = hash_object(obj) - - assert hash1 == hash2 - assert len(hash1) == 64 # SHA256 produces 64 character hex string - - def test_hash_object_different_objects(self): - """Test hash_object produces different hashes for different objects""" - obj1 = {"key": "value1"} - obj2 = {"key": "value2"} - - hash1 = hash_object(obj1) - hash2 = hash_object(obj2) - - assert hash1 != hash2 - - def test_hash_file(self): - """Test hash_file produces consistent hash""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - f.write("test content") - temp_file = f.name - - try: - hash1 = hash_file(temp_file) - hash2 = hash_file(temp_file) - - assert hash1 == hash2 - assert len(hash1) == 64 - finally: - os.unlink(temp_file) - - def test_hash_file_different_content(self): - """Test hash_file produces different hashes for different content""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - f1.write("content1") - temp_file1 = f1.name - - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - f2.write("content2") - temp_file2 = f2.name - - try: - hash1 = hash_file(temp_file1) - hash2 = hash_file(temp_file2) - - assert hash1 != hash2 - finally: - os.unlink(temp_file1) - os.unlink(temp_file2) - - def test_hash_files_or_dirs_single_file(self): - """Test hash_files_or_dirs with single file""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - f.write("test content") - temp_file = f.name - - try: - result = hash_files_or_dirs([temp_file]) - - assert len(result) == 64 - finally: - os.unlink(temp_file) - - def test_hash_files_or_dirs_multiple_files(self): - """Test hash_files_or_dirs with multiple files""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - f1.write("content1") - temp_file1 = f1.name - - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - f2.write("content2") - temp_file2 = f2.name - - try: - result = hash_files_or_dirs([temp_file1, temp_file2]) - - assert len(result) == 64 - finally: - os.unlink(temp_file1) - os.unlink(temp_file2) - - def test_hash_files_or_dirs_directory(self): - """Test hash_files_or_dirs with directory""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create some files in the directory - Path(temp_dir, "file1.txt").write_text("content1") - Path(temp_dir, "file2.txt").write_text("content2") - - result = hash_files_or_dirs([temp_dir]) - - assert len(result) == 64 - - def test_hash_files_or_dirs_order_matters(self): - """Test hash_files_or_dirs produces same hash regardless of input order""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - f1.write("content1") - temp_file1 = f1.name - - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - f2.write("content2") - temp_file2 = f2.name - - try: - # Hash should be same regardless of order due to sorting - hash1 = hash_files_or_dirs([temp_file1, temp_file2]) - hash2 = hash_files_or_dirs([temp_file2, temp_file1]) - - assert hash1 == hash2 - finally: - os.unlink(temp_file1) - os.unlink(temp_file2) - - def test_get_processing_dependencies_empty(self): - """Test get_processing_dependencies with empty lists""" - result = get_processing_dependencies([None, None, None]) - - assert result == [] - - def test_get_processing_dependencies_single_list(self): - """Test get_processing_dependencies with single list""" - result = get_processing_dependencies([["dep1", "dep2"], None, None]) - - assert result == ["dep1", "dep2"] - - def test_get_processing_dependencies_multiple_lists(self): - """Test get_processing_dependencies with multiple lists""" - result = get_processing_dependencies([["dep1", "dep2"], ["dep3"], ["dep4", "dep5"]]) - - assert result == ["dep1", "dep2", "dep3", "dep4", "dep5"] - - def test_get_processing_code_hash_with_source_dir(self): - """Test get_processing_code_hash with source_dir""" - with tempfile.TemporaryDirectory() as temp_dir: - code_file = Path(temp_dir, "script.py") - code_file.write_text("print('hello')") - - result = get_processing_code_hash( - code=str(code_file), source_dir=temp_dir, dependencies=[] - ) - - assert result is not None - assert len(result) == 64 - - def test_get_processing_code_hash_code_only(self): - """Test get_processing_code_hash with code only""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write("print('hello')") - temp_file = f.name - - try: - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[]) - - assert result is not None - assert len(result) == 64 - finally: - os.unlink(temp_file) - - def test_get_processing_code_hash_s3_uri(self): - """Test get_processing_code_hash with S3 URI returns None""" - result = get_processing_code_hash( - code="s3://bucket/script.py", source_dir=None, dependencies=[] - ) - - assert result is None - - def test_get_processing_code_hash_with_dependencies(self): - """Test get_processing_code_hash with dependencies""" - with tempfile.TemporaryDirectory() as temp_dir: - code_file = Path(temp_dir, "script.py") - code_file.write_text("print('hello')") - - dep_file = Path(temp_dir, "utils.py") - dep_file.write_text("def helper(): pass") - - result = get_processing_code_hash( - code=str(code_file), source_dir=temp_dir, dependencies=[str(dep_file)] - ) - - assert result is not None - - def test_get_training_code_hash_with_source_dir(self): - """Test get_training_code_hash with source_dir""" - with tempfile.TemporaryDirectory() as temp_dir: - entry_file = Path(temp_dir, "train.py") - entry_file.write_text("print('training')") - requirements_file = Path(temp_dir, "requirements.txt") - requirements_file.write_text("numpy==1.21.0") - - result_no_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=temp_dir, dependencies=None - ) - result_with_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file) - ) - - assert result_no_deps is not None - assert result_with_deps is not None - assert len(result_no_deps) == 64 - assert len(result_with_deps) == 64 - assert result_no_deps != result_with_deps - - def test_get_training_code_hash_entry_point_only(self): - """Test get_training_code_hash with entry_point only""" - with tempfile.TemporaryDirectory() as temp_dir: - entry_file = Path(temp_dir, "train.py") - entry_file.write_text("print('training')") - requirements_file = Path(temp_dir, "requirements.txt") - requirements_file.write_text("numpy==1.21.0") - - # Without dependencies - result_no_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=None, dependencies=None - ) - # With dependencies - result_with_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file) - ) - - assert result_no_deps is not None - assert result_with_deps is not None - assert len(result_no_deps) == 64 - assert len(result_with_deps) == 64 - assert result_no_deps != result_with_deps - - def test_get_training_code_hash_s3_uri(self): - """Test get_training_code_hash with S3 URI returns None""" - result = get_training_code_hash( - entry_point="s3://bucket/train.py", source_dir=None, dependencies=[] - ) - - assert result is None - - def test_get_training_code_hash_pipeline_variable(self): - """Test get_training_code_hash with pipeline variable returns None""" - with patch("sagemaker.core.workflow.is_pipeline_variable", return_value=True): - result = get_training_code_hash( - entry_point="train.py", source_dir="source", dependencies=[] - ) - - assert result is None - - def test_validate_step_args_input_valid(self): - """Test validate_step_args_input with valid input""" - step_args = _StepArguments( - caller_name="test_function", func=Mock(), func_args=[], func_kwargs={} - ) - - # Should not raise an error - validate_step_args_input( - step_args, expected_caller={"test_function"}, error_message="Invalid input" - ) - - def test_validate_step_args_input_invalid_type(self): - """Test validate_step_args_input with invalid type""" - with pytest.raises(TypeError): - validate_step_args_input( - "not_step_args", expected_caller={"test_function"}, error_message="Invalid input" - ) - - def test_validate_step_args_input_wrong_caller(self): - """Test validate_step_args_input with wrong caller""" - step_args = _StepArguments( - caller_name="wrong_function", func=Mock(), func_args=[], func_kwargs={} - ) - - with pytest.raises(ValueError): - validate_step_args_input( - step_args, expected_caller={"test_function"}, error_message="Invalid input" - ) - - def test_override_pipeline_parameter_var_decorator(self): - """Test override_pipeline_parameter_var decorator""" - from sagemaker.core.workflow.parameters import ParameterInteger - - @override_pipeline_parameter_var - def test_func(param1, param2=None): - return param1, param2 - - param = ParameterInteger(name="test", default_value=10) - - result = test_func(param, param2=20) - - assert result[0] == 10 # Should use default_value - assert result[1] == 20 - - def test_override_pipeline_parameter_var_decorator_kwargs(self): - """Test override_pipeline_parameter_var decorator with kwargs""" - from sagemaker.core.workflow.parameters import ParameterInteger - - @override_pipeline_parameter_var - def test_func(param1, param2=None): - return param1, param2 - - param = ParameterInteger(name="test", default_value=5) - - result = test_func(1, param2=param) - - assert result[0] == 1 - assert result[1] == 5 # Should use default_value - - def test_trim_request_dict_without_config(self): - """Test trim_request_dict without config removes job_name""" - request_dict = {"job_name": "test-job-123", "other": "value"} - - result = trim_request_dict(request_dict, "job_name", None) - - assert "job_name" not in result - assert result["other"] == "value" - - def test_trim_request_dict_with_config_use_custom_prefix(self): - """Test trim_request_dict with config and use_custom_job_prefix""" - from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig - - config = Mock() - config.pipeline_definition_config = PipelineDefinitionConfig(use_custom_job_prefix=True) - - request_dict = {"job_name": "test-job-123", "other": "value"} - - with patch("sagemaker.core.workflow.utilities.base_from_name", return_value="test-job"): - result = trim_request_dict(request_dict, "job_name", config) - - assert result["job_name"] == "test-job" - - def test_trim_request_dict_with_config_none_job_name(self): - """Test trim_request_dict raises error when job_name is None with use_custom_job_prefix""" - from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig - - config = Mock() - config.pipeline_definition_config = PipelineDefinitionConfig(use_custom_job_prefix=True) - - request_dict = {"job_name": None, "other": "value"} - - with pytest.raises(ValueError, match="name field .* has not been specified"): - trim_request_dict(request_dict, "job_name", config) - - def test_collect_parameters_decorator(self): - """Test _collect_parameters decorator""" - - class TestClass: - @_collect_parameters - def __init__(self, param1, param2, param3=None): - pass - - obj = TestClass("value1", "value2", param3="value3") - - assert obj.param1 == "value1" - assert obj.param2 == "value2" - assert obj.param3 == "value3" - - def test_collect_parameters_decorator_excludes_self(self): - """Test _collect_parameters decorator excludes self""" - - class TestClass: - @_collect_parameters - def __init__(self, param1): - pass - - obj = TestClass("value1") - - assert not hasattr(obj, "self") - assert obj.param1 == "value1" - - def test_collect_parameters_decorator_excludes_depends_on(self): - """Test _collect_parameters decorator excludes depends_on""" - - class TestClass: - @_collect_parameters - def __init__(self, param1, depends_on=None): - pass - - obj = TestClass("value1", depends_on=["step1"]) - - assert not hasattr(obj, "depends_on") - assert obj.param1 == "value1" - - def test_collect_parameters_decorator_with_defaults(self): - """Test _collect_parameters decorator with default values""" - - class TestClass: - @_collect_parameters - def __init__(self, param1, param2="default"): - pass - - obj = TestClass("value1") - - assert obj.param1 == "value1" - assert obj.param2 == "default" - - def test_collect_parameters_decorator_overrides_existing(self): - """Test _collect_parameters decorator overrides existing attributes""" - - class TestClass: - def __init__(self, param1): - self.param1 = "old_value" - - @_collect_parameters - def reinit(self, param1): - pass - - obj = TestClass("initial") - obj.reinit("new_value") - - assert obj.param1 == "new_value"