Skip to content

Commit eda06d5

Browse files
danielsuotensorflower-gardener
authored andcommitted
[pmap] Push more into _cached_shard_map.
NOTE: We aren't using all the fields of CachedShardMap yet. Subsequent CLs will make use of them. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861952422
1 parent 5ff8a51 commit eda06d5

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tensorflow_probability/python/distributions/jax_transformation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
PMAP_SAMPLE_BLOCKLIST = (
7676
'Bates',
7777
'BatchReshape', # http://b/163171224
78+
'IncrementLogProb', # pmap_shmap_merge sharding conflict
7879
'MixtureSameFamily', # Too slow: http://b/170871051
7980
'NegativeBinomial', # Too slow: http://b/170871051
8081
'ZeroInflatedNegativeBinomial', # Too slow: http://b/170871051

0 commit comments

Comments
 (0)