Skip to content

Commit 7099e75

Browse files
all_to_all model implementation first approximation in ISL
1 parent 5c836a0 commit 7099e75

5 files changed

Lines changed: 176 additions & 56 deletions

File tree

accelforge/model/_looptree/reuse/isl/distributed/bind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Applies the binding layer into one that can be used for later analysis,"""
22

3-
from accelforge.frontend.binding import Binding
3+
from accelforge.frontend._binding import Binding
44
from accelforge.frontend.mapping import Mapping
55
from accelforge.frontend.workload import Workload
66

notebooks/astrasim2_correlation/correlation.ipynb

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,60 +26,7 @@
2626
"id": "95cf92d9",
2727
"metadata": {},
2828
"source": [
29-
"## Astrasim-2.0\n",
30-
"\n",
31-
"The [Astrasim-2.0 paper](https://arxiv.org/abs/2303.14006) has correlation to a torus of 4 and 16 V100s on page 6 to the network latency of data transfers of an all-reduce.\n",
32-
"\n",
33-
"On top of this, on page 9 they have data on their simulation framework correlated to the above and the latencies of certain operation of their model.\n",
34-
"\n",
35-
"We aim to show that we can match these numbers with our analytical model."
36-
]
37-
},
38-
{
39-
"cell_type": "code",
40-
"execution_count": null,
41-
"id": "f70059c4",
42-
"metadata": {},
43-
"outputs": [],
44-
"source": [
45-
"from numbers import Number\n",
46-
"from typing import Dict, Sequence\n",
47-
"\n",
48-
"def ring_4xV100():\n",
49-
"\t\"\"\"\n",
50-
"\tGenerates a graph with the latencies of 4xV100s.\n",
51-
"\t\"\"\"\n",
52-
"\t# Collective size in MB to latency TODO: stop eyeballing latency.\n",
53-
"\tGROUND: Dict[int, int] = {\n",
54-
"\t\t64: 500,\n",
55-
"\t\t96: 750,\n",
56-
"\t\t128: 1000,\n",
57-
"\t\t192: 2000,\n",
58-
"\t\t750: 10_000,\n",
59-
"\t\t1500: 20_000\n",
60-
"\t}\n",
61-
"\n",
62-
"\tsize: Sequence[Number] = tuple(GROUND.keys())\n",
63-
"\ttruth: Sequence[Number] = tuple(GROUND.values())\n",
64-
"\testimate: Sequence[Number] = []\n",
65-
"\n",
66-
"\t# Generates estimates from model\n",
67-
"\tfor mb in size:\n",
68-
"\t\t# TODO: Read and Jinja2 these items.\n",
69-
"\t\t\n",
70-
"\n",
71-
"\tx = np.arange(len(size)) # positions\n",
72-
"\twidth = 0.35 # bar width\n",
73-
"\n",
74-
"\tfig, ax = plt.subplots()\n",
75-
"\t# Position bars side-by-side using offset\n",
76-
"\tax.bar(x - width/2, men, width, label='Ground', color=\"blue\")\n",
77-
"\tax.bar(x + width/2, women, width, label='Model', color=\"red\")\n",
78-
"\n",
79-
"\tax.set_xticks(x) # Center labels\n",
80-
"\tax.set_xticklabels(labels)\n",
81-
"\tax.legend()\n",
82-
"\tplt.show()"
29+
"## We are testing an 8 GPU All-to-All to Correlate Later, simulating an NVLink Switch\n"
8330
]
8431
}
8532
],

tests/not_working/distribuffers/multicast/test_cases.yaml

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,121 @@
479479
# dist_fn: *ring_dist_size_8
480480
# expected:
481481
# latency: 1
482-
# total_hops: 4
482+
# total_hops: 4
483+
###################################################
484+
# 8-GPU fully-connected (NVLink/NVSwitch-style) #
485+
# all-to-all, one-hot GPU encoding. #
486+
# #
487+
# GPU i sits at one-hot coordinate e_i so every #
488+
# src!=dst cast has extent 1 along exactly the #
489+
# src and dst dims: cost = (1+1)(1+1)-1 = 3, #
490+
# uniform across all pairs (fully-connected). #
491+
# Self-chunks never cross the fabric (cost 0). #
492+
# dist_fn is unit-cost (matching only). #
493+
###################################################
494+
495+
# All-to-all over 8 GPUs: each GPU holds chunks data[self, d] and requests data[s, self].
496+
- occ: |
497+
{
498+
noc[gs0, gs1, gs2, gs3, gs4, gs5, gs6, gs7] -> data[s, d] :
499+
0 <= gs0 <= 1 and 0 <= gs1 <= 1 and 0 <= gs2 <= 1 and 0 <= gs3 <= 1 and 0 <= gs4 <= 1 and 0 <= gs5 <= 1 and 0 <= gs6 <= 1 and 0 <= gs7 <= 1 and
500+
gs0 + gs1 + gs2 + gs3 + gs4 + gs5 + gs6 + gs7 = 1 and
501+
s = 1*gs1 + 2*gs2 + 3*gs3 + 4*gs4 + 5*gs5 + 6*gs6 + 7*gs7 and 0 <= d < 8
502+
}
503+
fill: |
504+
{
505+
noc[gd0, gd1, gd2, gd3, gd4, gd5, gd6, gd7] -> data[s, d] :
506+
0 <= gd0 <= 1 and 0 <= gd1 <= 1 and 0 <= gd2 <= 1 and 0 <= gd3 <= 1 and 0 <= gd4 <= 1 and 0 <= gd5 <= 1 and 0 <= gd6 <= 1 and 0 <= gd7 <= 1 and
507+
gd0 + gd1 + gd2 + gd3 + gd4 + gd5 + gd6 + gd7 = 1 and
508+
d = 1*gd1 + 2*gd2 + 3*gd3 + 4*gd4 + 5*gd5 + 6*gd6 + 7*gd7 and 0 <= s < 8
509+
}
510+
dims: &8d_onehot_spatial
511+
- type: Spatial
512+
spatial_dim: 0
513+
target: 0
514+
- type: Spatial
515+
spatial_dim: 1
516+
target: 0
517+
- type: Spatial
518+
spatial_dim: 2
519+
target: 0
520+
- type: Spatial
521+
spatial_dim: 3
522+
target: 0
523+
- type: Spatial
524+
spatial_dim: 4
525+
target: 0
526+
- type: Spatial
527+
spatial_dim: 5
528+
target: 0
529+
- type: Spatial
530+
spatial_dim: 6
531+
target: 0
532+
- type: Spatial
533+
spatial_dim: 7
534+
target: 0
535+
dist_fn: &fully_connected_unit |
536+
{
537+
[noc[xd0, xd1, xd2, xd3, xd4, xd5, xd6, xd7] -> noc[xs0, xs1, xs2, xs3, xs4, xs5, xs6, xs7]] -> hops[0] :
538+
xd0 = xs0 and xd1 = xs1 and xd2 = xs2 and xd3 = xs3 and xd4 = xs4 and xd5 = xs5 and xd6 = xs6 and xd7 = xs7;
539+
[noc[xd0, xd1, xd2, xd3, xd4, xd5, xd6, xd7] -> noc[xs0, xs1, xs2, xs3, xs4, xs5, xs6, xs7]] -> hops[1] :
540+
(xd0 < xs0) or (xd0 > xs0) or (xd1 < xs1) or (xd1 > xs1) or (xd2 < xs2) or (xd2 > xs2) or (xd3 < xs3) or (xd3 > xs3) or (xd4 < xs4) or (xd4 > xs4) or (xd5 < xs5) or (xd5 > xs5) or (xd6 < xs6) or (xd6 > xs6) or (xd7 < xs7) or (xd7 > xs7)
541+
}
542+
expected:
543+
latency: null
544+
total_hops: null
545+
multicast_hops: null
546+
hypercube_hops: 168
547+
extent_DOR_hops: null
548+
549+
# Single chunk GPU0 -> GPU3: one unicast cast, cost (1+1)(1+1)-1 = 3.
550+
- occ: |
551+
{
552+
noc[gs0, gs1, gs2, gs3, gs4, gs5, gs6, gs7] -> data[s, d] :
553+
0 <= gs0 <= 1 and 0 <= gs1 <= 1 and 0 <= gs2 <= 1 and 0 <= gs3 <= 1 and 0 <= gs4 <= 1 and 0 <= gs5 <= 1 and 0 <= gs6 <= 1 and 0 <= gs7 <= 1 and
554+
gs0 + gs1 + gs2 + gs3 + gs4 + gs5 + gs6 + gs7 = 1 and
555+
s = 1*gs1 + 2*gs2 + 3*gs3 + 4*gs4 + 5*gs5 + 6*gs6 + 7*gs7 and s = 0 and d = 3
556+
}
557+
fill: |
558+
{
559+
noc[gd0, gd1, gd2, gd3, gd4, gd5, gd6, gd7] -> data[s, d] :
560+
0 <= gd0 <= 1 and 0 <= gd1 <= 1 and 0 <= gd2 <= 1 and 0 <= gd3 <= 1 and 0 <= gd4 <= 1 and 0 <= gd5 <= 1 and 0 <= gd6 <= 1 and 0 <= gd7 <= 1 and
561+
gd0 + gd1 + gd2 + gd3 + gd4 + gd5 + gd6 + gd7 = 1 and
562+
d = 1*gd1 + 2*gd2 + 3*gd3 + 4*gd4 + 5*gd5 + 6*gd6 + 7*gd7 and s = 0 and d = 3
563+
}
564+
dims: *8d_onehot_spatial
565+
566+
dist_fn: *fully_connected_unit
567+
568+
expected:
569+
latency: null
570+
total_hops: null
571+
multicast_hops: null
572+
hypercube_hops: 3
573+
extent_DOR_hops: null
574+
575+
# Self chunk GPU5 -> GPU5: never crosses the fabric, cost 0.
576+
- occ: |
577+
{
578+
noc[gs0, gs1, gs2, gs3, gs4, gs5, gs6, gs7] -> data[s, d] :
579+
0 <= gs0 <= 1 and 0 <= gs1 <= 1 and 0 <= gs2 <= 1 and 0 <= gs3 <= 1 and 0 <= gs4 <= 1 and 0 <= gs5 <= 1 and 0 <= gs6 <= 1 and 0 <= gs7 <= 1 and
580+
gs0 + gs1 + gs2 + gs3 + gs4 + gs5 + gs6 + gs7 = 1 and
581+
s = 1*gs1 + 2*gs2 + 3*gs3 + 4*gs4 + 5*gs5 + 6*gs6 + 7*gs7 and s = 5 and d = 5
582+
}
583+
fill: |
584+
{
585+
noc[gd0, gd1, gd2, gd3, gd4, gd5, gd6, gd7] -> data[s, d] :
586+
0 <= gd0 <= 1 and 0 <= gd1 <= 1 and 0 <= gd2 <= 1 and 0 <= gd3 <= 1 and 0 <= gd4 <= 1 and 0 <= gd5 <= 1 and 0 <= gd6 <= 1 and 0 <= gd7 <= 1 and
587+
gd0 + gd1 + gd2 + gd3 + gd4 + gd5 + gd6 + gd7 = 1 and
588+
d = 1*gd1 + 2*gd2 + 3*gd3 + 4*gd4 + 5*gd5 + 6*gd6 + 7*gd7 and s = 5 and d = 5
589+
}
590+
dims: *8d_onehot_spatial
591+
592+
dist_fn: *fully_connected_unit
593+
594+
expected:
595+
latency: null
596+
total_hops: null
597+
multicast_hops: null
598+
hypercube_hops: 0
599+
extent_DOR_hops: null
File renamed without changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Utility functions common to testing the isl mapper functions.
3+
"""
4+
5+
from pathlib import Path
6+
import islpy as isl
7+
8+
from ruamel.yaml import YAML
9+
10+
11+
def to_isl_maps(obj: str | list | dict) -> dict:
12+
"""
13+
Given an object, attempt to reduce all strings in tree with isl.Map
14+
15+
Parameters
16+
----------
17+
obj:
18+
A DAG which can be explored and contains isl.Map strings within it.
19+
20+
Returns
21+
-------
22+
`obj` but all strings are converted to isl.Map.
23+
"""
24+
25+
def _to_isl_maps(obj: str | dict | list) -> isl.Map | dict | list:
26+
"""Recursively convert string ISL maps to isl.Map; leave others alone."""
27+
if isinstance(obj, str):
28+
return isl.Map.read_from_str(isl.DEFAULT_CONTEXT, obj)
29+
if isinstance(obj, dict):
30+
return {k: (_to_isl_maps(v) if k != "type" else v) for k, v in obj.items()}
31+
if isinstance(obj, list):
32+
return [_to_isl_maps(v) for v in obj]
33+
return obj
34+
35+
return _to_isl_maps(obj) # type: ignore
36+
37+
38+
def load_solutions(path: Path) -> dict:
39+
"""
40+
Loads in a dictionary with the isl solutions to a workload problem.
41+
42+
Parameters
43+
----------
44+
path:
45+
The path to the solutions.
46+
47+
Returns
48+
-------
49+
A dictionary relating Python-based keys generated by the mapper (e.g.,
50+
`BufferTensorEinsum` to their corresponding isl.Map.)
51+
"""
52+
# Load expected solutions (YAML file with string ISL maps)
53+
yaml: YAML = YAML(typ="safe")
54+
55+
with open(path, "r", encoding="utf-8") as f:
56+
return to_isl_maps(yaml.load(f))

0 commit comments

Comments
 (0)