diff --git a/pyproject.toml b/pyproject.toml index c3025c3fe..86ef9cdcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,10 +74,6 @@ dependencies = [ [tool.pyright] exclude = [ - # TODO(lucasagomes): This module was copied from road-core - # service/ols/src/auth/k8s.py and currently has 58 Pyright issues. It - # might need to be rewritten down the line. - "src/authentication/k8s.py", # Agent API v1 endpoints - deprecated API but still supported # Type errors due to llama-stack-client not exposing Agent API types "src/app/endpoints/conversations.py", diff --git a/src/authentication/k8s.py b/src/authentication/k8s.py index 399a3d40c..fba8336fe 100644 --- a/src/authentication/k8s.py +++ b/src/authentication/k8s.py @@ -1,8 +1,7 @@ """Manage authentication flow for FastAPI endpoints with K8S/OCP.""" import os -from pathlib import Path -from typing import Optional, Self +from typing import Optional, Self, cast import kubernetes.client from fastapi import HTTPException, Request @@ -80,19 +79,19 @@ def __new__(cls: type[Self]) -> Self: ce, ) - k8s_config.host = ( - configuration.authentication_configuration.k8s_cluster_api - or k8s_config.host - ) + k8s_api_url = configuration.authentication_configuration.k8s_cluster_api + if k8s_api_url is not None: + k8s_config.host = str(k8s_api_url) k8s_config.verify_ssl = ( not configuration.authentication_configuration.skip_tls_verification ) - k8s_config.ssl_ca_cert = ( + ca_cert_path = ( configuration.authentication_configuration.k8s_ca_cert_path - if configuration.authentication_configuration.k8s_ca_cert_path - not in {None, Path()} - else k8s_config.ssl_ca_cert ) + if ca_cert_path is not None: + # Kubernetes client library has incomplete type stubs for ssl_ca_cert + k8s_config.ssl_ca_cert = str(ca_cert_path) # type: ignore[assignment] + # else keep the default k8s_config.ssl_ca_cert api_client = kubernetes.client.ApiClient(k8s_config) cls._api_client = api_client cls._custom_objects_api = kubernetes.client.CustomObjectsApi(api_client) @@ -101,7 +100,8 @@ def __new__(cls: type[Self]) -> Self: except Exception as e: logger.info("Failed to initialize Kubernetes client: %s", e) raise - return cls._instance + # At this point _instance is guaranteed to be initialized + return cast(Self, cls._instance) @classmethod def get_authn_api(cls) -> kubernetes.client.AuthenticationV1Api: @@ -161,10 +161,23 @@ def _get_cluster_id(cls) -> str: """ try: custom_objects_api = cls.get_custom_objects_api() - version_data = custom_objects_api.get_cluster_custom_object( - "config.openshift.io", "v1", "clusterversions", "version" + # Kubernetes API always returns dict for custom objects + version_data = cast( + dict, + custom_objects_api.get_cluster_custom_object( + "config.openshift.io", "v1", "clusterversions", "version" + ), ) - cluster_id = version_data["spec"]["clusterID"] + spec = version_data.get("spec") + if not isinstance(spec, dict): + raise ClusterIDUnavailableError( + "Missing or invalid 'spec' in ClusterVersion" + ) + cluster_id = spec.get("clusterID") + if not isinstance(cluster_id, str) or not cluster_id.strip(): + raise ClusterIDUnavailableError( + "Missing or invalid 'clusterID' in ClusterVersion" + ) cls._cluster_id = cluster_id return cluster_id except KeyError as e: @@ -172,11 +185,6 @@ def _get_cluster_id(cls) -> str: "Failed to get cluster_id from cluster, missing keys in version object" ) raise ClusterIDUnavailableError("Failed to get cluster ID") from e - except TypeError as e: - logger.error( - "Failed to get cluster_id, version object is: %s", version_data - ) - raise ClusterIDUnavailableError("Failed to get cluster ID") from e except ApiException as e: logger.error("API exception during ClusterInfo: %s", e) raise ClusterIDUnavailableError("Failed to get cluster ID") from e @@ -212,14 +220,14 @@ def get_cluster_id(cls) -> str: return cls._cluster_id -def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]: +def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus]: """Perform a Kubernetes TokenReview to validate a given token. Parameters: token: The bearer token to be validated. Returns: - The user information if the token is valid, None otherwise. + The V1TokenReviewStatus if the token is valid, None otherwise. Raises: HTTPException: If unable to connect to Kubernetes API or unexpected error occurs. @@ -238,9 +246,13 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]: spec=kubernetes.client.V1TokenReviewSpec(token=token) ) try: - response = auth_api.create_token_review(token_review) - if response.status.authenticated: - return response.status + response = cast( + kubernetes.client.V1TokenReview, + auth_api.create_token_review(token_review), + ) + status = response.status + if status is not None and status.authenticated: + return status return None except Exception as e: # pylint: disable=broad-exception-caught logger.error("API exception during TokenReview: %s", e) @@ -307,9 +319,12 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token") raise HTTPException(**response.model_dump()) - if user_info.user.username == "kube:admin": + # Cast user to proper type for type checking + user = cast(kubernetes.client.V1UserInfo, user_info.user) + + if user.username == "kube:admin": try: - user_info.user.uid = K8sClientSingleton.get_cluster_id() + user.uid = K8sClientSingleton.get_cluster_id() except ClusterIDUnavailableError as e: logger.error("Failed to get cluster ID: %s", e) response = InternalServerErrorResponse( @@ -322,14 +337,17 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: authorization_api = K8sClientSingleton.get_authz_api() sar = kubernetes.client.V1SubjectAccessReview( spec=kubernetes.client.V1SubjectAccessReviewSpec( - user=user_info.user.username, - groups=user_info.user.groups, + user=user.username, + groups=user.groups, non_resource_attributes=kubernetes.client.V1NonResourceAttributes( path=self.virtual_path, verb="get" ), ) ) - response = authorization_api.create_subject_access_review(sar) + sar_response = cast( + kubernetes.client.V1SubjectAccessReview, + authorization_api.create_subject_access_review(sar), + ) except Exception as e: logger.error("API exception during SubjectAccessReview: %s", e) @@ -339,13 +357,19 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: ) raise HTTPException(**response.model_dump()) from e - if not response.status.allowed: - response = ForbiddenResponse.endpoint(user_id=user_info.user.uid) + sar_status = cast( + kubernetes.client.V1SubjectAccessReviewStatus, sar_response.status + ) + user_uid = cast(str, user.uid) + username = cast(str, user.username) + + if not sar_status.allowed: + response = ForbiddenResponse.endpoint(user_id=user_uid) raise HTTPException(**response.model_dump()) return ( - user_info.user.uid, - user_info.user.username, + user_uid, + username, self.skip_userid_check, token, )