Skip to content

Commit c056c4e

Browse files
authored
Compile the function ahead of time in the JAX example (#6286)
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent ea93334 commit c056c4e

1 file changed

Lines changed: 54 additions & 52 deletions

File tree

docs/examples/frameworks/jax/jax-basic_example.ipynb

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@
1515
{
1616
"cell_type": "code",
1717
"execution_count": 1,
18-
"metadata": {
19-
"execution": {
20-
"iopub.execute_input": "2023-07-28T07:43:41.850101Z",
21-
"iopub.status.busy": "2023-07-28T07:43:41.849672Z",
22-
"iopub.status.idle": "2023-07-28T07:43:41.853520Z",
23-
"shell.execute_reply": "2023-07-28T07:43:41.852990Z"
24-
}
25-
},
18+
"metadata": {},
2619
"outputs": [],
2720
"source": [
2821
"import os\n",
@@ -50,14 +43,7 @@
5043
{
5144
"cell_type": "code",
5245
"execution_count": 2,
53-
"metadata": {
54-
"execution": {
55-
"iopub.execute_input": "2023-07-28T07:43:41.855441Z",
56-
"iopub.status.busy": "2023-07-28T07:43:41.855301Z",
57-
"iopub.status.idle": "2023-07-28T07:43:41.986406Z",
58-
"shell.execute_reply": "2023-07-28T07:43:41.985500Z"
59-
}
60-
},
46+
"metadata": {},
6147
"outputs": [],
6248
"source": [
6349
"from nvidia.dali.plugin.jax import data_iterator\n",
@@ -98,22 +84,15 @@
9884
{
9985
"cell_type": "code",
10086
"execution_count": 3,
101-
"metadata": {
102-
"execution": {
103-
"iopub.execute_input": "2023-07-28T07:43:41.989183Z",
104-
"iopub.status.busy": "2023-07-28T07:43:41.988964Z",
105-
"iopub.status.idle": "2023-07-28T07:43:42.104446Z",
106-
"shell.execute_reply": "2023-07-28T07:43:42.103668Z"
107-
}
108-
},
87+
"metadata": {},
10988
"outputs": [
11089
{
11190
"name": "stdout",
11291
"output_type": "stream",
11392
"text": [
11493
"Creating iterators\n",
115-
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>\n",
116-
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>\n"
94+
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d7397b4b790>\n",
95+
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d739800e530>\n"
11796
]
11897
}
11998
],
@@ -145,19 +124,50 @@
145124
{
146125
"cell_type": "code",
147126
"execution_count": 4,
148-
"metadata": {
149-
"execution": {
150-
"iopub.execute_input": "2023-07-28T07:43:43.559575Z",
151-
"iopub.status.busy": "2023-07-28T07:43:43.559420Z",
152-
"iopub.status.idle": "2023-07-28T07:43:43.618221Z",
153-
"shell.execute_reply": "2023-07-28T07:43:43.617532Z"
154-
}
155-
},
127+
"metadata": {},
156128
"outputs": [],
157129
"source": [
158130
"from model import init_model, update, accuracy"
159131
]
160132
},
133+
{
134+
"cell_type": "markdown",
135+
"metadata": {},
136+
"source": [
137+
"`jax.jit` traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI's background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don't trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": 5,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"import jax.numpy as jnp\n",
147+
"\n",
148+
"model = init_model()\n",
149+
"dummy_images = jnp.empty(\n",
150+
" (batch_size, image_size * image_size), dtype=jnp.float32\n",
151+
")\n",
152+
"dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)\n",
153+
"_ = update(model, {\"images\": dummy_images, \"labels\": dummy_labels})"
154+
]
155+
},
156+
{
157+
"cell_type": "markdown",
158+
"metadata": {},
159+
"source": [
160+
"<div class=\"alert alert-warning\">\n",
161+
"\n",
162+
" Warning<br>\n",
163+
" \n",
164+
" If you skip this step, CUDA graph capture will happen on the first call to `update` and may overlap with DALI's execution, causing CUDA errors in JAX.\n",
165+
" \n",
166+
" Alternatively, you can disable XLA command buffers entirely by setting `XLA_FLAGS=\"--xla_gpu_enable_command_buffer=\"`, at the cost of some performance.\n",
167+
" \n",
168+
"</div>"
169+
]
170+
},
161171
{
162172
"attachments": {},
163173
"cell_type": "markdown",
@@ -168,38 +178,30 @@
168178
},
169179
{
170180
"cell_type": "code",
171-
"execution_count": 5,
172-
"metadata": {
173-
"execution": {
174-
"iopub.execute_input": "2023-07-28T07:43:43.622376Z",
175-
"iopub.status.busy": "2023-07-28T07:43:43.621205Z",
176-
"iopub.status.idle": "2023-07-28T07:43:58.016073Z",
177-
"shell.execute_reply": "2023-07-28T07:43:58.015333Z"
178-
}
179-
},
181+
"execution_count": 6,
182+
"metadata": {},
180183
"outputs": [
181184
{
182185
"name": "stdout",
183186
"output_type": "stream",
184187
"text": [
185188
"Starting training\n",
186189
"Epoch 0 sec\n",
187-
"Test set accuracy 0.67330002784729\n",
190+
"Test set accuracy 0.674500048160553\n",
188191
"Epoch 1 sec\n",
189-
"Test set accuracy 0.7855000495910645\n",
192+
"Test set accuracy 0.7854000329971313\n",
190193
"Epoch 2 sec\n",
191-
"Test set accuracy 0.8251000642776489\n",
194+
"Test set accuracy 0.8252000212669373\n",
192195
"Epoch 3 sec\n",
193-
"Test set accuracy 0.8469000458717346\n",
196+
"Test set accuracy 0.847100019454956\n",
194197
"Epoch 4 sec\n",
195-
"Test set accuracy 0.8616000413894653\n"
198+
"Test set accuracy 0.8618000149726868\n"
196199
]
197200
}
198201
],
199202
"source": [
200203
"print(\"Starting training\")\n",
201204
"\n",
202-
"model = init_model()\n",
203205
"num_epochs = 5\n",
204206
"\n",
205207
"for epoch in range(num_epochs):\n",
@@ -215,7 +217,7 @@
215217
"metadata": {
216218
"celltoolbar": "Raw Cell Format",
217219
"kernelspec": {
218-
"display_name": "Python 3",
220+
"display_name": "Python 3 (ipykernel)",
219221
"language": "python",
220222
"name": "python3"
221223
},
@@ -229,9 +231,9 @@
229231
"name": "python",
230232
"nbconvert_exporter": "python",
231233
"pygments_lexer": "ipython3",
232-
"version": "3.10.12"
234+
"version": "3.10.20"
233235
}
234236
},
235237
"nbformat": 4,
236-
"nbformat_minor": 2
238+
"nbformat_minor": 4
237239
}

0 commit comments

Comments
 (0)