@@ -49,31 +49,37 @@ def bench_layer_norm(shape, dtype, mode, backend, benchmark):
4949 torch .bfloat16 : (1e-2 , 1e-2 ),
5050 }[dtype ]
5151
52- y = backend (x , weight , bias , eps )
53- y_ref = torch_layer_norm (x , weight , bias , eps )
54- if mode == "forward" :
55- torch .testing .assert_close (y , y_ref , atol = atol , rtol = rtol )
56- bench_f , bench_args = backend , (x , weight , bias , eps )
57- else :
58- y .backward (dy , retain_graph = True )
59- dx , dw , db = [_ .grad .clone () for _ in [x , weight , bias ]]
60- x .grad , weight .grad , bias .grad = None , None , None
61-
62- y_ref .backward (dy , retain_graph = True )
63- dx_ref , dw_ref , db_ref = [_ .grad .clone () for _ in [x , weight , bias ]]
64-
65- torch .testing .assert_close (dx , dx_ref , atol = atol , rtol = rtol )
66- torch .testing .assert_close (dw , dw_ref , atol = atol , rtol = rtol )
67- torch .testing .assert_close (db , db_ref , atol = atol , rtol = rtol )
68-
69- bench_f , bench_args = partial (y .backward , retain_graph = True ), (dy ,)
70-
71- warmup_rounds , iterations , rounds = estimate_bench_iter (bench_f , bench_args )
72-
73- benchmark .pedantic (
74- bench_f , bench_args ,
75- rounds = rounds , warmup_rounds = warmup_rounds , iterations = iterations ,
76- )
52+ # Run in non default stream so backward graph can be captured without
53+ # sync with default stream
54+ s = torch .cuda .Stream ()
55+ s .wait_stream (torch .cuda .current_stream ())
56+ with torch .cuda .stream (s ):
57+ y = backend (x , weight , bias , eps )
58+ y_ref = torch_layer_norm (x , weight , bias , eps )
59+ if mode == "forward" :
60+ torch .testing .assert_close (y , y_ref , atol = atol , rtol = rtol )
61+ bench_f , bench_args = backend , (x , weight , bias , eps )
62+ else :
63+ y .backward (dy , retain_graph = True )
64+ dx , dw , db = [_ .grad .clone () for _ in [x , weight , bias ]]
65+ x .grad , weight .grad , bias .grad = None , None , None
66+
67+ y_ref .backward (dy , retain_graph = True )
68+ dx_ref , dw_ref , db_ref = [_ .grad .clone () for _ in [x , weight , bias ]]
69+
70+ torch .testing .assert_close (dx , dx_ref , atol = atol , rtol = rtol )
71+ torch .testing .assert_close (dw , dw_ref , atol = atol , rtol = rtol )
72+ torch .testing .assert_close (db , db_ref , atol = atol , rtol = rtol )
73+
74+ bench_f , bench_args = partial (y .backward , retain_graph = True ), (dy ,)
75+
76+ warmup_rounds , iterations , rounds = estimate_bench_iter (bench_f , bench_args , cudagraph = True )
77+
78+ benchmark .pedantic (
79+ bench_f , bench_args ,
80+ rounds = rounds , warmup_rounds = warmup_rounds , iterations = iterations ,
81+ cudagraph = True
82+ )
7783
7884
7985class CuTileLayerNorm (torch .autograd .Function ):
0 commit comments