forked from InternLM/lmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathray_executor.py
More file actions
713 lines (605 loc) · 27.4 KB
/
ray_executor.py
File metadata and controls
713 lines (605 loc) · 27.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import contextlib
import json
import os
from typing import Any
import ray
import ray.exceptions
import torch
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.backends.selector import init_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.pytorch.ray import RayContext, get_device_str
from lmdeploy.pytorch.utils import wait_for_async_tasks
from lmdeploy.utils import get_logger, try_import_deeplink
from .base import ExecutorBase
from .base_worker import WorkerWrapperBase
from .dist_utils import find_available_port
logger = get_logger('lmdeploy')
def _get_master_addr():
"""Get master addr."""
addr = _envs.dist_master_addr
if addr is not None:
return addr
gcs_addr = ray.get_runtime_context().gcs_address
master_addr = gcs_addr.split(':')[0]
return master_addr
def _get_master_port():
"""Get master port."""
port = _envs.dist_master_port
if port is not None:
return port
return find_available_port()
def get_ascend_device_rank_mapping(master_addr):
rank_table_file = _envs.ascend_rank_table_file
if not rank_table_file:
raise ValueError('ASCEND_RANK_TABLE_FILE_PATH is not set')
with open(rank_table_file) as f:
rank_table = json.load(f)
try:
assert master_addr == rank_table['server_list'][0]['server_id'], 'Master address does not match rank table'
rank_mapping: dict[int, int] = {}
worker_ip_by_rank: dict[int, str] = {}
for server in rank_table['server_list']:
node_ip = server['server_id']
for idx, device in enumerate(server['device']):
# Prefer explicit device_id if present; fall back to enumeration order.
local_rank = int(device.get('device_id', idx))
global_rank = int(device['rank_id'])
rank_mapping[global_rank] = local_rank
worker_ip_by_rank[global_rank] = node_ip
if len(worker_ip_by_rank) == 0:
raise ValueError('Rank table contains no devices.')
ranks = sorted(worker_ip_by_rank.keys())
if ranks[0] != 0 or ranks[-1] != len(ranks) - 1:
raise ValueError(f'Rank ids are not contiguous starting from 0: {ranks[:8]}...{ranks[-8:]}')
worker_ips = [worker_ip_by_rank[r] for r in range(len(ranks))]
except Exception as e:
logger.error(f'Parse rank table file({rank_table}) failed')
raise e
envs = {
'ASCEND_RANK_TABLE_FILE_PATH': rank_table_file,
}
return rank_mapping, worker_ips, envs
def _update_env_cuda_alloc_conf(env_vars: dict):
"""Update runtime env for CUDA alloc conf."""
cuda_alloc_conf = os.getenv('PYTORCH_CUDA_ALLOC_CONF', None)
if cuda_alloc_conf is None:
return
# check and update conf, skip expandable_segments
cuda_alloc_conf = cuda_alloc_conf.split(',')
new_cuda_alloc_conf = []
for conf in cuda_alloc_conf:
if 'expandable_segments' in conf:
if 'True' in conf:
logger.warning('"expandable_segments:True" is not supported.')
continue
new_cuda_alloc_conf.append(conf)
if len(new_cuda_alloc_conf) == 0:
new_cuda_alloc_conf = ['expandable_segments:False']
cuda_alloc_conf = ','.join(new_cuda_alloc_conf)
# update env_vars
env_vars['PYTORCH_CUDA_ALLOC_CONF'] = cuda_alloc_conf
def _update_runtime_envs(runtime_env: dict):
"""Update runtime envs."""
new_envs = _envs.get_all_envs()
env_vars: dict = runtime_env.get('env_vars', {})
env_vars.update(new_envs)
_update_env_cuda_alloc_conf(env_vars)
runtime_env['env_vars'] = env_vars
return runtime_env
def _update_runtime_env_nsys(runtime_env: dict):
"""Update runtime env for nsys."""
nsight_env = {
't': 'cuda,cudnn,cublas,nvtx',
'o': "'worker_process_%p'",
'stop-on-exit': 'true',
}
prefix_path = _envs.ray_nsys_output_prefix
if prefix_path is not None:
nsight_env['o'] = f'{prefix_path}%p'
runtime_env['nsight'] = nsight_env
return runtime_env
class RemoteLogger:
"""Remote logger."""
def __init__(self):
self._records = dict()
self._next_handle = 0
def start(self, msg: str):
"""Start remote log."""
record = torch.profiler.record_function(msg)
record.__enter__()
handle = self._next_handle
self._records[handle] = record
self._next_handle += 1
return handle
def end(self, handle: int):
"""End remote log."""
record = self._records.pop(handle, None)
if record is not None:
record.__exit__(None, None, None)
class RayWorkerWrapper(WorkerWrapperBase):
"""Worker wrapper."""
def __init__(
self,
model_path: str,
cache_config: CacheConfig,
backend_config: BackendConfig,
model_config: ModelConfig,
dist_config: DistConfig,
misc_config: MiscConfig,
adapters: dict[str, str] = None,
device_type: str = 'cuda',
dtype: str = 'auto',
log_level: int = 30,
specdecode_config: SpecDecodeConfig = None,
):
init_backend(device_type)
try_import_deeplink(device_type)
super().__init__(
model_path=model_path,
cache_config=cache_config,
backend_config=backend_config,
model_config=model_config,
dist_config=dist_config,
misc_config=misc_config,
adapters=adapters,
device_type=device_type,
log_level=log_level,
specdecode_config=specdecode_config,
)
self.node_ip = ray.util.get_node_ip_address()
self._remote_logger = RemoteLogger()
def set_device(self, local_rank):
"""Set worker local rank."""
torch.cuda.set_device(local_rank)
def set_env(self, envs: dict[str, str]):
for key, value in envs.items():
os.environ[key] = value
def get_node_ip(self):
"""Get worker ip."""
return self.node_ip
def warmup_dist(self):
# None default CUDA_VISIBLE_DEVICES might leads to slow first time all_reduce
# WHY?
logger.debug('Warmup all_reduce.')
import torch
from lmdeploy.pytorch.distributed import all_reduce, get_dist_manager
with get_dist_manager().context(self.dist_ctx):
group = self.dist_ctx.tp_group.gpu_group
tmp = torch.empty((1, ), device='cuda')
all_reduce(tmp, group=group)
def pack_output(self, output: dict):
"""Pack output."""
return output.to_numpy()
def remote_log_start(self, msg: str):
"""Remote log start."""
return self._remote_logger.start(msg)
def remote_log_end(self, handle: int):
"""Remote log end."""
return self._remote_logger.end(handle)
def exit(self):
"""Exit actor."""
ray.actor.exit_actor()
class RayExecutor(ExecutorBase):
"""Ray executor."""
def __init__(
self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
dist_config: DistConfig,
misc_config: MiscConfig,
adapters: dict[str, str] = None,
device_type: str = 'cuda',
dtype: str = 'auto',
specdecode_config: SpecDecodeConfig = None,
):
"""Initialize Executor."""
super().__init__(
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
dist_config=dist_config,
misc_config=misc_config,
adapters=adapters,
device_type=device_type,
specdecode_config=specdecode_config,
)
device_ctx = DeviceContext(device_type)
with get_device_manager().context(device_ctx):
logger.info('Init ray cluster.')
attn_tp = dist_config.attn_tp
self.ray_ctx = RayContext(attn_tp, dp=dist_config.dp, device_type=device_type)
placement_group = self.ray_ctx.get_placement_group()
self.placement_group = placement_group
if self.dp == 1:
self.master_addr = _get_master_addr()
self.master_port = _get_master_port()
else:
self.master_addr = _envs.dp_master_addr
self.master_port = _envs.dp_master_port
if self.master_addr is None or self.master_port is None:
raise RuntimeError('DP > 1 requires "LMDEPLOY_DP_MASTER_ADDR" and "LMDEPLOY_DP_MASTER_PORT".')
# create workerwrapper actors
worker_kwargs = dict(
model_path=model_path,
cache_config=cache_config,
model_config=model_config,
backend_config=backend_config,
dist_config=dist_config,
misc_config=misc_config,
adapters=adapters,
device_type=device_type,
dtype=dtype,
log_level=logger.level,
specdecode_config=specdecode_config,
)
logger.info('Init ray workers.')
self.workers = self._init_workers_ray(placement_group, worker_kwargs)
self.dag = None
self._prefetch_task: asyncio.Task = None
self.remote_outs: asyncio.Queue = None
logger.info('Init distributed environment by device.')
self.rank_offset = dist_config.dp_rank * attn_tp
self._init_distributed_environment_by_device(device_type)
logger.info('Init distributed process group.')
ray.get([
worker.init_process_group.remote(rank + self.rank_offset, self.master_addr, self.master_port)
for rank, worker in enumerate(self.workers)
])
if self.dist_config.world_size > 1:
logger.info('Warming up distribute environment, this might take long time, please waiting...')
ray.get([worker.warmup_dist.remote() for worker in self.workers])
def collective_rpc(self,
method: str,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None,
timeout: float = None):
"""Collective rpc."""
if args is None:
args = list()
if kwargs is None:
kwargs = dict()
return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)
async def collective_rpc_async(self,
method: str,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None):
"""Collective async rpc."""
if args is None:
args = list()
if kwargs is None:
kwargs = dict()
tasks = [getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]
return await asyncio.gather(*tasks)
def build_model(self):
"""Build model."""
self.collective_rpc('build_model')
def gather_free_mem(self):
"""Gather available memory."""
return self.collective_rpc('get_free_mem')
def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):
"""Set all cache config."""
self.collective_rpc('set_cache_config', (cache_config, spec_cache_config))
def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
"""Set all model config."""
self.collective_rpc('set_model_config', (model_config, spec_model_config))
def build_graph_runner(self):
"""Build graph runner."""
self.collective_rpc('build_graph_runner')
def build_cache_engine(self):
"""Build cache engine."""
self.collective_rpc('build_cache_engine')
def update_params(self, request: Any):
"""Update params."""
self.collective_rpc('update_params', (request, ))
def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')
async def sleep(self, level: int = 1):
"""Sleep."""
await self.collective_rpc_async('sleep', (level, ))
def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
if tags is None or 'kv_cache' in tags:
self.update_configs()
self.collective_rpc('wakeup', (tags, ))
def get_input_processor(self):
"""Build cache engine."""
return ray.get(self.workers[0].get_input_processor.remote())
def _prefetch_task_callback(self, task: asyncio.Task):
try:
task.result()
except asyncio.CancelledError:
logger.debug(f'{task.get_name()} cancelled.')
except KeyboardInterrupt:
logger.debug(f'{task.get_name()} KeyboardInterrupt.')
except BaseException:
logger.debug(f'{task.get_name()} task failed.')
def start(self, forward_event: asyncio.Event):
"""Start engine loop."""
self.forward_event = forward_event
self.collective_rpc('start')
self.remote_outs = asyncio.Queue()
logger.info('Starting async task RayPrefetchOutput loop.')
async def wait_tasks(self):
"""Wait tasks."""
dp_rank = self.dist_config.dp_rank
tasks_to_cancel = set()
event_loop = asyncio.get_event_loop()
async def _wait_single_worker(worker):
try:
task = worker.wait_tasks.remote()
tasks_to_cancel.add(task)
await task
except ray.exceptions.ActorDiedError:
# It is safe to ignore wait tasks on died actor
logger.info('RayExecutor worker has been killed before finish wait_tasks.')
tasks = [
event_loop.create_task(_wait_single_worker(worker), name=f'WorkerWaitTasks_{idx}')
for idx, worker in enumerate(self.workers)
]
if self._prefetch_task is not None:
tasks.append(self._prefetch_task)
try:
await wait_for_async_tasks(tasks)
except asyncio.CancelledError:
logger.info(f'RayExecutor DP[{dp_rank}] wait_tasks cancelled.')
raise
except BaseException:
logger.error(f'RayExecutor DP[{dp_rank}] wait_tasks failed.')
raise
finally:
logger.debug(f'RayExecutor DP[{dp_rank}] wait_tasks cleanup.')
for task in tasks_to_cancel:
try:
ray.cancel(task)
except ray.exceptions.ActorDiedError:
logger.debug('RayExecutor worker has been killed before finish cancel task.')
except Exception as e:
logger.error(f'RayExecutor DP[{dp_rank}] Cancel wait_tasks failed: {e}')
def stop(self):
"""Stop engine loop."""
# TODO: For dp > 1 we currently rely on external teardown (e.g. Ray actor
# destruction) instead of explicitly stopping worker loops here. Implementing
# coordinated shutdown across multiple dp ranks is non-trivial, especially
# when some ranks may have already failed. The explicit stop_async RPC is
# therefore only issued when dp == 1.
if self.dp == 1:
try:
# add timeout might disable dump profile
# hope this will not lead to hanging
self.collective_rpc('stop_async')
except ray.exceptions.ActorDiedError:
logger.info('RayExecutor worker has been killed before finish stop_async.')
logger.debug('RayExecutor workers stopped.')
if self._prefetch_task is not None:
self._prefetch_task.cancel()
def release(self):
"""release."""
if _envs.ray_timeline_enable:
ray.timeline(_envs.ray_timeline_output_path)
if self.dp == 1:
try:
self.collective_rpc('release', timeout=5.0)
logger.debug('RayExecutor workers released.')
except ray.exceptions.ActorDiedError:
logger.info('RayExecutor worker has been killed before finish release.')
[ray.kill(worker) for worker in self.workers]
except ray.exceptions.GetTimeoutError:
logger.info('Ray release timeout, killing workers')
[ray.kill(worker) for worker in self.workers]
else:
[ray.kill(worker) for worker in self.workers]
self.ray_ctx.shutdown()
def _compile_dag(self):
"""Compile dag."""
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
with InputNode() as input_data:
outputs = [worker.forward_async.bind(input_data) for worker in self.workers]
output = MultiOutputNode(outputs)
return output
async def forward_async(self, inputs):
"""Start forward."""
if self.dag is None:
self.dag = self._compile_dag()
self._prev_inputs = None
self._prev_out = None
if self._prev_out is not None:
try:
ray.get(self._prev_out)
except SystemExit:
logger.error('Ray worker exited.')
raise
finally:
# free ray.put inputs
try:
ray.internal.free(self._prev_inputs, local_only=False)
except Exception as e:
logger.warning(f'Free input ref failed: {e}')
self._prev_inputs = ray.put(inputs)
# non-compiled dag would add input object ref, and the ref can not be released in python
self._prev_out = [
worker.forward_async.remote(self._prev_inputs) for worker in self.workers
]
async def get_output_async(self):
"""Get output async."""
ret = await self.workers[0].get_outputs.remote()
ret = ret.to_tensor()
return ret
@contextlib.contextmanager
def remote_log(self, msg: str):
"""Send log for debugging.
Do not use it in production.
"""
handle_ref = self.workers[0].remote_log_start.remote(msg)
yield
handle = ray.get(handle_ref)
ray.get(self.workers[0].remote_log_end.remote(handle))
def _sort_workers(self, driver_ip: str, workers: list[RayWorkerWrapper]):
"""Sort workers."""
# External bundle handling is only applicable when lmdeploy does NOT own
# the placement group. If lmdeploy owns the PG, we should continue to
# sort workers even if external bundle indices are specified.
if (not _envs.ray_external_pg_bundles) or self.ray_ctx.owned_pg:
return self._sort_workers_by_driver_then_worker_ip(driver_ip, workers)
else:
# do not sort when external bundle indices are specified and the
# placement group is externally managed
return workers
def _sort_workers_by_driver_then_worker_ip(self, driver_ip: str, workers: list[RayWorkerWrapper]):
"""Sort workers by ip."""
worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])
ip_counts: dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_ip_map = list(zip(workers, worker_ips))
def sort_by_driver_then_worker_ip(item):
"""Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item[1]
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_ip_map = sorted(worker_ip_map, key=sort_by_driver_then_worker_ip)
workers = [item[0] for item in sorted_worker_ip_map]
return workers
def _sort_workers_by_ip(self, ips, workers: list[RayWorkerWrapper]):
worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])
if len(ips) != len(workers):
raise ValueError(f'The length of the ips list does not match the workers, '
f'ips length: {len(ips)}, workers length: {len(workers)}')
# Check if all elements in ips are present in worker_ips and vice versa (ignoring order)
if set(ips) != set(worker_ips):
raise ValueError(f'The IP addresses in the ips list do not match the worker IPs. '
f'ips: {ips}, worker_ips: {worker_ips}')
worker_ip_map = list(zip(workers, worker_ips))
ip_priority = {ip: idx for idx, ip in enumerate(ips)}
def get_priority(ip):
return ip_priority.get(ip)
sorted_worker_ip_map = sorted(worker_ip_map, key=lambda x: get_priority(x[1]))
sorted_workers = [item[0] for item in sorted_worker_ip_map]
return sorted_workers
def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict):
"""Init worker ray."""
device_str = get_device_str()
bundle_indices = []
if not _envs.ray_external_pg_bundles:
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(device_str, 0):
bundle_indices.append(bundle_id)
else:
# use external specified bundle indices,keep the order as well
bundle_indices = _envs.ray_external_pg_bundles.copy()
# validate external bundle indices
num_bundles = len(placement_group.bundle_specs)
for bundle_id in bundle_indices:
if bundle_id < 0 or bundle_id >= num_bundles:
raise ValueError(f'External bundle index {bundle_id} is out of range. '
f'Placement group has {num_bundles} bundles (valid indices: 0-{num_bundles - 1}).')
bundle = placement_group.bundle_specs[bundle_id]
if not bundle.get(device_str, 0):
raise ValueError(
f'External bundle index {bundle_id} does not have required resource: {device_str}. '
f'Available resources in this bundle: {dict(bundle)}')
attn_tp = self.dist_config.attn_tp
if len(bundle_indices) < attn_tp:
raise ValueError(f'Not enough bundle indices for attention tensor parallelism. '
f'Required: {attn_tp}, Provided: {len(bundle_indices)} '
f'(bundle_indices: {bundle_indices}).')
bundle_indices = bundle_indices[:attn_tp]
workers = list()
for _, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if device_str == 'GPU':
runtime_env = dict()
runtime_env = _update_runtime_envs(runtime_env)
if _envs.ray_nsys_enable:
runtime_env = _update_runtime_env_nsys(runtime_env)
worker = ray.remote(
num_cpus=0,
num_gpus=0.01,
scheduling_strategy=scheduling_strategy,
runtime_env=runtime_env,
)(RayWorkerWrapper).remote(**worker_kwargs)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={device_str: 0.01},
scheduling_strategy=scheduling_strategy,
)(RayWorkerWrapper).remote(**worker_kwargs)
workers.append(worker)
return workers
def _init_distributed_environment_by_device(self, device_str: str):
"""Init distributed environment."""
driver_ip = _get_master_addr()
if device_str == 'cuda':
self.workers = self._sort_workers(driver_ip, self.workers)
elif device_str == 'ascend':
self._init_ascend_distributed_environment(driver_ip)
elif device_str in ['camb', 'maca']:
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
else:
raise ValueError(f'Unsupported device type: {device_str}')
def _init_ascend_distributed_environment(self, driver_ip):
"""Init ascend distributed environment."""
rank_table_file = _envs.ascend_rank_table_file
set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray
if rank_table_file:
# if rank table file is set, use it to get rank mapping, multiple nodes
rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)
rank_start = self.rank_offset
rank_end = rank_start + len(self.workers)
if rank_end > len(worker_ips):
raise ValueError(
'Rank table world_size is smaller than required ranks for current dp_rank. '
f'rank_table_world_size={len(worker_ips)}, required_rank_range=[{rank_start}, {rank_end})')
# In dp mode each process only owns a slice of global ranks.
expected_worker_ips = worker_ips[rank_start:rank_end]
self.workers = self._sort_workers_by_ip(expected_worker_ips, self.workers)
ray.get(
[worker.set_device.remote(rank_mapping[rank_start + idx]) for idx, worker in enumerate(self.workers)])
ray.get([worker.set_env.remote(envs) for worker in self.workers])
elif not set_rt_visable_devices_by_ray:
# if rank table file is not set, treat as single node
# simply set device by index, this is for single node, multiple devices
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx + self.rank_offset) for idx, worker in enumerate(self.workers)])
else:
self.workers = self._sort_workers(driver_ip, self.workers)
""" PD Disaggregation API Begin """
def p2p_initialize(self, init_request: DistServeInitRequest):
return self.collective_rpc('p2p_initialize', (init_request, ))
def p2p_connect(self, remote_engine_id: str, conn_request: list[DistServeKVTransferEndpointInfo]):
"""Rdma connect."""
return self.collective_rpc('p2p_connect', (
remote_engine_id,
conn_request,
))
async def migrate(self, batch: MigrationExecutionBatch):
jobs = (worker.migrate.remote(batch) for worker in self.workers)
return await asyncio.gather(*jobs)
""" PD Disaggregation API Begin """