Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions exir/print_program.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -60,6 +61,8 @@ def _scalar_type_str(scalar_type: ScalarType) -> str:
ScalarType.QUINT8: "qui8",
ScalarType.QINT32: "qi32",
ScalarType.BFLOAT16: "bf16",
ScalarType.FLOAT8E5M2: "f8e5m2",
ScalarType.FLOAT8E4M3FN: "f8e4m3fn",
ScalarType.QUINT4x2: "qui4x2",
ScalarType.QUINT2x4: "qui2x4",
}
Expand Down
3 changes: 3 additions & 0 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -149,6 +150,8 @@ def _reverse_map(d: Dict[Any, Enum]):
torch.complex128: ScalarType.COMPLEXDOUBLE,
torch.bool: ScalarType.BOOL,
torch.bfloat16: ScalarType.BFLOAT16,
torch.float8_e5m2: ScalarType.FLOAT8E5M2,
torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN,
torch.uint16: ScalarType.UINT16
}

Expand Down
5 changes: 4 additions & 1 deletion exir/serde/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -16,7 +17,7 @@
from executorch.exir.serde.union import _Union

# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (5, 3)
SCHEMA_VERSION = (5, 4)
TREESPEC_VERSION = 1


Expand All @@ -36,6 +37,8 @@ class ScalarType(IntEnum):
BOOL = 12
BFLOAT16 = 13
UINT16 = 14
FLOAT8E5M2 = 15
FLOAT8E4M3FN = 16

class Layout(IntEnum):
Unknown = 0
Expand Down
3 changes: 3 additions & 0 deletions exir/tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -292,6 +293,8 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
torch.qint32: ScalarType.QINT32,
torch.bfloat16: ScalarType.BFLOAT16,
torch.quint4x2: ScalarType.QUINT4x2,
torch.float8_e5m2: ScalarType.FLOAT8E5M2,
torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN,
torch.uint16: ScalarType.UINT16,
torch.uint32: ScalarType.UINT32,
}
Expand Down
9 changes: 9 additions & 0 deletions exir/tests/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -90,6 +91,14 @@ def test_normal_tensor_conversion(self) -> None:
# whereas strides for torch.memory_format = torch.channels_last is
# (3*4*5, 1, 5*3, 3))

def test_fp8_tensor_conversion(self) -> None:
for dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
normal_tensor = torch.randn(2, 2, 3, dtype=torch.float32).to(dtype)
flatbuffer_tensor = make_tensor_value(
1, 0, TensorSpec.from_tensor(normal_tensor)
)
self.compare_tensors(normal_tensor, flatbuffer_tensor)

def test_allocation_info_succeeds(self) -> None:
test_cases = (
(
Expand Down
Loading