diff --git a/clarifai/cli/pipeline.py b/clarifai/cli/pipeline.py index 1cfe2e29..5c91b9ef 100644 --- a/clarifai/cli/pipeline.py +++ b/clarifai/cli/pipeline.py @@ -100,7 +100,17 @@ def pipeline(): is_flag=True, help='Skip creating config-lock.yaml file.', ) -def upload(path, no_lockfile): +@click.option( + '--user_id', + default=None, + help='Override the user_id from the Clarifai context.', +) +@click.option( + '--app_id', + default=None, + help='Override the app_id from the Clarifai context.', +) +def upload(path, no_lockfile, user_id, app_id): """Upload a pipeline with associated pipeline steps to Clarifai. PATH: Path to the pipeline configuration file or directory containing config.yaml. If not specified, the current directory is used by default. @@ -110,6 +120,10 @@ def upload(path, no_lockfile): if os.path.isfile(path) and path.endswith('.py'): pipeline_obj = load_pipeline_from_file(path) + if user_id: + pipeline_obj.user_id = user_id + if app_id: + pipeline_obj.app_id = app_id output_dir = os.path.join( os.path.dirname(os.path.abspath(path)), f'generated-{pipeline_obj.id}' ) @@ -128,16 +142,42 @@ def upload(path, no_lockfile): required=True, help='Directory to write the compiled pipeline config and step folders.', ) -def compile(path, output_dir): - """Compile YAML/config-based pipeline assets from a Python pipeline definition.""" +@click.option('--user_id', default=None, help='Override the user_id from the Clarifai context.') +@click.option('--app_id', default=None, help='Override the app_id from the Clarifai context.') +def compile(path, output_dir, user_id, app_id): + """Compile YAML/config-based pipeline assets from a Python pipeline definition. + + Generates config.yaml, step directories (with requirements.txt and + pipeline_step.py), and a Dockerfile for each locally managed step. + """ + from clarifai.runners.pipeline_steps.pipeline_step_builder import PipelineStepBuilder from clarifai.runners.pipelines import load_pipeline_from_file if not os.path.isfile(path) or not path.endswith('.py'): raise click.UsageError('clarifai pipeline compile expects a Python file path.') pipeline_obj = load_pipeline_from_file(path) + if user_id: + pipeline_obj.user_id = user_id + if app_id: + pipeline_obj.app_id = app_id config_path = pipeline_obj.generate(output_dir) - logger.info(f"Generated pipeline assets at {config_path}") + + # Generate Dockerfiles for all locally managed step directories. + seen: set = set() + step_ids = [] + for node in pipeline_obj.nodes: + sid = node.step_definition.id + if node.step_definition.is_managed and sid not in seen: + seen.add(sid) + step_ids.append(sid) + for step_id in step_ids: + step_dir = os.path.join(output_dir, step_id) + if os.path.isdir(step_dir): + PipelineStepBuilder(step_dir).create_dockerfile() + logger.info(f"Generated Dockerfile for step '{step_id}'") + + logger.info(f'Generated pipeline assets at {config_path}') @pipeline.command() @@ -1049,7 +1089,7 @@ def validate_lock(lockfile_path): raise click.Abort() -@pipeline.command(['ls']) +@pipeline.command(name='list', aliases=['ls']) @click.option('--page_no', required=False, help='Page number to list.', default=1) @click.option('--per_page', required=False, help='Number of items per page.', default=16) @click.option( @@ -1063,7 +1103,7 @@ def validate_lock(lockfile_path): help='User ID to list pipelines from. If not provided, uses current user.', ) @click.pass_context -def list(ctx, page_no, per_page, app_id, user_id): +def list_pipelines(ctx, page_no, per_page, app_id, user_id): """List all pipelines for the user.""" validate_context(ctx) diff --git a/clarifai/runners/pipeline_steps/pipeline_step_builder.py b/clarifai/runners/pipeline_steps/pipeline_step_builder.py index 6ac6947b..39631dd0 100644 --- a/clarifai/runners/pipeline_steps/pipeline_step_builder.py +++ b/clarifai/runners/pipeline_steps/pipeline_step_builder.py @@ -3,7 +3,6 @@ import sys import tarfile import time -from string import Template from typing import List, Optional import yaml @@ -240,41 +239,46 @@ def create_pipeline_step(self): def create_dockerfile(self): """Create a Dockerfile for the pipeline step.""" - # Use similar logic to model builder for dockerfile creation - dockerfile_template = """FROM --platform=$TARGETPLATFORM public.ecr.aws/clarifai-models/python-base:$PYTHON_VERSION-df565436eea93efb3e8d1eb558a0a46df29523ec as final - -COPY --link requirements.txt /home/nonroot/requirements.txt - -# Update clarifai package so we always have latest protocol to the API. Everything should land in /venv -RUN ["pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"] - -# Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/pipeline_step.py for the actual pipeline step. -COPY --link=true 1 /home/nonroot/main/1 -# At this point we only need these for validation in the SDK. -COPY --link=true requirements.txt config.yaml /home/nonroot/main/ -""" - - # Get Python version from config or use default build_info = self.config.get('build_info', {}) python_version = build_info.get('python_version', '3.12') + base_image = build_info.get('base_image') + platform = build_info.get('platform') # Ensure requirements.txt has clarifai self._ensure_clarifai_requirement() - # Replace placeholders - dockerfile_content = Template(dockerfile_template).safe_substitute( - PYTHON_VERSION=python_version + platform_str = f'--platform={platform}' if platform else '' + image = ( + base_image + or f'public.ecr.aws/clarifai-models/python-base:{python_version}-df565436eea93efb3e8d1eb558a0a46df29523ec' + ) + + dockerfile_content = ( + f'FROM {platform_str} {image} as final\n' + '\n' + 'COPY --link requirements.txt /home/nonroot/requirements.txt\n' + '\n' + '# Install uv, create a venv, and install requirements\n' + f'RUN pip install uv && uv venv /tmp/venv --python {python_version} --clear\n' + 'ENV VIRTUAL_ENV=/tmp/venv\n' + 'ENV PATH="/tmp/venv/bin:$PATH"\n' + 'RUN uv pip install --no-cache-dir -r /home/nonroot/requirements.txt\n' + '\n' + '# Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/pipeline_step.py for the actual pipeline step.\n' + 'COPY --link=true 1 /home/nonroot/main/1\n' + '# At this point we only need these for validation in the SDK.\n' + 'COPY --link=true requirements.txt config.yaml /home/nonroot/main/\n' ) # Write Dockerfile if it doesn't exist dockerfile_path = os.path.join(self.folder, 'Dockerfile') if os.path.exists(dockerfile_path): - logger.info(f"Dockerfile already exists at {dockerfile_path}, skipping creation.") + logger.info(f'Dockerfile already exists at {dockerfile_path}, skipping creation.') return with open(dockerfile_path, 'w') as dockerfile: dockerfile.write(dockerfile_content) - logger.info(f"Created Dockerfile at {dockerfile_path}") + logger.info(f'Created Dockerfile at {dockerfile_path}') def _ensure_clarifai_requirement(self): """Ensure clarifai is in requirements.txt with proper version.""" diff --git a/clarifai/runners/pipelines/codegen.py b/clarifai/runners/pipelines/codegen.py index 8d699e2c..62224fc7 100644 --- a/clarifai/runners/pipelines/codegen.py +++ b/clarifai/runners/pipelines/codegen.py @@ -219,7 +219,15 @@ def generate_step_directory(step_definition, output_dir: str, user_id: str, app_ 'app_id': app_id, }, 'pipeline_step_input_params': step_definition.get_input_params(), - 'build_info': {'python_version': step_definition.python_version}, + 'build_info': { + k: v + for k, v in [ + ('python_version', step_definition.python_version), + ('base_image', step_definition.base_image), + ('platform', step_definition.platform), + ] + if v is not None + }, 'pipeline_step_compute_info': MessageToDict( step_definition.compute, preserving_proto_field_name=True ), diff --git a/clarifai/runners/pipelines/pipeline.py b/clarifai/runners/pipelines/pipeline.py index e22eae8f..8e162680 100644 --- a/clarifai/runners/pipelines/pipeline.py +++ b/clarifai/runners/pipelines/pipeline.py @@ -27,11 +27,6 @@ def __init__( visibility: str = 'PRIVATE', ): user_id, app_id = self._resolve_from_context(user_id, app_id) - if not user_id or not app_id: - raise ValueError( - "Pipeline(...) needs user_id and app_id. Pass them explicitly, " - "or run `clarifai login` to set them in your CLI context." - ) self.id = id self.user_id = user_id self.app_id = app_id @@ -112,6 +107,14 @@ def _generate_task_name(self, step_id: str) -> str: suffix += 1 return candidate + def _validate_identity(self): + """Raise if user_id/app_id are still unresolved at the time of use.""" + if not self.user_id or not self.app_id: + raise ValueError( + "Pipeline(...) needs user_id and app_id. Pass them explicitly, " + "set --user_id/--app_id on the CLI, or run `clarifai login`." + ) + def validate(self): nodes_by_name = {node.name: node for node in self.nodes} for node in self.nodes: @@ -247,6 +250,7 @@ def to_config(self) -> Dict[str, Any]: return config def generate(self, output_dir: str) -> str: + self._validate_identity() os.makedirs(output_dir, exist_ok=True) step_definitions = OrderedDict() for node in self.nodes: @@ -262,6 +266,7 @@ def generate(self, output_dir: str) -> str: return config_path def upload(self, no_lockfile: bool = False) -> Optional[str]: + self._validate_identity() from clarifai.runners.pipelines.pipeline_builder import PipelineBuilder with tempfile.TemporaryDirectory(prefix='clarifai-pipeline-') as temp_dir: diff --git a/clarifai/runners/pipelines/step.py b/clarifai/runners/pipelines/step.py index 45130045..59981686 100644 --- a/clarifai/runners/pipelines/step.py +++ b/clarifai/runners/pipelines/step.py @@ -98,6 +98,8 @@ def __init__( assets=None, compute: Optional[ComputeInfo] = None, python_version: str = '3.12', + base_image: Optional[str] = None, + platform: Optional[str] = None, secrets: Optional[Dict[str, str]] = None, ): self.func = func @@ -106,6 +108,8 @@ def __init__( self.assets = assets or [] self.compute = compute or ComputeInfo() self.python_version = python_version + self.base_image = base_image + self.platform = platform self.secrets = secrets or {} self.signature = inspect.signature(func) @@ -215,6 +219,8 @@ def step( assets=None, compute: Optional[ComputeInfo] = None, python_version: str = '3.12', + base_image: Optional[str] = None, + platform: Optional[str] = None, secrets: Optional[Dict[str, str]] = None, ): def decorator(func: Callable[..., Any]) -> StepDefinition: @@ -225,6 +231,8 @@ def decorator(func: Callable[..., Any]) -> StepDefinition: assets=assets, compute=compute, python_version=python_version, + base_image=base_image, + platform=platform, secrets=secrets, ) diff --git a/tests/cli/test_pipeline.py b/tests/cli/test_pipeline.py index 6f8fb6bb..1bf901fe 100644 --- a/tests/cli/test_pipeline.py +++ b/tests/cli/test_pipeline.py @@ -6,6 +6,7 @@ import yaml from click.testing import CliRunner +from clarifai.cli.pipeline import compile as compile_command from clarifai.cli.pipeline import init, run, upload from clarifai.cli.pipeline_template import info, list_templates from clarifai.runners.pipelines.pipeline_builder import ( @@ -384,6 +385,21 @@ def test_cli_upload_help(self): assert result.exit_code == 0 assert "Upload a pipeline with associated pipeline steps" in result.output assert "PATH" in result.output + assert '--user_id' in result.output + assert '--app_id' in result.output + assert '--user-id' not in result.output + assert '--app-id' not in result.output + + def test_cli_compile_help_uses_underscore_identity_flags(self): + """Test compile help uses the existing underscore flag convention.""" + runner = CliRunner() + result = runner.invoke(compile_command, ['--help']) + + assert result.exit_code == 0 + assert '--user_id' in result.output + assert '--app_id' in result.output + assert '--user-id' not in result.output + assert '--app-id' not in result.output def test_cli_upload_missing_config(self): """Test CLI upload with missing config file.""" @@ -2115,7 +2131,7 @@ def test_list_command_requires_app_id(self): ctx_obj.current.api_base = 'https://api.clarifai.com' # Import here to avoid circular imports in testing - from clarifai.cli.pipeline import list as list_command + from clarifai.cli.pipeline import list_pipelines as list_command result = runner.invoke( list_command, @@ -2152,7 +2168,7 @@ def test_list_command_success_with_app_id(self, mock_display, mock_app_class, mo ctx_obj.current.api_base = 'https://api.clarifai.com' # Import here to avoid circular imports in testing - from clarifai.cli.pipeline import list as list_command + from clarifai.cli.pipeline import list_pipelines as list_command result = runner.invoke( list_command, @@ -2185,7 +2201,7 @@ def test_list_command_default_parameters(self, mock_validate): ctx_obj.current.api_base = 'https://api.clarifai.com' # Import here to avoid circular imports in testing - from clarifai.cli.pipeline import list as list_command + from clarifai.cli.pipeline import list_pipelines as list_command with patch('clarifai.client.app.App') as mock_app_class: mock_app_instance = Mock() diff --git a/tests/cli/test_pipeline_dsl_cli.py b/tests/cli/test_pipeline_dsl_cli.py index 3d438b12..c252f94b 100644 --- a/tests/cli/test_pipeline_dsl_cli.py +++ b/tests/cli/test_pipeline_dsl_cli.py @@ -38,6 +38,7 @@ def test_generate_python_pipeline_file_writes_output(tmp_path: Path): with patch('clarifai.runners.pipelines.load_pipeline_from_file') as mock_loader: mock_pipeline = Mock() mock_pipeline.generate.return_value = str(output_dir / 'config.yaml') + mock_pipeline.nodes = [] # no managed steps → no Dockerfiles expected mock_loader.return_value = mock_pipeline result = runner.invoke(compile, [str(pipeline_file), '--output-dir', str(output_dir)]) @@ -68,3 +69,27 @@ def test_generate_real_example_pipeline_writes_mixed_step_config(tmp_path: Path) assert (output_dir / 'prepare-text' / '1' / 'text_utils.py').exists() assert not (output_dir / 'summarize').exists() assert not (output_dir / 'classify-sentiment').exists() + # compile must also generate Dockerfiles for locally managed steps + assert (output_dir / 'prepare-text' / 'Dockerfile').exists() + assert (output_dir / 'assemble-report' / 'Dockerfile').exists() + + +def test_compile_generates_dockerfiles_for_managed_steps(tmp_path: Path): + """compile writes a Dockerfile next to each locally managed step directory.""" + repo_root = Path(__file__).resolve().parents[2] + pipeline_file = repo_root / 'examples' / 'pipeline_dsl_text_pipeline.py' + output_dir = tmp_path / 'compiled' + runner = CliRunner() + + result = runner.invoke(compile, [str(pipeline_file), '--output-dir', str(output_dir)]) + + assert result.exit_code == 0, result.output + for step_id in ('prepare-text', 'assemble-report'): + dockerfile = output_dir / step_id / 'Dockerfile' + assert dockerfile.exists(), f'Dockerfile missing for step {step_id!r}' + content = dockerfile.read_text(encoding='utf-8') + assert 'FROM ' in content + assert 'COPY --link=true 1 /home/nonroot/main/1' in content + # Pre-existing (non-managed) steps must NOT get a Dockerfile. + assert not (output_dir / 'summarize').exists() + assert not (output_dir / 'classify-sentiment').exists() diff --git a/tests/runners/test_pipeline_dsl.py b/tests/runners/test_pipeline_dsl.py index ca158ece..88c72c38 100644 --- a/tests/runners/test_pipeline_dsl.py +++ b/tests/runners/test_pipeline_dsl.py @@ -236,11 +236,122 @@ def test_step_ref_from_url_parses_versioned_pipeline_step_url(): assert step_definition.secrets == {'OPENAI_API_KEY': 'users/demo-user/secrets/openai-key'} +def test_argo_spec_propagates_function_default_parameters(): + # NOTE: avoid using 'name' as a parameter name here — StepDefinition.__call__ + # does `kwargs.pop('name', None)` to extract the optional task-name override, + # which silently swallows any step parameter literally called 'name'. + @step(id='greet') + def greet(text: str, greeting: str = 'Hello', repeat: int = 1) -> str: + return f'{greeting} {text}' * repeat + + with Pipeline(id='defaults-pipeline', user_id='me', app_id='my-app') as pipeline: + raw_text = pipeline.input('input_text') + # Only pass 'text'; 'greeting' and 'repeat' should be filled from defaults. + greet(text=raw_text) + + argo_spec = pipeline.to_argo_spec() + step_groups = argo_spec['spec']['templates'][0]['steps'] + tasks = {entry['name']: entry for group in step_groups for entry in group} + + params = {p['name']: p['value'] for p in tasks['greet']['arguments']['parameters']} + assert params['text'] == '{{workflow.parameters.input_text}}' + assert params['greeting'] == 'Hello' + assert params['repeat'] == '1' + + def test_step_ref_from_url_requires_versioned_pipeline_step_path(): with pytest.raises(ValueError, match='versioned pipeline step URL or resource path'): step_ref.from_url('users/demo-user/apps/shared-app/pipeline_steps/summarize') +# Module-level step definitions so codegen can locate their source in this file. +@step( + id='custom-base-step', + base_image='my-registry.example.com/my-python:3.12-gpu', +) +def custom_base_step(text: str) -> str: + return text + + +@step(id='no-base-step') +def no_base_step(text: str) -> str: + return text + + +@step( + id='arm64-step', + platform='linux/arm64', +) +def arm64_step(text: str) -> str: + return text + + +@step( + id='custom-platform-and-image-step', + base_image='ghcr.io/org/ml-base:latest', + platform='linux/amd64', +) +def custom_platform_and_image_step(text: str) -> str: + return text + + +def test_step_with_base_image_writes_build_info(tmp_path: Path): + """base_image specified in @step is written to build_info in config.yaml.""" + with Pipeline(id='base-image-pipeline', user_id='me', app_id='my-app') as pipeline: + raw = pipeline.input('input_text') + custom_base_step(text=raw) + + pipeline.generate(str(tmp_path)) + + config_path = tmp_path / 'custom-base-step' / 'config.yaml' + assert config_path.exists() + config = yaml.safe_load(config_path.read_text(encoding='utf-8')) + + assert config['build_info']['base_image'] == 'my-registry.example.com/my-python:3.12-gpu' + + +def test_step_without_base_image_omits_field_from_build_info(tmp_path: Path): + """When base_image is not set the key is absent from build_info.""" + with Pipeline(id='no-base-image-pipeline', user_id='me', app_id='my-app') as pipeline: + raw = pipeline.input('input_text') + no_base_step(text=raw) + + pipeline.generate(str(tmp_path)) + + config_path = tmp_path / 'no-base-step' / 'config.yaml' + config = yaml.safe_load(config_path.read_text(encoding='utf-8')) + + assert 'base_image' not in config.get('build_info', {}) + + +def test_step_with_platform_writes_build_info(tmp_path: Path): + """platform specified in @step is written to build_info in config.yaml.""" + with Pipeline(id='arm64-pipeline', user_id='me', app_id='my-app') as pipeline: + raw = pipeline.input('input_text') + arm64_step(text=raw) + + pipeline.generate(str(tmp_path)) + + config = yaml.safe_load((tmp_path / 'arm64-step' / 'config.yaml').read_text(encoding='utf-8')) + assert config['build_info']['platform'] == 'linux/arm64' + assert 'base_image' not in config['build_info'] + + +def test_step_with_platform_and_base_image(tmp_path: Path): + """Both platform and base_image are written to build_info when set together.""" + with Pipeline(id='combo-pipeline', user_id='me', app_id='my-app') as pipeline: + raw = pipeline.input('input_text') + custom_platform_and_image_step(text=raw) + + pipeline.generate(str(tmp_path)) + + config = yaml.safe_load( + (tmp_path / 'custom-platform-and-image-step' / 'config.yaml').read_text(encoding='utf-8') + ) + assert config['build_info']['base_image'] == 'ghcr.io/org/ml-base:latest' + assert config['build_info']['platform'] == 'linux/amd64' + + def _mock_config_with_context(user_id, app_id): """Return a mock Config whose .current gives a context with the given ids.""" ctx = SimpleNamespace(user_id=user_id, app_id=app_id) @@ -265,8 +376,29 @@ def test_pipeline_explicit_args_override_context(): assert pipeline.app_id == 'explicit-app' -def test_pipeline_raises_without_context_or_explicit_args(): +def test_pipeline_no_context_or_explicit_args_defers_validation(): + # Validation of user_id/app_id is deferred to upload/compile time, + # so constructing a Pipeline without them should not raise. + config = _mock_config_with_context(None, None) + with patch('clarifai.utils.config.Config.from_yaml', return_value=config): + pipeline = Pipeline(id='p3') + assert pipeline.user_id is None + assert pipeline.app_id is None + + +def test_pipeline_generate_raises_without_user_id_or_app_id(tmp_path): + config = _mock_config_with_context(None, None) + with patch('clarifai.utils.config.Config.from_yaml', return_value=config): + pipeline = Pipeline(id='p3') + with pytest.raises(ValueError, match='clarifai login') as exc_info: + pipeline.generate(str(tmp_path)) + assert '--user_id/--app_id' in str(exc_info.value) + + +def test_pipeline_upload_raises_without_user_id_or_app_id(): config = _mock_config_with_context(None, None) with patch('clarifai.utils.config.Config.from_yaml', return_value=config): - with pytest.raises(ValueError, match='clarifai login'): - Pipeline(id='p3') + pipeline = Pipeline(id='p3') + with pytest.raises(ValueError, match='clarifai login') as exc_info: + pipeline.upload() + assert '--user_id/--app_id' in str(exc_info.value) diff --git a/tests/runners/test_pipeline_step_builder.py b/tests/runners/test_pipeline_step_builder.py index 146356e1..9addacbe 100644 --- a/tests/runners/test_pipeline_step_builder.py +++ b/tests/runners/test_pipeline_step_builder.py @@ -232,10 +232,100 @@ def test_create_dockerfile(self, mock_base_client, setup_test_folder): with open(dockerfile_path, 'r') as f: content = f.read() - assert "FROM --platform=$TARGETPLATFORM" in content + assert "FROM " in content + assert "--platform" not in content.split('\n')[0] # no platform when not specified assert "COPY --link requirements.txt" in content assert "COPY --link=true 1 /home/nonroot/main/1" in content assert "3.12" in content # Python version from config + assert "public.ecr.aws/clarifai-models/python-base" in content + + def test_create_dockerfile_custom_base_image(self, mock_base_client, temp_dir, valid_config): + """Test Dockerfile creation with a custom base image.""" + config = dict(valid_config) + config['build_info'] = { + 'python_version': '3.12', + 'base_image': 'my-registry/my-image:latest', + } + + config_path = os.path.join(temp_dir, "config.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + os.makedirs(os.path.join(temp_dir, "1"), exist_ok=True) + with open(os.path.join(temp_dir, "1", "pipeline_step.py"), 'w') as f: + f.write("# Pipeline step implementation") + with open(os.path.join(temp_dir, "requirements.txt"), 'w') as f: + f.write("requests==2.32.0\n") + + builder = PipelineStepBuilder(temp_dir) + builder.create_dockerfile() + + dockerfile_path = os.path.join(temp_dir, "Dockerfile") + assert os.path.exists(dockerfile_path) + + with open(dockerfile_path, 'r') as f: + content = f.read() + + assert "FROM my-registry/my-image:latest as final" in content + assert "public.ecr.aws/clarifai-models/python-base" not in content + assert "COPY --link requirements.txt" in content + assert "COPY --link=true 1 /home/nonroot/main/1" in content + + def test_create_dockerfile_custom_platform(self, mock_base_client, temp_dir, valid_config): + """Test Dockerfile creation with an explicit build platform.""" + config = dict(valid_config) + config['build_info'] = {'python_version': '3.12', 'platform': 'linux/arm64'} + + config_path = os.path.join(temp_dir, "config.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + os.makedirs(os.path.join(temp_dir, "1"), exist_ok=True) + with open(os.path.join(temp_dir, "1", "pipeline_step.py"), 'w') as f: + f.write("# Pipeline step implementation") + with open(os.path.join(temp_dir, "requirements.txt"), 'w') as f: + f.write("requests==2.32.0\n") + + builder = PipelineStepBuilder(temp_dir) + builder.create_dockerfile() + + with open(os.path.join(temp_dir, "Dockerfile"), 'r') as f: + content = f.read() + + assert "FROM --platform=linux/arm64" in content + assert "$TARGETPLATFORM" not in content + assert "public.ecr.aws/clarifai-models/python-base" in content + + def test_create_dockerfile_custom_platform_and_base_image( + self, mock_base_client, temp_dir, valid_config + ): + """Test Dockerfile creation with both a custom platform and custom base image.""" + config = dict(valid_config) + config['build_info'] = { + 'python_version': '3.12', + 'platform': 'linux/amd64', + 'base_image': 'ghcr.io/org/my-image:latest', + } + + config_path = os.path.join(temp_dir, "config.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + os.makedirs(os.path.join(temp_dir, "1"), exist_ok=True) + with open(os.path.join(temp_dir, "1", "pipeline_step.py"), 'w') as f: + f.write("# Pipeline step implementation") + with open(os.path.join(temp_dir, "requirements.txt"), 'w') as f: + f.write("requests==2.32.0\n") + + builder = PipelineStepBuilder(temp_dir) + builder.create_dockerfile() + + with open(os.path.join(temp_dir, "Dockerfile"), 'r') as f: + content = f.read() + + assert "FROM --platform=linux/amd64 ghcr.io/org/my-image:latest as final" in content + assert "$TARGETPLATFORM" not in content + assert "public.ecr.aws/clarifai-models/python-base" not in content def test_tar_file_property(self, mock_base_client, setup_test_folder): """Test tar_file property returns correct path."""