Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/dl_api_lib_testing/dl_api_lib_testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def sync_us_manager(
ca_data=ca_data,
).make_service_registry(request_context_info=bi_context),
us_base_url=us_config.us_host,
us_auth_context=USAuthContextMaster(us_config.us_master_token),
us_auth_context=USAuthContextMaster(us_master_token=us_config.us_master_token),
crypto_keys_config=core_test_config.get_crypto_keys_config(),
retry_policy_factory=dl_retrier.DefaultRetryPolicyFactory(),
)
Expand Down
10 changes: 5 additions & 5 deletions lib/dl_core/dl_core/aio/middlewares/us_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def actual_public_usm_workaround_middleware(
sr = DummyServiceRegistry(rci=rci)
assert sr is not None

usm = usm_factory.get_async_usm(rci=rci, services_registry=sr, us_api_type=us_api_type)
usm = await usm_factory.get_async_usm(rci=rci, services_registry=sr, us_api_type=us_api_type)
try:
try:
# TODO: context_name not passed due to target type unknown
Expand Down Expand Up @@ -147,12 +147,12 @@ async def actual_us_manager_middleware(dl_request: DLRequestDataCore, handler: H
return await handler(dl_request.request)

if embed:
usm = usm_factory.get_embed_async_usm(
usm = await usm_factory.get_embed_async_usm(
rci=dl_request.rci,
services_registry=dl_request.services_registry,
)
else:
usm = usm_factory.get_regular_async_usm(
usm = await usm_factory.get_regular_async_usm(
rci=dl_request.rci,
services_registry=dl_request.services_registry,
)
Expand Down Expand Up @@ -192,7 +192,7 @@ async def actual_public_us_manager_middleware(
if aiohttp_wrappers.RequiredResourceCommon.US_MANAGER not in dl_request.required_resources:
return await handler(dl_request.request)

usm = usm_factory.get_public_async_usm(
usm = await usm_factory.get_public_async_usm(
rci=dl_request.rci,
services_registry=dl_request.services_registry,
)
Expand Down Expand Up @@ -237,7 +237,7 @@ async def actual_service_us_manager_middleware(
if target_resource not in dl_request.required_resources:
return await handler(dl_request.request)

usm = usm_factory.get_master_async_usm(
usm = await usm_factory.get_master_async_usm(
rci=dl_request.rci,
services_registry=dl_request.services_registry,
)
Expand Down
8 changes: 7 additions & 1 deletion lib/dl_core/dl_core/united_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def get_outbound_cookies(self) -> dict[DLCookies, str]:
return {}


@attr.s(frozen=True, kw_only=True)
class USAuthContextPrivateBase(USAuthContextBase):
"""
Common base class for environment-specific US authentication contexts.
Expand All @@ -143,6 +144,11 @@ class USAuthContextPrivateBase(USAuthContextBase):
DEFAULT_TENANT = TenantCommon()
IS_TENANT_ID_MUTABLE = True

us_master_token: str | None = attr.ib(
repr=False,
default=None,
) # TODO: Remove after US migration to dynamic authorization DLPROJECTS-500

def get_tenant(self) -> TenantDef:
return self.DEFAULT_TENANT

Expand All @@ -154,7 +160,7 @@ def get_outbound_cookies(self) -> dict[DLCookies, str]:
return {}


@attr.s(frozen=True)
@attr.s(frozen=True, kw_only=True)
class USAuthContextMaster(USAuthContextPrivateBase):
us_master_token: str = attr.ib(repr=False)

Expand Down
21 changes: 12 additions & 9 deletions lib/dl_core/dl_core/us_manager/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ def def_embed_us_auth_ctx_from_rci(self, rci: RequestContextInfo) -> USAuthConte
tenant=tenant,
)

def get_master_auth_context(self) -> USAuthContextPrivateBase:
def get_master_auth_context_sync(self) -> USAuthContextPrivateBase:
assert self.us_master_token is not None, "US master token must be set in factory to create USAuthContextMaster"
return USAuthContextMaster(us_master_token=self.us_master_token)

async def get_master_auth_context_async(self) -> USAuthContextPrivateBase:
return self.get_master_auth_context_sync()

def get_master_auth_context_from_headers(self) -> USAuthContextPrivateBase:
us_master_token = flask.request.headers.get(DLHeadersCommon.US_MASTER_TOKEN.value)
if us_master_token is None:
Expand All @@ -70,7 +73,7 @@ def get_ca_data(self) -> bytes:
return self.ca_data

# Async
def get_regular_async_usm(
async def get_regular_async_usm(
self,
rci: RequestContextInfo,
services_registry: ServicesRegistry,
Expand All @@ -85,13 +88,13 @@ def get_regular_async_usm(
retry_policy_factory=self.retry_policy_factory,
)

def get_master_async_usm(
async def get_master_async_usm(
self,
rci: RequestContextInfo,
services_registry: ServicesRegistry,
) -> AsyncUSManager:
return AsyncUSManager(
us_auth_context=self.get_master_auth_context(),
us_auth_context=await self.get_master_auth_context_async(),
us_base_url=self.us_base_url,
bi_context=rci,
crypto_keys_config=self.crypto_keys_config,
Expand All @@ -100,7 +103,7 @@ def get_master_async_usm(
retry_policy_factory=self.retry_policy_factory,
)

def get_public_async_usm(
async def get_public_async_usm(
self,
rci: RequestContextInfo,
services_registry: ServicesRegistry,
Expand Down Expand Up @@ -138,7 +141,7 @@ def get_master_sync_usm(
self, rci: RequestContextInfo, services_registry: ServicesRegistry, is_token_stored: bool | None = True
) -> SyncUSManager:
return SyncUSManager(
us_auth_context=self.get_master_auth_context()
us_auth_context=self.get_master_auth_context_sync()
if is_token_stored
else self.get_master_auth_context_from_headers(),
us_base_url=self.us_base_url,
Expand All @@ -148,7 +151,7 @@ def get_master_sync_usm(
retry_policy_factory=self.retry_policy_factory,
)

def get_embed_async_usm(
async def get_embed_async_usm(
self,
rci: RequestContextInfo,
services_registry: ServicesRegistry,
Expand All @@ -163,7 +166,7 @@ def get_embed_async_usm(
retry_policy_factory=self.retry_policy_factory,
)

def get_async_usm(
async def get_async_usm(
self,
rci: RequestContextInfo,
services_registry: ServicesRegistry,
Expand All @@ -176,4 +179,4 @@ def get_async_usm(
USApiType.embeds: self.get_embed_async_usm,
}
get_usm = usm_getters.get(us_api_type, self.get_regular_async_usm)
return get_usm(rci=rci, services_registry=services_registry)
return await get_usm(rci=rci, services_registry=services_registry)
2 changes: 1 addition & 1 deletion lib/dl_core/dl_core/us_manager/us_manager_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
map_id_key={"dummy_usm_key": fernet.Fernet.generate_key().decode("ascii")},
actual_key_id="dummy_usm_key",
),
us_auth_context=USAuthContextMaster("FakeKey"),
us_auth_context=USAuthContextMaster(us_master_token="FakeKey"),
services_registry=services_registry,
retry_policy_factory=dl_retrier.DefaultRetryPolicyFactory(),
)
2 changes: 1 addition & 1 deletion lib/dl_core/dl_core/us_manager/us_manager_sync_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
)
if crypto_keys_config is None
else crypto_keys_config,
us_auth_context=USAuthContextMaster("FakeKey"),
us_auth_context=USAuthContextMaster(us_master_token="FakeKey"),
services_registry=services_registry,
retry_policy_factory=dl_retrier.DefaultRetryPolicyFactory(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def always_400_mw(request, handler):
mock = await aiohttp_client(app)

us_client = UStorageClientAIO(
auth_ctx=USAuthContextMaster("fake_token"),
auth_ctx=USAuthContextMaster(us_master_token="fake_token"),
host=f"http://{mock.host}:{mock.port}",
prefix="/api/private",
ca_data=root_certificates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def conn_us_config(self) -> USConfig:
us_env_config = self.core_test_config.get_us_config()
return USConfig(
us_base_url=us_env_config.us_host,
us_auth_context=USAuthContextMaster(us_env_config.us_master_token),
us_auth_context=USAuthContextMaster(us_master_token=us_env_config.us_master_token),
us_crypto_keys_config=self.core_test_config.get_crypto_keys_config(),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ async def default_async_usm_per_test(bi_context, prepare_us, us_config, root_cer
rci = dl_api_commons.RequestContextInfo.create_empty()
return AsyncUSManager(
us_base_url=us_config.base_url,
us_auth_context=USAuthContextMaster(us_config.master_token),
us_auth_context=USAuthContextMaster(us_master_token=us_config.master_token),
crypto_keys_config=us_config.crypto_keys_config,
bi_context=bi_context,
services_registry=DummyServiceRegistry(rci=rci),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Optional
from typing import (
Any,
Callable,
Coroutine,
Optional,
)

import arq
import attr
Expand Down Expand Up @@ -36,6 +41,7 @@ class SecureReaderSettings:
@attr.s
class FileUploaderTaskContext(BaseContext):
settings: FileUploaderWorkerSettings = attr.ib()
us_auth_context_factory: Callable[[], Coroutine[Any, Any, USAuthContextPrivateBase]] = attr.ib()
tpe: ContextVarExecutor = attr.ib()
redis_service: RedisBaseService = attr.ib()
s3_service: S3Service = attr.ib()
Expand All @@ -45,7 +51,6 @@ class FileUploaderTaskContext(BaseContext):
secure_reader_settings: SecureReaderSettings = attr.ib()
tenant_resolver: TenantResolver = attr.ib()
ca_data: bytes = attr.ib()
us_auth_context: USAuthContextPrivateBase = attr.ib()

def get_rci(self) -> RequestContextInfo:
return RequestContextInfo.create_empty()
Expand All @@ -60,13 +65,13 @@ def get_service_registry(self, rci: Optional[RequestContextInfo] = None) -> Serv
def get_retry_policy_factory(self) -> dl_retrier.BaseRetryPolicyFactory:
return dl_retrier.RetryPolicyFactory.from_settings(self.settings.US_CLIENT.RETRY_POLICY)

def get_async_usm(self, rci: Optional[RequestContextInfo] = None) -> AsyncUSManager:
async def get_async_usm(self, rci: Optional[RequestContextInfo] = None) -> AsyncUSManager:
rci = rci or RequestContextInfo.create_empty()
services_registry = self.get_service_registry(rci=rci)
retry_policy_factory = self.get_retry_policy_factory()
return get_async_service_us_manager(
return await get_async_service_us_manager(
us_host=self.settings.US_BASE_URL,
us_auth_context=self.us_auth_context,
us_auth_context_factory=self.us_auth_context_factory,
services_registry=services_registry,
bi_context=rci,
crypto_keys_config=self.crypto_keys_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import os
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Optional,
)

Expand Down Expand Up @@ -57,9 +60,9 @@ def get_conn_options(conn: ConnectionBase) -> Optional[ConnectOptions]:
)


def get_async_service_us_manager(
async def get_async_service_us_manager(
us_host: str,
us_auth_context: USAuthContextPrivateBase,
us_auth_context_factory: Callable[[], Coroutine[Any, Any, USAuthContextPrivateBase]],
ca_data: bytes,
crypto_keys_config: CryptoKeysConfig,
services_registry: ServicesRegistry,
Expand All @@ -69,7 +72,7 @@ def get_async_service_us_manager(
usm = AsyncUSManager(
us_api_prefix="private",
us_base_url=us_host,
us_auth_context=us_auth_context,
us_auth_context=await us_auth_context_factory(),
crypto_keys_config=crypto_keys_config,
bi_context=bi_context or RequestContextInfo.create_empty(),
services_registry=services_registry,
Expand Down
18 changes: 12 additions & 6 deletions lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import abc
import os
from typing import (
Any,
Callable,
Coroutine,
Generic,
Optional,
TypeVar,
Expand Down Expand Up @@ -56,7 +59,7 @@
@attr.s
class FileUploaderContextFab(BaseContextFabric):
_settings: FileUploaderWorkerSettings = attr.ib()
_us_auth_context: USAuthContextPrivateBase = attr.ib()
_us_auth_context_factory: Callable[[], Coroutine[Any, Any, USAuthContextPrivateBase]] = attr.ib()
_ca_data: bytes = attr.ib()
_tenant_resolver: TenantResolver = attr.ib(factory=lambda: CommonTenantResolver())

Expand Down Expand Up @@ -96,7 +99,7 @@ async def make(self) -> FileUploaderTaskContext:
),
tenant_resolver=self._tenant_resolver,
ca_data=self._ca_data,
us_auth_context=self._us_auth_context,
us_auth_context_factory=self._us_auth_context_factory,
)

async def tear_down(self, inst: FileUploaderTaskContext) -> None: # type: ignore # 2024-01-30 # TODO: Argument 1 of "tear_down" is incompatible with supertype "BaseContextFabric"; supertype defines the argument type as "BaseContext" [override]
Expand All @@ -110,16 +113,19 @@ class FileUploaderWorkerFactory(Generic[_TSettings], abc.ABC):
_ca_data: bytes = attr.ib()
_settings: _TSettings = attr.ib()

def _get_us_auth_context(self) -> USAuthContextPrivateBase:
return USAuthContextMaster(us_master_token=self._settings.US_MASTER_TOKEN)

@abc.abstractmethod
def _get_tenant_resolver(self) -> TenantResolver:
raise NotImplementedError()

def _get_metrics_sender(self) -> Optional[WorkerMetricsSenderProtocol]:
return None

def _get_us_auth_context_factory(self) -> Callable[[], Coroutine[Any, Any, USAuthContextPrivateBase]]:
async def get_us_auth_context() -> USAuthContextPrivateBase:
return USAuthContextMaster(us_master_token=self._settings.US_MASTER_TOKEN)

return get_us_auth_context

def create_worker(self, state: Optional[TaskState] = None) -> ArqWorker:
if state is None:
state = TaskState(DummyStateImpl())
Expand All @@ -141,7 +147,7 @@ def create_worker(self, state: Optional[TaskState] = None) -> ArqWorker:
settings=self._settings,
ca_data=self._ca_data,
tenant_resolver=self._get_tenant_resolver(),
us_auth_context=self._get_us_auth_context(),
us_auth_context_factory=self._get_us_auth_context_factory(),
),
worker_settings=WorkerSettings(max_concurrent_jobs=self._settings.MAX_CONCURRENT_JOBS),
cron_tasks=cron_tasks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def run(self) -> TaskResult:
)
await RenameTenantStatusModel(manager=rmm, id=tenant_id, status=RenameTenantStatus.started).save()
try:
usm = self._ctx.get_async_usm()
usm = await self._ctx.get_async_usm()
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
s3_service = self._ctx.s3_service
s3_client = s3_service.get_client()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class DownloadGSheetTask(BaseExecutorTask[task_interface.DownloadGSheetTask, Fil
async def run(self) -> TaskResult:
dfile: Optional[DataFile] = None
sources_to_update_by_sheet_id: dict[int, list[DataSource]] = defaultdict(list)
usm = self._ctx.get_async_usm()
usm = await self._ctx.get_async_usm()
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
task_processor = self._ctx.make_task_processor(self._request_id)
redis = self._ctx.redis_service.get_redis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def run(self) -> TaskResult:
dfile: Optional[DataFile] = None
redis = self._ctx.redis_service.get_redis()
task_processor = self._ctx.make_task_processor(self._request_id)
usm = self._ctx.get_async_usm()
usm = await self._ctx.get_async_usm()
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
connection_error_tracker = FileConnectionDataSourceErrorTracker(
usm=usm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ProcessExcelTask(BaseExecutorTask[task_interface.ProcessExcelTask, FileUpl
async def run(self) -> TaskResult:
dfile: Optional[DataFile] = None
sources_to_update_by_sheet_id: dict[int, list[DataSource]] = defaultdict(list)
usm = self._ctx.get_async_usm()
usm = await self._ctx.get_async_usm()
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
task_processor = self._ctx.make_task_processor(self._request_id)
redis = self._ctx.redis_service.get_redis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ParseFileTask(BaseExecutorTask[task_interface.ParseFileTask, FileUploaderT

async def run(self) -> TaskResult:
dfile: Optional[DataFile] = None
usm = self._ctx.get_async_usm()
usm = await self._ctx.get_async_usm()
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
task_processor = self._ctx.make_task_processor(self._request_id)
redis = self._ctx.redis_service.get_redis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def run(self) -> TaskResult:

# TODO: init all this stuff in a proper place, not in task
rci = self._ctx.get_rci()
usm = self._ctx.get_async_usm(rci=rci)
usm = await self._ctx.get_async_usm(rci=rci)
usm.set_tenant_override(self._ctx.tenant_resolver.resolve_tenant_def_by_tenant_id(self.meta.tenant_id))
service_registry = self._ctx.get_service_registry(rci=rci)
release_update_source_lock_flag = False
Expand Down
Loading