44from collections .abc import AsyncGenerator , Callable
55from contextlib import asynccontextmanager
66from typing import Any
7- from urllib .parse import urljoin
7+ from urllib .parse import urljoin , urlparse
88
99from authlib .integrations .base_client import InvalidTokenError
1010from authlib .integrations .httpx_client import AsyncOAuth2Client , OAuthError
1111from sqlalchemy import select
1212from sqlalchemy .ext .asyncio import AsyncSession
13- from sqlalchemy .orm import selectinload
13+ from sqlalchemy .orm import joinedload , selectinload
1414from ulid import ULID
1515
1616import renku_data_services .base_models as base_models
1717from renku_data_services import errors
1818from renku_data_services .app_config import logging
1919from renku_data_services .base_api .pagination import PaginationRequest
20- from renku_data_services .connected_services import apispec , models
20+ from renku_data_services .connected_services import models
2121from renku_data_services .connected_services import orm as schemas
2222from renku_data_services .connected_services .apispec import ConnectionStatus , ProviderKind
2323from renku_data_services .connected_services .provider_adapters import (
2626 get_provider_adapter ,
2727)
2828from renku_data_services .connected_services .utils import generate_code_verifier
29+ from renku_data_services .notebooks .api .classes .image import Image , ImageRepoDockerAPI
2930from renku_data_services .utils .cryptography import decrypt_string , encrypt_string
3031
3132logger = logging .getLogger (__name__ )
@@ -68,9 +69,7 @@ async def get_oauth2_client(self, provider_id: str, user: base_models.APIUser) -
6869 return client .dump (user_is_admin = user .is_admin )
6970
7071 async def insert_oauth2_client (
71- self ,
72- user : base_models .APIUser ,
73- new_client : apispec .ProviderPost ,
72+ self , user : base_models .APIUser , new_client : models .UnsavedOAuth2Client
7473 ) -> models .OAuth2Client :
7574 """Insert a new OAuth2 Client environment."""
7675 if user .id is None :
@@ -93,6 +92,7 @@ async def insert_oauth2_client(
9392 url = new_client .url ,
9493 use_pkce = new_client .use_pkce or False ,
9594 created_by_id = user .id ,
95+ image_registry_url = new_client .image_registry_url ,
9696 )
9797
9898 async with self .session_maker () as session , session .begin ():
@@ -144,6 +144,12 @@ async def update_oauth2_client(
144144 client .url = patch .url
145145 if patch .use_pkce is not None :
146146 client .use_pkce = patch .use_pkce
147+ if patch .image_registry_url :
148+ # Patching with a string of at least length 1 updates the value
149+ client .image_registry_url = patch .image_registry_url
150+ elif patch .image_registry_url == "" :
151+ # Patching with "", removes the value
152+ client .image_registry_url = None
147153
148154 await session .flush ()
149155 await session .refresh (client )
@@ -272,7 +278,7 @@ async def authorize_callback(self, state: str, raw_url: str, callback_url: str)
272278 adapter .token_endpoint_url , authorization_response = raw_url , code_verifier = code_verifier
273279 )
274280
275- logger .info (f"Token for client { client .id } has keys: { ", " .join (token .keys ())} " )
281+ logger .info (f"Token for client { client .id } has keys: { ', ' .join (token .keys ())} " )
276282
277283 next_url = connection .next_url
278284
@@ -356,6 +362,40 @@ async def get_oauth2_connection_token(
356362 token_model = models .OAuth2TokenSet .from_dict (oauth2_client .token )
357363 return token_model
358364
365+ async def get_docker_client (
366+ self , user : base_models .APIUser , image : Image
367+ ) -> tuple [ImageRepoDockerAPI , ULID ] | tuple [None , None ]:
368+ """Search for clients and connections that can work with the specific image and return a docker client."""
369+ async with self .session_maker () as session :
370+ registry_urls = [f"http://{ image .hostname } " , f"https://{ image .hostname } " ]
371+ stmt = (
372+ select (schemas .OAuth2ConnectionORM )
373+ .where (schemas .OAuth2ConnectionORM .user_id == user .id )
374+ .where (
375+ schemas .OAuth2ConnectionORM .client .has (
376+ schemas .OAuth2ClientORM .image_registry_url .in_ (registry_urls )
377+ )
378+ )
379+ .options (joinedload (schemas .OAuth2ConnectionORM .client ))
380+ )
381+ conn = await session .scalar (stmt )
382+ if not conn :
383+ return None , None
384+ if conn .client .kind != ProviderKind .gitlab :
385+ # NOTE: Only Gitlab is currently supported for this
386+ return None , None
387+ url = conn .client .image_registry_url
388+ if not url :
389+ return None , None
390+ token_set = await self .get_oauth2_connection_token (conn .id , user )
391+ url_parsed = urlparse (url )
392+ access_token = token_set .access_token
393+ if not access_token :
394+ return None , None
395+ return ImageRepoDockerAPI (
396+ hostname = url_parsed .netloc , scheme = url_parsed .scheme , oauth2_token = access_token
397+ ), conn .id
398+
359399 async def get_oauth2_app_installations (
360400 self , connection_id : ULID , user : base_models .APIUser , pagination : PaginationRequest
361401 ) -> models .AppInstallationList :
0 commit comments