2121
2222from absl .testing import absltest
2323from absl .testing import parameterized
24+ import jax
25+ import numpy as onp
26+
27+ from tensor2tensor .trax import backend
2428from tensor2tensor .trax import layers as tl
29+ from tensor2tensor .trax .backend import numpy as np
2530from tensor2tensor .trax .models .research import reformer
2631
2732
33+ class PoisonOnRNGMismatchAttention (tl .BaseCausalAttention ):
34+ """Fills gradients with NaNs if reverse rng does not match forward rng."""
35+
36+ # pylint: disable=protected-access
37+ def forward_and_backward (self , inputs , ct , rng = None , ** kwargs ):
38+ assert backend .get_name () == 'jax' , (
39+ 'JAX backend is required to use forward_and_backward.' )
40+
41+ if ct is not None and tl .Layer ._STASH_OUT is not None :
42+ recovered_rng = tl .Layer ._STASH_OUT .pop (self )
43+ is_same = (rng [0 ] == recovered_rng [0 ]) & (rng [1 ] == recovered_rng [1 ])
44+ is_same = is_same .astype (np .float32 )
45+ # Divides by zero if rngs are not the same, which results in NaNs.
46+ inputs = (inputs [0 ] / is_same , inputs [1 ] / is_same , inputs [2 ] / is_same )
47+
48+ def _do_forward (x ): # pylint: disable=invalid-name
49+ res , _ = self .forward (x , rng = rng , ** kwargs )
50+ return res
51+ output , vjpfun = jax .vjp (_do_forward , inputs )
52+ return output , vjpfun (ct )[0 ]
53+
54+ def forward (self , inputs , params = (), state = (), rng = None , ** kwargs ):
55+ if tl .Layer ._STASH_IN is not None :
56+ tl .Layer ._STASH_IN [self ] = rng
57+ return inputs [2 ], state
58+ # pylint: enable=protected-access
59+
60+
2861class ReformerTest (parameterized .TestCase ):
2962
3063 def test_reformer_lm_forward_shape (self ):
@@ -39,6 +72,33 @@ def test_reformer_lm_forward_shape(self):
3972 model , tuple (input_shape ), integer_inputs = True )
4073 self .assertEqual (((1 , 8 , 16 ), (1 , 8 , 16 )), final_shape )
4174
75+ def test_reformer_rng_consistency (self ):
76+ with backend .use_backend ('jax' ):
77+ vocab_size = 16
78+ batch_size = 1
79+ input_shape = ((batch_size , 8 ), (batch_size , 8 ))
80+ model = reformer .ReformerLM (
81+ vocab_size , d_model = 32 , d_ff = 64 ,
82+ d_attention_key = 16 , d_attention_value = 16 , n_layers = 1 , n_heads = 2 ,
83+ max_len = 16 , n_chunks = 2 , n_attention_chunks = 1 , mode = 'train' ,
84+ attention_type = PoisonOnRNGMismatchAttention )
85+
86+ rng = backend .random .get_prng (0 )
87+ params , state = model .initialize_once (
88+ input_shape , (np .int32 , np .int32 ), rng )
89+
90+ def dummy_loss_fn (params ):
91+ inputs = (np .zeros (input_shape [0 ], dtype = np .int32 ),) * 2
92+ output = model (inputs , params = params , state = state , rng = rng )
93+ dummy_loss = backend .numpy .sum (output [0 ])
94+ return dummy_loss
95+
96+ grad_fn = backend .grad (dummy_loss_fn )
97+ grads = grad_fn (params )
98+ # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
99+ for grad in jax .tree_util .tree_leaves (grads ):
100+ assert onp .all (onp .isfinite (grad ))
101+
42102
43103if __name__ == '__main__' :
44104 absltest .main ()
0 commit comments