Skip to content

Commit 1e13fdc

Browse files
committed
add init.py
1 parent a0a8092 commit 1e13fdc

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

src/maxtext/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud
17+
TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters
18+
while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.
19+
"""
20+
21+
__author__ = "Google LLC"
22+
__version__ = "0.2.0"
23+
__description__ = (
24+
"MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and "
25+
"targeting Google Cloud TPUs and GPUs for training and **inference."
26+
)
27+
28+
from collections.abc import Sequence
29+
30+
import os
31+
# In order to have any effect on the C++ logging this has to be set before we import anything from jax.
32+
# When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger.
33+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0")
34+
del os
35+
36+
from jax.sharding import Mesh
37+
38+
from maxtext.configs import pyconfig
39+
from maxtext.models import models
40+
from maxtext.trainers.post_train.dpo import dpo_utils
41+
from maxtext.utils import maxtext_utils
42+
from maxtext.utils.model_creation_utils import *
43+
44+
Transformer = models.Transformer
45+
transformer_as_linen = models.transformer_as_linen

0 commit comments

Comments
 (0)