-
Notifications
You must be signed in to change notification settings - Fork 26
Documentation for JIT support #1145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
454f20d
first draft, need to refine
aseembits93 9fe6ef7
todo write more about the flags in torch/tensorflow and jax
aseembits93 85344f5
keep editing
aseembits93 eb9b3df
add examples
aseembits93 15f4b6d
typos
aseembits93 06f5460
start testing
aseembits93 2a73955
start cleaning up
aseembits93 446fdf9
sample code not needed right now
aseembits93 7b9d09a
mintlify icon
aseembits93 3ab8fbb
almost ready
aseembits93 754eb6c
improve gantt chart
aseembits93 ec3eed6
ready to review
aseembits93 8c66acb
reordering sections
aseembits93 2dcfba6
Merge branch 'main' into jit-docs
aseembits93 d571512
Merge branch 'main' into jit-docs
HeshamHM28 d020da8
Merge branch 'main' into jit-docs
aseembits93 b9cc789
precommit fix
aseembits93 5f74b14
restore version
aseembits93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,266 @@ | ||
| --- | ||
| title: "Support for Just-in-Time Compilation" | ||
| description: "Learn how Codeflash optimizes code using JIT compilation with Numba, PyTorch, TensorFlow, and JAX" | ||
| icon: "bolt" | ||
| sidebarTitle: "JIT Compilation" | ||
| keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU", "CUDA", "compilation", "performance"] | ||
| --- | ||
|
|
||
| # Support for Just-in-Time Compilation | ||
|
|
||
| Codeflash supports optimizing numerical code using Just-in-Time (JIT) compilation via leveraging JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**. | ||
|
|
||
| ## Supported JIT Frameworks | ||
|
|
||
| Each framework uses different compilation strategies to accelerate Python code: | ||
|
|
||
| ### Numba (CPU Code) | ||
|
|
||
| Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use: | ||
|
|
||
| - **`@jit`** - General-purpose JIT compilation with optional flags. | ||
| - **`nopython=True`** - Compiles to machine code without falling back to the Python interpreter. | ||
| - **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag. | ||
| - **`cache=True`** - cache compiled function to disk which reduces future runtimes. | ||
| - **`parallel=True`** - Parallelizes code inside loops. | ||
|
|
||
| ### PyTorch | ||
|
|
||
| PyTorch provides JIT compilation through `torch.compile()`, the recommended compilation API introduced in PyTorch 2.0. It uses TorchDynamo to capture Python bytecode and TorchInductor to generate optimized kernels. | ||
|
|
||
| - **`torch.compile()`** - Compiles a function or module for optimized execution. | ||
| - **`mode`** - Controls the compilation strategy: | ||
| - `"default"` - Balanced compilation with moderate optimization. | ||
| - `"reduce-overhead"` - Minimizes Python overhead using CUDA graphs, ideal for small batches. | ||
| - `"max-autotune"` - Spends more time autotuning to find the fastest kernels. | ||
| - **`fullgraph=True`** - Requires the entire function to be captured as a single graph. Raises an error if graph breaks occur, useful for ensuring complete optimization. | ||
| - **`dynamic=True`** - Enables dynamic shape support, allowing the compiled function to handle varying input sizes without recompilation. | ||
|
|
||
| ### TensorFlow | ||
|
|
||
| TensorFlow uses `@tf.function` to compile Python functions into optimized TensorFlow graphs. When combined with XLA (Accelerated Linear Algebra), it can generate highly optimized machine code for both CPU and GPU. | ||
|
|
||
| - **`@tf.function`** - Converts Python functions into TensorFlow graphs for optimized execution. | ||
| - **`jit_compile=True`** - Enables XLA compilation, which performs whole-function optimization including operation fusion, memory layout optimization, and target-specific code generation. | ||
|
|
||
| ### JAX | ||
|
|
||
| JAX uses XLA to JIT compile pure functions into optimized machine code. It emphasizes functional programming patterns and captures side-effect-free operations for optimization. | ||
|
|
||
| - **`@jax.jit`** - JIT compiles functions using XLA with automatic operation fusion. | ||
|
|
||
| ## How Codeflash Optimizes with JIT | ||
|
|
||
| When Codeflash identifies a function that could benefit from JIT compilation, it: | ||
|
|
||
| 1. Rewrites the code in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components. | ||
| 2. Generates appropriate tests that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements. | ||
| 3. Disables JIT compilation while running coverage and tracer to get accurate coverage and trace information. Both of them rely on Python bytecode execution but JIT compiled code stops running as Python bytecode. | ||
| 4. Disables Line Profiler information collection whenever presented with JIT compiled code. It could be possible to disable JIT compilation and run the line profiler, but that would lead to inaccurate information which could misguide the optimization process. | ||
|
|
||
| ## Accurate Benchmarking on Non-CPU devices | ||
|
|
||
| Since Non-CPU operations execute asynchronously, Codeflash automatically inserts synchronization barriers before measuring performance. This ensures timing measurements reflect actual computation time rather than just the time to queue operations: | ||
|
|
||
| - **PyTorch**: Uses `torch.cuda.synchronize()` (NVIDIA GPUs) or `torch.mps.synchronize()` (MacOS Metal Performance Shaders) depending on the device. | ||
| - **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete. | ||
| - **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization. | ||
|
|
||
| ## When JIT Compilation Helps | ||
|
|
||
| JIT compilation is most effective for: | ||
|
|
||
| - Numerical computations with loops that can't be easily vectorized. | ||
| - Custom algorithms not covered by existing optimized libraries. | ||
| - Functions that are called repeatedly with consistent input types. | ||
| - Code that benefits from hardware-specific optimizations (SIMD, GPU acceleration). | ||
|
|
||
| ### Example | ||
|
|
||
| #### Function Definition | ||
|
|
||
| ```python | ||
| import torch | ||
| def complex_activation(x): | ||
| """A custom activation with many small operations - compile makes a huge difference""" | ||
| # Many sequential element-wise ops create kernel launch overhead | ||
| x = torch.sin(x) | ||
| x = x * torch.cos(x) | ||
| x = x + torch.exp(-x.abs()) | ||
| x = x / (1 + x.pow(2)) | ||
| x = torch.tanh(x) * torch.sigmoid(x) | ||
| x = x - 0.5 * x.pow(3) | ||
| return x | ||
| ``` | ||
|
|
||
| #### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac) | ||
|
|
||
| ```python | ||
| import time | ||
| # Create compiled version | ||
| complex_activation_compiled = torch.compile(complex_activation) | ||
|
|
||
| # Benchmark | ||
| x = torch.randn(1000, 1000, device='cuda') | ||
|
|
||
| # Warmup | ||
| for _ in range(10): | ||
| _ = complex_activation(x) | ||
| _ = complex_activation_compiled(x) | ||
|
|
||
| # Time uncompiled | ||
| torch.cuda.synchronize() | ||
| start = time.time() | ||
| for _ in range(100): | ||
| y = complex_activation(x) | ||
| torch.cuda.synchronize() | ||
| uncompiled_time = time.time() - start | ||
|
|
||
| # Time compiled | ||
| torch.cuda.synchronize() | ||
| start = time.time() | ||
| for _ in range(100): | ||
| y = complex_activation_compiled(x) | ||
| torch.cuda.synchronize() | ||
| compiled_time = time.time() - start | ||
|
|
||
| print(f"Uncompiled: {uncompiled_time:.4f}s") | ||
| print(f"Compiled: {compiled_time:.4f}s") | ||
| print(f"Speedup: {uncompiled_time/compiled_time:.2f}x") | ||
| ``` | ||
|
|
||
| Expected Output on CUDA | ||
|
|
||
| ``` | ||
| Uncompiled: 0.0176s | ||
| Compiled: 0.0063s | ||
| Speedup: 2.80x | ||
| ``` | ||
|
|
||
| Here, JIT compilation via `torch.compile` is the only viable option because | ||
| 1. Already vectorized - All operations are already PyTorch tensor ops. | ||
| 2. Multiple Kernel Launches - Uncompiled code launches ~10 separate kernels. torch.compile fuses them into 1-2 kernels, eliminating kernel launch overhead. | ||
| 3. No algorithmic improvement - The computation itself is already optimal. | ||
| 4. Python overhead elimination - Removes Python interpreter overhead between operations. | ||
|
|
||
|
|
||
| ## When JIT Compilation May Not Help | ||
|
|
||
| JIT compilation may not provide speedups when: | ||
|
|
||
| - The code already uses highly optimized libraries (e.g., `NumPy` with `MKL`, `cuBLAS`, `cuDNN`). | ||
| - Functions have variable input types or shapes that prevent effective compilation. | ||
| - The compilation overhead exceeds the runtime savings for short-running functions. | ||
| - The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize. | ||
|
|
||
| ### Example | ||
|
|
||
| #### Function Definition | ||
|
|
||
| ``` | ||
| def adaptive_processing(x, threshold=0.5): | ||
| """Function with data-dependent control flow - compile struggles here""" | ||
| # Check how many values exceed threshold (data-dependent!) | ||
| mask = x > threshold | ||
| num_large = mask.sum().item() # .item() causes graph break | ||
|
|
||
| if num_large > x.numel() * 0.3: | ||
| # Path 1: Many large values - use expensive operation | ||
| result = torch.matmul(x, x.T) # Already optimized by cuBLAS | ||
| result = result.mean(dim=0) | ||
| else: | ||
| # Path 2: Few large values - use cheap operation | ||
| result = x.mean(dim=1) | ||
|
|
||
| return result | ||
| ``` | ||
|
|
||
| #### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac) | ||
|
|
||
| ``` | ||
| # Create compiled version | ||
| adaptive_processing_compiled = torch.compile(adaptive_processing) | ||
|
|
||
| # Test with data that causes branch variation | ||
| x = torch.randn(500, 500, device='cuda') | ||
|
|
||
| # Warmup | ||
| for _ in range(10): | ||
| _ = adaptive_processing(x) | ||
| _ = adaptive_processing_compiled(x) | ||
|
|
||
| # Benchmark with varying data (causes recompilation) | ||
| torch.cuda.synchronize() | ||
| start = time.time() | ||
| for i in range(100): | ||
| # Vary the data to trigger different branches | ||
| x_test = torch.randn(500, 500, device='cuda') + (i % 2) | ||
| y = adaptive_processing(x_test) | ||
| torch.cuda.synchronize() | ||
| uncompiled_time = time.time() - start | ||
|
|
||
| torch.cuda.synchronize() | ||
| start = time.time() | ||
| for i in range(100): | ||
| x_test = torch.randn(500, 500, device='cuda') + (i % 2) | ||
| y = adaptive_processing_compiled(x_test) # Recompiles frequently! | ||
| torch.cuda.synchronize() | ||
| compiled_time = time.time() - start | ||
|
|
||
| print(f"Uncompiled: {uncompiled_time:.4f}s") | ||
| print(f"Compiled: {compiled_time:.4f}s") | ||
| print(f"Slowdown: {compiled_time/uncompiled_time:.2f}x") | ||
| ``` | ||
|
|
||
| Expected Output on CUDA | ||
|
|
||
| ``` | ||
| Uncompiled: 0.0296s | ||
| Compiled: 0.2847s | ||
| Slowdown: 9.63x | ||
| ``` | ||
|
|
||
| Why `torch.compile` is detrimental here: | ||
|
|
||
| 1. Graph breaks - `.item()` forces a graph break, negating compile benefits. | ||
| 2. Recompilation overhead - Different branches cause expensive recompilation each time. | ||
| 3. Dynamic control flow - Data-dependent conditionals can't be optimized away. | ||
| 4. Already optimized ops - `matmul` already uses `cuBLAS`; compile adds overhead without benefit. | ||
|
|
||
| #### Better Optimization Strategy | ||
|
|
||
| ```python | ||
| def optimized_version(x, threshold=0.5): | ||
| """Remove data-dependent control flow - vectorize instead""" | ||
| mask = (x > threshold).float() | ||
| weight = (mask.mean() > 0.3).float() # Keep on GPU | ||
|
|
||
| # Compute both paths, blend based on weight (branchless) | ||
| expensive = torch.matmul(x, x.T).mean(dim=0) | ||
| cheap = x.mean(dim=1).squeeze() | ||
|
|
||
| # Pad cheap result to match expensive dimensions | ||
| cheap_padded = cheap.expand(expensive.shape[0]) | ||
|
|
||
| result = weight * expensive + (1 - weight) * cheap_padded | ||
| return result | ||
| ``` | ||
|
|
||
| Expected Output on CUDA | ||
|
|
||
| ``` | ||
| Optimized: 0.0277s | ||
| Speedup compared to Uncompiled: 1.57x | ||
| ``` | ||
|
|
||
|
|
||
| Key improvements: | ||
|
|
||
| 1. Eliminate `.item()` - Keep computation on GPU. | ||
| 2. Branchless execution - Compute both paths, blend results. | ||
| 3. Vectorization - Replace conditionals with masked operations. | ||
| 4. Reduce Python overhead - Minimize host-device synchronization. | ||
|
|
||
| ## Configuration | ||
|
aseembits93 marked this conversation as resolved.
Outdated
|
||
|
|
||
| JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations. | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.