Skip to content

Commit 88069a4

Browse files
Merge pull request #2833 from AI-Hypercomputer:nightly
PiperOrigin-RevId: 846834192
2 parents 820c39f + ef37ff4 commit 88069a4

1 file changed

Lines changed: 143 additions & 153 deletions

File tree

tools/setup/setup.sh

Lines changed: 143 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,46 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
# Description:
18-
# bash setup.sh MODE={stable,nightly,libtpu-only} LIBTPU_GCS_PATH={gcs_path_to_custom_libtpu} DEVICE={tpu,gpu}
17+
# ==================================
18+
# TPU EXAMPLES
19+
# ==================================
1920

21+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt
22+
## bash tools/setup/setup.sh MODE=stable
2023

21-
# You need to specify a MODE, default value stable.
22-
# You have the option to provide a LIBTPU_GCS_PATH that points to a libtpu.so provided to you by Google.
23-
# In libtpu-only MODE, the LIBTPU_GCS_PATH is mandatory.
24-
# For MODE=stable you may additionally specify JAX_VERSION, e.g. JAX_VERSION=0.4.13
25-
# For DEVICE=gpu, you may also specify JAX_VERSION when MODE=nightly, e.g. JAX_VERSION=0.4.36.dev20241109
24+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax, jaxlib, libtpu
25+
## bash tools/setup/setup.sh MODE=stable JAX_VERSION=0.8.0
26+
27+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + custom libtpu
28+
## bash tools/setup/setup.sh MODE=stable LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
29+
30+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + jax-nightly, jaxlib-nightly, libtpu-nightly
31+
## bash tools/setup/setup.sh MODE=nightly
32+
33+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax-nightly, jaxlib-nightly + latest libtpu-nightly
34+
## bash tools/setup/setup.sh MODE=nightly JAX_VERSION=0.8.2.dev20251211
35+
36+
# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + jax-nightly, jaxlib-nightly + custom libtpu
37+
## bash tools/setup/setup.sh MODE=nightly LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
38+
39+
# Install custom libtpu only
40+
## bash tools/setup/setup.sh MODE=libtpu-only LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so
41+
42+
# ==================================
43+
# GPU EXAMPLES
44+
# ==================================
45+
46+
# Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt
47+
## bash tools/setup/setup.sh MODE=stable DEVICE=gpu
48+
49+
# Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + specified jax, jaxlib, jax-cuda12-plugin, jax-cuda12-pjrt
50+
## bash tools/setup/setup.sh MODE=stable DEVICE=gpu JAX_VERSION=0.4.13
51+
52+
# Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + jax-nightly, jaxlib-nightly
53+
## bash tools/setup/setup.sh MODE=nightly DEVICE=gpu
54+
55+
# Install dependencies in dependencies/generated_requirements/cuda12-requirements.txt + specified jax, jaxlib, jax-cuda12-plugin, jax-cuda12-pjrt
56+
## bash tools/setup/setup.sh MODE=nightly DEVICE=gpu JAX_VERSION=0.4.36.dev20241109
2657

2758

