Skip to content

Commit dcc6a55

Browse files
committed
Fix Broken MaxText CI
1 parent 4ef205e commit dcc6a55

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
from MaxText import sharding
4343
from MaxText.common_types import ShardMode
4444
from MaxText.globals import EPS
45-
# pylint: disable-next=unused-import
46-
from maxtext import maxtext_google
4745

4846
from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad
4947
from MaxText.vocabulary_tiling import vocab_tiling_linen_loss

tests/unit/llama4_layers_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
from MaxText.common_types import MODEL_MODE_TRAIN, AttentionType
2929
from MaxText import pyconfig
30-
from maxtext.layers import attentions, embeddings, llama4
30+
from maxtext.layers import attentions, embeddings
31+
from maxtext.models import llama4
3132
from maxtext.utils import maxtext_utils
3233
import numpy as np
3334
from tests.utils.test_helpers import get_test_config_path

tests/unit/pipeline_parallelism_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626
import jax.numpy as jnp
2727
from jax.sharding import Mesh
2828
from MaxText import pyconfig
29-
from maxtext.common.gcloud_stub import is_decoupled
3029
from MaxText.common_types import MODEL_MODE_TRAIN
3130
from MaxText.globals import MAXTEXT_ASSETS_ROOT
31+
from maxtext.common.gcloud_stub import is_decoupled
3232
from maxtext.layers import nnx_wrappers
3333
from maxtext.layers import pipeline
34-
from maxtext.layers import simple_layer
34+
from maxtext.models import deepseek
35+
from maxtext.models import simple_layer
36+
from maxtext.utils import maxtext_utils
3537
from maxtext.trainers.pre_train.train import main as train_main
38+
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory
3639
import pytest
3740

3841

0 commit comments

Comments
 (0)