diff --git a/MaxCode/rag/rag_agent.py b/MaxCode/rag/rag_agent.py index 45d651e..0288dbf 100644 --- a/MaxCode/rag/rag_agent.py +++ b/MaxCode/rag/rag_agent.py @@ -1,8 +1,10 @@ """Tool for performing retrieval augmented generation.""" +import ast +import logging import os import sqlite3 -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import models from agents import base @@ -11,6 +13,8 @@ from rag import vector_db import numpy as np +logger = logging.getLogger(__name__) + # We use a hardcoded character limit for the full code context to avoid # exceeding the model's token limit. While the Gemini API does not provide a @@ -19,6 +23,74 @@ # when considering that the prompt sends file content in two fields. _MAX_CONTEXT_LENGTH = 100_000 +# Corpus tags supported by the RAG layer. Files whose basename starts with +# "maxtext_" are tagged "maxtext"; everything else falls back to "jax". +_KNOWN_CORPORA = ("jax", "maxtext") + + +def _query_prefix_for_target(target: str) -> str: + """Returns the human-readable prefix used in component-signature queries.""" + if target == "maxtext": + return "MaxText" + return "JAX Flax" + + +def _corpus_for_filename(filename: str) -> str: + """Auto-tags a file by its basename. `maxtext_*.py` -> 'maxtext', else 'jax'.""" + base = os.path.basename(filename) + if base.startswith("maxtext_"): + return "maxtext" + return "jax" + + +def _extract_component_signatures(code: str, target: str = "jax") -> list[str]: + """Extracts focused query strings per top-level class/function using AST. + + For classes: "{prefix} {ClassName} {base_classes} {method_names} {init_params}" + For functions: "{prefix} {func_name} {param_names}" + + The prefix is "JAX Flax" for `target="jax"` and "MaxText" for + `target="maxtext"`. + + Args: + code: Python source code to parse. + target: Conversion target — selects the query prefix. + + Returns: + A list of query strings, one per top-level component. + """ + try: + tree = ast.parse(code) + except SyntaxError: + return [] + + prefix = _query_prefix_for_target(target) + signatures = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + bases = [ + ast.unparse(b) if hasattr(ast, "unparse") else getattr(b, "id", "") + for b in node.bases + ] + methods = [ + n.name for n in ast.iter_child_nodes(node) + if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + init_params = [] + for n in ast.iter_child_nodes(node): + if isinstance(n, ast.FunctionDef) and n.name == "__init__": + init_params = [ + a.arg for a in n.args.args if a.arg != "self" + ] + break + parts = [prefix, node.name] + bases + methods + init_params + signatures.append(" ".join(parts)) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + params = [a.arg for a in node.args.args if a.arg != "self"] + parts = [prefix, node.name] + params + signatures.append(" ".join(parts)) + return signatures + class RAGAgent(base.Agent): """Tool for performing retrieval augmented generation.""" @@ -29,6 +101,7 @@ def __init__( embedding_model_name: models.EmbeddingModel, db_path: str = vector_db.RAG_DB_FILE, api_key: str | None = None, + target: str = "jax", ): """Initializes the agent. @@ -37,32 +110,59 @@ def __init__( embedding_model_name: Name of the embedding model to use. db_path: Path to the RAG SQLite database. api_key: The API key for Google AI services. + target: Conversion target ("jax" or "maxtext"). Selects which corpus + the agent retrieves from. """ super().__init__(model=model) self._db_path = db_path + self._target = target self._embedding_agent = embedding.EmbeddingAgent( model_name=embedding_model_name.value, api_key=api_key ) vector_db.create_db(db_path) - ( - self._ids, - self._names, - self._texts, - self._files, - self._index, - ) = vector_db.make_embedding_index(db_path) + self._index_by_corpus: Dict[str, Dict[str, Any]] = {} + self._refresh_indexes() + + def _refresh_indexes(self) -> None: + """Rebuilds per-corpus indexes from the database.""" + self._index_by_corpus = {} + for corpus in _KNOWN_CORPORA: + ids, names, texts, files, index = vector_db.make_embedding_index( + self._db_path, corpus=corpus + ) + self._index_by_corpus[corpus] = { + "ids": ids, + "names": names, + "texts": texts, + "files": files, + "index": index, + } + + def _active_corpus(self) -> Dict[str, Any]: + """Returns the corpus payload selected by `self._target`, with JAX fallback.""" + if self._target in self._index_by_corpus: + return self._index_by_corpus[self._target] + return self._index_by_corpus.get("jax", { + "ids": [], "names": [], "texts": [], "files": [], "index": None, + }) def build_from_directory(self, source_path: str): - """Builds RAG database from files in a source directory.""" + """Builds RAG database from files in a source directory. + + Each file is auto-tagged with a corpus based on its basename — files + starting with `maxtext_` are stored in the 'maxtext' corpus, all others + in 'jax'. Retrieval at query time is filtered by the agent's `target`. + """ for root, _, files in os.walk(source_path): for filename in files: if filename.endswith(".py"): file_path = os.path.join(root, filename) + corpus_tag = _corpus_for_filename(filename) try: with open(file_path, "r", encoding="utf-8", errors="replace") as f: content = f.read() doc_name = os.path.relpath(file_path, source_path) - print(f"Adding {doc_name} to RAG database...") + print(f"Adding {doc_name} to RAG database (corpus={corpus_tag})...") description = self.generate( prompts.CODE_DESCRIPTION, { @@ -78,13 +178,12 @@ def build_from_directory(self, source_path: str): file=file_path, embedding=np.array(embedding_vector), db_path=self._db_path, + corpus=corpus_tag, ) except (OSError, sqlite3.Error) as e: print(f"Skipping {file_path}: {e}") - # Refresh index - self._ids, self._names, self._texts, self._files, self._index = ( - vector_db.make_embedding_index(self._db_path) - ) + # Refresh per-corpus indexes + self._refresh_indexes() print("Finished building RAG database.") def retrieve_context( @@ -92,6 +191,8 @@ def retrieve_context( ) -> List[Dict[str, Any]]: """Retrieves relevant context from the vector DB based on the query. + Filters to the corpus selected by the agent's `target`. + Args: query: The query string to search for. top_k: The number of top results to return. @@ -100,22 +201,83 @@ def retrieve_context( A list of dictionaries, each containing 'name', 'text', 'file', and 'distance' for a retrieved document. """ - if self._index is None: + payload = self._active_corpus() + index = payload.get("index") + if index is None: return [] query_embedding = self._embedding_agent.embed(query) results = vector_db.search_embedding( - np.array(query_embedding), self._index, self._texts, top_k=top_k + np.array(query_embedding), index, payload["texts"], top_k=top_k ) + names = payload["names"] + files = payload["files"] retrieved_context = [] for text, distance, i in results: retrieved_context.append({ - "name": self._names[i], + "name": names[i], "text": text, - "file": self._files[i], + "file": files[i], "distance": distance, }) return retrieved_context + def retrieve_per_component_context( + self, + source_code: str, + top_k_per_component: int = 3, + max_total: int = 15, + ) -> List[Dict[str, Any]]: + """Retrieves RAG context using a hybrid full-file + per-component strategy. + + Combines broad domain context from the full source code query with + targeted results from per-component queries. This ensures the LLM gets + both the overall architectural patterns AND component-specific examples. + + Args: + source_code: The full Python source code to retrieve context for. + top_k_per_component: Number of results per component query. + max_total: Maximum total results to return after deduplication. + + Returns: + A deduplicated, distance-sorted list of retrieved documents. + """ + signatures = _extract_component_signatures(source_code, self._target) + + # Fall back to single-query if AST parsing yielded nothing + if not signatures: + logger.info("Per-component extraction failed, falling back to single query") + return self.retrieve_context(source_code, top_k=max_total) + + # Start with full-file query for broad domain context + best_by_file: Dict[str, Dict[str, Any]] = {} + full_results = self.retrieve_context(source_code, top_k=max_total) + for doc in full_results: + best_by_file[doc["file"]] = doc + + # If >12 components, batch into groups of 3-4 to cap embedding calls + if len(signatures) > 12: + batched = [] + for i in range(0, len(signatures), 4): + batched.append(" ".join(signatures[i:i + 4])) + queries = batched + else: + queries = signatures + + logger.info("Per-component RAG: %d queries from %d components (+ full-file)", + len(queries), len(signatures)) + + # Add per-component results, keeping best distance per file + for query in queries: + results = self.retrieve_context(query, top_k=top_k_per_component) + for doc in results: + fpath = doc["file"] + if fpath not in best_by_file or doc["distance"] < best_by_file[fpath]["distance"]: + best_by_file[fpath] = doc + + # Sort by distance, truncate to max_total + sorted_docs = sorted(best_by_file.values(), key=lambda d: d["distance"]) + return sorted_docs[:max_total] + def run(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]: """Runs RAG to retrieve context for a query.""" return self.retrieve_context(query, top_k) diff --git a/MaxCode/rag/sources/generic/docs_flax_basics.py b/MaxCode/rag/sources/generic/docs_flax_basics.py new file mode 100644 index 0000000..648ca0e --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_flax_basics.py @@ -0,0 +1,125 @@ +# Flax Linen Documentation: Flax Basics +# Source: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html +""" +Flax Basics: Complete Reference Documentation + +Core Workflow Components +======================== + +1. Model Instantiation and Parameter Initialization +---------------------------------------------------- +Flax uses nn.Module base class for all models. Parameters are NOT stored with models +themselves but rather initialized separately through the init() method using a PRNG key +and dummy input data. + +Key concept: The dummy input data triggers shape inference - you only declare the number +of features wanted in the output, and Flax automatically determines kernel dimensions +from input specifications alone. + +Parameters are returned as a pytree structure matching the model's architecture. + + import flax.linen as nn + import jax + import jax.numpy as jnp + + model = nn.Dense(features=5) + key = jax.random.PRNGKey(0) + params = model.init(key, jnp.ones((1, 3))) # shape inference from dummy input + +2. Forward Passes with apply() +------------------------------ +Models cannot be called directly. Use apply() with parameters: + + output = model.apply(params, x) + +3. Training with Gradient Descent +--------------------------------- +- Define loss function with jax.vmap() for vectorization +- Compute gradients using jax.value_and_grad() +- Update parameters iteratively with learning rate scaling + +4. Optimization with Optax +-------------------------- + import optax + tx = optax.adam(learning_rate=1e-3) + opt_state = tx.init(params) + grads = jax.grad(loss_fn)(params, x, y) + updates, opt_state = tx.update(grads, opt_state) + params = optax.apply_updates(params, updates) + +Defining Custom Models +====================== + +Module Basics +------------- +Custom models extend nn.Module (a Python dataclass) with: +- Data fields for configuration +- setup() method for submodule registration +- __call__() method for forward computation + +Explicit approach (using setup): + + class ExplicitMLP(nn.Module): + features: Sequence[int] + + def setup(self): + self.layers = [nn.Dense(feat) for feat in self.features] + + def __call__(self, inputs): + x = inputs + for i, layer in enumerate(self.layers[:-1]): + x = nn.relu(layer(x)) + x = self.layers[-1](x) + return x + +Compact approach (using @nn.compact): + + class SimpleMLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, inputs): + x = inputs + for i, feat in enumerate(self.features[:-1]): + x = nn.relu(nn.Dense(feat, name=f'layers_{i}')(x)) + x = nn.Dense(self.features[-1], name=f'layers_{len(self.features)-1}')(x) + return x + +Parameter Declaration +--------------------- +Custom parameters use self.param() within modules: + + kernel = self.param('kernel', + self.kernel_init, + (inputs.shape[-1], self.features)) + +Arguments: +- Name for parameter identification in pytree +- Initialization function with signature (PRNGKey, *args, **kwargs) +- Shape and dtype arguments passed to init function + +Variables and State Management +------------------------------ +Beyond parameters, modules can maintain mutable state through variables: + +Pattern: self.variable(collection_name, variable_name, init_fn, *args) + +Usage example - batch normalization with running mean: +- Detect initialization via self.has_variable() +- Create tracked variables with self.variable() +- Update during apply() with mutable=['collection_name'] +- Extract and update state between training steps + +State update pattern: + + y, updated_state = model.apply(variables, x, mutable=['batch_stats']) + variables = flax.core.freeze({'params': params, **updated_state}) + +This separates mutable state from frozen parameters for explicit control during training. + +Serialization +------------- +- serialization.to_bytes() - convert parameters to byte representation +- serialization.to_state_dict() - convert to dictionary format +- serialization.from_bytes() - restore from bytes using a template structure +""" diff --git a/MaxCode/rag/sources/generic/docs_flax_layers_api.py b/MaxCode/rag/sources/generic/docs_flax_layers_api.py new file mode 100644 index 0000000..a18b0bc --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_flax_layers_api.py @@ -0,0 +1,157 @@ +# Flax Linen Layers API Reference +# Source: https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html +""" +Flax Linen Layers API Reference +================================ + +Linear Modules +-------------- + +Dense(features, use_bias=True, dtype=None, param_dtype=float32, + kernel_init=variance_scaling, bias_init=zeros) + + A linear transformation applied over the last dimension of the input. + + layer = nn.Dense(features=4) + params = layer.init(jax.random.key(0), jnp.ones((1, 3))) + output = layer.apply(params, x) # x: [..., in_features] -> [..., 4] + +DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, + kernel_init=variance_scaling, bias_init=zeros) + + A linear transformation with flexible axes. Can contract over multiple axes. + + # Contract over axes 1 and -1, output features (4, 5) + layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) + params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) + +Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, + kernel_dilation=1, feature_group_count=1, use_bias=True, dtype=None) + + Convolution layer wrapping lax.conv_general_dilated. + + # 1D convolution + layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') + out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + + # Causal 1D convolution (pad left only) + layer = nn.Conv(features=4, kernel_size=(3,), padding=((2, 0),)) + +Embedding Module +----------------- + +Embed(num_embeddings, features, dtype=None, param_dtype=float32, + embedding_init=variance_scaling) + + A parameterized function from integers [0, num_embeddings) to features-dimensional vectors. + + layer = nn.Embed(num_embeddings=50000, features=768) + variables = layer.init(jax.random.key(0), jnp.array([[0, 1, 2]])) + embeddings = layer.apply(variables, input_ids) # [batch, seq_len, features] + + # attend() method for output projection (weight tying): + logits = layer.attend(hidden_states) # [batch, seq_len, num_embeddings] + # Note: For exact PyTorch weight-tying equivalence, prefer explicit matmul: x @ embed.embedding.T + +Normalization Layers +--------------------- + +LayerNorm(epsilon=1e-6, dtype=None, use_bias=True, use_scale=True, + reduction_axes=-1, feature_axes=-1) + + Layer normalization. Normalizes over the last axis by default. + + norm = nn.LayerNorm() + variables = norm.init(jax.random.key(0), x) + y = norm.apply(variables, x) + +RMSNorm(epsilon=1e-6, dtype=None, use_scale=True, scale_init=ones, + reduction_axes=-1, feature_axes=-1) + + RMS Layer normalization. Normalizes by root mean square without re-centering. + More efficient than LayerNorm as it skips the mean computation. + + norm = nn.RMSNorm() + variables = norm.init(jax.random.key(0), x) + y = norm.apply(variables, x) + + # Custom implementation pattern (common in LLMs): + class CustomRMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + + @nn.compact + def __call__(self, x): + weight = self.param('weight', nn.initializers.ones, (self.dim,)) + variance = jnp.mean(x ** 2, axis=-1, keepdims=True) + x = x * jax.lax.rsqrt(variance + self.eps) + return weight * x + +GroupNorm(num_groups=32, epsilon=1e-6, use_bias=True, use_scale=True) + + Group normalization. Statistics shared across equally-sized groups of channels. + +Attention Modules +------------------ + +MultiHeadDotProductAttention(num_heads, dtype=None, qkv_features=None, + out_features=None, dropout_rate=0.0, deterministic=None, + kernel_init=variance_scaling, use_bias=True, + attention_fn=dot_product_attention, decode=False, normalize_qk=False) + + Multi-head dot-product attention mechanism. + + layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=64) + + # Self-attention + variables = layer.init(jax.random.key(0), x) + out = layer.apply(variables, x) + + # Cross-attention + out = layer.apply(variables, query, key, value) + + # With causal mask + mask = nn.make_causal_mask(jnp.ones((batch, seq_len))) + out = layer.apply(variables, x, mask=mask, deterministic=True) + + # Autoregressive decoding with KV cache + layer = nn.MultiHeadDotProductAttention(num_heads=8, decode=True) + variables = layer.init(jax.random.key(0), x) + # variables['cache'] contains cached keys and values + # Note: For PyTorch->JAX migrations, prefer pre-allocated NamedTuple buffers + # over Flax's decode=True mutable cache (see targeted_kvcache_prefill_decode_jax.py) + + Key parameters: + - decode=True: enables autoregressive KV caching + - normalize_qk=True: applies QK normalization + - deterministic=True: disables dropout + +Mask Utilities +--------------- + +make_causal_mask(x, extra_batch_dims=0, dtype=bool) + Creates a causal attention mask from input shape. + + mask = nn.make_causal_mask(jnp.ones((1, seq_len))) + # Returns [1, 1, seq_len, seq_len] boolean mask + +make_attention_mask(query_input, key_input, pairwise_fn=jnp.multiply, + extra_batch_dims=0, dtype=bool) + Creates an attention mask from query and key padding masks. + + query_mask = jnp.array([1, 1, 1, 0]) # 1=valid, 0=padded + key_mask = jnp.array([1, 1, 0, 0]) + mask = nn.make_attention_mask(query_mask, key_mask) + +Activation Functions +--------------------- +nn.relu, nn.gelu, nn.silu (swish), nn.softmax, nn.tanh, nn.sigmoid, nn.elu + + x = nn.silu(x) # SiLU/Swish activation, common in modern LLMs + x = nn.gelu(x, approximate=False) + +Pooling Functions +------------------ +nn.max_pool(inputs, window_shape, strides=None, padding='VALID') +nn.avg_pool(inputs, window_shape, strides=None, padding='VALID') +""" diff --git a/MaxCode/rag/sources/generic/docs_flax_module_api.py b/MaxCode/rag/sources/generic/docs_flax_module_api.py new file mode 100644 index 0000000..213efad --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_flax_module_api.py @@ -0,0 +1,180 @@ +# Flax Linen Module API Reference +# Source: https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html +""" +Complete Flax Linen Module API Reference +========================================= + +flax.linen.Module is the foundational base class for all neural network modules in Flax. +All Flax Modules are Python 3.7 dataclasses and should override setup() rather than __init__. + +Setup vs Compact Patterns +-------------------------- + +Setup Pattern:: + + class MyModule(nn.Module): + features: Tuple[int, ...] = (16, 4) + + def setup(self): + self.dense1 = nn.Dense(self.features[0]) + self.dense2 = nn.Dense(self.features[1]) + + def __call__(self, x): + return self.dense2(nn.relu(self.dense1(x))) + +Compact Pattern:: + + class MyModule(nn.Module): + features: int = 16 + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.features)(x) + x = nn.relu(x) + return nn.Dense(4)(x) + +Initialization Methods +----------------------- + +init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs) + Initializes module variables. A single PRNGKey is treated as {'params': key}. + For multiple RNG streams, pass a dict: {'params': key1, 'dropout': key2}. + + model = MyModule() + variables = model.init(jax.random.key(0), dummy_input) + +init_with_output(rngs, *args, ...) + Returns both the output and variables as a tuple: (output, vars). + +lazy_init(rngs, *args, ...) + Initializes variables without computing on actual data. + Accepts jax.ShapeDtypeStruct for memory-efficient initialization. + +Execution Methods +------------------ + +apply(variables, *args, rngs=None, method=None, mutable=False, **kwargs) + Applies a module method to variables and returns output. + If mutable collections specified, returns (output, updated_state). + + output = model.apply(variables, x) + output, state = model.apply(variables, x, mutable=['batch_stats']) + +bind(variables, *args, rngs=None, mutable=False) + Creates an interactive Module instance. Useful for debugging. + +Variable Management +-------------------- + +param(name, init_fn, *init_args, unbox=True, **init_kwargs) + Declares read-only parameters in the "params" collection. + init_fn receives PRNG key automatically as first argument. + + # Inside @nn.compact or setup(): + kernel = self.param('kernel', nn.initializers.lecun_normal(), (in_feat, out_feat)) + bias = self.param('bias', nn.initializers.zeros, (out_feat,)) + +variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs) + Declares mutable or immutable variables in named collections. + Unlike param(), PRNG keys must be passed explicitly. + + # For KV cache or running statistics: + cache_key = self.variable('cache', 'cached_key', jnp.zeros, (max_len, head_dim)) + cache_key.value = updated_value # update during forward pass + +get_variable(col, name, default=None) + Retrieves variable values from specified collections. + +put_variable(col, name, value) + Updates mutable variable values. + +has_variable(col, name) + Checks variable existence. Useful for conditional initialization. + + is_initialized = self.has_variable('cache', 'cached_key') + +RNG Management +--------------- + +make_rng(name='params') + Returns a new PRNG key from a named RNG sequence. + Each call splits the previous key for new values. + + dropout_key = self.make_rng('dropout') + +Inspection Methods +------------------- + +is_initializing() + Returns True when running under module.init() or nn.init()(). + + if self.is_initializing(): + # Do initialization-specific logic + cache = jnp.zeros((max_len, features)) + +is_mutable_collection(col) + Checks if a variable collection is mutable during current execution. + +path (property) + Returns the module's path as a tuple. + +Intermediate Value Capture +--------------------------- + +sow(col, name, value, reduce_fn=, init_fn=) + Stores intermediate values without explicit container passing. + + self.sow('intermediates', 'attention_weights', attn_weights) + # Later: y, state = model.apply(variables, x, mutable=['intermediates']) + +Complete Training Pattern +-------------------------- + +:: + + class Transformer(nn.Module): + config: TransformerConfig + + @nn.compact + def __call__(self, x, train=False): + x = nn.Dense(self.config.hidden_size)(x) + x = nn.Dropout(rate=0.1, deterministic=not train)(x) + x = nn.LayerNorm()(x) + return nn.Dense(self.config.vocab_size)(x) + + model = Transformer(config=config) + variables = model.init({'params': key1, 'dropout': key2}, dummy_input) + + # Training step + def train_step(variables, batch, dropout_rng): + def loss_fn(params): + logits = model.apply( + {'params': params}, + batch['input'], + train=True, + rngs={'dropout': dropout_rng} + ) + return cross_entropy_loss(logits, batch['labels']) + + grads = jax.grad(loss_fn)(variables['params']) + return grads + +Multiple RNG Streams +--------------------- + +:: + + class NoisyModel(nn.Module): + @nn.compact + def __call__(self, x, add_noise=False): + x = nn.Dense(16)(x) + if add_noise: + noise_key = self.make_rng('noise') + x = x + jax.random.normal(noise_key, x.shape) + return nn.Dense(1)(x) + + model = NoisyModel() + rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} + variables = model.init(rngs, x) + out = model.apply(variables, x, add_noise=True, rngs=rngs) +""" diff --git a/MaxCode/rag/sources/generic/docs_flax_setup_vs_compact.py b/MaxCode/rag/sources/generic/docs_flax_setup_vs_compact.py new file mode 100644 index 0000000..edaf2d0 --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_flax_setup_vs_compact.py @@ -0,0 +1,66 @@ +# Flax Linen Documentation: setup vs nn.compact +# Source: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/setup_or_nncompact.html +""" +Flax Linen: setup vs compact Documentation + +Overview +-------- +Flax's module system provides two distinct approaches for defining submodules and variables: + +Explicit Definition (setup): Variables and submodules are assigned to self. within a +setup() method, mirroring PyTorch's conventional pattern. Forward pass logic is then +implemented in separate methods. + +Inline Definition (nn.compact): Network architecture is written directly within a single +method marked with the @nn.compact decorator, collocating component definitions with +their usage points. + +Both methods are functionally equivalent and fully interoperable throughout Flax. + +Code Examples +------------- + +Setup Approach:: + + class MLP(nn.Module): + def setup(self): + self.dense1 = nn.Dense(32) + self.dense2 = nn.Dense(32) + + def __call__(self, x): + x = self.dense1(x) + x = nn.relu(x) + x = self.dense2(x) + return x + +Compact Approach:: + + class MLP(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(32, name="dense1")(x) + x = nn.relu(x) + x = nn.Dense(32, name="dense2")(x) + return x + +When to Choose Each Approach +---------------------------- + +Prefer nn.compact when: +- Reducing navigation between variable definitions and usage sites +- Handling conditional logic or loops that affect variable creation +- Aligning code structure with mathematical notation +- Implementing shape inference dependent on input dimensions + +Prefer setup when: +- Maintaining PyTorch compatibility conventions +- Preferring explicit separation between definitions and application +- Requiring multiple distinct forward pass methods + +Key patterns for nn.compact: +- Submodules are instantiated inline: nn.Dense(features, name="layer_name")(x) +- Parameters declared via self.param('name', init_fn, shape) +- Variables declared via self.variable('collection', 'name', init_fn) +- Only one method per module can use @nn.compact +- Auto-naming: if no name= is provided, Flax assigns Dense_0, Dense_1, etc. +""" diff --git a/MaxCode/rag/sources/generic/docs_jax_gotchas.py b/MaxCode/rag/sources/generic/docs_jax_gotchas.py new file mode 100644 index 0000000..cbe30a1 --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_jax_gotchas.py @@ -0,0 +1,133 @@ +# JAX Common Gotchas and Patterns +# Source: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html +""" +JAX Sharp Bits: Common Gotchas and Patterns +============================================= + +Pure Functions +-------------- +JAX transforms and compilation work exclusively on functionally pure Python functions. +A pure function must satisfy: +- All input data enters through function parameters +- All results exit through function returns +- Invoking with identical inputs always produces identical outputs + +Side effects (print, global state, iterators) only execute on first JIT call: + + # BAD: print only runs on first call + @jit + def f(x): + print("called") # only prints once! + return x + 1 + + # BAD: global variable captured at trace time + g = 0. + @jit + def f(x): + return x + g # uses g=0 forever, even if g changes later + + # BAD: iterators have state + iterator = iter(range(10)) + jax.lax.fori_loop(0, 10, lambda i, x: x + next(iterator), 0) # WRONG + +Immutable Arrays and .at[] Updates +------------------------------------ +JAX arrays are immutable. Direct index assignment fails: + + jax_array[1, :] = 1.0 # TypeError! + +Use functional .at API instead: + + updated = jax_array.at[1, :].set(1.0) # set values + updated = jax_array.at[1, :].add(1.0) # add to values + updated = jax_array.at[1, :].mul(2.0) # multiply values + updated = jax_array.at[::2, 3:].add(7.) # slice indexing + +IMPORTANT: Inside JIT, the compiler optimizes .at[] to in-place when input isn't reused. +IMPORTANT: Slice sizes in JIT must be static (can't depend on array values). + +Random Numbers +-------------- +JAX uses explicit key-based state management (no global RNG state): + + key = jax.random.key(0) + key, subkey = jax.random.split(key) + x = jax.random.normal(subkey, (5, 5)) + + # Split for multiple independent uses + key, *subkeys = jax.random.split(key, num=4) + +Never reuse the same key for different random operations. + +Control Flow in JIT +-------------------- +Python if/else and for loops are traced once. Use JAX primitives for dynamic control: + + # Instead of: if x > 0: ... + result = jax.lax.cond(x > 0, true_fn, false_fn, x) + + # Instead of: for i in range(n): ... + result = jax.lax.fori_loop(0, n, body_fn, init_val) + + # For sequential state + accumulation: + final_carry, outputs = jax.lax.scan(step_fn, init_carry, xs) + + # For parallel prefix operations: + result = jax.lax.associative_scan(binary_fn, elems) + + # Dynamic while loop: + result = jax.lax.while_loop(cond_fn, body_fn, init_val) + +Static vs Dynamic Shapes +-------------------------- +All output and intermediate arrays must have static shape in JIT: + + # BAD: shape depends on values + x_filtered = x[~jnp.isnan(x)] # dynamic shape! + + # GOOD: use where to maintain static shape + x_clean = jnp.where(~jnp.isnan(x), x, 0) + +Out-of-Bounds Indexing +----------------------- +JAX can't raise errors from accelerators. Instead: +- Retrieval: indices clamped to bounds (returns last element) +- Updates: out-of-bounds ops silently skipped + + jnp.arange(10)[11] # Returns 9, not error + jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) # Returns nan + +Double Precision (64-bit) +-------------------------- +JAX defaults to float32. Enable float64 explicitly: + + jax.config.update("jax_enable_x64", True) # must run at startup + # Or: JAX_ENABLE_X64=True python script.py + +PyTree Patterns +---------------- +JAX operates on pytrees - nested structures of arrays. Common patterns: + + # Pytrees can be dicts, lists, tuples, NamedTuples, dataclasses + params = {'dense': {'kernel': w, 'bias': b}} + + # tree_map applies a function to all leaves + doubled = jax.tree_util.tree_map(lambda x: 2 * x, params) + + # Custom pytrees via register_pytree_node + from jax import tree_util + tree_util.register_pytree_node( + MyClass, + lambda obj: ((obj.dynamic_field,), {'static': obj.static_field}), + lambda aux, children: MyClass(*children, **aux) + ) + +Key Differences from NumPy +---------------------------- +- Arrays are immutable (use .at[] for updates) +- No in-place operations (+=, *= create new arrays) +- Explicit PRNG key management (no global state) +- Type promotion rules differ +- No dynamic shapes in JIT +- Out-of-bounds indexing clamps instead of raising +""" diff --git a/MaxCode/rag/sources/generic/docs_jax_lax_primitives.py b/MaxCode/rag/sources/generic/docs_jax_lax_primitives.py new file mode 100644 index 0000000..1f948e1 --- /dev/null +++ b/MaxCode/rag/sources/generic/docs_jax_lax_primitives.py @@ -0,0 +1,155 @@ +# JAX LAX Primitive Functions Documentation +# Source: https://docs.jax.dev/en/latest/jax.lax.html +""" +JAX LAX Primitive Functions +=========================== + +jax.lax.scan +------------- +Signature: scan(f, init, xs=None, length=None, reverse=False, unroll=1) + +Scan a function over leading array axes while carrying along state. +This enables sequential operations with accumulated results, similar to +a fold operation in functional programming. + +Parameters: +- f: Function taking (carry, x) and returning (new_carry, y) +- init: Initial carry value +- xs: Input sequence (optional, stacked along axis 0) +- length: Iteration count (optional, inferred from xs) +- reverse: Process in reverse order +- unroll: Loop unrolling factor + +Returns: (final_carry, stacked_ys) + +Example:: + + def cumsum(carry, x): + new_carry = carry + x + return new_carry, new_carry + + final, history = jax.lax.scan(cumsum, 0, jnp.array([1, 2, 3, 4])) + # final = 10, history = [1, 3, 6, 10] + +Use for recurrent computations, RNN cells, sequential state updates. +Inside nn.compact, use nn.scan to lift scan over Flax modules. + +jax.lax.associative_scan +-------------------------- +Signature: associative_scan(fn, elems, reverse=False, axis=0) + +Performs a scan with an associative binary operation, in parallel. +Unlike sequential scan, this exploits associativity for O(log n) depth. + +Parameters: +- fn: Binary associative function f(a, b) where f(f(a,b), c) == f(a, f(b,c)) +- elems: Array elements to process +- reverse: Reverse processing direction +- axis: Dimension along which to scan + +Example:: + + # Parallel prefix sum + result = jax.lax.associative_scan(jnp.add, jnp.array([1, 2, 3, 4])) + # result = [1, 3, 6, 10] + +jax.lax.dynamic_update_slice +------------------------------ +Signature: dynamic_update_slice(operand, update, start_indices) + +Wraps XLA's DynamicUpdateSlice operator. Updates a slice at dynamically +determined indices within a larger array. Useful for KV-cache updates. + +Example:: + + arr = jnp.zeros((5, 3)) + update = jnp.ones((2, 3)) + result = jax.lax.dynamic_update_slice(arr, update, (1, 0)) + # Updates rows 1-2 with ones + +Common pattern for KV cache:: + + cache = jax.lax.dynamic_update_slice( + cache, # existing cache [max_len, features] + new_kv[None], # new entry [1, features] + (cache_index, 0) # write position + ) + +jax.lax.dynamic_slice +----------------------- +Signature: dynamic_slice(operand, start_indices, slice_sizes) + +Wraps XLA's DynamicSlice operator. Extracts array slices using +runtime-determined start positions. + +Parameters: +- operand: Source array +- start_indices: Runtime start positions (one per dimension) +- slice_sizes: Static slice sizes (must be constants) + +Example:: + + arr = jnp.arange(10) + result = jax.lax.dynamic_slice(arr, (3,), (4,)) + # result = [3, 4, 5, 6] + +jax.lax.conv_general_dilated +------------------------------ +Signature: conv_general_dilated(lhs, rhs, window_strides, padding, + lhs_dilation=None, rhs_dilation=None, + dimension_numbers=None, precision=None) + +General n-dimensional convolution operator with optional dilation. + +Parameters: +- lhs: Input array +- rhs: Kernel weights +- window_strides: Stride configuration +- padding: 'SAME', 'VALID', or explicit padding pairs +- dimension_numbers: Tuple of (lhs_spec, rhs_spec, out_spec) strings + +Example for 1D causal convolution:: + + # Input: [batch, length, channels] -> need ('NHC', 'HIO', 'NHC') + out = jax.lax.conv_general_dilated( + x, kernel, + window_strides=(1,), + padding=((kernel_size - 1, 0),), # causal: pad left only + dimension_numbers=('NHC', 'HIO', 'NHC') + ) + +jax.lax.cond +-------------- +Signature: cond(pred, true_fun, false_fun, *operands) + +Conditionally apply true_fun or false_fun based on a boolean predicate. +Both branches are traced; use instead of Python if/else in JIT code. + +Example:: + + result = jax.lax.cond( + x > 0, + lambda x: x + 1, # true branch + lambda x: x - 1, # false branch + x + ) + +jax.lax.fori_loop +------------------- +Signature: fori_loop(lower, upper, body_fun, init_val) + +Loop from lower to upper by reduction to jax.lax.while_loop(). +Implements bounded iteration with state accumulation. + +Parameters: +- lower: Loop start index +- upper: Loop end index (exclusive) +- body_fun: Function(i, carry) -> new_carry +- init_val: Initial carry state + +Example:: + + def body(i, carry): + return carry + i + result = jax.lax.fori_loop(0, 10, body, 0) # 45 +""" diff --git a/MaxCode/rag/sources/generic/fla_layers_gated_deltanet.py b/MaxCode/rag/sources/generic/fla_layers_gated_deltanet.py new file mode 100644 index 0000000..967724b --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_layers_gated_deltanet.py @@ -0,0 +1,316 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from torch.nn import functional as F + +from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +@torch.compile +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +@torch.compile +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class GatedDeltaNet(nn.Module): + """ + The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa + + Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. + + Parameter alloation when use_gate=True: + - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each + - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each + - Others are ignorably small. + - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size + NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. + + Parameter allocation when use_gate=False: + - 1 * hidden_size * hidden_size for the q_proj and k_proj each + - 2 * hidden_size * hidden_size for the v_proj and o_proj each + - Others are ignorably small. + - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size + + Args: + hidden_size (int, Optional): + The hidden size of the input. Default: 2048. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + head_dim (int, Optional): + The dimension of each head. Default: 256. + num_heads (int, Optional): + The number of heads. Default: 4. + num_v_heads (int, Optional): + The number of heads for the value projection, equal to `num_heads` if `None`. + GVA is applied if `num_v_heads` > `num_heads`. Default: `None`. + mode (str, Optional): + Which Gated DeltaNet kernel to use. + Currently available: `chunk` and `fused_recurrent`. + Default: `chunk`. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `True`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + allow_neg_eigval (bool, Optional): + Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2. + See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537) + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the normalization layer. Default: 1e-5. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + num_v_heads: int = None, + mode: str = 'chunk', + use_gate: bool = True, + use_short_conv: bool = True, + allow_neg_eigval: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + **kwargs, + ) -> GatedDeltaNet: + super().__init__() + + self.mode = mode + self.allow_neg_eigval = allow_neg_eigval + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads + + self.head_k_dim = head_dim + self.head_v_dim = int(self.head_dim * self.expand_v) + self.key_dim = int(self.num_heads * self.head_k_dim) + self.value_dim = int(self.num_v_heads * self.head_v_dim) + self.layer_idx = layer_idx + + # Consistency check: Ensure expand_v produces integer values + if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. " + f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear.", + ) + if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0: + raise ValueError( + f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.", + ) + + if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. " + f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated.", + ) + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`." + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) + + A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min), + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu', + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu', + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu', + ) + else: + warnings.warn( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing.", + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps, dtype=torch.float32) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + output_attentions: bool | None = False, + **kwargs: Unpack[dict], + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + # change to inference mode. + mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = get_layer_cache(self, past_key_values) + + cu_seqlens = kwargs.get('cu_seqlens') + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + q = F.silu(self.q_proj(hidden_states)) + k = F.silu(self.k_proj(hidden_states)) + v = F.silu(self.v_proj(hidden_states)) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + + if self.num_v_heads > self.num_heads: + q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k)) + + beta = self.b_proj(hidden_states).sigmoid() + if self.allow_neg_eigval: + beta = beta * 2. + + g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + update_layer_cache( + self, + past_key_values, + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + offset=q_len, + ) + + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/MaxCode/rag/sources/generic/fla_models_gated_deltanet.py b/MaxCode/rag/sources/generic/fla_models_gated_deltanet.py new file mode 100644 index 0000000..a4823d4 --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_models_gated_deltanet.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Optional + +import torch +import torch.nn as nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.gated_deltanet import GatedDeltaNet +from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig +from fla.models.utils import Cache, FLAGenerationMixin +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm +from fla.modules import GatedMLP as GatedDeltaNetMLP +from fla.modules.l2warp import l2_warp + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +try: + from transformers.modeling_layers import GradientCheckpointingLayer +except ImportError: + from fla.models.modeling_layers import GradientCheckpointingLayer + +logger = logging.get_logger(__name__) + + +class GatedDeltaNetBlock(GradientCheckpointingLayer): + + def __init__(self, config: GatedDeltaNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx, + ) + else: + self.attn = GatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + num_v_heads=config.num_v_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + allow_neg_eigval=config.allow_neg_eigval, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + layer_idx=layer_idx, + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = GatedDeltaNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + use_cache: bool | None = False, + output_attentions: bool | None = False, + **kwargs: Unpack[dict], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GatedDeltaNetPreTrainedModel(PreTrainedModel): + + config_class = GatedDeltaNetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['GatedDeltaNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: str | None = None, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, GatedDeltaNet) and next(module.parameters()).device.type != 'meta': + with torch.no_grad(): + if not getattr(module.A_log, '_is_hf_initialized', False): + module.A_log.copy_(nn.init.uniform_(module.A_log, a=0, b=16).log()) + module.A_log._no_weight_decay = True + if not getattr(module.dt_bias, '_is_hf_initialized', False): + dt = torch.exp( + nn.init.uniform_(module.dt_bias) * (math.log(0.1) - math.log(0.001)) + math.log(0.001), + ).clamp(min=1e-4) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_weight_decay = True + + elif isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel): + + def __init__(self, config: GatedDeltaNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: torch.FloatTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs: Unpack[dict], + ) -> tuple | BaseModelOutputWithPast: + if output_attentions: + warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + ) + + +class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, FLAGenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GatedDeltaNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies", + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | None = 0, + **kwargs: Unpack[dict], + ) -> tuple | CausalLMOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + hidden_states = outputs[0] + + loss, logits = None, None + if not self.config.fuse_linear_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if self.config.fuse_linear_cross_entropy: + criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if self.config.fuse_linear_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + loss = l2_warp(loss, logits) if self.config.use_l2warp else loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/MaxCode/rag/sources/generic/fla_modules_l2norm.py b/MaxCode/rag/sources/generic/fla_modules_l2norm.py new file mode 100644 index 0000000..06f4a45 --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_modules_l2norm.py @@ -0,0 +1,282 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from fla.utils import IS_AMD, autotune_cache_kwargs, input_guard + +BT_LIST = [8, 16, 32, 64, 128] +NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16] if IS_AMD else [1, 2, 4, 8, 16, 32] + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE], + key=["D"], + **autotune_cache_kwargs, +) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + rstd, + eps, + D, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x) + eps) + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + tl.store(rstd + i_t, b_rstd) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE], + key=["D"], + **autotune_cache_kwargs, +) +@triton.jit +def l2norm_bwd_kernel1( + y, + rstd, + dy, + dx, + eps, + D, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + y += i_t * D + dx += i_t * D + dy += i_t * D + + cols = tl.arange(0, BD) + mask = cols < D + b_y = tl.load(y + cols, mask=mask, other=0.0).to(tl.float32) + b_rstd = tl.load(rstd + i_t).to(tl.float32) + b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32) + b_dx = b_dy * b_rstd - tl.sum(b_dy * b_y) * b_y * b_rstd + tl.store(dx + cols, b_dx, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D", "NB"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def l2norm_fwd_kernel( + x, + y, + rstd, + eps, + T, + D: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + BT: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x, 1) + eps) + b_y = b_x * b_rstd[:, None] + + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) + + +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D", "NB"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def l2norm_bwd_kernel( + y, + rstd, + dy, + dx, + eps, + T, + D: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + BT: tl.constexpr, +): + i_t = tl.program_id(0) + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + + b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32) + b_rstd = tl.load(p_rstd, boundary_check=(0,)).to(tl.float32) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + b_dx = b_dy * b_rstd[:, None] - tl.sum(b_dy * b_y, 1)[:, None] * b_y * b_rstd[:, None] + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def l2norm_fwd( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: torch.dtype | None = None, +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + if D <= 512: + # NOTE(tylerr): Avoid excessive recompilation and autotuning by tolerating a larger range + # of T before recompiling the kernel. + # NB = triton.cdiv(T, 2048) + NB = triton.cdiv(T, 2048 * 32) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x=x, + y=y, + rstd=rstd, + eps=eps, + T=T, + D=D, + BD=BD, + NB=NB, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x=x, + y=y, + rstd=rstd, + eps=eps, + D=D, + BD=BD, + ) + return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) + + +def l2norm_bwd( + y: torch.Tensor, + rstd: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-6, +): + y_shape_og = y.shape + y = y.view(-1, dy.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + assert dy.shape == y.shape + # allocate output + dx = torch.empty_like(y) + T, D = y.shape[0], y.shape[-1] + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // y.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if D <= 512: + # NOTE(tylerr): Avoid excessive recompilation and autotuning by tolerating a larger range + # of T before recompiling the kernel. + # NB = triton.cdiv(T, 2048) + NB = triton.cdiv(T, 2048 * 32) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_bwd_kernel[grid]( + y=y, + rstd=rstd, + dy=dy, + dx=dx, + eps=eps, + T=T, + D=D, + BD=BD, + NB=NB, + ) + else: + l2norm_bwd_kernel1[(T,)]( + y=y, + rstd=rstd, + dy=dy, + dx=dx, + eps=eps, + D=D, + BD=BD, + ) + + return dx.view(y_shape_og) + + +class L2NormFunction(torch.autograd.Function): + @staticmethod + @input_guard + def forward( + ctx, + x, + eps=1e-6, + output_dtype=None, + ): + y, rstd = l2norm_fwd(x, eps, output_dtype) + ctx.eps = eps + ctx.x_dtype = x.dtype + ctx.save_for_backward(y, rstd) + return y + + @staticmethod + @input_guard + def backward(ctx, dy): + y, rstd = ctx.saved_tensors + dx = l2norm_bwd(y, rstd, dy, ctx.eps) + return dx, None, None + + +def l2norm( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: torch.dtype | None = None, +) -> torch.Tensor: + return L2NormFunction.apply(x, eps, output_dtype) + + +l2_norm = l2norm + + +class L2Norm(nn.Module): + def __init__( + self, + eps: float = 1e-6, + output_dtype: torch.dtype | None = None, + ): + super().__init__() + self.eps = eps + self.output_dtype = output_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return l2norm(x, self.eps, self.output_dtype) diff --git a/MaxCode/rag/sources/generic/fla_modules_layernorm_gated.py b/MaxCode/rag/sources/generic/fla_modules_layernorm_gated.py new file mode 100644 index 0000000..7702653 --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_modules_layernorm_gated.py @@ -0,0 +1,527 @@ +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + +from fla.utils import get_multiprocessor_count, input_guard + + +def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): + dtype = x.dtype + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor = None, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None, +}) +@triton.jit +def layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DZ, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_z_row, + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dz_row, + stride_dw_row, + stride_db_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + group = tl.program_id(1) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + group * N + if HAS_Z: + Z += row_start * stride_z_row + group * N + DZ += row_start * stride_dz_row + group * N + DY += row_start * stride_dy_row + group * N + DX += row_start * stride_dx_row + group * N + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: + B += group * N + b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + x_og = x + x = x_og * z * tl.sigmoid(z) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.) + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + y = xhat * w + b if HAS_BIAS else xhat * w + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y * z * z_sigmoid, mask=mask) + dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dy *= z * z_sigmoid + else: + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + c1 = tl.sum(xhat * wdy, axis=0) / N + if not IS_RMS_NORM: + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + dx = (wdy - xhat * c1) * rstd + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_Z and not NORM_BEFORE_GATE: + z_sigmoid = tl.sigmoid(z) + dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dx *= z * z_sigmoid + # Write dx + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_Z: + Z += stride_z_row + DZ += stride_dz_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) + + +def layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + mean: torch.Tensor, + rstd: torch.Tensor, + z: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, + recompute_output: bool = False, + dz: torch.Tensor = None, + out: torch.Tensor = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + assert dz.stride(-1) == 1 + else: + dz = torch.empty_like(z) if z is not None else None + if recompute_output: + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + sm_count = get_multiprocessor_count(x.device.index) + # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs + # would limit the occupancy. + nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) + _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / nrow_groups) + grid = (nrow_groups, ngroups) + layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + z, + out if recompute_output else None, + dy, + dx, + _dw, + _db, + dz, + mean, + rstd, + x.stride(0), + z.stride(0) if z is not None else 0, + 0 if not recompute_output else out.stride(0), + dy.stride(0), + dx.stride(0), + dz.stride(0) if dz is not None else 0, + _dw.stride(0), + _db.stride(0) if _db is not None else 0, + M, group_size, eps, + rows_per_program, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) + + +class LayerNormFn(torch.autograd.Function): + + @input_guard + @staticmethod + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + @input_guard + @staticmethod + def backward(ctx, dy): + x, weight, bias, mean, rstd, z = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + dx, dw, db, dz = layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + z, + ctx.group_size, + ctx.norm_before_gate, + ctx.is_rms_norm, + ) + dx = dx.reshape(ctx.x_shape_og) + dz = dz.reshape(ctx.x_shape_og) if dz is not None else None + return dx, dw, db, dz, None, None, None, None + + +def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) + + +class LayerNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: int | None = None, + norm_before_gate: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: int | None = None, + norm_before_gate: bool = False, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/MaxCode/rag/sources/generic/fla_modules_rotary.py b/MaxCode/rag/sources/generic/fla_modules_rotary.py new file mode 100644 index 0000000..6f43be7 --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_modules_rotary.py @@ -0,0 +1,511 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import torch.nn as nn +import triton +import triton.language as tl +from einops import rearrange, repeat + +from fla.ops.utils import prepare_chunk_indices +from fla.utils import IS_AMD, autotune_cache_kwargs, get_multiprocessor_count, input_guard + +NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if IS_AMD else [2, 4, 8, 16, 32] + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) + + +def rotary_embedding_ref(x, cos, sin, interleaved=False): + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') + sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') + return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS_AUTOTUNE + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'D', 'INTERLEAVED'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def rotary_embedding_kernel( + x, + cos, + sin, + y, + cu_seqlens, + chunk_indices, + seq_offsets, + T, + B: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + R: tl.constexpr, + TR: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, +): + i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1) + T = eos - bos + x = x + bos * H*D + i_h * D + y = y + bos * H*D + i_h * D + else: + i_n = i_b + x = x + i_n * T*H*D + i_h * D + y = y + i_n * T*H*D + i_h * D + + if i_t * BT >= T: + return + + o_t = i_t * BT + tl.arange(0, BT) + if not IS_SEQLEN_OFFSETS_TENSOR: + o_cs = o_t + seq_offsets + else: + o_cs = o_t + tl.load(seq_offsets + i_n) + m_t = (o_t >= 0) & (o_t < T) & (o_cs >= 0) & (o_cs < TR) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out + o_r = tl.arange(0, BD // 2) + p_x = x + o_t[:, None] * H*D + o_r[None, :] + p_cos = cos + (o_cs[:, None] * R + o_r[None, :]) + p_sin = sin + (o_cs[:, None] * R + o_r[None, :]) + mask = m_t[:, None] & (o_r < R)[None, :] + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x + R, mask=mask, other=0.0).to(tl.float32) + if CONJUGATE: + b_sin = -b_sin + b_o0 = b_x0 * b_cos - b_x1 * b_sin + b_o1 = b_x0 * b_sin + b_x1 * b_cos + # write back result + p_y = y + (o_t[:, None] * H*D + o_r[None, :]) + tl.store(p_y, b_o0, mask=mask) + tl.store(p_y + R, b_o1, mask=mask) + else: + # We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + o_d = tl.arange(0, BD) + o_d_swap = o_d + ((o_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + o_d_repeat = tl.arange(0, BD) // 2 + p_x0 = x + o_t[:, None] * H*D + o_d[None, :] + p_x1 = x + o_t[:, None] * H*D + o_d_swap[None, :] + p_cos = cos + (o_cs[:, None] * R + o_d_repeat[None, :]) + p_sin = sin + (o_cs[:, None] * R + o_d_repeat[None, :]) + mask = m_t[:, None] & (o_d_repeat < R)[None, :] + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32) + if CONJUGATE: + b_sin = -b_sin + b_o0 = b_x0 * b_cos + b_o1 = b_x1 * b_sin + b_y = tl.where(o_d[None, :] % 2 == 0, b_o0 - b_o1, b_o0 + b_o1) + p_y = y + (o_t[:, None] * H*D + o_d[None, :]) + tl.store(p_y, b_y, mask=mask) + + +def rotary_embedding_fwdbwd( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, + chunk_indices: torch.LongTensor | None = None, +) -> torch.Tensor: + """ + Args: + x: [B, T, H, D]. + cos: [TR, R / 2] + sin: [TR, R / 2] + seqlen_offsets: integer or integer tensor of size [N] + cu_seqlens: [N + 1,] or None + + Returns: + y: [B, T, H, D] + """ + is_varlen = cu_seqlens is not None + + B, T, H, D = x.shape + N = B if not is_varlen else cu_seqlens.shape[0] - 1 + TR, R = cos.shape + R2 = R * 2 + + assert D <= 256, "Only support D <= 256" + assert TR >= T, f"TR must be >= T, got {TR} and {T}" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (N,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + else: + assert seqlen_offsets + T <= TR + + y = torch.empty_like(x) if not inplace else x + if R2 < D and not inplace: + y[..., R2:].copy_(x[..., R2:]) + + BD = triton.next_power_of_2(R2) + BT = min(128, triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))) + if chunk_indices is None and is_varlen: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = len(chunk_indices) if is_varlen else triton.cdiv(T, BT) + + grid = (NT, B, H) + rotary_embedding_kernel[grid]( + x, + cos, + sin, + y, + cu_seqlens, + chunk_indices, + seqlen_offsets, + B=B, + T=T, + H=H, + D=D, + R=R, + TR=TR, + BT=BT, + BD=BD, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + ) + return y + + +class RotaryEmbeddingFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.LongTensor | None = None, + ): + y = rotary_embedding_fwdbwd( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + interleaved=interleaved, + inplace=inplace, + chunk_indices=chunk_indices, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.chunk_indices = chunk_indices + return y if not inplace else x + + @staticmethod + @input_guard + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = rotary_embedding_fwdbwd( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + chunk_indices=ctx.chunk_indices, + ) + return dx, None, None, None, None, None, None, None + + +def rotary_embedding( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.LongTensor | None = None, +): + """ + Args: + x: [B, T, H, D] + cos, sin: [TR, R//2] + interleaved: + If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). + inplace: + If True, apply rotary embedding in-place. + seqlen_offsets: [N,] or int. + Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: [N + 1,] or None + + Returns: + out: [B, T, H, D] + """ + return RotaryEmbeddingFunction.apply( + x, + cos, + sin, + interleaved, + inplace, + seqlen_offsets, + cu_seqlens, + chunk_indices, + ) + + +class RotaryEmbedding(nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + scale_base: float | None = None, + interleaved: bool = False, + pos_idx_in_fp32: bool = True, + device: torch.device | None = None, + ): + """ + interleaved: + If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: + If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. + In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, we add this option. + """ + super().__init__() + + self.dim = dim + self.base = float(base) + self.scale_base = scale_base + self.interleaved = interleaved + self.pos_idx_in_fp32 = pos_idx_in_fp32 + self.device = device + + # Generate and save the inverse frequency buffer (non trainable) + self.register_buffer("inv_freq", torch.empty(-(dim // -2), dtype=torch.float32, device=device), persistent=False) + + scale = None + if scale_base is not None: + scale = torch.empty(-(dim // -2), dtype=torch.float32, device=device) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + self.inv_freq.copy_(self._compute_inv_freq(device=self.inv_freq.device)) + if self.scale_base is not None: + self.scale.copy_(self._compute_scale(device=self.scale.device)) + + def __repr__(self): + s = f"{self.__class__.__name__}(" + s += f"dim={self.dim}, " + s += f"base={self.base}, " + s += f"interleaved={self.interleaved}, " + if self.scale_base is not None: + s += f"scale_base={self.scale_base}, " + s += f"pos_idx_in_fp32={self.pos_idx_in_fp32})" + return s + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _compute_scale(self, device=None): + return (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + 0.4 * self.dim) / (1.4 * self.dim) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seqlen_offset: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + chunk_indices: torch.LongTensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + q: [B, T, H, D] + k: [B, T, H, D] + seqlen_offset: + [N] or int. + Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: [N + 1] or None + max_seqlen: int + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype) + if self.scale is None: + q = rotary_embedding( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + k = rotary_embedding( + k, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + else: + q = rotary_embedding( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + k = rotary_embedding( + k, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + return q, k diff --git a/MaxCode/rag/sources/generic/fla_modules_short_conv.py b/MaxCode/rag/sources/generic/fla_modules_short_conv.py new file mode 100644 index 0000000..ff29417 --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_modules_short_conv.py @@ -0,0 +1,241 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +"""Short convolution implementation for efficient causal convolutions.""" + +import warnings + +import torch +import torch.nn as nn +from einops import rearrange + +try: + from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_cuda + from causal_conv1d import causal_conv1d_update as causal_conv1d_update_cuda +except ImportError: + causal_conv1d_fn_cuda = None + causal_conv1d_update_cuda = None + + +class ShortConvolution(nn.Conv1d): + """Short convolution layer for efficient causal convolution operations. + + This class implements a depthwise 1D convolution with causal padding, + designed for efficient sequence processing. It supports multiple backends (Triton/CUDA) + and optional activation functions. + + Args: + hidden_size (int): Number of input/output channels (must be equal for depthwise conv) + kernel_size (int): Size of the convolution kernel + bias (bool, optional): Whether to include learnable bias. Defaults to False. + activation (Optional[str], optional): Activation function ('silu' or 'swish'). Defaults to 'silu'. + backend (Optional[str], optional): Backend implementation ('triton' or 'cuda'). Defaults to 'triton'. + device (Optional[torch.device], optional): Device to place the layer on. Defaults to None. + dtype (Optional[torch.dtype], optional): Data type for layer parameters. Defaults to None. + **kwargs: Additional keyword arguments (deprecated 'use_fast_conv1d' supported for compatibility) + + Attributes: + hidden_size (int): Number of channels + activation (Optional[str]): Selected activation function + backend (str): Actual backend being used (may differ from input due to availability) + + Note: + - Uses depthwise convolution (groups=hidden_size) for efficiency + - Applies causal padding (kernel_size-1) to ensure no future information leakage + - Falls back to Triton backend if CUDA backend is unavailable + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + bias: bool = False, + activation: str | None = 'silu', + backend: str | None = 'triton', + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs, + ): + super().__init__( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + bias=bias, + padding=kernel_size - 1, + device=device, + dtype=dtype, + ) + + self.hidden_size = hidden_size + self.activation = None + + if activation is not None: + assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." + self.activation = activation + + if 'use_fast_conv1d' in kwargs: + warnings.warn( + "The `use_fast_conv1d` parameter is deprecated and will be ignored. " + "Please use the `backend` parameter instead.", + ) + import os + self.backend = os.environ.get('FLA_CONV_BACKEND', backend) + if backend not in ['cuda', 'triton']: + raise ValueError(f"Invalid backend: {backend}, must be one of ['cuda', 'triton']") + if backend == 'cuda': + if causal_conv1d_fn_cuda is None: + warnings.warn( + "The `backend` parameter is set to `cuda`, but `causal_conv1d_fn` is not available. " + "Switching to the Triton implementation instead. " + "Consider installing `causal_conv1d` to enable the CUDA backend.", + ) + self.backend = 'triton' + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + if self.activation is not None: + s += ', activation={activation}' + s += f', backend={self.backend}' + return s.format(**self.__dict__) + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + cache: torch.Tensor | None = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x (`torch.Tensor`): + Tensor of shape `[B, T, D]`. `B` must be 1 if `cu_seqlens` is provided. + residual (`Optional[torch.Tensor]`): + Residual tensor of shape `[B, T, D]`. Default: `None`. + mask (`Optional[torch.Tensor]`): + Attention mask dealing with padded positions. + cache (`Optional[torch.Tensor]`): + Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size. + If provided, the cache is updated **inplace**. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, D, W]`. Default: `False`. + cu_seqlens (Optional[torch.LongTensor]): + Cumulative sequence lengths for each batch. Used for varlen. Default: `None`. + Shape: [B+1] + chunk_indices (Optional[torch.LongTensor]): + Chunk indices for variable-length sequences. Default: `None`. + + Returns: + Tensor of shape `[B, T, D]`. + """ + # Import here to avoid circular dependency + from fla.modules.conv.causal_conv1d import causal_conv1d + + B, T, *_ = x.shape + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + if mask is not None: + if cu_seqlens is not None: + raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time") + x = x.mul_(mask.unsqueeze(-1)) + + # in decoding phase, the cache (if provided) is updated inplace + if B * T == N: + y, cache = self.step( + x=x, + residual=residual, + cache=cache, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + return y, cache + + # cuda backend do not support: + # 1. both `cu_seqlens` and `cache` being provided + # 2. both `cu_seqlens` and `output_final_state` being provided + # and other small issues + # to simplify the implementation, we just switch to triton backend + if self.backend == 'cuda' and cache is not None: + warnings.warn( + "The CUDA backend does not support both `cu_seqlens` and `cache` being provided, " + "or both `cu_seqlens` and `output_final_state` being provided. " + "Switching to the Triton backend instead. ", + stacklevel=2, + ) + self.backend = 'triton' + + return causal_conv1d( + x=x, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + residual=residual, + initial_state=cache, + output_final_state=output_final_state, + activation=self.activation, + backend=self.backend, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + **kwargs, + ) + + def step( + self, + x: torch.Tensor, + residual: torch.Tensor, + cache: torch.Tensor, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + ): + from fla.modules.conv.triton.ops import causal_conv1d_update + + B, _, D, W = *x.shape, self.kernel_size[0] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + if output_final_state and cache is None: + cache = x.new_zeros(N, D, W) + # NOTE: we follow the fast mode that updates the cache in-place + if self.backend == 'triton': + return causal_conv1d_update( + x=x, + cache=cache, + residual=residual, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + + shape = x.shape + x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1) + # equivalent to: + # cache.copy_(cache.roll(shifts=-1, dims=-1)) + # cache[:, :, -1] = x + # y = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) + y = causal_conv1d_update_cuda( + x=x, + conv_state=cache, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + y = y.view(shape) + if residual is not None: + y.add_(residual) + return y, cache + + @property + def state_size(self) -> int: + return self.hidden_size * self.kernel_size diff --git a/MaxCode/rag/sources/generic/fla_ops_gated_delta_rule_naive.py b/MaxCode/rag/sources/generic/fla_ops_gated_delta_rule_naive.py new file mode 100644 index 0000000..0747e9b --- /dev/null +++ b/MaxCode/rag/sources/generic/fla_ops_gated_delta_rule_naive.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import torch.nn.functional as F +from einops import rearrange + + +def naive_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +): + """ + Reference PyTorch implementation of recurrent gated delta rule. + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + v: [B, T, H, V] + beta: [B, T, H] + g: [B, T, H] + scale: float, optional + initial_state: [B, H, K, V], optional + output_final_state: bool + + Returns: + o: [B, T, H, V] + final_state: [B, H, K, V] if output_final_state else None + """ + q, k, v, beta, g = map(lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]) + B, H, T, K, V = *k.shape, v.shape[-1] + o = torch.zeros(B, H, T, V).to(v) + h = torch.zeros(B, H, K, V).to(v) + if initial_state is not None: + h = initial_state.to(torch.float32) + if scale is None: + scale = 1 / (q.shape[-1] ** 0.5) + q = q * scale + + for i in range(T): + b_q = q[:, :, i] + b_k = k[:, :, i] + b_v = v[:, :, i].clone() + h = h.clone() * g[:, :, i].exp()[..., None, None] + b_beta = beta[:, :, i] + b_v = b_v - (h.clone() * b_k[..., None]).sum(-2) + b_v = b_v * b_beta[..., None] + h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', b_q, h) + + if not output_final_state: + h = None + o = o.transpose(1, 2).contiguous() + return o, h + + +def naive_chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + chunk_size: int = 64, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +): + """ + Reference PyTorch implementation of chunk gated delta rule. + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + v: [B, T, H, V] + g: [B, T, H] + beta: [B, T, H] + chunk_size: int + scale: float, optional + initial_state: [B, H, K, V], optional + output_final_state: bool + + Returns: + o: [B, T, H, V] + final_state: [B, H, K, V] if output_final_state else None + """ + BT = chunk_size + if scale is None: + scale = 1 / (q.shape[-1] ** 0.5) + + q, k, v, beta, g = map(lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]) + + T = q.shape[-2] + pad_len = (BT - (T % BT)) % BT + if pad_len > 0: + q = F.pad(q, (0, 0, 0, pad_len)) + k = F.pad(k, (0, 0, 0, pad_len)) + v = F.pad(v, (0, 0, 0, pad_len)) + beta = F.pad(beta, (0, pad_len)) + g = F.pad(g, (0, pad_len)) + + q, k, v, beta, g = map(lambda x: x.to(torch.float32), [q, k, v, beta, g]) + decay = g + chunk_size = BT + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * scale + v = v * beta[..., None] + k_beta = k * beta[..., None] + assert l % chunk_size == 0 + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta, decay = map( + lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), + [q, k, v, k_beta, decay.unsqueeze(-1)], + ) + decay = decay.squeeze(-1).cumsum(-1) + decay_exp = decay.exp()[..., None] + L_mask = ((decay.unsqueeze(-1) - decay.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ k.transpose(-1, -2)) * L_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i].clone() + (attn[..., i, :i, None].clone() * attn[..., :i, :i].clone()).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + attn = attn + k_cumsum = attn @ v + k_cumdecay = attn @ (k_beta * decay_exp) + v = k_cumsum + + S = k.new_zeros(b, h, d_k, d_v) + if initial_state is not None: + S = initial_state.to(torch.float32) + + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * L_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ S + v_new = v_i - v_prime + o_inter = (q_i * decay[:, :, i, :, None].exp()) @ S + o[:, :, i] = o_inter + attn @ v_new + S = S * decay[:, :, i, -1, None, None].exp() + (k_i * (decay[:, :, i, -1, None] - decay[:, :, i]).exp() + [..., None]).transpose(-1, -2) @ v_new + if not output_final_state: + S = None + + # unpad + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o[:, :, :T] + o = o.transpose(1, 2) + return o, S diff --git a/MaxCode/rag/sources/generic/flax_example_attention.py b/MaxCode/rag/sources/generic/flax_example_attention.py new file mode 100644 index 0000000..05d5378 --- /dev/null +++ b/MaxCode/rag/sources/generic/flax_example_attention.py @@ -0,0 +1,219 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from pprint import pprint +from typing import Any, Optional +from collections.abc import Callable, Sequence +from flax.core.frozen_dict import unfreeze +from flax.linen import initializers +from flax.linen import Module, compact, vmap +from flax.linen.linear import PrecisionLike +import jax +from jax import lax, numpy as jnp, random + + +class Dense(Module): + features: int + use_bias: bool = True + kernel_init: Callable = initializers.lecun_normal() + bias_init: Callable = initializers.zeros_init() + dtype: Any = jnp.float32 + precision: PrecisionLike = None + + @compact + def __call__(self, inputs): + inputs = jnp.asarray(inputs, self.dtype) + kernel = self.param( + 'kernel', self.kernel_init, (inputs.shape[-1], self.features) + ) + kernel = jnp.asarray(kernel, self.dtype) + y = lax.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + if self.use_bias: + bias = self.param('bias', self.bias_init, (self.features,)) + bias = jnp.asarray(bias, self.dtype) + y = y + bias + return y + + +class SoftmaxAttn(Module): + + @compact + def __call__(self, weights): + norm_dims = tuple(range(weights.ndim // 2, weights.ndim)) + return jax.nn.softmax(weights, axis=norm_dims) + + +class Dropout(Module): + rate: float + + @compact + def __call__(self, x, deterministic=False, rng=None): + if self.rate == 0.0: + return x + keep_prob = 1.0 - self.rate + + if deterministic: + return x + else: + if rng is None: + rng = self.scope.make_rng('dropout') + mask = random.bernoulli(rng, p=keep_prob, shape=x.shape) + return lax.select(mask, x / keep_prob, jnp.zeros_like(x)) + + +class SoftmaxAttnWDropout(Module): + rate: float = 0.0 + deterministic: bool = False + + @compact + def __call__(self, x): + x = SoftmaxAttn()(x) + x = Dropout(self.rate)(x, deterministic=self.deterministic) + return x + + +class RawDotProductAttention(Module): + attn_module: Callable = SoftmaxAttn + + @compact + def __call__(self, query, key, value, bias=None, dtype=jnp.float32): + assert key.ndim == query.ndim + assert key.ndim == value.ndim + + n = query.ndim + attn_weights = lax.dot_general(query, key, (((n - 1,), (n - 1,)), ((), ()))) + if bias is not None: + attn_weights += bias + attn_weights = self.attn_module()(attn_weights) + attn_weights = attn_weights.astype(dtype) + + contract_dims = ( + tuple(range(n - 1, attn_weights.ndim)), + tuple(range(0, n - 1)), + ) + y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) + return y + + +class DotProductAttention(Module): + qkv_features: int | None = None + out_features: int | None = None + attn_module: Callable = SoftmaxAttn + + @compact + def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): + qkv_features = self.qkv_features or inputs_q.shape[-1] + out_features = self.out_features or inputs_q.shape[-1] + + QKVDense = functools.partial( + Dense, features=qkv_features, use_bias=False, dtype=dtype + ) + query = QKVDense(name='query')(inputs_q) + key = QKVDense(name='key')(inputs_kv) + value = QKVDense(name='value')(inputs_kv) + + y = RawDotProductAttention(attn_module=self.attn_module)( + query, key, value, bias=bias, dtype=dtype + ) + y = Dense(features=out_features, dtype=dtype, name='out')(y) + return y + + +# Trying out a slightly more compact vmap notation: + + +def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): + variable_axes = { + k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence) + } + splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)} + return vmap( + module, + in_axes=in_axes, + out_axes=out_axes, + variable_axes=variable_axes, + split_rngs=splits, + axis_size=axis_size, + ) + + +class MultiHeadDotProductAttention(Module): + qkv_features: int | None = None + out_features: int | None = None + attn_module: Callable = SoftmaxAttn + batch_axes: Sequence[int] = (0,) + num_heads: int = 1 + broadcast_dropout: bool = False + + @compact + def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): + qkv_features = self.qkv_features or inputs_q.shape[-1] + out_features = self.out_features or inputs_q.shape[-1] + + # Now, vmap attn.__call__ along heads and spatial dims. + Attn = concise_vmap( + DotProductAttention, + (None, None, None), + -2, + param=(0, True), + dropout=(None, not self.broadcast_dropout), + axis_size=self.num_heads, + ) + for axis in reversed(sorted(self.batch_axes)): + Attn = concise_vmap( + Attn, + (axis, axis, axis), + axis, + param=(None, False), + dropout=(None, not self.broadcast_dropout), + ) + + attn = Attn( + attn_module=self.attn_module, + qkv_features=qkv_features // self.num_heads, + out_features=out_features, + ) + + # evaluate multi-headed-attention. + y = attn(inputs_q, inputs_kv, bias) + return y.mean(axis=-2) + + +# run it. + + +if __name__ == '__main__': + inputs = jnp.ones((8, 97, 256)) + rngs = {'params': random.key(0), 'dropout': random.key(1)} + model = MultiHeadDotProductAttention( + broadcast_dropout=False, + qkv_features=256, + out_features=256, + attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1), + num_heads=8, + batch_axes=(0,), + ) + + y, params = model.init_with_output(rngs, inputs, inputs) + + print('input shape: ', inputs.shape) + print('parameter shapes:') + pprint(jax.tree_util.tree_map(jnp.shape, unfreeze(params))) + print('output shape: ', y.shape) diff --git a/MaxCode/rag/sources/generic/flax_linen_attention.py b/MaxCode/rag/sources/generic/flax_linen_attention.py new file mode 100644 index 0000000..2e9de33 --- /dev/null +++ b/MaxCode/rag/sources/generic/flax_linen_attention.py @@ -0,0 +1,911 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attention core modules for Flax.""" +from __future__ import annotations + +import functools +import inspect +import warnings +from typing import Any, overload +from collections.abc import Callable + +import jax +import jax.numpy as jnp +from jax import lax, random + +from flax.linen import initializers +from flax.linen.dtypes import promote_dtype +from flax.linen.linear import ( + DenseGeneral, + default_kernel_init, +) +from flax.linen.module import Module, compact, merge_param +from flax.linen.normalization import LayerNorm +from flax.typing import ( + Array, + PRNGKey, + Dtype, + Shape as Shape, + Initializer, + PrecisionLike, + DotGeneralT, +) + + +def dot_product_attention_weights( + query: Array, + key: Array, + bias: Array | None = None, + mask: Array | None = None, + broadcast_dropout: bool = True, + dropout_rng: PRNGKey | None = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Dtype | None = None, + precision: PrecisionLike = None, + module: Module | None = None, + force_fp32_for_softmax: bool = False, + einsum_dot_general: Callable[..., Array] | None = None, + einsum: Callable[..., Array] | None = None, +): + """Computes dot-product attention weights given query and key. + + Used by :func:`dot_product_attention`, which is what you'll most likely use. + But if you want access to the attention weights for introspection, then + you can directly call this function and call einsum yourself. + + Args: + query: queries for calculating attention with shape of ``[batch..., + q_length, num_heads, qk_depth_per_head]``. + key: keys for calculating attention with shape of ``[batch..., kv_length, + num_heads, qk_depth_per_head]``. + bias: bias for the attention weights. This should be broadcastable to the + shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + mask: mask for the attention weights. This should be broadcastable to the + shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is ``False``. + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: infer from inputs and params) + precision: numerical precision of the computation see ``jax.lax.Precision`` + for details. + module: the Module that will sow the attention weights into the + 'intermediates' collection. Remember to mark 'intermediates' as mutable + via ``mutable=['intermediates']`` in order to have that collection + returned. If ``module`` is None, the attention weights will not be sowed. + force_fp32_for_softmax: bool, whether to force the softmax to be computed in + fp32. This is useful for mixed-precision training where higher precision + is desired for numerical stability. + einsum_dot_general: the dot_general to use in einsum. + einsum: If unspecified, default `jnp.einsum` will be used. This argument is + mutually exclusive with `precision` and `einsum_dot_general`. + + Raises: + ValueError: if both `precision`/`einsum_dot_general` and `einsum` are + specified. + + Returns: + Output of shape ``[batch..., num_heads, q_length, kv_length]``. + """ + if (precision or einsum_dot_general) and einsum: + raise ValueError( + 'precision/einsum_dot_general and einsum are mutually exclusive. Please' + ' specify only one of them.' + ) + if not einsum: + einsum = functools.partial( + jnp.einsum, + precision=precision, + _dot_general=einsum_dot_general + if einsum_dot_general + else jax.lax.dot_general, + ) + + query, key = promote_dtype(query, key, dtype=dtype) + dtype = query.dtype + + assert query.ndim == key.ndim, 'q, k must have same rank.' + assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' + assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' + assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + + # calculate attention matrix + depth = query.shape[-1] + query = query / jnp.sqrt(depth).astype(dtype) + # attn weight shape is (batch..., num_heads, q_length, kv_length) + attn_weights = einsum('...qhd,...khd->...hqk', query, key) + + # apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias + # apply attention mask + if mask is not None: + big_neg = jnp.finfo(dtype).min + attn_weights = jnp.where(mask, attn_weights, big_neg) + + # normalize the attention weights + if force_fp32_for_softmax and dtype != jnp.float32: + attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32)) + else: + attn_weights = jax.nn.softmax(attn_weights).astype(dtype) + + if module: + module.sow('intermediates', 'attention_weights', attn_weights) + + # apply attention dropout + if not deterministic and dropout_rate > 0.0: + keep_prob = 1.0 - dropout_rate + if broadcast_dropout: + # dropout is broadcast across the batch + head dimensions + dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore + else: + keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore + multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) + attn_weights = attn_weights * multiplier + + return attn_weights + + +def dot_product_attention( + query: Array, + key: Array, + value: Array, + bias: Array | None = None, + mask: Array | None = None, + broadcast_dropout: bool = True, + dropout_rng: PRNGKey | None = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Dtype | None = None, + precision: PrecisionLike = None, + module: Module | None = None, + force_fp32_for_softmax: bool = False, + einsum_dot_general: Callable[..., Array] | None = None, + qk_attn_weights_einsum: Callable[..., Array] | None = None, + attn_weights_value_einsum: Callable[..., Array] | None = None, +): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + + .. note:: + ``query``, ``key``, ``value`` needn't have any batch dimensions. + + Args: + query: queries for calculating attention with shape of ``[batch..., + q_length, num_heads, qk_depth_per_head]``. + key: keys for calculating attention with shape of ``[batch..., kv_length, + num_heads, qk_depth_per_head]``. + value: values to be used in attention with shape of ``[batch..., kv_length, + num_heads, v_depth_per_head]``. + bias: bias for the attention weights. This should be broadcastable to the + shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + mask: mask for the attention weights. This should be broadcastable to the + shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is ``False``. + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: infer from inputs) + precision: numerical precision of the computation see ``jax.lax.Precision` + for details. + module: the Module that will sow the attention weights into the + 'intermediates' collection. Remember to mark 'intermediates' as mutable + via ``mutable=['intermediates']`` in order to have that collection + returned. If ``module`` is None, the attention weights will not be sowed. + force_fp32_for_softmax: bool, whether to force the softmax to be computed in + fp32. This is useful for mixed-precision training where higher precision + is desired for numerical stability. + einsum_dot_general: the dot_general to use in `jnp.einsum`. + qk_attn_weights_einsum: the einsum for computing the attention weights. When + unspecified, the default `jnp.einsum` will be used. This argument is + mutually exclusive with `precision` and `einsum_dot_general`. + attn_weights_value_einsum: the einsum for computing the product of the + attention weights and the values. When unspecified, the default + `jnp.einsum` will be used. This argument is mutually exclusive with + `precision` and `einsum_dot_general`. + + Returns: + Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``. + + Raises: + ValueError: if both `precision`/`einsum_dot_general` and + `qk_attn_weights_einsum`/`attn_weights_value_einsum` are + specified. + """ + if (qk_attn_weights_einsum and not attn_weights_value_einsum) or ( + not qk_attn_weights_einsum and attn_weights_value_einsum + ): + raise ValueError( + 'qk_attn_weights_einsum and attn_weights_value_einsum must be specified' + ' together.' + ) + if (precision or einsum_dot_general) and ( + qk_attn_weights_einsum or attn_weights_value_einsum + ): + raise ValueError( + 'precision/einsum_dot_general and' + ' qk_attn_weights_einsum/attn_weights_value_einsum are mutually' + ' exclusive. Please specify only one of them.' + ) + + query, key, value = promote_dtype(query, key, value, dtype=dtype) + dtype = query.dtype + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), 'q, k, v batch dims must match.' + assert ( + query.shape[-2] == key.shape[-2] == value.shape[-2] + ), 'q, k, v num_heads must match.' + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + + # compute attention weights + attn_weights = dot_product_attention_weights( + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + module, + force_fp32_for_softmax, + einsum_dot_general=einsum_dot_general, + einsum=qk_attn_weights_einsum, + ) + if not attn_weights_value_einsum: + attn_weights_value_einsum = functools.partial( + jnp.einsum, + precision=precision, + _dot_general=einsum_dot_general + if einsum_dot_general + else jax.lax.dot_general, + ) + # return weighted sum over values for each query position + return attn_weights_value_einsum( + '...hqk,...khd->...qhd', + attn_weights, + value, + ) + + +class MultiHeadDotProductAttention(Module): + """Multi-head dot-product attention. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax + + >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) + >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) + >>> shape = (4, 3, 2, 5) + >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) + >>> variables = layer.init(jax.random.key(0), q) + + >>> # different inputs for inputs_q, inputs_k and inputs_v + >>> out = layer.apply(variables, q, k, v) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) + >>> out = layer.apply(variables, q, k) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) + >>> out = layer.apply(variables, q) + + >>> attention_kwargs = dict( + ... num_heads=8, + ... qkv_features=16, + ... kernel_init=nn.initializers.ones, + ... bias_init=nn.initializers.zeros, + ... dropout_rate=0.5, + ... deterministic=False, + ... ) + >>> class Module(nn.Module): + ... attention_kwargs: dict + ... + ... @nn.compact + ... def __call__(self, x, dropout_rng=None): + ... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... return out1, out2 + >>> module = Module(attention_kwargs) + >>> variables = module.init({'params': key1, 'dropout': key2}, q) + + >>> # out1 and out2 are different. + >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) + >>> # out3 and out4 are different. + >>> # out1 and out3 are different. out2 and out4 are different. + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) + >>> # out1 and out2 are the same. + >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) + >>> # out1 and out2 are the same as out3 and out4. + >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) + + Attributes: + num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: The dtype of the computation (default: infer from inputs and params) + param_dtype: The dtype passed to parameter initializers (default: float32) + qkv_features: Dimension of the key, query, and value. + out_features: Dimension of the last projection + broadcast_dropout: Use a broadcasted dropout along batch dims. + dropout_rate: Dropout rate. + deterministic: If False, the attention weight is masked randomly using + dropout, whereas if True, the attention weights are deterministic. + precision: Numerical precision of the computation see ``jax.lax.Precision`` + for details. + kernel_init: Initializer for the kernel of the Dense layers. + out_kernel_init: Optional Initializer for the kernel of the output Dense layer, + if None, ``kernel_init`` will be used. + bias_init: Initializer for the bias of the Dense layers. + out_bias_init: Optional Initializer for the bias of the output Dense layer, + if None, ``bias_init`` will be used. + use_bias: Whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts query, + key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,, + num_heads, value_channels]`` + decode: Whether to prepare and use an autoregressive cache. + normalize_qk: Should QK normalization be applied (arxiv.org/abs/2302.05442). + qk_attn_weights_einsum_cls: factory function to create the einsum for + computing the attention weights. + attn_weights_value_einsum_cls: factory function to create the einsum for + computing the product of the attention weights and the values. + """ + + num_heads: int + dtype: Dtype | None = None + param_dtype: Dtype = jnp.float32 + qkv_features: int | None = None + out_features: int | None = None + broadcast_dropout: bool = True + dropout_rate: float = 0.0 + deterministic: bool | None = None + precision: PrecisionLike = None + kernel_init: Initializer = default_kernel_init + out_kernel_init: Initializer | None = None + bias_init: Initializer = initializers.zeros_init() + out_bias_init: Initializer | None = None + use_bias: bool = True + attention_fn: Callable[..., Array] = dot_product_attention + decode: bool = False + normalize_qk: bool = False + force_fp32_for_softmax: bool = False + # Deprecated, will be removed. + qkv_dot_general: DotGeneralT | None = None + out_dot_general: DotGeneralT | None = None + qkv_dot_general_cls: Any = None + out_dot_general_cls: Any = None + qk_attn_weights_einsum_cls: Callable[..., Callable[..., Array]] | None = None + attn_weights_value_einsum_cls: Callable[..., Callable[..., Array]] | None = ( + None + ) + + @overload + def __call__( + self, + inputs_q: Array, + inputs_k: Array | None = None, + inputs_v: Array | None = None, + *, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, + sow_weights: bool = False, + ): + ... + + @overload + def __call__( + self, + inputs_q: Array, + *, + inputs_kv: Array | None = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, + sow_weights: bool = False, + ): + ... + + @compact + def __call__( + self, + inputs_q: Array, + inputs_k: Array | None = None, + inputs_v: Array | None = None, + *, + inputs_kv: Array | None = None, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, + sow_weights: bool = False, + ): + """Applies multi-head dot product attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + If both inputs_k and inputs_v are None, they will both copy the value of + inputs_q (self attention). + If only inputs_v is None, it will copy the value of inputs_k. + + Args: + inputs_q: input queries of shape ``[batch_sizes..., length, features]``. + inputs_k: key of shape ``[batch_sizes..., length, features]``. If None, + inputs_k will copy the value of inputs_q. + inputs_v: values of shape ``[batch_sizes..., length, features]``. If None, + inputs_v will copy the value of inputs_k. + inputs_kv: key/values of shape ``[batch_sizes..., length, features]``. If + None, inputs_kv will copy the value of inputs_q. This arg will be + deprecated soon. Use inputs_k and inputs_v instead. + mask: attention mask of shape ``[batch_sizes..., num_heads, query_length, + key/value_length]``. Attention weights are masked out if their + corresponding mask value is ``False``. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + dropout_rng: optional rng key to pass to the attention layer's dropout + mask. Otherwise, self.make_rng('dropout') is used instead. + sow_weights: if ``True``, the attention weights are sowed into the + 'intermediates' collection. Remember to mark 'intermediates' as + mutable via ``mutable=['intermediates']`` in order to have that + collection returned. + + Returns: + output of shape ``[batch_sizes..., length, features]``. + """ + if inputs_kv is not None: + if inputs_k is not None or inputs_v is not None: + raise ValueError( + 'If either `inputs_k` or `inputs_v` is not None, ' + '`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` ' + 'and `inputs_v` must be None. We recommend using `inputs_k` and ' + '`inputs_v` args, since `inputs_kv` will be deprecated soon. See ' + 'https://github.com/google/flax/discussions/3389 for more ' + 'information.' + ) + inputs_k = inputs_v = inputs_kv + warnings.warn( + 'The inputs_kv arg will be deprecated soon. ' + 'Use inputs_k and inputs_v instead. See ' + 'https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning, + ) + else: + if inputs_k is None: + if inputs_v is not None: + raise ValueError( + '`inputs_k` cannot be None if `inputs_v` is not None. ' + 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' + 'value to `inputs_k` and leave `inputs_v` as None.' + ) + inputs_k = inputs_q + if inputs_v is None: + inputs_v = inputs_k + elif inputs_v.shape[-1] == inputs_v.shape[-2]: + warnings.warn( + f'You are passing an array of shape {inputs_v.shape} ' + 'to the `inputs_v` arg, when you may have intended ' + 'to pass it to the `mask` arg. As of Flax version ' + '0.7.4, the function signature of ' + "MultiHeadDotProductAttention's `__call__` method " + 'has changed to `__call__(inputs_q, inputs_k=None, ' + 'inputs_v=None, *, inputs_kv=None, mask=None, ' + 'deterministic=None)`. Use the kwarg `mask` instead. ' + 'See https://github.com/google/flax/discussions/3389 ' + 'and read the docstring for more information.', + DeprecationWarning, + ) + + features = self.out_features or inputs_q.shape[-1] + qkv_features = self.qkv_features or inputs_q.shape[-1] + assert qkv_features % self.num_heads == 0, ( + f'Memory dimension ({qkv_features}) must be divisible by number of' + f' heads ({self.num_heads}).' + ) + head_dim = qkv_features // self.num_heads + + dense = functools.partial( + DenseGeneral, + axis=-1, + dtype=self.dtype, + param_dtype=self.param_dtype, + features=(self.num_heads, head_dim), + kernel_init=self.kernel_init, + bias_init=self.bias_init, + use_bias=self.use_bias, + precision=self.precision, + dot_general=self.qkv_dot_general, + dot_general_cls=self.qkv_dot_general_cls, + ) + # project inputs_q to multi-headed q/k/v + # dimensions are then [batch..., length, n_heads, n_features_per_head] + query, key, value = ( + dense(name='query')(inputs_q), + dense(name='key')(inputs_k), + dense(name='value')(inputs_v), + ) + + if self.normalize_qk: + # Normalizing query and key projections stabilizes training with higher + # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. + query = LayerNorm( + name='query_ln', + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(query) # type: ignore[call-arg] + key = LayerNorm( + name='key_ln', + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(key) # type: ignore[call-arg] + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.decode: + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') + cached_key = self.variable( + 'cache', 'cached_key', jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) + ) + if is_initialized: + ( + *batch_dims, + max_length, + num_heads, + depth_per_head, + ) = cached_key.value.shape + # shape check of cached keys against query input + expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) + if expected_shape != query.shape: + raise ValueError( + 'Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' + % (expected_shape, query.shape) + ) + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) + indices: tuple[int | jax.Array, ...] = (zero,) * len( + batch_dims + ) + ( + cur_index, + zero, + zero, + ) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + cache_index.value = cache_index.value + 1 + # causal mask for cached decoder self-attention: + # our single query position should only attend to those key + # positions that have already been generated and cached, + # not the remaining zero elements. + mask = combine_masks( + mask, + jnp.broadcast_to( + jnp.arange(max_length) <= cur_index, + tuple(batch_dims) + (1, 1, max_length), + ), + ) + + if ( + self.dropout_rate > 0.0 + ): # Require `deterministic` only if using dropout. + m_deterministic = merge_param( + 'deterministic', self.deterministic, deterministic + ) + if not m_deterministic and dropout_rng is None: + dropout_rng = self.make_rng('dropout') + else: + m_deterministic = True + + # `qk_attn_weights_einsum` and `attn_weights_value_einsum` are optional + # arguments that can be used to override the default `jnp.einsum`. They + # exist for quantized einsum support in AQT. + qk_attn_weights_einsum = ( + self.qk_attn_weights_einsum_cls() + if self.qk_attn_weights_einsum_cls + else None + ) + attn_weights_value_einsum = ( + self.attn_weights_value_einsum_cls() + if self.attn_weights_value_einsum_cls + else None + ) + # apply attention + attn_args = (query, key, value) + # This kwargs list match the default nn.dot_product_attention. + # For custom `attention_fn`s, invalid kwargs will be filtered. + attn_kwargs = dict( + mask=mask, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + broadcast_dropout=self.broadcast_dropout, + deterministic=m_deterministic, + dtype=self.dtype, + precision=self.precision, + force_fp32_for_softmax=self.force_fp32_for_softmax, + qk_attn_weights_einsum=qk_attn_weights_einsum, + attn_weights_value_einsum=attn_weights_value_einsum, + ) + attn_kwargs = { + k: v + for k, v in attn_kwargs.items() + if k in inspect.signature(self.attention_fn).parameters + } + if sow_weights: + x = self.attention_fn(*attn_args, **attn_kwargs, module=self) + else: + x = self.attention_fn(*attn_args, **attn_kwargs) + # back to the original inputs dimensions + out = DenseGeneral( + features=features, + axis=(-2, -1), + kernel_init=self.out_kernel_init or self.kernel_init, + bias_init=self.out_bias_init or self.bias_init, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + dot_general=self.out_dot_general, + dot_general_cls=self.out_dot_general_cls, + name='out', # type: ignore[call-arg] + )(x) + return out + + +class MultiHeadAttention(MultiHeadDotProductAttention): + """Multi-head dot-product attention. + Alias for ``MultiHeadDotProductAttention``. + + **NOTE**: ``MultiHeadAttention`` is a wrapper of ``MultiHeadDotProductAttention``, + and so their implementations are identical. However ``MultiHeadAttention`` layers + will, by default, be named ``MultiHeadAttention_{index}``, whereas ``MultiHeadDotProductAttention`` + will be named ``MultiHeadDotProductAttention_{index}``. Therefore, this could affect + checkpointing, param collection names and RNG threading (since the layer name is + used when generating new RNG's) within the module. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax + + >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) + >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) + >>> shape = (4, 3, 2, 5) + >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) + >>> variables = layer.init(jax.random.key(0), q) + + >>> # different inputs for inputs_q, inputs_k and inputs_v + >>> out = layer.apply(variables, q, k, v) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) + >>> out = layer.apply(variables, q, k) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) + >>> out = layer.apply(variables, q) + + >>> attention_kwargs = dict( + ... num_heads=8, + ... qkv_features=16, + ... kernel_init=nn.initializers.ones, + ... bias_init=nn.initializers.zeros, + ... dropout_rate=0.5, + ... deterministic=False, + ... ) + >>> class Module(nn.Module): + ... attention_kwargs: dict + ... + ... @nn.compact + ... def __call__(self, x, dropout_rng=None): + ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... return out1, out2 + >>> module = Module(attention_kwargs) + >>> variables = module.init({'params': key1, 'dropout': key2}, q) + + >>> # out1 and out2 are different. + >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) + >>> # out3 and out4 are different. + >>> # out1 and out3 are different. out2 and out4 are different. + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) + >>> # out1 and out2 are the same. + >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) + >>> # out1 and out2 are the same as out3 and out4. + >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: infer from inputs and params) + param_dtype: the dtype passed to parameter initializers (default: float32) + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + precision: numerical precision of the computation see ``jax.lax.Precision`` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts query, + key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,, + num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). + """ + + +class SelfAttention(MultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention. + This layer is deprecated in favor of ``MultiHeadDotProductAttention``. + + Example usage:: + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) + >>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5))) + """ + + @compact + def __call__( # type: ignore + self, + inputs_q: Array, + mask: Array | None = None, + deterministic: bool | None = None, + dropout_rng: PRNGKey | None = None, + sow_weights: bool = False, + ): + """Applies multi-head dot product self-attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + Args: + inputs_q: input queries of shape ``[batch_sizes..., length, features]``. + mask: attention mask of shape ``[batch_sizes..., num_heads, query_length, + key/value_length]``. Attention weights are masked out if their + corresponding mask value is ``False``. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + + Returns: + output of shape ``[batch_sizes..., length, features]``. + """ + warnings.warn( + 'SelfAttention will be deprecated soon. Use ' + '`MultiHeadDotProductAttention.__call__(inputs_q)` instead. ' + 'See https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning, + ) + return super().__call__( + inputs_q, + mask=mask, + deterministic=deterministic, + dropout_rng=dropout_rng, + sow_weights=sow_weights, + ) + + +# mask-making utility functions + + +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[..., Any] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: Dtype = jnp.float32, +): + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., ``[batch..., len_q]``, ``[batch..., len_kv]``, the + attention weights will be ``[batch..., heads, len_q, len_kv]`` and this + function will produce ``[batch..., 1, len_q, len_kv]``. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default + dtype: mask return dtype + + Returns: + A ``[batch..., 1, len_q, len_kv]`` shaped mask for 1d attention. + """ + mask = pairwise_fn( + jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) + ) + mask = jnp.expand_dims(mask, axis=-3) + mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) + return mask.astype(dtype) + + +def make_causal_mask( + x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 +) -> Array: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., ``[batch..., len]``, the self-attention weights + will be ``[batch..., heads, len, len]`` and this function will produce a + causal mask of shape ``[batch..., 1, len, len]``. + + Args: + x: input array of shape ``[batch..., len]`` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A ``[batch..., 1, len, len]`` shaped causal mask for 1d attention. + """ + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype, + ) + + +def combine_masks( + *masks: Array | None, dtype: Dtype = jnp.float32 +) -> Array | None: + """Combine attention masks. + + Args: + *masks: set of attention mask arguments to combine, some can be None. + dtype: dtype for the returned mask. + + Returns: + Combined mask, reduced by logical and, returns None if no masks given. + """ + masks_list = [m for m in masks if m is not None] + if not masks_list: + return None + assert all( + map(lambda x: x.ndim == masks_list[0].ndim, masks_list) + ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' + mask, *other_masks = masks_list + for other_mask in other_masks: + mask = jnp.logical_and(mask, other_mask) + return mask.astype(dtype) diff --git a/MaxCode/rag/sources/generic/maxtext_layers_attentions.py b/MaxCode/rag/sources/generic/maxtext_layers_attentions.py new file mode 100644 index 0000000..813cb33 --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_layers_attentions.py @@ -0,0 +1,1177 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attentions Layers.""" + +import dataclasses +import functools +from typing import Any, Iterable, Optional, Tuple, Union, cast + +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh, NamedSharding +import jax +import jax.numpy as jnp + +from flax import nnx + +from maxtext.common.common_types import ( + DecoderBlockType, + BATCH, + BATCH_NO_EXP, + HEAD, + PREFILL_LENGTH, + D_KV, + AxisNames, + AxisIdxes, + ATTN_LENGTH, + ATTN_LENGTH_NO_EXP, + DType, + Config, + Array, + DECODE_LENGTH, + DECODE_BATCH, + PREFILL_KV_BATCH, + KV_HEAD, + KV_HEAD_DIM, + KV_BATCH, + KV_BATCH_NO_EXP, + ATTN_EMBED, + MODEL_MODE_AUTOREGRESSIVE, + MODEL_MODE_TRAIN, + MODEL_MODE_PREFILL, + EP_AS_CONTEXT, + AttentionType, +) +from maxtext.layers import nnx_wrappers +from maxtext.layers.attention_op import AttentionOp +from maxtext.layers.embeddings import ( + LLaMARotaryEmbedding, + LlamaVisionRotaryEmbedding, + Qwen3OmniMoeThinkerTextRotaryEmbedding, + Qwen3OmniMoeVisionRotaryEmbedding, + RotaryEmbedding, + YarnRotaryEmbedding, + PartialRotaryEmbedding, +) +from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init +from maxtext.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes +from maxtext.layers.normalizations import RMSNorm, Qwen3NextRMSNorm, GlobalRMSNorm +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import kvcache, page_manager, paged_attention +from maxtext.inference.kvcache import KVQuant +from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding + +# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes +# pytype: disable=attribute-error + + +@dataclasses.dataclass(repr=False) +class L2Norm(nnx.Module): + """ + Implementation of L2Norm in JAX. + + Args: + eps: float, epsilon used for numerical stability (default value should be ok for most cases). + """ + + eps: float = 1e-6 + rngs: nnx.Rngs = None # Not used in L2Norm but passed in by nnx.bridge.to_linen + + def __call__(self, x): + return x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps) + + +def l2_norm_as_linen(self, eps: float = 1e-6): + """ + Initializes the L2Norm module and returns it as a Linen module. + + Args: + eps: float, epsilon used for numerical stability (default value should be ok for most cases). + """ + return nnx_wrappers.to_linen(L2Norm, eps=eps, metadata_fn=variable_to_logically_partitioned) + + +def attention_as_linen( + *, + config: Config, + num_query_heads: int, + num_kv_heads: int, + head_dim: int, + max_target_length: int, + mesh: Mesh, + attention_kernel: str, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + max_prefill_predict_length: int = -1, + dropout_rate: float = 0.0, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), + float32_qk_product: bool = False, # computes logits in float32 for stability. + float32_logits: bool = False, # cast logits in float32 for stability. + quant: Optional[Quant] = None, + kv_quant: Optional[KVQuant] = None, + attention_type: AttentionType = AttentionType.GLOBAL, # Default to global attention + attn_logits_soft_cap: float | None = None, + sliding_window_size: int | None = None, + use_ragged_attention: bool = False, + ragged_block_size: int = 256, + use_qk_norm: bool = False, + query_pre_attn_scalar: float | None = None, + use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections + share_kv_projections: bool = False, # If true, Key and Value use the same projection + # Temperature tuning parameters used for Llama4 + temperature_tuning: bool = False, + temperature_tuning_scale: float = 0.1, + temperature_tuning_floor_scale: float = 8192.0, + # Shard the query activation as the same as the key and value. + # TODO: Find a better sharding axis name. + # TODO: Further break down the Training and Inference axes for the q, k, v. + prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), + prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), + decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), + prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), + decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV), + prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3), + ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3), + compute_axis_order: AxisIdxes = (0, 1, 2, 3), + reshape_q: bool = False, + is_nope_layer: bool = False, + is_vision: bool = False, + model_mode: str = MODEL_MODE_TRAIN, + use_mrope: bool = False, + mrope_section: tuple[int, int, int] | None = None, + name: str | None = None, + rope_type: str | None = None, +): + """A factory function to create an Attention as a Linen module. + + This function serves as a bridge to use the NNX-based `Attention` within a + Linen model. + """ + return nnx_wrappers.to_linen( + Attention, + config=config, + num_query_heads=num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_target_length=max_target_length, + mesh=mesh, + attention_kernel=attention_kernel, + inputs_q_shape=inputs_q_shape, + inputs_kv_shape=inputs_kv_shape, + dtype=dtype, + weight_dtype=weight_dtype, + max_prefill_predict_length=max_prefill_predict_length, + dropout_rate=dropout_rate, + kernel_init=kernel_init, + float32_qk_product=float32_qk_product, + float32_logits=float32_logits, + quant=quant, + kv_quant=kv_quant, + attention_type=attention_type, + attn_logits_soft_cap=attn_logits_soft_cap, + sliding_window_size=sliding_window_size, + use_ragged_attention=use_ragged_attention, + ragged_block_size=ragged_block_size, + use_qk_norm=use_qk_norm, + query_pre_attn_scalar=query_pre_attn_scalar, + use_bias_in_projections=use_bias_in_projections, + share_kv_projections=share_kv_projections, + temperature_tuning=temperature_tuning, + temperature_tuning_scale=temperature_tuning_scale, + temperature_tuning_floor_scale=temperature_tuning_floor_scale, + prefill_query_axis_names=prefill_query_axis_names, + prefill_key_axis_names=prefill_key_axis_names, + prefill_value_axis_names=prefill_value_axis_names, + query_axis_names=query_axis_names, + key_axis_names=key_axis_names, + value_axis_names=value_axis_names, + ep_query_axis_names=ep_query_axis_names, + ep_key_axis_names=ep_key_axis_names, + ep_value_axis_names=ep_value_axis_names, + input_axis_names=input_axis_names, + ep_input_axis_names=ep_input_axis_names, + out_axis_names=out_axis_names, + ep_out_axis_names=ep_out_axis_names, + prefill_input_axis_names=prefill_input_axis_names, + decode_input_axis_names=decode_input_axis_names, + prefill_out_axis_names=prefill_out_axis_names, + decode_out_axis_names=decode_out_axis_names, + prefill_cache_axis_order=prefill_cache_axis_order, + ar_cache_axis_order=ar_cache_axis_order, + compute_axis_order=compute_axis_order, + reshape_q=reshape_q, + is_nope_layer=is_nope_layer, + is_vision=is_vision, + model_mode=model_mode, + use_mrope=use_mrope, + mrope_section=mrope_section, + name=name, + rope_type=rope_type, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + + +class Attention(nnx.Module): + """Attention Module. + + This module implements multi-headed attention as described in the + original Transformer paper. It projects the inputs into query, key, and + value vectors, applies the attention mechanism, and projects the results to + an output vector. + + Attributes: + config: The model configuration. + num_query_heads: Number of query attention heads. + num_kv_heads: Number of key-value attention heads. + head_dim: The dimension of each attention head. + max_target_length: Maximum sequence length. + mesh: The device mesh. + attention_kernel: The attention kernel to use (e.g., 'dot_product', 'flash'). + inputs_q_shape: Query inputs shape for initialization, required by NNX. + inputs_kv_shape: Key/value inputs shape for initialization, required by NNX. + dtype: The data type for computation. + weight_dtype: The data type for weights. + max_prefill_predict_length: Maximum length for prefill. + dropout_rate: The dropout rate. + kernel_init: Initializer for the kernel of the dense layers. + float32_qk_product: If True, compute query-key product in float32. + float32_logits: If True, cast logits to float32 before softmax. + quant: Quantization configuration. + kv_quant: KV cache quantization configuration. + attention_type: The type of attention (e.g., 'global', 'local_sliding'). + attn_logits_soft_cap: Soft cap for attention logits. + ... and other configuration parameters. + """ + + def __init__( + self, + config: Config, + num_query_heads: int, + num_kv_heads: int, + head_dim: int, + max_target_length: int, + mesh: Mesh, + attention_kernel: str, + inputs_q_shape: Tuple, + inputs_kv_shape: Tuple, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + max_prefill_predict_length: int = -1, + dropout_rate: float = 0.0, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), + float32_qk_product: bool = False, # computes logits in float32 for stability. + float32_logits: bool = False, # cast logits in float32 for stability. + quant: Optional[Quant] = None, + kv_quant: Optional[KVQuant] = None, + attention_type: AttentionType = AttentionType.GLOBAL, # Default to global attention + attn_logits_soft_cap: float | None = None, + sliding_window_size: int | None = None, + use_ragged_attention: bool = False, + ragged_block_size: int = 256, + use_qk_norm: bool = False, + query_pre_attn_scalar: float | None = None, + use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections + share_kv_projections: bool = False, # If true, Key and Value use the same projection + # Temperature tuning parameters used for Llama4 + temperature_tuning: bool = False, + temperature_tuning_scale: float = 0.1, + temperature_tuning_floor_scale: float = 8192.0, + # Shard the query activation as the same as the key and value. + # TODO: Find a better sharding axis name. + # TODO: Further break down the Training and Inference axes for the q, k, v. + prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), + prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), + decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), + prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), + decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV), + prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3), + ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3), + compute_axis_order: AxisIdxes = (0, 1, 2, 3), + reshape_q: bool = False, + is_nope_layer: bool = False, + is_vision: bool = False, + model_mode: str = MODEL_MODE_TRAIN, + base_kv_cache: bool = True, + use_mrope: bool = False, + mrope_section: tuple[int, int, int] | None = None, + name: str | None = None, + rope_type: str | None = None, + rngs: Optional[nnx.Rngs] = None, + ): + """Initializes the Attention module. + + Attributes: + config: The model configuration. + num_query_heads: Number of query attention heads. + num_kv_heads: Number of key-value attention heads. + head_dim: The dimension of each attention head. + max_target_length: Maximum sequence length. + mesh: The device mesh. + attention_kernel: The attention kernel to use (e.g., 'dot_product', 'flash'). + inputs_q_shape: Query inputs shape for initialization, required by NNX. + inputs_kv_shape: Key/value inputs shape for initialization, required by NNX. + dtype: The data type for computation. + weight_dtype: The data type for weights. + max_prefill_predict_length: Maximum length for prefill. + dropout_rate: The dropout rate. + kernel_init: Initializer for the kernel of the dense layers. + float32_qk_product: If True, compute query-key product in float32. + float32_logits: If True, cast logits to float32 before softmax. + quant: Quantization configuration. + kv_quant: KV cache quantization configuration. + attention_type: The type of attention (e.g., 'global', 'local_sliding'). + attn_logits_soft_cap: Soft cap for attention logits. + sliding_window_size: The size of the sliding window for local attention. + use_ragged_attention: Whether to use ragged attention for decoding. + ragged_block_size: The block size for ragged attention. + use_qk_norm: Whether to apply normalization to query and key. + query_pre_attn_scalar: Scalar to apply to query before attention. + use_bias_in_projections: Whether to use bias in Q, K, V, and output projections. + temperature_tuning: Whether to use temperature tuning for attention. + temperature_tuning_scale: The scale for temperature tuning. + temperature_tuning_floor_scale: The floor scale for temperature tuning. + ... other configuration parameters. + is_nope_layer: Whether this is a "NoPE" (No Position-Embedding) layer. + is_vision: Whether this is a vision attention layer. + model_mode: The model's operational mode (e.g., 'train', 'prefill'). + base_kv_cache: Whether to use base (non-MLA) kv cache, if KVCache is used + rope_type: Optional override for the RoPE type (e.g., 'default', 'yarn'). + If provided, this takes precedence over `config.rope_type`. + rngs: RNG state for initialization, passed by the nnx.to_linen wrapper. + """ + + self.config = config + self.num_query_heads = num_query_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.max_target_length = max_target_length + self.mesh = mesh + self.attention_kernel = attention_kernel + self.dtype = dtype + self.weight_dtype = weight_dtype + self.max_prefill_predict_length = max_prefill_predict_length + self.dropout_rate = dropout_rate + self.kernel_init = kernel_init + self.float32_qk_product = float32_qk_product + self.float32_logits = float32_logits + self.quant = quant + self.kv_quant = kv_quant + self.attention_type = attention_type + self.attn_logits_soft_cap = attn_logits_soft_cap + self.sliding_window_size = sliding_window_size + self.use_ragged_attention = use_ragged_attention + self.ragged_block_size = ragged_block_size + self.use_qk_norm = use_qk_norm + self.query_pre_attn_scalar = query_pre_attn_scalar + self.use_bias_in_projections = use_bias_in_projections + self.share_kv_projections = share_kv_projections + self.temperature_tuning = temperature_tuning + self.temperature_tuning_scale = temperature_tuning_scale + self.temperature_tuning_floor_scale = temperature_tuning_floor_scale + self.prefill_query_axis_names = prefill_query_axis_names + self.prefill_key_axis_names = prefill_key_axis_names + self.prefill_value_axis_names = prefill_value_axis_names + self.query_axis_names = query_axis_names + self.key_axis_names = key_axis_names + self.value_axis_names = value_axis_names + self.ep_query_axis_names = ep_query_axis_names + self.ep_key_axis_names = ep_key_axis_names + self.ep_value_axis_names = ep_value_axis_names + self.input_axis_names = input_axis_names + self.ep_input_axis_names = ep_input_axis_names + self.out_axis_names = out_axis_names + self.ep_out_axis_names = ep_out_axis_names + self.prefill_input_axis_names = prefill_input_axis_names + self.decode_input_axis_names = decode_input_axis_names + self.prefill_out_axis_names = prefill_out_axis_names + self.decode_out_axis_names = decode_out_axis_names + self.prefill_cache_axis_order = prefill_cache_axis_order + self.ar_cache_axis_order = ar_cache_axis_order + self.compute_axis_order = compute_axis_order + self.reshape_q = reshape_q + self.is_nope_layer = is_nope_layer + self.is_vision = is_vision + self.model_mode = model_mode + self.use_mrope = use_mrope + self.mrope_section = mrope_section + self.rngs = rngs + # Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config. + self.rope_type = (rope_type or self.config.rope_type).lower() + + self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2 + self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT + + # Module attribute names must match names previously passed to Linen for checkpointing + self.KVCache_0 = ( + self.init_kv_caches(inputs_kv_shape=inputs_kv_shape) + if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache and config.attention != "vllm_rpa" + else None + ) + + self.rotary_embedding = self.init_rotary_embedding() + + self.attention_op = AttentionOp( + config=self.config, + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + kv_quant=self.kv_quant, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + compute_axis_order=self.compute_axis_order, + reshape_q=self.reshape_q, + attention_type=self.attention_type, + attn_logits_soft_cap=self.attn_logits_soft_cap, + sliding_window_size=self.sliding_window_size, + chunk_attn_window_size=self.config.chunk_attn_window_size, + use_ragged_attention=self.use_ragged_attention, + ragged_block_size=self.ragged_block_size, + rngs=self.rngs, + ) + # When paged attention is enabled, paged attention op is used for all model modes except TRAIN, + # which uses default attention op. + if self.config.attention == "paged": + self.paged_attention_op = paged_attention.PagedAttentionOp( + mesh=self.mesh, + num_pages=self.config.pagedattn_num_pages, + tokens_per_page=self.config.pagedattn_tokens_per_page, + max_pages_per_slot=(self.config.max_target_length + self.config.pagedattn_tokens_per_page - 1) + // self.config.pagedattn_tokens_per_page, + max_pages_per_prefill=(self.config.max_prefill_predict_length + self.config.pagedattn_tokens_per_page - 1) + // self.config.pagedattn_tokens_per_page, + pages_per_compute_block=self.config.pagedattn_pages_per_compute_block, + num_kv_heads=self.num_kv_heads, + kv_head_dim_size=self.head_dim, + dtype=self.dtype, + attn_logits_soft_cap=self.attn_logits_soft_cap, + rngs=self.rngs, + ) + + self._init_projections(inputs_q_shape, inputs_kv_shape) + + if self.config.attention_sink: + self.sinks = nnx.Param( + default_bias_init(self.rngs.params(), (self.config.num_query_heads,), self.weight_dtype), + sharding=(None,), + ) + else: + self.sinks = None + + is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4 + + if self.use_qk_norm and not is_llama4_decoder_block: + # Check if this is Olmo3, which uses a unique "Global" QK Norm strategy. + # GlobalRMSNorm flattens (Heads, Dim) to normalize across the entire hidden state. + use_global_qk_norm = self.config.model_name.startswith("olmo3") + qk_norm_cls = GlobalRMSNorm if use_global_qk_norm else RMSNorm + + # For RMSNorm use `head_dim` (per-head normalization), while for GlobalRMSNorm use `num_heads * head_dim` (global normalization). + q_features = (self.num_query_heads * self.head_dim) if use_global_qk_norm else self.head_dim + k_features = (self.num_kv_heads * self.head_dim) if use_global_qk_norm else self.head_dim + + self.query_norm = qk_norm_cls( + num_features=q_features, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + shard_mode=self.config.shard_mode, + epsilon=self.config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=self.rngs, + ) + self.key_norm = qk_norm_cls( + num_features=k_features, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + shard_mode=self.config.shard_mode, + epsilon=self.config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=self.rngs, + ) + elif self.is_qwen3_next: + self.query_norm = Qwen3NextRMSNorm( + num_features=self.config.head_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + rngs=self.rngs, + ) + self.key_norm = Qwen3NextRMSNorm( + num_features=self.config.head_dim, + eps=self.config.normalization_layer_epsilon, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + rngs=self.rngs, + ) + else: + self.query_norm = None + self.key_norm = None + + self._maybe_shard_with_logical = functools.partial( + maybe_shard_with_logical, + mesh=mesh, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + ) + + def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: + """Initializes the query, key, value, and output projections.""" + if self.config.fused_qkv: + self.qkv_proj = self.init_qkv_w(inputs_shape=inputs_q_shape) + else: + self.query = self.init_query_w(inputs_q_shape=inputs_q_shape) + self.key = self.init_kv_w(inputs_kv_shape=inputs_kv_shape) + if not self.share_kv_projections: + self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape) + self.out = self.init_out_w(output_dim=inputs_q_shape[-1]) + + def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: + """Query projection initialization.""" + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + # We disable depth_scaling when using qk_norm or a query_pre_attn_scalar + # to avoid applying scaling twice. + if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0): + depth_scaling = 1.0 + else: + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + + def query_init(*args): + # pylint: disable=no-value-for-parameter + return self.kernel_init(*args) / depth_scaling + + kernel_axes = ( + (None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv") + ) + in_features = self.convert_dense_general_inputs_shape(inputs_q_shape) + out_features = (self.num_query_heads, self.head_dim) + + if self.is_qwen3_next: + out_features = (self.num_query_heads, self.head_dim * 2) + + return DenseGeneral( + in_features_shape=in_features, + out_features_shape=out_features, + axis=-1, + kernel_init=query_init, + kernel_axes=kernel_axes, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + use_bias=self.use_bias_in_projections, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + + def query_projection(self, inputs_q: Array, out_sharding: NamedSharding | None = None) -> Array: + """Query projection.""" + + return self.query(inputs_q, out_sharding=out_sharding) + + def init_kv_w(self, inputs_kv_shape: Tuple) -> nnx.Module: + """Initializes the key or value projection. + + Args: + inputs_kv_shape: Key/value inputs shape for initialization. + + Returns: + A DenseGeneral module that performs the key or value projection. + """ + if self.num_kv_heads == -1: + raise ValueError("num_kv_heads is not defined.") + + if self.num_query_heads % self.num_kv_heads != 0: + raise ValueError("Invalid num_kv_heads for GQA.") + + kernel_axes = ( + (None, None, None) + if self.config.ici_context_autoregressive_parallelism > 1 + else ("embed", "kv_heads", "kv_head_dim") + ) + + return DenseGeneral( + in_features_shape=self.convert_dense_general_inputs_shape(inputs_kv_shape), + out_features_shape=(self.num_kv_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=kernel_axes, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=self.use_bias_in_projections, + rngs=self.rngs, + ) + + def kv_projection(self, inputs_kv: Array, proj_name: str, out_sharding: NamedSharding | None = None) -> nnx.Module: + """Applies the key or value projection. + + Args: + inputs_kv: The input tensor to project. + proj_name: The name of the projection ("key" or "value"). + + Returns: + The projected key or value tensor. + + Raises: + ValueError: If `proj_name` is not one of the supported values + ("key", "value"). + + """ + if proj_name == "key": + return self.key(inputs_kv, out_sharding=out_sharding) + elif proj_name == "value": + return self.value(inputs_kv, out_sharding=out_sharding) + else: + raise ValueError(f"proj_name must be 'key' or 'value', but got {proj_name}") + + def init_qkv_w(self, inputs_shape: Tuple) -> nnx.Module: + return DenseGeneral( + in_features_shape=self.convert_dense_general_inputs_shape(inputs_shape), + out_features_shape=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=self.use_bias_in_projections, + rngs=self.rngs, + ) + + def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedSharding | None = None): + """Fused QKV projection""" + + qkv_proj = self.qkv_proj(inputs, out_sharding) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] + return query, key, value + + def init_out_w(self, output_dim: int) -> nnx.Module: + """out projection""" + in_features = (self.num_query_heads, self.head_dim) + out_features = output_dim + out_kernel_axis = ( + (None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed") + ) + axis = (-2, -1) + + if self.is_qwen3_next: + in_features = self.num_query_heads * self.head_dim + out_kernel_axis = ("mlp", "embed") + axis = (-1,) + + return DenseGeneral( + in_features_shape=in_features, + out_features_shape=out_features, + axis=axis, + kernel_init=self.kernel_init, + kernel_axes=out_kernel_axis, # trade speed with memory + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=False if self.is_qwen2 else self.use_bias_in_projections, + rngs=self.rngs, + ) + + def out_projection(self, out: Array, out_sharding: NamedSharding | None = None) -> Array: + """out projection""" + return self.out(out, out_sharding=out_sharding) + + def convert_dense_general_inputs_shape( + self, + inputs_shape: tuple[int, ...] | None = None, + axis: Union[Iterable[int], int] = -1, + ) -> Union[Iterable[int], int]: + axis = canonicalize_tuple(axis) + return tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape))) + + def init_rotary_embedding(self): + """Initializes the rotary embeddings, handling different model types. + + Returns: + The rotary embedding module that will be used in the model. + """ + if self.config.attention_type == AttentionType.MLA.value: + # For MLA attention RoPE is applied to only `self.qk_rope_head_dim` portion the heads. + rope_embedding_dims = self.qk_rope_head_dim + else: + rope_embedding_dims = self.head_dim + + rope_type = self.rope_type + rope_use_scale = self.config.rope_use_scale + if self.is_vision: + if self.config.model_name.startswith("qwen3-omni"): + rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding( + hidden_size=self.config.hidden_size_for_vit, + num_attention_heads=self.config.num_attention_heads_for_vit, + spatial_merge_size=self.config.spatial_merge_size_for_vit, + rope_theta=self.config.rope_theta_for_vit, + fprop_dtype=self.dtype, + rngs=self.rngs, + ) + elif self.config.model_name.startswith("llama4"): + rotary_embedding = LlamaVisionRotaryEmbedding( + image_size=self.config.image_size_for_vit, + patch_size=self.config.patch_size_for_vit, + hidden_size=self.config.hidden_size_for_vit, + num_attention_heads=self.config.num_attention_heads_for_vit, + rope_theta=self.config.rope_theta_for_vit, + cast_as_fprop_dtype=True, + fprop_dtype=self.dtype, + rngs=self.rngs, + ) + else: + raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}") + + elif self.use_mrope: + rotary_embedding = Qwen3OmniMoeThinkerTextRotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=self.config.rope_max_timescale, + embedding_dims=rope_embedding_dims, + cast_as_fprop_dtype=True, + fprop_dtype=self.dtype, + mrope_section=self.mrope_section, + rngs=self.rngs, + ) + + elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"): + rotary_embedding = LLaMARotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=self.config.rope_max_timescale, + mesh=self.mesh, + embedding_dims=rope_embedding_dims, + fprop_dtype=self.dtype, + use_scale=rope_use_scale, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + elif rope_type.startswith("yarn"): + rotary_embedding = YarnRotaryEmbedding( + max_position_embeddings=self.config.max_position_embeddings, + mesh=self.mesh, + original_max_position_embeddings=self.config.original_max_position_embeddings, + beta_fast=self.config.beta_fast, + beta_slow=self.config.beta_slow, + rope_theta=self.config.rope_max_timescale, + rope_factor=self.config.rope_factor, + embedding_dims=rope_embedding_dims, + fprop_dtype=self.dtype, + interleave=self.config.rope_interleave, + truncate=self.config.rope_truncate, + attention_scaling=self.config.rope_attention_scaling, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + elif self.is_qwen3_next: + rotary_embedding = PartialRotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=self.config.rope_max_timescale, + mesh=self.mesh, + embedding_dims=self.config.head_dim, + partial_rotary_factor=self.config.partial_rotary_factor, + cast_as_fprop_dtype=True, + fprop_dtype=self.config.dtype, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + else: + max_timescale = self.config.rope_max_timescale + # For local attention use local_rope_max_timescale if it's is positive + if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.local_rope_max_timescale > 0: + max_timescale = self.config.local_rope_max_timescale + + rope_linear_scaling_factor = self.config.rope_linear_scaling_factor + # In gemma3, linear scaling factor does not apply to local sliding layers. + if self.config.model_name.startswith("gemma3") and self.attention_type == AttentionType.LOCAL_SLIDING: + rope_linear_scaling_factor = 1.0 + + rotary_embedding = RotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=max_timescale, + mesh=self.mesh, + embedding_dims=rope_embedding_dims, + fprop_dtype=self.dtype, + rope_linear_scaling_factor=rope_linear_scaling_factor, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + return rotary_embedding + + def apply_rotary_embedding( + self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None + ): + """Applies rotary embeddings, handling different model types. + + Args: + inputs: The input tensor to apply rotary embeddings to. + inputs_positions: The positions of the inputs. + rope_kwargs: A dictionary of keyword arguments for the rotary embedding. + + Returns: + The input tensor with rotary embeddings applied. + """ + if isinstance(self.rotary_embedding, Qwen3OmniMoeVisionRotaryEmbedding): + # For Qwen3OmniMoe vision, pass static dimensions from kwargs. + num_frames = rope_kwargs.get("num_frames") + height = rope_kwargs.get("height") + width = rope_kwargs.get("width") + # Type cast required: Omni rotary embedding uses different __call__ parameters than other embeddings. + return cast(Qwen3OmniMoeVisionRotaryEmbedding, self.rotary_embedding)(inputs, num_frames, height, width) + else: + return self.rotary_embedding(inputs, inputs_positions) + + def init_kv_caches(self, inputs_kv_shape: Tuple): + """Initializes KVCache. + + Args: + inputs_kv_shape: Key/value inputs shape for initialization. + + Returns: + A KVCache module instance. + + """ + batch_size, _, _ = inputs_kv_shape + # During initialization, seq_len of inputs_kv is max_target_length, + # which is not always correct for some functions in KVCache. + # However, KVCache internal cache shapes are based on max_prefill_length + # and max_target_length, not the passed seq_len. + # We can use a placeholder value. The correct fix might involve refactoring + # KVCache. + placeholder_seq_len = 1 + + return kvcache.KVCache( + max_prefill_length=self.max_prefill_predict_length, + max_target_length=self.max_target_length, + batch=batch_size, + key_seq_len=placeholder_seq_len, + value_seq_len=placeholder_seq_len, + key_heads=self.num_kv_heads, + value_heads=self.num_kv_heads, + key_head_size=self.head_dim, + value_head_size=self.head_dim, + dtype=self.dtype, + kv_quant=self.kv_quant, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + use_chunked_prefill=self.config.use_chunked_prefill, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous_chunk): + """Updates the KV caches for prefill and autoregressive modes. + + This method uses a kvcache module to update and retrieve the key-value + caches based on the current operational mode. + + Args: + key: The key tensor for the current attention computation. + value: The value tensor for the current attention computation. + decoder_segment_ids: Segment IDs for the decoder, used for masking. + model_mode: The operational mode ('train', 'prefill', 'autoregressive'). + previous_chunk: Information about previously processed chunks, used for + chunked prefill. + + Returns: + A list containing two elements: + - The prefill key-value cache, or None. + - The autoregressive key-value cache, or None. + """ + prefill_kv_cache, ar_kv_cache = self.KVCache_0( + key=key, + value=value, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + use_ragged_attention=self.use_ragged_attention, + previous_chunk=previous_chunk, + ) + return [prefill_kv_cache, ar_kv_cache] + + def forward_serve_vllm( + self, + query: Array, + key: Array, + value: Array, + rpa_kv_cache: list[Array] | None = None, + rpa_metadata: dict[str, Any] | None = None, + ) -> tuple[list[Array], Array]: + """Forward function for vLLM serving with RPA attention.""" + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.common.attention_interface import sharded_ragged_paged_attention as rpa_ops + except ImportError as e: + raise ImportError( + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + + if rpa_kv_cache is None or rpa_metadata is None: + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") + + query = query.reshape(-1, query.shape[2], query.shape[3]) + key = key.reshape(-1, key.shape[2], key.shape[3]) + value = value.reshape(-1, value.shape[2], value.shape[3]) + + if self.config.sliding_window_size > 0: + attention_chunk_size = self.config.sliding_window_size + else: + # Chunked attention currently not used in vLLM RPA. + attention_chunk_size = None + + q_scale, k_scale, v_scale = None, None, None + + md = rpa_metadata + + output, kv_cache = rpa_ops( + self.mesh, + query, + key, + value, + rpa_kv_cache, + md.seq_lens, + md.block_tables, + md.query_start_loc, + md.request_distribution, + self.sinks.astype(jnp.float32) if self.sinks is not None else None, + 1.0, + attention_chunk_size, + q_scale, + k_scale, + v_scale, + ) + return kv_cache, output + + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + inputs_positions: Array | None = None, + decoder_segment_ids: Array | None = None, + out_sharding: NamedSharding | None = None, + *, + model_mode: str = MODEL_MODE_TRAIN, + deterministic: bool = False, + previous_chunk: Any = None, + slot: Optional[int] = None, + page_state: Optional[page_manager.PageState] = None, + bidirectional_mask: Any = None, + rope_kwargs: dict | None = None, + kv_cache: Optional[Array] = None, + attention_metadata: Optional[dict[str, Any]] = None, + ): + """Applies Attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention, and project the results to an output vector. + + This method handles three modes: + 1. **Training**: The KV cache is ignored. + 2. **Prefill**: The KV cache is filled with the key-value pairs from the input sequence. + 3. **Autoregressive Decoding**: The KV cache is used to provide context from previous steps. + + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. + + Args: + inputs_q: Input queries of shape `[batch, q_length, q_features]`. + inputs_kv: Key/values of shape `[batch, kv_length, kv_features]`. + inputs_positions: Input positions for rotary embeddings. + decoder_segment_ids: Segment IDs for masking. + model_mode: The operational mode ('train', 'prefill', 'autoregressive'). + deterministic: If True, disables dropout. + previous_chunk: Information about previously processed chunks for chunked prefill. + slot: The batch slot index for paged attention. + page_state: The current state of the paged attention manager. + bidirectional_mask: A mask for bidirectional attention, used in multimodal models. + kv_cache: Optional KV cache input, used when invoking from vLLM. + attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM. + + Returns: + output of shape `[batch, length, q_features]`. + """ + if model_mode == MODEL_MODE_PREFILL: + input_axis_names = self.prefill_input_axis_names + elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: + input_axis_names = self.ep_input_axis_names + elif model_mode == MODEL_MODE_TRAIN: + input_axis_names = self.input_axis_names + else: + input_axis_names = self.decode_input_axis_names + + inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names) + inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names) + qkv_sharding = create_sharding(self.mesh, input_axis_names) + + # apply projection. + if self.config.fused_qkv: + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") + else: + query = self.query_projection(inputs_q, out_sharding=qkv_sharding) + key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=qkv_sharding) + if self.share_kv_projections: + value = key + else: + value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding) + + gate = None + if self.is_qwen3_next: + # Split query into query & gate. + query, gate = jnp.split(query, 2, axis=-1) + batch_size, seq_len, _, _ = gate.shape + gate = gate.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) + + is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4 + # NOTE: llama 4 does L2 normalization after RoPE + # Apply Qwen3Next specific RMS Norm + if (self.use_qk_norm and not is_llama4_decoder_block) or self.is_qwen3_next: + query = self.query_norm(query) + key = self.key_norm(key) + + # NOTE: is_nope_layer should be used in attention mask and also used in attention tuning + use_rope = not self.is_nope_layer + use_qk_norm = self.use_qk_norm and use_rope + + if use_rope: + query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs) + key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs) + + if use_qk_norm and is_llama4_decoder_block: + l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon) + query = l2_norm(query) + key = l2_norm(key) + + # apply query_pre_attn_scalar if it's present. + if self.query_pre_attn_scalar and self.query_pre_attn_scalar != 1.0: + query = query * self.query_pre_attn_scalar + + if self.temperature_tuning and not use_rope: + attn_scales = ( + jnp.log(jnp.floor((inputs_positions.astype(self.dtype) + 1.0) / self.temperature_tuning_floor_scale) + 1.0) + * self.temperature_tuning_scale + + 1.0 + ) + query = (query * attn_scales[:, :, jnp.newaxis, jnp.newaxis]).astype(self.dtype) + + if model_mode == MODEL_MODE_PREFILL: + query = self._maybe_shard_with_logical(query, self.prefill_query_axis_names) + key = self._maybe_shard_with_logical(key, self.prefill_key_axis_names) + value = self._maybe_shard_with_logical(value, self.prefill_value_axis_names) + elif model_mode == MODEL_MODE_AUTOREGRESSIVE: + query = self._maybe_shard_with_logical(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + key = self._maybe_shard_with_logical(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) + value = self._maybe_shard_with_logical(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) + elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: + query = self._maybe_shard_with_logical(query, self.ep_query_axis_names) + key = self._maybe_shard_with_logical(key, self.ep_key_axis_names) + value = self._maybe_shard_with_logical(value, self.ep_value_axis_names) + else: + query = self._maybe_shard_with_logical(query, self.query_axis_names) + key = self._maybe_shard_with_logical(key, self.key_axis_names) + value = self._maybe_shard_with_logical(value, self.value_axis_names) + + query = checkpoint_name(query, "query_proj") + key = checkpoint_name(key, "key_proj") + value = checkpoint_name(value, "value_proj") + + assert not self.config.quantize_kvcache or self.kv_quant + + if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN: + unnormalized_out, _, exp_sum = self.paged_attention_op( + query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state + ) + out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: + batch, seq_len, num_heads, head_dim = query.shape + updated_kv, attn_out = self.forward_serve_vllm( + query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata + ) + out = attn_out.reshape(batch, seq_len, num_heads, head_dim) + kv_cache = updated_kv + + else: + cached_values = [None, None] + if model_mode != MODEL_MODE_TRAIN: + cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) + out = self.attention_op( + query, + key, + value, + decoder_segment_ids, + model_mode, + cached_values, + previous_chunk, + bidirectional_mask, + self.sinks, + ) + out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") + if model_mode == MODEL_MODE_PREFILL: + out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names) + elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: + out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) + elif model_mode == MODEL_MODE_TRAIN: + out = self._maybe_shard_with_logical(out, self.out_axis_names) + else: + out = self._maybe_shard_with_logical(out, self.decode_out_axis_names) + if self.is_qwen3_next: + out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) + out = out * jax.nn.sigmoid(gate) + out = self.out_projection(out, out_sharding=out_sharding) + if self.config.distill_beta > 0.0: + self.sow(nnx.Intermediate, "out_projection_activations", out) + out = checkpoint_name(out, "out_proj") + return out, kv_cache diff --git a/MaxCode/rag/sources/generic/maxtext_layers_embeddings.py b/MaxCode/rag/sources/generic/maxtext_layers_embeddings.py new file mode 100644 index 0000000..8c2b53f --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_layers_embeddings.py @@ -0,0 +1,1730 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedding Layers.""" + +import dataclasses +import math + +import jax +from jax import lax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding + +from flax import nnx + +from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType +from maxtext.layers import nnx_wrappers +from maxtext.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned +from maxtext.utils import max_logging +from maxtext.utils import max_utils +from maxtext.utils.sharding import logical_to_mesh_axes, create_sharding + +_MAX_WAVELENGTH = 10_000 + + +def _maybe_move_embedding_to_device(embedding_table: Array, config: Config) -> Array: + """Moves embedding table to device if parameter offloading is enabled.""" + if config.parameter_memory_host_offload: + max_logging.log("embeddings.py: Moving embedding parameter to device") + return jax.device_put(embedding_table, max_utils.device_space()) + return embedding_table + + +def embed_as_linen( + *, + num_embeddings: int, + num_features: int, + config: Config, + mesh: Mesh, + cast_input_dtype: None | DType = None, + dtype: DType = jnp.float32, + attend_dtype: None | DType = None, + embedding_init: Initializer = default_embed_init, + name: str | None = None, +): + """Initializes the Embed NNX module and returns it as a Linen module. + + This function serves as a bridge to use the NNX-based `Embed` module within + a Linen model. It wraps the `Embed` module using `nnx.bridge.to_linen`, + making it compatible with the Linen API. + + Args: + num_embeddings: The number of embeddings. + num_features: The number of feature dimensions for each embedding. + config: The model configuration. + cast_input_dtype: The dtype to cast the input to, if any. + dtype: The dtype of the embedding vectors. + attend_dtype: The dtype for the `attend` method. + embedding_init: The initializer for the embedding matrix. + name: The name of the Linen module. + + Returns: + A Linen module that wraps the NNX `Embed` module. + """ + return nnx_wrappers.to_linen( + Embed, + num_embeddings=num_embeddings, + num_features=num_features, + config=config, + mesh=mesh, + cast_input_dtype=cast_input_dtype, + dtype=dtype, + attend_dtype=attend_dtype, + embedding_init=embedding_init, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +class Embed(nnx.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors.""" + + def __init__( + self, + num_embeddings: int, + num_features: int, + config: Config, + mesh: Mesh, + cast_input_dtype: None | DType = None, + dtype: DType = jnp.float32, + attend_dtype: None | DType = None, + embedding_init: Initializer = default_embed_init, + *, + # Not used in Embed but passed in by nnx.bridge.to_linen. + # TODO: Remove when bridge no longer needed + rngs: nnx.Rngs, + ): + """Initializes the Embed module. + + Args: + num_embeddings: The number of embeddings. + num_features: The number of feature dimensions for each embedding. + config: The model configuration. + cast_input_dtype: The dtype to cast the input to, if any. + dtype: The dtype of the embedding vectors. + attend_dtype: The dtype for the `attend` method. + embedding_init: The initializer for the embedding matrix. + rngs: The random number generators for initialization. + """ + self.num_embeddings = num_embeddings + self.num_features = num_features + self.config = config + self.mesh = mesh + self.cast_input_dtype = cast_input_dtype + self.dtype = dtype + self.attend_dtype = attend_dtype + + self.embedding = nnx.Param( + embedding_init( + rngs.params(), + (self.num_embeddings, self.num_features), + self.config.weight_dtype, + ), + sharding=("vocab", "embed"), + ) + + def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `num_features` dimension appended. + """ + cfg = self.config + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError("Input type must be an integer or unsigned integer.") + + embedding = jnp.asarray( + _maybe_move_embedding_to_device(self.embedding.value, self.config), + self.dtype, + ) + + output_axis_names = ( + ( + "activation_embed_and_logits_batch", + "prefill_activation_length", + "activation_embed", + ) + if model_mode == MODEL_MODE_PREFILL + else ( + "activation_embed_and_logits_batch", + "activation_length_no_exp", + "activation_embed", + ) + ) + out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh) + + out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None + + if cfg.use_iota_embed: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, embedding, out_sharding=out_sharding) + else: + output = embedding.at[inputs].get(out_sharding=out_sharding) + + return output + + def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `num_features` of the + embedding. + out_sharding: NamedSharding object indicating how the output gets sharded + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + embedding = self.embedding.value + attend_dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return attend_on_embedding(query, embedding, attend_dtype, self.config, out_sharding) + + +def attend_on_embedding( + query: Array, + embedding_table: Array, + attend_dtype: DType, + config: Config, + out_sharding: NamedSharding | None = None, +) -> Array: + """Attend over an embedding table using a query array. + + TODO: Remove this method when Embed bridge to Linen is no longer needed + + Args: + query: An array with a last dimension equal to the feature depth of the embedding. + embedding_table: The embedding table to attend over. + attend_dtype: The data type for the attention computation. + config: The model configuration, used to check for parameter offloading. + out_sharding: NamedSharding object indicating the output sharding + + Returns: + An array with a final dimension equal to `num_embeddings`, corresponding to the + batched inner-product of the query vectors against each embedding. + """ + # out_sharding must be None under auto shard_mode + if config.shard_mode != ShardMode.EXPLICIT: + out_sharding = None + embedding_table = _maybe_move_embedding_to_device(embedding_table, config) + return jnp.dot( + query, + jnp.asarray(embedding_table, jnp.bfloat16).T, + preferred_element_type=attend_dtype, + out_sharding=out_sharding, + ) + + +def rotary_embedding_as_linen( + *, + min_timescale: int, + max_timescale: int, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + name: str | None = None, +): + """Initializes the RotaryEmbedding module and returns it as a Linen module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + name: Name of the Linen module. + """ + return nnx_wrappers.to_linen( + RotaryEmbedding, + min_timescale=min_timescale, + max_timescale=max_timescale, + embedding_dims=embedding_dims, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +class RotaryEmbedding(nnx.Module): + """Rotary Position Embedding.""" + + def __init__( + self, + min_timescale: int, + max_timescale: int, + mesh: Mesh, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, + # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen. + # TODO: Remove when bridge no longer needed + rope_linear_scaling_factor: float = 1.0, + rngs: nnx.Rngs = None, + ): + """Initializes the RotaryEmbedding module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.mesh = mesh + self.embedding_dims = embedding_dims + self.cast_as_fprop_dtype = cast_as_fprop_dtype + self.fprop_dtype = fprop_dtype + self.shard_mode = shard_mode + self.rope_linear_scaling_factor = rope_linear_scaling_factor + + if self.embedding_dims % 2: + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") + + @property + def timescale(self): + """Returns the timescale for the rotary embedding.""" + half_embedding_dim = self.embedding_dims // 2 + fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + if self.rope_linear_scaling_factor != 1.0: + timescale = timescale * self.rope_linear_scaling_factor + return timescale + + def _rotate_half(self, x: jax.Array) -> jax.Array: + """Rotates half the hidden dims of the input: (x1, x2) -> (-x2, x1).""" + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-x2, x1), axis=-1) + + def apply_rotary(self, inputs: jax.Array, cos: jax.Array, sin: jax.Array) -> jax.Array: + """Applies the rotary transformation logic.""" + return (inputs * cos) + (self._rotate_half(inputs) * sin) + + def __call__( + self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks + inputs: jax.Array, + position: None | jax.Array = None, + ) -> jax.Array: + """Generates a jax.Array of sinusoids with different frequencies. + + Args: + inputs: The input sequence on which to apply the Rotary position + embedding. Since rotary position embeddings are applied to query and + keys after projection, it is assumed of shape [B, S, N, H]. + position: Optional position jax.Array which denotes the position of each + token in the sequence. This only needs to be supplied when the sequence + is packed. It is of shape [B, S]. + + Returns: + a jax.Array of shape [B, S, N, H] which includes the inputs together with + the rotary position embedding incorporated in it. + """ + assert position is not None + if len(inputs.shape) != 4: + raise ValueError("Input is assumed to be a rank 4 tensor of shape" "[batch, sequence, heads, dims].") + if self.embedding_dims != inputs.shape[3]: + raise ValueError( + "The embedding dims of the rotary position embedding" "must match the hidden dimension of the inputs." + ) + + position = position[:, :, jnp.newaxis, jnp.newaxis] + sinusoid_inp = position / self.timescale + sin_half = jnp.sin(sinusoid_inp).astype(inputs.dtype) + cos_half = jnp.cos(sinusoid_inp).astype(inputs.dtype) + + sin = jnp.concatenate([sin_half, sin_half], axis=-1) + cos = jnp.concatenate([cos_half, cos_half], axis=-1) + + x_out = self.apply_rotary(inputs, cos, sin) + + if self.cast_as_fprop_dtype: + x_out = x_out.astype(self.fprop_dtype) + return x_out + + +def llama_rotary_embedding_as_linen( + *, + min_timescale: int, + max_timescale: int, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + use_scale: bool = True, + name: str | None = None, +): + """Initializes the LLaMARotaryEmbedding module and returns it as a Linen module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + use_scale: Whether to apply LLaMA3.1 scaling factor. + name: Name of the Linen module. + """ + return nnx_wrappers.to_linen( + LLaMARotaryEmbedding, + min_timescale=min_timescale, + max_timescale=max_timescale, + embedding_dims=embedding_dims, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + use_scale=use_scale, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +def partial_rotary_embedding_as_linen( + *, + min_timescale: int, + max_timescale: int, + mesh: Mesh, + embedding_dims: int = 0, + partial_rotary_factor: float = 0.25, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, + name: str | None = None, +): + """Initializes the PartialRotaryEmbedding module and returns it as a Linen module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + partial_rotary_factor: Ratio of dimensions to apply ROPE to. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + name: Name of the Linen module. + """ + return nnx_wrappers.to_linen( + PartialRotaryEmbedding, + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + embedding_dims=embedding_dims, + partial_rotary_factor=partial_rotary_factor, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + shard_mode=shard_mode, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +class PartialRotaryEmbedding(RotaryEmbedding): + """Rotary Position Embedding applied to a partial fraction of dimensions.""" + + def __init__( + self, + min_timescale: int, + max_timescale: int, + mesh: Mesh, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + partial_rotary_factor: float = 0.25, + shard_mode: ShardMode = ShardMode.AUTO, + rngs: nnx.Rngs = None, + ): + """Initializes the PartialRotaryEmbedding module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + partial_rotary_factor: Ratio of dimensions to apply ROPE to + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + self.head_dim = embedding_dims + self.partial_rotary_factor = partial_rotary_factor + self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) + + # Initialize the base class with only the rotary_dim + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + embedding_dims=self.rotary_dim, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + shard_mode=shard_mode, + rngs=rngs, + ) + + def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array: + """Applies Partial variant of rotary position embedding. + + Args: + inputs: The input sequence on which to apply the Rotary position + embedding. It is assumed of shape [B, S, H, D]. + position: Optional position array [B, S]. Only needed when the sequence + is packed. + + Returns: + A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied. + """ + # Split, apply base RoPE to the first fraction, and concatenate + inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1) + inputs_rot = super().__call__(inputs_rot, position) + inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1) + return inputs + + +class LLaMARotaryEmbedding(RotaryEmbedding): + """LLaMA variant of ROPE.""" + + def __init__( + self, + min_timescale: int, + max_timescale: int, + mesh: Mesh, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + use_scale: bool = True, + shard_mode: ShardMode = ShardMode.AUTO, + # Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen. + # TODO: Remove when bridge no longer needed + rngs: nnx.Rngs = None, + ): + """Initializes the LLaMARotaryEmbedding module. + + Args: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + use_scale: Whether to apply LLaMA3.1 scaling factor. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + embedding_dims=embedding_dims, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + shard_mode=shard_mode, + rngs=rngs, + ) + + # LLaMA3.1 ROPE scaling, see the original pytorch implementation: + # https://github.com/meta-llama/llama-models/blob/301ca3a2b3b10e94ddcd1fdd2c57e52f812e1cac/models/llama3/reference_impl/model.py#L45C5-L45C18 + self.use_scale = use_scale + + @property + def timescale(self): + half_embedding_dim = self.embedding_dims // 2 + fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims + fraction = jnp.repeat(fraction, 2) + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + + # Apply scaling factor if enabled + if self.use_scale: + timescale = 1.0 / jax.vmap(self._apply_scaling_factor)(1.0 / timescale) + + # Expand timescale dimensions for broadcasting + return timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] + + def _apply_scaling_factor(self, freq): + """apply scaling factor to rotary position embedding.""" + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + wavelen = 2 * jnp.pi / freq + + def lower_wavelen(freq): + return freq + + def bigger_or_equal_wavelen(freq): + def bigger_wavelen(freq): + return freq / scale_factor + + def equal_wavelen(freq): + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + return (1 - smooth) * freq / scale_factor + smooth * freq + + bigger_wavelen_cond = wavelen > low_freq_wavelen + return jax.lax.cond(bigger_wavelen_cond, bigger_wavelen, equal_wavelen, freq) + + lower_wavelen_cond = wavelen < high_freq_wavelen + return jax.lax.cond(lower_wavelen_cond, lower_wavelen, bigger_or_equal_wavelen, freq) + + def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array: + """Applies LLaMA variant of rotary position embedding. + + Args: + inputs: The input sequence on which to apply the Rotary position + embedding. It is assumed of shape [B, S, N, H]. + position: Optional position array [B, S]. Only needed when the sequence + is packed. + + Returns: + A jax.Array of shape [B, S, N, H] with rotary position embeddings applied. + """ + # Ensure input is 4D + if len(inputs.shape) != 4: + raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].") + if self.embedding_dims != inputs.shape[3]: + raise ValueError( + "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." + ) + + # Shift the inputs left and right as per LLaMA's specific behavior + inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1) + inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1) + inputs_shifted = jax.lax.select( + jnp.tile( + jnp.mod(jnp.arange(self.embedding_dims, dtype=jnp.int32), 2), + inputs.shape[:-1] + (1,), + ), + inputs_shifted_right, + inputs_shifted_left, + ) + + # Determine positions if not provided + if position is None: + seq_length = inputs.shape[1] + position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] + + # Calculate sinusoidal input + position = position[:, :, jnp.newaxis, jnp.newaxis] + sinusoid_inp = position / self.timescale + + sin = jnp.sin(sinusoid_inp) + cos = jnp.cos(sinusoid_inp) + + # Apply alternating sign + sign = jnp.tile(jnp.array([-1, 1]), self.embedding_dims // 2) + + # Combine original inputs with sinusoidal information + outputs = inputs * cos + inputs_shifted * sin * sign + + if self.cast_as_fprop_dtype: + outputs = outputs.astype(self.fprop_dtype) + + return outputs + + +def yarn_rotary_embedding_as_linen( + *, + embedding_dims: int, + mesh: Mesh, + max_position_embeddings: int = 4096 * 4, + original_max_position_embeddings: int = 4096, + beta_fast: float = 32, + beta_slow: float = 1, + rope_theta: float = 10000.0, + rope_factor: float = 40, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + name: str | None = None, + interleave: bool = True, + truncate: bool = True, + attention_scaling: bool = False, + shard_mode: ShardMode = ShardMode.AUTO, +): + """Initializes the YarnRotaryEmbedding module and returns it as a Linen module. + + Args: + embedding_dims: The dimension of the embeddings. + max_position_embeddings: The maximum number of positions. + original_max_position_embeddings: The original maximum number of positions. + beta_fast: The fast beta parameter for YaRN. + beta_slow: The slow beta parameter for YaRN. + rope_theta: The base for the rotary frequencies. + rope_factor: The scaling factor for RoPE. + cast_as_fprop_dtype: Whether to cast the output to `fprop_dtype`. + fprop_dtype: The forward pass dtype. + name: The name of the module. + """ + return nnx_wrappers.to_linen( + YarnRotaryEmbedding, + embedding_dims=embedding_dims, + max_position_embeddings=max_position_embeddings, + mesh=mesh, + original_max_position_embeddings=original_max_position_embeddings, + beta_fast=beta_fast, + beta_slow=beta_slow, + rope_theta=rope_theta, + rope_factor=rope_factor, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + metadata_fn=variable_to_logically_partitioned, + name=name, + interleave=interleave, + truncate=truncate, + attention_scaling=attention_scaling, + shard_mode=shard_mode, + ) + + +class YarnRotaryEmbedding(nnx.Module): + """Yarn rotary embedding. + + Based on https://arxiv.org/abs/2309.00071 + This implementation uses DeepSeek-v3 PyTorch as reference + https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L294 + + Implementation Notes: + - YaRN vs. Standard RoPE: + 1. Frequency Initialization: YaRN modifies how frequencies are computed. + 2. Attention Scaling: YaRN typically scales embeddings by `0.1 * ln(rope_factor) + 1.0` + when `rope_factor > 1`. This scaling can be applied within this layer (if `attention_scaling=True`) + or externally. + - RoPE Implementation Details (General): + - Arithmetic: Uses complex number arithmetic. Real number arithmetic is not implemented here, + though the resulting embeddings would be equivalent. + - Input Layout: Supports both interleaved (`interleave=True`, e.g., [real1, img1, real2, img2]) and + concatenated (`interleave=False`, e.g., [real1, real2, img1, img2]) formats. + - Output Layout: Always returns concatenated format ([real, imag]). Interleaved output is not + implemented: While the embedding is different, attention scores are invariant, as long as we apply + the same output layout for Q and K. + + Attributes: + embedding_dims: Dimension of the embedding to be generated. + max_position_embeddings: The maximum sequence length that will be encountered. + original_max_position_embeddings: The sequence length for which the base frequencies were defined. + beta_fast: Lower bound parameter for correction. + beta_slow: Upper bound parameter for correction. + rope_theta: The base theta value for the frequency computation. + rope_factor: Factor applied to adjust the frequencies. + cast_as_fprop_dtype: Whether to cast the output to `fprop_dtype`. + fprop_dtype: The forward pass dtype. + rope_interleave: Whether complex representation is interleaved or concatenated. + rope_truncate: Whether or not to floor lower bound and ceil upper bound for correction range. + rope_attention_scaling: Whether or not to scale the rotary embedding output. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + + def __init__( + self, + embedding_dims: int, + mesh: Mesh, + max_position_embeddings: int = 4096 * 4, + original_max_position_embeddings: int = 4096, + beta_fast: float = 32, + beta_slow: float = 1, + rope_theta: float = 10000.0, + rope_factor: float = 40, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, + interleave=True, + truncate=True, + attention_scaling=False, + # Not used in YarnRotaryEmbedding but passed in by nnx.bridge.to_linen. + # TODO: Remove when bridge no longer needed + rngs: nnx.Rngs = None, + ): + """Initializes the YarnRotaryEmbedding module.""" + self.embedding_dims = embedding_dims + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.rope_theta = rope_theta + self.rope_factor = rope_factor + self.cast_as_fprop_dtype = cast_as_fprop_dtype + self.fprop_dtype = fprop_dtype + self.interleave = interleave + self.truncate = truncate + self.mesh = mesh + self.shard_mode = shard_mode + self.attention_scaling = attention_scaling + + self.freqs_sharding = ( + create_sharding(mesh, ("activation_batch", "activation_length_no_exp", "q_heads")) + if shard_mode == ShardMode.EXPLICIT + else None + ) + + if self.embedding_dims % 2: + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") + + @property + def freqs_cis(self): + """Frequencies for rotary embedding.""" + half_dim = self.embedding_dims // 2 + # Compute base frequencies for each (even-indexed) dimension. + # (Note: We use jnp.arange with float32 for precision.) + freqs = 1.0 / (self.rope_theta ** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.embedding_dims)) + + low, high = self._find_correction_range( + self.beta_fast, + self.beta_slow, + self.embedding_dims, + self.rope_theta, + self.original_max_position_embeddings, + self.truncate, + ) + smooth = 1 - self._linear_ramp_factor(low, high, half_dim) + # The corrected frequency is a weighted mix of the scaled and base values. + freqs = freqs / self.rope_factor * (1 - smooth) + freqs * smooth + + # Precompute frequencies for all positions by taking the outer product. + t = jnp.arange(self.max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings] + # This gives a [max_position_embeddings, half_dim] tensor with rows as time steps. + freqs = jnp.outer(t, freqs) + + # Compute the complex “cis” values: exp(i * theta). + return jnp.exp(1j * freqs) # shape [max_position_embeddings, half_dim] + + def _find_correction_dim(self, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float: + """Compute the correction dimension for a given number of rotations.""" + return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + + def _find_correction_range( + self, + low_rot: float, + high_rot: float, + dim: int, + base: float, + max_position_embeddings: int, + truncate: bool, + ): + """Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_position_embeddings (int): Maximum sequence length. + truncate (bool): Whether to floor lower bound and ceil upper bound. + + Returns: + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = self._find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = self._find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + low = max(low, 0) + high = min(high, dim - 1) + return low, high + + def _linear_ramp_factor(self, min_val: float, max_val: float, dim: int) -> Array: + """Computes a linear ramp over the dimension. + + Returns a jax.Array of shape (dim,) with values between 0 and 1. + """ + if min_val == max_val: + max_val += 0.001 # Avoid division by zero. + linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) + return jnp.clip(linear_func, 0, 1) + + def __call__(self, inputs: Array, position: None | Array = None) -> Array: + """Applies the rotary positional embedding using the precomputed complex frequencies. + + Args: + inputs: jax.Array of shape [B, S, N, H]. (H must equal self.embedding_dims.) + position: jax.Array of shape [B, S] with integer positions (indexes into precomputed freqs). + + Returns: + jax.Array of shape [B, S, N, H] with the rotary embedding applied. + """ + if len(inputs.shape) != 4: + raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].") + if self.embedding_dims != inputs.shape[3]: + raise ValueError( + "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." + ) + + # Determine positions if not provided + if position is None: + seq_length = inputs.shape[1] + position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :] + else: + position = position.astype(jnp.int32) + + # Lookup the precomputed frequencies using the position indices. + # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. + # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads. + freqs = self.freqs_cis.at[position].get(out_sharding=self.freqs_sharding) # shape: [B, S, half_dim] + freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim] + + if self.interleave: + # Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension + # Convert the last dimension into a complex representation. + # First reshape so that each pair of numbers represents the real and imaginary parts. + B, S, N, H = inputs.shape + half_dim = H // 2 + inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2) + first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1] + else: + # Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension + first_half, second_half = jnp.split(inputs, 2, axis=-1) + + inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] + # Apply the rotary transformation via complex multiplication. + rotated_sharding = ( + create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", None, None)) + if self.shard_mode == ShardMode.EXPLICIT + else None + ) + freqs = jnp.broadcast_to(freqs, inputs_complex.shape, out_sharding=rotated_sharding) + rotated = jnp.multiply(inputs_complex, freqs) # shape: [B, S, N, half_dim] + + # Convert the complex result back to a real tensor. + # Split the complex number into its real and imaginary parts. + # [real1, real2, ..., img1, img2, ...] + output = jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1) + + if self.attention_scaling: + attention_scaling = 1.0 if self.rope_factor <= 1 else (0.1 * math.log(self.rope_factor) + 1.0) + output = output * attention_scaling + + if self.cast_as_fprop_dtype: + output = output.astype(self.fprop_dtype) + return output + + +def positional_embedding_as_linen( + *, + embedding_dims: int, + max_wavelength: int = _MAX_WAVELENGTH, + cast_as_fprop_dtype: bool = False, + fprop_dtype: DType = jnp.bfloat16, +): + """Initializes the PositionalEmbedding module and returns it as a Linen module. + + Args: + embedding_dims: The dimension of the embeddings. + max_wavelength: The maximum wavelength for the sinusoidal positional embeddings. + cast_as_fprop_dtype: Whether to cast output to fprop_dtype. + fprop_dtype: The dtype of the output when cast_as_fprop_dtype is True. + """ + return nnx_wrappers.to_linen( + PositionalEmbedding, + embedding_dims=embedding_dims, + max_wavelength=max_wavelength, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + metadata_fn=variable_to_logically_partitioned, + ) + + +@dataclasses.dataclass(repr=False) +class PositionalEmbedding(nnx.Module): + """Sinusoidal positional embeddings supporting both uniform and per-batch positions. + + This module computes sinusoidal positional embeddings and supports two use cases: + + 1. Uniform positions across batch: All batch elements share the same position sequence. + Pass position as 1D array (seq_len,) or None for sequential [0,1,2,...]. + Returns (seq_len, embedding_dims), caller broadcasts to batch. + Example: pos_emb = layer(seq_len) # Sequential positions + pos_emb = layer(seq_len, position_1d) # Custom 1D positions + + 2. Per-batch positions (packed sequences): Each batch element has different positions. + Pass position as 2D array (batch, seq_len). + Returns (batch, seq_len, embedding_dims). + Example: pos_emb = layer(seq_len, position_2d) + + As a side effect, the uniform case is more efficient since sin/cos are computed once + and broadcasted, rather than per batch element. + """ + + #: The dimension of the embeddings. + embedding_dims: int + #: The maximum wavelength for the sinusoidal positional embeddings. + max_wavelength: int = _MAX_WAVELENGTH + #: Whether to cast output to fprop_dtype. + cast_as_fprop_dtype: bool = False + #: The dtype of the output when cast_as_fprop_dtype is True. + fprop_dtype: DType = jnp.bfloat16 + #: RNG state passed in by nnx.bridge.to_linen, not used in this module. + rngs: nnx.Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen + + def _compute_embeddings(self, position: Array) -> Array: + """Compute sinusoidal embeddings for given positions. + + Args: + position: Either (seq_len,) for efficient path or (batch, seq_len) for full path. + + Returns: + Embeddings of shape (seq_len, embedding_dims) or (batch, seq_len, embedding_dims). + """ + num_timescales = self.embedding_dims // 2 + log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum( + jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 + ) + inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) + + if position.ndim == 1: + # use the same position for the whole batch when position is (seq_len,) + scaled_time = position[:, jnp.newaxis] * inv_timescales[jnp.newaxis, :] + else: + # when position is (batch, seq_len) + position = position[:, :, jnp.newaxis] + inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] + scaled_time = position * inv_timescales + + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) + + if self.cast_as_fprop_dtype: + return signal.astype(self.fprop_dtype) + else: + return signal.astype(jnp.float32) + + def __call__( + self, + seq_len: int, + position: Array | None = None, + ) -> Array: + """Compute positional embeddings. + + Args: + seq_len: Sequence length for computing embeddings. + position: Optional position array. If None, uses sequential [0,1,2,...]. + Shape can be (seq_len,) or (batch, seq_len) for packed sequences. + + Returns: + Positional embeddings of shape (seq_len, embedding_dims) or + (batch, seq_len, embedding_dims) if position has batch dimension. + """ + if position is None: + position = jnp.arange(seq_len, dtype=jnp.float32) + + return self._compute_embeddings(position) + + +def llama_vision_rotary_embedding_as_linen( + *, + image_size: int, + patch_size: int, + hidden_size: int, + num_attention_heads: int, + rope_theta: float = 10000.0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + name: str | None = None, +): + """Initializes the LlamaVisionRotaryEmbedding module and returns it as a Linen module. + + Args: + image_size: The size of the input image. + patch_size: The size of the image patches. + hidden_size: The size of the hidden dimension. + num_attention_heads: The number of attention heads. + rope_theta: The base theta value for the frequency computation. + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. + fprop_dtype: The dtype of the output. + name: The name of the Linen module. + """ + return nnx_wrappers.to_linen( + LlamaVisionRotaryEmbedding, + image_size=image_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +@dataclasses.dataclass(repr=False) +class LlamaVisionRotaryEmbedding(nnx.Module): + """Rotary position embedding for Llama4 vision encoder. + + Based on Pytorch Reference + https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py + This implementation follows the Llama4 vision encoder's rotary embedding approach, + which uses 2D coordinates (x, y) to generate rotary position embeddings. + """ + + #: size of the input image + image_size: int + #: size of the image patches + patch_size: int + #: size of the hidden dimension + hidden_size: int + #: number of attention heads + num_attention_heads: int + #: base theta value for the frequency computation + rope_theta: float = 10000.0 + #: whether to cast the output to the fprop dtype + cast_as_fprop_dtype: bool = True + #: the dtype of the output + fprop_dtype: DType = jnp.bfloat16 + # Not used in LlamaVisionRotaryEmbedding but passed in by nnx.bridge.to_linen. + # TODO: Remove when bridge no longer needed + #: RNG state passed in by nnx.bridge.to_linen, not used in this module + rngs: nnx.Rngs = None + + @property + def freqs_cis(self): + """Frequencies for rotary embedding.""" + idx = self.image_size // self.patch_size + img_idx = jnp.arange(idx**2, dtype=jnp.int32).reshape(idx**2, 1) + img_idx = jnp.concatenate([img_idx, img_idx[:1]], axis=0) + img_idx = img_idx.at[-1, -1].set(-2) # ID_CLS_TOKEN + + # Get 2D coordinates + frequencies_x = img_idx % idx # x coordinates + frequencies_y = img_idx // idx # y coordinates + + # Compute frequency dimensions + freq_dim = self.hidden_size // self.num_attention_heads // 2 + rope_freq = 1.0 / (self.rope_theta ** (jnp.arange(0, freq_dim, 2)[: (freq_dim // 2)].astype(jnp.float32) / freq_dim)) + + # Compute frequencies for x and y coordinates + freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :] + freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :] + + # Interleave x and y frequencies + freqs_x = jnp.repeat(freqs_x, 2, axis=-1) + freqs_y = jnp.repeat(freqs_y, 2, axis=-1) + + # Combine frequencies + freqs = jnp.concatenate([freqs_x, freqs_y], axis=-1).astype(jnp.float32) + freqs = freqs[..., ::2] + + # Mask out invalid positions + freqs = jnp.where(img_idx.reshape(-1, 1, 1) < 0, 0, freqs) + # Convert to complex representation + return jnp.exp(1j * freqs) + + def __call__(self, inputs: Array, position: None | Array = None) -> Array: + """Applies rotary embeddings to the input tensor for Llama4 vision encoder. + + Args: + inputs: Input tensor of shape [batch_size_times_tiles, num_patches_incl_cls, num_heads, head_dim] + + Returns: + Tensor with rotary embeddings applied, maintaining the same shape as input. + """ + if len(inputs.shape) != 4: + raise ValueError( + """Input is assumed to be a rank 4 tensor of shape [batch_size_times_tiles, num_patches_incl_cls, + num_heads, head_dim].""" + ) + + # Reshape inputs to complex representation + B, S, N, H = inputs.shape + half_dim = H // 2 + + # Convert the last dimension into a complex representation. + # First reshape so that each pair of numbers represents the real and imaginary parts. + inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2) + inputs_complex = inputs_reshaped[..., 0] + 1j * inputs_reshaped[..., 1] + + # Reshape freqs_ci for broadcasting + freqs_ci = self.freqs_cis[jnp.newaxis, :, :, :] + + # Apply rotary transformation + rotated = inputs_complex * freqs_ci + + # Convert the complex result back to a real tensor. + # Split the complex number into its real and imaginary parts. + rotated_real = jnp.stack([jnp.real(rotated), jnp.imag(rotated)], axis=-1) + output = rotated_real.reshape(B, S, N, H) + + if self.cast_as_fprop_dtype: + output = output.astype(self.fprop_dtype) + + return output + + +class Qwen3OmniMoeVisionRotaryEmbedding(nnx.Module): + """Rotary position embedding for Qwen3OmniMoe vision encoder. + + Attributes: + hidden_size: Hidden dimension size + num_attention_heads: Number of attention heads + spatial_merge_size: Spatial merge block size (e.g., 2 for 2x2 blocks) + rope_theta: Base theta for frequency computation (default 10000.0) + cast_as_fprop_dtype: Whether to cast to fprop dtype + fprop_dtype: Output dtype + rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + spatial_merge_size: int, + rope_theta: float = 10000.0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3OmniMoe vision rotary embedding. + + Args: + hidden_size: Hidden dimension size + num_attention_heads: Number of attention heads + spatial_merge_size: Spatial merge block size (e.g., 2 for 2x2 blocks) + rope_theta: Base theta for frequency computation (default 10000.0) + cast_as_fprop_dtype: Whether to cast to fprop dtype + fprop_dtype: Output dtype + rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module + """ + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.spatial_merge_size = spatial_merge_size + self.rope_theta = rope_theta + self.cast_as_fprop_dtype = cast_as_fprop_dtype + self.fprop_dtype = fprop_dtype + self.rngs = rngs + self.head_dim = self.hidden_size // self.num_attention_heads + + def _compute_freq_table(self, max_hw: int) -> Array: + """Precompute frequency table for positions up to max_hw. + + Args: + max_hw: Maximum height or width dimension + + Returns: + Array of shape [max_hw, head_dim//4] containing frequencies for each position + """ + + inv_freq = 1.0 / (self.rope_theta ** (jnp.arange(0, self.head_dim // 2, 2, dtype=jnp.float32) / (self.head_dim // 2))) + # Compute for all positions [0, max_hw) + positions = jnp.arange(max_hw, dtype=jnp.float32) + freqs = jnp.outer(positions, inv_freq) # [max_hw, head_dim//4] + return freqs + + def _generate_position_ids_single(self, num_frames: int, height: int, width: int) -> Array: + """Generate 2D position IDs for a single image or video. + + Args: + num_frames: Number of temporal frames (1 for images, >1 for videos) + height: Height in patches + width: Width in patches + + Returns: + Array of shape [num_frames * height * width, 2] with (row_id, col_id) + """ + merge_size = self.spatial_merge_size + merged_h = height // merge_size + merged_w = width // merge_size + + # Block indices + block_rows = jnp.arange(merged_h) # [merged_h] + block_cols = jnp.arange(merged_w) # [merged_w] + + # Intra-block offsets + intra_row = jnp.arange(merge_size) # [merge_size] + intra_col = jnp.arange(merge_size) # [merge_size] + + # Full resolution positions using broadcasting + # Shape: [merged_h, 1, merge_size, 1] + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + # Shape: [1, merged_w, 1, merge_size] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + # Expand to full grid and flatten + row_idx = jnp.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1) + col_idx = jnp.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1) + + coords = jnp.stack([row_idx, col_idx], axis=-1) # [h*w, 2] + + # Repeat for video frames + if num_frames > 1: + coords = jnp.tile(coords, (num_frames, 1)) + + return coords + + def compute_cos_sin(self, num_frames: int, height: int, width: int) -> tuple[Array, Array]: + """Compute cos and sin embeddings for given static grid dimensions. + + Args: + num_frames: Number of temporal frames + height: Height in patches + width: Width in patches + + Returns: + Tuple of (cos_emb, sin_emb) each of shape [num_frames * height * width, head_dim] + """ + max_hw = max(height, width) + freq_table = self._compute_freq_table(max_hw) # [max_hw, head_dim//4] + coords = self._generate_position_ids_single(num_frames, height, width) # [T*H*W, 2] + + row_freqs = freq_table[coords[:, 0]] # [T*H*W, head_dim//4] + col_freqs = freq_table[coords[:, 1]] # [T*H*W, head_dim//4] + + # Concatenate row and column frequencies + embeddings = jnp.concatenate([row_freqs, col_freqs], axis=-1) # [T*H*W, head_dim//2] + + # Double the embeddings to match head_dim + embeddings = jnp.concatenate([embeddings, embeddings], axis=-1) # [T*H*W, head_dim] + + cos_emb = jnp.cos(embeddings) + sin_emb = jnp.sin(embeddings) + + if self.cast_as_fprop_dtype: + cos_emb = cos_emb.astype(self.fprop_dtype) + sin_emb = sin_emb.astype(self.fprop_dtype) + + return cos_emb, sin_emb + + def _rotate_half(self, x: Array) -> Array: + """Rotates half the hidden dims of the input. + + Args: + x: Input tensor of any shape with last dimension divisible by 2 + + Returns: + Rotated tensor where (x1, x2) -> (-x2, x1) + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return jnp.concatenate([-x2, x1], axis=-1) + + def __call__(self, inputs: Array, num_frames: int, height: int, width: int) -> Array: + """Apply rotary position embeddings directly to inputs (Q or K tensors). + + Args: + inputs: Input tensor of shape [B, T*H*W, N, head_dim] (batch, sequence, heads, head_dim) + where T=num_frames, H=height, W=width (all static) + num_frames: Number of temporal frames (static) + height: Height in patches (static) + width: Width in patches (static) + + Returns: + Rotated inputs with same shape [B, T*H*W, N, head_dim] + """ + cos_emb, sin_emb = self.compute_cos_sin(num_frames, height, width) + + if len(inputs.shape) == 4: + cos_emb = cos_emb[None, :, None, :] # [1, S, 1, H] + sin_emb = sin_emb[None, :, None, :] + elif len(inputs.shape) == 3: + # For [S, N, H] case + cos_emb = cos_emb[:, None, :] # [S, 1, H] + sin_emb = sin_emb[:, None, :] + + rotated = inputs * cos_emb + self._rotate_half(inputs) * sin_emb + + return rotated + + +def qwen3omnimoe_vision_pos_embed_interpolate_as_linen( + *, + num_position_embeddings: int, + hidden_size: int, + spatial_merge_size: int, + dtype: DType = jnp.float32, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + name: str | None = None, +): + """Initializes Qwen3OmniMoe bilinear position embedding interpolation as Linen module. + + This implements fast bilinear interpolation of learned 2D positional embeddings + for dynamic input sizes. The embeddings are learned on a fixed grid and interpolated + to match the actual image/video dimensions. + + Args: + num_position_embeddings: Number of position embeddings in the fixed grid (e.g., 1024 for 32x32) + hidden_size: Hidden dimension size + spatial_merge_size: Size of spatial merging blocks + dtype: Data type for embeddings + cast_as_fprop_dtype: Whether to cast the output to the fprop dtype + fprop_dtype: The dtype of the output + name: Module name + + Returns: + A Linen module that wraps the NNX Qwen3OmniMoeVisionPosEmbedInterpolate module. + """ + return nnx_wrappers.to_linen( + Qwen3OmniMoeVisionPosEmbedInterpolate, + num_position_embeddings=num_position_embeddings, + hidden_size=hidden_size, + spatial_merge_size=spatial_merge_size, + dtype=dtype, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) + + +class Qwen3OmniMoeVisionPosEmbedInterpolate(nnx.Module): + """Bilinear interpolation of learned 2D positional embeddings for Qwen3OmniMoe vision. + + This module maintains a fixed grid of learned positional embeddings and interpolates + them to match dynamic input dimensions using bilinear interpolation. This allows + the model to handle images/videos of varying sizes while using a fixed embedding table. + + Attributes: + num_position_embeddings: Number of position embeddings in the fixed grid + hidden_size: Hidden dimension size + spatial_merge_size: Spatial merge block size + dtype: Data type for embeddings + cast_as_fprop_dtype: Whether to cast to fprop dtype + fprop_dtype: Output dtype + rngs: RNG state passed in by nnx.bridge.to_linen + """ + + def __init__( + self, + num_position_embeddings: int, + hidden_size: int, + spatial_merge_size: int, + dtype: DType = jnp.float32, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3OmniMoe vision position embedding interpolation module. + + Args: + num_position_embeddings: Number of position embeddings in the fixed grid + hidden_size: Hidden dimension size + spatial_merge_size: Spatial merge block size + dtype: Data type for embeddings + cast_as_fprop_dtype: Whether to cast to fprop dtype + fprop_dtype: Output dtype + rngs: RNG state passed in by nnx.bridge.to_linen + """ + self.num_position_embeddings = num_position_embeddings + self.hidden_size = hidden_size + self.spatial_merge_size = spatial_merge_size + self.dtype = dtype + self.cast_as_fprop_dtype = cast_as_fprop_dtype + self.fprop_dtype = fprop_dtype + self.rngs = rngs + + # Initialize the learned position embedding table + if self.rngs is not None: + # Initialize with normal distribution scaled by hidden_size^(-0.5) + init_fn = nnx.initializers.normal(stddev=self.hidden_size**-0.5) + self.pos_embed = nnx.Param( + init_fn( + self.rngs.params(), + (self.num_position_embeddings, self.hidden_size), + self.dtype, + ), + ) + self.num_grid_per_side = int(self.num_position_embeddings**0.5) + + def _interpolate_single(self, t: int, h: int, w: int) -> tuple[Array, Array]: + """Compute bilinear interpolation indices and weights for a single image/video. + + Args: + t: Number of temporal frames + h: Target height in patches + w: Target width in patches + + Returns: + Tuple of (indices, weights) where: + - indices: [4, h*w] indices into pos_embed for 4 corners + - weights: [4, h*w] bilinear weights for 4 corners + """ + N = self.num_grid_per_side + + # Create interpolation coordinates + h_idxs = jnp.linspace(0, N - 1, h) + w_idxs = jnp.linspace(0, N - 1, w) + + # Floor and ceiling indices + h_idxs_floor = jnp.floor(h_idxs).astype(jnp.int32) + w_idxs_floor = jnp.floor(w_idxs).astype(jnp.int32) + h_idxs_ceil = jnp.minimum(h_idxs_floor + 1, N - 1) + w_idxs_ceil = jnp.minimum(w_idxs_floor + 1, N - 1) + + # Fractional parts for interpolation weights + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + # Compute flat indices for 2D grid + base_h = h_idxs_floor * N + base_h_ceil = h_idxs_ceil * N + + # 4 corner indices: (floor_h, floor_w), (floor_h, ceil_w), (ceil_h, floor_w), (ceil_h, ceil_w) + indices = jnp.stack( + [ + (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1), + (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1), + (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1), + (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1), + ], + axis=0, + ) # [4, h*w] + + # Bilinear weights + weights = jnp.stack( + [ + ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1), + ((1 - dh)[:, None] * dw[None, :]).reshape(-1), + (dh[:, None] * (1 - dw)[None, :]).reshape(-1), + (dh[:, None] * dw[None, :]).reshape(-1), + ], + axis=0, + ) # [4, h*w] + + return indices, weights + + def __call__(self, num_frames: int, height: int, width: int) -> Array: + """Interpolate positional embeddings for given static grid dimensions. + + Args: + num_frames: Number of temporal frames (static) + height: Height in patches (static) + width: Width in patches (static) + + Returns: + Interpolated positional embeddings of shape [num_frames * height * width, hidden_size] + """ + # Get interpolation indices and weights + indices, weights = self._interpolate_single(num_frames, height, width) # [4, h*w], [4, h*w] + + # Lookup embeddings for all 4 corners + corner_embeds = self.pos_embed.value[indices] # [4, h*w, hidden_size] + + # Apply bilinear weights and sum + weighted_embeds = corner_embeds * weights[:, :, None] # [4, h*w, hidden_size] + interpolated = jnp.sum(weighted_embeds, axis=0) # [h*w, hidden_size] + + # Repeat for temporal frames + if num_frames > 1: + interpolated = jnp.tile(interpolated, (num_frames, 1)) # [t*h*w, hidden_size] + + # Apply spatial merge permutation + # Reshape to [t, h, w, hidden_size] then permute for block-based processing + merge_size = self.spatial_merge_size + merged_h = height // merge_size + merged_w = width // merge_size + + # Reshape: [t*h*w, hidden_size] -> [t, h, w, hidden_size] + interpolated = interpolated.reshape(num_frames, height, width, self.hidden_size) + + # Permute for spatial merging: [t, merged_h, merge_size, merged_w, merge_size, hidden_size] + interpolated = interpolated.reshape(num_frames, merged_h, merge_size, merged_w, merge_size, self.hidden_size) + # -> [t, merged_h, merged_w, merge_size, merge_size, hidden_size] + interpolated = jnp.transpose(interpolated, (0, 1, 3, 2, 4, 5)) + # Flatten back to [t*merged_h*merged_w*merge_size*merge_size, hidden_size] + interpolated = interpolated.reshape(-1, self.hidden_size) + + if self.cast_as_fprop_dtype: + interpolated = interpolated.astype(self.fprop_dtype) + + return interpolated + + +class Qwen3OmniMoeThinkerTextRotaryEmbedding(RotaryEmbedding): + """Multi-dimensional Rotary Position Embedding (MRoPE) for Qwen3-Omni Thinker. + + This implements MRoPE which extends standard RoPE to handle 3D position IDs + (temporal, height, width) for multimodal sequences containing text and vision tokens. + + For text-only sequences, it uses standard 2D position IDs. + For sequences with vision tokens, it uses 3D position IDs where: + - Dimension 0: Temporal position + - Dimension 1: Height position (spatial) + - Dimension 2: Width position (spatial) + + The implementation uses an interleaved pattern that reorganizes frequency + components from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...]. + """ + + def __init__( + self, + min_timescale: int, + max_timescale: int, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + mrope_section: tuple[int, int, int] | None = None, + attention_scaling: float = 1.0, + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3OmniMoeThinkerTextRotaryEmbedding module. + + Args: + min_timescale: Start of the geometric index (typically 1). + max_timescale: End of the geometric index (rope_theta, e.g., 1000000). + embedding_dims: Dimension of the embedding (head_dim). + cast_as_fprop_dtype: Whether to cast output to fprop dtype. + fprop_dtype: The dtype of the output. + mrope_section: Tuple of (temporal_dim, height_dim, width_dim) for MRoPE. + Defaults to [24, 20, 20] if None. + attention_scaling: Scaling factor applied to cos/sin embeddings. Defaults to 1.0. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=None, + embedding_dims=embedding_dims, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + rngs=rngs, + ) + self.mrope_section = mrope_section if mrope_section is not None else (24, 20, 20) + self.attention_scaling = attention_scaling + + if self.embedding_dims % 2: + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") + + def _apply_interleaved_mrope(self, freqs: jax.Array) -> jax.Array: + """Apply interleaved MRoPE pattern to 3D rotary embeddings. + + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...], preserving frequency continuity. + + Args: + freqs: Shape (3, batch, seq_len, head_dim // 2) + Dimension 0: temporal frequencies + Dimension 1: height frequencies + Dimension 2: width frequencies + + Returns: + freqs_t: Shape (batch, seq_len, head_dim // 2) with interleaved pattern + """ + # Start with temporal frequencies (dimension 0) + freqs_t = freqs[0] # (batch, seq_len, head_dim // 2) + + # Create interleaved pattern + # For each spatial dimension (H, W), place frequencies at positions: + # offset=1 for H, offset=2 for W, with stride=3 + for dim_idx, offset in enumerate([1, 2], start=1): # H=1, W=2 + section_size = self.mrope_section[dim_idx] * 3 # Total positions for this dimension + # Select positions with stride 3, starting at offset + # Use slice syntax to match PyTorch behavior + idx = slice(offset, section_size, 3) + # Replace those positions with the corresponding spatial frequencies + freqs_t = freqs_t.at[..., idx].set(freqs[dim_idx, ..., idx]) + + return freqs_t + + def __call__( + self, + inputs: jax.Array, + position: jax.Array, + ) -> jax.Array: + """Generates rotary position embeddings for multimodal sequences. + + Args: + inputs: Input tensor of shape [batch, sequence, heads, head_dim]. + position: Position IDs with shape: + - [batch, sequence] for text-only (2D) + - [3, batch, sequence] for multimodal with vision (3D) + where dim 0 = temporal, dim 1 = height, dim 2 = width + + Returns: + Tensor of shape [batch, sequence, heads, head_dim] with RoPE applied. + """ + if len(inputs.shape) != 4: + raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, head_dim].") + if self.embedding_dims != inputs.shape[3]: + raise ValueError( + "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." + ) + + # Handle both 2D (text-only) and 3D (multimodal) position IDs + if position.ndim == 2: + # Text-only: expand (batch, seq) -> (3, batch, seq) with same positions + position = jnp.broadcast_to(position[jnp.newaxis, ...], (3,) + position.shape) + elif position.ndim != 3 or position.shape[0] != 3: + raise ValueError(f"Position IDs must be 2D (batch, seq) or 3D (3, batch, seq), got shape {position.shape}") + + # Compute frequencies: (3, batch, seq, 1) @ (head_dim // 2, 1) -> (3, batch, seq, head_dim // 2) + inv_freq_expanded = (1.0 / self.timescale)[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # (1, 1, 1, head_dim//2) + position_expanded = position[..., jnp.newaxis] # (3, batch, seq, 1) + freqs = position_expanded * inv_freq_expanded # (3, batch, seq, head_dim//2) + + # Apply interleaved MRoPE pattern for 3D positions + freqs = self._apply_interleaved_mrope(freqs) # (batch, seq, head_dim//2) + + # Compute sin and cos + # Concatenate to get full head_dim: (batch, seq, head_dim//2) -> (batch, seq, head_dim) + emb = jnp.concatenate([freqs, freqs], axis=-1) # Duplicate for both halves + cos_emb = jnp.cos(emb) * self.attention_scaling # (batch, seq, head_dim) + sin_emb = jnp.sin(emb) * self.attention_scaling # (batch, seq, head_dim) + + # Expand for heads dimension: (batch, seq, head_dim) -> (batch, seq, 1, head_dim) + cos_emb = cos_emb[:, :, jnp.newaxis, :] + sin_emb = sin_emb[:, :, jnp.newaxis, :] + + x_out = self.apply_rotary(inputs, cos_emb, sin_emb) + + if self.cast_as_fprop_dtype: + x_out = x_out.astype(self.fprop_dtype) + + return x_out + + +def qwen3_omni_mrope_embedding_as_linen( + *, + min_timescale: int, + max_timescale: int, + embedding_dims: int = 0, + cast_as_fprop_dtype: bool = True, + fprop_dtype: DType = jnp.bfloat16, + mrope_section: tuple[int, int, int] | None = None, + name: str | None = None, +): + """Initializes Qwen3OmniMoeThinkerTextRotaryEmbedding and returns it as a Linen module. + + Args: + min_timescale: Start of the geometric index. + max_timescale: End of the geometric index (rope_theta). + embedding_dims: Dimension of the embedding (head_dim). + cast_as_fprop_dtype: Whether to cast output to fprop dtype. + fprop_dtype: The dtype of the output. + mrope_section: Tuple of (temporal_dim, height_dim, width_dim) for MRoPE. + name: Name of the Linen module. + """ + return nnx_wrappers.to_linen( + Qwen3OmniMoeThinkerTextRotaryEmbedding, + min_timescale=min_timescale, + max_timescale=max_timescale, + embedding_dims=embedding_dims, + cast_as_fprop_dtype=cast_as_fprop_dtype, + fprop_dtype=fprop_dtype, + mrope_section=mrope_section, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) diff --git a/MaxCode/rag/sources/generic/maxtext_layers_linears.py b/MaxCode/rag/sources/generic/maxtext_layers_linears.py new file mode 100644 index 0000000..4af9c5c --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_layers_linears.py @@ -0,0 +1,571 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Linear Layers.""" + +import functools +import operator +from typing import Any, Callable, Iterable, Sequence + +import numpy as np +import jax +import jax.numpy as jnp + +from jax import lax +from jax.sharding import NamedSharding, Mesh +from jax.ad_checkpoint import checkpoint_name + +from flax import nnx +import flax.linen as nn + +from maxtext.common.common_types import DecoderBlockType, ShardMode, DType, Array, Config +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT +from maxtext.layers import nnx_wrappers, quantizations +from maxtext.layers import normalizations +from maxtext.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.utils import max_logging +from maxtext.utils import max_utils +from maxtext.utils.sharding import maybe_shard_with_logical + + +def _convert_to_activation_function(fn_or_string: str | Callable[..., Any]) -> Callable[..., Any]: + """Convert a string to an activation function.""" + if fn_or_string == "linear": + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError( + f"""Don't know how to convert {fn_or_string} + to an activation function""" + ) + + +def normalize_axes(axes: Iterable[int], ndim: int) -> tuple[int, ...]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +def _compute_dot_general(inputs, kernel, kernel_axes, axis, contract_ind, matmul_precision, quant): + """Computes a dot_general operation that may be quantized.""" + dot_general = lax.dot_general + matmul_precision = lax.Precision(matmul_precision) + if quant: + dot_general_cls = quant.dot_general_cls(mesh_axes=kernel_axes) + dot_general = dot_general_cls() + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + +def _compute_dot_general_nnx( + inputs, + kernel, + axis, + contract_ind, + matmul_precision, + quant_dot_general: nnx_wrappers.ToNNX | None, + initializing: bool, + out_sharding: NamedSharding | None = None, +): + """Computes a dot_general operation that may be quantized.""" + dot_general = lax.dot_general + matmul_precision = lax.Precision(matmul_precision) + if quant_dot_general is not None: + if initializing: + quant_dot_general.lazy_init(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + return quant_dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None, mutable=["aqt"]) + + return dot_general( + inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision, out_sharding=out_sharding + ) + + +class DenseGeneral(nnx.Module): + """A linear transformation with flexible axes.""" + + def __init__( + self, + in_features_shape: Iterable[int] | int, + out_features_shape: Iterable[int] | int, + axis: Iterable[int] | int = -1, + weight_dtype: DType = jnp.float32, + dtype: DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes: tuple[None | str, ...] = (), + quant: None | Quant = None, + use_bias: bool = False, + shard_mode: ShardMode = ShardMode.AUTO, + matmul_precision: str = "default", + parameter_memory_host_offload: bool = False, + *, # Following arguments are keyword-only + rngs: nnx.Rngs = None, + ): + """Initializes the DenseGeneral module. + + Args: + in_features_shape: tuple with numbers of input features for axes specified in + 'axis'. + out_features_shape: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + kernel_axes: logical axes for partitioning the kernel. + quant: quantization config, defaults to None implying no quantization. + use_bias: whether to add bias in linear transformation. + shard_mode: auto or explicit shard mode. + matmul_precision: Precision for matrix multiplication. + parameter_memory_host_offload: Determines whether to offload params to host + rngs: RNG state for initialization in nnx. + """ + self.in_features_shape = canonicalize_tuple(in_features_shape) + self.out_features_shape = canonicalize_tuple(out_features_shape) + self.axis = canonicalize_tuple(axis) + self.weight_dtype = weight_dtype + self.dtype = dtype + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.quant = quant + self.use_bias = use_bias + self.shard_mode = shard_mode + self.matmul_precision = matmul_precision + self.parameter_memory_host_offload = parameter_memory_host_offload + + # Parameter initialization + kernel_shape = self.in_features_shape + self.out_features_shape + kernel_in_axis = np.arange(len(self.axis)) + kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape)) + + if not quantizations.in_serve_mode(self.quant): + self.kernel = nnx.Param( + self.kernel_init( + rngs.params(), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.kernel_axes, + ) + + if self.use_bias: + bias_axes = self.kernel_axes[-len(self.out_features_shape) :] + bias_shape = kernel_shape[-len(self.out_features_shape) :] + self.bias = nnx.Param( + default_bias_init(rngs.params(), bias_shape, self.weight_dtype), + sharding=bias_axes, + ) + else: + self.bias = None + + if quant: + dot_general_cls = quant.dot_general_cls(mesh_axes=kernel_axes) + dot_general_linen = dot_general_cls() + quant_dot_general = nnx_wrappers.ToNNX(dot_general_linen, rngs=rngs) + self._quant_dot_general_name = f"{type(dot_general_linen).__name__}_0" + setattr(self, self._quant_dot_general_name, quant_dot_general) + block_size = getattr(quant, "get_block_size", lambda: 1)() # needed for TE MXFP8 + dummy_inputs = jnp.zeros((block_size, *self.in_features_shape), dtype=self.dtype) + self(dummy_inputs, _initializing=True) + else: + self._quant_dot_general_name = None + + @property + def quant_dot_general(self) -> nnx_wrappers.ToNNX | None: + if self._quant_dot_general_name is None: + return None + return getattr(self, self._quant_dot_general_name) + + def __call__(self, inputs: Array, _initializing: bool = False, out_sharding: NamedSharding | None = None) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + inputs = jnp.asarray(inputs, self.dtype) + norm_axis = normalize_axes(self.axis, inputs.ndim) + + for i, ax in enumerate(norm_axis): + if inputs.shape[ax] != self.in_features_shape[i]: + raise ValueError( + f"Input dimension {inputs.shape[ax]} at axis {ax} " + f"does not match expected input feature size {self.in_features_shape[i]}" + ) + + if quantizations.in_serve_mode(self.quant): + kernel_shape = self.in_features_shape + self.out_features_shape + kernel = jnp.zeros(kernel_shape, dtype=self.dtype) + else: + kernel = self.kernel[...] + # Move logit_dense kernel to device if parameter offloading is enabled + if self.parameter_memory_host_offload: + max_logging.log("linear.py: Moving parameter logits_dense kernel to device") + kernel = jax.device_put(kernel, max_utils.device_space()) + kernel = jnp.asarray(kernel, self.dtype) + + # out_sharding should be None for auto mesh axis + if self.shard_mode != ShardMode.EXPLICIT: + out_sharding = None + + contract_ind = tuple(range(0, len(self.axis))) + output = _compute_dot_general_nnx( + inputs, + kernel, + norm_axis, + contract_ind, + self.matmul_precision, + self.quant_dot_general, + _initializing, + out_sharding, + ) + + if self.bias is not None: + bias = jnp.asarray(self.bias[...], self.dtype) + output += bias + return output + + +def dense_general( + *, + inputs_shape: tuple[int, ...] | None = None, + in_features_shape: tuple[int, ...] | int | None = None, + out_features_shape: Iterable[int] | int, + axis: Iterable[int] | int = -1, + weight_dtype: DType = jnp.float32, + dtype: DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes: tuple[None | str, ...] = (), + quant: None | Quant = None, + use_bias: bool = False, + shard_mode: ShardMode = ShardMode.AUTO, + matmul_precision: str = "default", + parameter_memory_host_offload: bool = False, + name: None | str = None, +): + """Creates a DenseGeneral Linen module using nnx.bridge.to_linen. + + Args: + inputs_shape: tuple with the shape of the inputs + in_features_shape: tuple with numbers of input features for axes specified in + 'axis'. + out_features_shape: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + kernel_axes: logical axes for partitioning the kernel. + quant: quantization config, defaults to None implying no quantization. + use_bias: whether to add bias in linear transformation. + shard_mode: indicating the shard mode + matmul_precision: Precision for matrix multiplication. + parameter_memory_host_offload: Determines whether to offload params to host + name: name passed to the ToLinen Module + """ + if not (inputs_shape is not None) ^ (in_features_shape is not None): + raise ValueError("Exactly one of inputs_shape or in_features must be specified.") + + if inputs_shape is not None: + axis = canonicalize_tuple(axis) + in_features_shape = tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape))) + else: + assert in_features_shape is not None + module = nnx_wrappers.to_linen( + DenseGeneral, + in_features_shape=in_features_shape, + out_features_shape=out_features_shape, + axis=axis, + weight_dtype=weight_dtype, + dtype=dtype, + kernel_init=kernel_init, + kernel_axes=kernel_axes, + quant=quant, + use_bias=use_bias, + shard_mode=shard_mode, + matmul_precision=matmul_precision, + parameter_memory_host_offload=parameter_memory_host_offload, + name=name, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + return module + + +class Dropout(nnx.Dropout): + """Forked nnx.Dropout that is easier to use with bridge""" + + def __init__( # pylint: disable=super-init-not-called + self, + rate: float, + *, + broadcast_dims: Sequence[int] = (), + deterministic: bool = False, + rng_collection: str = "dropout", + rngs: nnx.Rngs | None = None, + ): + self.rate = rate + self.broadcast_dims = broadcast_dims + self.deterministic = deterministic + self.rng_collection = rng_collection + + if isinstance(rngs, nnx.Rngs): + self.rngs = rngs.fork() if hasattr(type(rngs), "fork") else rngs + else: + raise TypeError(f"rngs must be a Rngs, RngStream or None, but got {type(rngs)}.") + + +class MlpBlock(nnx.Module): + """Transformer MLP / feed-forward block.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + in_features: int, + intermediate_dim: int = 2048, + activations: Sequence[str | Callable[..., Any]] = ("relu",), + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + intermediate_dropout_rate: float = 0.1, + dtype: Any = jnp.float32, + weight_dtype: Any = jnp.float32, + use_bias: bool = False, + use_pre_norm: bool = False, + quant: None | Quant = None, + model_mode: None | str = None, + *, + rngs: nnx.Rngs, + ) -> None: + """A MlpBlock module. + + Args: + config: Config object containing model parameters. + mesh: Mesh object of device and physical axes information + in_features: Number of input features. + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: computation data type for the dense layer. + weight_dtype: weight data type for the dense layer. + use_bias: whether to add bias in all feedforward layers. + use_pre_norm: whether to add pre layer norm in mlp layers. + quant: Optional quantization config, no quantization if None. + out_sharding: Named sharding of outputs + """ + self.config = config + self.mesh = mesh + self.in_features = in_features + self.intermediate_dim = intermediate_dim + self.activations = activations + self.kernel_init = kernel_init + self.intermediate_dropout_rate = intermediate_dropout_rate + self.dtype = dtype + self.weight_dtype = weight_dtype + self.use_bias = use_bias + self.use_pre_norm = use_pre_norm + self.quant = quant + self.model_mode = model_mode + + if self.use_pre_norm: + self.mlp_layer_norm = self.get_norm_layer(num_features=in_features)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + else: + self.mlp_layer_norm = None + + if self.model_mode == MODEL_MODE_PREFILL: + self.intermediate_logical = ("activation_batch", "prefill_activation_length", "activation_mlp") + elif config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + self.intermediate_logical = ("activation_batch_no_exp", "activation_length", "activation_mlp") + else: + self.intermediate_logical = ("activation_batch", "activation_length_no_exp", "activation_mlp") + + if config.fused_mlp: + self.wi = DenseGeneral( + in_features_shape=in_features, + out_features_shape=(len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + quant=self.quant, + use_bias=self.use_bias, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + rngs=rngs, + ) + else: + for idx in range(len(self.activations)): + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" + module = DenseGeneral( + in_features_shape=in_features, + out_features_shape=self.intermediate_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "mlp"), + quant=self.quant, + use_bias=self.use_bias, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + rngs=rngs, + ) + setattr(self, dense_name, module) + self.dropout = Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,), rngs=rngs) + self.wo = DenseGeneral( + in_features_shape=self.intermediate_dim, + out_features_shape=in_features, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("mlp", "embed"), + quant=self.quant, + use_bias=self.use_bias, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + rngs=rngs, + ) + + self._maybe_shard_with_logical = functools.partial( + maybe_shard_with_logical, + mesh=mesh, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + ) + + def get_norm_layer(self, num_features: int): + """get normalization layer.""" + if self.config.decoder_block in ( + DecoderBlockType.DEFAULT, + DecoderBlockType.LLAMA2, + DecoderBlockType.MISTRAL, + DecoderBlockType.MIXTRAL, + DecoderBlockType.GEMMA, + DecoderBlockType.GEMMA2, + DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN3, + DecoderBlockType.DEEPSEEK, + DecoderBlockType.LLAMA4, + ): + return functools.partial(normalizations.RMSNorm, num_features=num_features) + elif self.config.decoder_block == DecoderBlockType.GPT3: + from maxtext.models import gpt3 # pylint: disable=import-outside-toplevel + + return functools.partial( + gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=self.use_bias + ) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def __call__( + self, + inputs, + decode: bool = False, + deterministic: bool = False, + intermediate_sharding: NamedSharding | None = None, + out_sharding: NamedSharding | None = None, + ): + """Applies Transformer MlpBlock module.""" + cfg = self.config + + if self.mlp_layer_norm is not None: + inputs = self.mlp_layer_norm(inputs) + + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + if cfg.fused_mlp: + x = self.wi(inputs, out_sharding=intermediate_sharding) + x = checkpoint_name(x, "mlpwi") + for idx, act_fn in enumerate(self.activations): + y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) + activations.append(y) + else: + for idx, act_fn in enumerate(self.activations): + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" + module = getattr(self, dense_name) + x = module(inputs, out_sharding=intermediate_sharding) + x = checkpoint_name(x, "mlp" + dense_name) + if cfg.activations_in_float32: + x = x.astype(jnp.float32) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations).astype(self.dtype) + # Apply dropout and final dense output projection. + x = self.dropout(x, deterministic=deterministic) # Broadcast along length. + x = self._maybe_shard_with_logical(x, self.intermediate_logical) + output = self.wo(x, out_sharding=out_sharding) + + output = checkpoint_name(output, "mlpwo") + return output + + +def mlp_block( + *, + config: Config, + mesh: Mesh, + in_features: int, + intermediate_dim: int = 2048, + activations: Sequence[str | Callable[..., Any]] = ("relu",), + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + intermediate_dropout_rate: float = 0.1, + dtype: Any = jnp.float32, + weight_dtype: Any = jnp.float32, + use_bias: bool = False, + use_pre_norm: bool = False, + quant: None | Quant = None, + model_mode: None | str = None, + name: None | str = None, +): + """Creates a MlpBlock Linen module using nnx.bridge.to_linen.""" + module = nnx_wrappers.to_linen( + MlpBlock, + config=config, + mesh=mesh, + in_features=in_features, + intermediate_dim=intermediate_dim, + activations=activations, + kernel_init=kernel_init, + intermediate_dropout_rate=intermediate_dropout_rate, + dtype=dtype, + weight_dtype=weight_dtype, + use_bias=use_bias, + use_pre_norm=use_pre_norm, + quant=quant, + model_mode=model_mode, + name=name, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + return module diff --git a/MaxCode/rag/sources/generic/maxtext_layers_normalizations.py b/MaxCode/rag/sources/generic/maxtext_layers_normalizations.py new file mode 100644 index 0000000..195d5bc --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_layers_normalizations.py @@ -0,0 +1,228 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization Layers.""" + +from typing import Any + +from flax import linen as nn +from flax import nnx +from flax.linen import initializers as linen_initializers +import jax +from jax import lax +import jax.numpy as jnp +from jax.sharding import NamedSharding +from maxtext.common.common_types import Array, DType, ShardMode +from maxtext.layers import nnx_wrappers +from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned +from maxtext.utils import max_logging +from maxtext.utils import max_utils + + +class RMSNorm(nnx.Module): + """RMS normalization.""" + + def __init__( + self, + num_features: int, + epsilon: float = 1e-6, + dtype: Any = jnp.float32, + weight_dtype: Any = jnp.float32, + shard_mode: ShardMode = ShardMode.AUTO, + kernel_axes: tuple[None | str, ...] = (), + scale_init: Initializer = nn.initializers.ones, + parameter_memory_host_offload: bool = False, + scale_offset: float = 0.0, + *, + rngs: nnx.Rngs, + ): + self.num_features = num_features + self.epsilon = epsilon + self.dtype = dtype + self.weight_dtype = weight_dtype + self.shard_mode = shard_mode + self.kernel_axes = kernel_axes + self.scale_init = scale_init + self.parameter_memory_host_offload = parameter_memory_host_offload + self.scale_offset = scale_offset + self.scale = nnx.Param( + scale_init(rngs.params(), (num_features,), weight_dtype), + sharding=kernel_axes, + ) + + def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = self.scale.value + # Move scale to device if parameter offloading is enabled + if self.parameter_memory_host_offload: + max_logging.log("normalizations.py: Moving scale parameter to device") + scale = jax.device_put(scale, max_utils.device_space()) + # out_sharding must be None in auto shard mode + if self.shard_mode != ShardMode.EXPLICIT: + out_sharding = None + + scale = jnp.asarray(scale, self.dtype) + effective_scale = scale + self.scale_offset # Apply offset + return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding) + + +class GlobalRMSNorm(RMSNorm): + """ + Applies RMSNorm over the last two dimensions (Heads * HeadDim). + Used for Olmo3 which normalizes across all heads combined. + """ + + def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> jnp.ndarray: + # x shape: [..., Heads, HeadDim] + input_shape = x.shape + + # Flatten the last two dimensions: [..., Heads * HeadDim] + # We use -2 and -1 to ensure we capture the last two dims regardless of rank + flattened_shape = input_shape[:-2] + (input_shape[-2] * input_shape[-1],) + x_flat = x.reshape(flattened_shape) + + # Apply standard RMSNorm (which normalizes over the last axis) + y_flat = super().__call__(x_flat, out_sharding) + + # Reshape back to [..., Heads, HeadDim] + return y_flat.reshape(input_shape) + + +def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): + """ + Used for input and post attention layernorms + in Qwen3NextDecoderLayer. + + This normalization layer is specific to Qwen3-Next. Key characteristics: + 1. The learnable scale parameter `scale` is initialized to ZEROS. + 2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0. + This matches the PyTorch implementation of Qwen3NextRMSNorm. + """ + return nnx.data( + RMSNorm( + num_features=num_features, + epsilon=eps, + dtype=dtype, + weight_dtype=weight_dtype, + scale_init=linen_initializers.zeros, + scale_offset=1.0, + rngs=rngs, + ) + ) + + +class Qwen3NextRMSNormGated(nnx.Module): + """ + This applies RMS Normalization and then a gated activation function (SiLU). + This is used within the Qwen3NextGatedDeltaNet. + + The normalization is performed by an internal `RMSNorm` instance (`self.rms_norm`), + which has its own learnable `scale` parameter, initialized to ONES. + + Attributes: + num_features: The number of features in the input. + eps: A small epsilon value to prevent division by zero in RMSNorm. + dtype: The datatype of the computation. + weight_dtype: The datatype of the internal RMSNorm scale. + """ + + def __init__(self, num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): + self.num_features = num_features + self.eps = eps + self.dtype = dtype + self.weight_dtype = weight_dtype + self.rms_norm = nnx.data( + RMSNorm( + num_features=num_features, + epsilon=eps, + dtype=dtype, + weight_dtype=weight_dtype, + scale_init=nnx.initializers.ones, + rngs=rngs, + ) + ) + + def __call__(self, hidden_states: Array, gate: Array) -> Array: + """ + Applies RMSNorm and then a SiLU gate. + + Args: + hidden_states: The input array to be normalized (o). Shape: (..., F) + gate: The gating array for the activation (z). Shape: (..., F) + where F is num_features. + + Returns: + The normalized and gated output array. Shape: (..., F) + """ + normalized_states = self.rms_norm(hidden_states) + + # Gated Activation using SiLU (Sigmoid-weighted Linear Unit) + gated_states = normalized_states * jax.nn.silu(gate.astype(jnp.float32)) + + return gated_states.astype(self.dtype) + + +def rms_norm( + num_features: int, + epsilon: float = 1e-6, + dtype: Any = jnp.float32, + weight_dtype: Any = jnp.float32, + shard_mode: ShardMode = ShardMode.AUTO, + kernel_axes: tuple[None | str, ...] = (), + scale_init: Initializer = nn.initializers.ones, + name: None | str = None, + parameter_memory_host_offload: bool = False, +): + """Creates a RMSNorm module.""" + module = nnx_wrappers.to_linen( + RMSNorm, + num_features=num_features, + epsilon=epsilon, + dtype=dtype, + weight_dtype=weight_dtype, + shard_mode=shard_mode, + kernel_axes=kernel_axes, + scale_init=scale_init, + parameter_memory_host_offload=parameter_memory_host_offload, + name=name, + metadata_fn=variable_to_logically_partitioned, + ) + return module + + +def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array: + """L2 normalization function. Normalizes a vector to have a length of 1. + + Args: + x: Input array. + dim: The axis or axes along which to normalize. Defaults to the last axis. + eps: Small epsilon to prevent division by zero. + + Returns: + L2 normalized array with the same shape as x. + """ + + inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype)) + return x * inv_norm + + +Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class( + RMSNorm, + base_metadata_fn=variable_to_logically_partitioned, + scale_init=linen_initializers.zeros, + scale_offset=1.0, +) diff --git a/MaxCode/rag/sources/generic/maxtext_models_deepseek.py b/MaxCode/rag/sources/generic/maxtext_models_deepseek.py new file mode 100644 index 0000000..6d502d9 --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_models_deepseek.py @@ -0,0 +1,531 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer model definition.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +import functools +from typing import Optional + +from flax import nnx +import jax +from jax.ad_checkpoint import checkpoint_name +import jax.numpy as jnp +from jax.sharding import Mesh +from maxtext.common.common_types import Config +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.inference import page_manager +from maxtext.layers import attention_mla +from maxtext.layers import initializers +from maxtext.layers import linears +from maxtext.layers import mhc +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.linears import Dropout +from maxtext.layers.engram import Engram +from maxtext.layers.engram import NgramHashMapping +from maxtext.layers.normalizations import RMSNorm +from maxtext.models import deepseek_batchsplit +from maxtext.utils import max_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical + +import transformers + +# ----------------------------------------- +# The Decoder Layer for DeepSeek v3 +# ----------------------------------------- + + +class DeepSeekGenericLayer(nnx.Module): + """Generic DeepSeek layer with Multi-Head Latent Attention. + + This is to be used as a base class for DeepSeek layers with dense/sparse MLPs. + This class follows a pattern of separating module creation from execution. + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + self.config = config + self.model_mode = model_mode + self.mesh = mesh + self.quant = quant + self.rngs = rngs + self.is_mhc_enabled = config.mhc_expansion_rate > 1 + self.layer_idx = layer_idx + self.is_engram_enabled = config.engram_layers and layer_idx in config.engram_layers + + batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode) + self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) + + self.out_sharding = create_sharding(self.mesh, self.logical_axis_names) + self.mlp_intermediate_sharding = create_sharding(self.mesh, self.mlp_logical_axis_names) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=self.dummy_inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=self.dummy_inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=rngs, + ) + + if self.is_engram_enabled: + self.engram_layer_norm = RMSNorm( + num_features=self.dummy_inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=rngs, + ) + tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path, token=config.hf_access_token) + # TODO(ranran): Refactor NgramHashMapping to initialize once globally or at the model level. + # Moving this to decoders.py currently causes JAX initialization errors. + self.ngram_hash_mapping = NgramHashMapping( + engram_vocab_bases=config.engram_vocab_bases, + max_ngram_size=config.engram_max_ngram_size, + engram_num_heads=config.engram_num_heads, + layer_ids=config.engram_layers, + tokenizer=tokenizer, + pad_id=tokenizer.pad_token_id, + seed=config.engram_seed, + ) + self.engram = Engram( + config=config, + mesh=mesh, + vocab_sizes=self.ngram_hash_mapping.get_vocab_sizes(layer_idx), + engram_num_heads=config.engram_num_heads, + engram_head_dim=config.engram_head_dim, + engram_max_ngram_size=config.engram_max_ngram_size, + engram_kernel_size=config.engram_kernel_size, + mhc_expansion_rate=config.mhc_expansion_rate, + quant=quant, + rngs=rngs, + ) + else: + self.engram_layer_norm = None + self.engram = None + + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) + + self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + if self.is_mhc_enabled: + self.mhc_attention = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs) + self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs) + + def mlp_op(self, x, deterministic, *args, **kwargs): + """Executes the MLP operation. To be implemented by subclasses.""" + raise NotImplementedError() + + def with_logical_constraint(self, x): + return maybe_shard_with_logical( + x, + logical_axes=self.logical_axis_names, + mesh=self.mesh, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=1, + ) + + def dropout_op(self, x, deterministic): + dropout = self.dropout(x, deterministic=deterministic) + return self.with_logical_constraint(dropout) + + def pre_attention_norm_op(self, x): + pre_attention_norm = self.pre_self_attention_layer_norm(x) + return self.with_logical_constraint(pre_attention_norm) + + def post_attention_norm_op(self, x): + post_attention_norm = self.post_self_attention_layer_norm(x) + return self.with_logical_constraint(post_attention_norm) + + def attention_op( + self, + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + ): + """Executes the attention layer.""" + attention_result, _ = self.self_attention( + x, + x, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=self.model_mode, + out_sharding=self.out_sharding, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + return self.with_logical_constraint(attention_result) + + @property + def logical_axis_names(self): + """Generate logical names for activations generally.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_embed"] + return axis_names + + @property + def mlp_logical_axis_names(self): + """Generate logical names for activations in MLP.""" + length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" + axis_names = ["activation_batch", length_name, "activation_mlp"] + return axis_names + + def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None): + """postprocessing.""" + + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) + + if self.config.routed_bias and self.config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + self.sow(nnx.Intermediate, "moe_bias_updates", moe_bias_updates) + + if self.config.record_internal_nn_metrics: + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) + self.sow( + nnx.Intermediate, + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if self.config.scan_layers: + return layer_output, None + return layer_output, kv_cache + + def self_attention_with_norm_op( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + ): + """self-attention with normalization""" + if self.is_mhc_enabled: + intermediate_inputs, _ = self.mhc_attention( + self.pre_attention_norm_op, + self.self_attention, + x=inputs, + mhc_type=HyperConnectionType.ATTENTION, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=deterministic, + model_mode=self.model_mode, + out_sharding=self.out_sharding, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + else: + lnx = self.pre_attention_norm_op(inputs) + attention_lnx = self.attention_op( + lnx, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + intermediate_inputs = inputs + attention_lnx + # Normalization + hidden_states = self.post_attention_norm_op(intermediate_inputs) + return hidden_states, intermediate_inputs + + def engram_op(self, x, decoder_input_tokens): + normed_x = self.engram_layer_norm(x) + hash_ids = self.ngram_hash_mapping(decoder_input_tokens)[self.layer_idx] + return self.engram(normed_x, hash_ids) + + +class DeepSeekDenseLayer(DeepSeekGenericLayer): + """DeepSeek-style dense layer with Multi-Head Latent Attention.""" + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) + self.mlp = linears.MlpBlock( + in_features=self.dummy_inputs_shape[-1], + intermediate_dim=self.config.mlp_dim, + activations=self.config.mlp_activations, + intermediate_dropout_rate=self.config.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + config=self.config, + quant=quant, + model_mode=model_mode, + mesh=mesh, + rngs=self.rngs, + ) + + def mlp_op(self, x, deterministic): + mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) + return self.with_logical_constraint(mlp) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + if self.is_engram_enabled: + engram_output = self.engram_op(x, decoder_input_tokens) + x = x + engram_output + + hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + + if self.is_mhc_enabled: + layer_output, _ = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_DENSE, + deterministic=deterministic, + ) + else: + mlp_lnx = self.mlp_op(hidden_states, deterministic) + layer_output = mlp_lnx + intermediate_inputs + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + + return self.post_process(layer_output, None, None, kv_cache) + + +DeepSeekDenseLayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekDenseLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + + +class DeepSeekMoELayer(DeepSeekGenericLayer): + """DeepSeek-style MoE layer with Multi-Head Latent Attention. + + Supports dropless and dropping base on configs. Uses a bias in routing instead + of load balancing loss. + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ) -> None: + super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) + self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( + config=self.config, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + rngs=self.rngs, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + + # This code should only be traced during initialization when using + # batch-split schedule. It is never run during model execution, since + # `Decoder` directly calls `batch_split_schedule` during execution. + # That is also why we can split/merge activations here as well as + # in `Decoder`, since they will never be executed together. + if self.config.use_batch_split_schedule: + activation_pspec = jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert", "context"), + None, + None, + ) + inputs = jax.shard_map( + functools.partial( + deepseek_batchsplit.split, + split_factor=self.config.batch_split_factor, + ), + mesh=self.mesh, + in_specs=activation_pspec, + out_specs=[activation_pspec] * self.config.batch_split_factor, + )(inputs) + dpos = deepseek_batchsplit.split(decoder_positions, self.config.batch_split_factor) + dseg = deepseek_batchsplit.split(decoder_segment_ids, self.config.batch_split_factor) + weights = deepseek_batchsplit.fetch_weights(nnx.to_pure_dict(nnx.state(self, nnx.Param)), self.config.dtype) + outputs = deepseek_batchsplit.batch_split_schedule( + inputs, + weights, + dpos, + dseg, + model_mode=model_mode, + mesh=self.mesh, + quant=self.quant, + cfg=self.config, + ) + outputs = jax.shard_map( + functools.partial( + deepseek_batchsplit.merge, + split_factor=self.config.batch_split_factor, + ), + mesh=self.mesh, + in_specs=([activation_pspec] * self.config.batch_split_factor,), + out_specs=activation_pspec, + )(outputs) + return outputs, None + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + if self.is_engram_enabled: + engram_output = self.engram_op(x, decoder_input_tokens) + x = x + engram_output + + hidden_states, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + page_state, + slot, + ) + + if self.is_mhc_enabled: + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.DeepSeekMoeBlock_0, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + ) + load_balance_loss = metadata["load_balance_loss"] + moe_bias_updates = metadata["moe_bias_updates"] + else: + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) + layer_output = mlp_lnx + intermediate_inputs + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + def mlp_op(self, x, deterministic, *args, **kwargs): + mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0( + x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + +DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class( + DeepSeekMoELayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/MaxCode/rag/sources/generic/maxtext_models_models.py b/MaxCode/rag/sources/generic/maxtext_models_models.py new file mode 100644 index 0000000..0d1fcab --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_models_models.py @@ -0,0 +1,574 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer models.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Any + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx + +from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN +from maxtext.inference import page_manager +from maxtext.layers.nnx_decoders import NNXDecoder +from maxtext.layers import initializers +from maxtext.layers import nnx_wrappers +from maxtext.layers.decoders import Decoder +from maxtext.layers.embeddings import Embed, embed_as_linen +from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen +from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.multimodal import processor as mm_processor +from maxtext.utils import max_utils + +# ------------------------------------------------------------------------------ +# The network: Transformer Definitions +# ------------------------------------------------------------------------------ + + +class TransformerLinenPure(nn.Module): + """An autoregressive transformer model.""" + + # Make new attributes required, so that all Transformer dependencies (train, decode, + # compile, etc) will error instead of silently use defaults. + # pylint: disable=attribute-defined-outside-init + config: Config + mesh: Mesh + quant: Quant + # Possible model_mode values can be found in maxtext.common.common_types. + # We generally use maxtext.common.common_types.MODEL_MODE_TRAIN or + # maxtext.common.common_types.MODEL_MODE_PREFILL for initializations here. + # TODO: Make model_mode required after confirming no users are affected. + model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ + # pylint: enable=attribute-defined-outside-init + + def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Initializes the model.""" + module = self.clone(model_mode=model_mode) + kwargs["model_mode"] = model_mode + return nn.Module.init(module, *args, **kwargs) + + def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Applies the model.""" + module = self.clone(model_mode=model_mode) + kwargs["model_mode"] = model_mode + return nn.Module.apply(module, *args, **kwargs) + + def setup(self): + """Initialize shared_embedding & decoder layers.""" + + cfg = self.config + mesh = self.mesh + self.shared_embedding = embed_as_linen( + num_embeddings=cfg.vocab_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + name="token_embedder", + config=cfg, + mesh=self.mesh, + ) + self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None + self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + + # If MTP is enabled via config, set up the MTP block. + if self.config.mtp_num_layers > 0: + # Get the list of layer blueprints for the current model. + layer_types = self.decoder.get_decoder_layers() + # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. + # By convention, this is the last layer in the list. + mtp_layer = layer_types[-1] + self.mtp_block = multi_token_prediction_block_as_linen( + config=self.config, + mesh=self.mesh, + transformer_layer_module=mtp_layer, + decoder=self.decoder, + rngs=self.make_rng("mtp_block"), + ) + + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """ + Compute logits from hidden states (wrapping decoder.apply_output_head). + This function is only used for vocabulary tiling. + """ + logits = self.decoder.apply_output_head( + shared_embedding=self.shared_embedding, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + return logits + + def __call__( + self, + decoder_input_tokens: jnp.ndarray, + decoder_positions: jnp.ndarray, + decoder_segment_ids=None, + encoder_images: None | jnp.ndarray = None, + encoder_image_masks: None | jnp.ndarray = None, + encoder_audios: None | jnp.ndarray = None, + enable_dropout=True, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + true_length: None | int = None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + decoder_target_tokens: None | jnp.ndarray = None, + decoder_target_mask: None | jnp.ndarray = None, + nnx_method=None, + kv_caches: list[jax.Array] | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + true_length: (Optional) Prompt length before padding + slot: (Optional) An integer representing the decode batch index selected + for this request. + """ + + if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: + raise ValueError( + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) + + bidirectional_mask = None + image_embeddings = None + audio_embeddings = None + deepstack_visual_embeds = None + + if self.config.use_multimodal and encoder_images is not None: + image_embeddings, deepstack_visual_embeds = self.vision_encoder( + input_images=encoder_images, deterministic=not enable_dropout + ) + bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens) + + if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None: + audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout) + + # Create audio mask for placeholder tokens (qwen3-omni models) + audio_masks = None + if audio_embeddings is not None: + audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens) + + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.shared_embedding, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + + # If we are initializing the model AND MTP is enabled, we must create + # dummy target tensors. This allows Flax to trace the MTPBlock and create + # all its necessary parameters, without requiring the main training pipeline + # to be aware of this initialization detail. + if self.is_initializing() and self.config.mtp_num_layers > 0: + if decoder_target_tokens is None: + dummy_shape = decoder_input_tokens.shape + decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) + decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) + decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32) + + # The Multi-Token Prediction (MTP) block functions as a "side-car" to the main + # model, active only during training. It computes an auxiliary loss based on + # predicting multiple future tokens, as described in the DeepSeek-V3 paper. + # To ensure architectural consistency, it uses two key components from the parent Transformer: + # 1. The same `DecoderLayer` blueprint for its internal transformer blocks. + # 2. The `shared_embedding` for both embedding future tokens and for its final + # logit projection. + # Its only effect is to "sow" these losses; it does not alter the primary logits output. + if self.config.mtp_num_layers > 0: + self.mtp_block( + shared_embedding=self.shared_embedding, + main_hidden_state=hidden_state, + input_ids=decoder_input_tokens, + target_ids=decoder_target_tokens, + target_mask=decoder_target_mask, + position_ids=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + ) + + if self.config.attention == "vllm_rpa": + # In vLLM, logits are computed separately after updating the KV cache. + return hidden_state, kv_caches + + return logits + + +def transformer_as_linen( + config: Config, + mesh: Mesh, + quant: Quant, + model_mode: str = MODEL_MODE_TRAIN, + *, + name: str | None = None, +) -> nnx_wrappers.ToLinen | TransformerLinenPure: + """Constructs a Transformer model as a Linen or NNX module. + + This function returns an autoregressive Transformer model as either a Linen module + or an NNX-wrapped module, depending on the `config.enable_nnx` flag. The returned module + is suitable for training, evaluation, or decoding. + + If `config.enable_nnx` is True, returns a `TransformerLinen` that wraps the NNX-style + Transformer for integration with NNX-specific APIs and workflows. + Otherwise, returns a pure Flax Linen implementation (`TransformerLinenPure`). + + Args: + config (Config): The configuration object specifying model hyperparameters and options. + mesh (Mesh): The JAX sharding mesh for device partitioning. + quant (Quant): The quantization module or configuration to use. + model_mode (str, optional): The operational mode for the model, e.g. + training, prefill, or autoregressive. Defaults to `MODEL_MODE_TRAIN`. + name (str, optional): Optional module name for Linen/NNX construction. + + Returns: + nnx_wrappers.ToLinen | TransformerLinenPure: + A constructed Transformer model compatible with the specified framework (Linen or NNX). + """ + if config.enable_nnx: + return TransformerLinen( + Transformer, + args=(), + kwargs=nn.FrozenDict( + { + "mesh": mesh, + "config": config, + "quant": quant, + "model_mode": model_mode, + } + ), + metadata_fn=initializers.variable_to_logically_partitioned, + name=name, + ) + else: + return TransformerLinenPure(config, mesh, quant, model_mode=model_mode, name=name) + + +class TransformerLinen(nnx_wrappers.ToLinen): + """Transformer model as a linen module.""" + + def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Initializes the model.""" + model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types] + module = self.clone(kwargs=model_kwargs) + kwargs["model_mode"] = model_mode + return nnx_wrappers.ToLinen.init(module, *args, **kwargs) + + def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Applies the model.""" + model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types] + module = self.clone(kwargs=model_kwargs) + kwargs["model_mode"] = model_mode + return nnx_wrappers.ToLinen.apply(module, *args, **kwargs) + + +class Transformer(nnx.Module): + """An autoregressive transformer model.""" + + # Make new attributes required, so that all Transformer dependencies (train, decode, + # compile, etc) will error instead of silently use defaults. + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + config: Config, + mesh: Mesh, + quant: Quant, + *, + model_mode: str = MODEL_MODE_TRAIN, + rngs: nnx.Rngs, + ): + """Initialize shared_embedding & decoder layers.""" + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + + cfg = self.config + mesh = self.mesh + self.token_embedder = Embed( + mesh=self.mesh, + num_embeddings=cfg.vocab_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + config=cfg, + rngs=rngs, + ) + self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None + self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None + if cfg.pure_nnx_decoder: + self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) + else: + decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + self.hidden_states = None + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) + dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + dummy_decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + if self.config.attention == "vllm_rpa": + try: + # pylint: disable=import-outside-toplevel + from tpu_inference.layers.common.attention_metadata import AttentionMetadata # pytype: disable=import-error + except ImportError as e: + raise ImportError( + "vLLM RPA attention requires the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + dummy_attention_metadata = AttentionMetadata( + input_positions=jnp.ones((batch_size * seq_len,), dtype=jnp.int32), + block_tables=jnp.ones((seq_len,), dtype=jnp.int32), + seq_lens=jnp.ones((1), dtype=jnp.int32), + query_start_loc=jnp.ones((2), dtype=jnp.int32), + request_distribution=jnp.ones((3), dtype=jnp.int32), + ) + else: + dummy_attention_metadata = None + + if not cfg.pure_nnx_decoder: + self.decoder.lazy_init( + shared_embedding=self.token_embedder, + decoder_input_tokens=dummy_decoder_input_tokens, + decoder_positions=dummy_decoder_positions, + attention_metadata=dummy_attention_metadata, + ) + + # If MTP is enabled via config, set up the MTP block. + if self.config.mtp_num_layers > 0: + # Get the list of layer blueprints for the current model. + layer_types = self.decoder.get_decoder_layers() + # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. + # By convention, this is the last layer in the list. + mtp_layer = layer_types[-1] + mtp_block_linen = multi_token_prediction_block_as_linen( + config=self.config, + mesh=self.mesh, + transformer_layer_module=mtp_layer, + decoder=self.decoder, + rngs=rngs, + name="mtp_block", + ) + self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) + + self.mtp_block.lazy_init( + shared_embedding=self.token_embedder, + main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype), + input_ids=jnp.ones((1, 1), dtype=jnp.int32), + target_ids=jnp.ones((1, 1), dtype=jnp.int32), + target_mask=jnp.ones((1, 1), dtype=jnp.int32), + position_ids=jnp.ones((1, 1), dtype=jnp.int32), + decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32), + deterministic=True, + ) + + def no_op(self, *args, **kwargs): + """A no-op method to allow the model to be used in a lazy context.""" + return + + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): + """Initializes the KV cache for the Transformer. + + Args: + cache_size: The maximum size of the KV cache. + batch_size: The batch size for which the cache is initialized. + dtype: Data type for the cache. Defaults to `jnp.float32`. + + Returns: + True if the cache is successfully initialized. + """ + return True + + def __call__( + self, + decoder_input_tokens: jnp.ndarray, + decoder_positions: jnp.ndarray, + decoder_segment_ids=None, + cache=None, + encoder_images: jax.Array | None = None, + encoder_image_masks: jax.Array | None = None, + encoder_audios: jax.Array | None = None, + enable_dropout=True, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + true_length: int | None = None, + slot: int | None = None, + page_state: page_manager.PageState | None = None, + decoder_target_tokens: jax.Array | None = None, + decoder_target_mask: jax.Array | None = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + """Applies the Zero-1 FSDP wrapped Transformer model. + + This method handles the all-gather operation for model weights before + applying the underlying Transformer model, and then releases them. + + Args: + decoder_input_tokens: Input tokens for the decoder. + decoder_positions: Positional encodings for the decoder inputs. + decoder_segment_ids: Segment IDs for the decoder inputs (optional). + encoder_images: Encoder images for multimodal models (optional). + enable_dropout: Whether to enable dropout. Defaults to True. + previous_chunk: Previous chunk for incremental decoding (optional). + true_length: True length of the prompt before padding (optional). + slot: An integer representing the decode batch index selected for this request (optional). + page_state: Page state for paged attention (optional). + partition_spec: Partition specification for FSDP all-gather. + decoder_target_tokens: Target tokens for the decoder (optional, used in MTP). + decoder_target_mask: Target mask for the decoder (optional, used in MTP). + nnx_method: Method to call on the NNX module (optional). + kv_caches: List of KV caches for each attention layer, used when invoking from vLLM (optional). + attention_metadata: Mapping to store attention metadata, used when invoking from vLLM (optional). + + Returns: + Logits from the Transformer model. Logits, hidden_state, kv_caches if called by vLLM. + """ + if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: + raise ValueError( + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) + + bidirectional_mask = None + image_embeddings = None + deepstack_visual_embeds = None + if self.config.use_multimodal and encoder_images is not None: + image_embeddings, deepstack_visual_embeds = self.vision_encoder( + input_images=encoder_images, deterministic=not enable_dropout + ) + bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens) + + audio_embeddings = None + if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None: + audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout) + + # Create audio mask for placeholder tokens (qwen3-omni models) + audio_masks = None + if audio_embeddings is not None: + audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens) + + mutable_collections = [] + if self.config.record_internal_nn_metrics: + mutable_collections.append("intermediates") + if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections: + mutable_collections.append("intermediates") + + if self.config.pure_nnx_decoder: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + else: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + mutable=mutable_collections, + ) # pytype: disable=wrong-keyword-args + + # Materialize hidden state when vocab tiling is enabled + if self.config.num_vocab_tiling > 1: + self.hidden_states = hidden_state + + # If we are initializing the model AND MTP is enabled, we must create + # dummy target tensors. This allows Flax to trace the MTPBlock and create + # all its necessary parameters, without requiring the main training pipeline + # to be aware of this initialization detail. + # if self.is_initializing() and self.config.mtp_num_layers > 0: + # if decoder_target_tokens is None: + # dummy_shape = decoder_input_tokens.shape + # decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) + # decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) + # decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32) + + # The Multi-Token Prediction (MTP) block functions as a "side-car" to the main + # model, active only during training. It computes an auxiliary loss based on + # predicting multiple future tokens, as described in the DeepSeek-V3 paper. + # To ensure architectural consistency, it uses two key components from the parent Transformer: + # 1. The same `DecoderLayer` blueprint for its internal transformer blocks. + # 2. The `shared_embedding` for both embedding future tokens and for its final + # logit projection. + # Its only effect is to "sow" these losses; it does not alter the primary logits output. + if self.config.mtp_num_layers > 0: + self.mtp_block( + shared_embedding=self.token_embedder, + main_hidden_state=hidden_state, + input_ids=decoder_input_tokens, + target_ids=decoder_target_tokens, + target_mask=decoder_target_mask, + position_ids=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + ) + + if self.config.attention == "vllm_rpa": + # In vLLM, logits are computed separately after updating the KV cache. + return hidden_state, kv_caches + + return logits diff --git a/MaxCode/rag/sources/generic/maxtext_models_qwen3.py b/MaxCode/rag/sources/generic/maxtext_models_qwen3.py new file mode 100644 index 0000000..eb15747 --- /dev/null +++ b/MaxCode/rag/sources/generic/maxtext_models_qwen3.py @@ -0,0 +1,2256 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3 family of model decoder layers.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Any, cast +import math + +import jax +import jax.nn +from jax import lax +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh +import jax.numpy as jnp + +from flax import linen as nn +from flax import nnx + +from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN +from maxtext.layers import attentions +from maxtext.layers import initializers as max_initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding +from maxtext.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.layers.attentions import Attention +from maxtext.layers.linears import DenseGeneral, MlpBlock +from maxtext.layers.moe import RoutedMoE +from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned + +from maxtext.utils import max_utils +from maxtext.inference import page_manager, kvcache + + +# ----------------------------------------- +# Qwen3-Next Layer Implementations +# ----------------------------------------- + + +def naive_jax_chunk_gated_delta_rule( + query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False +): + """Naive implementation of the Gated Delta Rule in jax.""" + initial_dtype = query.dtype + if use_qk_norm_in_gdn: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + + query = jnp.transpose(query, (0, 2, 1, 3)).astype(jnp.float32) + key = jnp.transpose(key, (0, 2, 1, 3)).astype(jnp.float32) + value = jnp.transpose(value, (0, 2, 1, 3)).astype(jnp.float32) + beta = jnp.transpose(beta, (0, 2, 1)).astype(jnp.float32) + g = jnp.transpose(g, (0, 2, 1)).astype(jnp.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + + if pad_size > 0: + query = jnp.pad(query, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + key = jnp.pad(key, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + value = jnp.pad(value, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + beta = jnp.pad(beta, ((0, 0), (0, 0), (0, pad_size))) + g = jnp.pad(g, ((0, 0), (0, 0), (0, pad_size))) + + total_sequence_length = sequence_length + pad_size + scale = jax.lax.rsqrt(jnp.array(query.shape[-1]).astype(jnp.float32)) + query = query * scale + + v_beta = value * jnp.expand_dims(beta, -1) + k_beta = key * jnp.expand_dims(beta, -1) + + num_chunks = total_sequence_length // chunk_size + query_c = query.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + key_c = key.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + k_beta_c = k_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + v_beta_c = v_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, v_head_dim) + g_c = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + + mask = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=0) + + g_cumsum = jnp.cumsum(g_c, axis=-1) + g_diff = jnp.expand_dims(g_cumsum, -1) - jnp.expand_dims(g_cumsum, -2) + g_diff_tril = jnp.tril(g_diff) + g_diff_exp = jnp.exp(g_diff_tril).astype(jnp.float32) + decay_mask = g_diff_exp + + prec = jax.lax.Precision.HIGHEST + attn = -jnp.matmul(k_beta_c, jnp.swapaxes(key_c, -1, -2), precision=prec) * decay_mask + attn = jnp.where(mask, 0.0, attn) + + def inner_attn_body(i, attn_val): + indices = jnp.arange(chunk_size) + col_mask = indices < i + row = attn_val[..., i, :] * col_mask + sub_mask = jnp.expand_dims(indices < i, -1) & (indices < i) + sub = attn_val * sub_mask + row_exp = jnp.expand_dims(row, -1) + term = row_exp * sub + summed = jnp.sum(term, axis=-2) + update_val = row + summed + original_row = attn_val[..., i, :] + new_row = jnp.where(col_mask, update_val, original_row) + return attn_val.at[..., i, :].set(new_row) + + attn = jax.lax.fori_loop(1, chunk_size, inner_attn_body, attn) + attn = attn + jnp.eye(chunk_size, dtype=attn.dtype) + value_intra = jnp.matmul(attn, v_beta_c, precision=prec) + k_cumdecay = jnp.matmul(attn, (k_beta_c * jnp.expand_dims(jnp.exp(g_cumsum), -1)), precision=prec) + + output_final_state = initial_state is not None + if initial_state is None: + last_recurrent_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=value_intra.dtype) + else: + last_recurrent_state = initial_state.astype(value_intra.dtype) + + mask_inter = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=1) + + query_scan = jnp.transpose(query_c, (2, 0, 1, 3, 4)) + key_scan = jnp.transpose(key_c, (2, 0, 1, 3, 4)) + value_scan = jnp.transpose(value_intra, (2, 0, 1, 3, 4)) + k_cumdecay_scan = jnp.transpose(k_cumdecay, (2, 0, 1, 3, 4)) + g_scan = jnp.transpose(g_cumsum, (2, 0, 1, 3)) + decay_mask_scan = jnp.transpose(decay_mask, (2, 0, 1, 3, 4)) + + xs = (query_scan, key_scan, value_scan, k_cumdecay_scan, g_scan, decay_mask_scan) + + def scan_body(prev_state, x): + q_i, k_i, v_i, k_cumdecay_i, g_i, decay_mask_i = x + last_recurrent_state = prev_state + prec = jax.lax.Precision.HIGHEST + + attn_i = jnp.matmul(q_i, jnp.swapaxes(k_i, -1, -2), precision=prec) * decay_mask_i + attn_i = jnp.where(mask_inter, 0.0, attn_i) + + v_prime = jnp.matmul(k_cumdecay_i, last_recurrent_state, precision=prec) + v_new = v_i - v_prime + + g_i_exp = jnp.exp(g_i) + attn_inter = jnp.matmul(q_i * jnp.expand_dims(g_i_exp, -1), last_recurrent_state, precision=prec) + + core_attn_out_i = attn_inter + jnp.matmul(attn_i, v_new, precision=prec) + + g_i_last_exp = jnp.exp(g_i[..., -1, None, None]) + new_last_recurrent_state = last_recurrent_state * g_i_last_exp + + g_diff_exp = jnp.expand_dims(jnp.exp(jnp.expand_dims(g_i[..., -1], -1) - g_i), -1) + k_i_g_diff = k_i * g_diff_exp + + update_term = jnp.matmul(jnp.swapaxes(k_i_g_diff, -1, -2), v_new, precision=prec) + new_last_recurrent_state = new_last_recurrent_state + update_term + + return new_last_recurrent_state, core_attn_out_i + + final_state, core_attn_out_stacked = jax.lax.scan(scan_body, last_recurrent_state, xs) + + core_attn_out = jnp.transpose(core_attn_out_stacked, (1, 2, 0, 3, 4)) + core_attn_out = core_attn_out.reshape(batch_size, num_heads, -1, v_head_dim) + core_attn_out = core_attn_out[:, :, :sequence_length, :] + core_attn_out = jnp.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + + return core_attn_out, final_state if output_final_state else None + + +def jax_chunk_gated_delta_rule( + query: Array, + key: Array, + value: Array, + g: Array, + beta: Array, + chunk_size: int = 64, + initial_state: None | Array = None, + use_qk_norm_in_gdn: bool = False, + compute_dtype: jnp.dtype = jnp.bfloat16, +) -> tuple[Array, None | Array]: + """Optimized JAX implementation of Gated Delta Rule.""" + # ========================================================================= + # STAGE 1: PREPARATION & PADDING + # ========================================================================= + initial_dtype = query.dtype + + if use_qk_norm_in_gdn: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + + g = g.astype(jnp.float32) + + # 2. Cast inputs to the requested compute_dtype (cfg.dtype) to save memory/compute + query = query.astype(compute_dtype) + key = key.astype(compute_dtype) + value = value.astype(compute_dtype) + beta = beta.astype(compute_dtype) + + # Scale Query (keep in compute_dtype) + scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype) + query = query * scale + + B, seq_len, H, K_dim = key.shape + V_dim = value.shape[-1] + + pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size + if pad_len > 0: + + def pad_fn(x, val=0.0): + return jnp.pad(x, ((0, 0), (0, pad_len)) + ((0, 0),) * (x.ndim - 2), constant_values=val) + + query = pad_fn(query) + key = pad_fn(key) + value = pad_fn(value) + g = pad_fn(g) + beta = pad_fn(beta) + + num_chunks = query.shape[1] // chunk_size + + # Helper: (B, S, H, D) -> (B, N, H, C, D) + def to_chunk(x): + return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4) + + # Helper for scalars: (B, S, H) -> (B, N, H, C) + def to_chunk_scalar(x): + return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2) + + q_c = to_chunk(query) + k_c = to_chunk(key) + v_c = to_chunk(value) + g_c = to_chunk_scalar(g) + beta_c = to_chunk_scalar(beta) + + # ========================================================================= + # STAGE 2: INTRA-CHUNK PRE-COMPUTATION (Parallel) + # ========================================================================= + + # Cumulative decay (Must be float32) + g_cumsum = jnp.cumsum(g_c, axis=-1) + k_beta = k_c * beta_c[..., None] + + # S Matrix Calculation + S = jnp.matmul(k_beta, k_c.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST) + S = S.astype(jnp.float32) + + # Apply mask BEFORE exp to prevent 'inf' gradients + g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :] + mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1) + g_diff = jnp.where(mask, g_diff, -1e30) + + S = S * jnp.exp(g_diff) + S = jnp.where(mask, S, 0.0) + + # Inversion (A) - Strictly float32 + identity = jnp.eye(chunk_size, dtype=jnp.float32) + identity_broadcasted = jnp.broadcast_to(identity, S.shape) + + A = jax.scipy.linalg.solve_triangular(identity + S, identity_broadcasted, lower=True, unit_diagonal=True) + + # 5. WY Factors + v_beta = v_c * beta_c[..., None] + u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST) + u_chunks = u_chunks.astype(compute_dtype) + + k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None] + w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST) + w_chunks = w_chunks.astype(compute_dtype) + + # ========================================================================= + # STAGE 3: INTER-CHUNK RECURRENCE (Scan) + # ========================================================================= + scan_perm_vec = (1, 0, 2, 3, 4) + scan_perm_scl = (1, 0, 2, 3) + + w_scan = w_chunks.transpose(scan_perm_vec) + u_scan = u_chunks.transpose(scan_perm_vec) + k_scan = k_c.transpose(scan_perm_vec) + q_scan = q_c.transpose(scan_perm_vec) + g_scan = g_cumsum.transpose(scan_perm_scl) + + if initial_state is None: + h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=jnp.float32) + else: + h_init = initial_state.astype(jnp.float32) + + xs = (w_scan, u_scan, q_scan, k_scan, g_scan) + + def scan_body(h, args): + w, u, q, k, g = args + prec = jax.lax.Precision.HIGHEST + + # --- Output Computation --- + # 1. Inter-chunk: q(dtype) * exp(g)(f32) -> f32 + q_g = q.astype(jnp.float32) * jnp.exp(g)[..., None] + attn_inter = jnp.matmul(q_g, h, precision=prec) + + # 2. Delta Rule Subtraction (v_prime and v_new) + # w serves as k_cumdecay, u serves as value_intra + v_prime = jnp.matmul(w.astype(jnp.float32), h, precision=prec) + v_new = u.astype(jnp.float32) - v_prime + + # 3. Intra-chunk: q(dtype) @ k(dtype) -> f32 + attn = jnp.matmul(q, k.swapaxes(-1, -2), precision=prec) + attn = attn.astype(jnp.float32) + + # Mask before exp + g_diff = g[..., :, None] - g[..., None, :] + mask_intra = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool)) + g_diff = jnp.where(mask_intra, g_diff, -1e30) + + attn_i = attn * jnp.exp(g_diff) + attn_i = jnp.where(mask_intra, attn_i, 0.0) + + # Note: We do NOT multiply attn_i by beta here. The Delta rule mathematically + # absorbed beta inside v_new (via u). + + # 4. Combine Core Output + term2 = jnp.matmul(attn_i, v_new, precision=prec) + o_c = attn_inter + term2 + + # --- State Update --- + g_i_last_exp = jnp.exp(g[..., -1, None, None]) + h_new = h * g_i_last_exp + + # Apply Delta Rule K decay to state + g_diff_exp_state = jnp.exp(g[..., -1, None] - g)[..., None] + k_i_g_diff = k.astype(jnp.float32) * g_diff_exp_state + + update_term = jnp.matmul(k_i_g_diff.swapaxes(-1, -2), v_new, precision=prec) + h_new = h_new + update_term + + return h_new, o_c + + final_h, o_chunks = lax.scan(scan_body, h_init, xs) + + # ========================================================================= + # STAGE 4: FINALIZATION + # ========================================================================= + o = o_chunks.transpose(1, 0, 3, 2, 4) + o = o.reshape(B, -1, H, V_dim) + + if pad_len > 0: + o = o[:, :seq_len, :, :] + + o = o.astype(initial_dtype) + + return o, (final_h if initial_state is not None else None) + + +class Qwen3NextGatedDeltaNet(nnx.Module): + """ + This module implements the full end-to-end logic of a Gated Delta Network layer. + + End-to-End Equations Implemented: + Let `x` be the input `hidden_states`. + + Step A: Input Projections + 1. (q_raw, k_raw, v_raw, z) = Linear_qkvz(x) + 2. (b, a) = Linear_ba(x) + + Step B: 1D Convolution + 1. qkv_conv = silu(Conv1D(concatenate(q_raw, k_raw, v_raw))) + 2. (q, k, v) = split(qkv_conv) + + Step C: Gated Delta Rule (Recurrent Core) + 1. Gates: β=sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) + 2. Core Calculation: core_attn_out = jax_chunk_gated_delta_rule(q, k, v, g, β) + + Step D: Final Output Stage + 1. y = RMSNorm(core_attn_out) * silu(z) + 2. output = Linear_out(y) + """ + + def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs): + """ + Args: + config: MaxText configuration object. + rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. + """ + self.config = config + cfg = self.config + + in_features = cfg.emb_dim + self.num_v_heads = cfg.gdn_num_value_heads + self.num_k_heads = cfg.gdn_num_key_heads + self.head_k_dim = cfg.gdn_key_head_dim + self.head_v_dim = cfg.gdn_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + conv_dim = self.key_dim * 2 + self.value_dim + conv_kernel_size = cfg.gdn_conv_kernel_dim + self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + if model_mode != MODEL_MODE_TRAIN: + self.cache = kvcache.GatedDeltaNetCache( + batch=config.per_device_batch_size, + num_heads=self.num_v_heads, + k_head_dim=self.head_k_dim, + v_head_dim=self.head_v_dim, + conv_kernel_size=self.config.gdn_conv_kernel_dim, + conv_dim=conv_dim, + dtype=dtype, + ) + + # Submodule instantiations + self.in_proj_qkvz = DenseGeneral( + in_features_shape=in_features, + out_features_shape=(self.key_dim * 2 + self.value_dim * 2), + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("embed", "mlp"), + matmul_precision=cfg.matmul_precision, + rngs=rngs, + ) + self.in_proj_ba = DenseGeneral( + in_features_shape=in_features, + out_features_shape=(self.num_v_heads * 2), + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("embed", "mlp"), + matmul_precision=cfg.matmul_precision, + rngs=rngs, + ) + + self.conv1d = nnx.Conv( + in_features=conv_dim, + out_features=conv_dim, + kernel_size=(conv_kernel_size,), + feature_group_count=conv_dim, # Depthwise + padding="CAUSAL", + use_bias=False, + dtype=cfg.dtype, + param_dtype=cfg.weight_dtype, + precision=cfg.matmul_precision, + rngs=rngs, + ) + + # Initialize A_log to match torch.log(torch.uniform(0, 16)) + def a_log_init(key, shape, dtype=jnp.float32): + # Sample from Uniform(epsilon, 16) to avoid log(0) + a_vals = jax.random.uniform(key, shape=shape, dtype=dtype, minval=1e-9, maxval=16.0) + return jnp.log(a_vals) + + self.A_log = nnx.Param(a_log_init(rngs.params(), (self.num_v_heads,), dtype=cfg.weight_dtype)) + self.dt_bias = nnx.Param(nnx.initializers.ones(rngs.params(), (self.num_v_heads,), dtype=cfg.weight_dtype)) + + self.norm = Qwen3NextRMSNormGated( + num_features=self.head_v_dim, # Normalize over the head dimension (D_v) + eps=cfg.normalization_layer_epsilon, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + rngs=rngs, + ) + self.out_proj = DenseGeneral( + in_features_shape=self.value_dim, + out_features_shape=(in_features,), + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("mlp", "embed"), + matmul_precision=cfg.matmul_precision, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: Array, + model_mode: str = MODEL_MODE_TRAIN, + kv_cache=None, + decoder_segment_ids: None | Array = None, + **kwargs, + ) -> Array: + # hidden_states: (B, S, E) + cfg = self.config + batch, seq_len, _ = hidden_states.shape + + # ========================================================================= + # STEP A: Input Projections + # ========================================================================= + # qkvz: (B, S, 2 * K_dim + 2 * V_dim) + qkvz = self.in_proj_qkvz(hidden_states) + # ba: (B, S, 2 * H_v) + ba = self.in_proj_ba(hidden_states) + + # QKVZ Reshaping and Splitting + # Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K + new_shape_qkvz = ( + batch, + seq_len, + self.num_k_heads, # H_k + 2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head, + ) + # mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K) + mixed_qkvz = qkvz.reshape(new_shape_qkvz) + + split_indices_qkvz = [ + self.head_k_dim, # D_k + 2 * self.head_k_dim, # 2 * D_k + 2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v + ] + # query: (B, S, H_k, D_k) + # key: (B, S, H_k, D_k) + # value_raw: (B, S, H_k, V_per_K * D_v) + # z_raw: (B, S, H_k, V_per_K * D_v) + query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3) + + # value: (B, S, H_v, D_v) + value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + # z: (B, S, H_v, D_v) + z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + # BA Reshaping and Splitting + new_shape_ba = ( + batch, + seq_len, + self.num_k_heads, # H_k + 2 * self.v_heads_per_k_head, + ) + # mixed_ba: (B, S, H_k, 2 * V_per_K) + mixed_ba = ba.reshape(new_shape_ba) + + split_indices_ba = [self.v_heads_per_k_head] + # b_raw: (B, S, H_k, V_per_K) + # a_raw: (B, S, H_k, V_per_K) + b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3) + + # b: (B, S, H_v) + b = b_raw.reshape(batch, seq_len, self.num_v_heads) + # a: (B, S, H_v) + a = a_raw.reshape(batch, seq_len, self.num_v_heads) + + # Flatten head dimensions for concatenation before conv + # q: (B, S, K_dim) + q = query.reshape(batch, seq_len, -1) + # k: (B, S, K_dim) + k = key.reshape(batch, seq_len, -1) + # v: (B, S, V_dim) + v = value.reshape(batch, seq_len, -1) + + # ========================================================================= + # STEP B: 1D Convolution + # ========================================================================= + qkv = jnp.concatenate([q, k, v], axis=-1) + batch, seq_len, _ = qkv.shape + conv_kernel_size = self.config.gdn_conv_kernel_dim + + conv_state = None + if model_mode != MODEL_MODE_TRAIN: + # Retrieve state from self.cache + conv_state = self.cache.conv_state.value + if conv_state.shape[0] != batch: + # Assumes zero-initialized state for testing + if conv_state.shape[0] == 1: + conv_state = jnp.broadcast_to(conv_state, (batch,) + conv_state.shape[1:]) + else: + conv_state = conv_state[:batch] + + # Concatenate previous state with new input + conv_input = jnp.concatenate([conv_state, qkv], axis=1) + + if decoder_segment_ids is not None: + valid_lens = jnp.sum(decoder_segment_ids != 0, axis=1) # Shape: (B,) + + def extract_state(c_in, v_len): + return jax.lax.dynamic_slice_in_dim(c_in, v_len, conv_kernel_size - 1, axis=0) + + new_conv_state = jax.vmap(extract_state)(conv_input, valid_lens) + else: + new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] + + # Update self.cache in place + self.cache.conv_state.value = new_conv_state + else: + # Train: pad with zeros + conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0))) + + # Perform the convolution. + conv_out = self.conv1d(conv_input) + # Slice the output to match the original input sequence length. + conv_out = conv_out[:, -seq_len:, :] + qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype) + # q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim) + q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) + + # Reshape for multi-head processing + batch, seq_len, _ = hidden_states.shape + # query shape: (B, S, H_k, D_k) + query = q_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) + # key shape: (B, S, H_k, D_k) + key = k_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) + # value shape: (B, S, H_v, D_v) + value = v_conv.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + # ========================================================================= + # STEP C: Gated Delta Rule Recurrence + # ========================================================================= + A_log = jnp.asarray(self.A_log[...], dtype=cfg.dtype) + dt_bias = jnp.asarray(self.dt_bias[...], dtype=cfg.dtype) + # beta shape: (B, S, H_v) + beta = jax.nn.sigmoid(b) + # g shape: (B, S, H_v) + g = -jnp.exp(A_log) * jax.nn.softplus(a + dt_bias) + + if decoder_segment_ids is not None: + mask = decoder_segment_ids != 0 + # Apply mask by broadcasting to respective shapes + key = jnp.where(mask[..., None, None], key, 0.0) + value = jnp.where(mask[..., None, None], value, 0.0) + g = jnp.where(mask[..., None], g, 0.0) + + if self.num_v_heads > self.num_k_heads and self.num_v_heads % self.num_k_heads == 0: + repeats = self.num_v_heads // self.num_k_heads + # query shape after repeat: (B, S, H_v, D_k) + query = jnp.repeat(query, repeats, axis=2) + # key shape after repeat: (B, S, H_v, D_k) + key = jnp.repeat(key, repeats, axis=2) + elif self.num_k_heads > self.num_v_heads and self.num_k_heads % self.num_v_heads == 0: + pass + + recurrent_state = None + if model_mode != MODEL_MODE_TRAIN: + # Retrieve state from self.cache + recurrent_state = self.cache.recurrent_state.value + + if recurrent_state.shape[0] != batch: + if recurrent_state.shape[0] == 1: + recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) + else: + recurrent_state = recurrent_state[:batch] + + core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=cfg.gdn_chunk_size, + initial_state=recurrent_state, + use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, + compute_dtype=cfg.dtype, + ) + + if model_mode != MODEL_MODE_TRAIN: + # Update self.cache in place for both prefill and decode + self.cache.recurrent_state.value = recurrent_state_out + + # ========================================================================= + # STEP D: Final Output Stage + # ========================================================================= + + # The normalization and gating is applied per-head on the value dimension. + + # Apply the norm and gate. Output shape: (B, S, H_v, D_v) + gated_output_reshaped = self.norm(core_attn_out, z) + + # Reshape back to a single feature dimension for the final projection. + # Shape from (B, S, H_v, D_v) -> (B, S, value_dim) + gated_output = gated_output_reshaped.reshape(batch, seq_len, -1) + + # Final output shape: (B, S, E) + output = self.out_proj(gated_output) + + return output + + +class Qwen3NextFullAttention(nnx.Module): + """Qwen3-Next Full Attention Layer. + + This module implements the full self-attention mechanism as used in + Qwen3-Next models for layers that do not use the Gated Delta Network. + It wraps the main `attentions.Attention` class, which handles the core attention operation, + including the query, key, value, and output projections. + + Qwen3 Next Attention differs from standard attention by the following features: + - Query and Gate splitting from a single q projection. + - Application of a sigmoid gate to the attention output. + - Usage of `Qwen3NextRMSNorm` for query and key normalization. + - Usage of `PartialRotaryEmbedding` for partial rotary position embeddings. + - Partial ROPE is applied to the first 25% of head dimensions + + Attributes: + config: MaxText configuration object. + mesh: The device mesh for sharding. + model_mode: The operational mode (e.g., 'train', 'prefill'). + layer_idx: The index of the current layer. + quant: Optional quantization configuration. + attention: An instance of `attentions.Attention` which contains the + learnable parameters for query, key, value, and output projections + (e.g., `attention.query`, `attention.key`, etc.), and performs + the attention calculation. + """ + + def __init__( + self, config: Config, mesh: Mesh, model_mode: str, layer_idx: int, quant: None | Quant = None, *, rngs: nnx.Rngs + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.layer_idx = layer_idx + self.quant = quant + cfg = self.config + + scaling_factor = self.config.head_dim**-0.5 + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.attention = attentions.Attention( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + out_axis_names=(BATCH, LENGTH_NO_EXP, EMBED), + mesh=self.mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + use_qk_norm=cfg.use_qk_norm, + query_pre_attn_scalar=scaling_factor, + model_mode=model_mode, + rngs=rngs, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + attention_output, kv_cache = self.attention( + inputs_q=inputs, + inputs_kv=inputs, + inputs_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + return attention_output, kv_cache + + +class Qwen3NextSparseMoeBlock(nnx.Module): + """ + This module encapsulates the unique MoE structure of Qwen3-Next, which includes: + 1. A set of routed experts, where each token is sent to a subset of experts. + 2. A single shared expert, which all tokens pass through. + 3. A learnable gate that determines the contribution of the shared expert. + + Attributes: + config: The model configuration object. + mesh: The device mesh for sharding. + quant: Optional quantization configuration. + """ + + def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rngs: nnx.Rngs): + self.config = config + self.mesh = mesh + self.quant = quant + cfg = self.config + + # 1. Instantiate and apply the routed experts block. + self.routed_experts = moe.RoutedMoE( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=self.mesh, + kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.moe_mlp_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + rngs=rngs, + ) + + # 2. Instantiate and apply the shared expert. + self.shared_expert = MlpBlock( + config=cfg, + mesh=mesh, + in_features=cfg.emb_dim, + intermediate_dim=cfg.moe_mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + model_mode=config.model_call_mode, + rngs=rngs, + ) + + # 3. Instantiate and apply the gate for the shared expert. + self.shared_expert_gate = DenseGeneral( + in_features_shape=cfg.emb_dim, + out_features_shape=1, + use_bias=False, # Qwen3-Next shared_expert_gate does not have a bias + dtype=cfg.dtype, + kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + matmul_precision=cfg.matmul_precision, + rngs=rngs, + ) + + def __call__(self, hidden_states: Array, deterministic: bool) -> tuple[Array, Array | None]: + """ + Applies the sparse MoE block to the input hidden states. + + Args: + hidden_states: The input array from the previous layer. Shape: (batch, seq, embed_dim) + deterministic: If True, disables dropout. + + Returns: + A tuple containing: + - The output array of the MoE block. + - The load balancing loss from the routed experts, if applicable during training. + """ + # 1. Apply the routed experts block. + routed_output, load_balance_loss, _ = self.routed_experts(hidden_states) + + # 2. Apply the shared expert. + shared_expert_output = self.shared_expert(hidden_states, deterministic=deterministic) + + # 3. Apply the gate for the shared expert. + shared_gate_output = self.shared_expert_gate(hidden_states) + + # 4. Combine the outputs. + final_output = routed_output + jax.nn.sigmoid(shared_gate_output) * shared_expert_output + + return final_output, load_balance_loss + + +class Qwen3NextScannableBlock(nnx.Module): + """A scannable block of Qwen3-Next decoder layers. + + This module contains a fixed number of heterogeneous decoder layers that form + a repeating pattern, as defined by `config.inhomogeneous_layer_cycle_interval`. It is + intended to be the body of an `nn.scan` transformation to construct the full + decoder stack efficiently. + + Attributes: + config: The model configuration object. + mesh: The device mesh for sharding. + model_mode: The operational mode (e.g., 'train', 'prefill'). + quant: Optional quantization configuration. + """ + + def __init__(self, config: Config, mesh: Mesh, model_mode: str, quant: None | Quant = None, *, rngs: nnx.Rngs): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + cfg = self.config + + # Instantiate each layer within the block in __init__ + for i in range(cfg.inhomogeneous_layer_cycle_interval): + layer_rngs = self.rngs.fork() # Fork RNGs for each layer + layer_name = f"layer_{i}" + layer = Qwen3NextDecoderLayer( + config=self.config, + mesh=self.mesh, + quant=self.quant, + model_mode=self.model_mode, + layer_idx=i, + rngs=layer_rngs, + ) + setattr(self, layer_name, layer) + + def __call__( + self, + carry: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + ) -> tuple[Array, None]: + """Applies the block of decoder layers to the input carry. + + Args: + carry: The input tensor from the previous scan iteration. + # ... other arguments are broadcasted to each iteration. + + Returns: + A tuple containing the output of the block (the new carry) and an empty + value for the scan's `y` collection. + """ + cfg = self.config + x = carry + + # Loop over the number of sub-layers that make up one repeating pattern. + for i in range(cfg.inhomogeneous_layer_cycle_interval): + layer = getattr(self, f"layer_{i}") + # The second return value is kv_cache, which we ignore here because + # it is not passed as a carry in scannable layers. + x, _ = layer( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, + ) + + # The output of the block is the carry for the next scan iteration. + return x, None + + +class Qwen3NextDecoderLayer(nnx.Module): + """ + This layer is a hybrid, capable of functioning as either: + 1. A standard attention + MoE layer. + 2. A linear attention + MoE layer. + + NOTE: This implementation assumes every layer contains a MoE block, which is true for + models like Qwen3-Next-80B-A3B where `decoder_sparse_step=1`. For models that + interleave dense and sparse MLP layers, conditional logic would be needed here. + + Attributes: + config: The model configuration object. + mesh: The device mesh for sharding. + model_mode: The operational mode (e.g., 'train', 'prefill'). + layer_idx: The index of the current layer in the transformer stack. + quant: Optional quantization configuration. + """ + + def __init__( + self, config: Config, mesh: Mesh, model_mode: str, layer_idx: int, quant: None | Quant = None, *, rngs: nnx.Rngs + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.layer_idx = layer_idx + self.quant = quant + cfg = self.config + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + # First LayerNorm, applied before the attention block. + self.input_layernorm = Qwen3NextRMSNorm( + num_features=cfg.emb_dim, + eps=cfg.normalization_layer_epsilon, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + rngs=rngs, + ) + + # Determine the type of attention mechanism for the current layer. + is_full_attention_layer = (self.layer_idx + 1) % cfg.inhomogeneous_layer_cycle_interval == 0 + + # Conditionally instantiate either the Linear Attention or Full Attention block. + if is_full_attention_layer: + self.attention = Qwen3NextFullAttention( + config=cfg, + mesh=self.mesh, + quant=self.quant, + model_mode=model_mode, + layer_idx=self.layer_idx, + rngs=rngs, + ) + else: + self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) + + # Second LayerNorm, applied before the MoE block. + self.post_attention_layernorm = Qwen3NextRMSNorm( + num_features=cfg.emb_dim, + eps=cfg.normalization_layer_epsilon, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + rngs=rngs, + ) + + # Instantiate our `Qwen3NextSparseMoeBlock`. + self.mlp = Qwen3NextSparseMoeBlock(config=cfg, mesh=self.mesh, quant=self.quant, rngs=rngs) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache: None | dict[str, Array] = None, + attention_metadata: None | dict[str, Any] = None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + residual = inputs + + # First LayerNorm, applied before the attention block. + hidden_states = self.input_layernorm(inputs) + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + + # Conditionally apply either the Linear Attention or Full Attention block. + if isinstance(self.attention, Qwen3NextFullAttention): + attention_output, new_kv_cache = cast(Qwen3NextFullAttention, self.attention)( + hidden_states, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + else: + attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)( + hidden_states, + model_mode=model_mode, + kv_cache=None, + decoder_segment_ids=decoder_segment_ids, + ) + new_kv_cache = None + + # First residual connection after attention + hidden_states = residual + attention_output + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + + # Prepare for the MoE block by capturing the new residual + residual = hidden_states + + # Second LayerNorm, applied before the MoE block. + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + + # Instantiate and call our `Qwen3NextSparseMoeBlock`. + mlp_output, load_balance_loss = self.mlp(hidden_states, deterministic=deterministic) + + # We sow the load balancing loss so it can be collected and added to the total loss + # during training. + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow("intermediates", "moe_lb_loss", load_balance_loss) + + # Final residual connection (after the MoE block) + layer_output = residual + mlp_output + layer_output = nn.with_logical_constraint( + layer_output, + self.activation_axis_names, + ) + return layer_output, new_kv_cache + + +# ----------------------------------------- +# The Base Decoder Layer for Qwen3 +# ----------------------------------------- +class AttentionWithNorm(nnx.Module): + """Base class with shared common components: self-attention block with normalization.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + # Corresponds to Qwen3's `input_layernorm` + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # Self-attention block + query_pre_attn_scalar = config.head_dim**-0.5 # Qwen3 specific scaling + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + use_qk_norm=config.use_qk_norm, + query_pre_attn_scalar=query_pre_attn_scalar, + model_mode=model_mode, + use_mrope=config.use_mrope, + mrope_section=config.mrope_section, + rngs=rngs, + ) + + # Post Attention LayerNorm (corresponds to Qwen3's `post_attention_layernorm`) + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + def apply_attention_with_norm( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + """Applies self-attention with pre and post-layer normalization.""" + inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + # Pre attention norm + lnx = self.pre_self_attention_layer_norm(inputs) + lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + # Self attention + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + # Residual connection after attention + intermediate_inputs = inputs + attention_lnx + # Post attention norm + hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + return hidden_states, intermediate_inputs, kv_cache + + +# ----------------------------------------- +# The Dense Decoder Layer for Qwen3 +# ----------------------------------------- +class Qwen3DecoderLayer(AttentionWithNorm): + """Qwen3 Transformer decoder layer (dense).""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + super().__init__(config, mesh, model_mode, quant, rngs) + self.mlp = MlpBlock( + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + mesh=mesh, + quant=quant, + model_mode=model_mode, + rngs=rngs, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + + layer_output = intermediate_inputs + mlp_lnx + layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + + if self.config.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +# ----------------------------------------- +# The MoE Decoder Layer for Qwen3 +# ----------------------------------------- +class Qwen3MoeDecoderLayer(AttentionWithNorm): + """Qwen3 Transformer decoder layer (MoE).""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + super().__init__(config, mesh, model_mode, quant, rngs) + self.moe_block = RoutedMoE( + config=config, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + mesh=mesh, + kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=config.moe_mlp_dim, # same as config.mlp_dim + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=quant, + rngs=rngs, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow("intermediates", "moe_lb_loss", load_balance_loss) + + layer_output = intermediate_inputs + mlp_lnx + layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + + if self.config.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +class Qwen3OmniMoeVisionPatchMerger(nnx.Module): + """Vision patch merger that spatially merges patches using an MLP. + + Attributes: + config: Config containing model parameters + hidden_size: Hidden dimension after spatial merging + use_postshuffle_norm: Whether to apply normalization after spatial shuffle + dtype: Data type for computation + weight_dtype: Data type for weights + kernel_init: Initializer for kernel weights + rngs: RNG state for initialization + ln_q: LayerNorm before MLP + mlp_0: First MLP layer + mlp_2: Second MLP layer + """ + + def __init__( + self, + config: Config, + use_postshuffle_norm: bool = False, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + kernel_init: max_initializers.NdInitializer = max_initializers.nd_dense_init(1.0, "fan_in", "normal"), + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3Omni vision patch merger. + + Args: + config: Config containing model parameters + use_postshuffle_norm: Whether to apply normalization after spatial shuffle + dtype: Data type for computation + weight_dtype: Data type for weights + kernel_init: Initializer for kernel weights + rngs: RNG state for initialization + """ + self.config = config + self.use_postshuffle_norm = use_postshuffle_norm + self.dtype = dtype + self.weight_dtype = weight_dtype + self.kernel_init = kernel_init + self.rngs = rngs + + # Calculate hidden_size after spatial merge + spatial_merge_size = config.spatial_merge_size_for_vit + base_hidden_size = config.hidden_size_for_vit + out_hidden_size = config.out_hidden_size_for_vit + + self.hidden_size = base_hidden_size * (spatial_merge_size**2) + + # LayerNorm before MLP + ln_features = self.hidden_size if use_postshuffle_norm else base_hidden_size + self.ln_q = nnx.LayerNorm( + num_features=ln_features, + epsilon=config.normalization_layer_epsilon, + dtype=dtype, + rngs=rngs, + ) + + # MLP layers: Linear -> GELU -> Linear + self.mlp_0 = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=self.hidden_size, + use_bias=True, + dtype=dtype, + weight_dtype=weight_dtype, + kernel_init=kernel_init, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + self.mlp_2 = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=out_hidden_size, + use_bias=True, + dtype=dtype, + weight_dtype=weight_dtype, + kernel_init=kernel_init, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + def __call__(self, hidden: Array) -> Array: + """ + Args: + hidden: Input tensor of shape (batch, seq_len, base_hidden_size) after spatial reordering + + Returns: + Output tensor of shape (batch, seq_len//merge_size**2, out_hidden_size) - spatially merged + """ + # Get dimensions + spatial_merge_size = self.config.spatial_merge_size_for_vit + base_hidden_size = self.config.hidden_size_for_vit + tokens_per_block = spatial_merge_size**2 + + batch_size = hidden.shape[0] + seq_len = hidden.shape[1] + num_blocks = seq_len // tokens_per_block + + hidden = hidden.reshape(batch_size, num_blocks, tokens_per_block * base_hidden_size) + + # Apply layer norm + if self.use_postshuffle_norm: + hidden = self.ln_q(hidden) + else: + hidden_unmerged = hidden.reshape(batch_size, seq_len, base_hidden_size) + hidden_unmerged = self.ln_q(hidden_unmerged) + hidden = hidden_unmerged.reshape(batch_size, num_blocks, tokens_per_block * base_hidden_size) + + # MLP: Linear -> GELU -> Linear + hidden = self.mlp_0(hidden) + hidden = jax.nn.gelu(hidden) + hidden = self.mlp_2(hidden) + + return hidden + + +class Qwen3OmniMoeVisionMLP(nnx.Module): + """Vision MLP block with GELU activation. + + Attributes: + config: Config containing model parameters + hidden_size: Hidden dimension size + intermediate_size: Intermediate dimension size + dtype: Data type for computation + weight_dtype: Data type for weights + kernel_init: Initializer for kernel weights + rngs: RNG state for initialization + linear_fc1: First linear layer + linear_fc2: Second linear layer + """ + + def __init__( + self, + config: Config, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + kernel_init: max_initializers.NdInitializer = max_initializers.nd_dense_init(1.0, "fan_in", "normal"), + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3Omni vision MLP. + + Args: + config: Config containing model parameters + dtype: Data type for computation + weight_dtype: Data type for weights + kernel_init: Initializer for kernel weights + rngs: RNG state for initialization + """ + self.config = config + self.dtype = dtype + self.weight_dtype = weight_dtype + self.kernel_init = kernel_init + self.rngs = rngs + + self.hidden_size = config.hidden_size_for_vit + self.intermediate_size = config.intermediate_size_for_vit + + self.linear_fc1 = DenseGeneral( + in_features_shape=self.hidden_size, + out_features_shape=self.intermediate_size, + use_bias=True, + dtype=dtype, + weight_dtype=weight_dtype, + kernel_init=kernel_init, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + self.linear_fc2 = DenseGeneral( + in_features_shape=self.intermediate_size, + out_features_shape=self.hidden_size, + use_bias=True, + dtype=dtype, + weight_dtype=weight_dtype, + kernel_init=kernel_init, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + def __call__(self, hidden_state: Array) -> Array: + """ + Args: + hidden_state: Input tensor of shape (..., hidden_size) - supports packed sequences + + Returns: + Output tensor of shape (..., hidden_size) + """ + hidden_state = self.linear_fc1(hidden_state) + hidden_state = jax.nn.gelu(hidden_state) + hidden_state = self.linear_fc2(hidden_state) + return hidden_state + + +class Qwen3OmniMoeVisionPatchEmbed(nnx.Module): + """3D convolution-based patch embedding for vision inputs. + + Attributes: + config: Config containing model parameters + patch_size: Spatial patch size + temporal_patch_size: Temporal patch size + in_channels: Number of input channels + embed_dim: Embedding dimension + dtype: Data type for computation + weight_dtype: Data type for weights + rngs: RNG state for initialization + proj: Convolution projection layer + """ + + def __init__( + self, + config: Config, + # Default to float32 for numerical stability in 3D convolutions on image/video inputs + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + rngs: nnx.Rngs = None, + ): + """Initializes the Qwen3Omni vision patch embedding. + + Args: + config: Config containing model parameters + dtype: Data type for computation (defaults to float32 for numerical stability) + weight_dtype: Data type for weights (defaults to float32 for numerical stability) + rngs: RNG state for initialization + """ + self.config = config + self.dtype = dtype + self.weight_dtype = weight_dtype + self.rngs = rngs + + self.patch_size = config.patch_size_for_vit + self.temporal_patch_size = config.temporal_patch_size_for_vit + self.in_channels = config.num_channels_for_vit + self.embed_dim = config.hidden_size_for_vit + + kernel_size = (self.temporal_patch_size, self.patch_size, self.patch_size) + + self.proj = nnx.Conv( + in_features=self.in_channels, + out_features=self.embed_dim, + kernel_size=kernel_size, + strides=kernel_size, + use_bias=True, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + def __call__(self, hidden_states: Array) -> Array: + """ + Args: + hidden_states: Input tensor of shape (batch, in_channels, temporal*patch_size, height*patch_size, width*patch_size) + Returns: + Output tensor of shape (batch, T*H*W, embed_dim) where T, H, W are the number of patches + """ + hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) + hidden_states = self.proj(hidden_states) + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] * hidden_states.shape[2] * hidden_states.shape[3] + hidden_states = hidden_states.reshape(batch_size, seq_len, self.embed_dim) + return hidden_states + + +class Qwen3OmniMoeVisionAttention(nnx.Module): + """Vision attention layer wrapper. + + Attributes: + config: Config containing model parameters + attn: Underlying attention module + """ + + def __init__(self, config: Config, *, mesh=None, rngs: nnx.Rngs = None): + """Initializes the Qwen3Omni vision attention layer. + + Args: + config: Config containing model parameters + mesh: JAX device mesh for sharding + rngs: RNG state for initialization + """ + self.config = config + head_dim = self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit + # Vision uses full SA, no kv cache + self.attn = Attention( + config=self.config, + num_query_heads=self.config.num_attention_heads_for_vit, + num_kv_heads=self.config.num_attention_heads_for_vit, + head_dim=head_dim, + max_target_length=self.config.num_position_embeddings_for_vit, + attention_kernel="dot_product", + inputs_q_shape=(1, 1, self.config.hidden_size_for_vit), + inputs_kv_shape=(1, 1, self.config.hidden_size_for_vit), + float32_qk_product=self.config.float32_qk_product, + float32_logits=self.config.float32_logits, + dtype=self.config.dtype_mm, + weight_dtype=self.config.weight_dtype, + mesh=mesh, + dropout_rate=0.0, + attention_type=AttentionType.FULL, + is_nope_layer=False, + use_bias_in_projections=True, + is_vision=True, + use_qk_norm=False, + query_pre_attn_scalar=head_dim ** (-0.5), + model_mode="train", + rngs=rngs, + ) + + def __call__( + self, + hidden_states: Array, + num_frames: int, + height: int, + width: int, + deterministic: bool = True, + ) -> Array: + """ + Args: + hidden_states: Input tensor of shape (batch, T*H*W, hidden_size) + num_frames: Number of temporal frames (static) + height: Height in patches (static) + width: Width in patches (static) + deterministic: Whether to use deterministic mode (disable dropout) + + Returns: + Output tensor of shape (batch, T*H*W, hidden_size) + """ + # Pass through attention with static dimensions via rope_kwargs + rope_kwargs = { + "num_frames": num_frames, + "height": height, + "width": width, + } + output, _ = self.attn( + inputs_q=hidden_states, + inputs_kv=hidden_states, + deterministic=deterministic, + rope_kwargs=rope_kwargs, + ) + + return output + + +class Qwen3OmniMoeVisionBlock(nnx.Module): + """Vision transformer block with attention and MLP. + + Attributes: + config: Config containing model parameters + ln1: LayerNorm before attention + ln2: LayerNorm before MLP + attn: Attention module + mlp: First MLP layer + mlp_out: Second MLP layer + """ + + def __init__(self, config: Config, *, mesh=None, rngs: nnx.Rngs = None): + """Initializes the Qwen3Omni vision transformer block. + + Args: + config: Config containing model parameters + mesh: JAX device mesh for sharding + rngs: RNG state for initialization + """ + self.config = config + hs = self.config.hidden_size_for_vit + self.ln1 = nnx.LayerNorm(num_features=hs, epsilon=config.normalization_layer_epsilon, rngs=rngs) + self.ln2 = nnx.LayerNorm(num_features=hs, epsilon=config.normalization_layer_epsilon, rngs=rngs) + self.attn = Qwen3OmniMoeVisionAttention(config=config, mesh=mesh, rngs=rngs) + self.mlp = DenseGeneral( + in_features_shape=hs, + out_features_shape=self.config.intermediate_size_for_vit, + use_bias=True, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + self.mlp_out = DenseGeneral( + in_features_shape=self.config.intermediate_size_for_vit, + out_features_shape=hs, + use_bias=True, + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + def __call__( + self, + x: Array, + num_frames: int, + height: int, + width: int, + ) -> Array: + """ + Args: + x: Input tensor of shape (batch, T*H*W, hidden_size) + num_frames: Number of temporal frames (static) + height: Height in patches (static)i + width: Width in patches (static) + + Returns: + Output tensor of shape (batch, T*H*W, hidden_size) + """ + x = x + self.attn(self.ln1(x), num_frames=num_frames, height=height, width=width) + y = self.ln2(x) + y = self.mlp(y) + y = jax.nn.gelu(y) + y = self.mlp_out(y) + return x + y + + +class Qwen3OmniMoeVisionEncoder(nnx.Module): + """Vision encoder with patch embedding, positional embedding, and transformer blocks. + + Attributes: + config: Config containing model parameters + patch_embed: Patch embedding module + pos_embed_interpolate: Position embedding interpolation module + blocks: List of transformer blocks + merger_list: List of patch mergers for deep supervision + spatial_merge_size: Size of spatial merging + deep_idx: Indices of layers to extract deep features from + """ + + def __init__(self, config: Config, *, mesh=None, rngs: nnx.Rngs = None): + """Initializes the Qwen3Omni vision encoder. + + Args: + config: Config containing model parameters + mesh: JAX device mesh for sharding + rngs: RNG state for initialization + """ + self.config = config + self.patch_embed = Qwen3OmniMoeVisionPatchEmbed(config=config, rngs=rngs) + + num_pos = config.num_position_embeddings_for_vit + hs = config.hidden_size_for_vit + self.spatial_merge_size = config.spatial_merge_size_for_vit + + self.pos_embed_interpolate = Qwen3OmniMoeVisionPosEmbedInterpolate( + num_position_embeddings=num_pos, + hidden_size=hs, + spatial_merge_size=self.spatial_merge_size, + rngs=rngs, + ) + + self.depth = config.num_hidden_layers_for_vit + + # Use setattr with string names instead of nnx.List to avoid Orbax integer key bug + for i in range(self.depth): + block_name = f"blocks_{i}" + block = Qwen3OmniMoeVisionBlock(config=config, mesh=mesh, rngs=rngs) + setattr(self, block_name, block) + + self.deep_idx = tuple(config.deepstack_visual_indexes_for_vit) + # Use setattr with string names instead of nnx.List to avoid Orbax integer key bug + for i, _ in enumerate(self.deep_idx): + merger_name = f"merger_{i}" + merger = Qwen3OmniMoeVisionPatchMerger(config=config, use_postshuffle_norm=True, rngs=rngs) + setattr(self, merger_name, merger) + + def __call__( + self, + hidden_states: Array, + deterministic: bool = True, + ): + """ + Args: + hidden_states: Input visual tokens of shape (batch, in_channels, T*patch_size, H*patch_size, W*patch_size) + deterministic: Whether to use deterministic mode + + Returns: + Tuple of: + - encoder_output: shape (batch, T*H*W, hidden_size_for_vit) + - deep_features: List of intermediate features, each of shape (batch, T*H*W, out_hidden_size) + """ + _, _, num_frames, height, width = hidden_states.shape + num_frames = num_frames // self.config.temporal_patch_size_for_vit + height = height // self.config.patch_size_for_vit + width = width // self.config.patch_size_for_vit + + x = self.patch_embed(hidden_states) + pos = self.pos_embed_interpolate(num_frames, height, width) + + pos = pos[jnp.newaxis, :, :] + x = x + pos + + h_traj = [] + for i in range(self.depth): + block_name = f"blocks_{i}" + blk = getattr(self, block_name) + x = blk(x, num_frames=num_frames, height=height, width=width) + h_traj.append(x) + + deep_feats = [] + for i, idx in enumerate(self.deep_idx): + h = h_traj[idx] + merger_name = f"merger_{i}" + merger = getattr(self, merger_name) + deep_feat = merger(h) + deep_feats.append(deep_feat) + + return x, deep_feats + + +class Qwen3OmniMoeVisionProjector(nnx.Module): + """Projection layer that converts vision encoder output to model embedding space. + + Attributes: + config: Config containing model parameters + merger: Patch merger for spatial reduction + """ + + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + """Initializes the Qwen3Omni vision projector. + + Args: + config: Config containing model parameters + rngs: RNG state for initialization + """ + self.config = config + self.merger = Qwen3OmniMoeVisionPatchMerger(config=config, use_postshuffle_norm=False, rngs=rngs) + + def __call__(self, hidden_states: Array) -> Array: + """ + Args: + hidden_states: Encoder output of shape (batch, T*H*W, hidden_size_for_vit) + + Returns: + Projected output of shape (batch, T*H*W//merge_size**2, out_hidden_size_for_vit) + """ + output = self.merger(hidden_states) + return output + + +def qwen3omni_visionencoder_as_linen(config: Config, mesh: Mesh) -> nn.Module: + """Convert Qwen3OmniMoeVisionEncoder to Linen module.""" + return nnx_wrappers.to_linen( + Qwen3OmniMoeVisionEncoder, + config=config, + mesh=mesh, + name="Qwen3OmniMoeVisionEncoder_0", + abstract_init=False, + metadata_fn=max_initializers.variable_to_logically_partitioned, + ) + + +def qwen3omni_visionprojector_as_linen(config: Config, mesh: Mesh) -> nn.Module: + """Convert Qwen3OmniMoeVisionProjector to Linen module.""" + return nnx_wrappers.to_linen( + Qwen3OmniMoeVisionProjector, + config=config, + name="Qwen3OmniMoeVisionProjector_0", + abstract_init=False, + metadata_fn=max_initializers.variable_to_logically_partitioned, + ) + + +class Qwen3OmniAudioEncoderLayer(nnx.Module): + """Transformer encoder layer for audio model.""" + + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + self.config = config + self.mesh = mesh + self.rngs = rngs + + self.hidden_states_shape = ( + self.config.per_device_batch_size, + self.config.max_source_positions_for_audio, + self.config.d_model_for_audio, + ) + + self.input_layer_norm = nnx.LayerNorm( + num_features=self.config.d_model_for_audio, + epsilon=1e-5, + dtype=self.config.dtype_mm, + rngs=self.rngs, + ) + + self.self_attention_audio = Attention( + config=self.config, + num_query_heads=self.config.encoder_attention_heads_for_audio, + num_kv_heads=self.config.encoder_attention_heads_for_audio, + head_dim=self.config.d_model_for_audio // self.config.encoder_attention_heads_for_audio, + max_target_length=self.config.max_source_positions_for_audio, + attention_kernel="dot_product", + inputs_q_shape=self.hidden_states_shape, + inputs_kv_shape=self.hidden_states_shape, + float32_qk_product=self.config.float32_qk_product, + float32_logits=self.config.float32_logits, + dtype=self.config.dtype_mm, + weight_dtype=self.config.weight_dtype, + mesh=self.mesh, + dropout_rate=self.config.attention_dropout_for_audio, + name="self_attention_audio", + attention_type=AttentionType.FULL, + is_nope_layer=True, # No rotary position embeddings for audio + use_bias_in_projections=True, + use_qk_norm=False, + query_pre_attn_scalar=1 + / math.sqrt(self.config.d_model_for_audio // self.config.encoder_attention_heads_for_audio), + model_mode=MODEL_MODE_TRAIN, + rngs=self.rngs, + ) + + self.post_attention_layer_norm = nnx.LayerNorm( + num_features=self.config.d_model_for_audio, + epsilon=1e-5, + dtype=self.config.dtype_mm, + rngs=self.rngs, + ) + + self.AudioMLP = MlpBlock( + config=self.config, + mesh=self.mesh, + in_features=self.config.d_model_for_audio, + intermediate_dim=self.config.encoder_ffn_dim_for_audio, + activations=("gelu",), # Single GELU activation + kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + intermediate_dropout_rate=0.0, # No dropout to match AudioMLP + dtype=self.config.dtype_mm, + weight_dtype=self.config.weight_dtype, + use_bias=True, # AudioMLP uses bias + use_pre_norm=False, # Norm is handled outside + quant=None, # No quantization + model_mode=None, # Not needed for encoder + rngs=rngs, + ) + + def __call__( + self, + hidden_states: Array, + deterministic: bool = False, + ): + """Apply transformer encoder layer to audio hidden states. + + Args: + hidden_states: Input tensor of shape (batch, seq_len, d_model_for_audio) + deterministic: Whether to use deterministic mode (disable dropout) + + Returns: + Output tensor of shape (batch, seq_len, d_model_for_audio) + """ + residual = hidden_states + hidden_states = self.input_layer_norm(hidden_states) + hidden_states, _ = self.self_attention_audio( + inputs_q=hidden_states, + inputs_kv=hidden_states, + deterministic=deterministic, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layer_norm(hidden_states) + hidden_states = self.AudioMLP(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3OmniAudioEncoder(nnx.Module): + """Full audio encoder with convs, positional embeddings, and transformer layers. + + Attributes: + config: Config containing model parameters + mesh: Mesh, JAX device mesh (used for sharding) + """ + + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + self.config = config + self.mesh = mesh + self.rngs = rngs + + self.positional_embedding = PositionalEmbedding( + embedding_dims=self.config.d_model_for_audio, + max_wavelength=self.config.max_timescale_for_audio, + cast_as_fprop_dtype=True, + fprop_dtype=self.config.dtype_mm, + ) + + self.layernorm_post = nnx.LayerNorm( + num_features=self.config.d_model_for_audio, + epsilon=1e-5, + dtype=self.config.dtype_mm, + rngs=self.rngs, + ) + + # Convolutional downsampling layers + self.conv2d1 = nnx.Conv( + in_features=1, + out_features=self.config.downsample_hidden_size_for_audio, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + use_bias=True, + dtype=self.config.dtype_mm, + param_dtype=self.config.weight_dtype, + precision=self.config.matmul_precision, + rngs=self.rngs, + ) + + self.conv2d2 = nnx.Conv( + in_features=self.config.downsample_hidden_size_for_audio, + out_features=self.config.downsample_hidden_size_for_audio, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + use_bias=True, + dtype=self.config.dtype_mm, + param_dtype=self.config.weight_dtype, + precision=self.config.matmul_precision, + rngs=self.rngs, + ) + + self.conv2d3 = nnx.Conv( + in_features=self.config.downsample_hidden_size_for_audio, + out_features=self.config.downsample_hidden_size_for_audio, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + use_bias=True, + dtype=self.config.dtype_mm, + param_dtype=self.config.weight_dtype, + precision=self.config.matmul_precision, + rngs=self.rngs, + ) + + conv_out_dim = self.config.downsample_hidden_size_for_audio * ( + (((self.config.num_mel_bins_for_audio + 1) // 2 + 1) // 2 + 1) // 2 + ) + self.conv_out = DenseGeneral( + in_features_shape=conv_out_dim, + out_features_shape=self.config.d_model_for_audio, + use_bias=False, + dtype=self.config.dtype_mm, + weight_dtype=self.config.weight_dtype, + kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + matmul_precision=self.config.matmul_precision, + rngs=self.rngs, + ) + + # Transformer encoder layers + for lyr in range(self.config.encoder_layers_for_audio): + layer_name = f"layers_{lyr}" + layer = Qwen3OmniAudioEncoderLayer( + config=self.config, + mesh=self.mesh, + rngs=self.rngs, + ) + setattr(self, layer_name, layer) + + def __call__( + self, + audio_features: Array, + deterministic: bool = False, + ): + """Process audio features through convs + transformer encoder. + + Args: + audio_features: Input of shape (batch, num_mel_bins, audio_length) + deterministic: Whether to use deterministic mode + + Returns: + Encoded features of shape (batch, seq_len, d_model_for_audio) + """ + batch_size, num_mel_bins, audio_length = audio_features.shape + chunk_size = self.config.n_window_for_audio * 2 + + # Reshape to chunks + num_chunks = audio_length // chunk_size + audio_chunks = audio_features.reshape(batch_size, num_mel_bins, num_chunks, chunk_size) + audio_chunks = audio_chunks.transpose(0, 2, 1, 3) + audio_chunks = audio_chunks.reshape(batch_size * num_chunks, num_mel_bins, chunk_size) + + # Add channel dimension + hidden_states = audio_chunks[:, :, :, jnp.newaxis] + + # Apply convolutional layers + hidden_states = self.conv2d1(hidden_states) + hidden_states = jax.nn.gelu(hidden_states) + hidden_states = self.conv2d2(hidden_states) + hidden_states = jax.nn.gelu(hidden_states) + hidden_states = self.conv2d3(hidden_states) + hidden_states = jax.nn.gelu(hidden_states) + + # Reshape conv output + bc, f, t, c = hidden_states.shape + hidden_states = hidden_states.transpose(0, 2, 3, 1) + hidden_states = hidden_states.reshape(bc, t, c * f) + hidden_states = self.conv_out(hidden_states) + + # Add positional embeddings + seq_len_per_chunk = hidden_states.shape[1] + pos_emb = self.positional_embedding(seq_len_per_chunk) + pos_emb = jnp.broadcast_to( + pos_emb[None, :, :], (batch_size * num_chunks, seq_len_per_chunk, self.config.d_model_for_audio) + ) + hidden_states = hidden_states + pos_emb + + # Apply transformer encoder layers + for lyr in range(self.config.encoder_layers_for_audio): + layer_name = f"layers_{lyr}" + layer = getattr(self, layer_name) + hidden_states = layer( + hidden_states, + deterministic=deterministic, + ) + + hidden_states = self.layernorm_post(hidden_states) + + # Reshape back: (batch*chunks, seq_len_per_chunk, d_model) -> (batch, chunks*seq_len_per_chunk, d_model) + hidden_states = hidden_states.reshape(batch_size, num_chunks * seq_len_per_chunk, self.config.d_model_for_audio) + + return hidden_states + + +class Qwen3OmniAudioProjector(nnx.Module): + """Projection layer that converts audio encoder output to model embedding space.""" + + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + self.config = config + self.proj1 = DenseGeneral( + in_features_shape=config.d_model_for_audio, + out_features_shape=config.d_model_for_audio, + use_bias=True, + dtype=config.dtype_mm, + weight_dtype=config.weight_dtype, + kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + self.proj2 = DenseGeneral( + in_features_shape=config.d_model_for_audio, + out_features_shape=config.output_dim_for_audio, + use_bias=True, + dtype=config.dtype_mm, + weight_dtype=config.weight_dtype, + kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + matmul_precision=config.matmul_precision, + rngs=rngs, + ) + + def __call__(self, hidden_states: Array) -> Array: + """ + Args: + hidden_states: Encoder output of shape (num_chunks, seq_len, d_model_for_audio) + + Returns: + Projected output of shape (num_chunks, seq_len, output_dim_for_audio) + """ + hidden_states = self.proj1(hidden_states) + hidden_states = jax.nn.gelu(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + +def qwen3omni_audioencoder_as_linen(config: Config, mesh: Mesh): + """Convert AudioEncoder (convs + transformer layers, no projector) to Linen module.""" + return nnx_wrappers.to_linen( + Qwen3OmniAudioEncoder, + config=config, + mesh=mesh, + name="Qwen3OmniAudioEncoder_0", + abstract_init=False, + metadata_fn=variable_to_logically_partitioned, + ) + + +def qwen3omni_audioprojector_as_linen(config: Config, mesh: Mesh): + """Convert AudioProjector to Linen module.""" + return nnx_wrappers.to_linen( + Qwen3OmniAudioProjector, + config=config, + name="Qwen3OmniAudioProjector_0", + abstract_init=False, + metadata_fn=variable_to_logically_partitioned, + ) + + +# Vision encoder Linen wrappers +Qwen3OmniMoeVisionPatchMergerToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionPatchMerger, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionMLPToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionMLP, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionPatchEmbedToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionPatchEmbed, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionAttentionToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionAttention, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionBlockToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionBlock, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionEncoderToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionEncoder, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniMoeVisionProjectorToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniMoeVisionProjector, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3DecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3DecoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3MoeDecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3MoeDecoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3NextDecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3NextDecoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3NextScannableBlockToLinen = nnx_wrappers.to_linen_class( + Qwen3NextScannableBlock, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +# Audio encoder Linen wrappers +Qwen3OmniAudioEncoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniAudioEncoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniAudioEncoderToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniAudioEncoder, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) + +Qwen3OmniAudioProjectorToLinen = nnx_wrappers.to_linen_class( + Qwen3OmniAudioProjector, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) diff --git a/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_config.py b/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_config.py new file mode 100644 index 0000000..db26be8 --- /dev/null +++ b/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_config.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Copyright Lightning AI. Licensed under the Apache License 2.0, +# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE + +from dataclasses import dataclass +from typing import Any, Literal, Optional, Type + +import torch +from typing_extensions import Self + +import lit_gpt.model +from lit_gpt.utils import find_multiple + + +@dataclass +class Config: + org: str = "Lightning-AI" + name: str = "lit-GPT" + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = True + local_window: int = -1 + mlp: bool = True + full_per_layer: int = 1000000 + mb_per_layer: int = -1 + ret_per_layer: int = -1 + gla_per_layer: int = -1 + nope: bool = False + mamba: bool = False + sc_attn: bool = False + rms_norm: bool= True + residual_in_fp32: bool = True + fused_add_norm: bool = True + mamba_init: bool = False + attn_layer_pos: str = None + gated_delta_per_layer: int = -1 + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" + intermediate_size: Optional[int] = None + condense_ratio: int = 1 + + def __post_init__(self): + # error checking + assert self.n_embd % self.n_head == 0 + # vocab size should be a power of 2 to be optimal on hardware. compute the closest value + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) + # compute the number of query groups + if self.n_query_groups is not None: + assert self.n_head % self.n_query_groups == 0 + else: + self.n_query_groups = self.n_head + # compute the intermediate size for MLP if not set + if self.intermediate_size is None: + if self._mlp_class == "LLaMAMLP": + raise ValueError("The config needs to set the `intermediate_size`") + self.intermediate_size = 4 * self.n_embd + + @property + def head_size(self) -> int: + return self.n_embd // self.n_head + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + conf_dict = name_to_config[name].copy() + conf_dict.update(kwargs) + return cls(**conf_dict) + + @property + def mlp_class(self) -> Type: + # `self._mlp_class` cannot be the type to keep the config json serializable + return getattr(lit_gpt.model, self._mlp_class) + + @property + def norm_class(self) -> Type: + # `self._norm_class` cannot be the type to keep the config json serializable + if self._norm_class == "RMSNorm": + from lit_gpt.rmsnorm import RMSNorm + + return RMSNorm + elif self._norm_class == "FusedRMSNorm": + from lit_gpt.rmsnorm import FusedRMSNorm + return FusedRMSNorm + return getattr(torch.nn, self._norm_class) + + +configs=[] + +GatedDeltaNet = [ + dict( + org="NVIDIA", + name="GatedDeltaNet_0.4B", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + gated_delta_per_layer=1, + n_layer=11, + n_head=12, + n_embd=1536, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=6144, + local_window = 2048, + mamba_init = True, + ), + dict( + org="NVIDIA", + name="GatedDeltaNet_H1_0.4B", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + gated_delta_per_layer=2, + n_layer=12, + n_head=12, + n_embd=1536, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=6144, + local_window = 2048, + mamba_init = True, + ), + dict( + org="NVIDIA", + name="GatedDeltaNet_1.3B", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + gated_delta_per_layer=1, + n_layer=16, + n_head=16, + n_embd=2400, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=5888, + local_window = 2048, + mamba_init = True, + ), + dict( + org="NVIDIA", + name="GatedDeltaNet_H1_1.3B", + block_size=4096, + vocab_size=32000, + padding_multiple=64, + gated_delta_per_layer=2, + n_layer=18, + n_head=18, + n_embd=2304, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=6144, + local_window = 2048, + mamba_init = True, + ), +] +configs.extend(GatedDeltaNet) + +name_to_config = {config["name"]: config for config in configs} diff --git a/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_model.py b/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_model.py new file mode 100644 index 0000000..5bb1a42 --- /dev/null +++ b/MaxCode/rag/sources/generic/nvlabs_gated_deltanet_model.py @@ -0,0 +1,576 @@ +# Modified by Songlin Yang & Ali Hatamizadeh + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Copyright Lightning AI. Licensed under the Apache License 2.0, +# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE + +import math +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +from lightning_utilities.core.imports import RequirementCache +from .gated_delta_net import GatedDeltaNet +from typing_extensions import Self +from lit_gpt.config import Config +from xformers.ops import SwiGLU +from .fused_rotary_embedding import apply_rotary_emb_func +from torch import Tensor +from functools import partial +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +from einops import rearrange +import torch.nn.functional as F + +from causal_conv1d import causal_conv1d_fn + +RoPECache = Tuple[torch.Tensor, torch.Tensor] +KVCache = Tuple[torch.Tensor, torch.Tensor] +FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = MBlock( + d_model, + mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + factory_kwargs = {"device": "cuda", "dtype": torch.float32} + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) + if config.mamba: + if self.config.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList( + create_block( + config.n_embd, + ssm_cfg=None, + norm_epsilon=config.norm_eps, + rms_norm=config.rms_norm, + residual_in_fp32=config.residual_in_fp32, + fused_add_norm=config.fused_add_norm, + layer_idx=i, + **factory_kwargs, + ) + for i in range(config.n_layer)), + ln_f= (nn.LayerNorm if not config.rms_norm else RMSNorm)( + config.n_embd, eps=config.norm_eps, + **factory_kwargs, + ) + ) + ) + + else: + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[torch.Tensor] = None + self.kv_caches: List[KVCache] = [] + self.max_len = self.config.block_size + self.mamba_init = config.mamba or config.mamba_init + if self.mamba_init: + self.tie_weights() + + def _init_weights(self, module: nn.Module, n_layer) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf + if isinstance(module, nn.Embedding): + if self.mamba_init: + torch.nn.init.normal_(module.weight, std=0.02) + else: + torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) + elif isinstance(module, nn.Linear): + if self.mamba_init: + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + else: + torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + # GPT-NeoX + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"] or (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3 + if self.mamba_init: + n_residuals_per_layer = 1 if self.config.mamba or not self.config.mlp else 2 + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + else: + nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) + + def tie_weights(self): + self.lm_head.weight = self.transformer.wte.weight + + + def reset_cache(self) -> None: + self.max_len = self.config.block_size + self.kv_caches.clear() + if self.mask_cache is not None and self.mask_cache.device.type == "xla": + # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 + self.rope_cache = None + self.mask_cache = None + + def forward( + self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.config.mamba: + hidden_states = self.transformer.wte(idx) + residual = None + for block in self.transformer.h: + hidden_states, residual = block( + hidden_states, residual, inference_params=None + ) + norm_f = self.transformer.ln_f + if not self.config.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = norm_f(residual.to(dtype= norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + norm_f.weight, + norm_f.bias, + eps=norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.config.residual_in_fp32, + ) + return self.lm_head(hidden_states) + + B, T = idx.size() + use_kv_cache = input_pos is not None + + block_size = self.config.block_size + if max_seq_length is None: + max_seq_length = block_size + if use_kv_cache: # not relevant otherwise + assert ( + max_seq_length >= T + ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + #assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + #assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" + if not self.config.nope: + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache(idx, self.max_len) + elif T> self.max_len: + self.max_len = T + self.rope_cache = self.build_rope_cache(idx, self.max_len) + cos, sin = self.rope_cache + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: + self.mask_cache = self.build_mask_cache(idx) + + if use_kv_cache: + if not self.config.nope: + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + else: + if not self.config.nope: + cos = cos[:T] + sin = sin[:T] + mask = None + if self.config.nope: + rope = None + else: + rope = (cos, sin) + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if not use_kv_cache: + for block in self.transformer.h: + x, *_ = block(x, rope, max_seq_length) + else: + if self.config.nope: + self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, None ) + else: + self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2) + for i, block in enumerate(self.transformer.h): + x, self.kv_caches[i] = block(x, rope, max_seq_length, mask, input_pos, self.kv_caches[i]) + + x = self.transformer.ln_f(x) + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def build_rope_cache(self, idx: torch.Tensor, seq_len: int) -> RoPECache: + return build_rope_cache( + seq_len=seq_len, + n_elem=int(self.config.rotary_percentage * self.config.head_size), + dtype=torch.bfloat16, + device=idx.device, + condense_ratio=self.config.condense_ratio, + ) + + def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: + ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: + B = idx.size(0) + heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups + if rope_cache_length is not None: + k_cache_shape = ( + B, + max_seq_length, + heads, + rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), + ) + else: + k_cache_shape = ( + B, + max_seq_length, + heads, + self.config.head_size, + ) + v_cache_shape = (B, max_seq_length, heads, self.config.head_size) + device = idx.device + return [ + (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) + for _ in range(self.config.n_layer) + ] + + +class Block(nn.Module): + def __init__(self, config: Config, layer_idx: int) -> None: + super().__init__() + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.use_gated_deltanet = layer_idx % config.gated_delta_per_layer == 0 if config.gated_delta_per_layer >0 else False + if self.use_gated_deltanet: + self.attn = GatedDeltaNet(hidden_size=config.n_embd) + else: + self.attn = CausalSelfAttention(config, n_embd= config.n_embd, layer_idx= layer_idx, ) + if not config.shared_attention_norm and config.mlp and not config.parallel_residual: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + if config.mlp: + self.mlp = config.mlp_class(config,) + self.config = config + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + + n_1 = self.norm_1(x) + + if self.use_gated_deltanet: + h, _ , new_kv_cache = self.attn(n_1, attention_mask=mask) + else: + h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) + if self.config.parallel_residual: + assert self.config.shared_attention_norm + if self.config.mlp: + h = h + self.mlp(n_1) + x = x + h + else: + x = x + h + if self.config.mlp: + n_2 = self.norm_2(x) + h = self.mlp(n_2) + x = x + h + return x, new_kv_cache + + +class MBlock(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config, layer_idx: int , n_embd: int, head_size = None) -> None: + super().__init__() + self.local = layer_idx % config.full_per_layer < config.full_per_layer-1 + if head_size is not None: + self.head_size = head_size + self.n_head = n_embd // head_size + self.n_query_groups = self.n_head + else: + self.head_size = config.head_size + self.n_head = config.n_head + self.n_query_groups = config.n_query_groups + shape = (self.n_head + 2 * self.n_query_groups) * self.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(n_embd, shape, bias=config.bias) + # output projection + self.proj = nn.Linear(n_embd, n_embd, bias=config.bias) + self.config = config + self.sc = config.sc_attn + if self.sc: + self.q_dim = self.n_head * self.head_size + self.kv_dim = self.n_query_groups * self.head_size + d_conv = 4 + self.q_conv1d = nn.Conv1d( + in_channels=self.q_dim, + out_channels=self.q_dim, + bias=False, + kernel_size=d_conv, + groups=self.q_dim, + padding=d_conv - 1, + ) + self.k_conv1d = nn.Conv1d( + in_channels=self.kv_dim, + out_channels=self.kv_dim, + bias=False, + kernel_size=d_conv, + groups=self.kv_dim, + padding=d_conv - 1, + ) + self.v_conv1d = nn.Conv1d( + in_channels= self.kv_dim, + out_channels= self.kv_dim, + bias=False, + kernel_size=d_conv, + groups= self.kv_dim, + padding=d_conv - 1, + ) + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + max_seq_length: int, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.n_head // self.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.n_query_groups, total_qkv, self.head_size) # (B, T, n_query_groups, total_qkv, hs) + # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) + q = q.reshape(B, T, -1 ) # (B, T, nh_q, hs) + k = k.reshape(B, T, -1 ) + v = v.reshape(B, T, -1 ) + if self.sc: + q = causal_conv1d_fn( + x = q.transpose(-1,-2), + weight=rearrange(self.q_conv1d.weight, "d 1 w -> d w"), + bias=self.q_conv1d.bias, + activation="silu", + ).transpose(-1,-2) + k = causal_conv1d_fn( + x = k.transpose(-1,-2), + weight=rearrange(self.k_conv1d.weight, "d 1 w -> d w"), + bias=self.k_conv1d.bias, + activation="silu", + ).transpose(-1,-2) + v = causal_conv1d_fn( + x = v.transpose(-1,-2), + weight=rearrange(self.v_conv1d.weight, "d 1 w -> d w"), + bias=self.v_conv1d.bias, + activation="silu", + ).transpose(-1,-2) + + q = q.reshape(B, T, -1, self.head_size) # (B, T, nh_q, hs) + k = k.reshape(B, T, -1, self.head_size) + v = v.reshape(B, T, -1, self.head_size) + + if not self.config.nope: + cos, sin = rope + # apply rope in fp32 significanly stabalize training + # fused rope expect (batch_size, seqlen, nheads, headdim) + q = apply_rotary_emb_func(q, cos, sin, False, True) + k = apply_rotary_emb_func(k, cos, sin, False, True) + + if kv_cache is not None: + cache_k, cache_v = kv_cache + cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) + # check if reached token limit + if input_pos[-1] >= max_seq_length: + input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) + # shift 1 position to the left + cache_k = torch.roll(cache_k, -1, dims=1) + cache_v = torch.roll(cache_v, -1, dims=1) + + k = cache_k.index_copy_(1, input_pos, k) + v = cache_v.index_copy_(1, input_pos, v) + kv_cache = k, v + + y = self.scaled_dot_product_attention(q, k, v, mask=mask) + + y = y.reshape(B, T, -1) # re-assemble all head outputs side by side + + # output projection + y = self.proj(y) + return y, kv_cache + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ): + scale = 1.0 / math.sqrt(self.head_size) + + if ( + FlashAttention2Available + and mask is None + and q.device.type == "cuda" + and q.dtype in (torch.float16, torch.bfloat16) + ): + from flash_attn import flash_attn_func + if self.local and self.config.local_window > -1: + win_tuple = (self.config.local_window-1, 0) + else: + win_tuple = (-1,-1) + return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True, window_size=win_tuple) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if q.size() != k.size(): + k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) + v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config,) -> None: + super().__init__() + self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=config.bias, _pack_weights=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.swiglu(x) + return x + +def build_rope_cache( + seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 +) -> RoPECache: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta) + + cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) + + # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding + if dtype == torch.bfloat16: + return cos.bfloat16(), sin.bfloat16() + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + return cos.half(), sin.half() + return cos, sin + + + \ No newline at end of file diff --git a/MaxCode/rag/sources/targeted/targeted_buffer_dtype_fidelity_jax.py b/MaxCode/rag/sources/targeted/targeted_buffer_dtype_fidelity_jax.py new file mode 100644 index 0000000..d3d85a1 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_buffer_dtype_fidelity_jax.py @@ -0,0 +1,57 @@ +""" +TARGETED RAG: Preserve Buffer Dtypes When Converting register_buffer to JAX +============================================================================= + +When converting PyTorch's register_buffer() to JAX, you MUST preserve the +exact dtype of the buffer tensor. torch.Tensor() creates float32 by default, +torch.LongTensor() creates int64, etc. + +WRONG -- Changing buffer dtype during conversion: +--------------------------------------------------- + # PyTorch source: + # self.register_buffer('version', torch.Tensor([2])) + # # torch.Tensor([2]) creates a float32 tensor containing [2.0] + + # WRONG! Changed dtype from float32 to int32 + self.sow('buffers', 'version', jnp.array([2], dtype=jnp.int32)) + +WHY THIS IS WRONG: +- torch.Tensor([2]) creates float32, NOT int32 +- Changing the dtype means the buffer has different bit representation +- Code that checks buffer dtype or uses it in float operations will break +- State dict comparison tools will flag the dtype mismatch + +CORRECT -- Match the exact PyTorch dtype: +------------------------------------------- + # PyTorch: torch.Tensor([2]) -> float32 + # CORRECT: preserve float32 dtype + self.sow('buffers', 'version', jnp.array([2.0], dtype=jnp.float32)) + +DTYPE REFERENCE for torch tensor constructors: +------------------------------------------------ + torch.Tensor([...]) -> float32 -> jnp.array([...], dtype=jnp.float32) + torch.FloatTensor([...]) -> float32 -> jnp.array([...], dtype=jnp.float32) + torch.DoubleTensor([...]) -> float64 -> jnp.array([...], dtype=jnp.float64) + torch.HalfTensor([...]) -> float16 -> jnp.array([...], dtype=jnp.float16) + torch.LongTensor([...]) -> int64 -> jnp.array([...], dtype=jnp.int64) + torch.IntTensor([...]) -> int32 -> jnp.array([...], dtype=jnp.int32) + torch.BoolTensor([...]) -> bool -> jnp.array([...], dtype=jnp.bool_) + torch.tensor([...]) -> inferred -> match the inferred dtype + torch.zeros(N) -> float32 -> jnp.zeros(N, dtype=jnp.float32) + torch.ones(N) -> float32 -> jnp.ones(N, dtype=jnp.float32) + +REGISTER_BUFFER conversion patterns: +-------------------------------------- + # PyTorch: + self.register_buffer('name', torch.Tensor([2])) + # JAX (using sow for mutable state): + self.sow('buffers', 'name', jnp.array([2.0], dtype=jnp.float32)) + + # PyTorch: + self.register_buffer('mask', torch.ones(seq_len, seq_len).triu(1).bool()) + # JAX (using variable for persistent state): + mask = jnp.triu(jnp.ones((seq_len, seq_len), dtype=jnp.float32), k=1).astype(jnp.bool_) + +RULE: Every buffer's dtype must match the PyTorch source exactly. +torch.Tensor() is float32, not int32. Always check the constructor. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_causal_conv1d_prefill_decode_jax.py b/MaxCode/rag/sources/targeted/targeted_causal_conv1d_prefill_decode_jax.py new file mode 100644 index 0000000..b129933 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_causal_conv1d_prefill_decode_jax.py @@ -0,0 +1,151 @@ +""" +TARGETED JAX PATTERN: Causal Conv1d — Separate Prefill and Decode Functions + +APPLICABILITY: This pattern applies ONLY to **causal** convolutions — those used +in autoregressive models, SSMs, and linear attention layers. Identify causal +conv1d by looking for: `conv_state` / rolling state management, output slicing +like `[:, :, :seq_len]` after the conv, or functions named `causal_conv1d`. + +DO NOT apply this pattern to standard (non-causal) conv1d layers found in +encoders, classifiers, or non-autoregressive models. For those, translate the +padding directly (e.g., PyTorch `padding="same"` -> JAX `padding="SAME"`, +PyTorch `padding=P` -> JAX `padding=((P, P),)`). + +CRITICAL: When this pattern DOES apply, implement causal conv1d as TWO separate +functions, not a single unified function with conditional branching. This gives +clearer semantics, better XLA optimization, and matches the PyTorch source's +separate causal_conv1d_fn and causal_conv1d_update functions. + +## WRONG approach (single unified function -- DO NOT DO THIS): + + # WRONG! Single function with conditional branching + def causal_conv1d(x, weight, bias=None, conv_state=None): + if conv_state is not None: + # decode path + conv_state = jnp.roll(conv_state, -1, axis=-1) + conv_state = conv_state.at[:, :, -1].set(x[:, :, 0]) + y = jnp.sum(conv_state * weight, axis=-1) + bias + return jax.nn.silu(y), conv_state + else: + # prefill path + x_padded = jnp.pad(x, ((0,0), (0,0), (weight.shape[-1]-1, 0))) + y = jax.lax.conv_general_dilated(...) + return jax.nn.silu(y), None + +## CORRECT approach (two separate functions): + + import jax + import jax.numpy as jnp + + def causal_conv1d(x, weight, bias=None, activation='silu'): + ''' + Causal conv1d for PREFILL: processes full sequence. + + Args: + x: [batch, channels, seq_len] input (channels-first) + weight: [channels, 1, kernel_size] depthwise conv kernel + bias: [channels] optional bias + activation: activation function name ('silu' or None) + + Returns: + y: [batch, channels, seq_len] output + conv_state: [batch, channels, kernel_size-1] state for subsequent decode + ''' + batch, channels, seq_len = x.shape + kernel_size = weight.shape[-1] + + # Depthwise 1D causal convolution: left-only padding prevents + # future information leakage. Passing the padding tuple directly + # to conv_general_dilated is cleaner than a separate jnp.pad call. + y = jax.lax.conv_general_dilated( + lhs=x, # [B, C, T] + rhs=weight, # [C, 1, K] + window_strides=(1,), + padding=((kernel_size - 1, 0),), # left-only pad + feature_group_count=channels, + dimension_numbers=('NCH', 'IOH', 'NCH'), + ) + + if bias is not None: + y = y + bias[None, :, None] + + if activation == 'silu': + y = jax.nn.silu(y) + + # Save the last (kernel_size - 1) timesteps as conv state for decode + conv_state = x[:, :, -(kernel_size - 1):] # [B, C, K-1] + + return y, conv_state + + def causal_conv1d_update(x_t, conv_state, weight, bias=None, activation='silu'): + ''' + Causal conv1d for DECODE: processes single timestep. + + Args: + x_t: [batch, channels] or [batch, channels, 1] single token input + conv_state: [batch, channels, kernel_size-1] rolling state + weight: [channels, 1, kernel_size] depthwise conv kernel + bias: [channels] optional bias + activation: activation function name ('silu' or None) + + Returns: + y_t: [batch, channels] output for this timestep + new_conv_state: [batch, channels, kernel_size-1] updated state + ''' + if x_t.ndim == 3: + x_t = x_t.squeeze(-1) # [B, C] + + # Roll state left: drop oldest, append new input + new_conv_state = jnp.concatenate( + [conv_state[:, :, 1:], x_t[:, :, None]], axis=-1 + ) # [B, C, K-1] + + # Full window = [state..., x_t] = new_conv_state padded? No: + # weight is [C, 1, K], state is [B, C, K-1], we need K values + full_window = jnp.concatenate( + [conv_state, x_t[:, :, None]], axis=-1 + ) # [B, C, K] + + # Depthwise multiply-sum (equivalent to conv with kernel_size window) + weight_squeezed = weight.squeeze(1) # [C, K] + y_t = jnp.sum(full_window * weight_squeezed[None, :, :], axis=-1) # [B, C] + + if bias is not None: + y_t = y_t + bias + + if activation == 'silu': + y_t = jax.nn.silu(y_t) + + return y_t, new_conv_state + +## Usage in a GatedDeltaNet layer: + + class GatedDeltaNetLayer(nn.Module): + @nn.compact + def __call__(self, x, cache=None, decode=False): + # ... projection ... + + if not decode: + # Prefill: full sequence convolution + conv_out, conv_state = causal_conv1d( + q_conv_input, self.conv_weight, self.conv_bias + ) + # ... chunk-parallel delta rule ... + else: + # Decode: single-step update + conv_out, new_conv_state = causal_conv1d_update( + q_conv_input, cache.conv_state, self.conv_weight, self.conv_bias + ) + # ... recurrent delta rule ... + +## Why two functions: + +1. **XLA optimization**: Two simple functions compile to tighter kernels than one + function with dynamic branching. +2. **Clarity**: Prefill processes [B, C, T], decode processes [B, C, 1]. Different + shapes, different algorithms, different code. +3. **Matches PyTorch**: The source has separate `causal_conv1d_fn` and + `causal_conv1d_update` functions. +4. **Cache management**: Prefill returns initial conv_state. Decode takes and + returns updated conv_state. Clean separation of concerns. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_config_dataclass_jax.py b/MaxCode/rag/sources/targeted/targeted_config_dataclass_jax.py new file mode 100644 index 0000000..36bce60 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_config_dataclass_jax.py @@ -0,0 +1,94 @@ +""" +TARGETED JAX PATTERN: Model Config as a Python Dataclass + +Every model conversion MUST include a Config dataclass at the top of the file. +This dataclass mirrors the PyTorch model's configuration class and provides +typed, defaulted fields for all hyperparameters. Without it, modules use +`config: Any` which loses type safety, IDE support, and default values. + +## WRONG: No Config dataclass, using Any + + class Qwen3NextAttention(nn.Module): + config: Any # No type info, no defaults, can't instantiate standalone + layer_idx: int + + # WHY THIS IS WRONG: + # - Cannot create a default config for testing: config = ??? + # - No IDE autocomplete for config.hidden_size, config.num_attention_heads + # - No documentation of what fields the config requires + # - Cannot validate config values at construction time + +## CORRECT: Full Config dataclass with all fields + + import dataclasses + from typing import Any, Dict, List + + @dataclasses.dataclass + class Qwen3NextConfig: + # Vocabulary and embeddings + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 22016 + + # Attention + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + head_dim: int = 128 + num_key_value_groups: int = 1 + + # Sequence + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + initializer_range: float = 0.02 + + # Layer configuration + num_hidden_layers: int = 32 + layer_types: List[str] = dataclasses.field( + default_factory=lambda: ["full_attention"] * 32 + ) + rope_parameters: Dict[str, Any] = dataclasses.field( + default_factory=lambda: { + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + } + ) + + # Gated DeltaNet (linear attention) + gated_delta_rule_chunk_size: int = 64 + v_head_dim: int = 128 + conv_size: int = 4 + num_v_heads: int = 16 + qk_nope_head_dim: int = 128 + + # MoE + num_experts: int = 64 + num_experts_per_tok: int = 4 + decoder_sparse_step: int = 1 + moe_intermediate_size: int = 1408 + shared_expert_intermediate_size: int = 5632 + norm_topk_prob: bool = False + router_aux_loss_coef: float = 0.001 + output_router_logits: bool = False + + # MLP-only layers + mlp_only_layers: List[int] = dataclasses.field(default_factory=list) + + # Misc + attention_bias: bool = False + attention_dropout: float = 0.0 + hidden_act: str = "silu" + tie_word_embeddings: bool = True + + # Then use it in modules: + class Qwen3NextAttention(nn.Module): + config: Qwen3NextConfig # Typed, not Any! + layer_idx: int + +## KEY POINTS: +## - ALWAYS include a @dataclasses.dataclass Config class at the top of the file +## - Use dataclasses.field(default_factory=...) for mutable defaults (lists, dicts) +## - Mirror ALL fields from the PyTorch config class +## - Use the Config type (not Any) in module annotations +## - Default values should match the PyTorch model's defaults +""" diff --git a/MaxCode/rag/sources/targeted/targeted_cosine_similarity_batchwise_jax.py b/MaxCode/rag/sources/targeted/targeted_cosine_similarity_batchwise_jax.py new file mode 100644 index 0000000..9d66f8d --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_cosine_similarity_batchwise_jax.py @@ -0,0 +1,104 @@ +""" +TARGETED JAX PATTERN: Batch-wise Cosine Similarity + +When the PyTorch source uses F.cosine_similarity on 2D tensors, it computes +per-sample (row-wise) similarity. The JAX conversion MUST preserve this +batch-wise semantics. Do NOT use a library function that computes a single +global similarity scalar over the entire tensor. + +## WRONG: Using optax.cosine_similarity (global, not per-sample) + + # PyTorch source: + # corr = F.cosine_similarity( + # expert_outputs[i].flatten(1), + # expert_outputs[j].flatten(1) + # ).mean() + # + # F.cosine_similarity with 2D input [B, D] returns a per-sample + # similarity vector of shape [B], then .mean() averages over samples. + + # WRONG! optax.cosine_similarity computes a single scalar over the + # entire tensor, not per-sample similarity. + sim = optax.cosine_similarity( + outputs[i].reshape(outputs[i].shape[0], -1), + outputs[j].reshape(outputs[j].shape[0], -1) + ) + return jnp.mean(sim) + +## CORRECT: Per-sample cosine similarity with manual computation + + # CORRECT: Compute cosine similarity per sample (row), then average. + def _cosine_similarity(a, b): + '''Per-sample cosine similarity for 2D arrays [B, D] -> [B].''' + a_norm = a / (jnp.linalg.norm(a, axis=-1, keepdims=True) + 1e-8) + b_norm = b / (jnp.linalg.norm(b, axis=-1, keepdims=True) + 1e-8) + return jnp.sum(a_norm * b_norm, axis=-1) + + sim = _cosine_similarity( + outputs[i].reshape(outputs[i].shape[0], -1), + outputs[j].reshape(outputs[j].shape[0], -1) + ) + return jnp.mean(sim) + +## CORRECT (alternative): Using jax.vmap over single-vector cosine similarity + + def _single_cosine_sim(a, b): + '''Cosine similarity for 1D vectors.''' + return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8) + + batch_cosine_sim = jax.vmap(_single_cosine_sim) + sim = batch_cosine_sim( + outputs[i].reshape(outputs[i].shape[0], -1), + outputs[j].reshape(outputs[j].shape[0], -1) + ) + return jnp.mean(sim) + +## WRONG: Using einsum that sums over both batch AND feature dimensions + + # If you stack expert outputs into shape [num_experts, batch_size, features] + # and normalize, you might be tempted to use a single einsum: + + outputs_stacked = jnp.stack([out.reshape(out.shape[0], -1) for out in expert_outputs]) + norms = jnp.linalg.norm(outputs_stacked, axis=2, keepdims=True) + outputs_norm = outputs_stacked / (norms + 1e-8) + + # WRONG! This sums over BOTH batch (k) and feature (d) dimensions, + # producing sum_k(sum_d(a[i,k,d] * b[j,k,d])) -- a single scalar per + # expert pair that conflates batch and feature reductions. + correlations = jnp.einsum('ikd,jkd->ij', outputs_norm, outputs_norm) + + # The result is NOT the mean of per-sample cosine similarities. + # It equals batch_size * mean(per_sample_cos_sim) only when all samples + # have equal norms, and even then the scaling is wrong. + +## CORRECT: Using einsum with separate batch and feature reductions + + outputs_stacked = jnp.stack([out.reshape(out.shape[0], -1) for out in expert_outputs]) + norms = jnp.linalg.norm(outputs_stacked, axis=2, keepdims=True) + outputs_norm = outputs_stacked / (norms + 1e-8) + + # CORRECT: First compute per-sample dot products with einsum over + # features only (d), keeping the batch dimension (b): + # per_sample_sim[i, j, b] = sum_d(a[i,b,d] * b[j,b,d]) + per_sample_sim = jnp.einsum('ibd,jbd->ijb', outputs_norm, outputs_norm) + + # Then average over the batch dimension to get mean cosine similarity: + correlations = per_sample_sim.mean(axis=2) + + # This matches F.cosine_similarity(...).mean() exactly: + # for each expert pair (i,j), compute per-sample cosine sim, then average. + +## WHY this matters: + +1. **Semantic difference**: F.cosine_similarity(a, b) with a=[B,D], b=[B,D] + returns shape [B] -- one similarity per sample. A global cosine similarity + returns a single scalar, which conflates all samples into one value. +2. **Numerical difference**: mean(per_sample_cosine_sim) != global_cosine_sim. + The global version effectively computes similarity between the "average + direction" of all samples, losing per-sample variation. +3. **Metric correctness**: expert_correlation is a diagnostic metric. Wrong + computation means misleading expert diversity analysis. +4. **General rule**: When the PyTorch source applies a pairwise operation + along dim=0 (batch dimension) and then reduces, preserve the per-sample + computation in JAX. Do not replace it with a global reduction. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_dead_code_helper_functions_jax.py b/MaxCode/rag/sources/targeted/targeted_dead_code_helper_functions_jax.py new file mode 100644 index 0000000..131ddfe --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_dead_code_helper_functions_jax.py @@ -0,0 +1,61 @@ +""" +TARGETED RAG: Preserve Helper Function Call Sites — No Dead Code +================================================================= + +When converting PyTorch to JAX, if the source defines a helper function and +calls it from another function, the JAX version MUST also call the helper. +Do not inline the helper's logic and leave the helper as dead code. + +WRONG -- Inlining logic and leaving helper as dead code: +---------------------------------------------------------- + # PyTorch source: + # def fill_with_neg_inf(t): + # return t.float().fill_(float('-inf')).type_as(t) + # + # def buffered_future_mask(tensor, tensor2=None): + # dim1 = dim2 = tensor.size(0) + # if tensor2 is not None: + # dim2 = tensor2.size(0) + # future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), ...) + # return future_mask[:dim1, :dim2] + + # WRONG! fill_with_neg_inf is defined but never called -- dead code + def fill_with_neg_inf(t): + return jnp.full_like(t, float('-inf'), dtype=t.dtype) + + def buffered_future_mask(tensor, tensor2=None): + dim1 = tensor.shape[0] + dim2 = dim1 if tensor2 is None else tensor2.shape[0] + # WRONG: inlined the logic instead of calling fill_with_neg_inf + inf_matrix = jnp.full((dim1, dim2), float('-inf'), dtype=jnp.float32) + future_mask = jnp.triu(inf_matrix, 1 + abs(dim2 - dim1)) + return future_mask[:dim1, :dim2] + +WHY THIS IS WRONG: +- fill_with_neg_inf preserves dtype via .type_as(t) -- important for FP16/BF16 +- The inlined version hardcodes jnp.float32, losing mixed-precision support +- Dead code confuses maintenance -- readers expect the helper to be used +- The source author created the helper for a reason (dtype safety) + +CORRECT -- Call the helper function just as the source does: +------------------------------------------------------------- + def fill_with_neg_inf(t): + \"\"\"FP16-compatible function that fills a tensor with -inf.\"\"\" + return jnp.full_like(t, float('-inf')) + + def buffered_future_mask(tensor, tensor2=None): + dim1 = tensor.shape[0] + dim2 = dim1 if tensor2 is None else tensor2.shape[0] + # CORRECT: calls fill_with_neg_inf just like the source + future_mask = jnp.triu( + fill_with_neg_inf(jnp.ones((dim1, dim2))), + 1 + abs(dim2 - dim1) + ) + return future_mask[:dim1, :dim2] + +GENERAL RULE: +- If the source defines function A and calls it from function B, + the JAX version must also call A from B. +- Never inline A's logic into B and leave A as dead code. +- This preserves dtype handling, code structure, and maintainability. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_detach_stop_gradient_jax.py b/MaxCode/rag/sources/targeted/targeted_detach_stop_gradient_jax.py new file mode 100644 index 0000000..ae2b1ae --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_detach_stop_gradient_jax.py @@ -0,0 +1,87 @@ +""" +TARGETED RAG: Preserve .detach() as jax.lax.stop_gradient() in JAX/Flax +========================================================================= + +When converting PyTorch code that calls .detach() on a tensor, you MUST +use jax.lax.stop_gradient() in the JAX version. Omitting this changes +the gradient flow and training dynamics. + +This is especially common for: +- Positional embeddings (sinusoidal or learned) that should not receive gradients +- Target values in loss computation +- Codebook entries in VQ-VAE +- Teacher outputs in knowledge distillation + +WRONG -- Omitting stop_gradient when source uses .detach(): +------------------------------------------------------------ + # PyTorch source: + # def forward(self, input): + # ... + # return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + # WRONG! Missing stop_gradient -- gradients will flow through positional embeddings + def __call__(self, input): + ... + return weights[positions] + +WHY THIS IS WRONG: +- .detach() in PyTorch severs the tensor from the computation graph +- Without it, gradients propagate back through the embedding lookup +- For sinusoidal positional embeddings this is especially wrong because: + 1. The embeddings are deterministic functions of position, not learnable + 2. Gradient flow through them wastes compute and can cause instability + 3. The PyTorch source author explicitly chose to block gradients here +- Omitting .detach() silently changes training behavior with no error or warning + +CORRECT -- Use jax.lax.stop_gradient() wherever source uses .detach(): +----------------------------------------------------------------------- + # CORRECT: stop_gradient preserves the .detach() semantics + def __call__(self, input): + ... + return jax.lax.stop_gradient(weights[positions]) + +PATTERN MATCHING: +----------------- +When you see ANY of these patterns in PyTorch, add jax.lax.stop_gradient(): + + PyTorch pattern 1: `tensor.detach()` + JAX equivalent: `jax.lax.stop_gradient(tensor)` + + PyTorch pattern 2: `tensor.detach().clone()` + JAX equivalent: `jax.lax.stop_gradient(tensor).copy()` + + PyTorch pattern 3: `with torch.no_grad(): result = ...` + JAX equivalent: `result = jax.lax.stop_gradient(...)` + + PyTorch pattern 4: `x.data` (accessing raw data, no grad tracking) + JAX equivalent: `jax.lax.stop_gradient(x)` + +FULL EXAMPLE -- Sinusoidal Positional Embedding: +------------------------------------------------- + # PyTorch source: + class SinusoidalPositionalEmbedding(nn.Module): + def forward(self, input): + bsz, seq_len = input.size() + max_pos = self.padding_idx + 1 + seq_len + weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) + positions = make_positions(input, self.padding_idx, self.left_pad) + return weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + # CORRECT JAX conversion: + class SinusoidalPositionalEmbedding(nn.Module): + embedding_dim: int + padding_idx: int = 0 + left_pad: int = 0 + + @nn.compact + def __call__(self, input): + bsz, seq_len = input.shape + max_pos = self.padding_idx + 1 + seq_len + weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) + positions = make_positions(input, self.padding_idx, self.left_pad) + # CRITICAL: preserve .detach() as stop_gradient + return jax.lax.stop_gradient(weights[positions.reshape(-1)].reshape(bsz, seq_len, -1)) + +RULE: Every .detach() in the source MUST become a jax.lax.stop_gradient() in JAX. +This is not optional -- it changes the mathematical gradient computation. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_dtype_mixed_precision_jax.py b/MaxCode/rag/sources/targeted/targeted_dtype_mixed_precision_jax.py new file mode 100644 index 0000000..614ce65 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_dtype_mixed_precision_jax.py @@ -0,0 +1,101 @@ +""" +TARGETED JAX PATTERN: dtype and Mixed Precision on TPU/GPU + +When converting PyTorch models to JAX, handle dtype carefully. TPU bfloat16 has +different precision characteristics than GPU float16, and certain operations +MUST be done in float32 for numerical stability. + +## Operations that MUST use float32: + +| Operation | Why float32 is needed | +|------------------------|----------------------------------------------------| +| Softmax | exp() overflows in bf16; sum of probs loses precision | +| Variance / RMS | Squaring amplifies error; mean of squares needs range | +| Layer/RMS normalization| Uses variance internally | +| Loss computation | Cross-entropy log() needs precision | +| Cumulative sum/prod | Accumulation amplifies rounding errors | +| Router logits (MoE) | Small differences in routing matter | + +## Pattern: Upcast before, cast back after + + import jax.numpy as jnp + + def stable_softmax(x, axis=-1): + '''Softmax with float32 upcast for numerical stability.''' + x_f32 = x.astype(jnp.float32) + result = jax.nn.softmax(x_f32, axis=axis) + return result.astype(x.dtype) + + def rms_norm(x, weight, eps=1e-6): + '''RMS normalization with float32 upcast.''' + orig_dtype = x.dtype + x = x.astype(jnp.float32) + rms = jax.lax.rsqrt(jnp.mean(x ** 2, axis=-1, keepdims=True) + eps) + return (x * rms).astype(orig_dtype) * weight + +## Flax param_dtype vs compute dtype: + + import flax.linen as nn + + class MyDense(nn.Module): + features: int + param_dtype: jnp.dtype = jnp.bfloat16 # Store weights in bf16 + compute_dtype: jnp.dtype = jnp.bfloat16 # Compute in bf16 + + @nn.compact + def __call__(self, x): + kernel = self.param( + 'kernel', + nn.initializers.normal(stddev=0.02), + (x.shape[-1], self.features), + self.param_dtype, # Weight stored in this dtype + ) + # Cast to compute dtype for matmul + x = x.astype(self.compute_dtype) + kernel = kernel.astype(self.compute_dtype) + return x @ kernel + +## TPU bfloat16 gotchas: + +1. **No float16 on TPU**: TPU natively supports bf16 and f32. Using float16 + requires emulation and is slower. Always use bfloat16 on TPU. + +2. **bf16 range vs precision**: bf16 has same exponent range as f32 (no overflow + for typical values) but only 7 bits of mantissa (vs 23 for f32). This means + additions of values with different magnitudes lose precision. + +3. **Matmul accumulation**: `jnp.matmul` on TPU accumulates in float32 internally + even with bf16 inputs, so matmuls are generally safe. But element-wise ops + (add, multiply, square) do NOT auto-upcast. + +4. **jnp.where dtype**: `jnp.where(cond, 0.0, -1e9)` -- the -1e9 must fit in + the output dtype. For bf16, -1e9 is representable. For fp16, use + `jnp.finfo(dtype).min` instead of a literal. + +## Full pattern in a transformer layer: + + class TransformerLayer(nn.Module): + config: ModelConfig + + @nn.compact + def __call__(self, x): + dtype = self.config.compute_dtype # e.g., jnp.bfloat16 + + # RMSNorm: upcast to f32 internally + normed = rms_norm(x, self.param('norm', nn.initializers.ones_init(), + (self.config.hidden_size,))) + + # Attention: matmuls are safe in bf16 + q = nn.Dense(self.config.qk_dim, dtype=dtype)(normed) + k = nn.Dense(self.config.qk_dim, dtype=dtype)(normed) + v = nn.Dense(self.config.v_dim, dtype=dtype)(normed) + + # Attention scores: safe in bf16 (matmul accumulates in f32) + attn = q @ k.swapaxes(-2, -1) / jnp.sqrt(self.config.head_dim) + + # Softmax: MUST upcast to f32 + attn = stable_softmax(attn) + + out = attn @ v + return x + nn.Dense(self.config.hidden_size, dtype=dtype)(out) +""" diff --git a/MaxCode/rag/sources/targeted/targeted_encoder_decoder_cache_jax.py b/MaxCode/rag/sources/targeted/targeted_encoder_decoder_cache_jax.py new file mode 100644 index 0000000..c383f34 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_encoder_decoder_cache_jax.py @@ -0,0 +1,140 @@ +""" +TARGETED JAX PATTERN: Encoder-Decoder KV Cache with NamedTuple + +When converting encoder-decoder models (e.g., Whisper, T5, BART), the decoder +has TWO types of KV cache: + 1. Self-attention cache: grows with each decode step (like decoder-only models) + 2. Cross-attention cache: computed ONCE from encoder output, reused every step + +For migration output, use pure functional NamedTuple caches passed as arguments +and returned as outputs. Flax mutable variables (`self.variable('cache', ...)`) +are Flax's built-in approach but are not recommended for migration output because +they couple the code to Flax's variable management and complicate beam search. +Do NOT use init-flag protocols. + +## WRONG approach (Flax mutable variables with init flag -- DO NOT DO THIS): + + class MultiHeadAttention(nn.Module): + @nn.compact + def __call__(self, x, xa=None, kv_cache=None): + if xa is not None and kv_cache is not None: + cross_k = self.variable('cache', 'cross_k', ...) + cross_v = self.variable('cache', 'cross_v', ...) + if kv_cache.get('init', False): # <-- BAD: init flag protocol + k = key_proj(xa) + cross_k.value = k # <-- BAD: mutable state + else: + k = cross_k.value # <-- BAD: reading mutable state + # This couples caching logic to the attention module, breaks pure + # functional JAX semantics, and makes beam search difficult. + +## WRONG approach 2 (config dict with no actual caches -- DO NOT DO THIS): + + def install_kv_cache_hooks(self, max_length=448): + cache_config = {'init': True, 'cache_index': 0, 'max_length': max_length} + return cache_config, [] + # This returns flags but no pre-allocated cache tensors! + # PyTorch hooks have no JAX equivalent -- replace with init function. + +## CORRECT approach (NamedTuple caches, passed as args, returned as outputs): + + import jax + import jax.numpy as jnp + from typing import NamedTuple, Optional, Tuple + + class KVCache(NamedTuple): + '''Pre-allocated KV cache buffer.''' + key: jnp.ndarray # [B, max_len, D] + value: jnp.ndarray # [B, max_len, D] + index: jnp.ndarray # scalar: next write position + + class MultiHeadAttention(nn.Module): + n_state: int + n_head: int + + @nn.compact + def __call__(self, x, xa=None, mask=None, kv_cache=None): + q = nn.Dense(self.n_state, name='query')(x) + source = x if xa is None else xa + + if kv_cache is not None and xa is not None: + # Cross-attention: K/V already cached from encoder output + k = kv_cache.key + v = kv_cache.value + new_cache = kv_cache # pass through unchanged + elif kv_cache is not None: + # Self-attention: update cache with new K/V + k_new = nn.Dense(self.n_state, use_bias=False, name='key')(x) + v_new = nn.Dense(self.n_state, name='value')(x) + k = jax.lax.dynamic_update_slice(kv_cache.key, k_new, (0, kv_cache.index, 0)) + v = jax.lax.dynamic_update_slice(kv_cache.value, v_new, (0, kv_cache.index, 0)) + new_cache = KVCache(key=k, value=v, index=kv_cache.index + k_new.shape[1]) + else: + # No cache: compute K/V from source + k = nn.Dense(self.n_state, use_bias=False, name='key')(source) + v = nn.Dense(self.n_state, name='value')(source) + new_cache = None + + out, qk = self._qkv_attention(q, k, v, mask) + return nn.Dense(self.n_state, name='out')(out), qk, new_cache + + # ResidualAttentionBlock accepts SEPARATE self and cross caches: + class ResidualAttentionBlock(nn.Module): + n_state: int + n_head: int + cross_attention: bool = False + + @nn.compact + def __call__(self, x, xa=None, mask=None, self_attn_cache=None, cross_attn_cache=None): + out, _, new_self_cache = MultiHeadAttention( + self.n_state, self.n_head, name='attn' + )(nn.LayerNorm(name='attn_ln')(x), mask=mask, kv_cache=self_attn_cache) + x = x + out + + new_cross_cache = cross_attn_cache + if self.cross_attention: + cross_out, _, new_cross_cache = MultiHeadAttention( + self.n_state, self.n_head, name='cross_attn' + )(nn.LayerNorm(name='cross_attn_ln')(x), xa=xa, kv_cache=cross_attn_cache) + x = x + cross_out + + # MLP + h = nn.Dense(self.n_state * 4)(nn.LayerNorm(name='mlp_ln')(x)) + h = jax.nn.gelu(h) + h = nn.Dense(self.n_state)(h) + x = x + h + + return x, new_self_cache, new_cross_cache + + # Pre-allocate all caches for decoder layers: + def init_kv_caches(dims, batch_size, dtype=jnp.float32): + '''Create pre-allocated KV caches for all decoder layers.''' + self_caches = tuple( + KVCache( + key=jnp.zeros((batch_size, dims.n_text_ctx, dims.n_text_state), dtype=dtype), + value=jnp.zeros((batch_size, dims.n_text_ctx, dims.n_text_state), dtype=dtype), + index=jnp.array(0, dtype=jnp.int32), + ) + for _ in range(dims.n_text_layer) + ) + # Cross-attention caches: populated once from encoder output + cross_caches = tuple( + KVCache( + key=jnp.zeros((batch_size, dims.n_audio_ctx, dims.n_text_state), dtype=dtype), + value=jnp.zeros((batch_size, dims.n_audio_ctx, dims.n_text_state), dtype=dtype), + index=jnp.array(0, dtype=jnp.int32), + ) + for _ in range(dims.n_text_layer) + ) + return self_caches, cross_caches + +## WHY this pattern is correct: + +1. **Pure functional**: Caches are inputs AND outputs. No hidden mutable state. +2. **Cross-attention reuse**: Encoder K/V computed once, stored in cross_attn_cache, + passed through unchanged on every decode step. No init flag needed. +3. **JIT-safe**: All shapes static. dynamic_update_slice is traced, not Python mutation. +4. **Beam search**: Easy to duplicate/reorder NamedTuple caches by batch indexing. +5. **Replaces install_kv_cache_hooks**: PyTorch uses hooks to intercept projections. + JAX replaces this with init_kv_caches() that pre-allocates all layer caches. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_flax_checkpoint_api_jax.py b/MaxCode/rag/sources/targeted/targeted_flax_checkpoint_api_jax.py new file mode 100644 index 0000000..5c7d15a --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_flax_checkpoint_api_jax.py @@ -0,0 +1,70 @@ +""" +TARGETED JAX PATTERN: Flax Checkpoint and TensorBoard APIs + +CRITICAL: Several Flax APIs are deprecated or removed in newer versions. +When converting training utilities, use current stable APIs. + +## WRONG: Using deprecated flax.training.checkpoints + + # WRONG! This API is deprecated and may be removed. + from flax.training.checkpoints import save_checkpoint, restore_checkpoint + + save_checkpoint(ckpt_dir, target=state, step=epoch) + state = restore_checkpoint(ckpt_dir, target=state) + +## CORRECT: Use flax.serialization for simple cases + + import flax.serialization + + # Save + state_bytes = flax.serialization.to_bytes(state) + with open(path, 'wb') as f: + f.write(state_bytes) + + # Load + with open(path, 'rb') as f: + state_bytes = f.read() + state = flax.serialization.from_bytes(state, state_bytes) + +## CORRECT: Use orbax for production checkpointing + + import orbax.checkpoint as ocp + + # Save + checkpointer = ocp.StandardCheckpointer() + checkpointer.save(path, state) + + # Load + state = checkpointer.restore(path, target=state) + +## WRONG: Using flax.metrics.tensorboard + + # WRONG! This module may not exist in newer Flax versions. + from flax.metrics.tensorboard import SummaryWriter + writer = SummaryWriter(log_dir) + +## CORRECT: Use tensorboardX or standard TensorBoard + + # Option 1: tensorboardX (most common in JAX ecosystem) + from tensorboardX import SummaryWriter + writer = SummaryWriter(log_dir) + writer.add_scalar('train/loss', loss_val, step) + + # Option 2: Use the source's TensorBoard pattern faithfully + # If the PyTorch source uses torch.utils.tensorboard.SummaryWriter, + # convert to tensorboardX which has the same API: + from tensorboardX import SummaryWriter + writer = SummaryWriter(tensorboard_dir) + for name, value in epoch_metrics.items(): + writer.add_scalar(f'train/{name}', float(value), epoch) + writer.close() + +## Why this matters: + +1. **Import errors**: Deprecated APIs cause ImportError at runtime, making the + converted code non-functional without manual fixes. +2. **API stability**: orbax and tensorboardX are the recommended replacements + and are actively maintained. +3. **Source fidelity**: If the source has TensorBoard logging, the conversion + should preserve it using the correct JAX-ecosystem equivalent. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_flax_train_eval_mode_jax.py b/MaxCode/rag/sources/targeted/targeted_flax_train_eval_mode_jax.py new file mode 100644 index 0000000..cb94ee8 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_flax_train_eval_mode_jax.py @@ -0,0 +1,82 @@ +""" +TARGETED JAX PATTERN: Train/Eval Mode in Flax — Use deterministic Flag + +CRITICAL: Flax nn.Module objects do NOT have a .train attribute like PyTorch. +Setting model.train = True or model.train = False does nothing in Flax and +will silently produce incorrect behavior. Flax controls train vs eval mode +through a `deterministic` argument passed to __call__. + +## WRONG: Setting .train attribute on Flax module (PyTorch habit) + + # WRONG! Flax modules have no .train attribute. This sets a random + # Python attribute that NO Flax module reads. Dropout, noise, and + # other stochastic layers will NOT change behavior. + model = MixtureOfExperts(config) + + # Training loop + model.train = True # <-- DOES NOTHING! Silently ignored. + output = model(x, deterministic=False) + + # Eval loop + model.train = False # <-- DOES NOTHING! Silently ignored. + output = model(x, deterministic=True) + +## WRONG: Using PyTorch's model.eval() / model.train() pattern + + # WRONG! Flax modules do not have .eval() or .train() methods. + # This will raise an AttributeError. + model.eval() + model.train() + +## CORRECT: Use the deterministic flag on __call__ + + # In Flax, train/eval mode is controlled by passing `deterministic` + # to the module's __call__ method. Each submodule (Dropout, etc.) + # checks this flag to decide whether to apply stochastic behavior. + + model = MixtureOfExperts(config) + + # Training: deterministic=False enables dropout, noise, etc. + output = model.apply( + {'params': params}, + x, + deterministic=False, + rngs={'dropout': dropout_rng} + ) + + # Evaluation: deterministic=True disables all stochastic behavior. + output = model.apply( + {'params': params}, + x, + deterministic=True + # No rngs needed in eval mode + ) + +## CORRECT: Training loop pattern + + # The training loop should NOT set any attribute on the model. + # Instead, pass deterministic=False to train_step and deterministic=True + # to eval_step via the model.apply call. + + for epoch in range(num_epochs): + # Training: pass deterministic=False + for batch in train_loader: + state, metrics = train_step(state, batch) # uses deterministic=False internally + + # Evaluation: pass deterministic=True + for batch in val_loader: + metrics = eval_step(state, batch) # uses deterministic=True internally + +## Why this matters: + +1. **Silent failure**: Setting model.train = True/False creates a new Python attribute + but no Flax code reads it. The model behaves identically in both cases. +2. **Dropout stays on/off**: Without the deterministic flag, nn.Dropout either always + drops (if deterministic defaults to False) or never drops. This corrupts training + dynamics or evaluation metrics. +3. **Router noise**: Routers that add noise during training (for load balancing) use + the deterministic flag to decide whether to inject noise. Without it, noise is + either always on (noisy eval) or always off (no exploration during training). +4. **Functional paradigm**: Flax follows JAX's functional style — behavior is controlled + by function arguments, not by mutable object state. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_float32_softmax_upcast_jax.py b/MaxCode/rag/sources/targeted/targeted_float32_softmax_upcast_jax.py new file mode 100644 index 0000000..ee31501 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_float32_softmax_upcast_jax.py @@ -0,0 +1,67 @@ +""" +TARGETED RAG: Float32 Softmax Upcast in JAX/Flax +================================================== + +When converting attention code that uses `.float()` before softmax in PyTorch, +you MUST preserve the float32 upcast in JAX. This is critical for numerical +stability when the model runs in bfloat16 or float16. + +WRONG -- No upcast before softmax: +------------------------------------ + attn_weights = jnp.matmul(q, k.transpose(0, 2, 1)) * scale + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + attn_weights = jax.nn.softmax(attn_weights, axis=-1) # WRONG: no upcast + attn_probs = nn.Dropout(rate=self.attn_dropout)( + attn_weights, deterministic=self.deterministic) + +WHY THIS IS WRONG: +- In bfloat16, the exp() inside softmax can overflow or underflow +- PyTorch code explicitly does `attn_weights_float = attn_weights.float()` + before softmax, then casts back with `.type_as(attn_weights)` +- Without the upcast, attention distributions become inaccurate, especially + for long sequences where values can be very negative +- This causes subtle numerical errors that compound through layers + +CORRECT -- Upcast to float32 before softmax, cast back after: +-------------------------------------------------------------- + attn_weights = jnp.matmul(q, k.transpose(0, 2, 1)) * scale + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + # CORRECT: upcast to float32 before softmax for numerical stability + attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1) + attn_weights = attn_weights.astype(q.dtype) # cast back to compute dtype + attn_probs = nn.Dropout(rate=self.attn_dropout)( + attn_weights, deterministic=self.deterministic) + +PATTERN MATCHING: +----------------- +When you see ANY of these patterns in PyTorch source code, add the float32 upcast: + + PyTorch pattern 1: `attn_weights_float = attn_weights.float()` + PyTorch pattern 2: `attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)` + PyTorch pattern 3: `attn_weights.float().softmax(dim=-1).type_as(attn_weights)` + +JAX equivalent for ALL of these: + ``` + attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1) + attn_weights = attn_weights.astype(q.dtype) + ``` + +OTHER OPERATIONS THAT NEED FLOAT32 UPCAST: +------------------------------------------- +The same principle applies to: + +1. Layer normalization variance: + WRONG: variance = jnp.mean(x ** 2, axis=-1, keepdims=True) + CORRECT: variance = jnp.mean(x.astype(jnp.float32) ** 2, axis=-1, keepdims=True) + +2. Loss functions with log: + WRONG: loss = -jnp.log(probs) + CORRECT: loss = -jnp.log(probs.astype(jnp.float32)) + +3. Any operation with exp(), log(), or division where precision matters. + +RULE: When in doubt, upcast to float32. The cost is negligible (XLA fuses the +cast with the computation) but the benefit is correct numerics. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_fused_qkv_projection_jax.py b/MaxCode/rag/sources/targeted/targeted_fused_qkv_projection_jax.py new file mode 100644 index 0000000..b9768bb --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_fused_qkv_projection_jax.py @@ -0,0 +1,163 @@ +""" +TARGETED RAG: Fused QKV Projection in JAX/Flax +================================================ + +When converting fairseq-style MultiheadAttention that uses a single +`in_proj_weight` of shape [3*embed_dim, embed_dim] with sliced projection +methods (in_proj_qkv, in_proj_q, in_proj_kv), preserve this fused design +in JAX. Do NOT split into 3 separate nn.Dense layers. + +WRONG -- 3 separate Dense layers: +----------------------------------- +class MultiheadAttention(nn.Module): + embed_dim: int + num_heads: int + + @nn.compact + def __call__(self, query, key, value): + q = nn.Dense(self.embed_dim, name='q_proj')(query) # WRONG + k = nn.Dense(self.embed_dim, name='k_proj')(key) # WRONG + v = nn.Dense(self.embed_dim, name='v_proj')(value) # WRONG + ... + +WHY THIS IS WRONG: +- Breaks weight compatibility with PyTorch checkpoints that store a single + in_proj_weight tensor of shape [3*D, D] +- Loses the qkv_same_embed_dim / kv_same_embed_dim optimization paths + where Q,K,V are projected from the same input in a single matmul +- Cannot faithfully represent in_proj_q (query-only), in_proj_kv + (key+value only) projection methods used for cross-attention + +CORRECT -- Single fused [3*D, D] parameter with sliced projection: +------------------------------------------------------------------- +import jax +import jax.numpy as jnp +import flax.linen as nn + +class MultiheadAttention(nn.Module): + embed_dim: int + num_heads: int + kdim: int = None + vdim: int = None + add_bias_kv: bool = False + add_zero_attn: bool = False + attn_dropout: float = 0.0 + deterministic: bool = False + + def _get_dims(self): + kdim = self.kdim if self.kdim is not None else self.embed_dim + vdim = self.vdim if self.vdim is not None else self.embed_dim + head_dim = self.embed_dim // self.num_heads + qkv_same = (kdim == self.embed_dim and vdim == self.embed_dim) + kv_same = (kdim == vdim) + return kdim, vdim, head_dim, qkv_same, kv_same + + @nn.compact + def __call__(self, query, key, value, attn_mask=None, need_weights=True): + kdim, vdim, head_dim, qkv_same, kv_same = self._get_dims() + bsz = query.shape[1] # (T, B, D) time-first layout + + # === Fused QKV weight: single [3*D, D] parameter === + if qkv_same: + in_proj_weight = self.param( + 'in_proj_weight', + nn.initializers.xavier_uniform(), + (3 * self.embed_dim, self.embed_dim), + ) + in_proj_bias = self.param( + 'in_proj_bias', + nn.initializers.zeros_init(), + (3 * self.embed_dim,), + ) + else: + # Separate weights when dims differ (cross-attention) + q_weight = self.param('q_proj_weight', nn.initializers.xavier_uniform(), + (self.embed_dim, self.embed_dim)) + k_weight = self.param('k_proj_weight', nn.initializers.xavier_uniform(), + (self.embed_dim, kdim)) + v_weight = self.param('v_proj_weight', nn.initializers.xavier_uniform(), + (self.embed_dim, vdim)) + q_bias = self.param('q_proj_bias', nn.initializers.zeros_init(), (self.embed_dim,)) + k_bias = self.param('k_proj_bias', nn.initializers.zeros_init(), (self.embed_dim,)) + v_bias = self.param('v_proj_bias', nn.initializers.zeros_init(), (self.embed_dim,)) + + out_proj = nn.Dense(self.embed_dim, name='out_proj', + kernel_init=nn.initializers.xavier_uniform()) + + # === Sliced projection methods (matching fairseq) === + def _in_proj(x, weight, bias, start=0, end=None): + \"\"\"Project x using a slice of the fused weight and bias.\"\"\" + w = weight[start:end] + b = bias[start:end] if bias is not None else None + out = x @ w.T + if b is not None: + out = out + b + return out + + def in_proj_qkv(x): + \"\"\"Project Q, K, V from the same input (self-attention).\"\"\" + D = self.embed_dim + return (_in_proj(x, in_proj_weight, in_proj_bias, 0, D), + _in_proj(x, in_proj_weight, in_proj_bias, D, 2*D), + _in_proj(x, in_proj_weight, in_proj_bias, 2*D, 3*D)) + + def in_proj_q(x): + \"\"\"Project Q only (used in cross-attention).\"\"\" + if qkv_same: + return _in_proj(x, in_proj_weight, in_proj_bias, 0, self.embed_dim) + else: + return x @ q_weight.T + q_bias + + def in_proj_kv(x): + \"\"\"Project K and V together (used in cross-attention).\"\"\" + D = self.embed_dim + if qkv_same: + return (_in_proj(x, in_proj_weight, in_proj_bias, D, 2*D), + _in_proj(x, in_proj_weight, in_proj_bias, 2*D, 3*D)) + elif kv_same: + return (x @ k_weight.T + k_bias, x @ v_weight.T + v_bias) + else: + return (x @ k_weight.T + k_bias, x @ v_weight.T + v_bias) + + # === Usage in forward pass === + if qkv_same and (query is key is value): + # Self-attention: single fused projection + q, k, v = in_proj_qkv(query) + else: + # Cross-attention: separate Q and KV projections + q = in_proj_q(query) + k, v = in_proj_kv(key) # key == value typically + + # Reshape: (T, B, D) -> (B*H, T, head_dim) + T_q, T_kv = q.shape[0], k.shape[0] + q = q.reshape(T_q, bsz * self.num_heads, head_dim).transpose(1, 0, 2) + k = k.reshape(T_kv, bsz * self.num_heads, head_dim).transpose(1, 0, 2) + v = v.reshape(T_kv, bsz * self.num_heads, head_dim).transpose(1, 0, 2) + + # Scaled dot-product attention + scale = head_dim ** -0.5 + attn_weights = jnp.matmul(q, k.transpose(0, 2, 1)) * scale + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1) + attn_weights = attn_weights.astype(q.dtype) + attn_weights = nn.Dropout(rate=self.attn_dropout)( + attn_weights, deterministic=self.deterministic) + + attn_output = jnp.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 0, 2).reshape(T_q, bsz, self.embed_dim) + attn_output = out_proj(attn_output) + + if need_weights: + attn_weights = attn_weights.reshape(bsz, self.num_heads, T_q, T_kv) + attn_weights = attn_weights.mean(axis=1) # avg over heads + return attn_output, attn_weights + +KEY POINTS: +----------- +1. Single `in_proj_weight` param of shape [3*embed_dim, embed_dim] -- matches PyTorch +2. Sliced access via in_proj_qkv(), in_proj_q(), in_proj_kv() -- matches fairseq API +3. Falls back to separate weights when kdim != embed_dim or vdim != embed_dim +4. Xavier uniform initialization matches PyTorch's default for MultiheadAttention +5. Weight loading from PyTorch is trivial: just copy in_proj_weight directly +""" diff --git a/MaxCode/rag/sources/targeted/targeted_integer_dtype_long_cast_jax.py b/MaxCode/rag/sources/targeted/targeted_integer_dtype_long_cast_jax.py new file mode 100644 index 0000000..8cb0293 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_integer_dtype_long_cast_jax.py @@ -0,0 +1,51 @@ +""" +TARGETED RAG: Preserve .long() / .int() Integer Dtype Casts in JAX +==================================================================== + +When PyTorch code explicitly calls .long() (int64) or .int() (int32) on a +tensor, you MUST preserve the equivalent dtype cast in JAX. These casts +exist for a reason -- often for indexing, embedding lookups, or API +compatibility. + +WRONG -- Omitting the .long() cast: +------------------------------------- + # PyTorch source: + # positions = make_positions(input, padding_idx, left_pad) + # return new_tensor.masked_scatter_(mask, positions[mask]).long() + + # WRONG! Missing .long() -- returns int32 instead of int64 + def make_positions(tensor, padding_idx, left_pad): + ... + return jnp.where(mask, positions, tensor) + +WHY THIS IS WRONG: +- .long() converts to int64 (torch.int64) +- Without the cast, positions may be int32, causing: + 1. Dtype mismatches when used as indices into int64-indexed arrays + 2. Overflow for very large sequence lengths or vocabularies + 3. Subtle bugs when comparing with other int64 tensors +- The source author explicitly added .long() for a reason + +CORRECT -- Preserve the int64 cast: +------------------------------------- + # CORRECT: .long() -> .astype(jnp.int64) or jnp.int64 + def make_positions(tensor, padding_idx, left_pad): + ... + return jnp.where(mask, positions, tensor).astype(jnp.int64) + +PATTERN MATCHING: +----------------- + PyTorch: `tensor.long()` -> JAX: `tensor.astype(jnp.int64)` + PyTorch: `tensor.int()` -> JAX: `tensor.astype(jnp.int32)` + PyTorch: `tensor.short()` -> JAX: `tensor.astype(jnp.int16)` + PyTorch: `tensor.float()` -> JAX: `tensor.astype(jnp.float32)` + PyTorch: `tensor.double()` -> JAX: `tensor.astype(jnp.float64)` + PyTorch: `tensor.half()` -> JAX: `tensor.astype(jnp.float16)` + PyTorch: `tensor.bfloat16()` -> JAX: `tensor.astype(jnp.bfloat16)` + PyTorch: `tensor.bool()` -> JAX: `tensor.astype(jnp.bool_)` + PyTorch: `tensor.to(dtype)` -> JAX: `tensor.astype(dtype)` + PyTorch: `tensor.type_as(ref)` -> JAX: `tensor.astype(ref.dtype)` + +RULE: Every explicit dtype cast in PyTorch (.long(), .float(), .type_as(), etc.) +must have an equivalent .astype() in JAX. Never drop dtype casts. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_kvcache_prefill_decode_jax.py b/MaxCode/rag/sources/targeted/targeted_kvcache_prefill_decode_jax.py new file mode 100644 index 0000000..682d585 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_kvcache_prefill_decode_jax.py @@ -0,0 +1,155 @@ +""" +TARGETED JAX PATTERN: KV Cache — Pure Functional with Pre-Allocated Buffers + +For migration output, use pre-allocated NamedTuple buffers instead of Flax mutable +variables. NamedTuples are framework-agnostic, JIT-safe with static shapes, and +beam-search friendly. Flax's `self.variable('cache', ...)` is the standard Flax API +and works for Flax-only codebases, but couples the conversion to Flax internals. +Do NOT use growing arrays (`jnp.concatenate`) -- they change shape each step and +break jax.jit. Use `dynamic_update_slice` for writes and `dynamic_slice` for reads, +with cache buffers passed as function arguments and returned as outputs. + +## WRONG approach 1 (Flax mutable variables -- DO NOT DO THIS): + + # WRONG! Hidden mutable state breaks pure functional JAX semantics + class Attention(nn.Module): + @nn.compact + def __call__(self, x, deterministic=True): + k = nn.Dense(self.kv_dim)(x) + v = nn.Dense(self.kv_dim)(x) + + # BAD: Flax mutable variables are hard to manage with jax.jit, + # beam search, and custom training loops + cached_key = self.variable('cache', 'cached_key', + jnp.zeros, (batch, max_len, kv_dim)) + cached_key.value = jnp.concatenate([cached_key.value, k], axis=1) + +## WRONG approach 2 (growing arrays -- DO NOT DO THIS): + + # WRONG! Concatenation creates new arrays each step, breaking jax.jit + if cache is not None: + k = jnp.concatenate([cache['key'], k], axis=1) # Shape changes each step! + v = jnp.concatenate([cache['value'], v], axis=1) + +## CORRECT approach (pre-allocated buffers + dynamic_update_slice): + + import jax + import jax.numpy as jnp + from typing import NamedTuple + + class AttentionCache(NamedTuple): + '''Pure functional cache for standard attention.''' + key: jnp.ndarray # [batch, max_seq_len, num_heads, head_dim] + value: jnp.ndarray # [batch, max_seq_len, num_heads, head_dim] + index: jnp.ndarray # [] scalar: next write position + + def init_attention_cache(batch_size, max_seq_len, num_heads, head_dim, dtype=jnp.bfloat16): + '''Create an empty pre-allocated cache.''' + return AttentionCache( + key=jnp.zeros((batch_size, max_seq_len, num_heads, head_dim), dtype=dtype), + value=jnp.zeros((batch_size, max_seq_len, num_heads, head_dim), dtype=dtype), + index=jnp.array(0, dtype=jnp.int32), + ) + + def update_attention_cache(cache, new_key, new_value): + ''' + Write new K/V into pre-allocated buffers at the current index. + + Args: + cache: AttentionCache with pre-allocated buffers + new_key: [batch, seq_len, num_heads, head_dim] new keys + new_value: [batch, seq_len, num_heads, head_dim] new values + + Returns: + updated_cache: AttentionCache with new K/V written in-place + full_key: [batch, max_seq_len, num_heads, head_dim] (view for attention) + full_value: [batch, max_seq_len, num_heads, head_dim] + ''' + seq_len = new_key.shape[1] + + # Write new K/V at current index using dynamic_update_slice + updated_key = jax.lax.dynamic_update_slice( + cache.key, new_key, + (0, cache.index, 0, 0) # start indices: batch=0, time=index, head=0, dim=0 + ) + updated_value = jax.lax.dynamic_update_slice( + cache.value, new_value, + (0, cache.index, 0, 0) + ) + + updated_cache = AttentionCache( + key=updated_key, + value=updated_value, + index=cache.index + seq_len, + ) + + return updated_cache, updated_key, updated_value + + def get_attention_mask(cache_index, new_seq_len, max_seq_len): + ''' + Build causal mask for cached attention. + + Returns additive mask: 0.0 for allowed positions, -1e9 for blocked. + ''' + # Positions of new queries: [cache_index, cache_index + new_seq_len) + q_positions = jnp.arange(new_seq_len) + cache_index + # Positions of all keys: [0, max_seq_len) + k_positions = jnp.arange(max_seq_len) + + # Causal: query can attend to keys with position <= query position + causal_mask = q_positions[:, None] >= k_positions[None, :] + # Also mask out unfilled positions (beyond cache_index + new_seq_len) + valid_mask = k_positions[None, :] < (cache_index + new_seq_len) + + mask = causal_mask & valid_mask + return jnp.where(mask, 0.0, -1e9) + +## For GatedDeltaNet linear attention (recurrent state cache): + + class GatedDeltaNetCache(NamedTuple): + '''Cache for gated delta net linear attention layer.''' + state: jnp.ndarray # [batch, num_heads, head_k_dim, head_v_dim] recurrent state + conv_state: jnp.ndarray # [batch, channels, kernel_size-1] conv1d rolling state + + def init_gdn_cache(batch_size, num_heads, head_k_dim, head_v_dim, + conv_channels, kernel_size, dtype=jnp.bfloat16): + return GatedDeltaNetCache( + state=jnp.zeros((batch_size, num_heads, head_k_dim, head_v_dim), dtype=dtype), + conv_state=jnp.zeros((batch_size, conv_channels, kernel_size - 1), dtype=dtype), + ) + +## Full model cache as a NamedTuple of layer caches: + + class ModelCache(NamedTuple): + '''Cache for the full model -- one entry per layer.''' + layers: tuple # tuple of (AttentionCache | GatedDeltaNetCache) per layer + + def init_model_cache(config, batch_size, max_seq_len, dtype=jnp.bfloat16): + layers = [] + for i in range(config.num_hidden_layers): + if config.layer_types[i] == 'attention': + layers.append(init_attention_cache( + batch_size, max_seq_len, + config.num_attention_heads, config.head_dim, dtype + )) + else: + layers.append(init_gdn_cache( + batch_size, config.num_attention_heads, + config.head_k_dim, config.head_v_dim, + config.hidden_size, config.conv_kernel_size, dtype + )) + return ModelCache(layers=tuple(layers)) + +## Why pure functional cache: + +1. **JIT-compatible**: All shapes are static. `dynamic_update_slice` is a traced + op, not a Python-level mutation. +2. **Pure functional**: Cache is an input and output of the model -- no hidden + state. Works with `jax.jit`, `jax.vmap`, `jax.pmap`. +3. **Beam search**: Easy to duplicate/reorder caches for beam search by indexing + into the batch dimension. +4. **No Flax coupling**: NamedTuple cache works with any JAX framework, not just + Flax. No `self.variable('cache', ...)` magic. +5. **Efficient**: `dynamic_update_slice` is an O(seq_len) in-place XLA op, not + O(max_seq_len) like concatenation. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_linear_init_consistency_jax.py b/MaxCode/rag/sources/targeted/targeted_linear_init_consistency_jax.py new file mode 100644 index 0000000..028fc02 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_linear_init_consistency_jax.py @@ -0,0 +1,64 @@ +""" +TARGETED RAG: Use Consistent Initialization for All Linear Layers +================================================================== + +When converting PyTorch models that define a custom Linear() helper function +with explicit initialization (e.g., xavier_uniform), ALL nn.Linear layers +in the model must use that same helper in JAX. Do not use bare nn.Dense for +some layers while using the custom helper for others. + +WRONG -- Inconsistent initialization across layers: +----------------------------------------------------- + # PyTorch source defines a custom Linear helper: + # def Linear(in_features, out_features, bias=True): + # m = nn.Linear(in_features, out_features, bias) + # nn.init.xavier_uniform_(m.weight) + # if bias: nn.init.constant_(m.bias, 0.) + # return m + # + # Some layers use it: self.fc1 = Linear(dim, 4*dim) + # Other layers use bare nn.Linear: self.proj1 = nn.Linear(dim, dim) + + # JAX helper correctly uses xavier_uniform: + def Linear(in_features, out_features, bias=True, name=None): + return nn.Dense(out_features, use_bias=bias, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros_init(), + name=name) + + # WRONG! fc1 uses the helper but proj1 uses bare nn.Dense + fc1 = Linear(dim, 4 * dim, name='fc1') # xavier_uniform -- correct + proj1 = nn.Dense(dim, name='proj1') # lecun_normal -- WRONG! + +WHY THIS IS WRONG: +- In PyTorch, both bare nn.Linear layers use kaiming_uniform by default +- The JAX helper uses xavier_uniform (matching the PyTorch helper) +- But bare nn.Dense uses lecun_normal (different from PyTorch's kaiming_uniform) +- This creates INCONSISTENT initialization between layers in the same model +- Layers initialized with different distributions train differently +- Weight transfer from PyTorch checkpoints will have mismatched assumptions + +CORRECT -- Use the same Linear helper for ALL linear layers: +-------------------------------------------------------------- + # CORRECT: All linear layers use the same helper, matching PyTorch behavior + fc1 = Linear(dim, 4 * dim, name='fc1') + proj1 = Linear(dim, dim, name='proj1') # Use helper, not bare nn.Dense + proj2 = Linear(dim, dim, name='proj2') # Use helper, not bare nn.Dense + out_layer = Linear(dim, output_dim, name='out_layer') # Use helper here too + + # If the PyTorch source uses bare nn.Linear (no custom init), use bare nn.Dense: + # self.proj = nn.Linear(dim, dim) -> proj = nn.Dense(dim, name='proj') + # + # If the PyTorch source uses a custom init helper, use the JAX equivalent for ALL: + # self.fc1 = Linear(dim, 4*dim) -> fc1 = Linear(dim, 4*dim, name='fc1') + # self.proj = nn.Linear(dim, dim) -> proj = Linear(dim, dim, name='proj') + # + # The key insight: in PyTorch, nn.Linear always uses kaiming_uniform. + # When some layers get xavier_uniform via a helper, the REST still have + # kaiming_uniform. In JAX, bare nn.Dense uses lecun_normal (different!). + # So for layers without explicit init in PyTorch, using bare nn.Dense in JAX + # is acceptable. But when the SAME CLASS mixes helper and bare, be consistent. + +RULE: When a model defines a custom Linear() helper, use it for ALL linear +layers in that model to ensure consistent initialization behavior. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_load_balancing_loss_jax.py b/MaxCode/rag/sources/targeted/targeted_load_balancing_loss_jax.py new file mode 100644 index 0000000..045f386 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_load_balancing_loss_jax.py @@ -0,0 +1,101 @@ +""" +TARGETED JAX PATTERN: Load Balancing Loss with Attention Mask + +This function computes the auxiliary load-balancing loss from Switch Transformer +(equations 4-6). It MUST support an optional attention_mask parameter to exclude +padding tokens from the loss computation. Without the mask, padding tokens +pollute the routing statistics and destabilize MoE training. + +## WRONG: No attention_mask support + + def load_balancing_loss(gate_logits, num_experts, top_k): + concatenated = jnp.concatenate(gate_logits, axis=0) + routing_weights = jax.nn.softmax(concatenated, axis=-1) + _, selected_experts = jax.lax.top_k(routing_weights, top_k) + expert_mask = jax.nn.one_hot(selected_experts, num_experts) + tokens_per_expert = jnp.mean(expert_mask, axis=0) + router_prob_per_expert = jnp.mean(routing_weights, axis=0) + return jnp.sum(tokens_per_expert * router_prob_per_expert[None, :]) * num_experts + + # WHY THIS IS WRONG: Without attention_mask, padding tokens are counted in + # the mean, which dilutes the expert frequency statistics. In batched + # inference with variable-length sequences, this makes the loss meaningless. + +## WRONG: Collapsing the top_k dimension with axis=(0, 1) + + # expert_mask shape: [num_tokens, top_k, num_experts] + # PyTorch source: torch.mean(expert_mask.float(), dim=0) + # -> result shape: [top_k, num_experts] + + # WRONG! axis=(0, 1) reduces BOTH token AND top_k dimensions. + # Result shape becomes [num_experts] instead of [top_k, num_experts]. + tokens_per_expert = jnp.mean(expert_mask, axis=(0, 1)) # WRONG SHAPE! + + # WRONG! Flattening before reducing also collapses top_k. + expert_mask_flat = expert_mask.reshape(-1, num_experts) + tokens_per_expert = jnp.mean(expert_mask_flat, axis=0) # WRONG SHAPE! + + # WHY THIS IS WRONG: PyTorch dim=0 reduces ONLY the first dimension. + # The top_k dimension must be preserved. Collapsing it changes the loss + # value and breaks expert routing during training. + +## CORRECT: With attention_mask support + + def load_balancing_loss( + gate_logits: list[jnp.ndarray], + num_experts: int, + top_k: int, + attention_mask: jnp.ndarray | None = None, + ) -> jnp.ndarray: + if not gate_logits: + return jnp.array(0.0) + + # Concatenate all MoE layers: [num_layers * B * T, num_experts] + concatenated = jnp.concatenate(gate_logits, axis=0) + + routing_weights = jax.nn.softmax(concatenated, axis=-1) + _, selected_experts = jax.lax.top_k(routing_weights, top_k) + expert_mask = jax.nn.one_hot(selected_experts, num_experts) + # expert_mask: [num_layers * B * T, top_k, num_experts] + + if attention_mask is None: + # No padding: simple mean over all tokens + tokens_per_expert = jnp.mean(expert_mask.astype(jnp.float32), axis=0) + router_prob_per_expert = jnp.mean(routing_weights, axis=0) + else: + # With padding: mask out padding tokens before computing statistics + batch_size, seq_len = attention_mask.shape + num_layers = concatenated.shape[0] // (batch_size * seq_len) + + # Expand mask to [num_layers * B * T, top_k, num_experts] + expert_attn_mask = jnp.broadcast_to( + attention_mask[None, :, :, None, None], + (num_layers, batch_size, seq_len, top_k, num_experts), + ).reshape(-1, top_k, num_experts) + + tokens_per_expert = ( + jnp.sum(expert_mask.astype(jnp.float32) * expert_attn_mask, axis=0) + / jnp.maximum(jnp.sum(expert_attn_mask, axis=0), 1.0) + ) + + # Expand mask to [num_layers * B * T, num_experts] + router_attn_mask = jnp.broadcast_to( + attention_mask[None, :, :, None], + (num_layers, batch_size, seq_len, num_experts), + ).reshape(-1, num_experts) + + router_prob_per_expert = ( + jnp.sum(routing_weights * router_attn_mask, axis=0) + / jnp.maximum(jnp.sum(router_attn_mask, axis=0), 1.0) + ) + + overall_loss = jnp.sum(tokens_per_expert * router_prob_per_expert[None, :]) + return overall_loss * num_experts + +## KEY POINTS: +## - The attention_mask parameter is REQUIRED (even if optional=None) +## - Use jnp.maximum(..., 1.0) to avoid division by zero +## - Broadcast the mask to match [num_layers * B * T, ...] shape +## - The ForCausalLM forward method should pass attention_mask through: +## aux_loss = load_balancing_loss(router_logits, num_experts, top_k, attention_mask) +""" diff --git a/MaxCode/rag/sources/targeted/targeted_moe_capacity_routing_jax.py b/MaxCode/rag/sources/targeted/targeted_moe_capacity_routing_jax.py new file mode 100644 index 0000000..994e4ae --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_moe_capacity_routing_jax.py @@ -0,0 +1,122 @@ +""" +TARGETED JAX PATTERN: MoE Expert Dispatch with Capacity-Based Routing + +CRITICAL: When converting Mixture-of-Experts layers, the Experts class MUST use +capacity-based dispatch with einsum dispatch/combine tensors. Do NOT use per-token +weight gathering or dense all-experts einsum. + +## WRONG approach 1 (per-token gather -- DO NOT DO THIS): + + # WRONG! Gathers individual expert weights per token + flat_indices = top_k_index.reshape(-1) + gate_up_w = gate_up_proj[flat_indices] # [T*K, 2I, H] + hidden_repeated = jnp.repeat(x, top_k, axis=0) + out = jnp.sum(hidden_repeated[:, None, :] * gate_up_w, axis=-1) # unbatched! + # This does T*K individual matmuls -- not batched, XLA-unfriendly + +## WRONG approach 2 (dense einsum -- DO NOT DO THIS): + + # WRONG! Computes ALL experts for ALL tokens + expert_outputs = jnp.einsum('th,ehi->tei', x, expert_w1) # O(T*E*H*I) + # For E=64: wastes 93% of compute (each token only uses K=4 experts) + +## CORRECT approach (capacity-based dispatch with einsum): + + import jax + import jax.numpy as jnp + from flax import linen as nn + + class Experts(nn.Module): + config: Qwen3NextConfig + capacity_factor: float = 1.5 # Match source model's default -- this is an example value + + @nn.compact + def __call__(self, hidden_states, top_k_indices, top_k_weights): + config = self.config + num_experts = config.num_experts + hidden_dim = config.hidden_size + intermediate_dim = config.moe_intermediate_size + top_k = config.num_experts_per_tok + + # Expert weight parameters: [E, 2*I, H] and [E, H, I] + gate_up_proj = self.param('gate_up_proj', + nn.initializers.normal(config.initializer_range), + (num_experts, 2 * intermediate_dim, hidden_dim)) + down_proj = self.param('down_proj', + nn.initializers.normal(config.initializer_range), + (num_experts, hidden_dim, intermediate_dim)) + + num_tokens = hidden_states.shape[0] + + # ---- Step 1: Compute per-expert capacity ---- + raw_capacity = max((num_tokens * top_k + num_experts - 1) // num_experts, 1) + capacity = int(raw_capacity * self.capacity_factor) + + # ---- Step 2: Build dispatch and combine tensors ---- + # expert_one_hot: [T, K, E] + expert_one_hot = jax.nn.one_hot(top_k_indices, num_experts) + + # Flatten T*K for per-expert position counting + flat_mask = expert_one_hot.reshape(-1, num_experts) # [T*K, E] + + # Position within each expert's buffer (0-indexed via cumsum) + positions = (jnp.cumsum(flat_mask, axis=0) - 1) * flat_mask # [T*K, E] + + # Drop tokens exceeding capacity + within_cap = (positions < capacity) & (flat_mask > 0) + safe_positions = jnp.where(within_cap, positions, 0).astype(jnp.int32) + + # Dispatch tensor: [T*K, E, C] via one-hot on position + pos_one_hot = jax.nn.one_hot(safe_positions, capacity) # [T*K, E, C] + dispatch_flat = pos_one_hot * within_cap[..., None] + + # Combine tensor: dispatch weighted by routing weights + flat_weights = top_k_weights.reshape(-1) # [T*K] + combine_flat = dispatch_flat * flat_weights[:, None, None] + + # Aggregate over K dimension: [T, E, C] + dispatch = dispatch_flat.reshape(num_tokens, top_k, num_experts, capacity).sum(axis=1) + combine = combine_flat.reshape(num_tokens, top_k, num_experts, capacity).sum(axis=1) + + # ---- Step 3: Dispatch tokens to expert buffers ---- + # [E, C, H] = einsum([T, E, C], [T, H]) + expert_inputs = jnp.einsum('tec,th->ech', dispatch, hidden_states) + + # ---- Step 4: Batched expert computation ---- + gate_up_out = jnp.einsum('ech,eih->eci', expert_inputs, gate_up_proj) # [E, C, 2I] + gate_part, up_part = jnp.split(gate_up_out, 2, axis=-1) + expert_out = jnp.einsum( + 'eci,ehi->ech', jax.nn.silu(gate_part) * up_part, down_proj + ) # [E, C, H] + + # ---- Step 5: Combine -- scatter results back ---- + # [T, H] = einsum([T, E, C], [E, C, H]) + output = jnp.einsum('tec,ech->th', combine, expert_out) + + return output + +## WHY this pattern is correct: + +1. **Batched einsums**: All expert computation is batched via einsum. No Python loops, + no per-token gathers, no `.at[].add()`. XLA compiles this into efficient matmuls. +2. **O(E*C*H*I)** compute where C = ceil(T*K/E)*1.5, typically C << T. + For E=64, K=4, T=1024: C ~= 96 vs T=1024. Each expert only processes its share. +3. **Capacity overflow**: Tokens exceeding an expert's capacity are dropped via the + `within_cap` mask. With 1.5x capacity factor, drops are rare for trained routers. +4. **dispatch/combine tensors**: The dispatch tensor routes tokens TO expert buffers, + the combine tensor routes results BACK with routing weights. Both are [T, E, C]. +5. **Matches PyTorch**: The PyTorch Qwen3NextExperts uses this capacity-based pattern + internally (via scatter/gather ops). The einsum formulation is the JAX equivalent. + +## Router weight initialization: + +The router (gate) weight should be zero-initialized when the source model explicitly +zero-initializes it (e.g., Qwen3-Next, Switch Transformer, GShard). If the source uses +a different explicit init, match the source. If the source uses bare `nn.Linear` with +no custom init, use the Flax default (`lecun_normal`). + + # When source's _init_weights zeros the router: + weight = self.param('weight', nn.initializers.zeros_init(), (num_experts, hidden_dim)) + +Zero-init ensures uniform routing at start of training. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_no_explicit_init_for_bare_layers_jax.py b/MaxCode/rag/sources/targeted/targeted_no_explicit_init_for_bare_layers_jax.py new file mode 100644 index 0000000..44d2ace --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_no_explicit_init_for_bare_layers_jax.py @@ -0,0 +1,105 @@ +""" +TARGETED JAX PATTERN: No Explicit Initializer for Bare nn.Linear / nn.Conv1d + +CRITICAL: When converting bare PyTorch layers that use only framework defaults +(no explicit nn.init call), the JAX conversion must NOT add explicit initializer +arguments. Flax defaults (lecun_normal for kernel, zeros for bias) are the +accepted equivalent of PyTorch defaults (kaiming_uniform for weight, uniform for +bias). Adding explicit kaiming_uniform or uniform locks in a specific +initialization that may not match downstream usage. + +## WRONG: Adding explicit kaiming_uniform to bare nn.Conv1d + + # PyTorch source: + # self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) + # (no nn.init call anywhere for conv1) + + # WRONG! Source uses the default init, but conversion adds explicit kaiming. + conv1 = nn.Conv( + features=out_channels, + kernel_size=(1,), + use_bias=False, + kernel_init=nn.initializers.kaiming_uniform(), # NOT in source! + ) + +## WRONG: Adding explicit kaiming_uniform and uniform to bare nn.Linear + + # PyTorch source: + # self.fc = nn.Linear(in_features, out_features) + # (no nn.init call anywhere for fc) + + # WRONG! Source uses the default init, but conversion adds explicit inits. + fc = nn.Dense( + features=out_features, + kernel_init=nn.initializers.kaiming_uniform(), # NOT in source! + bias_init=nn.initializers.uniform(), # NOT in source! + ) + +## WRONG: Adding explicit kaiming_uniform to a gate projection + + # PyTorch source: + # self.gate = nn.Linear(hidden_size, num_heads, bias=False) + # (no nn.init call) + + # WRONG! + gate = nn.Dense( + features=num_heads, + use_bias=False, + kernel_init=nn.initializers.kaiming_uniform(), # NOT in source! + ) + +## CORRECT: Bare nn.Conv1d -> bare nn.Conv (no explicit init args) + + # PyTorch source: + # self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) + + # CORRECT: No explicit initializer. Flax default (lecun_normal) is the + # accepted equivalent of PyTorch's default (kaiming_uniform). + conv1 = nn.Conv( + features=out_channels, + kernel_size=(1,), + use_bias=False, + ) + +## CORRECT: Bare nn.Linear -> bare nn.Dense (no explicit init args) + + # PyTorch source: + # self.fc = nn.Linear(in_features, out_features) + + # CORRECT: No explicit initializer. Flax defaults (lecun_normal for kernel, + # zeros for bias) are the accepted equivalent of PyTorch's defaults. + fc = nn.Dense(features=out_features) + +## CORRECT: Only use explicit init when the source explicitly initializes + + # PyTorch source HAS an explicit init call: + # self.fc = nn.Linear(in_features, out_features) + # nn.init.xavier_uniform_(self.fc.weight) + # nn.init.zeros_(self.fc.bias) + + # CORRECT: Mirror the explicit init from source. + fc = nn.Dense( + features=out_features, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros_init(), + ) + +## Why this matters: + +1. **PyTorch default != Flax default, but both are accepted**: PyTorch uses + kaiming_uniform by default; Flax uses lecun_normal. These are DIFFERENT + distributions, but both are reasonable defaults. Adding explicit kaiming + to Flax code locks in a specific choice the source author never made. +2. **Bare layers signal "use framework default"**: When the source writes + `nn.Linear(in, out)` with no init call, the intent is "use whatever the + framework provides". The JAX equivalent of that intent is `nn.Dense(out)` + with no init args. +3. **Explicit init adds noise to verification**: Adding kaiming_uniform gets + flagged as a deviation from source faithfulness, even though the source + never specified any initializer. +4. **Weight loading overrides init anyway**: For inference or fine-tuning from + pretrained weights, the initializer is irrelevant because weights are loaded + from a checkpoint. Adding an explicit init is pure noise. +5. **Rule of thumb**: Only add kernel_init / bias_init to nn.Dense or nn.Conv + when the PyTorch source has an explicit nn.init.* call for that parameter. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_no_invented_attributes_jax.py b/MaxCode/rag/sources/targeted/targeted_no_invented_attributes_jax.py new file mode 100644 index 0000000..662e793 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_no_invented_attributes_jax.py @@ -0,0 +1,72 @@ +""" +TARGETED RAG: Do Not Invent Attributes or Fix Bugs in Source Code +=================================================================== + +When converting PyTorch to JAX, faithfully translate what the source code +ACTUALLY DOES, not what it SHOULD do. If the source has a bug (e.g., +referencing an undefined attribute), the JAX version should reproduce +that same behavior, not silently fix it by adding the missing attribute. + +WRONG -- Adding attributes that don't exist in the PyTorch source: +------------------------------------------------------------------- + # PyTorch source: + # class TransformerEncoder(nn.Module): + # def __init__(self, embed_dim, num_heads, layers, ...): + # self.embed_dim = embed_dim + # self.embed_positions = SinusoidalPositionalEmbedding(embed_dim) + # # NOTE: self.max_source_positions is NEVER defined here + # + # def max_positions(self): + # if self.embed_positions is None: + # return self.max_source_positions # Would crash: AttributeError + # return min(self.max_source_positions, self.embed_positions.max_positions()) + # # Also uses self.max_source_positions -- would crash + + # WRONG! Invented max_source_positions with a made-up default value + class TransformerEncoder(nn.Module): + embed_dim: int + num_heads: int + layers: int + max_source_positions: int = 100000 # NOT IN SOURCE! Invented attribute! + + def max_positions(self): + return min(self.max_source_positions, self.embed_positions.max_positions()) + +WHY THIS IS WRONG: +- The PyTorch source never defines max_source_positions in __init__ +- Adding it with a default value of 100000 introduces behavior that doesn't + exist in the original model +- The original max_positions() method would crash if called -- the JAX version + silently "fixes" this by inventing an attribute +- Users loading PyTorch weights into the JAX model will have an unexpected + extra parameter that doesn't correspond to any PyTorch state +- The invented default (100000) is arbitrary and may not match user expectations + +CORRECT -- Faithfully reproduce the source's behavior: +-------------------------------------------------------- + # Option A: Reproduce the bug faithfully + class TransformerEncoder(nn.Module): + embed_dim: int + num_heads: int + layers: int + # Do NOT add max_source_positions -- it's not in the source + + def max_positions(self): + # Faithfully translated: embed_positions is always non-None, + # so we only need the path that actually executes + return self.embed_positions.max_positions() + + # Option B: If max_positions() is never called in the model's forward pass, + # translate only the code paths that are actually reachable + class TransformerEncoder(nn.Module): + embed_dim: int + num_heads: int + layers: int + # max_positions() method omitted since it references undefined attributes + # and is never called during forward() + +RULE: Never add attributes, parameters, or default values that don't exist in +the PyTorch source. If the source has unreachable or buggy code paths, +either faithfully reproduce them or omit them -- but never "fix" them +by inventing new state. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_pallas_kernel_opportunities.py b/MaxCode/rag/sources/targeted/targeted_pallas_kernel_opportunities.py new file mode 100644 index 0000000..6b61b70 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_pallas_kernel_opportunities.py @@ -0,0 +1,152 @@ +""" +TARGETED JAX PATTERN: Pallas Kernel Fusion Opportunities + +This document identifies high-priority operations that benefit from Pallas kernel +fusion on TPU/GPU. For initial conversion, implement these in pure JAX first, +then add Pallas kernels as optimizations. The pure JAX version serves as the +reference implementation. + +## What is Pallas? + +Pallas is JAX's kernel language for writing custom TPU/GPU kernels. It provides: +- Direct control over memory hierarchy (VMEM on TPU, shared memory on GPU) +- Kernel fusion (combine multiple ops into one kernel launch) +- BlockSpec for tiling large tensors into manageable chunks +- Automatic grid parallelism + +## High-Priority Fusion Opportunities: + +### 1. Chunk Delta Rule (3-5x speedup on TPU) + +Current pure JAX implementation uses 6+ separate kernels: + - cumsum for decay + - matmul for Q@K^T + - tril masking + - solve_triangular for WY representation + - matmul for attention @ value + - state update matmul + +Pallas fusion: Single kernel per chunk that does all of the above in VMEM/SRAM. + + # Current pure JAX (correct, use as reference): + g_cumsum = jnp.cumsum(log_decay, axis=-1) + decay_mask = jnp.exp(g_cumsum[..., :, None] - g_cumsum[..., None, :]) + decay_mask = jnp.where(causal_mask, decay_mask, 0.0) + raw_attn = (k_beta @ key.swapaxes(-2, -1)) * decay_mask + attn = jax.scipy.linalg.solve_triangular(eye - raw_attn, eye, lower=True) + out = attn @ v_beta + + # Future Pallas kernel (pseudocode): + @pl.pallas_call( + out_shape=jax.ShapeDtypeStruct((batch, heads, chunk_size, v_dim), jnp.bfloat16), + grid=(batch, heads), + in_specs=[BlockSpec((1, 1, chunk_size, k_dim), lambda b, h: (b, h, 0, 0)), # q + BlockSpec((1, 1, chunk_size, k_dim), lambda b, h: (b, h, 0, 0)), # k + BlockSpec((1, 1, chunk_size, v_dim), lambda b, h: (b, h, 0, 0)), # v + BlockSpec((1, 1, chunk_size), lambda b, h: (b, h, 0))], # decay + out_specs=BlockSpec((1, 1, chunk_size, v_dim), lambda b, h: (b, h, 0, 0)), + ) + def chunk_delta_rule_kernel(q_ref, k_ref, v_ref, decay_ref, out_ref): + # All computation in on-chip memory, no HBM round-trips + q = q_ref[...] + k = k_ref[...] + v = v_ref[...] + # ... fused cumsum + mask + solve + matmul ... + out_ref[...] = result + +### 2. Causal Conv1d + SiLU (2-3x speedup) + +Current: 3 separate kernels (pad + conv_general_dilated + silu) +Fused: Single depthwise conv + activation kernel + + # Current pure JAX (correct, use as reference): + x_padded = jnp.pad(x, ((0, 0), (0, 0), (kernel_size - 1, 0))) + y = jax.lax.conv_general_dilated(x_padded, weight, (1,), 'VALID', + feature_group_count=channels, + dimension_numbers=('NCH', 'IOH', 'NCH')) + y = jax.nn.silu(y) + + # The fusion opportunity: pad + conv + silu in one kernel + # Especially beneficial for decode (single timestep, kernel launch overhead dominates) + +### 3. MoE Expert Dispatch + Compute (10-50x for large E) + +Current: 5+ kernels (top_k + one_hot + cumsum + scatter + expert_matmul + gather) +Fused: Single megakernel that routes and computes in shared memory + + # This is the MOST impactful fusion for models with many experts. + # For E=64, K=2, most tokens go to ~2 experts out of 64. + # Without fusion: scattered memory access patterns dominate runtime. + # With fusion: tokens are routed to expert SRAM tiles, computed locally. + + # Start with capacity-based pure JAX dispatch (see targeted_moe_capacity_routing_jax.py) + # Then profile to decide if Pallas fusion is needed. + +### 4. RMSNormGated (2x speedup) + +Current: 6 elementwise ops (square + mean + rsqrt + multiply + gate_silu + multiply) +Fused: Single-pass kernel reading x once, writing normalized + gated output + + # Current pure JAX (correct, use as reference): + def rms_norm_gated(x, gate, weight, eps=1e-6): + x_f32 = x.astype(jnp.float32) + rms = jax.lax.rsqrt(jnp.mean(x_f32 ** 2, axis=-1, keepdims=True) + eps) + normed = (x_f32 * rms).astype(x.dtype) * weight + return normed * jax.nn.silu(gate) + + # Fused version reads x and gate once from HBM, does everything in SRAM/registers + +## Pallas Basics: + +### @pl.pallas_call pattern: + + from jax.experimental import pallas as pl + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(output_shape, output_dtype), + grid=grid_dims, # Parallel grid dimensions + in_specs=[BlockSpec(...)], # How to tile inputs + out_specs=BlockSpec(...), # How to tile outputs + ) + def my_kernel(input_ref, output_ref): + # input_ref and output_ref are Ref types (like pointers to tiles) + x = input_ref[...] # Load tile from memory + result = x * 2 # Compute + output_ref[...] = result # Store tile to memory + +### BlockSpec basics: + + # BlockSpec(block_shape, index_map) + # block_shape: size of each tile + # index_map: function from grid indices to tile start indices + + # Example: tile a [1024, 512] matrix into [128, 128] blocks + BlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i * 128, j * 128), + ) + +### When to use Pallas vs pure JAX: + +| Situation | Use | +|--------------------------------------------|-------------| +| Initial conversion / correctness | Pure JAX | +| Element-wise fusion (norm + activation) | Pallas | +| Complex memory access (scatter/gather MoE) | Pallas | +| Simple matmuls | Pure JAX | +| Custom reduction patterns | Pallas | +| Prototype / debugging | Pure JAX | +| Production TPU serving | Pallas | + +## Implementation Strategy: + +1. **Phase 1**: Convert everything to pure JAX/Flax. Verify correctness against + PyTorch reference outputs. +2. **Phase 2**: Profile on TPU to identify actual bottlenecks (don't guess!). +3. **Phase 3**: Write Pallas kernels for the top 2-3 bottlenecks. +4. **Phase 4**: Verify Pallas output matches pure JAX output numerically. + +Always keep the pure JAX version as a fallback and reference. Pallas kernels +should be drop-in replacements with the same function signature. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_preserve_class_hierarchy_jax.py b/MaxCode/rag/sources/targeted/targeted_preserve_class_hierarchy_jax.py new file mode 100644 index 0000000..50b9f5f --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_preserve_class_hierarchy_jax.py @@ -0,0 +1,153 @@ +""" +TARGETED JAX PATTERN: Preserve Class Hierarchy and All Source Components + +CRITICAL: When converting PyTorch to JAX/Flax, preserve EVERY class, function, +and method from the source. Do not merge classes, drop base classes, or omit +utility functions/classes — even if they seem redundant. The goal is a faithful +1:1 conversion, not a redesign. + +## WRONG: Merging base class into subclass + + # Source has: + # class ExpertBase(nn.Module): ... # base with 2-layer network + # class FFNExpert(ExpertBase): ... # subclass with configurable layers + + # WRONG! Merging them loses the base class and breaks code that + # instantiates ExpertBase directly. + class FFNExpert(nn.Module): + config: MoEConfig + # ... only the subclass, base class gone + +## CORRECT: Preserve both classes + + class ExpertBase(nn.Module): + input_dim: int + output_dim: int + hidden_dim: int = None + + def setup(self): + hdim = self.hidden_dim if self.hidden_dim is not None else 4 * self.input_dim + self.dense1 = nn.Dense(hdim) + self.dense2 = nn.Dense(self.output_dim) + + def __call__(self, x): + x = self.dense1(x) + x = nn.relu(x) + x = self.dense2(x) + return x + + class FFNExpert(nn.Module): + input_dim: int + output_dim: int + hidden_dim: int = None + num_layers: int = 2 + dropout_rate: float = 0.1 + + @nn.compact + def __call__(self, x, deterministic=True): + hdim = self.hidden_dim if self.hidden_dim is not None else 4 * self.input_dim + for i in range(self.num_layers - 1): + x = nn.Dense(hdim, name=f'dense_{i}')(x) + x = nn.relu(x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = nn.Dense(self.output_dim, name=f'dense_{self.num_layers - 1}')(x) + return x + +## WRONG: Dropping get_config / serialization methods + + # Source has get_config() on multiple classes for checkpoint serialization. + # WRONG! Omitting these breaks save/load workflows. + class MixtureOfExperts(nn.Module): + # ... no get_config method + +## CORRECT: Preserve get_config methods + + class MixtureOfExperts(nn.Module): + input_dim: int + output_dim: int + num_experts: int + k: int = 1 + + # ... other methods ... + + def get_config(self): + return { + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'num_experts': self.num_experts, + 'k': self.k, + } + +## WRONG: Omitting utility classes and functions + + # Source has: + # def expert_utilization(routing_weights): ... + # def expert_capacity_utilization(routing_weights, capacity): ... + # def routing_entropy(routing_weights): ... + # def expert_correlation(expert_outputs): ... + # class MoEMetrics: ... + + # WRONG! Only converting some functions and dropping the class. + def expert_utilization(routing_weights): + return routing_weights.mean(axis=0) + def routing_entropy(routing_weights): + ... + # expert_capacity_utilization -- MISSING + # expert_correlation -- MISSING + # MoEMetrics class -- MISSING + +## CORRECT: Convert ALL functions and classes + + def expert_utilization(routing_weights): + return jnp.mean(routing_weights, axis=0) + + def expert_capacity_utilization(routing_weights, capacity): + expert_counts = jnp.sum(routing_weights, axis=0) + return expert_counts / capacity + + def routing_entropy(routing_weights): + eps = 1e-10 + probs = routing_weights + eps + return -(probs * jnp.log(probs)).sum(axis=-1).mean() + + def expert_correlation(expert_outputs): + num_experts = len(expert_outputs) + correlations = jnp.zeros((num_experts, num_experts)) + for i in range(num_experts): + for j in range(i + 1, num_experts): + xi = expert_outputs[i].flatten() + xj = expert_outputs[j].flatten() + corr = jnp.dot(xi, xj) / (jnp.linalg.norm(xi) * jnp.linalg.norm(xj)) + correlations = correlations.at[i, j].set(corr) + correlations = correlations.at[j, i].set(corr) + return correlations + + class MoEMetrics: + def __init__(self, num_experts, expert_capacity=None): + self.num_experts = num_experts + self.expert_capacity = expert_capacity + + def compute_metrics(self, routing_weights, expert_outputs=None): + metrics = { + 'expert_utilization': expert_utilization(routing_weights), + 'routing_entropy': routing_entropy(routing_weights), + } + if self.expert_capacity is not None: + metrics['capacity_utilization'] = expert_capacity_utilization( + routing_weights, self.expert_capacity + ) + if expert_outputs is not None: + metrics['expert_correlation'] = expert_correlation(expert_outputs) + return metrics + +## Why preserving everything matters: + +1. **API compatibility**: Downstream code may instantiate ExpertBase, call get_config(), + or use MoEMetrics. Dropping them breaks the public interface. +2. **Testing**: Equivalence tests compare source and converted outputs class-by-class. + Missing classes cause test failures. +3. **Faithfulness**: The conversion should be a translation, not a redesign. Users + expect to find every source component in the output. +4. **Weight loading**: get_config() is used during checkpoint serialization/deserialization. + Without it, weights cannot be saved or loaded correctly. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_preserve_default_values_jax.py b/MaxCode/rag/sources/targeted/targeted_preserve_default_values_jax.py new file mode 100644 index 0000000..6c6a2ef --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_preserve_default_values_jax.py @@ -0,0 +1,107 @@ +""" +TARGETED JAX PATTERN: Preserve Default Parameter Values Exactly + +CRITICAL: When converting PyTorch to JAX, default parameter values must match +the source EXACTLY. Do not change defaults, even if you think a different value +is "better". Changed defaults silently alter model behavior and break +reproducibility between PyTorch and JAX versions. + +## WRONG: Changing default values during conversion + + # PyTorch source: + # class Router(nn.Module): + # def __init__(self, input_dim, num_experts, k=1, capacity_factor=1.0): + # ... + + # WRONG! Changed capacity_factor from 1.0 to 1.25 + class Router(nn.Module): + config: MoEConfig # where MoEConfig has capacity_factor: float = 1.25 + + # WRONG! Changed dropout from 0.1 to 0.0 + class FFNExpert(nn.Module): + dropout_rate: float = 0.0 # Source default is 0.1! + + # WRONG! Changed noise_epsilon from 1e-2 to 1e-3 + class Router(nn.Module): + noise_epsilon: float = 1e-3 # Source default is 1e-2! + +## CORRECT: Match source defaults exactly + + # PyTorch source: + # class Router(nn.Module): + # def __init__(self, input_dim, num_experts, k=1, capacity_factor=1.0): + + # CORRECT: All defaults match source + class Router(nn.Module): + input_dim: int + num_experts: int + k: int = 1 + capacity_factor: float = 1.0 # Matches source exactly + + # CORRECT: If using a config dataclass, defaults must also match + @dataclasses.dataclass + class MoEConfig: + input_dim: int + output_dim: int + num_experts: int + k: int = 1 + capacity_factor: float = 1.0 # Must match source Router default + noise_epsilon: float = 1e-2 # Must match source Router default + dropout_rate: float = 0.1 # Must match source FFNExpert default + num_layers: int = 2 # Must match source FFNExpert default + +## WRONG: Changing weight initialization from PyTorch default + + # PyTorch nn.Linear uses Kaiming uniform by default (not zeros, not normal). + # When the source uses bare nn.Linear(...) with no explicit init, use the + # Flax default initializer (lecun_normal), NOT zeros_init. + + # WRONG! Source uses default init, but conversion uses zeros + router_logits = nn.Dense( + features=num_experts, + use_bias=False, + kernel_init=nn.initializers.zeros_init(), # NOT what source does! + )(x) + +## CORRECT: Match PyTorch default initialization + + # When PyTorch source uses bare nn.Linear with no custom init: + router_logits = nn.Dense( + features=num_experts, + use_bias=False, + # Default Flax init (lecun_normal) is acceptable, or use: + # kernel_init=nn.initializers.normal(stddev=config.initializer_range) + # DO NOT use zeros_init unless the source explicitly does so. + )(x) + + # ONLY use zeros_init when the source EXPLICITLY initializes to zeros: + # nn.init.zeros_(self.router.weight) # PyTorch source has this line + # Then and only then: + router_logits = nn.Dense( + features=num_experts, + kernel_init=nn.initializers.zeros_init(), + )(x) + +## Note on _init_weights and constructor defaults: + +When the source's `_init_weights` method explicitly zero-initializes a layer +(e.g., router weights via `nn.init.zeros_`), use `zeros_init()` in the Flax +conversion. This IS matching the source, since `_init_weights` overrides the +constructor default. The rule "match the source default" means match the +EFFECTIVE default after all initialization code runs, not just the bare +constructor signature. + +## Why preserving defaults matters: + +1. **Reproducibility**: Changed defaults mean the JAX model behaves differently + from PyTorch even with identical weights and inputs. +2. **Capacity factor**: Changing capacity_factor from 1.0 to 1.25 changes how many + tokens each expert receives, altering load balancing dynamics. +3. **Dropout rate**: A different default dropout rate changes regularization strength, + leading to different training outcomes. +4. **Router init**: Zero-initialized router weights produce uniform routing at step 0, + while Kaiming/lecun_normal produces non-uniform routing. This affects early + training dynamics and can lead to expert collapse or slower convergence. +5. **Trust the source**: The original author chose specific defaults for a reason. + The conversion should preserve their intent exactly. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_qkvz_interleaved_ordering.py b/MaxCode/rag/sources/targeted/targeted_qkvz_interleaved_ordering.py new file mode 100644 index 0000000..38e5a54 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_qkvz_interleaved_ordering.py @@ -0,0 +1,62 @@ +""" +TARGETED JAX PATTERN: Interleaved QKVZ Weight Ordering (fix_query_key_value_ordering) + +CRITICAL: When converting models where num_key_heads != num_value_heads, +the projection weights are stored in an INTERLEAVED order grouped by key heads. +You MUST NOT use a flat split on the concatenated projection output. + +## The Problem: + +If num_k_heads = 4 and num_v_heads = 8 (i.e., v_per_k = 2), the QKVZ +projection output is NOT laid out as [all_Q, all_K, all_V, all_Z]. + +Instead, it is grouped by key heads: + [key_head_0_Q, key_head_0_K, key_head_0_V0, key_head_0_V1, key_head_0_Z0, key_head_0_Z1, + key_head_1_Q, key_head_1_K, key_head_1_V0, key_head_1_V1, key_head_1_Z0, key_head_1_Z1, + ...] + +## WRONG approach (flat split -- DO NOT DO THIS): + + # WRONG! This assumes Q, K, V, Z are contiguous blocks + q, k, v, z = jnp.split(proj_qkvz, [key_dim, key_dim*2, key_dim*2+value_dim], axis=-1) + +## CORRECT approach (group by key heads, then split within each group): + + def fix_query_key_value_ordering(mixed_qkvz, mixed_ba, batch_size, seq_len, + num_k_heads, num_v_heads, head_k_dim, head_v_dim): + v_per_k = num_v_heads // num_k_heads + + # Step 1: Reshape to [B, T, num_k_heads, per_head_size] + per_head_size = 2 * head_k_dim + 2 * v_per_k * head_v_dim + qkvz = mixed_qkvz.reshape(batch_size, seq_len, num_k_heads, per_head_size) + + # Step 2: Split within each key-head group + split_points = [head_k_dim, 2 * head_k_dim, 2 * head_k_dim + v_per_k * head_v_dim] + q, k, v, z = jnp.split(qkvz, split_points, axis=-1) + # q: [B, T, num_k_heads, head_k_dim] + # k: [B, T, num_k_heads, head_k_dim] + # v: [B, T, num_k_heads, v_per_k * head_v_dim] + # z: [B, T, num_k_heads, v_per_k * head_v_dim] + + # Step 3: Reshape v, z to per-value-head + v = v.reshape(batch_size, seq_len, num_v_heads, head_v_dim) + z = z.reshape(batch_size, seq_len, num_v_heads, head_v_dim) + + # Same for BA projection: + ba_per_head = 2 * v_per_k + ba = mixed_ba.reshape(batch_size, seq_len, num_k_heads, ba_per_head) + b, a = jnp.split(ba, 2, axis=-1) + b = b.reshape(batch_size, seq_len, num_v_heads) + a = a.reshape(batch_size, seq_len, num_v_heads) + + return q, k, v, z, b, a + +## Why this matters: + +With num_k_heads=4 and num_v_heads=8, a flat split would assign the wrong +dimensions to Q, K, V, Z because the weights are interleaved per key-head group. +The model will produce completely wrong outputs if this ordering is not preserved. + +This pattern appears in Qwen3-Next's GatedDeltaNet and similar models with +grouped key-value heads in linear attention layers. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_reduction_axis_preservation_jax.py b/MaxCode/rag/sources/targeted/targeted_reduction_axis_preservation_jax.py new file mode 100644 index 0000000..892c923 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_reduction_axis_preservation_jax.py @@ -0,0 +1,112 @@ +""" +TARGETED JAX PATTERN: Preserve Exact Reduction Axes — Never Flatten or Combine + +CRITICAL: When PyTorch uses `dim=N` in a reduction (mean, sum, max, etc.), the +JAX conversion MUST use `axis=N` with the SAME single integer. Never combine +multiple axes like `axis=(0, 1)`, and never reshape/flatten the tensor before +reducing. These change the output shape and numerical result. + +This mistake is especially common in MoE load-balancing loss functions where +`expert_mask` has shape [tokens, top_k, num_experts]. The LLM "helpfully" +collapses the top_k dimension, but PyTorch's `dim=0` preserves it. + +## WRONG: Combining axes when source uses a single dim + + # PyTorch source: + # expert_mask = one_hot(selected_experts, num_experts) + # # expert_mask shape: [num_tokens, top_k, num_experts] + # tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # # result shape: [top_k, num_experts] + + # WRONG! axis=(0, 1) reduces BOTH token and top_k dims. + # Result shape becomes [num_experts] instead of [top_k, num_experts]. + tokens_per_expert = jnp.mean(expert_mask, axis=(0, 1)) + + # WRONG! Flattening first, then reducing, also collapses the top_k dim. + expert_mask_flat = expert_mask.reshape(-1, num_experts) + tokens_per_expert = jnp.mean(expert_mask_flat, axis=0) + +## WRONG: Flattening before sum changes the semantics + + # PyTorch source: + # tokens_per_expert = torch.sum( + # expert_mask.float() * expert_attention_mask, dim=0 + # ) / torch.sum(expert_attention_mask, dim=0) + # # Both sums reduce dim=0 only, preserving [top_k, num_experts] + + # WRONG! Flattening expert_mask before summing collapses top_k. + expert_mask_flattened = expert_mask.reshape(-1, num_experts) + attn_mask_flattened = expert_attention_mask.reshape(-1, num_experts) + tokens_per_expert = jnp.sum(expert_mask_flattened * attn_mask_flattened, axis=0) + +## CORRECT: dim=0 becomes axis=0, nothing else changes + + # PyTorch source: + # tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + # # shape: [num_tokens, top_k, num_experts] -> [top_k, num_experts] + + # CORRECT: axis=0 reduces only the first dimension, preserving top_k. + tokens_per_expert = jnp.mean(expert_mask.astype(jnp.float32), axis=0) + # result shape: [top_k, num_experts] -- matches PyTorch exactly + +## CORRECT: Masked sum with axis=0 only + + # PyTorch source: + # tokens_per_expert = torch.sum( + # expert_mask.float() * expert_attention_mask, dim=0 + # ) / torch.sum(expert_attention_mask, dim=0) + + # CORRECT: reduce axis=0 without any reshaping or flattening. + tokens_per_expert = ( + jnp.sum(expert_mask.astype(jnp.float32) * expert_attention_mask, axis=0) + / jnp.maximum(jnp.sum(expert_attention_mask, axis=0), 1e-9) + ) + # result shape: [top_k, num_experts] -- matches PyTorch exactly + +## CORRECT: Subsequent operations use the preserved shape + + # PyTorch source: + # router_prob_per_expert = torch.mean(routing_weights, dim=0) + # # routing_weights shape: [num_tokens, num_experts] + # # result shape: [num_experts] + # overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert) + + # CORRECT: router_prob_per_expert is [num_experts], tokens_per_expert is + # [top_k, num_experts]. Broadcasting handles the shape difference. + router_prob_per_expert = jnp.mean(routing_weights, axis=0) + overall_loss = jnp.sum(tokens_per_expert * router_prob_per_expert[None, :]) + +## The general rule: + + # torch.mean(x, dim=N) => jnp.mean(x, axis=N) + # torch.sum(x, dim=N) => jnp.sum(x, axis=N) + # torch.max(x, dim=N) => jnp.max(x, axis=N) + # torch.min(x, dim=N) => jnp.min(x, axis=N) + # + # The axis integer is ALWAYS the same as the dim integer. + # NEVER combine axes: dim=0 does NOT become axis=(0, 1). + # NEVER flatten before reducing: reshape(-1, K) + axis=0 != axis=0 on original. + # NEVER add axes that are not in the source. + +## Why this matters: + +1. **Shape change**: `axis=(0, 1)` produces a different output shape than + `axis=0`. Downstream code expecting [top_k, num_experts] will break or + silently compute wrong results with [num_experts]. + +2. **Numerical change**: Reducing over more elements changes the mean/sum + value. `mean(x, axis=0)` divides by `x.shape[0]`, while + `mean(x, axis=(0,1))` divides by `x.shape[0] * x.shape[1]`. + +3. **Load-balancing loss**: In MoE models, this bug makes the auxiliary loss + numerically wrong, which destabilizes expert routing during training. + Experts may collapse to a single active expert or oscillate wildly. + +4. **Flattening is not neutral**: `x.reshape(-1, K)` followed by `sum(axis=0)` + is mathematically equivalent to `sum(axis=tuple(range(x.ndim-1)))` — it + reduces ALL leading dimensions, not just the first one. + +5. **Rule of thumb**: If the source says `dim=0`, write `axis=0` and touch + nothing else. Do not reshape, flatten, squeeze, or combine axes. The + tensor shape flowing through JAX should match PyTorch at every step. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_scan_vs_forloop_jax.py b/MaxCode/rag/sources/targeted/targeted_scan_vs_forloop_jax.py new file mode 100644 index 0000000..ba3f9a1 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_scan_vs_forloop_jax.py @@ -0,0 +1,124 @@ +""" +TARGETED JAX PATTERN: scan vs fori_loop vs Python for-loop + +When converting sequential loops from PyTorch to JAX, choose the right primitive. +NEVER use a plain Python for-loop over a dynamic range for sequential computation -- +it unrolls at trace time, causing slow compilation and large XLA graphs. + +## Decision Table: + +| Pattern | JAX Primitive | When to Use | +|----------------------------------|----------------------|--------------------------------------| +| Sequential state + collect outputs| `jax.lax.scan` | RNN steps, chunk scans, time series | +| Sequential state, no outputs | `jax.lax.fori_loop` | Iterative refinement, power iteration| +| Fixed small N (< ~8) | Python for-loop | Unrolling is acceptable | +| Independent iterations | `jax.vmap` | Batched computation, no dependencies | + +## WRONG: Python for-loop for sequential scan (DO NOT DO THIS): + + # WRONG! Unrolls N iterations at trace time -> huge XLA graph, slow compile + state = init_state + outputs = [] + for i in range(num_chunks): + state, out = step_fn(state, inputs[i]) + outputs.append(out) + outputs = jnp.stack(outputs) + +## CORRECT: jax.lax.scan for sequential state + outputs: + + import jax + import jax.numpy as jnp + + def scan_chunks(init_state, inputs): + ''' + Process chunks sequentially, accumulating state and collecting outputs. + + Args: + init_state: [batch, heads, k_dim, v_dim] initial recurrent state + inputs: tuple of arrays, each with leading dim = num_chunks + (arrays are sliced along axis 0 for each step) + + Returns: + final_state: [batch, heads, k_dim, v_dim] + all_outputs: [num_chunks, batch, heads, chunk_size, v_dim] + ''' + def step_fn(carry, chunk_input): + state = carry + q_c, k_c, v_c, decay_c = chunk_input + + # Inter-chunk: query the accumulated state + inter_out = jnp.einsum('bhkd,bhkv->bhdv', q_c, state) + + # Intra-chunk: local attention within the chunk + intra_out = local_attention(q_c, k_c, v_c, decay_c) + + out = inter_out + intra_out + + # Update state for next chunk + new_state = state * decay_c[..., -1:, None] + jnp.einsum( + 'bhck,bhcv->bhkv', k_c, v_c + ) + + return new_state, out + + final_state, all_outputs = jax.lax.scan(step_fn, init_state, inputs) + return final_state, all_outputs + +## CORRECT: Reshaping inputs for scan + + # Inputs are [batch, heads, seq_len, dim] + # Need to reshape to [num_chunks, batch, heads, chunk_size, dim] for scan + + batch, heads, seq_len, dim = x.shape + chunk_size = 64 + num_chunks = seq_len // chunk_size + + # Reshape: split seq_len into (num_chunks, chunk_size) + x_chunked = x.reshape(batch, heads, num_chunks, chunk_size, dim) + + # Transpose time axis to LEADING position for scan + # scan slices along axis 0, so num_chunks must be first + x_chunked = jnp.transpose(x_chunked, (2, 0, 1, 3, 4)) + # Now: [num_chunks, batch, heads, chunk_size, dim] + + # Pack multiple arrays into a tuple for scan + scan_inputs = (q_chunked, k_chunked, v_chunked, decay_chunked) + +## CORRECT: jax.lax.fori_loop for state-only iteration: + + def iterative_refinement(init_x, num_iters): + '''State-only loop -- no outputs collected per step.''' + def body_fn(i, state): + x = state + x = x - learning_rate * gradient(x) + return x + + final_x = jax.lax.fori_loop(0, num_iters, body_fn, init_x) + return final_x + +## scan with auxiliary state (carry multiple values): + + def step_fn(carry, inputs): + state, running_sum = carry # Unpack multiple carry values + x = inputs + + out = state @ x + new_state = update(state, x) + new_sum = running_sum + jnp.sum(out) + + return (new_state, new_sum), out # Pack carry back as tuple + + (final_state, total_sum), outputs = jax.lax.scan( + step_fn, (init_state, jnp.zeros(())), inputs + ) + +## Key gotchas: + +1. **scan slices axis 0**: The scanned array's leading dimension is the loop length. + Transpose your data so the time/chunk axis is first. +2. **Carry must be a pytree**: Use tuples or NamedTuples for multiple carry values. +3. **Static shapes**: All arrays in the scan body must have shapes determinable at + trace time. No data-dependent shapes inside the body. +4. **scan unroll parameter**: `jax.lax.scan(..., unroll=k)` unrolls k iterations for + better optimization at the cost of compile time. Default unroll=1. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_source_faithfulness_jax.py b/MaxCode/rag/sources/targeted/targeted_source_faithfulness_jax.py new file mode 100644 index 0000000..e9fa46a --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_source_faithfulness_jax.py @@ -0,0 +1,187 @@ +""" +TARGETED JAX PATTERN: Source Faithfulness — Do Not "Improve" the Source + +CRITICAL: The goal of PyTorch-to-JAX conversion is a FAITHFUL TRANSLATION, not +a redesign or optimization. The converted code must produce identical behavior to +the source for the same inputs and weights. Never change defaults, initializers, +reduction operations, or function semantics — even if you believe a different +choice is "better", "more stable", or "more efficient". + +## Principle 1: Preserve Exact Initializer Semantics + +WRONG: Adding an explicit initializer when the source uses the framework default. + + # PyTorch source (uses default Kaiming uniform init): + # self.router = nn.Linear(input_dim, num_experts, bias=False) + + # WRONG! Source does NOT explicitly initialize to zeros. + # Adding zeros_init changes the model's behavior at initialization. + router_logits = nn.Dense( + features=num_experts, + use_bias=False, + kernel_init=nn.initializers.zeros_init(), # NOT in source! + )(x) + +CORRECT: Use the Flax default init (lecun_normal) to match "bare nn.Linear". + + # CORRECT: No explicit kernel_init => Flax default (lecun_normal), + # which is the closest match to PyTorch's default Kaiming uniform. + router_logits = nn.Dense( + features=num_experts, + use_bias=False, + )(x) + + # ONLY use a custom initializer when the PyTorch source EXPLICITLY sets one: + # nn.init.zeros_(self.router.weight) => kernel_init=nn.initializers.zeros_init() + # nn.init.normal_(self.fc.weight, std=0.02) => kernel_init=nn.initializers.normal(stddev=0.02) + # nn.init.xavier_uniform_(self.fc.weight) => kernel_init=nn.initializers.xavier_uniform() + + # Exception: MoE router layers -- when the model's `_init_weights` method + # explicitly zeros the router (common in Switch Transformer, Qwen3-Next), + # use `zeros_init()` even though the router is constructed as bare `nn.Linear`. + # The `_init_weights` override IS the source's explicit init. + + +## Principle 2: Preserve Exact Default Parameter Values + +WRONG: Changing numeric defaults because you think a different value is better. + + # PyTorch source: + # def __init__(self, ..., capacity_factor=1.0, noise_epsilon=1e-2): + + # WRONG! Changed capacity_factor. The comment does NOT justify this. + @dataclass + class Config: + capacity_factor: float = 1.25 # "Increased for stability" + # This silently changes model behavior! + +CORRECT: Copy every default value exactly from the source. + + # CORRECT: All defaults match source constructor signatures exactly. + @dataclass + class Config: + capacity_factor: float = 1.0 # Matches source + noise_epsilon: float = 1e-2 # Matches source + + # This applies to ALL numeric values: learning rates, epsilon values, + # dropout rates, capacity factors, number of layers, hidden dimensions, etc. + # If the source says 1.0, write 1.0. If the source says 0.1, write 0.1. + # NEVER round, adjust, or "improve" any default. + + +## Principle 3: Preserve Exact Reduction Operations + +WRONG: Substituting one reduction for another. + + # PyTorch source: + # return routing_weights.mean(dim=0) + + # WRONG! .sum() != .mean() -- different semantics! + def expert_utilization(routing_weights): + return routing_weights.sum(axis=0) # Should be .mean()! + + # PyTorch source: + # expert_counts = routing_weights.sum(dim=0) + + # WRONG! .mean() != .sum() + def expert_counts(routing_weights): + return routing_weights.mean(axis=0) # Should be .sum()! + +CORRECT: Use the exact same reduction as the source. + + # If source uses .mean(dim=0), use .mean(axis=0) + def expert_utilization(routing_weights): + return jnp.mean(routing_weights, axis=0) + + # If source uses .sum(dim=0), use .sum(axis=0) + def expert_counts(routing_weights): + return jnp.sum(routing_weights, axis=0) + + # PyTorch dim= maps to JAX axis= with the same integer value. + # torch.mean(x, dim=0) => jnp.mean(x, axis=0) + # torch.sum(x, dim=-1) => jnp.sum(x, axis=-1) + # torch.max(x, dim=1) => jnp.max(x, axis=1) + # NEVER swap .mean() for .sum() or vice versa. + + +## Principle 4: Preserve Function Placement and Structure + +WRONG: Relocating a method from one class to another. + + # PyTorch source: + # class Router(nn.Module): + # def __init__(self, ...): + # self.capacity = lambda batch_size: int(batch_size * cf * k / E) + + # WRONG! Moving capacity computation to a different class + class MixtureOfExperts(nn.Module): + def __call__(self, x): + capacity = int(...) # Relocated from Router + +CORRECT: Keep methods and attributes on the same class as the source. + + # CORRECT: capacity stays on Router where the source defines it + class Router(nn.Module): + ... + def capacity(self, batch_size: int) -> int: + return int(batch_size * self.capacity_factor * self.k / self.num_experts) + + +## Principle 5: Preserve All Utility Components + +WRONG: Dropping "non-essential" components like logging, metrics, or I/O. + + # PyTorch source has TensorBoard logging in the trainer. + # WRONG! Dropping it because "it's not core model logic" + class Trainer: + def __init__(self, ...): + # No tensorboard setup <-- MISSING from source + +CORRECT: Convert ALL components, including logging and metrics. + + # CORRECT: Preserve TensorBoard logging using JAX-ecosystem equivalent + class Trainer: + def __init__(self, ..., tensorboard_dir=None): + self.writer = None + if tensorboard_dir: + os.makedirs(tensorboard_dir, exist_ok=True) + from tensorboardX import SummaryWriter + self.writer = SummaryWriter(tensorboard_dir) + + +## Approved Deviations from Literal Translation: + +The following JAX-specific changes are acceptable even though they differ from the +literal PyTorch code, because they preserve numerical equivalence while adapting to +JAX's programming model: + + # (a) f32 upcast before softmax/norm -- even if PyTorch relies on AMP autocast, + # JAX should explicitly upcast to f32 for numerical stability. + + # (b) lax.scan replacing Python for-loops over layers -- semantically identical, + # but enables XLA loop optimization and reduces compilation time. + + # (c) solve_triangular replacing Neumann-series for-loops -- numerically + # equivalent but more efficient and stable in JAX. + + # (d) Separate prefill/decode functions replacing if/else branching -- JAX's + # tracing requires static control flow; separate functions are the idiomatic + # equivalent of PyTorch's runtime if/else on cache state. + + # (e) Additive masking replacing boolean masking -- numerically equivalent for + # standard attention (see targeted_triangular_masking_jax.py for details). + + +## Why faithfulness matters: + +1. **Reproducibility**: Users expect identical outputs from the JAX version when + loaded with the same weights. Changed defaults or reductions break this. +2. **Weight loading**: Different initializers mean the JAX model cannot use + PyTorch pretrained weights correctly for fine-tuning or inference. +3. **Testing**: Equivalence tests compare source and converted outputs. Semantic + changes cause test failures that are hard to debug. +4. **Trust**: If users find the conversion changed their defaults, they lose + confidence in the entire output and must audit every line. +5. **Downstream code**: Other code may depend on specific method placements, + return value semantics, or default behaviors. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_sum_div_not_mean_jax.py b/MaxCode/rag/sources/targeted/targeted_sum_div_not_mean_jax.py new file mode 100644 index 0000000..d101fc9 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_sum_div_not_mean_jax.py @@ -0,0 +1,67 @@ +""" +TARGETED JAX PATTERN: Preserve .sum() / divisor — Do Not Replace with .mean() + +CRITICAL: When PyTorch source computes `.sum(dim=N) / some_constant`, the JAX +conversion must use `jnp.sum(x, axis=N) / some_constant` — NOT `.mean(axis=N)`. +These are only equivalent when the dimension size equals the constant, which is +not guaranteed. + +## WRONG: Replacing .sum(dim=1) / num_heads with .mean(axis=1) + + # PyTorch source: + # attn_output = attn_weights.sum(dim=1) / self.num_heads + + # WRONG! .mean(axis=1) divides by the dimension size (dim_size), + # but the source divides by num_heads. These differ when dim_size != num_heads. + attn_output = jnp.mean(attn_weights, axis=1) + +## WRONG: Replacing .sum(dim=-1) / divisor with .mean(axis=-1) + + # PyTorch source: + # normalized = scores.sum(dim=-1) / temperature + + # WRONG! .mean(axis=-1) divides by the last dimension size, + # but the source divides by temperature (a scalar parameter). + normalized = jnp.mean(scores, axis=-1) + +## CORRECT: Preserve .sum() / constant exactly + + # PyTorch source: + # attn_output = attn_weights.sum(dim=1) / self.num_heads + + # CORRECT: Faithful translation — sum then divide by the same constant. + attn_output = jnp.sum(attn_weights, axis=1) / self.num_heads + +## CORRECT: Preserve .sum() / scalar parameter + + # PyTorch source: + # normalized = scores.sum(dim=-1) / temperature + + # CORRECT: Same reduction and same divisor. + normalized = jnp.sum(scores, axis=-1) / temperature + +## CORRECT: Use .mean() ONLY when the source uses .mean() + + # PyTorch source: + # avg_pool = features.mean(dim=1) + + # CORRECT: Source uses .mean(), so JAX uses .mean(). + avg_pool = jnp.mean(features, axis=1) + +## Why this matters: + +1. **Different denominators**: `.mean(axis=N)` divides by `x.shape[N]` (the + dimension size). `.sum(axis=N) / C` divides by a constant C. These produce + different results whenever `x.shape[N] != C`. +2. **Concrete example**: If `attn_weights` has shape `(batch, 8, seq, seq)` and + `num_heads = 4`, then `.mean(axis=1)` divides by 8, but `.sum(axis=1) / 4` + divides by 4 — the result is off by a factor of 2. +3. **Numerical equivalence is not guaranteed**: Even when the dimension happens + to equal the constant for one model config, a different config (different + num_heads, different seq_len) may break the equivalence. +4. **Faithfulness principle**: The conversion must preserve the source's exact + arithmetic. If the source says "sum then divide by N", write "sum then divide + by N" — do not simplify to "mean". +5. **Rule of thumb**: Only use `.mean()` in JAX when the PyTorch source uses + `.mean()`. For `.sum() / constant`, always write `jnp.sum(...) / constant`. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_tied_output_projection_jax.py b/MaxCode/rag/sources/targeted/targeted_tied_output_projection_jax.py new file mode 100644 index 0000000..5fe55c5 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_tied_output_projection_jax.py @@ -0,0 +1,45 @@ +""" +TARGETED JAX PATTERN: Tied Output Projection (Weight Tying) + +When the PyTorch source uses explicit `x @ weight.T` for output projection, +the JAX conversion must use explicit matmul, not `.attend()`. Flax's +`nn.Embed.attend()` and framework-specific attend() methods (e.g., MaxText's +`Embed.attend()`) may internally match the matmul behavior, but explicit +`x @ embedding.T` guarantees numerical equivalence with the PyTorch source. + +## WRONG approach (attend() -- DO NOT DO THIS): + + # WRONG! attend() is for embedding lookup, not linear projection + token_embedding = nn.Embed(n_vocab, n_state, name='token_embedding') + x_emb = token_embedding(tokens) + # ... transformer layers ... + logits = token_embedding.attend(x_out) # <-- WRONG: may not match PyTorch + + # nn.Embed.attend() computes a dot product for attention-style lookup. + # It may apply different scaling or normalization than a simple matmul. + # The PyTorch source does `x @ weight.T` which is a plain linear projection. + +## CORRECT approach (explicit matmul with embedding table): + + token_embedding = nn.Embed(n_vocab, n_state, name='token_embedding') + x_emb = token_embedding(tokens) + # ... transformer layers ... + # Tied output projection: multiply by transpose of embedding table + logits = (x_out @ token_embedding.embedding.T).astype(jnp.float32) + + # `token_embedding.embedding` is the [n_vocab, n_state] weight matrix. + # `.T` transposes it to [n_state, n_vocab]. + # The matmul gives [B, T, n_vocab] logits -- exactly like PyTorch. + +## WHY this matters: + +1. **Faithfulness**: PyTorch `x @ weight.T` is a plain matrix multiplication. + Using `token_embedding.embedding.T` in Flax does the exact same operation. +2. **Weight loading**: When loading PyTorch weights, the embedding weight is + shared between input embedding and output projection. Using explicit matmul + ensures the same weight is used for both, matching PyTorch exactly. +3. **Numerical equivalence**: `.attend()` may apply internal transformations + that produce different logits than the simple transpose+matmul. +4. **Float32 cast**: Apply `.astype(jnp.float32)` after the matmul to match + PyTorch's `.float()` call on the logits. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_triangular_masking_jax.py b/MaxCode/rag/sources/targeted/targeted_triangular_masking_jax.py new file mode 100644 index 0000000..308d4ea --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_triangular_masking_jax.py @@ -0,0 +1,124 @@ +""" +TARGETED JAX PATTERN: Triangular Masking for Causal Attention + +For standard attention scores before softmax, use ADDITIVE masking with large negative +values, NOT multiplicative boolean masks. Multiplicative masks cause issues with +softmax (masked positions become 0 instead of being suppressed to near-zero probability). + +## WRONG: Multiplicative boolean mask (DO NOT DO THIS): + + # WRONG! After softmax, masked positions get non-zero probability + causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) + attn_weights = attn_scores * causal_mask # Zeros out future positions + attn_weights = jax.nn.softmax(attn_weights, axis=-1) + # Problem: softmax(0) != 0, so masked positions still get some probability! + +## CORRECT: Additive float mask with large negative value: + + import jax + import jax.numpy as jnp + + def make_causal_mask(seq_len, dtype=jnp.float32): + ''' + Create additive causal mask. + + Returns: + mask: [seq_len, seq_len] where allowed=0.0, blocked=-1e9 + ''' + # Lower-triangular inclusive (k=0): position i can attend to j where j <= i + causal = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_), k=0) + mask = jnp.where(causal, 0.0, -1e9) + return mask.astype(dtype) + + # Usage: + attn_scores = q @ k.swapaxes(-2, -1) / jnp.sqrt(head_dim) + mask = make_causal_mask(seq_len, dtype=attn_scores.dtype) + attn_scores = attn_scores + mask # Add mask BEFORE softmax + attn_weights = jax.nn.softmax(attn_scores, axis=-1) + +## Key functions: + + # Lower triangular inclusive (causal: attend to self and past) + jnp.tril(jnp.ones((n, n)), k=0) + # [[1, 0, 0], + # [1, 1, 0], + # [1, 1, 1]] + + # Strict lower triangular (attend to past only, NOT self) + jnp.tril(jnp.ones((n, n)), k=-1) + # [[0, 0, 0], + # [1, 0, 0], + # [1, 1, 0]] + + # Strict upper triangular (what to BLOCK in causal attention) + jnp.triu(jnp.ones((n, n)), k=1) + # [[0, 1, 1], + # [0, 0, 1], + # [0, 0, 0]] + +## For chunk-parallel attention (within-chunk causal mask): + + def make_chunk_causal_mask(chunk_size, dtype=jnp.float32): + '''Causal mask for within-chunk attention.''' + causal = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=jnp.bool_), k=0) + return jnp.where(causal, 0.0, -1e9).astype(dtype) + + # For decay-based masking (gated delta rule): + # The decay mask is multiplicative but applied to attention weights + # BEFORE adding to the accumulator, not to raw scores before softmax. + # This is different from standard attention masking. + + def make_decay_mask(log_decay, chunk_size): + ''' + Create exponential decay mask for linear attention within a chunk. + + Args: + log_decay: [batch, heads, chunk_size] log-decay values per timestep + + Returns: + decay_mask: [batch, heads, chunk_size, chunk_size] where + mask[i,j] = exp(sum(log_decay[j+1:i+1])) for j <= i, 0 otherwise + ''' + # Cumulative sum of log-decay gives log of product of decays + cumsum = jnp.cumsum(log_decay, axis=-1) + + # decay_mask[i,j] = exp(cumsum[i] - cumsum[j]) + mask = jnp.exp(cumsum[..., :, None] - cumsum[..., None, :]) + + # Zero out upper triangle (future positions) + causal = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=jnp.bool_), k=0) + return jnp.where(causal, mask, 0.0) + +## Combining causal mask with padding mask: + + def make_combined_mask(seq_len, padding_lengths, dtype=jnp.float32): + ''' + Combine causal mask with padding mask. + + Args: + seq_len: sequence length + padding_lengths: [batch] number of padding tokens at start + + Returns: + mask: [batch, 1, seq_len, seq_len] broadcastable over heads + ''' + causal = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_), k=0) + + # Padding mask: True where position is valid (not padding) + positions = jnp.arange(seq_len) + valid = positions[None, :] >= padding_lengths[:, None] # [batch, seq_len] + + # Combine: attend only to valid, causal positions + combined = causal[None, :, :] & valid[:, None, :] # [batch, seq_len, seq_len] + mask = jnp.where(combined, 0.0, -1e9).astype(dtype) + return mask[:, None, :, :] # [batch, 1, seq_len, seq_len] for head broadcast + +## Why additive masking: + +1. **Correct softmax behavior**: Adding -1e9 before softmax makes masked positions + have exp(-1e9) ~ 0 probability. Multiplying by 0 after scores but before + softmax doesn't suppress probability correctly. +2. **Gradient flow**: Additive mask has clean gradients. Multiplicative mask + creates 0 * gradient = 0 issues. +3. **JAX convention**: JAX/Flax examples universally use additive masking. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_weight_init_patterns_jax.py b/MaxCode/rag/sources/targeted/targeted_weight_init_patterns_jax.py new file mode 100644 index 0000000..3ae8a33 --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_weight_init_patterns_jax.py @@ -0,0 +1,125 @@ +""" +TARGETED JAX PATTERN: Weight Initialization — PyTorch to Flax Mapping + +CRITICAL: Weight initialization must match the PyTorch source EXACTLY. Wrong init +breaks routing, norms, and weight loading from PyTorch checkpoints. Each layer type +has a specific initializer -- do NOT use a single default for everything. + +## PyTorch to Flax Initializer Mapping Table: + +This table applies to models with `_init_weights` methods (e.g., HuggingFace-style). +When no `_init_weights` exists and the source uses bare `nn.Linear`, use the Flax +default (`lecun_normal`) as the closest match to PyTorch's default Kaiming uniform. + +| PyTorch Layer / Init | Flax Initializer | +|-----------------------------------|----------------------------------------------------------| +| nn.Linear (general Dense) | nn.initializers.normal(stddev=config.initializer_range) | +| nn.Embedding | nn.initializers.normal(stddev=1.0) | +| MoE Router / Gate | nn.initializers.zeros_init() (when source explicitly zero-inits) | +| RMSNorm weight (1 + w formulation)| nn.initializers.zeros_init() | +| RMSNorm weight (w formulation) | nn.initializers.ones_init() | +| LayerNorm weight | nn.initializers.ones_init() | +| LayerNorm bias | nn.initializers.zeros_init() | +| Log-decay / log-tau parameters | Custom log_uniform_init or specific range | +| Conv1d weight (depthwise) | nn.initializers.normal(stddev=config.initializer_range) | +| Bias (general) | nn.initializers.zeros_init() | + +## WRONG: Using default or wrong init for router + + # WRONG! Normal init causes non-uniform routing from step 0 + class MoERouter(nn.Module): + num_experts: int + + @nn.compact + def __call__(self, x): + return nn.Dense(self.num_experts)(x) # Default normal init! + +## CORRECT: Zero-init for router + + class MoERouter(nn.Module): + num_experts: int + + @nn.compact + def __call__(self, x): + return nn.Dense( + self.num_experts, + kernel_init=nn.initializers.zeros_init(), + use_bias=False, + )(x) + +## WRONG: Using ones_init for RMSNorm when source uses (1 + w) formulation + + # If PyTorch source initializes RMSNorm weight to zeros and computes: + # output = x * rsqrt(mean(x^2) + eps) * (1 + self.weight) + # Then weight starts at 0, making the initial scale factor = 1. + + # WRONG! ones_init means initial scale = 1 + 1 = 2 + weight = self.param('scale', nn.initializers.ones_init(), (dim,)) + return normed * (1 + weight) + +## CORRECT: Match the source formulation + + # If source uses (1 + w) with w initialized to zeros: + weight = self.param('scale', nn.initializers.zeros_init(), (dim,)) + return normed * (1 + weight) + + # If source uses plain w with w initialized to ones: + weight = self.param('scale', nn.initializers.ones_init(), (dim,)) + return normed * weight + +## Dense layer initialization: + + # General Dense projection -- match config.initializer_range (typically 0.02) + nn.Dense( + features, + kernel_init=nn.initializers.normal(stddev=config.initializer_range), + use_bias=config.use_bias, + ) + +## Embedding initialization: + + nn.Embed( + num_embeddings=config.vocab_size, + features=config.hidden_size, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + +## Custom log-uniform initializer for decay/tau parameters: + + import jax + import jax.numpy as jnp + + def log_uniform_init(min_val, max_val): + '''Initialize in log-space uniformly between min_val and max_val.''' + def init(key, shape, dtype=jnp.float32): + log_min = jnp.log(jnp.array(min_val, dtype=dtype)) + log_max = jnp.log(jnp.array(max_val, dtype=dtype)) + return jax.random.uniform(key, shape, dtype=dtype, + minval=log_min, maxval=log_max) + return init + + # Usage for log-decay parameters: + log_decay = self.param('log_decay', log_uniform_init(1.0, 16.0), (num_heads,)) + decay = jnp.exp(-jnp.exp(log_decay)) + +## Additional notes: + +Note: RMSNorm epsilon defaults vary by model (1e-6 in Flax, 1e-5 in FLA/PyTorch). +Always match the source model's epsilon value. + +Note: Flax names norm weights 'scale'; PyTorch uses 'weight'. Checkpoint loading +must handle this mapping (e.g., rename 'weight' -> 'scale' when loading PyTorch +weights into Flax). + +## Why initialization matters: + +1. **Router zeros**: Ensures uniform expert selection at initialization. Normal init + creates random biases that can cause expert collapse (some experts never chosen). +2. **RMSNorm**: Wrong init changes the effective scale factor, which means loaded + PyTorch weights will produce different outputs. +3. **Dense layers**: stddev=0.02 matches the default PyTorch nn.Linear init for + transformer models (config.initializer_range). +4. **Weight loading**: When loading PyTorch checkpoints, the Flax model's init + doesn't matter for loaded weights. But for any randomly-initialized weights + (e.g., during pretraining), matching init is essential for convergence. +""" diff --git a/MaxCode/rag/sources/targeted/targeted_wy_representation_jax.py b/MaxCode/rag/sources/targeted/targeted_wy_representation_jax.py new file mode 100644 index 0000000..735eeff --- /dev/null +++ b/MaxCode/rag/sources/targeted/targeted_wy_representation_jax.py @@ -0,0 +1,83 @@ +""" +TARGETED JAX PATTERN: WY Representation for Chunk-Parallel Delta Rule + +When converting a PyTorch for-loop that computes a Neumann series row-by-row +on a lower-triangular matrix, DO NOT translate it as a jax.lax.scan with +dynamic slicing like attn[..., i, :i]. Dynamic slice sizes are NOT compatible +with jax.jit because JAX requires static shapes at trace time. + +INSTEAD, use jax.scipy.linalg.solve_triangular to compute (I - W)^{-1} +directly. This is mathematically equivalent to the Neumann series +I + W + W^2 + ... but is JIT-safe, GPU-parallelizable, and numerically stable. + +## The PyTorch Pattern (for-loop, do NOT copy directly): + + # PyTorch: row-by-row Neumann series (CANNOT run under jax.jit) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + \\ + (attn[..., i, :i, None] * attn[..., :i, :i]).sum(-2) + attn = attn + torch.eye(chunk_size) + +## The Correct JAX Pattern (solve_triangular): + + import jax + import jax.numpy as jnp + + # raw_attn is strictly lower triangular: -(k_beta @ key^T) * decay_mask + # with upper triangle and diagonal zeroed out + upper_mask = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=0) + raw_attn = -(k_beta @ jnp.transpose(key, (0, 1, 2, 4, 3))) * decay_mask + raw_attn = jnp.where(upper_mask, 0.0, raw_attn) + + # Compute (I - W)^{-1} using solve_triangular + # This solves (I - W) @ X = I, giving X = (I - W)^{-1} + eye = jnp.eye(chunk_size) + attn = jax.scipy.linalg.solve_triangular( + eye - raw_attn, # unit lower triangular matrix + eye, # solve for identity -> gives the inverse + lower=True, # it's lower triangular + ) + + # Then apply the WY transform: + value_corrected = attn @ v_beta + k_cumdecay = attn @ (k_beta * jnp.exp(g_cumsum)[..., None]) + +## Why solve_triangular works: + +The for-loop computes the Neumann series I + W + W^2 + ... which equals +(I - W)^{-1} for strictly lower triangular W. solve_triangular computes +this directly via back-substitution, which is: +- O(n^2) per row, same complexity as the for-loop +- JIT-compatible (no dynamic shapes) +- GPU-parallelizable (LAPACK/cuSOLVER backend) +- Numerically stable + +## Inter-chunk scan pattern: + +After computing the WY correction within each chunk, use jax.lax.scan +across chunks to accumulate the recurrent state: + + def chunk_scan_fn(S_prev, chunk_inputs): + q_c, k_c, v_c, k_cumdec_c, g_c, decay_c = chunk_inputs + + # Intra-chunk attention + intra_attn = (q_c @ jnp.transpose(k_c, (0, 1, 3, 2))) * decay_c + intra_attn = jnp.where(upper_mask_strict, 0.0, intra_attn) + + # Inter-chunk: project through accumulated state + v_prime = k_cumdec_c @ S_prev + v_new = v_c - v_prime + attn_inter = (q_c * jnp.exp(g_c)[..., None]) @ S_prev + + # Combine + out_c = attn_inter + intra_attn @ v_new + + # Update state + g_last = g_c[..., -1, None, None] + k_weighted = k_c * jnp.exp(g_c[..., -1:] - g_c)[..., None] + S_next = S_prev * jnp.exp(g_last) + jnp.transpose(k_weighted, (0, 1, 3, 2)) @ v_new + + return S_next, out_c + + final_state, core_attn_out = jax.lax.scan(chunk_scan_fn, init_S, scan_inputs) +""" diff --git a/MaxCode/rag/vector_db.py b/MaxCode/rag/vector_db.py index 6efed09..20504bc 100644 --- a/MaxCode/rag/vector_db.py +++ b/MaxCode/rag/vector_db.py @@ -3,11 +3,25 @@ import os import pickle import sqlite3 +from typing import Optional import numpy as np RAG_DB_FILE = os.path.join(os.environ["HOME"], "rag_store.db") +def _ensure_corpus_column(cur: sqlite3.Cursor) -> None: + """Adds the `corpus` column to existing databases that pre-date it. + + Existing rows are tagged 'jax' (the original behaviour). + """ + cur.execute("PRAGMA table_info(documents)") + columns = {row[1] for row in cur.fetchall()} + if "corpus" not in columns: + cur.execute( + "ALTER TABLE documents ADD COLUMN corpus TEXT NOT NULL DEFAULT 'jax'" + ) + + def create_db(db_path: str = RAG_DB_FILE): """Create the SQLite database and `documents` table if they do not exist. @@ -23,9 +37,11 @@ def create_db(db_path: str = RAG_DB_FILE): text TEXT NOT NULL, desc TEXT NOT NULL, file TEXT NOT NULL, - embedding BLOB NOT NULL + embedding BLOB NOT NULL, + corpus TEXT NOT NULL DEFAULT 'jax' ) """) + _ensure_corpus_column(cur) conn.commit() conn.close() @@ -37,6 +53,7 @@ def save_document( file: str, embedding: np.ndarray, db_path: str = RAG_DB_FILE, + corpus: str = "jax", ): """Insert a document and its embedding into the database. @@ -48,14 +65,16 @@ def save_document( embedding: Dense vector representation of the document with shape (dim,) and dtype convertible to float32. db_path: Path to the SQLite database file. + corpus: Logical corpus tag for filtering (e.g. "jax" or "maxtext"). """ conn = sqlite3.connect(db_path) cur = conn.cursor() + _ensure_corpus_column(cur) emb_binary = pickle.dumps(embedding.astype(np.float32)) cur.execute( - "INSERT INTO documents (name,text,desc,file, embedding) VALUES (?," - " ?,?,?,?)", - (name, text, desc, file, emb_binary), + "INSERT INTO documents (name,text,desc,file, embedding, corpus) VALUES" + " (?, ?,?,?,?,?)", + (name, text, desc, file, emb_binary, corpus), ) conn.commit() conn.close() @@ -63,11 +82,13 @@ def save_document( def load_all_documents( db_path: str = RAG_DB_FILE, + corpus: Optional[str] = None, ) -> tuple[list[int], list[str], list[str], list[str], np.ndarray]: """Load all documents and embeddings from the database. Args: db_path: Path to the SQLite database file. + corpus: Optional corpus tag to filter on. If None, all rows are returned. Returns: tuple[list[int], list[str], list[str], list[str], numpy.ndarray]: @@ -79,7 +100,14 @@ def load_all_documents( """ conn = sqlite3.connect(db_path) cur = conn.cursor() - cur.execute("SELECT id,name, text,file, embedding FROM documents") + _ensure_corpus_column(cur) + if corpus is None: + cur.execute("SELECT id,name, text,file, embedding FROM documents") + else: + cur.execute( + "SELECT id,name, text,file, embedding FROM documents WHERE corpus = ?", + (corpus,), + ) rows = cur.fetchall() conn.close() @@ -134,16 +162,18 @@ def search_embedding( def make_embedding_index( db_path: str = RAG_DB_FILE, + corpus: Optional[str] = None, ) -> tuple[list[int], list[str], list[str], list[str], np.ndarray | None]: """Load all documents and return embeddings as index. Args: db_path: Path to the SQLite database file. + corpus: Optional corpus tag to filter on. If None, all rows are loaded. Returns: tuple[list[int], list[str], list[str], list[str], np.ndarray | None]: (ids, names, texts, files, index) """ - ids, names, texts, files, embeddings = load_all_documents(db_path) + ids, names, texts, files, embeddings = load_all_documents(db_path, corpus=corpus) index = build_numpy_index(embeddings) return ids, names, texts, files, index