2859
# Enable "exit immediately if any command fails" option
@@ -102,7 +133,6 @@ apt update -y && apt -y install gcsfuse
102133
rm -rf /var/lib/apt/lists/*
103134
EOF
104135

105-
106136
python3 -m pip install -U setuptools wheel uv
107137

108138
# Set environment variables
@@ -111,188 +141,148 @@ for ARGUMENT in "$@"; do
111141
export "$KEY"="$VALUE"
112142
done
113143

144+
# Set default value for $DEVICE
114145
if [[ -z "$DEVICE" ]]; then
115-
export DEVICE="tpu"
146+
export DEVICE=tpu
116147
fi
117148

118-
if [[ $JAX_VERSION == NONE ]]; then
119-
unset JAX_VERSION
149+
# Set default value for $MODE
150+
if [[ -z "$MODE" ]]; then
151+
export MODE=stable
120152
fi
121153

122-
if [[ $LIBTPU_GCS_PATH == NONE ]]; then
123-
unset LIBTPU_GCS_PATH
124-
fi
154+
# Unset optional variables if set to NONE
155+
unset_optional_vars() {
156+
local optional_vars=("JAX_VERSION" "LIBTPU_GCS_PATH")
157+
for var_name in "${optional_vars[@]}"; do
158+
if [[ ${!var_name} == NONE ]]; then
159+
unset "$var_name"
160+
fi
161+
done
162+
}
163+
unset_optional_vars
125164

126-
if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE || ($MODE == "nightly" && $DEVICE == "gpu")) ]]; then
127-
echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n"
128-
exit 1
129-
fi
165+
version_mismatch_warning() {
166+
echo -e "\n\nWARNING: You are installing a $1 version that is different from the one pinned by MaxText. This can lead to the following issues:"
167+
echo -e "1. Compatibility: The dependencies in the requirements file are tested and compatible with the pinned $1 version. We cannot guarantee that they will work correctly with a different $1 version."
168+
echo -e "2. Consistency: Installing a custom $1 version can pull in different transitive dependencies over time, making the environment non-reproducible and potentially affecting performance.\n\n"
169+
}
130170

131-
if [[ $DEVICE == "tpu" ]]; then
171+
install_custom_libtpu() {
132172
libtpu_path="$HOME/custom_libtpu/libtpu.so"
133-
if [[ "$MODE" == "libtpu-only" ]]; then
134-
# Only update custom libtpu.
135-
if [[ -n "$LIBTPU_GCS_PATH" ]]; then
136-
# Install custom libtpu
137-
echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
138-
# Install required dependency
139-
python3 -m uv pip install -U crcmod
140-
# Copy libtpu.so from GCS path
141-
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
142-
exit 0
143-
else
144-
echo -e "\n\nError: You must provide a custom libtpu for libtpu-only mode.\n\n"
145-
exit 1
146-
fi
173+
echo -e "\nInstalling libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
174+
version_mismatch_warning "libtpu"
175+
# Delete custom libtpu if it exists
176+
if [ -e "$libtpu_path" ]; then
177+
rm -v "$libtpu_path"
147178
fi
148-
fi
179+
# Install 'crcmod' to download 'libtpu.so' from GCS reliably
180+
python3 -m uv pip install -U crcmod
181+
# Copy libtpu.so from GCS path
182+
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
183+
}
149184

150-
if [[ "$MODE" == "nightly" ]]; then
151-
if [ "$DEVICE" = "gpu" ]; then
152-
dep_name='dependencies/requirements/generated_requirements/cuda12-requirements.txt'
153-
else
154-
dep_name='dependencies/requirements/generated_requirements/'"${DEVICE?}"'-requirements.txt'
185+
install_maxtext_with_deps() {
186+
if [[ "$DEVICE" != "tpu" && "$DEVICE" != "gpu" ]]; then
187+
echo -e "\n\nError: DEVICE must be either 'tpu' or 'gpu'.\n\n"
188+
exit 1
155189
fi
156-
printf 'Nightly mode: Installing "%s", stripping commit pins from git+ repos.\n' "$dep_name"
157-
nightly_txt="${dep_name##*/}"
158-
nightly_txt="${nightly_txt%.txt}"'-nightly-temp.txt'
159-
160-
# Create a temp file, strip commit pins from git+ repos in requirements.txt
161-
# Remove/update this section based on the pinned github repo commit in requirements.txt
162-
sed -E \
163-
-e 's|^([^ ]*) @ https?://github.com/([^/]*\/[^/]*)/archive/.*\.zip$|\1@git+https://github.com/\2.git|' \
164-
-e '/JetStream/d' \
165-
-e '/mlperf-logging/d' \
166-
"$dep_name" > "$nightly_txt"
167-
168-
echo "--- Installing modified nightly requirements: ---"
169-
cat -- "$nightly_txt"
170-
echo "-------------------------------------------------"
171-
172-
python3 -m uv pip install --no-cache-dir -U -r "$nightly_txt" \
173-
-r 'src/install_maxtext_extra_deps/extra_deps_from_github.txt'
174-
rm -fv -- "$nightly_txt"
175-
else
176-
# stable or stable_stack mode: Install with pinned commits
190+
echo "Setting up MaxText in $MODE mode for $DEVICE device"
177191
if [ "$DEVICE" = "gpu" ]; then
178-
dep_basename='cuda12-requirements.txt'
192+
dep_name='dependencies/requirements/generated_requirements/cuda12-requirements.txt'
179193
else
180-
dep_basename="${DEVICE?}"'-requirements.txt'
194+
dep_name='dependencies/requirements/generated_requirements/tpu-requirements.txt'
181195
fi
196+
echo "Installing requirements from $dep_name"
197+
python3 -m uv pip install --resolution=lowest -r "$dep_name" \
198+
-r 'src/install_maxtext_extra_deps/extra_deps_from_github.txt'
182199

