|
95 | 95 | " IN_COLAB = False" |
96 | 96 | ] |
97 | 97 | }, |
98 | | - { |
99 | | - "cell_type": "code", |
100 | | - "execution_count": null, |
101 | | - "metadata": {}, |
102 | | - "outputs": [], |
103 | | - "source": [ |
104 | | - "try:\n", |
105 | | - " import google.colab\n", |
106 | | - " print(\"Running the notebook in Google Colab\")\n", |
107 | | - " IN_COLAB = True\n", |
108 | | - "except ImportError:\n", |
109 | | - " print(\"Running the notebook on JupyterLab\")\n", |
110 | | - " IN_COLAB = False" |
111 | | - ] |
112 | | - }, |
113 | 98 | { |
114 | 99 | "cell_type": "markdown", |
115 | 100 | "metadata": { |
|
168 | 153 | "from datetime import datetime\n", |
169 | 154 | "from flax import nnx\n", |
170 | 155 | "from huggingface_hub import login\n", |
| 156 | + "from jax import numpy as jnp\n", |
171 | 157 | "\n", |
172 | 158 | "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" |
173 | 159 | ] |
|
192 | 178 | }, |
193 | 179 | "outputs": [], |
194 | 180 | "source": [ |
195 | | - "try:\n", |
196 | | - " from google.colab import userdata\n", |
| 181 | + "if IN_COLAB:\n", |
197 | 182 | " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", |
198 | | - "except ImportError:\n", |
| 183 | + "else:\n", |
199 | 184 | " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", |
200 | 185 | "\n", |
201 | 186 | "# If not found in the environment, prompt the user for input securely\n", |
|
438 | 423 | " rollout_vllm_hbm_utilization=0.8,\n", |
439 | 424 | " rollout_vllm_init_with_random_weights=True,\n", |
440 | 425 | " rollout_vllm_tpu_backend_type=\"jax\",\n", |
| 426 | + " data_type=jnp.bfloat16,\n", |
441 | 427 | " ),\n", |
442 | 428 | ")" |
443 | 429 | ] |
|
0 commit comments