Skip to content

Commit 8266254

Browse files
author
maxtext authors
committed
Merge pull request #1668 from SamuelMarks:pylint-c
PiperOrigin-RevId: 756037012
2 parents 6247f44 + 683c221 commit 8266254

82 files changed

Lines changed: 705 additions & 379 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/CPUTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
pytype --jobs auto --disable 'import-error,late-directive,wrong-arg-types,module-attr,unsupported-operands' MaxText/ || true
3737
- name: Analysing the code with pylint in Maxtext/
3838
run: |
39-
pylint --disable C0301,C3001,C0114,C0115,C0116,C0200,C0121,C0201,C0206,C0209,C0412,C0415,C2801,E0102,E0606,E1102,E1111,E1123,E1135,E1136,R0401,R1701,R1703,R1710,R1711,R1735,R0917,R1714,R1716,R1719,R1721,R1728,R1728,W0102,W0107,W0201,W0212,W0221,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
39+
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable E0102,E0606,E0611,E1102,E1111,E1120,E1121,E1123,E1135,E1136,R0401,R1701,R1703,R1710,R1711,R1735,R0917,R1714,R1716,R1719,R1721,R1728,R1728,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
4040
echo 'Maxtext PyLint check successful' || { echo \
4141
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
4242
- name: Analysing the code with pylint in pedagogical_examples/

MaxText/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171

7272

7373
class DecoderBlockType(enum.Enum):
74+
"""Decoder block types."""
75+
7476
DEFAULT = "default"
7577
LLAMA2 = "llama2"
7678
MISTRAL = "mistral"

MaxText/convert_gpt3_ckpt_from_paxml.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
limitations under the License.
1212
"""
1313

14-
from MaxText.globals import PKG_DIR
15-
1614
# pylint: disable=line-too-long
1715
"""Convert weights from a paxml gpt3 model to a MaxText one.
1816
@@ -34,27 +32,31 @@
3432
--run-name=$RUN_NAME \
3533
--base-output-directory=$BASE_OUTPUT_DIR
3634
"""
37-
from MaxText import max_utils
38-
from MaxText import maxtext_utils
39-
from MaxText import optimizers
40-
from MaxText import pyconfig
35+
36+
import argparse
37+
import gc
4138
import os
39+
import sys
40+
41+
from psutil import Process
42+
43+
import numpy as np
44+
45+
import jax
4246
from jax import random
4347
from jax.sharding import Mesh
44-
from MaxText.layers.models import Transformer
45-
from MaxText.layers import quantizations
46-
from MaxText import checkpointing
4748

48-
import numpy as np
4949
import tensorstore as ts
5050

51-
import sys
52-
import jax
53-
import gc
51+
from MaxText import checkpointing
5452
from MaxText import max_logging
55-
from psutil import Process
53+
from MaxText import maxtext_utils
54+
from MaxText import optimizers
55+
from MaxText import pyconfig
56+
from MaxText.globals import PKG_DIR
57+
from MaxText.layers import quantizations
58+
from MaxText.layers.models import Transformer
5659
from MaxText.train import save_checkpoint
57-
import argparse
5860

5961

6062
def fmt_size(num_bytes: int) -> str:

MaxText/decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""CLI utility for running inference on a single/multi stream(s)"""
15+
"""CLI utility for running inference on a single/multi stream(s)."""
1616

1717
import os
1818
from typing import Sequence
@@ -65,7 +65,7 @@ def _pad_to_batch_size(data: jax.Array, batch_size: int):
6565
)
6666

6767
def _all_equals(elements: Sequence[jax.Array], target: jax.Array):
68-
"""Checks if each element equals the given target"""
68+
"""Checks if each element equals the given target."""
6969
stacked = jnp.stack(elements)
7070
row_comparisons = stacked == target
7171
return jnp.all(row_comparisons)

MaxText/deepseek_fp8_to_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128)
7070
return y
7171

7272

73-
def convert_fp8_to_bf16(fp8_path: string, bf16_path: string, cache_file_num: int = 2):
73+
def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, cache_file_num: int = 2):
7474
"""
7575
Converts a FP8 model to a BF16 model and saves the converted weights.
7676

MaxText/elastic_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def train_loop(config, elastic_manager, state=None):
187187
Args:
188188
config:
189189
state:
190-
ckpt_path:
190+
elastic_manager:
191191
Returns:
192192
"""
193193
# Create a GoodputRecorder to log information

MaxText/inference/kvcache.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616

1717
from typing import Any, Optional, Tuple
1818

19+
import jax
20+
import jax.numpy as jnp
21+
22+
from flax import linen as nn
23+
1924
from aqt.jax.v2 import aqt_tensor
2025
from aqt.jax.v2 import config as aqt_config
2126
from aqt.jax.v2.flax import aqt_flax
27+
2228
from MaxText import common_types
23-
from flax import linen as nn
24-
import jax
25-
import jax.numpy as jnp
2629

