Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/4191.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix wrong `Accept` header handling in image rescanning.
349 changes: 279 additions & 70 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
import json
import copy
import logging
from abc import ABCMeta, abstractmethod
from contextlib import asynccontextmanager as actxmgr
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self.registry_url = registry_info[""]
self.max_concurrency_per_registry = max_concurrency_per_registry
self.base_hdrs = {
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
"Accept": self.MEDIA_TYPE_DOCKER_MANIFEST,
}
self.credentials = {}
self.ssl_verify = ssl_verify
Expand Down Expand Up @@ -158,9 +158,9 @@ async def scan_single_ref(self, image_ref: str) -> None:
self.credentials,
f"repository:{image}:pull",
)
rqst_args["headers"].update(**self.base_hdrs)
await self._scan_tag(sess, rqst_args, image, tag)
await self.commit_rescan_result()
await self._scan_tag(sess, rqst_args, project_and_image_name, tag)
scanned_images = await self.commit_rescan_result()
return RescanImagesResult(images=scanned_images)
finally:
concurrency_sema.reset(sema_token)
all_updates.reset(all_updates_token)
Expand Down Expand Up @@ -213,6 +213,7 @@ async def _scan_tag(
) -> None:
manifests = {}
async with concurrency_sema.get():
rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST_LIST
async with sess.get(
self.registry_url / f"v2/{image}/manifests/{tag}", **rqst_args
Expand All @@ -223,72 +224,280 @@ async def _scan_tag(
return
content_type = resp.headers["Content-Type"]
resp.raise_for_status()
resp_json = await resp.json()
match content_type:
case self.MEDIA_TYPE_DOCKER_MANIFEST_LIST:
manifest_list = resp_json["manifests"]
request_type = self.MEDIA_TYPE_DOCKER_MANIFEST
case self.MEDIA_TYPE_OCI_INDEX:
manifest_list = [
item
for item in resp_json["manifests"]
if "annotations" not in item # skip attestation manifests
]
request_type = self.MEDIA_TYPE_OCI_MANIFEST
case _:
raise RuntimeError(
"The registry does not support the standard way of "
"listing multiarch images."
)
rqst_args["headers"]["Accept"] = request_type
for manifest in manifest_list:
platform_arg = (
f"{manifest['platform']['os']}/{manifest['platform']['architecture']}"
resp_json = await read_json(resp)

try:
async with aiotools.TaskGroup() as tg:
match content_type:
case self.MEDIA_TYPE_DOCKER_MANIFEST:
await self._process_docker_v2_image(
tg, sess, rqst_args, image, tag, resp_json
)
case self.MEDIA_TYPE_DOCKER_MANIFEST_LIST:
await self._process_docker_v2_multiplatform_image(
tg, sess, rqst_args, image, tag, resp_json
)
case self.MEDIA_TYPE_OCI_INDEX:
await self._process_oci_index(
tg, sess, rqst_args, image, tag, resp_json
)
case self.MEDIA_TYPE_OCI_MANIFEST:
await self._process_oci_manifest(
tg, sess, rqst_args, image, tag, resp_json
)
case (
self.MEDIA_TYPE_DOCKER_MANIFEST_V1_PRETTY_JWS
| self.MEDIA_TYPE_DOCKER_MANIFEST_V1_JSON
):
await self._process_docker_v1_image(
tg, sess, rqst_args, image, tag, resp_json
)

case _:
log.warning("Unknown content type: {}", content_type)
raise RuntimeError(
"The registry does not support the standard way of "
"listing multiarch images."
)
except aiotools.TaskGroupError as e:
raise ScanTagError(
f"Tag scan failed, Details: {cast(ExceptionGroup, e).exceptions}"
) from e

async def _read_manifest_list(
self,
sess: aiohttp.ClientSession,
manifest_list: Sequence[Any],
rqst_args: dict[str, Any],
image: str,
tag: str,
) -> None:
"""
Understands images defined under [OCI image manifest](https://github.com/opencontainers/image-spec/blob/main/manifest.md#example-image-manifest) or
[Docker image manifest list](https://github.com/openshift/docker-distribution/blob/master/docs/spec/manifest-v2-2.md#example-manifest-list)
and imports Backend.AI compatible images.
"""
manifests = {}
for manifest in manifest_list:
platform_arg = f"{manifest['platform']['os']}/{manifest['platform']['architecture']}"
if variant := manifest["platform"].get("variant", None):
platform_arg += f"/{variant}"
architecture = manifest["platform"]["architecture"]
architecture = arch_name_aliases.get(architecture, architecture)

async with sess.get(
self.registry_url / f"v2/{image}/manifests/{manifest['digest']}",
**rqst_args,
) as resp:
manifest_info = await resp.json()

manifests[architecture] = await self._preprocess_manifest(
sess, manifest_info, rqst_args, image
)

if not manifests[architecture]["labels"]:
log.warning(
"The image {}:{}/{} has no metadata labels -> treating as vanilla image",
image,
tag,
architecture,
)
if variant := manifest["platform"].get("variant", None):
platform_arg += f"/{variant}"
architecture = manifest["platform"]["architecture"]
architecture = arch_name_aliases.get(architecture, architecture)
async with sess.get(
self.registry_url / f"v2/{image}/manifests/{manifest['digest']}", **rqst_args
) as resp:
data = await resp.json()
config_digest = data["config"]["digest"]
size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"]
async with sess.get(
self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args
) as resp:
resp.raise_for_status()
data = json.loads(await resp.read())
labels = {}
if "container_config" in data:
raw_labels = data["container_config"].get("Labels")
if raw_labels:
labels.update(raw_labels)
else:
log.warning(
"label not found on image {}:{}/{}",
image,
tag,
architecture,
)
else:
raw_labels = data["config"].get("Labels")
if raw_labels:
labels.update(raw_labels)
else:
log.warning(
"label not found on image {}:{}/{}",
image,
tag,
architecture,
)
manifests[architecture] = {
"size": size_bytes,
"labels": labels,
"digest": config_digest,
}
await self._read_manifest(image, tag, manifests)
manifests[architecture]["labels"] = {}

await self._read_manifest(image, tag, manifests)

async def _preprocess_manifest(
self,
sess: aiohttp.ClientSession,
manifest: Mapping[str, Any],
rqst_args: dict[str, Any],
image: str,
) -> dict[str, Any]:
"""
Extracts informations from
[Docker iamge manifest](https://github.com/openshift/docker-distribution/blob/master/docs/spec/manifest-v2-2.md#example-image-manifest)
required by Backend.AI.
"""
config_digest = manifest["config"]["digest"]
size_bytes = sum(layer["size"] for layer in manifest["layers"]) + manifest["config"]["size"]

async with sess.get(
self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args
) as resp:
resp.raise_for_status()
data = await read_json(resp)
labels = {}

# we should favor `config` instead of `container_config` since `config` can contain additional datas
# set when commiting image via `--change` flag
if _config_labels := data.get("config", {}).get("Labels"):
labels = _config_labels
elif _container_config_labels := data.get("container_config", {}).get("Labels"):
labels = _container_config_labels

return {
"size": size_bytes,
"labels": labels,
"digest": config_digest,
}

async def _process_oci_index(
self,
tg: aiotools.TaskGroup,
sess: aiohttp.ClientSession,
rqst_args: dict[str, Any],
image: str,
tag: str,
image_info: Mapping[str, Any],
) -> None:
manifest_list = [
item
for item in image_info["manifests"]
if "annotations" not in item # skip attestation manifests
]
rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_OCI_MANIFEST

await self._read_manifest_list(sess, manifest_list, rqst_args, image, tag)

async def _process_oci_manifest(
self,
tg: aiotools.TaskGroup,
sess: aiohttp.ClientSession,
rqst_args: dict[str, Any],
image: str,
tag: str,
image_info: Mapping[str, Any],
) -> None:
rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"] = rqst_args.get("headers", {})
rqst_args["headers"].update({
"Accept": self.MEDIA_TYPE_OCI_MANIFEST,
})

if (reporter := progress_reporter.get()) is not None:
reporter.total_progress += 1

async with concurrency_sema.get():
config_digest = image_info["config"]["digest"]
size_bytes = (
sum(layer["size"] for layer in image_info["layers"]) + image_info["config"]["size"]
)

async with sess.get(
self.registry_url / f"v2/{image}/blobs/{config_digest}",
**rqst_args,
) as resp:
resp.raise_for_status()
config_data = await read_json(resp)

labels = {}
if _config_labels := config_data.get("config", {}).get("Labels"):
labels = _config_labels
elif _container_config_labels := config_data.get("container_config", {}).get("Labels"):
labels = _container_config_labels

if not labels:
log.warning(
"The image {}:{} has no metadata labels -> treating as vanilla image",
image,
tag,
)
labels = {}

architecture = config_data.get("architecture")
if architecture:
architecture = arch_name_aliases.get(architecture, architecture)
else:
if tag.endswith("-arm64") or tag.endswith("-aarch64"):
architecture = "aarch64"
else:
architecture = "x86_64"

manifests = {
architecture: {
"size": size_bytes,
"labels": labels,
"digest": config_digest,
}
}
await self._read_manifest(image, tag, manifests)

async def _process_docker_v2_multiplatform_image(
self,
tg: aiotools.TaskGroup,
sess: aiohttp.ClientSession,
rqst_args: dict[str, Any],
image: str,
tag: str,
image_info: Mapping[str, Any],
) -> None:
manifest_list = image_info["manifests"]
rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST

await self._read_manifest_list(
sess,
manifest_list,
rqst_args,
image,
tag,
)

async def _process_docker_v2_image(
self,
tg: aiotools.TaskGroup,
sess: aiohttp.ClientSession,
rqst_args: dict[str, Any],
image: str,
tag: str,
image_info: Mapping[str, Any],
) -> None:
config_digest = image_info["config"]["digest"]
rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST

async with sess.get(
self.registry_url / f"v2/{image}/blobs/{config_digest}",
**rqst_args,
) as resp:
resp.raise_for_status()
blob_data = await read_json(resp)

manifest_arch = blob_data["architecture"]
architecture = arch_name_aliases.get(manifest_arch, manifest_arch)

manifests = {
architecture: await self._preprocess_manifest(sess, image_info, rqst_args, image),
}
await self._read_manifest(image, tag, manifests)

async def _process_docker_v1_image(
self,
tg: aiotools.TaskGroup,
sess: aiohttp.ClientSession,
rqst_args: dict[str, Any],
image: str,
tag: str,
image_info: Mapping[str, Any],
) -> None:
log.warning("Docker image manifest v1 is deprecated.")

architecture = image_info["architecture"]

manifest_list = [
{
"platform": {
"os": "linux",
"architecture": architecture,
},
"digest": tag,
}
]

rqst_args = copy.deepcopy(rqst_args)
rqst_args["headers"]["Accept"] = self.MEDIA_TYPE_DOCKER_MANIFEST
await self._read_manifest_list(sess, manifest_list, rqst_args, image, tag)

async def _read_manifest(
self,
Expand Down
Loading