1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- # Description:
18- # bash setup.sh MODE={stable,nightly,libtpu-only} LIBTPU_GCS_PATH={gcs_path_to_custom_libtpu} DEVICE={tpu,gpu}
17+ # ==================================
18+ # TPU EXAMPLES
19+ # ==================================
1920
21+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt
22+ # # bash tools/setup/setup.sh MODE=stable
2023
21- # You need to specify a MODE, default value stable.
22- # You have the option to provide a LIBTPU_GCS_PATH that points to a libtpu.so provided to you by Google.
23- # In libtpu-only MODE, the LIBTPU_GCS_PATH is mandatory.
24- # For MODE=stable you may additionally specify JAX_VERSION, e.g. JAX_VERSION=0.4.13
25- # For DEVICE=gpu, you may also specify JAX_VERSION when MODE=nightly, e.g. JAX_VERSION=0.4.36.dev20241109
24+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax, jaxlib, libtpu
25+ # # bash tools/setup/setup.sh MODE=stable JAX_VERSION=0.8.0
26+
27+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + custom libtpu
28+ # # bash tools/setup/setup.sh MODE=stable LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
29+
30+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + jax-nightly, jaxlib-nightly, libtpu-nightly
31+ # # bash tools/setup/setup.sh MODE=nightly
32+
33+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax-nightly, jaxlib-nightly + latest libtpu-nightly
34+ # # bash tools/setup/setup.sh MODE=nightly JAX_VERSION=0.8.2.dev20251211
35+
36+ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + jax-nightly, jaxlib-nightly + custom libtpu
37+ # # bash tools/setup/setup.sh MODE=nightly LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
38+
39+ # Install custom libtpu only
40+ # # bash tools/setup/setup.sh MODE=libtpu-only LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
41+
42+ # ==================================
43+ # GPU EXAMPLES
44+ # ==================================
45+
46+ # Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt
47+ # # bash tools/setup/setup.sh MODE=stable DEVICE=gpu
48+
49+ # Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + specified jax, jaxlib, jax-cuda12-plugin, jax-cuda12-pjrt
50+ # # bash tools/setup/setup.sh MODE=stable DEVICE=gpu JAX_VERSION=0.4.13
51+
52+ # Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + jax-nightly, jaxlib-nightly
53+ # # bash tools/setup/setup.sh MODE=nightly DEVICE=gpu
54+
55+ # Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + specified jax, jaxlib, jax-cuda12-plugin, jax-cuda12-pjrt
56+ # # bash tools/setup/setup.sh MODE=nightly DEVICE=gpu JAX_VERSION=0.4.36.dev20241109
2657
2758
2859# Enable "exit immediately if any command fails" option
@@ -102,7 +133,6 @@ apt update -y && apt -y install gcsfuse
102133rm -rf /var/lib/apt/lists/*
103134EOF
104135
105-
106136python3 -m pip install -U setuptools wheel uv
107137
108138# Set environment variables
@@ -111,188 +141,148 @@ for ARGUMENT in "$@"; do
111141 export " $KEY " =" $VALUE "
112142done
113143
144+ # Set default value for $DEVICE
114145if [[ -z " $DEVICE " ]]; then
115- export DEVICE=" tpu"
146+ export DEVICE=tpu
116147fi
117148
118- if [[ $JAX_VERSION == NONE ]]; then
119- unset JAX_VERSION
149+ # Set default value for $MODE
150+ if [[ -z " $MODE " ]]; then
151+ export MODE=stable
120152fi
121153
122- if [[ $LIBTPU_GCS_PATH == NONE ]]; then
123- unset LIBTPU_GCS_PATH
124- fi
154+ # Unset optional variables if set to NONE
155+ unset_optional_vars () {
156+ local optional_vars=(" JAX_VERSION" " LIBTPU_GCS_PATH" )
157+ for var_name in " ${optional_vars[@]} " ; do
158+ if [[ ${! var_name} == NONE ]]; then
159+ unset " $var_name "
160+ fi
161+ done
162+ }
163+ unset_optional_vars
125164
126- if [[ -n $JAX_VERSION && ! ($MODE == " stable" || -z $MODE || ($MODE == " nightly" && $DEVICE == " gpu" )) ]]; then
127- echo -e " \n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n"
128- exit 1
129- fi
165+ version_mismatch_warning () {
166+ echo -e " \n\nWARNING: You are installing a $1 version that is different from the one pinned by MaxText. This can lead to the following issues:"
167+ echo -e " 1. Compatibility: The dependencies in the requirements file are tested and compatible with the pinned $1 version. We cannot guarantee that they will work correctly with a different $1 version."
168+ echo -e " 2. Consistency: Installing a custom $1 version can pull in different transitive dependencies over time, making the environment non-reproducible and potentially affecting performance.\n\n"
169+ }
130170
131- if [[ $DEVICE == " tpu " ]] ; then
171+ install_custom_libtpu () {
132172 libtpu_path=" $HOME /custom_libtpu/libtpu.so"
133- if [[ " $MODE " == " libtpu-only" ]]; then
134- # Only update custom libtpu.
135- if [[ -n " $LIBTPU_GCS_PATH " ]]; then
136- # Install custom libtpu
137- echo " Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path "
138- # Install required dependency
139- python3 -m uv pip install -U crcmod
140- # Copy libtpu.so from GCS path
141- gsutil cp " $LIBTPU_GCS_PATH " " $libtpu_path "
142- exit 0
143- else
144- echo -e " \n\nError: You must provide a custom libtpu for libtpu-only mode.\n\n"
145- exit 1
146- fi
173+ echo -e " \nInstalling libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path "
174+ version_mismatch_warning " libtpu"
175+ # Delete custom libtpu if it exists
176+ if [ -e " $libtpu_path " ]; then
177+ rm -v " $libtpu_path "
147178 fi
148- fi
179+ # Install 'crcmod' to download 'libtpu.so' from GCS reliably
180+ python3 -m uv pip install -U crcmod
181+ # Copy libtpu.so from GCS path
182+ gsutil cp " $LIBTPU_GCS_PATH " " $libtpu_path "
183+ }
149184
150- if [[ " $MODE " == " nightly" ]]; then
151- if [ " $DEVICE " = " gpu" ]; then
152- dep_name=' dependencies/requirements/generated_requirements/cuda12-requirements.txt'
153- else
154- dep_name=' dependencies/requirements/generated_requirements/' " ${DEVICE?} " ' -requirements.txt'
185+ install_maxtext_with_deps () {
186+ if [[ " $DEVICE " != " tpu" && " $DEVICE " != " gpu" ]]; then
187+ echo -e " \n\nError: DEVICE must be either 'tpu' or 'gpu'.\n\n"
188+ exit 1
155189 fi
156- printf ' Nightly mode: Installing "%s", stripping commit pins from git+ repos.\n' " $dep_name "
157- nightly_txt=" ${dep_name##*/ } "
158- nightly_txt=" ${nightly_txt% .txt} " ' -nightly-temp.txt'
159-
160- # Create a temp file, strip commit pins from git+ repos in requirements.txt
161- # Remove/update this section based on the pinned github repo commit in requirements.txt
162- sed -E \
163- -e ' s|^([^ ]*) @ https?://github.com/([^/]*\/[^/]*)/archive/.*\.zip$|\1@git+https://github.com/\2.git|' \
164- -e ' /JetStream/d' \
165- -e ' /mlperf-logging/d' \
166- " $dep_name " > " $nightly_txt "
167-
168- echo " --- Installing modified nightly requirements: ---"
169- cat -- " $nightly_txt "
170- echo " -------------------------------------------------"
171-
172- python3 -m uv pip install --no-cache-dir -U -r " $nightly_txt " \
173- -r ' src/install_maxtext_extra_deps/extra_deps_from_github.txt'
174- rm -fv -- " $nightly_txt "
175- else
176- # stable or stable_stack mode: Install with pinned commits
190+ echo " Setting up MaxText in $MODE mode for $DEVICE device"
177191 if [ " $DEVICE " = " gpu" ]; then
178- dep_basename= ' cuda12-requirements.txt'
192+ dep_name= ' dependencies/requirements/generated_requirements/ cuda12-requirements.txt'
179193 else
180- dep_basename= " ${DEVICE?} " ' -requirements.txt'
194+ dep_name= ' dependencies/requirements/generated_requirements/tpu -requirements.txt'
181195 fi
196+ echo " Installing requirements from $dep_name "
197+ python3 -m uv pip install --resolution=lowest -r " $dep_name " \
198+ -r ' src/install_maxtext_extra_deps/extra_deps_from_github.txt'
182199
183- printf ' Installing "%s" with pinned commits.\n' " $dep_basename "
184- requirements_txt=
185- for candidate in ' dependencies/requirements/generated_requirements' ' dependencies/requirements' " ${MAXTEXT_REPO_ROOT} " ' /dependencies/requirements' " $PWD " ; do
186- if [ -f " $candidate " ' /' " $dep_basename " ]; then
187- requirements_txt=" $candidate " ' /' " $dep_basename "
188- break
189- else
190- searched=" $searched " ' :'
191- fi
192- done
193- if [ -z " ${requirements_txt} " ]; then
194- >&2 printf ' Could not find "%s", looked in: %s\n' " $dep_basename " " ${searched% ?} "
195- exit 2
196- else
197- python3 -m uv pip install --resolution=lowest -r " $requirements_txt " \
198- -r ' src/install_maxtext_extra_deps/extra_deps_from_github.txt'
200+ # The MaxText package is installed separately from its dependencies to optimize
201+ # docker image rebuild times by leveraging docker's layer caching.
202+ # Dependencies are installed in a separate step before MaxText code is
203+ # copied. This means that if MaxText code changes, but the
204+ # dependencies do not, docker can reuse the cached dependency layer, leading
205+ # to significantly faster image builds.
206+ if [ -f ' pyproject.toml' ]; then
207+ echo " Installing MaxText package without installing the dependencies (already installed)"
208+ python3 -m uv pip install --no-deps -e .
199209 fi
200- fi
210+ }
201211
202- # Install maxtext package
203- if [ -f ' pyproject.toml' ]; then
204- case " $DEVICE " in
205- ' gpu' ) python3 -m uv pip install -e .[cuda12] --no-deps --resolution=lowest ;;
206- ' tpu' ) python3 -m uv pip install -e .[tpu] --no-deps --resolution=lowest ;;
207- * )
208- >&2 printf ' Unsupported device\n'
209- exit 6
210- ;;
211- esac
212- python3 -m uv pip install --resolution=lowest -r ' src/install_maxtext_extra_deps/extra_deps_from_github.txt'
213- fi
212+ # stable mode installation
213+ if [[ " $MODE " == " stable" ]]; then
214+ install_maxtext_with_deps
214215
215- # Delete custom libtpu if it exists
216- if [ -e " $libtpu_path " ]; then
217- rm -v " $libtpu_path "
218- fi
219-
220- if [[ " $MODE " == " stable" || ! -v MODE ]]; then
221- # Stable mode
222216 if [[ $DEVICE == " tpu" ]]; then
223- echo " Installing stable jax, jaxlib for tpu"
224217 if [[ -n " $JAX_VERSION " ]]; then
225- echo " Installing stable jax, jaxlib, libtpu version ${JAX_VERSION} "
226- python3 -m uv pip install -U jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
218+ echo -e " \nInstalling stable jax, jaxlib, libtpu version ${JAX_VERSION} "
219+ version_mismatch_warning " jax"
220+ python3 -m uv pip install -U jax[tpu]==${JAX_VERSION}
227221 fi
228222 if [[ -n " $LIBTPU_GCS_PATH " ]]; then
229- # Install custom libtpu
230- echo " Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path "
231- # Install required dependency
232- python3 -m uv pip install -U crcmod
233- # Copy libtpu.so from GCS path
234- gsutil cp " $LIBTPU_GCS_PATH " " $libtpu_path "
223+ install_custom_libtpu
235224 fi
236225 elif [[ $DEVICE == " gpu" ]]; then
237- echo " Installing stable jax, jaxlib for NVIDIA gpu"
238226 if [[ -n " $JAX_VERSION " ]]; then
239- echo " Installing stable jax, jaxlib ${JAX_VERSION} "
227+ echo -e " \nInstalling stable jax, jaxlib ${JAX_VERSION} "
228+ version_mismatch_warning " jax"
240229 python3 -m uv pip install -U " jax[cuda12]==${JAX_VERSION} "
241- else
242- echo " Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
243- python3 -m uv pip install " jax[cuda12]"
244- fi
245- export NVTE_FRAMEWORK=jax
246- if [[ -n " $JAX_VERSION " && " $JAX_VERSION " != " 0.7.0" ]]; then
247- python3 -m uv pip install transformer-engine[jax]
248- else
249- python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@9d031f
250230 fi
251231 fi
252- elif [[ $MODE == " nightly " ]] ; then
253- # Nightly mode
232+ exit 0
233+ fi
254234
255- # Uninstall existing jax, jaxlib and libtpu-nightly
235+ # nightly mode installation
236+ if [[ $MODE == " nightly" ]]; then
237+ install_maxtext_with_deps
238+
239+ # Uninstall existing jax, jaxlib and libtpu
256240 python3 -m uv pip show jax && python3 -m uv pip uninstall jax
257241 python3 -m uv pip show jaxlib && python3 -m uv pip uninstall jaxlib
258- python3 -m uv pip show libtpu-nightly && python3 -m uv pip uninstall libtpu-nightly
242+ python3 -m uv pip show libtpu && python3 -m uv pip uninstall libtpu
259243
260- if [[ $DEVICE == " gpu" ]]; then
261- # Install jax-nightly
244+ if [[ $DEVICE == " tpu" ]]; then
262245 if [[ -n " $JAX_VERSION " ]]; then
263- echo " Installing jax-nightly, jaxlib-nightly ${JAX_VERSION} "
264- python3 -m uv pip install -U --pre jax==${JAX_VERSION} jaxlib==${JAX_VERSION} jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
246+ echo -e " \nInstalling jax-nightly, jaxlib-nightly ${JAX_VERSION} "
247+ python3 -m uv pip install -U --pre --no-deps jax==${JAX_VERSION} jaxlib==${JAX_VERSION} -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
265248 else
266- echo " Installing latest jax-nightly, jaxlib-nightly"
267- python3 -m uv pip install -U -- pre jax jaxlib jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
249+ echo -e " \nInstalling the latest jax-nightly, jaxlib-nightly"
250+ python3 -m uv pip install -- pre -U --no-deps jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
268251 fi
269- # Install Transformer Engine
270- export NVTE_FRAMEWORK=jax
271- python3 -m uv pip install https://github.com/NVIDIA/TransformerEngine/archive/9d031f.zip
272- elif [[ $DEVICE == " tpu" ]]; then
273- echo " Installing nightly tensorboard plugin profile"
274- python3 -m uv pip install tbp-nightly --upgrade
275- echo " Installing jax-nightly, jaxlib-nightly"
276- # Install jax-nightly
277- python3 -m uv pip install --pre -U jax -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
278- # Install jaxlib-nightly
279- python3 -m uv pip install --pre -U jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
252+
280253 if [[ -n " $LIBTPU_GCS_PATH " ]]; then
281- # Install custom libtpu
282- echo " Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path "
283- # Install required dependency
284- python3 -m uv pip install -U crcmod
285- # Copy libtpu.so from GCS path
286- gsutil cp " $LIBTPU_GCS_PATH " " $libtpu_path "
254+ install_custom_libtpu
287255 else
288- # Install libtpu-nightly
289- echo " Installing libtpu-nightly"
290- python3 -m uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
256+ echo -e " \nInstalling the latest libtpu-nightly"
257+ python3 -m uv pip install -U --pre --no-deps libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
258+ fi
259+ elif [[ $DEVICE == " gpu" ]]; then
260+ if [[ -n " $JAX_VERSION " ]]; then
261+ echo -e " \nInstalling jax-nightly, jaxlib-nightly ${JAX_VERSION} "
262+ python3 -m uv pip install -U --pre --no-deps jax==${JAX_VERSION} jaxlib==${JAX_VERSION} \
263+ jax-cuda12-plugin[with-cuda]==${JAX_VERSION} jax-cuda12-pjrt==${JAX_VERSION} -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
264+ else
265+ echo -e " \nInstalling the latest jax-nightly, jaxlib-nightly"
266+ python3 -m uv pip install -U --pre --no-deps jax jaxlib \
267+ jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
291268 fi
292269 fi
293- echo " Installing nightly tensorboard plugin profile"
294- python3 -m uv pip install tbp-nightly --upgrade
295- else
296- echo -e " \n\nError: You can only set MODE to [stable,nightly,libtpu-only].\n\n"
297- exit 1
270+ exit 0
298271fi
272+
273+ # libtpu-only mode installation
274+ if [[ " $MODE " == " libtpu-only" ]]; then
275+ if [[ " $DEVICE " != " tpu" ]]; then
276+ echo -e " \n\nError: MODE=libtpu-only is only supported for DEVICE=tpu.\n\n"
277+ exit 1
278+ fi
279+ if [[ -z " $LIBTPU_GCS_PATH " ]]; then
280+ echo -e " \n\nError: LIBTPU_GCS_PATH must be set when MODE is libtpu-only.\n\n"
281+ exit 1
282+ fi
283+ install_custom_libtpu
284+ exit 0
285+ fi
286+
287+ echo -e " \n\nError: MODE must be either 'stable', 'nightly', or 'libtpu-only'.\n\n"
288+ exit 1
0 commit comments