Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 93 additions & 5 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,86 @@ def execute_task_cmd(
)


def _parse_fast_execute_args(task_execute_cmd: List[str], additional_distribution: str, dest_dir: str) -> Dict[str, str]:
"""Parse pyflyte-execute arguments from the raw command list."""
args: Dict[str, str] = {}
i = 0
cmd_list = list(task_execute_cmd)
if cmd_list and cmd_list[0] == "pyflyte-execute":
i = 1

resolver_args_list: List[str] = []
found_resolver_args = False

while i < len(cmd_list):
arg = cmd_list[i]
if arg.startswith("--"):
key = arg.lstrip("-").replace("-", "_")
if key == "test":
args["test"] = "true"
i += 1
continue
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
args[key] = cmd_list[i + 1]
if key == "resolver":
found_resolver_args = True
i += 2
continue
i += 1
else:
if found_resolver_args:
resolver_args_list.append(arg)
i += 1

args["dynamic_addl_distro"] = additional_distribution or ""
args["dynamic_dest_dir"] = dest_dir or ""
args["resolver_args"] = ",".join(resolver_args_list)

return args


def _execute_in_process(
task_execute_cmd: List[str],
additional_distribution: str,
dest_dir: str,
dest_dir_resolved: str,
) -> int:
"""Run the pyflyte-execute logic in-process instead of spawning a subprocess."""
if dest_dir_resolved:
if dest_dir_resolved not in sys.path:
sys.path.insert(0, dest_dir_resolved)
existing = os.environ.get("PYTHONPATH", "")
if existing:
os.environ["PYTHONPATH"] = existing + os.pathsep + dest_dir_resolved
else:
os.environ["PYTHONPATH"] = dest_dir_resolved

parsed = _parse_fast_execute_args(task_execute_cmd, additional_distribution, dest_dir)

raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs(
parsed.get("raw_output_data_prefix"),
parsed.get("checkpoint_path"),
parsed.get("prev_checkpoint"),
)

resolver_args_str = parsed.get("resolver_args", "")
resolver_args_tuple = tuple(resolver_args_str.split(",")) if resolver_args_str else ()

_execute_task(
inputs=parsed.get("inputs", ""),
output_prefix=parsed.get("output_prefix", ""),
raw_output_data_prefix=raw_output_data_prefix,
test=parsed.get("test") == "true",
resolver=parsed.get("resolver"),
resolver_args=resolver_args_tuple,
dynamic_addl_distro=parsed.get("dynamic_addl_distro") or None,
dynamic_dest_dir=parsed.get("dynamic_dest_dir") or None,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)
return 0


@_pass_through.command("pyflyte-fast-execute")
@click.option("--additional-distribution", required=False)
@click.option("--dest-dir", required=False)
Expand All @@ -732,18 +812,26 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec
dest_dir = os.getcwd()
_download_distribution(additional_distribution, dest_dir)

# Insert the call to fast before the unbounded resolver args
dest_dir_resolved = ""
if dest_dir is not None:
dest_dir_resolved = os.path.realpath(os.path.expanduser(dest_dir))

try:
returncode = _execute_in_process(task_execute_cmd, additional_distribution, dest_dir, dest_dir_resolved)
exit(returncode)
except SystemExit:
raise
except Exception:
logger.warning("In-process execute failed, falling back to subprocess", exc_info=True)

cmd = []
for arg in task_execute_cmd:
if arg == "--resolver":
cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir])
cmd.append(arg)

# Use the commandline to run the task execute command rather than calling it directly in python code
# since the current runtime bytecode references the older user code, rather than the downloaded distribution.
env = os.environ.copy()
if dest_dir is not None:
dest_dir_resolved = os.path.realpath(os.path.expanduser(dest_dir))
if dest_dir_resolved:
if "PYTHONPATH" in env:
env["PYTHONPATH"] += os.pathsep + dest_dir_resolved
else:
Expand Down
16 changes: 11 additions & 5 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
_F_IMG_ID = "_F_IMG_ID"
FLYTE_FORCE_PUSH_IMAGE_SPEC = "FLYTE_FORCE_PUSH_IMAGE_SPEC"

_ecr_existence_cache: Dict[Tuple[str, str, str], bool] = {}


# Shared helpers for Nix flake path inputs and git root discovery
_NIX_FLAKE_PATH_INPUT_PATTERN = re.compile(r'url\s*=\s*"(path:([^"]+))"')
Expand Down Expand Up @@ -132,8 +134,12 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional
f"Extracted - Account ID: {account_id}, Region: {region}, Repository: {repository}, Tag: {tag}", fg="cyan"
)

cache_key = (registry, repository, tag)
if cache_key in _ecr_existence_cache:
cached = _ecr_existence_cache[cache_key]
return cached

