1616and pooling modes.
1717"""
1818
19+ import logging
20+ import os
1921import unittest
22+ import warnings
2023from typing import Any , Optional
2124
2225import hypothesis .strategies as st
4447
4548if open_source :
4649 # pyre-ignore[21]
47- from test_utils import gpu_unavailable , optests
50+ from test_utils import additional_decorators , gpu_unavailable , optests
4851else :
49- from fbgemm_gpu .test .test_utils import gpu_unavailable , optests
52+ from fbgemm_gpu .test .test_utils import (
53+ additional_decorators ,
54+ gpu_unavailable ,
55+ optests ,
56+ )
5057
5158
5259VERBOSITY : Verbosity = Verbosity .verbose
5360
61+ SUPPRESS_HEALTH_CHECKS : list [HealthCheck ] = [
62+ HealthCheck .filter_too_much ,
63+ HealthCheck .data_too_large ,
64+ HealthCheck .differing_executors ,
65+ ]
66+
5467
55- @optests .generate_opcheck_tests (fast = True )
68+ @optests .generate_opcheck_tests (fast = True , additional_decorators = additional_decorators )
5669class BackwardDeterminismTest (unittest .TestCase ):
5770 """Verify backward determinism for all TBE optimizer types.
5871
@@ -68,6 +81,14 @@ class BackwardDeterminismTest(unittest.TestCase):
6881 - All other optimizers: compares split_embedding_weights() per table
6982 """
7083
84+ def setUp (self ) -> None :
85+ # The test calls multiple TBE constructors, each emit many debug/info log lines ,
86+ # Suppress INFO/DEBUG log to reduce the noise.
87+ # Set FBGEMM_TEST_VERBOSE=1 to re-enable.
88+ if "FBGEMM_TEST_VERBOSE" not in os .environ :
89+ warnings .simplefilter ("ignore" )
90+ logging .disable (logging .INFO )
91+
7192 def _run_dense_backward (
7293 self ,
7394 Es : list [int ],
@@ -103,7 +124,6 @@ def _run_split_backward(
103124 optimizer : OptimType ,
104125 weights_precision : SparseType ,
105126 output_dtype : SparseType ,
106- stochastic_rounding : bool ,
107127 pooling_mode : PoolingMode ,
108128 extra_kwargs : dict [str , Any ],
109129 ref_cc : SplitTableBatchedEmbeddingBagsCodegen ,
@@ -121,7 +141,7 @@ def _run_split_backward(
121141 learning_rate = 0.5 ,
122142 weights_precision = weights_precision ,
123143 output_dtype = output_dtype ,
124- stochastic_rounding = stochastic_rounding ,
144+ stochastic_rounding = False ,
125145 pooling_mode = pooling_mode ,
126146 ** extra_kwargs ,
127147 )
@@ -185,7 +205,6 @@ def _run_backward_determinism(
185205 pooling_mode : PoolingMode ,
186206 weights_precision : SparseType ,
187207 output_dtype : SparseType = SparseType .FP32 ,
188- stochastic_rounding : bool = False ,
189208 dense : bool = False ,
190209 num_runs : int = 5 ,
191210 optimizer_kwargs : Optional [dict [str , Any ]] = None ,
@@ -204,7 +223,6 @@ def _run_backward_determinism(
204223 pooling_mode: SUM, MEAN, or NONE.
205224 weights_precision: FP16 or FP32 for embedding weights.
206225 output_dtype: Output tensor dtype (FP16, FP32, or BF16).
207- stochastic_rounding: Whether to enable stochastic rounding.
208226 dense: If True, use DenseTableBatchedEmbeddingBagsCodegen.
209227 num_runs: Number of repeated backward passes to compare.
210228 optimizer_kwargs: Extra kwargs passed to the TBE constructor
@@ -291,7 +309,7 @@ def _run_backward_determinism(
291309 learning_rate = 0.5 ,
292310 weights_precision = weights_precision ,
293311 output_dtype = output_dtype ,
294- stochastic_rounding = stochastic_rounding ,
312+ stochastic_rounding = False ,
295313 pooling_mode = pooling_mode ,
296314 ** extra_kwargs ,
297315 )
@@ -339,7 +357,6 @@ def _run_backward_determinism(
339357 optimizer ,
340358 weights_precision ,
341359 output_dtype ,
342- stochastic_rounding ,
343360 pooling_mode ,
344361 extra_kwargs ,
345362 ref_cc ,
@@ -375,7 +392,7 @@ def _run_backward_determinism(
375392 verbosity = VERBOSITY ,
376393 max_examples = 20 ,
377394 deadline = None ,
378- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
395+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
379396 )
380397 @unittest .skipIf (* gpu_unavailable )
381398 def test_backward_determinism_sgd (
@@ -425,13 +442,12 @@ def test_backward_determinism_sgd(
425442 output_dtype = st .sampled_from (
426443 [SparseType .FP32 , SparseType .FP16 , SparseType .BF16 ]
427444 ),
428- stochastic_rounding = st .booleans (),
429445 )
430446 @settings (
431447 verbosity = VERBOSITY ,
432448 max_examples = 20 ,
433449 deadline = None ,
434- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
450+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
435451 )
436452 @unittest .skipIf (* gpu_unavailable )
437453 def test_backward_determinism_adagrad (
@@ -447,11 +463,9 @@ def test_backward_determinism_adagrad(
447463 optimizer : OptimType ,
448464 pooling_mode : PoolingMode ,
449465 output_dtype : SparseType ,
450- stochastic_rounding : bool ,
451466 ) -> None :
452467 """Test determinism for EXACT_ADAGRAD and EXACT_ROWWISE_ADAGRAD with
453- FP16/FP32 weights, FP16/FP32/BF16 output dtypes, and stochastic
454- rounding on/off."""
468+ FP16/FP32 weights and FP16/FP32/BF16 output dtypes."""
455469 self ._run_backward_determinism (
456470 T = T ,
457471 D = D ,
@@ -464,7 +478,6 @@ def test_backward_determinism_adagrad(
464478 pooling_mode = pooling_mode ,
465479 weights_precision = weights_precision ,
466480 output_dtype = output_dtype ,
467- stochastic_rounding = stochastic_rounding ,
468481 optimizer_kwargs = {"eps" : 1e-4 },
469482 )
470483
@@ -493,7 +506,7 @@ def test_backward_determinism_adagrad(
493506 verbosity = VERBOSITY ,
494507 max_examples = 20 ,
495508 deadline = None ,
496- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
509+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
497510 )
498511 @unittest .skipIf (* gpu_unavailable )
499512 def test_backward_determinism_optimizers (
@@ -569,7 +582,7 @@ def test_backward_determinism_optimizers(
569582 verbosity = VERBOSITY ,
570583 max_examples = 20 ,
571584 deadline = None ,
572- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
585+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
573586 )
574587 @unittest .skipIf (* gpu_unavailable )
575588 def test_backward_determinism_partial_rowwise_adam (
@@ -629,7 +642,7 @@ def test_backward_determinism_partial_rowwise_adam(
629642 verbosity = VERBOSITY ,
630643 max_examples = 20 ,
631644 deadline = None ,
632- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
645+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
633646 )
634647 @unittest .skipIf (* gpu_unavailable )
635648 def test_backward_determinism_ensemble (
@@ -692,7 +705,7 @@ def test_backward_determinism_ensemble(
692705 verbosity = VERBOSITY ,
693706 max_examples = 20 ,
694707 deadline = None ,
695- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
708+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
696709 )
697710 @unittest .skipIf (* gpu_unavailable )
698711 def test_backward_determinism_none (
@@ -741,7 +754,7 @@ def test_backward_determinism_none(
741754 verbosity = VERBOSITY ,
742755 max_examples = 20 ,
743756 deadline = None ,
744- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
757+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
745758 )
746759 @unittest .skipIf (* gpu_unavailable )
747760 def test_backward_determinism_dense (
@@ -796,7 +809,7 @@ def test_backward_determinism_dense(
796809 verbosity = VERBOSITY ,
797810 max_examples = 10 ,
798811 deadline = None ,
799- suppress_health_check = [ HealthCheck . filter_too_much , HealthCheck . data_too_large ] ,
812+ suppress_health_check = SUPPRESS_HEALTH_CHECKS ,
800813 )
801814 @unittest .skipIf (* gpu_unavailable )
802815 def test_backward_determinism_long_segments (
0 commit comments