Commit eda06d5
[pmap] Push more into _cached_shard_map.
NOTE: We aren't using all the fields of CachedShardMap yet. Subsequent CLs will make use of them.
Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`.
PiperOrigin-RevId: 8619524221 parent 5ff8a51 commit eda06d5
File tree
1 file changed
+1
-0
lines changed- tensorflow_probability/python/distributions
1 file changed
+1
-0
lines changedLines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
75 | 75 | | |
76 | 76 | | |
77 | 77 | | |
| 78 | + | |
78 | 79 | | |
79 | 80 | | |
80 | 81 | | |
| |||
0 commit comments