@@ -1860,6 +1860,77 @@ async def add_worker(self, worker_address: str):
18601860 self ._worker_address_to_worker [worker_address ] = worker_ref
18611861 logger .debug ("Worker %s has been added successfully" , worker_address )
18621862
1863+ @log_async (logger = logger )
1864+ async def ensure_worker (
1865+ self , worker_address : str
1866+ ) -> xo .ActorRefType ["WorkerActor" ]:
1867+ from .worker import WorkerActor
1868+
1869+ worker_ref = await xo .actor_ref (
1870+ address = worker_address , uid = WorkerActor .default_uid ()
1871+ )
1872+ if worker_address in self ._worker_address_to_worker :
1873+ self ._worker_address_to_worker [worker_address ] = worker_ref
1874+ logger .debug ("Worker %s already registered, refreshed ref" , worker_address )
1875+ else :
1876+ self ._worker_address_to_worker [worker_address ] = worker_ref
1877+ logger .debug ("Worker %s has been added successfully" , worker_address )
1878+ return worker_ref
1879+
1880+ @log_async (logger = logger )
1881+ async def restore_worker_models (
1882+ self , worker_address : str , models : Dict [str , Dict [str , Any ]]
1883+ ):
1884+ if not models :
1885+ return
1886+ worker_ref = await self .ensure_worker (worker_address )
1887+ restored = 0
1888+ for replica_model_uid in models .keys ():
1889+ model_uid , rep_id = parse_replica_model_uid (replica_model_uid )
1890+ if rep_id < 0 :
1891+ rep_id = 0
1892+
1893+ replica_info = self ._model_uid_to_replica_info .get (model_uid , None )
1894+ if replica_info is None :
1895+ replica_count = rep_id + 1
1896+ replica_info = ReplicaInfo (
1897+ replica = replica_count ,
1898+ scheduler = itertools .cycle (range (replica_count )),
1899+ )
1900+ self ._model_uid_to_replica_info [model_uid ] = replica_info
1901+ elif rep_id + 1 > replica_info .replica :
1902+ replica_info .replica = rep_id + 1
1903+ replica_info .scheduler = itertools .cycle (range (replica_info .replica ))
1904+
1905+ if all (
1906+ w .address != worker_ref .address
1907+ for w in replica_info .replica_to_worker_refs [rep_id ]
1908+ ):
1909+ replica_info .replica_to_worker_refs [rep_id ].append (worker_ref )
1910+
1911+ existing = self ._replica_model_uid_to_worker .get (replica_model_uid , None )
1912+ if existing is None :
1913+ self ._replica_model_uid_to_worker [replica_model_uid ] = worker_ref
1914+ elif isinstance (existing , (list , tuple )):
1915+ if all (w .address != worker_ref .address for w in existing ):
1916+ if isinstance (existing , tuple ):
1917+ self ._replica_model_uid_to_worker [replica_model_uid ] = [
1918+ * existing ,
1919+ worker_ref ,
1920+ ]
1921+ else :
1922+ existing .append (worker_ref )
1923+ else :
1924+ if existing .address != worker_ref .address :
1925+ self ._replica_model_uid_to_worker [replica_model_uid ] = [
1926+ existing ,
1927+ worker_ref ,
1928+ ]
1929+ restored += 1
1930+ logger .info (
1931+ "Restored %s model replicas for worker %s" , restored , worker_address
1932+ )
1933+
18631934 @log_async (logger = logger )
18641935 async def remove_worker (self , worker_address : str ):
18651936 uids_to_remove = []
0 commit comments