@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module):
883883 Dimensions that will share the same dropout mask for hidden
884884 attention_dropout: float, default = 0.1
885885 Dropout probability for the dropout op during multi-head attention.
886+ intermediate_dropout: float, default = 0.1
887+ Dropout probability for the dropout op after FC1 layer.
888+ intermediate_dropout_dims: Sequence[int], default = ()
889+ Dimensions that will share the same dropout mask for hidden after FC1 layer.
886890 dropout_rng_name: str, default = 'dropout'
887891 The key in given RNGs via flax.linen.Module.apply that for
888892 generating Dropout masks in the Multi-Head Attention.
@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module):
963967 hidden_dropout : float = 0.1
964968 hidden_dropout_dims : Sequence [int ] = ()
965969 attention_dropout : float = 0.1
970+ intermediate_dropout : float = 0.1
971+ intermediate_dropout_dims : Sequence [int ] = ()
966972 dropout_rng_name : str = 'dropout'
967973 mha_kernel_init : Initializer = None
968974 mlp_kernel_init : Initializer = None
@@ -1078,6 +1084,8 @@ def __call__(self,
10781084 else :
10791085 mha_name = 'self_attention'
10801086
1087+ inputs = _with_sharding_constraint (inputs , (BATCH_AXES , SEQLEN_AXES , HIDDEN_AXES ))
1088+
10811089 # [batch, length, emb_dim] -> [batch, length, emb_dim]
10821090 x , residual = MultiHeadAttention (
10831091 num_heads = self .num_attention_heads ,
@@ -1113,14 +1121,15 @@ def hidden_dropout(x, deterministic):
11131121 assert - x_shape_len <= dims < x_shape_len
11141122
11151123 return nn .Dropout (rate = self .hidden_dropout ,
1116- broadcast_dims = self .hidden_dropout_dims )( x ,
1117- deterministic = deterministic )
1124+ broadcast_dims = self .hidden_dropout_dims ,
1125+ rng_collection = self . dropout_rng_name )( x , deterministic = deterministic )
11181126
11191127 x = hidden_dropout (x , deterministic )
11201128 if self .drop_path > 0.0 :
11211129 drop_path_shape = _generate_drop_path_shape (x .shape , batch_dim )
11221130 x = nn .Dropout (rate = self .drop_path ,
1123- broadcast_dims = drop_path_shape )(x , deterministic = deterministic )
1131+ broadcast_dims = drop_path_shape ,
1132+ rng_collection = self .dropout_rng_name )(x , deterministic = deterministic )
11241133 x = x + residual
11251134
11261135 mlp_input = x
@@ -1156,6 +1165,8 @@ def hidden_dropout(x, deterministic):
11561165 y = hidden_dropout (y , deterministic )
11571166 mlp_input = y + residual
11581167
1168+ mlp_input = _with_sharding_constraint (mlp_input , (BATCH_AXES , SEQLEN_AXES , HIDDEN_AXES ))
1169+
11591170 # MlpBlock
11601171 residual = mlp_input
11611172 z , ln_out = LayerNormMLP (
@@ -1167,8 +1178,9 @@ def hidden_dropout(x, deterministic):
11671178 return_layernorm_output = self .apply_residual_connection_post_layernorm ,
11681179 intermediate_dim = self .mlp_hidden_size ,
11691180 activations = self .mlp_activations ,
1170- intermediate_dropout_rate = self .hidden_dropout ,
1171- intermediate_hidden_dropout_dims = self .hidden_dropout_dims ,
1181+ intermediate_dropout_rng_name = self .dropout_rng_name ,
1182+ intermediate_dropout_rate = self .intermediate_dropout ,
1183+ intermediate_hidden_dropout_dims = self .intermediate_dropout_dims ,
11721184 dtype = self .dtype ,
11731185 scale_axes = (W_NO_SHARD_AXES ,),
11741186 ln_bias_axes = (W_NO_SHARD_AXES ,),
0 commit comments