1515import nvdlfw_inspect .api as debug_api
1616import transformer_engine .debug
1717import transformer_engine .pytorch as tepytorch
18- import transformer_engine_torch as tex
1918from transformer_engine .common .recipe import DelayedScaling , Format
2019from transformer_engine .pytorch .quantization import _default_sf_compute
2120from transformer_engine .pytorch import (
@@ -57,7 +56,7 @@ def _cast_to_fp8(tensor, scale, dtype):
5756
5857
5958def _get_current_scale (tensor , fp8_dtype ):
60- if fp8_dtype == tex .DType .kFloat8E4M3 :
59+ if fp8_dtype == tepytorch .DType .kFloat8E4M3 :
6160 fp8_max = Format .E4M3 .value .max_fwd
6261 else :
6362 fp8_max = Format .E5M2 .value .max_fwd
@@ -93,19 +92,19 @@ def _emulate_linear(
9392 input : torch .Tensor ,
9493 weight : torch .Tensor ,
9594 fprop_fp8 : bool = False ,
96- fprop_input_fake_quant : tex .DType = None ,
95+ fprop_input_fake_quant : tepytorch .DType = None ,
9796 fprop_input_scale : torch .Tensor = None ,
98- fprop_weight_fake_quant : tex .DType = None ,
97+ fprop_weight_fake_quant : tepytorch .DType = None ,
9998 fprop_weight_scale : torch .Tensor = None ,
10099 dgrad_fp8 : bool = False ,
101- dgrad_gradient_fake_quant : tex .DType = None ,
100+ dgrad_gradient_fake_quant : tepytorch .DType = None ,
102101 dgrad_gradient_scale : torch .Tensor = None ,
103- dgrad_weight_fake_quant : tex .DType = None ,
102+ dgrad_weight_fake_quant : tepytorch .DType = None ,
104103 dgrad_weight_scale : torch .Tensor = None ,
105104 wgrad_fp8 : bool = False ,
106- wgrad_gradient_fake_quant : tex .DType = None ,
105+ wgrad_gradient_fake_quant : tepytorch .DType = None ,
107106 wgrad_gradient_scale : torch .Tensor = None ,
108- wgrad_input_fake_quant : tex .DType = None ,
107+ wgrad_input_fake_quant : tepytorch .DType = None ,
109108 wgrad_input_scale : torch .Tensor = None ,
110109 loss_multiplier : float = 1.0 ,
111110 activation_sync = None ,
@@ -116,10 +115,10 @@ def _emulate_linear(
116115 activation = _fp8_gemm_kernel (
117116 input ,
118117 _scalar (fprop_input_scale or 1.0 ),
119- tex .DType .kFloat8E4M3 ,
118+ tepytorch .DType .kFloat8E4M3 ,
120119 weight ,
121120 _scalar (fprop_weight_scale or 1.0 ),
122- tex .DType .kFloat8E4M3 ,
121+ tepytorch .DType .kFloat8E4M3 ,
123122 _2X_ACC_FPROP ,
124123 )
125124 activation = activation .clone ().detach ().contiguous ().requires_grad_ (True )
@@ -152,10 +151,10 @@ def _emulate_linear(
152151 dgrad = _fp8_gemm_kernel (
153152 weight .T ,
154153 _scalar (dgrad_weight_scale or 1.0 ),
155- tex .DType .kFloat8E4M3 ,
154+ tepytorch .DType .kFloat8E4M3 ,
156155 gradient ,
157156 _scalar (dgrad_gradient_scale or 1.0 ),
158- tex .DType .kFloat8E5M2 ,
157+ tepytorch .DType .kFloat8E5M2 ,
159158 _2X_ACC_DGRAD ,
160159 ).T
161160 else :
@@ -176,10 +175,10 @@ def _emulate_linear(
176175 wgrad = _fp8_gemm_kernel (
177176 input .T ,
178177 _scalar (wgrad_input_scale or 1.0 ),
179- tex .DType .kFloat8E4M3 ,
178+ tepytorch .DType .kFloat8E4M3 ,
180179 gradient .T ,
181180 _scalar (wgrad_gradient_scale or 1.0 ),
182- tex .DType .kFloat8E5M2 ,
181+ tepytorch .DType .kFloat8E5M2 ,
183182 _2X_ACC_WGRAD ,
184183 ).T
185184 else :
@@ -470,17 +469,17 @@ def set_scaling_factors(model, input_kwargs, fp8_kwargs):
470469def set_current_scaling_factors (x , weight , y , input_kwargs , fp8_kwargs ):
471470 # Compute per tensor scaling factor if respective flag in input_kwargs is set.
472471 if input_kwargs ["fprop_inp" ]:
473- fp8_kwargs ["fprop_input_scale" ] = tex .DType .kFloat8E4M3
472+ fp8_kwargs ["fprop_input_scale" ] = tepytorch .DType .kFloat8E4M3
474473 if input_kwargs ["fprop_weight" ]:
475- fp8_kwargs ["fprop_weight_scale" ] = tex .DType .kFloat8E4M3
474+ fp8_kwargs ["fprop_weight_scale" ] = tepytorch .DType .kFloat8E4M3
476475 if input_kwargs ["dgrad_grad" ]:
477- fp8_kwargs ["dgrad_gradient_scale" ] = tex .DType .kFloat8E5M2
476+ fp8_kwargs ["dgrad_gradient_scale" ] = tepytorch .DType .kFloat8E5M2
478477 if input_kwargs ["dgrad_weight" ]:
479- fp8_kwargs ["dgrad_weight_scale" ] = tex .DType .kFloat8E4M3
478+ fp8_kwargs ["dgrad_weight_scale" ] = tepytorch .DType .kFloat8E4M3
480479 if input_kwargs ["wgrad_grad" ]:
481- fp8_kwargs ["wgrad_gradient_scale" ] = tex .DType .kFloat8E5M2
480+ fp8_kwargs ["wgrad_gradient_scale" ] = tepytorch .DType .kFloat8E5M2
482481 if input_kwargs ["wgrad_input" ]:
483- fp8_kwargs ["wgrad_input_scale" ] = tex .DType .kFloat8E4M3
482+ fp8_kwargs ["wgrad_input_scale" ] = tepytorch .DType .kFloat8E4M3
484483
485484
486485@create_config_file
@@ -651,7 +650,7 @@ def init_and_warmup():
651650
652651
653652all_combinations = list (
654- itertools .product ([tex .DType .kFloat8E4M3 , tex .DType .kFloat8E5M2 , None ], repeat = 6 )
653+ itertools .product ([tepytorch .DType .kFloat8E4M3 , tepytorch .DType .kFloat8E5M2 , None ], repeat = 6 )
655654)
656655subset_combinations = random .sample (all_combinations , 10 )
657656
@@ -687,7 +686,7 @@ def test_fake_quant_fp8(
687686def fake_quant_fp8_create_config (
688687 fprop_inp , fprop_weight , dgrad_weight , dgrad_grad , wgrad_input , wgrad_grad , config_file
689688):
690- format_to_str = {tex .DType .kFloat8E4M3 : "FP8E4M3" , tex .DType .kFloat8E5M2 : "FP8E5M2" }
689+ format_to_str = {tepytorch .DType .kFloat8E4M3 : "FP8E4M3" , tepytorch .DType .kFloat8E5M2 : "FP8E5M2" }
691690 gemms = ""
692691
693692 def _add_tensor (quant_format , tensor ):
0 commit comments