Skip to content

Commit f458abe

Browse files
vthumbe1503pre-commit-ci[bot]timmoon10
authored
[PyTorch] Python DType enum (#3039)
* initial prototype Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * some more Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * done Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cache python_to_cpp and cpp_to_python casts for dtype Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add the missing conversion file Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * lint Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix build docs Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix review comment, lint Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix docs and addres review commentsg Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cache Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace 'te.DType' with 'tepytorch.DType' in tests Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Apply suggestion from @timmoon10 Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename pybind_dtype_casters.h to pybind_dtype_caster.h Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Fix include statement for pybind_dtype_caster Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent 97a9bfe commit f458abe

66 files changed

Lines changed: 649 additions & 457 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmarks/benchmark_rht_cast.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.utils.benchmark as benchmark
99

1010
import transformer_engine.pytorch as te
11-
import transformer_engine_torch as tex
1211
import transformer_engine.pytorch.cpp_extensions as ext
1312

1413
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
@@ -17,7 +16,7 @@
1716
permute_scale = False
1817

1918
TORCH_TO_TE_FLOAT_MAP = {
20-
torch.bfloat16: tex.DType.kBFloat16,
19+
torch.bfloat16: te.DType.kBFloat16,
2120
}
2221

2322

@@ -31,7 +30,7 @@ def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
3130

3231
# Quantize
3332
nvfp4_quantizer = NVFP4Quantizer(
34-
fp4_dtype=tex.DType.kFloat4E2M1,
33+
fp4_dtype=te.DType.kFloat4E2M1,
3534
rowwise=True,
3635
columnwise=True,
3736
with_amax_reduction=False,

docs/api/pytorch.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ PyTorch
5959

6060
.. autoapifunction:: transformer_engine.pytorch.deinterleave_glu_tensor
6161

62+
Data types
63+
----------
64+
65+
.. autoapiclass:: transformer_engine.pytorch.DType()
66+
:members: kByte, kInt32, kFloat32, kFloat16, kBFloat16, kFloat8E4M3, kFloat8E5M2, kFloat4E2M1
67+
6268
Recipe availability
6369
-------------------
6470

docs/examples/quickstart_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,8 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
204204

205205
def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
206206
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
207-
import transformer_engine_torch as tex
208207

209-
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
208+
fp8_type = te.DType.kFloat8E4M3 if fp8_format == "e4m3" else te.DType.kFloat8E5M2
210209
scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
211210
amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
212211
quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
1616
import transformer_engine_torch as tex
17+
from transformer_engine.pytorch import DType
1718
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
1819
from transformer_engine.pytorch import (
1920
autocast,
@@ -328,34 +329,34 @@ def run_dpa_with_cp(
328329
).cuda()
329330
if scaling_mode == "delayed":
330331
qkv_quantizer = Float8Quantizer(
331-
fp8_dtype=tex.DType.kFloat8E4M3,
332+
fp8_dtype=DType.kFloat8E4M3,
332333
scale=torch.tensor([1], dtype=torch.float32).cuda(),
333334
amax=torch.tensor([0], dtype=torch.float32).cuda(),
334335
)
335336
dout_quantizer = Float8Quantizer(
336-
fp8_dtype=tex.DType.kFloat8E5M2,
337+
fp8_dtype=DType.kFloat8E5M2,
337338
scale=torch.tensor([1], dtype=torch.float32).cuda(),
338339
amax=torch.tensor([0], dtype=torch.float32).cuda(),
339340
)
340341
if scaling_mode == "current":
341342
qkv_quantizer = Float8CurrentScalingQuantizer(
342-
fp8_dtype=tex.DType.kFloat8E4M3,
343+
fp8_dtype=DType.kFloat8E4M3,
343344
device="cuda",
344345
)
345346
dout_quantizer = Float8CurrentScalingQuantizer(
346-
fp8_dtype=tex.DType.kFloat8E5M2,
347+
fp8_dtype=DType.kFloat8E5M2,
347348
device="cuda",
348349
)
349350
if scaling_mode == "mxfp8":
350351
qkv_quantizer = MXFP8Quantizer(
351-
fp8_dtype=tex.DType.kFloat8E4M3,
352+
fp8_dtype=DType.kFloat8E4M3,
352353
rowwise=True,
353354
columnwise=True,
354355
)
355356
qkv_quantizer.optimize_for_gemm = True
356357
qkv_quantizer.internal = False
357358
dout_quantizer = MXFP8Quantizer(
358-
fp8_dtype=tex.DType.kFloat8E5M2,
359+
fp8_dtype=DType.kFloat8E5M2,
359360
rowwise=True,
360361
columnwise=True,
361362
)

tests/pytorch/debug/run_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
import torch.distributed as dist
1515
import transformer_engine
16-
import transformer_engine_torch as tex
16+
from transformer_engine.pytorch import DType
1717
import nvdlfw_inspect.api as debug_api
1818
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
1919
from transformer_engine.pytorch import is_fp8_available
@@ -683,7 +683,7 @@ def _run_test_with_combinations(
683683
)
684684

685685
# test_fake_quant_fp8
686-
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
686+
dtype_options = [DType.kFloat8E4M3, DType.kFloat8E5M2, None]
687687
_run_test_with_combinations(
688688
test_fake_quant_fp8,
689689
dtype_options,

tests/pytorch/debug/test_api_features.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import torch
66
from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
7+
from transformer_engine.pytorch import DType
78

89
import nvdlfw_inspect.api as debug_api
910

1011
try:
1112
import transformer_engine
12-
import transformer_engine_torch as tex
1313
except (ImportError, ModuleNotFoundError):
1414
print("Could not find TransformerEngine package.")
1515
exit(1)
@@ -128,12 +128,12 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
128128
default_quantizer1 = Float8Quantizer(
129129
scale=torch.tensor([1]).cuda(),
130130
amax=torch.tensor([0]).cuda(),
131-
fp8_dtype=tex.DType.kFloat8E4M3,
131+
fp8_dtype=DType.kFloat8E4M3,
132132
)
133133
default_quantizer2 = Float8Quantizer(
134134
scale=torch.tensor([1]).cuda(),
135135
amax=torch.tensor([0]).cuda(),
136-
fp8_dtype=tex.DType.kFloat8E5M2,
136+
fp8_dtype=DType.kFloat8E5M2,
137137
)
138138

139139
output1 = debug_api.transformer_engine.modify_tensor(
@@ -145,7 +145,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
145145
tensor=tensor,
146146
)
147147
assert type(output1) == Float8Tensor
148-
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
148+
assert output1._fp8_dtype == DType.kFloat8E4M3
149149

150150
output2 = debug_api.transformer_engine.modify_tensor(
151151
"decoder.1.mlp.fc1",
@@ -156,7 +156,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
156156
iteration=0,
157157
)
158158
assert type(output2) == Float8Tensor
159-
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
159+
assert output2._fp8_dtype == DType.kFloat8E5M2
160160

161161
assert not debug_api.transformer_engine.modify_tensor_enabled(
162162
"decoder.1.mlp.fc1",
@@ -234,7 +234,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
234234
quantizer = Float8Quantizer(
235235
scale=torch.full([1], 1.0).cuda(),
236236
amax=torch.full([1], 1.0).cuda(),
237-
fp8_dtype=tex.DType.kFloat8E4M3,
237+
fp8_dtype=DType.kFloat8E4M3,
238238
)
239239
tensor_fp8 = quantizer(tensor)
240240

@@ -372,7 +372,7 @@ def log_stats():
372372
quantizer = Float8Quantizer(
373373
scale=torch.full([1], 1.0).cuda(),
374374
amax=torch.full([1], 1.0).cuda(),
375-
fp8_dtype=tex.DType.kFloat8E4M3,
375+
fp8_dtype=DType.kFloat8E4M3,
376376
)
377377

378378
def fp8_tensor(t):

tests/pytorch/debug/test_numerics.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import nvdlfw_inspect.api as debug_api
1616
import transformer_engine.debug
1717
import transformer_engine.pytorch as tepytorch
18-
import transformer_engine_torch as tex
1918
from transformer_engine.common.recipe import DelayedScaling, Format
2019
from transformer_engine.pytorch.quantization import _default_sf_compute
2120
from transformer_engine.pytorch import (
@@ -57,7 +56,7 @@ def _cast_to_fp8(tensor, scale, dtype):
5756

5857

5958
def _get_current_scale(tensor, fp8_dtype):
60-
if fp8_dtype == tex.DType.kFloat8E4M3:
59+
if fp8_dtype == tepytorch.DType.kFloat8E4M3:
6160
fp8_max = Format.E4M3.value.max_fwd
6261
else:
6362
fp8_max = Format.E5M2.value.max_fwd
@@ -93,19 +92,19 @@ def _emulate_linear(
9392
input: torch.Tensor,
9493
weight: torch.Tensor,
9594
fprop_fp8: bool = False,
96-
fprop_input_fake_quant: tex.DType = None,
95+
fprop_input_fake_quant: tepytorch.DType = None,
9796
fprop_input_scale: torch.Tensor = None,
98-
fprop_weight_fake_quant: tex.DType = None,
97+
fprop_weight_fake_quant: tepytorch.DType = None,
9998
fprop_weight_scale: torch.Tensor = None,
10099
dgrad_fp8: bool = False,
101-
dgrad_gradient_fake_quant: tex.DType = None,
100+
dgrad_gradient_fake_quant: tepytorch.DType = None,
102101
dgrad_gradient_scale: torch.Tensor = None,
103-
dgrad_weight_fake_quant: tex.DType = None,
102+
dgrad_weight_fake_quant: tepytorch.DType = None,
104103
dgrad_weight_scale: torch.Tensor = None,
105104
wgrad_fp8: bool = False,
106-
wgrad_gradient_fake_quant: tex.DType = None,
105+
wgrad_gradient_fake_quant: tepytorch.DType = None,
107106
wgrad_gradient_scale: torch.Tensor = None,
108-
wgrad_input_fake_quant: tex.DType = None,
107+
wgrad_input_fake_quant: tepytorch.DType = None,
109108
wgrad_input_scale: torch.Tensor = None,
110109
loss_multiplier: float = 1.0,
111110
activation_sync=None,
@@ -116,10 +115,10 @@ def _emulate_linear(
116115
activation = _fp8_gemm_kernel(
117116
input,
118117
_scalar(fprop_input_scale or 1.0),
119-
tex.DType.kFloat8E4M3,
118+
tepytorch.DType.kFloat8E4M3,
120119
weight,
121120
_scalar(fprop_weight_scale or 1.0),
122-
tex.DType.kFloat8E4M3,
121+
tepytorch.DType.kFloat8E4M3,
123122
_2X_ACC_FPROP,
124123
)
125124
activation = activation.clone().detach().contiguous().requires_grad_(True)
@@ -152,10 +151,10 @@ def _emulate_linear(
152151
dgrad = _fp8_gemm_kernel(
153152
weight.T,
154153
_scalar(dgrad_weight_scale or 1.0),
155-
tex.DType.kFloat8E4M3,
154+
tepytorch.DType.kFloat8E4M3,
156155
gradient,
157156
_scalar(dgrad_gradient_scale or 1.0),
158-
tex.DType.kFloat8E5M2,
157+
tepytorch.DType.kFloat8E5M2,
159158
_2X_ACC_DGRAD,
160159
).T
161160
else:
@@ -176,10 +175,10 @@ def _emulate_linear(
176175
wgrad = _fp8_gemm_kernel(
177176
input.T,
178177
_scalar(wgrad_input_scale or 1.0),
179-
tex.DType.kFloat8E4M3,
178+
tepytorch.DType.kFloat8E4M3,
180179
gradient.T,
181180
_scalar(wgrad_gradient_scale or 1.0),
182-
tex.DType.kFloat8E5M2,
181+
tepytorch.DType.kFloat8E5M2,
183182
_2X_ACC_WGRAD,
184183
).T
185184
else:
@@ -470,17 +469,17 @@ def set_scaling_factors(model, input_kwargs, fp8_kwargs):
470469
def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs):
471470
# Compute per tensor scaling factor if respective flag in input_kwargs is set.
472471
if input_kwargs["fprop_inp"]:
473-
fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3
472+
fp8_kwargs["fprop_input_scale"] = tepytorch.DType.kFloat8E4M3
474473
if input_kwargs["fprop_weight"]:
475-
fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3
474+
fp8_kwargs["fprop_weight_scale"] = tepytorch.DType.kFloat8E4M3
476475
if input_kwargs["dgrad_grad"]:
477-
fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2
476+
fp8_kwargs["dgrad_gradient_scale"] = tepytorch.DType.kFloat8E5M2
478477
if input_kwargs["dgrad_weight"]:
479-
fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3
478+
fp8_kwargs["dgrad_weight_scale"] = tepytorch.DType.kFloat8E4M3
480479
if input_kwargs["wgrad_grad"]:
481-
fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2
480+
fp8_kwargs["wgrad_gradient_scale"] = tepytorch.DType.kFloat8E5M2
482481
if input_kwargs["wgrad_input"]:
483-
fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3
482+
fp8_kwargs["wgrad_input_scale"] = tepytorch.DType.kFloat8E4M3
484483

485484

486485
@create_config_file
@@ -651,7 +650,7 @@ def init_and_warmup():
651650

652651

653652
all_combinations = list(
654-
itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6)
653+
itertools.product([tepytorch.DType.kFloat8E4M3, tepytorch.DType.kFloat8E5M2, None], repeat=6)
655654
)
656655
subset_combinations = random.sample(all_combinations, 10)
657656

@@ -687,7 +686,7 @@ def test_fake_quant_fp8(
687686
def fake_quant_fp8_create_config(
688687
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
689688
):
690-
format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"}
689+
format_to_str = {tepytorch.DType.kFloat8E4M3: "FP8E4M3", tepytorch.DType.kFloat8E5M2: "FP8E5M2"}
691690
gemms = ""
692691

693692
def _add_tensor(quant_format, tensor):

tests/pytorch/distributed/run_gemm_with_overlap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
534534
if opts.quantization == "fp8":
535535
# Structure to maintain amax and scale/scale_inv information for the kernel and input
536536
num_gemms = 6 if ub_obj2 is not None else 3
537-
fp8_dtype = tex.DType.kFloat8E4M3
537+
fp8_dtype = te.DType.kFloat8E4M3
538538
fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda")
539539
fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda")
540540

@@ -577,7 +577,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
577577
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
578578
)
579579
elif opts.quantization == "mxfp8":
580-
fp8_dtype = tex.DType.kFloat8E4M3
580+
fp8_dtype = te.DType.kFloat8E4M3
581581
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
582582
ker_quantizer = MXFP8Quantizer(fp8_dtype)
583583
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:

tests/pytorch/distributed/run_numerics.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
from torch import nn
1717
import torch.distributed as dist
18-
import transformer_engine_torch as tex
1918
from transformer_engine.common.recipe import (
2019
MXFP8BlockScaling,
2120
DelayedScaling,
@@ -399,7 +398,7 @@ def _test_quantizer(input_dtype, fp8_dtype):
399398
400399
Args:
401400
input_dtype (torch.dtype): The data type of the input.
402-
fp8_dtype (tex.DType): The data type of the fp8.
401+
fp8_dtype (te.DType): The data type of the fp8.
403402
"""
404403

405404
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE
@@ -443,7 +442,7 @@ def test_quantizer():
443442
return
444443

445444
input_dtypes = [torch.float32, torch.bfloat16]
446-
fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
445+
fp8_dtypes = [te.DType.kFloat8E4M3, te.DType.kFloat8E5M2]
447446

448447
for input_dtype in input_dtypes:
449448
for fp8_dtype in fp8_dtypes:
@@ -514,7 +513,7 @@ def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):
514513
515514
Args:
516515
input_dtype (torch.dtype): The data type of the input.
517-
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
516+
low_precision_dtype (te.DType): The data type of the low precision, can be fp4 or fp8.
518517
"""
519518

520519
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
@@ -623,8 +622,8 @@ def test_quantized_all_gather():
623622
return
624623

625624
input_dtypes = [torch.bfloat16]
626-
fp4_dtype = [tex.DType.kFloat4E2M1]
627-
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
625+
fp4_dtype = [te.DType.kFloat4E2M1]
626+
fp8_dtype = [te.DType.kFloat8E4M3, te.DType.kFloat8E5M2]
628627
quantizer_cls_nvfp4 = [NVFP4Quantizer]
629628
# add FP8 quantizers if needed
630629
quantizer_cls_fp8 = []

tests/pytorch/distributed/test_fusible_ops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
is_bf16_available,
3030
)
3131
import transformer_engine.pytorch.ops as te_ops
32-
import transformer_engine_torch as tex
3332

3433
# Import utility functions
3534
_current_file = pathlib.Path(__file__).resolve()
@@ -107,17 +106,17 @@ def make_reference_and_test_tensors(
107106
quantizer = Float8Quantizer(
108107
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
109108
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
110-
fp8_dtype=tex.DType.kFloat8E4M3,
109+
fp8_dtype=te.DType.kFloat8E4M3,
111110
)
112111
test = quantizer(test)
113112
elif quantization == "fp8_current_scaling":
114113
quantizer = Float8CurrentScalingQuantizer(
115-
fp8_dtype=tex.DType.kFloat8E4M3,
114+
fp8_dtype=te.DType.kFloat8E4M3,
116115
device=test_device,
117116
)
118117
test = quantizer(test)
119118
elif quantization == "mxfp8":
120-
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
119+
test = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3)(test)
121120
elif quantization == "nvfp4":
122121
test = NVFP4Quantizer(
123122
with_rht=False,

0 commit comments

Comments
 (0)