Skip to content

Commit e3174f8

Browse files
committed
Fix sft_qwen3_demo jupyter notebook
1 parent c0b15ef commit e3174f8

2 files changed

Lines changed: 5 additions & 19 deletions

File tree

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ jobs:
8484
8585
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
8686
filename=$(basename "$notebook")
87-
if [[ "$filename" == "sft_qwen3_demo.ipynb" || "$filename" == "sft_llama3_demo_gpu.ipynb" ]]; then
87+
if [[ "$filename" == "sft_llama3_demo_gpu.ipynb" ]]; then
8888
echo "Skipping $filename"
8989
continue
9090
fi

src/maxtext/examples/sft_qwen3_demo.ipynb

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,6 @@
9595
" IN_COLAB = False"
9696
]
9797
},
98-
{
99-
"cell_type": "code",
100-
"execution_count": null,
101-
"metadata": {},
102-
"outputs": [],
103-
"source": [
104-
"try:\n",
105-
" import google.colab\n",
106-
" print(\"Running the notebook in Google Colab\")\n",
107-
" IN_COLAB = True\n",
108-
"except ImportError:\n",
109-
" print(\"Running the notebook on JupyterLab\")\n",
110-
" IN_COLAB = False"
111-
]
112-
},
11398
{
11499
"cell_type": "markdown",
115100
"metadata": {
@@ -168,6 +153,7 @@
168153
"from datetime import datetime\n",
169154
"from flax import nnx\n",
170155
"from huggingface_hub import login\n",
156+
"from jax import numpy as jnp\n",
171157
"\n",
172158
"print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
173159
]
@@ -192,10 +178,9 @@
192178
},
193179
"outputs": [],
194180
"source": [
195-
"try:\n",
196-
" from google.colab import userdata\n",
181+
"if IN_COLAB:\n",
197182
" HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
198-
"except ImportError:\n",
183+
"else:\n",
199184
" HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
200185
"\n",
201186
"# If not found in the environment, prompt the user for input securely\n",
@@ -438,6 +423,7 @@
438423
" rollout_vllm_hbm_utilization=0.8,\n",
439424
" rollout_vllm_init_with_random_weights=True,\n",
440425
" rollout_vllm_tpu_backend_type=\"jax\",\n",
426+
" data_type=jnp.bfloat16,\n",
441427
" ),\n",
442428
")"
443429
]

0 commit comments

Comments
 (0)