183-
printf 'Installing "%s" with pinned commits.\n' "$dep_basename"
184-
requirements_txt=
185-
for candidate in 'dependencies/requirements/generated_requirements' 'dependencies/requirements' "${MAXTEXT_REPO_ROOT}"'/dependencies/requirements' "$PWD"; do
186-
if [ -f "$candidate"'/'"$dep_basename" ]; then
187-
requirements_txt="$candidate"'/'"$dep_basename"
188-
break
189-
else
190-
searched="$searched"':'
191-
fi
192-
done
193-
if [ -z "${requirements_txt}" ]; then
194-
>&2 printf 'Could not find "%s", looked in: %s\n' "$dep_basename" "${searched%?}"
195-
exit 2
196-
else
197-
python3 -m uv pip install --resolution=lowest -r "$requirements_txt" \
198-
-r 'src/install_maxtext_extra_deps/extra_deps_from_github.txt'
200+
# The MaxText package is installed separately from its dependencies to optimize
201+
# docker image rebuild times by leveraging docker's layer caching.
202+
# Dependencies are installed in a separate step before MaxText code is
203+
# copied. This means that if MaxText code changes, but the
204+
# dependencies do not, docker can reuse the cached dependency layer, leading
205+
# to significantly faster image builds.
206+
if [ -f 'pyproject.toml' ]; then
207+
echo "Installing MaxText package without installing the dependencies (already installed)"
208+
python3 -m uv pip install --no-deps -e .
199209
fi
200-
fi
210+
}
201211

202-
# Install maxtext package
203-
if [ -f 'pyproject.toml' ]; then
204-
case "$DEVICE" in
205-
'gpu') python3 -m uv pip install -e .[cuda12] --no-deps --resolution=lowest ;;
206-
'tpu') python3 -m uv pip install -e .[tpu] --no-deps --resolution=lowest ;;
207-
*)
208-
>&2 printf 'Unsupported device\n'
209-
exit 6
210-
;;
211-
esac
212-
python3 -m uv pip install --resolution=lowest -r 'src/install_maxtext_extra_deps/extra_deps_from_github.txt'
213-
fi
212+
# stable mode installation
213+
if [[ "$MODE" == "stable" ]]; then
214+
install_maxtext_with_deps
214215

215-
# Delete custom libtpu if it exists
216-
if [ -e "$libtpu_path" ]; then
217-
rm -v "$libtpu_path"
218-
fi
219-
220-
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
221-
# Stable mode
222216
if [[ $DEVICE == "tpu" ]]; then
223-
echo "Installing stable jax, jaxlib for tpu"
224217
if [[ -n "$JAX_VERSION" ]]; then
225-
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
226-
python3 -m uv pip install -U jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
218+
echo -e "\nInstalling stable jax, jaxlib, libtpu version ${JAX_VERSION}"
219+
version_mismatch_warning "jax"
220+
python3 -m uv pip install -U jax[tpu]==${JAX_VERSION}
227221
fi
228222
if [[ -n "$LIBTPU_GCS_PATH" ]]; then
229-
# Install custom libtpu
230-
echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
231-
# Install required dependency
232-
python3 -m uv pip install -U crcmod
233-
# Copy libtpu.so from GCS path
234-
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
223+
install_custom_libtpu
235224
fi
236225
elif [[ $DEVICE == "gpu" ]]; then
237-
echo "Installing stable jax, jaxlib for NVIDIA gpu"
238226
if [[ -n "$JAX_VERSION" ]]; then
239-
echo "Installing stable jax, jaxlib ${JAX_VERSION}"
227+
echo -e "\nInstalling stable jax, jaxlib ${JAX_VERSION}"
228+
version_mismatch_warning "jax"
240229
python3 -m uv pip install -U "jax[cuda12]==${JAX_VERSION}"
241-
else
242-
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
243-
python3 -m uv pip install "jax[cuda12]"
244-
fi
245-
export NVTE_FRAMEWORK=jax
246-
if [[ -n "$JAX_VERSION" && "$JAX_VERSION" != "0.7.0" ]]; then
247-
python3 -m uv pip install transformer-engine[jax]
248-
else
249-
python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@9d031f
250230
fi
251231
fi
252-
elif [[ $MODE == "nightly" ]]; then
253-
# Nightly mode
232+
exit 0
233+
fi
254234

