https://github.com/google/praxis/blob/43db2717e0d2e9ef09b566f7e7bbad049d63dceb/praxis/layers/gpu_fast_attention.py#L133 https://github.com/google/jax/blame/main/jax/experimental/pallas/ops/attention.py#L163 seems like jax has added `segment_ids` as required argument but praxis has not updated to add the argument
praxis/praxis/layers/gpu_fast_attention.py
Line 133 in 43db271
https://github.com/google/jax/blame/main/jax/experimental/pallas/ops/attention.py#L163
seems like jax has added
segment_idsas required argument but praxis has not updated to add the argument