-
Notifications
You must be signed in to change notification settings - Fork 527
Add distillation launchers for qwen3-30b-a3b-base and gpt-oss-20b #4028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,69 @@ | ||||||
| #!/bin/bash | ||||||
| # Launch gpt-oss-20b distillation on TPU v7x. | ||||||
| # Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done] | ||||||
| # | ||||||
| # Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own | ||||||
| # image before running — the placeholders below will not work as-is. Everything | ||||||
| # else has a working default. Example: | ||||||
| # | ||||||
| # DISTILL_GCS_BUCKET=gs://your-bucket \ | ||||||
| # XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \ | ||||||
| # bash scripts/distillation/distill_gpt_oss_20b.sh submit | ||||||
| set -euo pipefail | ||||||
|
|
||||||
| MODE="${1:-submit}" | ||||||
| REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) | ||||||
| cd "$REPO_ROOT" | ||||||
|
|
||||||
| # Your GCS bucket: run outputs, staged YAML, and tokenizer files all land here. | ||||||
| export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}" | ||||||
|
|
||||||
| export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}" | ||||||
| export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}" | ||||||
| export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" | ||||||
| export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" | ||||||
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||||||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||||||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}" | ||||||
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" | ||||||
| export XPK_PRIORITY="${XPK_PRIORITY:-high}" | ||||||
|
|
||||||
| export XPK_USE_GCSFUSE=1 | ||||||
| export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" | ||||||
| export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" | ||||||
|
|
||||||
| # Stage HF tokenizer files (not in the image for gpt-oss). | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 Using a specific user's bucket as a default for `XPK_YAML_GCS` will cause the `submit` mode to fail for any other user due to lack of write permissions. Consider using a more generic placeholder or documenting this as a mandatory override.
Suggested change
|
||||||
| export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/tokenizer-gpt-oss-20b/}" | ||||||
| export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}" | ||||||
|
|
||||||
| LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml" | ||||||
| export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" | ||||||
| export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_gpt_oss_20b.yml}" | ||||||
|
|
||||||
| # distill_beta=0: decoder feature loss is broken on gpt-oss. | ||||||
| export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}" | ||||||
| export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" | ||||||
| export DISTILL_BETA="${DISTILL_BETA:-0}" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 For consistency with the `qwen3` script and the default in `run_distill_xpk.sh`, consider using `61440` (60MB) unless `65536` (64MB) was specifically found to be necessary for `gpt-oss-20b`.
Suggested change
|
||||||
| export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}" | ||||||
|
|
||||||
| # XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required | ||||||
| # by latency_hiding_layer_scheduler. | ||||||
| export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \ | ||||||
| --xla_tpu_impure_enable_packed_bf16_math_ops=true \ | ||||||
| --xla_tpu_aggressive_opt_barrier_removal=true \ | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 For consistency with other scripts in the repository (e.g., `distill_qwen3_30b_base.sh` and `run_distill_xpk.sh`) and the default XLA flags defined in `benchmarks/xla_flags_library.py`, consider using `ENABLED` instead of `true` for this flag.
Suggested change
|
||||||
| --xla_tpu_enable_sparse_core_collective_aggregator=true \ | ||||||
| --xla_tpu_enable_latency_hiding_layer_scheduler=true \ | ||||||
| --xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \ | ||||||
| --xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \ | ||||||
| --xla_tpu_scheduler_percent_shared_memory_limit=150 \ | ||||||
| --xla_enable_async_all_gather=true \ | ||||||
| --xla_tpu_prefer_async_allgather_to_allreduce=true \ | ||||||
| --xla_max_concurrent_async_all_gathers=2 \ | ||||||
| --xla_max_concurrent_async_reduce_scatters=2 \ | ||||||
| --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}" | ||||||
|
|
||||||
| if [ "$MODE" = "submit" ]; then | ||||||
| gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" | ||||||
| fi | ||||||
|
|
||||||
| exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,62 @@ | ||||||
| #!/bin/bash | ||||||
| # Launch qwen3-30b-a3b-base distillation on TPU v7x. | ||||||
| # Usage: bash scripts/distillation/distill_qwen3_30b_base.sh [submit|monitor|resume_until_done] | ||||||
| # | ||||||
| # Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own | ||||||
| # image before running — the placeholders below will not work as-is. Everything | ||||||
| # else has a working default. Example: | ||||||
| # | ||||||
| # DISTILL_GCS_BUCKET=gs://your-bucket \ | ||||||
| # XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \ | ||||||
| # bash scripts/distillation/distill_qwen3_30b_base.sh submit | ||||||
| set -euo pipefail | ||||||
|
|
||||||
| MODE="${1:-submit}" | ||||||
| REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) | ||||||
| cd "$REPO_ROOT" | ||||||
|
|
||||||
| # Your GCS bucket: run outputs and staged YAML land here. | ||||||
| export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}" | ||||||
|
|
||||||
| export XPK_WORKLOAD="${XPK_WORKLOAD:-q30b-base-$(date +%Y%m%d-%H%M)}" | ||||||
| export XPK_RUN_NAME="${XPK_RUN_NAME:-qwen3_30b_base}" | ||||||
| export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" | ||||||
| export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" | ||||||
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||||||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||||||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}" | ||||||
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 Consider using a more generic or stable base image reference for the demo default.
Suggested change
|
||||||
| export XPK_PRIORITY="${XPK_PRIORITY:-high}" | ||||||
|
|
||||||
| export XPK_USE_GCSFUSE=1 | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 Similar to the GPT-OSS script, the default `XPK_YAML_GCS` points to a specific user's bucket, which will prevent other users from using the `submit` mode out of the box.
Suggested change
|
||||||
| export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" | ||||||
| export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" | ||||||
|
|
||||||
| LOCAL_YAML="src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml" | ||||||
| export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" | ||||||
| export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_qwen3_30b_base.yml}" | ||||||
|
|
||||||
| export DISTILL_ALPHA="${DISTILL_ALPHA:-0.6}" | ||||||
| export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" | ||||||
| export DISTILL_BETA="${DISTILL_BETA:-1.0}" | ||||||
| export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[0,1,2,3,4,5,6,7]}" | ||||||
|
|
||||||
| # XLA flags tuned for ~20% MFU. | ||||||
| export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \ | ||||||
| --xla_tpu_enable_all_experimental_scheduler_features=true \ | ||||||
| --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ | ||||||
| --xla_tpu_host_transfer_overlap_limit=24 \ | ||||||
| --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ | ||||||
| --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ | ||||||
| --xla_tpu_enable_ag_backward_pipelining=true \ | ||||||
| --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ | ||||||
| --xla_should_add_loop_invariant_op_in_chain=ENABLED \ | ||||||
| --xla_max_concurrent_host_send_recv=100 \ | ||||||
| --xla_tpu_scheduler_percent_shared_memory_limit=100 \ | ||||||
| --xla_latency_hiding_scheduler_rerun=2}" | ||||||
|
|
||||||
| if [ "$MODE" = "submit" ]; then | ||||||
| gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" | ||||||
| fi | ||||||
|
|
||||||
| exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # gpt-oss-20b distillation. ~17% MFU on v7x. | ||
| base_config: "base.yml" | ||
|
|
||
| # NOTE: load_parameters_path values are placeholders. Replace them with paths | ||
| # to your own student and teacher checkpoints (Orbax format, ending in /0/items). | ||
| student_overrides: | ||
| model_name: "gpt-oss-20b" | ||
| override_model_config: True | ||
| base_num_query_heads: 32 | ||
| head_dim: 128 | ||
| base_num_kv_heads: 4 | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/student/0/items" | ||
| teacher_overrides: | ||
| model_name: "gpt-oss-20b" | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/teacher/0/items" | ||
|
|
||
| # distill_beta=0: decoder feature loss is broken on gpt-oss. | ||
| distill_alpha: 0.5 | ||
| distill_temperature: 1.0 | ||
| distill_beta: 0 | ||
| distill_layer_indices: [] | ||
| enable_nnx: True | ||
| load_balance_loss_weight: 0.001 | ||
|
gagika marked this conversation as resolved.
|
||
|
|
||
| ici_fsdp_parallelism: 32 | ||
| ici_data_parallelism: 4 | ||
|
|
||
| dataset_type: "grain" | ||
| grain_file_type: "arrayrecord" | ||
| # Launcher's gcsfuse override (CLI arg) replaces this for the student config. | ||
| grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" | ||
| grain_worker_count: 16 | ||
| grain_ram_budget_mb: 1024 | ||
| grain_per_worker_buffer_size: 2 | ||
| grain_prefetch_buffer_size: 1024 | ||
| num_epoch: 10 | ||
|
|
||
| tokenizer_path: "src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer" | ||
| tokenizer_type: "huggingface" | ||
|
|
||
| max_target_length: 32768 | ||
| per_device_batch_size: 1 | ||
| gradient_accumulation_steps: 1 | ||
|
|
||
| steps: 100000 | ||
| learning_rate_schedule_steps: 100000 | ||
| log_period: 10 | ||
| checkpoint_period: 500 | ||
| save_checkpoint_on_completion: True | ||
| enable_checkpointing: True | ||
| async_checkpointing: True | ||
| skip_jax_distributed_system: False | ||
|
|
||
| learning_rate: 5.0e-5 | ||
| learning_rate_final_fraction: 0.2 | ||
| warmup_steps_fraction: 0.02 | ||
| adam_b1: 0.9 | ||
| adam_b2: 0.95 | ||
| adam_eps: 1.e-5 | ||
| adam_weight_decay: 0.01 | ||
| adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] | ||
| z_loss_multiplier: 1.0e-5 | ||
| float32_logits: True | ||
|
|
||
| # Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. | ||
| attention: "flash" | ||
| use_tokamax_splash: True | ||
| sa_use_fused_bwd_kernel: True | ||
| sa_block_q: 2048 | ||
| sa_block_kv: 2048 | ||
| sa_block_kv_compute: 2048 | ||
| sa_block_q_dkv: 2048 | ||
| sa_block_kv_dkv: 2048 | ||
| sa_block_kv_dkv_compute: 2048 | ||
| sa_block_q_dq: 2048 | ||
| sa_block_kv_dq: 2048 | ||
| sa_q_layout: "SEQ_MINOR" | ||
| sa_k_layout: "SEQ_MINOR" | ||
| sa_v_layout: "SEQ_MINOR" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Qwen3-30b-a3b-base distillation. ~20% MFU on v7x. | ||
| base_config: "base.yml" | ||
|
|
||
| # NOTE: load_parameters_path values are placeholders. Replace them with paths | ||
| # to your own student and teacher checkpoints (Orbax format, ending in /0/items). | ||
| student_overrides: | ||
| model_name: "qwen3-30b-a3b-base" | ||
| override_model_config: True | ||
| base_num_query_heads: 16 | ||
| head_dim: 256 | ||
| base_num_kv_heads: 2 | ||
| rope_max_timescale: 1_000_000 | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items" | ||
| teacher_overrides: | ||
| override_model_config: True | ||
| model_name: "qwen3-30b-a3b-base" | ||
| rope_max_timescale: 1_000_000 | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items" | ||
|
|
||
| distill_alpha: 0.6 | ||
| distill_temperature: 1.0 | ||
| distill_beta: 1.0 | ||
| distill_layer_indices: [0,1,2,3,4,5,6,7] | ||
| enable_nnx: True | ||
| load_balance_loss_weight: 0.001 | ||
|
|
||
| ici_fsdp_parallelism: -1 | ||
| ici_data_parallelism: 4 | ||
|
|
||
| dataset_type: "grain" | ||
| grain_file_type: "arrayrecord" | ||
| # Launcher's gcsfuse override (CLI arg) replaces this for the student config. | ||
| grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" | ||
| grain_worker_count: 16 | ||
| grain_ram_budget_mb: 1024 | ||
| grain_per_worker_buffer_size: 2 | ||
| grain_prefetch_buffer_size: 1024 | ||
| num_epoch: 10 | ||
|
|
||
| tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer" | ||
| tokenizer_type: "huggingface" | ||
|
|
||
| max_target_length: 8192 | ||
| per_device_batch_size: 4 | ||
| gradient_accumulation_steps: 1 | ||
|
|
||
| steps: 100000 | ||
| learning_rate_schedule_steps: 100000 | ||
| log_period: 10 | ||
| checkpoint_period: 500 | ||
| save_checkpoint_on_completion: True | ||
| enable_checkpointing: True | ||
| async_checkpointing: True | ||
| skip_jax_distributed_system: False | ||
|
|
||
| learning_rate: 1.0e-4 | ||
| learning_rate_final_fraction: 0.1 | ||
| warmup_steps_fraction: 0.08 | ||
| adam_b1: 0.9 | ||
| adam_b2: 0.95 | ||
| adam_eps: 1.e-5 | ||
| adam_weight_decay: 0.01 | ||
| adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] | ||
| z_loss_multiplier: 1.0e-5 | ||
| float32_logits: True | ||
|
|
||
| # Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. | ||
| attention: "flash" | ||
| use_tokamax_splash: True | ||
| sa_use_fused_bwd_kernel: True | ||
| sa_block_q: 1024 | ||
| sa_block_kv: 1024 | ||
| sa_block_kv_compute: 512 | ||
| sa_block_q_dkv: 2048 | ||
| sa_block_kv_dkv: 2048 | ||
| sa_block_kv_dkv_compute: 1024 | ||
| sa_block_q_dq: 1024 | ||
| sa_block_kv_dq: 1024 | ||
| sa_q_layout: "HEAD_DIM_MINOR" | ||
| sa_k_layout: "SEQ_MINOR" | ||
| sa_v_layout: "HEAD_DIM_MINOR" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Qwen3-30b-a3b-base distillation, pdbs=8 + activation offload. ~22% MFU on v7x. | ||
| base_config: "base.yml" | ||
|
|
||
| # NOTE: load_parameters_path values are placeholders. Replace them with paths | ||
| # to your own student and teacher checkpoints (Orbax format, ending in /0/items). | ||
| student_overrides: | ||
| model_name: "qwen3-30b-a3b-base" | ||
| override_model_config: True | ||
| base_num_query_heads: 16 | ||
| head_dim: 256 | ||
| base_num_kv_heads: 2 | ||
| rope_max_timescale: 1_000_000 | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items" | ||
| teacher_overrides: | ||
| override_model_config: True | ||
| model_name: "qwen3-30b-a3b-base" | ||
| rope_max_timescale: 1_000_000 | ||
| load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items" | ||
|
|
||
| distill_alpha: 0.6 | ||
| distill_temperature: 1.0 | ||
| distill_beta: 1.0 | ||
| distill_layer_indices: [0,1,2,3,4,5,6,7] | ||
| enable_nnx: True | ||
| load_balance_loss_weight: 0.001 | ||
|
|
||
| ici_fsdp_parallelism: -1 | ||
| ici_data_parallelism: 4 | ||
|
|
||
| dataset_type: "grain" | ||
| grain_file_type: "arrayrecord" | ||
| # Launcher's gcsfuse override (CLI arg) replaces this for the student config. | ||
| grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" | ||
| grain_worker_count: 16 | ||
| grain_ram_budget_mb: 1024 | ||
| grain_per_worker_buffer_size: 2 | ||
| grain_prefetch_buffer_size: 1024 | ||
| num_epoch: 10 | ||
|
|
||
| tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer" | ||
| tokenizer_type: "huggingface" | ||
|
|
||
| max_target_length: 8192 | ||
| per_device_batch_size: 8 | ||
| gradient_accumulation_steps: 1 | ||
|
|
||
| # Activation offload to fit pdbs=8. | ||
| remat_policy: 'custom' | ||
| decoder_layer_input: 'offload' | ||
|
|
||
| steps: 100000 | ||
| learning_rate_schedule_steps: 100000 | ||
| log_period: 10 | ||
| checkpoint_period: 500 | ||
| save_checkpoint_on_completion: True | ||
| enable_checkpointing: True | ||
| async_checkpointing: True | ||
| skip_jax_distributed_system: False | ||
|
|
||
| learning_rate: 1.0e-4 | ||
| learning_rate_final_fraction: 0.1 | ||
| warmup_steps_fraction: 0.08 | ||
| adam_b1: 0.9 | ||
| adam_b2: 0.95 | ||
| adam_eps: 1.e-5 | ||
| adam_weight_decay: 0.01 | ||
| adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] | ||
| z_loss_multiplier: 1.0e-5 | ||
| float32_logits: True | ||
|
|
||
| # Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. | ||
| attention: "flash" | ||
| use_tokamax_splash: True | ||
| sa_use_fused_bwd_kernel: True | ||
| sa_block_q: 1024 | ||
| sa_block_kv: 1024 | ||
| sa_block_kv_compute: 512 | ||
| sa_block_q_dkv: 2048 | ||
| sa_block_kv_dkv: 2048 | ||
| sa_block_kv_dkv_compute: 1024 | ||
| sa_block_q_dq: 1024 | ||
| sa_block_kv_dq: 1024 | ||
| sa_q_layout: "HEAD_DIM_MINOR" | ||
| sa_k_layout: "SEQ_MINOR" | ||
| sa_v_layout: "HEAD_DIM_MINOR" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.