255-
# Uninstall existing jax, jaxlib and libtpu-nightly
235+
# nightly mode installation
236+
if [[ $MODE == "nightly" ]]; then
237+
install_maxtext_with_deps
238+
239+
# Uninstall existing jax, jaxlib and libtpu
256240
python3 -m uv pip show jax && python3 -m uv pip uninstall jax
257241
python3 -m uv pip show jaxlib && python3 -m uv pip uninstall jaxlib
258-
python3 -m uv pip show libtpu-nightly && python3 -m uv pip uninstall libtpu-nightly
242+
python3 -m uv pip show libtpu && python3 -m uv pip uninstall libtpu
259243

260-
if [[ $DEVICE == "gpu" ]]; then
261-
# Install jax-nightly
244+
if [[ $DEVICE == "tpu" ]]; then
262245
if [[ -n "$JAX_VERSION" ]]; then
263-
echo "Installing jax-nightly, jaxlib-nightly ${JAX_VERSION}"
264-
python3 -m uv pip install -U --pre jax==${JAX_VERSION} jaxlib==${JAX_VERSION} jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
246+
echo -e "\nInstalling jax-nightly, jaxlib-nightly ${JAX_VERSION}"
247+
python3 -m uv pip install -U --pre --no-deps jax==${JAX_VERSION} jaxlib==${JAX_VERSION} -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
265248
else
266-
echo "Installing latest jax-nightly, jaxlib-nightly"
267-
python3 -m uv pip install -U --pre jax jaxlib jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
249+
echo -e "\nInstalling the latest jax-nightly, jaxlib-nightly"
250+
python3 -m uv pip install --pre -U --no-deps jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
268251
fi
269-
# Install Transformer Engine
270-
export NVTE_FRAMEWORK=jax
271-
python3 -m uv pip install https://github.com/NVIDIA/TransformerEngine/archive/9d031f.zip
272-
elif [[ $DEVICE == "tpu" ]]; then
273-
echo "Installing nightly tensorboard plugin profile"
274-
python3 -m uv pip install tbp-nightly --upgrade
275-
echo "Installing jax-nightly, jaxlib-nightly"
276-
# Install jax-nightly
277-
python3 -m uv pip install --pre -U jax -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
278-
# Install jaxlib-nightly
279-
python3 -m uv pip install --pre -U jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
252+
280253
if [[ -n "$LIBTPU_GCS_PATH" ]]; then
281-
# Install custom libtpu
282-
echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
283-
# Install required dependency
284-
python3 -m uv pip install -U crcmod
285-
# Copy libtpu.so from GCS path
286-
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
254+
install_custom_libtpu
287255
else
288-
# Install libtpu-nightly
289-
echo "Installing libtpu-nightly"
290-
python3 -m uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
256+
echo -e "\nInstalling the latest libtpu-nightly"
257+
python3 -m uv pip install -U --pre --no-deps libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
258+
fi
259+
elif [[ $DEVICE == "gpu" ]]; then
260+
if [[ -n "$JAX_VERSION" ]]; then
261+
echo -e "\nInstalling jax-nightly, jaxlib-nightly ${JAX_VERSION}"
262+
python3 -m uv pip install -U --pre --no-deps jax==${JAX_VERSION} jaxlib==${JAX_VERSION} \
263+
jax-cuda12-plugin[with-cuda]==${JAX_VERSION} jax-cuda12-pjrt==${JAX_VERSION} -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
264+
else
265+
echo -e "\nInstalling the latest jax-nightly, jaxlib-nightly"
266+
python3 -m uv pip install -U --pre --no-deps jax jaxlib \
267+
jax-cuda12-plugin[with-cuda] jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
291268
fi
292269
fi
293-
echo "Installing nightly tensorboard plugin profile"
294-
python3 -m uv pip install tbp-nightly --upgrade
295-
else
296-
echo -e "\n\nError: You can only set MODE to [stable,nightly,libtpu-only].\n\n"
297-
exit 1
270+
exit 0
298271
fi
272+
273+
# libtpu-only mode installation
274+
if [[ "$MODE" == "libtpu-only" ]]; then
275+
if [[ "$DEVICE" != "tpu" ]]; then
276+
echo -e "\n\nError: MODE=libtpu-only is only supported for DEVICE=tpu.\n\n"
277+
exit 1
278+
fi
279+
if [[ -z "$LIBTPU_GCS_PATH" ]]; then
280+
echo -e "\n\nError: LIBTPU_GCS_PATH must be set when MODE is libtpu-only.\n\n"
281+
exit 1
282+
fi
283+
install_custom_libtpu
284+
exit 0
285+
fi
286+
287+
echo -e "\n\nError: MODE must be either 'stable', 'nightly', or 'libtpu-only'.\n\n"
288+
exit 1

0 commit comments

Comments
 (0)