2730
Array = common_types.Array
2831
AxisNames = common_types.AxisNames
@@ -104,6 +107,7 @@ def einsum_fn_with_rhs_qtensor(
104107
lhs_dequant_mode=None,
105108
lhs_calibration_mode=None,
106109
):
110+
"""einsum function where QTensor is the right-hand-side"""
107111
# Assumes kv is already quantized.
108112
einsum = jnp.einsum
109113
if isinstance(kv, aqt_tensor.QTensor):
@@ -141,6 +145,7 @@ def einsum_fn_with_rhs_qtensor(
141145
return einsum
142146

143147
def einsum_fn_with_rhs_qtensor_and_dequant(self, value):
148+
"""Get einstein summation for different dequant modes."""
144149
if self.dtype == jnp.float8_e4m3fn:
145150
return self.einsum_fn_with_rhs_qtensor(
146151
value,
@@ -184,6 +189,7 @@ def _get_cache_scale_logical_shape(self, batch, heads, cache_length):
184189
raise f"Invalid config for kv_quant_axis:{self.kv_quant.axis_cfg}"
185190

186191
def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size, value_head_size, model_mode):
192+
"""Get a shaped abstraction of the state"""
187193

188194
cache_length = self.max_prefill_length
189195
dtype = self._get_cached_kv_dtype()
@@ -258,6 +264,7 @@ def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size,
258264
return key_vars, value_vars, cached_segment_id_var
259265

260266
def _get_ar_cache_vars(self, batch, key_heads, value_heads, key_head_size, value_head_size, model_mode):
267+
"""get ar cache vars"""
261268

262269
dtype = self._get_cached_kv_dtype()
263270
if self.max_target_length <= self.max_prefill_length:
@@ -602,6 +609,7 @@ def value_body(i, val):
602609
)
603610

604611
def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.Array | KVTensor:
612+
"""get cached values"""
605613
cache_var, cache_scale_var = cache_vars
606614
cache_value = cache_var.value
607615
if cache_scale_var is not None:

MaxText/inference/page_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
"""
2424

2525
from functools import partial
26+
from typing import Tuple
27+
2628
import jax
2729
import jax.numpy as jnp
30+
2831
from flax import struct
29-
from typing import Tuple
32+
3033
from jaxtyping import Array, Integer, Bool
3134

3235
from MaxText import common_types
@@ -477,7 +480,8 @@ def _validate_init_params(self) -> None:
477480
min_required = (self.max_target_length + self.tokens_per_page - 1) // self.tokens_per_page
478481
if self.max_pages_per_group < min_required:
479482
raise ValueError(
480-
f"`pagedattn_max_pages_per_group` ({self.max_pages_per_group}) is insufficient for `max_target_length` ({self.max_target_length}). Needs {min_required}."
483+
f"`pagedattn_max_pages_per_group` ({self.max_pages_per_group}) is insufficient for `max_target_length` "
484+
f"({self.max_target_length}). Needs {min_required}."
481485
)
482486
# Check > 1 due to potential page 0 workaround
483487
if self.num_pages <= 1:

MaxText/inference/paged_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252

5353

5454
class PagedAttentionOp(nn.Module):
55+
"""paged-attention op"""
56+
5557
mesh: Mesh
5658
num_pages: int
5759
tokens_per_page: int
@@ -101,6 +103,7 @@ def init_or_get_kv_pages(self, model_mode: str):
101103
return key_pages_var, value_pages_var
102104

103105
def paged_dot_product_attention_with_max_and_sum(self, query, key, value):
106+
"""paged dot product attention with max & sum"""
104107
b, t, n, d = query.shape
105108
_, s, n_kv, _ = key.shape
106109
query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d))
@@ -307,6 +310,7 @@ def update_prefill_step_pages(
307310
value_pages_var.value = nn.with_logical_constraint(value, self.kv_pages_axis_names)
308311

309312
def update_decode_step_pages(self, key_pages_var, value_pages_var, key, value, page_state):
313+
"""Update decode-step pages"""
310314
key_pages = key_pages_var.value
311315
value_pages = value_pages_var.value
312316

MaxText/inference/paged_attention_kernel_v2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def ref_ragged_paged_attention(
7979
sm_scale: float = 1.0,
8080
mask_value: float = DEFAULT_MASK_VALUE,
8181
):
82+
"""Ref ragged paged attention."""
8283
_, _, num_kv_heads, head_dim = k_pages.shape
8384
num_q_heads = queries.shape[1]
8485
assert num_q_heads % num_kv_heads == 0
@@ -117,6 +118,7 @@ def validate_inputs_on_runtime(
117118
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
118119
num_seqs, # i32[1]
119120
):
121+
"""validate inputs on runtime"""
120122
check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
121123
max_num_batched_tokens = q.shape[0]
122124
page_size = k_pages.shape[1]
@@ -148,6 +150,7 @@ def check_inputs_shapes(
148150
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
149151
num_seqs, # i32[1]
150152
):
153+
"""check shapes of inputs"""
151154
_, num_q_heads, head_dim = q.shape
152155
_, _, num_kv_heads, head_dim_k = k_pages.shape
153156
max_num_seqs, _ = page_indices.shape
@@ -199,6 +202,7 @@ def ragged_paged_attention_kernel(
199202
sm_scale: float,
200203
mask_value: float,
201204
):
205+
"""ragged paged-attention kernel"""
202206
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
203207
num_seqs = num_seqs_ref[0]
204208
_, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape
@@ -522,6 +526,7 @@ def get_dtype_packing(dtype):
522526

523527

524528
def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype):
529+
"""get min heads per block"""
525530
q_packing = get_dtype_packing(q_dtype)
526531
kv_packing = get_dtype_packing(kv_dtype)
527532

0 commit comments

Comments
 (0)