try:
# Use AWS CLI to check if image exists
image_ids_json = json.dumps([{"imageTag": tag}])
cmd = [
"aws",
Expand All @@ -149,16 +155,16 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional
"json",
]

# Output the command being executed
click.secho(f"Executing AWS ECR command: {' '.join(cmd)}", fg="blue")

result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)

if result.returncode == 0:
data = json.loads(result.stdout)
return len(data.get("imageDetails", [])) > 0
exists = len(data.get("imageDetails", [])) > 0
_ecr_existence_cache[cache_key] = exists
return exists
elif result.returncode != 0:
# Check for various image not found scenarios
if any(
phrase in result.stderr
for phrase in [
Expand All @@ -169,9 +175,9 @@ def check_ecr_image_exists(registry: str, repository: str, tag: str) -> Optional
]
):
click.secho(f"Image not found in ECR: {result.stderr}", fg="yellow")
_ecr_existence_cache[cache_key] = False
return False
else:
# Some other error occurred
click.secho(f"Failed to check ECR image: {result.stderr}", fg="yellow")
return None
except subprocess.TimeoutExpired:
Expand Down
64 changes: 60 additions & 4 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,30 @@ def _get_image_names(self, entity: typing.Union[PythonAutoContainerTask, Workflo
return image_names
return []

def _get_image_specs(self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[ImageSpec]:
if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec):
return [entity.container_image]
if isinstance(entity, WorkflowBase):
specs: typing.List[ImageSpec] = []
for n in entity.nodes:
specs.extend(self._get_image_specs(n.flyte_entity))
return specs
return []

@staticmethod
def _prefetch_ecr_existence(image_specs: typing.List[ImageSpec]) -> None:
"""Pre-warm the ECR existence cache for all ImageSpec objects."""
from flytekit.image_spec.image_spec import check_ecr_image_exists, is_ecr_registry, check_aws_cli_and_creds

for spec in image_specs:
if spec.registry and is_ecr_registry(spec.registry) and check_aws_cli_and_creds():
registry_parts = spec.registry.split("/", 1)
if len(registry_parts) > 1:
repository = f"{registry_parts[1]}/{spec.name}"
else:
repository = spec.name
check_ecr_image_exists(spec.registry.split("/")[0], repository, spec.tag)

def register_script(
self,
entity: typing.Union[WorkflowBase, PythonTask, LaunchPlan],
Expand Down Expand Up @@ -1295,6 +1319,8 @@ def register_script(
:param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False.
:return:
"""
import concurrent.futures

if isinstance(entity, ReferenceWorkflow):
return entity
if copy_all:
Expand All @@ -1310,6 +1336,13 @@ def register_script(
if image_config is None:
image_config = ImageConfig.auto_default_image()

image_specs = self._get_image_specs(entity)
ecr_future: typing.Optional[concurrent.futures.Future[None]] = None
ecr_executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None
if image_specs:
ecr_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
ecr_future = ecr_executor.submit(self._prefetch_ecr_existence, image_specs)

with tempfile.TemporaryDirectory() as tmp_dir:
if fast_package_options and fast_package_options.copy_style != CopyFileDetection.NO_COPY:
md5_bytes, upload_native_url = self.fast_package(
Expand Down Expand Up @@ -1341,20 +1374,43 @@ def register_script(
if isinstance(entity, WorkflowBase):
default_inputs = entity.python_interface.default_inputs_as_kwargs

# The md5 version that we send to S3/GCS has to match the file contents exactly,
# but we don't have to use it when registering with the Flyte backend.
# For that add the hash of the compilation settings to hash of file
version = self._version_from_hash(
md5_bytes, serialization_settings, default_inputs, *self._get_image_names(entity)
)

if isinstance(entity, WorkflowBase) and isinstance(version, str):
try:
if self._wf_exists(
name=entity.name,
version=version,
project=project or self.default_project,
domain=domain or self.default_domain,
):
logger.info(f"Workflow {entity.name} version {version} already exists, skipping registration")
if ecr_future is not None:
ecr_future.cancel()
fwf = self.fetch_workflow(
project or self.default_project,
domain or self.default_domain,
entity.name,
version,
)
fwf.python_interface = entity.python_interface
return fwf
except Exception as e:
logger.debug(f"Version-check-first lookup failed, proceeding with registration: {e}")

if ecr_future is not None:
ecr_future.result()
if ecr_executor is not None:
ecr_executor.shutdown(wait=False)

if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)

if isinstance(entity, WorkflowBase):
return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options)
if isinstance(entity, LaunchPlan):
# If it's a launch plan, we need to register the workflow first
return self.register_launch_plan(entity, version, project, domain, options, serialization_settings)
raise ValueError(f"Unsupported entity type {type(entity)}")

Expand Down