2424from typing import (
2525 Any ,
2626 Dict ,
27+ Generic ,
2728 Generator ,
2829 Iterable ,
2930 List ,
@@ -69,8 +70,7 @@ def pytest_addoption(parser: Parser):
6970 parser .addoption (
7071 "--controller" ,
7172 action = "store" ,
72- help = "Juju controller to use; if not provided, "
73- "will use the current controller" ,
73+ help = "Juju controller to use; if not provided, will use the current controller" ,
7474 )
7575 parser .addoption (
7676 "--model-alias" ,
@@ -172,10 +172,9 @@ def pytest_configure(config: Config):
172172 config .addinivalue_line ("markers" , "abort_on_fail" )
173173 config .addinivalue_line ("markers" , "skip_if_deployed" )
174174
175- if config .option .basetemp is None :
176- tox_dir = os .environ .get ("TOX_ENV_DIR" )
177- if tox_dir :
178- config .option .basetemp = Path (tox_dir ) / "tmp/pytest"
175+ if tox_dir := os .environ .get ("TOX_ENV_DIR" ):
176+ config .option .basetemp = Path (tox_dir ) / "tmp/pytest"
177+ log .info ("Using basetemp: %s" , config .option .basetemp )
179178
180179
181180def pytest_runtest_setup (item ):
@@ -211,16 +210,16 @@ def event_loop():
211210
212211
213212# Plugin load order can't be set, replace asyncio directly
214- pytest_asyncio .plugin .event_loop = event_loop
213+ pytest_asyncio .plugin .event_loop = event_loop # type: ignore
215214
216215
217216def pytest_collection_modifyitems (session , config , items ):
218- """Automatically apply the "asyncio" marker to any async test items ."""
217+ """Automatically apply the "pytest.mark. asyncio" marker to any async testitems ."""
219218 for item in items :
220219 is_async = inspect .iscoroutinefunction (getattr (item , "function" , None ))
221220 has_marker = item .get_closest_marker ("asyncio" )
222221 if is_async and not has_marker :
223- item .add_marker (" asyncio" )
222+ item .add_marker (pytest . mark . asyncio )
224223
225224
226225@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -442,12 +441,12 @@ def _connect_kwds(request) -> Dict[str, Any]:
442441
443442
444443@dataclasses .dataclass
445- class ModelState :
444+ class ModelState ( Generic [ Timeout ]) :
446445 model : Model
447446 keep : bool
448447 destroy_storage : bool
449448 controller_name : str
450- cloud_name : Optional [ str ]
449+ cloud_name : str
451450 model_name : str
452451 config : Optional [dict ] = None
453452 tmp_path : Optional [Path ] = None
@@ -459,7 +458,7 @@ def full_name(self) -> str:
459458
460459
461460@dataclasses .dataclass
462- class CloudState :
461+ class CloudState ( Generic [ Timeout ]) :
463462 cloud_name : str
464463 models : List [str ] = dataclasses .field (default_factory = list )
465464 timeout : Optional [Timeout ] = None
@@ -520,7 +519,7 @@ def juju_download_args(self):
520519 if field .default is not dataclasses .MISSING
521520 ]
522521
523- def __init__ (self , request , tmp_path_factory ):
522+ def __init__ (self , request , tmp_path_factory ) -> None :
524523 self .request = request
525524 self ._tmp_path_factory = tmp_path_factory
526525 self ._global_tmp_path = None
@@ -558,7 +557,7 @@ def __init__(self, request, tmp_path_factory):
558557
559558 # maintains a set of all models connected by this fixture
560559 # use an OrderedDict so that the first model made is destroyed last.
561- self ._current_alias = None
560+ self ._current_alias : Optional [ str ] = None
562561 self ._models : MutableMapping [str , ModelState ] = OrderedDict ()
563562 self ._clouds : MutableMapping [str , CloudState ] = OrderedDict ()
564563
@@ -575,9 +574,18 @@ def model_context(self, alias: str) -> Generator[Model, None, None]:
575574 # if the there's a failure after yielding, don't fail to
576575 # switch back to the prior alias but still raise whatever
577576 # error condition occurred through the context
578- self ._switch (prior , raise_not_found = False )
577+ if isinstance (prior , str ):
578+ self ._switch (prior , raise_not_found = False )
579+
580+ @overload
581+ def _switch (self , alias : str , raise_not_found : Literal [True ] = True ) -> Model : ...
579582
580- def _switch (self , alias : str , raise_not_found = True ) -> Model :
583+ @overload
584+ def _switch (
585+ self , alias : str , raise_not_found : Literal [False ] = False
586+ ) -> Optional [Model ]: ...
587+
588+ def _switch (self , alias : str , raise_not_found = True ) -> Optional [Model ]:
581589 if alias in self ._models :
582590 self ._current_alias = alias
583591 elif not raise_not_found :
@@ -777,6 +785,8 @@ async def _model_exists(self, model_name: str) -> bool:
777785 """
778786 returns True when the model_name exists in the model.
779787 """
788+ if not self ._controller :
789+ return False
780790 all_models = await self ._controller .list_models ()
781791 return model_name in all_models
782792
@@ -790,13 +800,16 @@ async def _connect_to_model(
790800 """
791801 model = Model ()
792802 state = ModelState (
793- model , keep , destroy_storage , controller_name , None , model_name
803+ model , keep , destroy_storage , controller_name , "" , model_name
794804 )
795805 log .info (
796806 "Connecting to existing model %s on unspecified cloud" , state .full_name
797807 )
798808 await model .connect (state .full_name , ** connect_kwargs )
799809 state .config = await model .get_config ()
810+ controller = await model .get_controller ()
811+ state .cloud_name = await controller .get_cloud ()
812+
800813 return state
801814
802815 @staticmethod
@@ -920,13 +933,13 @@ async def track_model(
920933 ** self ._juju_connect_kwds ,
921934 )
922935 else :
923- cloud_name = cloud_name or self .cloud_name
936+ cloud_name = cloud_name or self .cloud_name or ""
924937 model_name = model_name or self ._generate_name (kind = "model" )
925938 model_state = await self ._add_model (
926939 cloud_name , model_name , keep_val , destroy_storage_val , ** kwargs
927940 )
928941 self ._models [alias ] = model_state
929- if ops_cloud := self ._clouds .get (cloud_name ):
942+ if ops_cloud := self ._clouds .get (model_state . cloud_name ):
930943 ops_cloud .models .append (alias )
931944 return model_state .model
932945
@@ -986,7 +999,7 @@ async def forget_model(
986999 if not alias :
9871000 alias = self .current_alias
9881001
989- if alias not in self .models :
1002+ if not alias or alias not in self .models :
9901003 raise ModelNotFoundError (f"{ alias } not found" )
9911004
9921005 model_state : ModelState = self ._models [alias ]
@@ -1127,7 +1140,7 @@ async def build_charm(
11271140 async def build_charm (
11281141 self ,
11291142 charm_path ,
1130- bases_index : int = None ,
1143+ bases_index : Optional [ int ] = None ,
11311144 verbosity : Optional [
11321145 Literal ["quiet" , "brief" , "verbose" , "debug" , "trace" ]
11331146 ] = None ,
@@ -1736,6 +1749,9 @@ async def add_k8s(
17361749 juju_cloud_config ["operator-storage" ] = storage_class
17371750
17381751 controller = self ._controller
1752+ if not controller :
1753+ raise RuntimeError ("No controller currently set." )
1754+
17391755 cloud_name = cloud_name or self ._generate_name ("k8s-cloud" )
17401756 log .info (f"Adding k8s cloud { cloud_name } " )
17411757
@@ -1790,5 +1806,6 @@ async def forget_cloud(self, cloud_name: str):
17901806 for model in reversed (self ._clouds [cloud_name ].models ):
17911807 await self .forget_model (model , destroy_storage = True )
17921808 log .info (f"Forgetting cloud: { cloud_name } ..." )
1793- await self ._controller .remove_cloud (cloud_name )
1809+ if self ._controller :
1810+ await self ._controller .remove_cloud (cloud_name )
17941811 del self ._clouds [cloud_name ]
0 commit comments