Skip to content

Commit 3ef96a7

Browse files
Add "server_main.py" for graph store mode (#394)
Co-authored-by: kmontemayor <kyle.e.montemayor@gmail.com>
1 parent f9508e8 commit 3ef96a7

6 files changed

Lines changed: 436 additions & 0 deletions

File tree

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Built-in GiGL Graph Store Server.
2+
3+
Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py
4+
5+
"""
6+
import argparse
7+
import os
8+
9+
import graphlearn_torch as glt
10+
import torch
11+
12+
from gigl.common import Uri, UriFactory
13+
from gigl.common.logger import Logger
14+
from gigl.distributed import build_dataset_from_task_config_uri
15+
from gigl.distributed.dist_dataset import DistDataset
16+
from gigl.distributed.graph_store.remote_dataset import register_dataset
17+
from gigl.distributed.utils import get_graph_store_info
18+
from gigl.env.distributed import GraphStoreInfo
19+
20+
logger = Logger()
21+
22+
23+
def _run_storage_process(
24+
storage_rank: int,
25+
cluster_info: GraphStoreInfo,
26+
dataset: DistDataset,
27+
) -> None:
28+
logger.info(
29+
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes } on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}. Cluster rank: {os.environ.get('RANK')}"
30+
)
31+
register_dataset(dataset)
32+
glt.distributed.init_server(
33+
num_servers=cluster_info.num_storage_nodes,
34+
server_rank=storage_rank,
35+
dataset=dataset,
36+
master_addr=cluster_info.cluster_master_ip,
37+
master_port=cluster_info.cluster_master_port,
38+
num_clients=cluster_info.compute_cluster_world_size,
39+
)
40+
41+
logger.info(
42+
f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit"
43+
)
44+
glt.distributed.wait_and_shutdown_server()
45+
logger.info(f"Storage node {storage_rank} exited")
46+
47+
48+
def storage_node_process(
49+
storage_rank: int,
50+
cluster_info: GraphStoreInfo,
51+
task_config_uri: Uri,
52+
is_inference: bool,
53+
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
54+
) -> None:
55+
"""Run a storage node process
56+
57+
Should be called *once* per storage node (machine).
58+
59+
Args:
60+
storage_rank (int): The rank of the storage node.
61+
cluster_info (GraphStoreInfo): The cluster information.
62+
task_config_uri (Uri): The task config URI.
63+
is_inference (bool): Whether the process is an inference process.
64+
tf_record_uri_pattern (str): The TF Record URI pattern.
65+
"""
66+
init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}"
67+
logger.info(
68+
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}"
69+
)
70+
torch.distributed.init_process_group(
71+
backend="gloo",
72+
world_size=cluster_info.num_storage_nodes,
73+
rank=storage_rank,
74+
init_method=init_method,
75+
group_name="gigl_server_comms",
76+
)
77+
logger.info(
78+
f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized"
79+
)
80+
dataset = build_dataset_from_task_config_uri(
81+
task_config_uri=task_config_uri,
82+
is_inference=is_inference,
83+
_tfrecord_uri_pattern=tf_record_uri_pattern,
84+
)
85+
server_processes = []
86+
mp_context = torch.multiprocessing.get_context("spawn")
87+
# TODO(kmonte): Enable more than one server process per machine
88+
for i in range(1):
89+
server_process = mp_context.Process(
90+
target=_run_storage_process,
91+
args=(
92+
storage_rank + i, # storage_rank
93+
cluster_info, # cluster_info
94+
dataset, # dataset
95+
),
96+
)
97+
server_processes.append(server_process)
98+
for server_process in server_processes:
99+
server_process.start()
100+
for server_process in server_processes:
101+
server_process.join()
102+
103+
104+
if __name__ == "__main__":
105+
parser = argparse.ArgumentParser()
106+
parser.add_argument("--task_config_uri", type=str, required=True)
107+
parser.add_argument("--resource_config_uri", type=str, required=True)
108+
parser.add_argument("--is_inference", action="store_true")
109+
args = parser.parse_args()
110+
logger.info(f"Running storage node with arguments: {args}")
111+
112+
is_inference = args.is_inference
113+
torch.distributed.init_process_group()
114+
cluster_info = get_graph_store_info()
115+
# Tear down the """"global""" process group so we can have a server-specific process group.
116+
torch.distributed.destroy_process_group()
117+
storage_node_process(
118+
storage_rank=cluster_info.storage_node_rank,
119+
cluster_info=cluster_info,
120+
task_config_uri=UriFactory.create_uri(args.task_config_uri),
121+
is_inference=is_inference,
122+
)

python/gigl/env/distributed.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Information about distributed environments."""
22

3+
import os
34
from dataclasses import dataclass
45
from typing import Final
56

@@ -61,3 +62,35 @@ def num_cluster_nodes(self) -> int:
6162
@property
6263
def compute_cluster_world_size(self) -> int:
6364
return self.num_compute_nodes * self.num_processes_per_compute
65+
66+
@property
67+
def storage_node_rank(self) -> int:
68+
"""Get the rank of the storage node in the storage cluster.
69+
70+
Raises:
71+
ValueError: If the node is not in the storage cluster.
72+
"""
73+
global_rank = int(os.environ["RANK"])
74+
if not (
75+
self.num_compute_nodes
76+
<= global_rank
77+
< self.num_compute_nodes + self.num_storage_nodes
78+
):
79+
raise ValueError(
80+
f"Global rank {global_rank} is not a storage rank. Expected storage rank to be in [{self.num_compute_nodes}, {self.num_compute_nodes + self.num_storage_nodes})"
81+
)
82+
return global_rank - self.num_compute_nodes
83+
84+
@property
85+
def compute_node_rank(self) -> int:
86+
"""Get the rank of the compute node in the compute cluster.
87+
88+
Raises:
89+
ValueError: If the node is not in the compute cluster.
90+
"""
91+
global_rank = int(os.environ["RANK"])
92+
if not 0 <= global_rank < self.num_compute_nodes:
93+
raise ValueError(
94+
f"Global rank {global_rank} is not a compute rank. Expected compute rank to be in [0, {self.num_compute_nodes})"
95+
)
96+
return global_rank

python/tests/integration/distributed/graph_store/__init__.py

Whitespace-only changes.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import os
2+
import unittest
3+
from unittest import mock
4+
5+
import torch
6+
import torch.multiprocessing as mp
7+
from graphlearn_torch.distributed import init_client, shutdown_client
8+
9+
from gigl.common import Uri
10+
from gigl.common.logger import Logger
11+
from gigl.distributed.graph_store.storage_main import storage_node_process
12+
from gigl.distributed.utils import get_free_port
13+
from gigl.env.distributed import (
14+
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
15+
GraphStoreInfo,
16+
)
17+
from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata
18+
from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import (
19+
CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO,
20+
)
21+
22+
logger = Logger()
23+
24+
25+
def _run_client_process(
26+
client_rank: int,
27+
cluster_info: GraphStoreInfo,
28+
) -> None:
29+
client_global_rank = (
30+
cluster_info.compute_node_rank * cluster_info.num_processes_per_compute
31+
+ client_rank
32+
)
33+
logger.info(
34+
f"Initializing client process {client_global_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}. OS rank: {os.environ['RANK']}, local client rank: {client_rank} on port: {cluster_info.cluster_master_port}"
35+
)
36+
# TODO(kmonte): Add gigl.*.init_client as a helper function to do this.
37+
torch.distributed.init_process_group(
38+
backend="gloo",
39+
world_size=cluster_info.compute_cluster_world_size,
40+
rank=client_global_rank,
41+
init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}",
42+
group_name="gigl_client_comms",
43+
)
44+
logger.info(
45+
f"Client {client_global_rank} / {cluster_info.compute_cluster_world_size} process group initialized"
46+
)
47+
init_client(
48+
num_servers=cluster_info.num_storage_nodes,
49+
num_clients=cluster_info.compute_cluster_world_size,
50+
client_rank=client_global_rank,
51+
master_addr=cluster_info.cluster_master_ip,
52+
master_port=cluster_info.cluster_master_port,
53+
client_group_name="gigl_client_rpc",
54+
)
55+
56+
torch.distributed.barrier()
57+
logger.info(
58+
f"{client_global_rank} / {cluster_info.compute_cluster_world_size} Shutting down client"
59+
)
60+
shutdown_client()
61+
62+
63+
def _client_process(
64+
client_rank: int,
65+
cluster_info: GraphStoreInfo,
66+
) -> None:
67+
logger.info(
68+
f"Initializing client node {client_rank} / {cluster_info.num_compute_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}, local client rank: {client_rank}"
69+
)
70+
71+
mp_context = torch.multiprocessing.get_context("spawn")
72+
client_processes = []
73+
for i in range(cluster_info.num_processes_per_compute):
74+
client_process = mp_context.Process(
75+
target=_run_client_process,
76+
args=[
77+
i, # client_rank
78+
cluster_info, # cluster_info
79+
],
80+
)
81+
client_processes.append(client_process)
82+
for client_process in client_processes:
83+
client_process.start()
84+
for client_process in client_processes:
85+
client_process.join()
86+
87+
88+
def _run_server_processes(
89+
cluster_info: GraphStoreInfo,
90+
task_config_uri: Uri,
91+
is_inference: bool,
92+
) -> None:
93+
logger.info(
94+
f"Initializing server processes. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}"
95+
)
96+
storage_node_process(
97+
storage_rank=cluster_info.storage_node_rank,
98+
cluster_info=cluster_info,
99+
task_config_uri=task_config_uri,
100+
is_inference=is_inference,
101+
tf_record_uri_pattern=".*tfrecord",
102+
)
103+
104+
105+
class TestUtils(unittest.TestCase):
106+
def test_graph_store_locally(self):
107+
# Simulating two server machine, two compute machines.
108+
# Each machine has one process.
109+
cora_supervised_info = get_mocked_dataset_artifact_metadata()[
110+
CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name
111+
]
112+
task_config_uri = cora_supervised_info.frozen_gbml_config_uri
113+
cluster_info = GraphStoreInfo(
114+
num_storage_nodes=2,
115+
num_compute_nodes=2,
116+
num_processes_per_compute=2,
117+
cluster_master_ip="localhost",
118+
storage_cluster_master_ip="localhost",
119+
compute_cluster_master_ip="localhost",
120+
cluster_master_port=get_free_port(),
121+
storage_cluster_master_port=get_free_port(),
122+
compute_cluster_master_port=get_free_port(),
123+
)
124+
125+
master_port = get_free_port()
126+
ctx = mp.get_context("spawn")
127+
client_processes: list = []
128+
for i in range(cluster_info.num_compute_nodes):
129+
with mock.patch.dict(
130+
os.environ,
131+
{
132+
"MASTER_ADDR": "localhost",
133+
"MASTER_PORT": str(master_port),
134+
"RANK": str(i),
135+
"WORLD_SIZE": str(cluster_info.compute_cluster_world_size),
136+
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str(
137+
cluster_info.num_processes_per_compute
138+
),
139+
},
140+
clear=False,
141+
):
142+
client_process = ctx.Process(
143+
target=_client_process,
144+
args=[
145+
i, # client_rank
146+
cluster_info, # cluster_info
147+
],
148+
)
149+
client_process.start()
150+
client_processes.append(client_process)
151+
# Start server process
152+
server_processes = []
153+
for i in range(cluster_info.num_storage_nodes):
154+
with mock.patch.dict(
155+
os.environ,
156+
{
157+
"MASTER_ADDR": "localhost",
158+
"MASTER_PORT": str(master_port),
159+
"RANK": str(i + cluster_info.num_compute_nodes),
160+
"WORLD_SIZE": str(cluster_info.compute_cluster_world_size),
161+
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str(
162+
cluster_info.num_processes_per_compute
163+
),
164+
},
165+
clear=False,
166+
):
167+
server_process = ctx.Process(
168+
target=_run_server_processes,
169+
args=[
170+
cluster_info, # cluster_info
171+
task_config_uri, # task_config_uri
172+
True, # is_inference
173+
],
174+
)
175+
server_process.start()
176+
server_processes.append(server_process)
177+
178+
for client_process in client_processes:
179+
client_process.join()
180+
for server_process in server_processes:
181+
server_process.join()

python/tests/unit/env/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)