diff --git a/onnx/src/main/proto/onnx-ml.proto b/onnx/src/main/proto/onnx-ml.proto index 4c26add1..c7c966c8 100644 --- a/onnx/src/main/proto/onnx-ml.proto +++ b/onnx/src/main/proto/onnx-ml.proto @@ -91,7 +91,7 @@ enum Version { IR_VERSION_2019_9_19 = 0x0000000000000006; // IR VERSION 7 published on May 8, 2020 - // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add support to allow function body graph to rely on multiple external operator sets. // - Add a list to promote inference graph's initializers to global and // mutable variables. Global variables are visible in all graphs of the // stored models. @@ -111,7 +111,23 @@ enum Version { // IR VERSION 9 published on May 5, 2023 // Added AttributeProto to FunctionProto so that default attribute values can be set. // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. - IR_VERSION = 0x0000000000000009; + IR_VERSION_2023_5_5 = 0x0000000000000009; + + // IR VERSION 10 published on March 25, 2024 + // Added UINT4, INT4, overload field for functions and metadata_props on multiple proto definitions. + IR_VERSION_2024_3_25 = 0x000000000000000A; + + // IR VERSION 11 published on May 12, 2025 + // Added FLOAT4E2M1, multi-device protobuf classes. + IR_VERSION_2025_05_12 = 0x000000000000000B; + + // IR VERSION 12 published on August 26, 2025 + // Added FLOAT8E8M0. + IR_VERSION_2025_08_26 = 0x000000000000000C; + + // IR VERSION 13 published on November 6, 2025 + // Added UINT2, INT2. + IR_VERSION = 0x000000000000000D; } // Attributes @@ -121,6 +137,8 @@ enum Version { // An AttributeProto MUST contain the name field, and *only one* of the // following content fields, effectively enforcing a C/C++ union equivalent. message AttributeProto { + reserved 12, 16 to 19; + reserved "v"; // Note: this enum is structurally identical to the OpSchema::AttrType // enum defined in schema.h. If you rev one, you likely need to rev the other. @@ -193,6 +211,8 @@ message ValueInfoProto { TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. string doc_string = 3; + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 4; } // Nodes @@ -207,19 +227,101 @@ message NodeProto { repeated string output = 2; // namespace Value // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. + // This field MAY be absent in this version of the IR. string name = 3; // namespace Node // The symbolic identifier of the Operator to execute. string op_type = 4; // namespace Operator // The domain of the OperatorSet that specifies the operator named by op_type. string domain = 7; // namespace Domain + // Overload identifier, used only to map this to a model-local function. + string overload = 8; // Additional named attributes. repeated AttributeProto attribute = 5; // A human-readable documentation for this node. Markdown is allowed. string doc_string = 6; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 9; + + // Configuration of multi-device annotations. + repeated NodeDeviceConfigurationProto device_configurations = 10; +} + +// IntIntListEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message IntIntListEntryProto { + int64 key = 1; + repeated int64 value = 2; +}; + +// Multi-device configuration proto for NodeProto. +message NodeDeviceConfigurationProto { + // This field MUST be present for this version of the IR. + // ID of the configuration. MUST match the name of a DeviceConfigurationProto. + string configuration_id = 1; + // Sharding spec for the node. + repeated ShardingSpecProto sharding_spec = 2; + // Pipeline stage of this node. + int32 pipeline_stage = 3; +} + +// ShardingSpecProto: This describes the sharding spec for a specific +// input or output tensor of a node. +message ShardingSpecProto { + // This field MUST be present for this version of the IR. + // Identifies the input or output of the node that is being sharded. + // Required to match a name specified in the node's input or output list of ValueInfoProtos. + // It is called `logical tensor` in subsequent descriptions. + string tensor_name = 1; + + // The following is the list of devices across which the logical + // tensor is sharded or replicated. + repeated int64 device = 2; + + // Each element v in above field devices may represent either a + // device or a set of devices (when we want the same shard/tensor + // to be replicated across a subset of devices), as indicated by + // the following optional map. If the map contains an entry for v, + // then v represents a device group, and the map indicates the set + // of devices in that group. + repeated IntIntListEntryProto index_to_device_group_map = 3; + + // The following is the sharded-shape of the tensor, consisting of + // the sharding-spec for each axis of the tensor. + repeated ShardedDimProto sharded_dim = 4; +} + +// ShardedDimProto: This describes the sharding spec for a single +// axis of a sharded tensor. +message ShardedDimProto { + // This field MUST be present for this version of the IR. + // The axis this sharding corresponds to. Must be in the range of + // [-r, r - 1], where r is the rank of the tensor. Negative axis values means + // counting from the back. + int64 axis = 1; + + // Describes how the tensor on the provided axis is sharded. + // The common-case is described by a single instance of SimpleShardedDimProto. + // Multiple instances can be used to handle cases where a sharded + // tensor is reshaped, fusing multiple axes into one. + repeated SimpleShardedDimProto simple_sharding = 2; +} + +// SimpleShardedDimProto: Indicates that N blocks are divided into M shards. +// N is allowed to be symbolic where M is required to be a constant. +message SimpleShardedDimProto { + // Dimension value to be sharded. + oneof dim { + int64 dim_value = 1; + string dim_param = 2; + } + + // This field MUST be present for this version of the IR. + // Number of shards to split dim into. + int64 num_shards = 3; } // Training information @@ -404,9 +506,9 @@ message ModelProto { // A list of function protos local to the model. // - // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // The (domain, name, overload) tuple must be unique across the function protos in this list. // In case of any conflicts the behavior (whether the model local functions are given higher priority, - // or standard operator sets are given higher priotity or this is treated as error) is defined by + // or standard operator sets are given higher priority or this is treated as error) is defined by // the runtimes. // // The operator sets imported by FunctionProto should be compatible with the ones @@ -419,8 +521,24 @@ message ModelProto { // One FunctionProto can reference other FunctionProto in the model, however, recursive reference // is not allowed. repeated FunctionProto functions = 25; + + // Describes different target configurations for a multi-device use case. + // A model MAY describe multiple multi-device configurations for execution. + repeated DeviceConfigurationProto configuration = 26; }; +// DeviceConfigurationProto describes a multi-device configuration for a model. +message DeviceConfigurationProto { + // This field MUST be present for this version of the IR. + // Name of the configuration. + string name = 1; + // This field MUST be present for this version of the IR. + // Number of devices inside this configuration. + int32 num_devices = 2; + // Optional names of the devices. MUST be length of num_devices if provided. + repeated string device = 3; +} + // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { @@ -478,6 +596,9 @@ message GraphProto { // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. repeated TensorAnnotation quantization_annotation = 14; + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; + reserved 3, 4, 6 to 9; reserved "ir_version", "producer_version", "producer_tag", "domain"; } @@ -523,7 +644,22 @@ message TensorProto { FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients - FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + + // 4-bit integer data types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + + // 4-bit floating point data types + FLOAT4E2M1 = 23; + + // E8M0 type used as the scale for microscaling (MX) formats: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + FLOAT8E8M0 = 24; + + // 2-bit integer data type + UINT2 = 25; // Unsigned integer in range [0, 3] + INT2 = 26; // Signed integer in range [-2, 1], using two's complement representation // Future extensions go here. } @@ -558,11 +694,23 @@ message TensorProto { // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. repeated float float_data = 4 [packed = true]; - // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values - // float16 and float8 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. + // For int32, uint8, int8, uint16, int16, uint4, int4, uint2, int2, bool, (b)float16, float8, and float4: + // - (b)float16 and float8 values MUST be converted bit-wise into an unsigned integer + // representation before being written to the buffer. + // - Each pair of uint4, int4, and float4 values MUST be packed as two 4-bit elements into a single byte. + // The first element is stored in the 4 least significant bits (LSB), + // and the second element is stored in the 4 most significant bits (MSB). + // - Each group of four uint2, int2 values MUST be packed as four 2-bit elements into a single byte. + // The elements are packed from LSB to MSB, with the first element in bits 0-1, second element in bits 2-3, + // third element in bits 4-5, and fourth element in bits 6-7. + // + // Consequently: + // - For data types with a bit-width of 8 or greater, each `int32_data` stores one element. + // - For 4-bit data types, each `int32_data` stores two elements. + // - For 2-bit data types, each `int32_data` stores four elements. + // // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ + // INT32, INT16, INT8, INT4, INT2, UINT16, UINT8, UINT4, UINT2, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ, FLOAT8E8M0, FLOAT4E2M1 repeated int32 int32_data = 5 [packed = true]; // For strings. @@ -592,6 +740,9 @@ message TensorProto { // Complex64 elements must be written as two consecutive FLOAT values, real component first. // Complex128 elements must be written as two consecutive DOUBLE values, real component first. // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. + // uint2 and int2 values must be packed to 2bitx4, with elements packed from LSB to MSB in a single byte as: x0 | (x1 << 2) | (x2 << 4) | (x3 << 6) + // where x0, x1, x2, x3 are consecutive elements. // // Note: the advantage of specific field rather than the raw_data field is // that in some cases (e.g. int data), protobuf does a better packing via @@ -634,6 +785,9 @@ message TensorProto { // When this field is present, the data_type field MUST be // UINT32 or UINT64 repeated uint64 uint64_data = 11 [packed = true]; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; } // A serialized sparse-tensor value @@ -794,9 +948,8 @@ enum OperatorStatus { } message FunctionProto { - // The name of the function, similar usage of op_type in OperatorProto. - // Combined with FunctionProto.domain, this forms the unique identity of - // the FunctionProto. + // The name of the function, similar to op_type in NodeProto. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string name = 1; // Deprecated since IR Version 8 @@ -843,11 +996,23 @@ message FunctionProto { repeated OperatorSetIdProto opset_import = 9; - // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of - // the FunctionProto. + // The domain which this function belongs to. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. string domain = 10; -} + // The overload identifier of the function. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + string overload = 13; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +} // For using protobuf-lite option optimize_for = LITE_RUNTIME; diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java index 1e534b78..5dc057c0 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java @@ -15,6 +15,17 @@ /** * A representation of a dense tensor. Use {@link #getInfo()} to select the proper buffer type. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type. * + * The relevant view will depend on the number of bits of the tensor's type. Consider the following examples: + *