diff --git a/exir/print_program.py b/exir/print_program.py index cf2daa2c2d3..fe3d1037c7f 100644 --- a/exir/print_program.py +++ b/exir/print_program.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -60,6 +61,8 @@ def _scalar_type_str(scalar_type: ScalarType) -> str: ScalarType.QUINT8: "qui8", ScalarType.QINT32: "qi32", ScalarType.BFLOAT16: "bf16", + ScalarType.FLOAT8E5M2: "f8e5m2", + ScalarType.FLOAT8E4M3FN: "f8e4m3fn", ScalarType.QUINT4x2: "qui4x2", ScalarType.QUINT2x4: "qui2x4", } diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 7fd1f9470d4..572d87f2dec 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -149,6 +150,8 @@ def _reverse_map(d: Dict[Any, Enum]): torch.complex128: ScalarType.COMPLEXDOUBLE, torch.bool: ScalarType.BOOL, torch.bfloat16: ScalarType.BFLOAT16, + torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, torch.uint16: ScalarType.UINT16 } diff --git a/exir/serde/schema.py b/exir/serde/schema.py index f91526c385f..bd1904aa34f 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,7 +17,7 @@ from executorch.exir.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (5, 3) +SCHEMA_VERSION = (5, 4) TREESPEC_VERSION = 1 @@ -36,6 +37,8 @@ class ScalarType(IntEnum): BOOL = 12 BFLOAT16 = 13 UINT16 = 14 + FLOAT8E5M2 = 15 + FLOAT8E4M3FN = 16 class Layout(IntEnum): Unknown = 0 diff --git a/exir/tensor.py b/exir/tensor.py index b1619d16bdf..e40d6f10168 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -292,6 +293,8 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, + torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, torch.uint16: ScalarType.UINT16, torch.uint32: ScalarType.UINT32, } diff --git a/exir/tests/test_tensor.py b/exir/tests/test_tensor.py index c5383b0dac2..1a73e81319c 100644 --- a/exir/tests/test_tensor.py +++ b/exir/tests/test_tensor.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -90,6 +91,14 @@ def test_normal_tensor_conversion(self) -> None: # whereas strides for torch.memory_format = torch.channels_last is # (3*4*5, 1, 5*3, 3)) + def test_fp8_tensor_conversion(self) -> None: + for dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + normal_tensor = torch.randn(2, 2, 3, dtype=torch.float32).to(dtype) + flatbuffer_tensor = make_tensor_value( + 1, 0, TensorSpec.from_tensor(normal_tensor) + ) + self.compare_tensors(normal_tensor, flatbuffer_tensor) + def test_allocation_info_succeeds(self) -> None: test_cases = ( (