Integrate torchax custom attention kernel into ulysses#392
Integrate torchax custom attention kernel into ulysses#392copybara-service[bot] merged 1 commit intomainfrom
Conversation
| bq = 2048 | ||
| bkv = 2048 | ||
| bkv_compute = 1024 | ||
| bkv_compute_in = 256 | ||
| heads_per_tile = 1 |
There was a problem hiding this comment.
Updating this to
bq = 4864
bkv = 1024
bkv_compute = 1024
bkv_compute_in = 1024
heads_per_tile = 1
and using this command gave me the following latency
Load (checkpoint): 297.0s
Compile: 219.8s
───────────────────────────────
Inference: 147.4s
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
56c76b8 to
daf4a31
Compare
|
Updated the PR based on @Perseus14 's comments: passing Also moved the custom kernel file under Updated the profiling code to log perf with more granularity: |
daf4a31 to
3f0f13d
Compare
|
Updated perf logging code. It was not logging each component's time correctly. In New run: Also added the two new flags from @rishabhmanoj which give us about ~1 sec of gain |
3f0f13d to
82033df
Compare
82033df to
589c3d5
Compare
Adding torchax path's custom kernel into ulysses (triggered when attention=
ulysses_custom)Inference time: