diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 01de763021..4b807909ad 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -718,6 +718,86 @@ def execute_task_cmd( ) +def _parse_fast_execute_args(task_execute_cmd: List[str], additional_distribution: str, dest_dir: str) -> Dict[str, str]: + """Parse pyflyte-execute arguments from the raw command list.""" + args: Dict[str, str] = {} + i = 0 + cmd_list = list(task_execute_cmd) + if cmd_list and cmd_list[0] == "pyflyte-execute": + i = 1 + + resolver_args_list: List[str] = [] + found_resolver_args = False + + while i < len(cmd_list): + arg = cmd_list[i] + if arg.startswith("--"): + key = arg.lstrip("-").replace("-", "_") + if key == "test": + args["test"] = "true" + i += 1 + continue + if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"): + args[key] = cmd_list[i + 1] + if key == "resolver": + found_resolver_args = True + i += 2 + continue + i += 1 + else: + if found_resolver_args: + resolver_args_list.append(arg) + i += 1 + + args["dynamic_addl_distro"] = additional_distribution or "" + args["dynamic_dest_dir"] = dest_dir or "" + args["resolver_args"] = ",".join(resolver_args_list) + + return args + + +def _execute_in_process( + task_execute_cmd: List[str], + additional_distribution: str, + dest_dir: str, + dest_dir_resolved: str, +) -> int: + """Run the pyflyte-execute logic in-process instead of spawning a subprocess.""" + if dest_dir_resolved: + if dest_dir_resolved not in sys.path: + sys.path.insert(0, dest_dir_resolved) + existing = os.environ.get("PYTHONPATH", "") + if existing: + os.environ["PYTHONPATH"] = existing + os.pathsep + dest_dir_resolved + else: + os.environ["PYTHONPATH"] = dest_dir_resolved + + parsed = _parse_fast_execute_args(task_execute_cmd, additional_distribution, dest_dir) + + raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs( + parsed.get("raw_output_data_prefix"), + parsed.get("checkpoint_path"), + parsed.get("prev_checkpoint"), + ) + + resolver_args_str = parsed.get("resolver_args", "") + resolver_args_tuple = tuple(resolver_args_str.split(",")) if resolver_args_str else () + + _execute_task( + inputs=parsed.get("inputs", ""), + output_prefix=parsed.get("output_prefix", ""), + raw_output_data_prefix=raw_output_data_prefix, + test=parsed.get("test") == "true", + resolver=parsed.get("resolver"), + resolver_args=resolver_args_tuple, + dynamic_addl_distro=parsed.get("dynamic_addl_distro") or None, + dynamic_dest_dir=parsed.get("dynamic_dest_dir") or None, + checkpoint_path=checkpoint_path, + prev_checkpoint=prev_checkpoint, + ) + return 0 + + @_pass_through.command("pyflyte-fast-execute") @click.option("--additional-distribution", required=False) @click.option("--dest-dir", required=False) @@ -732,18 +812,26 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec dest_dir = os.getcwd() _download_distribution(additional_distribution, dest_dir) - # Insert the call to fast before the unbounded resolver args + dest_dir_resolved = "" + if dest_dir is not None: + dest_dir_resolved = os.path.realpath(os.path.expanduser(dest_dir)) + + try: + returncode = _execute_in_process(task_execute_cmd, additional_distribution, dest_dir, dest_dir_resolved) + exit(returncode) + except SystemExit: + raise + except Exception: + logger.warning("In-process execute failed, falling back to subprocess", exc_info=True) + cmd = [] for arg in task_execute_cmd: if arg == "--resolver": cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir]) cmd.append(arg) - # Use the commandline to run the task execute command rather than calling it directly in python code - # since the current runtime bytecode references the older user code, rather than the downloaded distribution. env = os.environ.copy() - if dest_dir is not None: - dest_dir_resolved = os.path.realpath(os.path.expanduser(dest_dir)) + if dest_dir_resolved: if "PYTHONPATH" in env: env["PYTHONPATH"] += os.pathsep + dest_dir_resolved else: diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 8f6addf76a..baaf3ab199 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -25,6 +25,8 @@ _F_IMG_ID = "_F_IMG_ID" FLYTE_FORCE_PUSH_IMAGE_SPEC = "FLYTE_FORCE_PUSH_IMAGE_SPEC" +_ecr_existence_cache: Dict[Tuple[str, str, str], bool] = {} + # Shared helpers for Nix flake path inputs and git root discovery _NIX_FLAKE_PATH_INPUT_PATTERN = re.compile(r'url\s*=\s*"(path:([^"]+))"') @@ -132,8 +134,12 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional f"Extracted - Account ID: {account_id}, Region: {region}, Repository: {repository}, Tag: {tag}", fg="cyan" ) + cache_key = (registry, repository, tag) + if cache_key in _ecr_existence_cache: + cached = _ecr_existence_cache[cache_key] + return cached + try: - # Use AWS CLI to check if image exists image_ids_json = json.dumps([{"imageTag": tag}]) cmd = [ "aws", @@ -149,16 +155,16 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional "json", ] - # Output the command being executed click.secho(f"Executing AWS ECR command: {' '.join(cmd)}", fg="blue") result = subprocess.run(cmd, capture_output=True, text=True, timeout=10) if result.returncode == 0: data = json.loads(result.stdout) - return len(data.get("imageDetails", [])) > 0 + exists = len(data.get("imageDetails", [])) > 0 + _ecr_existence_cache[cache_key] = exists + return exists elif result.returncode != 0: - # Check for various image not found scenarios if any( phrase in result.stderr for phrase in [ @@ -169,9 +175,9 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional ] ): click.secho(f"Image not found in ECR: {result.stderr}", fg="yellow") + _ecr_existence_cache[cache_key] = False return False else: - # Some other error occurred click.secho(f"Failed to check ECR image: {result.stderr}", fg="yellow") return None except subprocess.TimeoutExpired: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 8f3de3cd55..344c34cd11 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1262,6 +1262,30 @@ def _get_image_names(self, entity: typing.Union[PythonAutoContainerTask, Workflo return image_names return [] + def _get_image_specs(self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[ImageSpec]: + if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec): + return [entity.container_image] + if isinstance(entity, WorkflowBase): + specs: typing.List[ImageSpec] = [] + for n in entity.nodes: + specs.extend(self._get_image_specs(n.flyte_entity)) + return specs + return [] + + @staticmethod + def _prefetch_ecr_existence(image_specs: typing.List[ImageSpec]) -> None: + """Pre-warm the ECR existence cache for all ImageSpec objects.""" + from flytekit.image_spec.image_spec import check_ecr_image_exists, is_ecr_registry, check_aws_cli_and_creds + + for spec in image_specs: + if spec.registry and is_ecr_registry(spec.registry) and check_aws_cli_and_creds(): + registry_parts = spec.registry.split("/", 1) + if len(registry_parts) > 1: + repository = f"{registry_parts[1]}/{spec.name}" + else: + repository = spec.name + check_ecr_image_exists(spec.registry.split("/")[0], repository, spec.tag) + def register_script( self, entity: typing.Union[WorkflowBase, PythonTask, LaunchPlan], @@ -1295,6 +1319,8 @@ def register_script( :param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False. :return: """ + import concurrent.futures + if isinstance(entity, ReferenceWorkflow): return entity if copy_all: @@ -1310,6 +1336,13 @@ def register_script( if image_config is None: image_config = ImageConfig.auto_default_image() + image_specs = self._get_image_specs(entity) + ecr_future: typing.Optional[concurrent.futures.Future[None]] = None + ecr_executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None + if image_specs: + ecr_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + ecr_future = ecr_executor.submit(self._prefetch_ecr_existence, image_specs) + with tempfile.TemporaryDirectory() as tmp_dir: if fast_package_options and fast_package_options.copy_style != CopyFileDetection.NO_COPY: md5_bytes, upload_native_url = self.fast_package( @@ -1341,20 +1374,43 @@ def register_script( if isinstance(entity, WorkflowBase): default_inputs = entity.python_interface.default_inputs_as_kwargs - # The md5 version that we send to S3/GCS has to match the file contents exactly, - # but we don't have to use it when registering with the Flyte backend. - # For that add the hash of the compilation settings to hash of file version = self._version_from_hash( md5_bytes, serialization_settings, default_inputs, *self._get_image_names(entity) ) + if isinstance(entity, WorkflowBase) and isinstance(version, str): + try: + if self._wf_exists( + name=entity.name, + version=version, + project=project or self.default_project, + domain=domain or self.default_domain, + ): + logger.info(f"Workflow {entity.name} version {version} already exists, skipping registration") + if ecr_future is not None: + ecr_future.cancel() + fwf = self.fetch_workflow( + project or self.default_project, + domain or self.default_domain, + entity.name, + version, + ) + fwf.python_interface = entity.python_interface + return fwf + except Exception as e: + logger.debug(f"Version-check-first lookup failed, proceeding with registration: {e}") + + if ecr_future is not None: + ecr_future.result() + if ecr_executor is not None: + ecr_executor.shutdown(wait=False) + if isinstance(entity, PythonTask): return self.register_task(entity, serialization_settings, version) if isinstance(entity, WorkflowBase): return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) if isinstance(entity, LaunchPlan): - # If it's a launch plan, we need to register the workflow first return self.register_launch_plan(entity, version, project, domain, options, serialization_settings) raise ValueError(f"Unsupported entity type {type(entity)}")