Skip to content

Commit 7bb8137

Browse files
authored
Megatron LoRA correctness: align distributed semantics with Megatron and validate TP/EP/ETP/DP against an oracle (#619)
* megatron: integrate lora grad sync with finalize_model_grads * megatron: harden sharded lora merge validation * tests: add megatron lora oracle correctness harness * Minor typing changes * megatron: extend LoRA grad-sync semantics across tp/expert-tp * megatron: add MoE routing replay core and unit tests * megatron runtime/service: wire routing replay into training jobs * oracle worker/trace: capture forward traces and emit replay bundles * oracle harness/tests: refactor suite and add oracle-replay parity flow * typing: clear blocking ty errors in oracle replay and LoRA paths * megatron: reduce oracle variance with sequence grad accumulation Use per-step micro-accumulation over multiple packed sequences so updates are less sensitive to sparse expert token assignment. Also make backend progress accounting accumulation-aware. * megatron lora: fix TP/EP export participation rules Correct LoRA shard export behavior so non-zero TP ranks in EP/ETP topologies contribute when required, while still filtering replicated-only entries. * oracle trace: canonicalize MoE outputs across arbitrary topologies Move normalization logic into ForwardTraceCapture so saved traces are canonicalized toward world-size-1 semantics (expert row identity/order and ETP fc1 layout). * oracle harness: stabilize scoring and expand sensitivity mutations Rework oracle pass/fail evaluation with per-phase functions, layer-averaged metrics, deterministic init, expanded sensitivity mutations, and smaller Adam epsilon for tiny-gradient regimes. * oracle tests: write suite output tables to log files Redirect suite stdout/stderr into local correctness/sensitivity logs and make skip/report messaging point to those artifacts instead of terminal output. * Add correct data parallelism. * Fix per-token DP normalization in Megatron training * Expand the oracle harness for DP correctness checks * Clean up type errors in Megatron correctness changes * Testing harness was working, but real training surfaced a few errors, mostly fixed. * Cut over Megatron LoRA to QuACK * Del held packed tensors so dir can be removed. Plus small typing changes. * Fuse LoRA scale into QuACK grouped GEMM * Avoid grad_out copy in QuACK LoRA backward * Fuse MoE FC1 gate and up LoRA paths * Tune QuACK low-rank tiles and rank contract * Inline FC1 QuACK dual call * Revert unnecessary python 3.12 requirement. * Create lora without instantiating full model by using meta device. - Fix routing replay for torch 2.10.0. * Update Megatron dependencies for transformers v5 change. * Update megatron tests for new lora kernel and avg grads across experts for stability. * Limit max build jobs when building the uv cache. * Fix CI uv cache build robustness * Tune CI uv cache build concurrency * Fix CI Apex cache contract
1 parent dc8c338 commit 7bb8137

28 files changed

+10901
-2718
lines changed

.github/workflows/prek.yml

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ env:
1313
CI_PYTHON_MM: "3.11"
1414
CI_UV_CACHE_RELEASE_TAG: "prek-uv-cache"
1515
CI_UV_CACHE_ASSET_PREFIX: "prek-uv-cache"
16+
CI_APEX_PARALLEL_BUILD: "8"
17+
CI_APEX_NVCC_THREADS: "1"
18+
CI_UV_BUILD_SLOTS: "2"
1619
UV_CACHE_DIR: "/root/.cache/uv"
1720
UV_LINK_MODE: "copy"
1821
TORCH_CUDA_ARCH_LIST: "8.0"
@@ -34,7 +37,9 @@ jobs:
3437
--pyproject pyproject.toml \
3538
--uv-lock uv.lock \
3639
--base-image "${CI_BASE_IMAGE}" \
37-
--python-mm "${CI_PYTHON_MM}")"
40+
--python-mm "${CI_PYTHON_MM}" \
41+
--ci-apex-parallel-build "${CI_APEX_PARALLEL_BUILD}" \
42+
--ci-apex-nvcc-threads "${CI_APEX_NVCC_THREADS}")"
3843
echo "fingerprint=${fp}" >> "${GITHUB_OUTPUT}"
3944
echo "Expected uv cache fingerprint: ${fp}"
4045
@@ -198,6 +203,13 @@ jobs:
198203
199204
- name: Install dependencies (with all optional extras for complete type checking)
200205
run: |
206+
original_pyproject="$(mktemp)"
207+
cp pyproject.toml "${original_pyproject}"
208+
cleanup() {
209+
mv "${original_pyproject}" pyproject.toml
210+
}
211+
trap cleanup EXIT
212+
201213
py_mm="$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')"
202214
cudnn_path="${GITHUB_WORKSPACE}/.venv/lib/python${py_mm}/site-packages/nvidia/cudnn"
203215
export CUDNN_PATH="${cudnn_path}"
@@ -207,13 +219,22 @@ jobs:
207219
export CPLUS_INCLUDE_PATH="${CUDNN_INCLUDE_PATH}${CPLUS_INCLUDE_PATH:+:${CPLUS_INCLUDE_PATH}}"
208220
export LIBRARY_PATH="${CUDNN_LIBRARY_PATH}${LIBRARY_PATH:+:${LIBRARY_PATH}}"
209221
export LD_LIBRARY_PATH="${CUDNN_LIBRARY_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}"
222+
export UV_CONCURRENT_BUILDS="${CI_UV_BUILD_SLOTS}"
223+
export CMAKE_BUILD_PARALLEL_LEVEL="${CI_APEX_PARALLEL_BUILD}"
224+
export MAX_JOBS="${CI_APEX_PARALLEL_BUILD}"
225+
export NINJAFLAGS="-j${CI_APEX_PARALLEL_BUILD}"
226+
python3 scripts/ci/apply_ci_uv_build_overrides.py \
227+
--pyproject pyproject.toml \
228+
--apex-parallel-build "${CI_APEX_PARALLEL_BUILD}" \
229+
--apex-nvcc-threads "${CI_APEX_NVCC_THREADS}"
230+
echo "CI uv build overrides: APEX_PARALLEL_BUILD=${CI_APEX_PARALLEL_BUILD}, NVCC_APPEND_FLAGS=--threads ${CI_APEX_NVCC_THREADS}, UV_CONCURRENT_BUILDS=${CI_UV_BUILD_SLOTS}"
210231
uv --version
211232
uv sync --all-extras --group dev --frozen
212233
213234
- name: Run prek hooks (lint, format, typecheck, uv.lock, tests)
214235
run: |
215-
uv run prek run --all-files
236+
uv run --no-sync prek run --all-files
216237
217238
- name: Run unit tests (via prek)
218239
run: |
219-
uv run prek run pytest
240+
uv run --no-sync prek run pytest

0 commit comments

Comments
 (0)