From 622ae52320104cc7cac8cb7999c7617f3bb1a22c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 11:34:05 +0000 Subject: [PATCH 1/8] Initial plan From 4e028c29659fcaee07b05f595a3e24dc05e58263 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 11:47:46 +0000 Subject: [PATCH 2/8] feat(jax): add type hints to base atomic model and model functions Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/jax/atomic_model/base_atomic_model.py | 6 ++- deepmd/jax/model/base_model.py | 40 ++++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index ffd58daf5e..474fcb03c7 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + from deepmd.jax.common import ( ArrayAPIVariable, to_jax_array, @@ -9,7 +13,7 @@ ) -def base_atomic_model_set_attr(name, value): +def base_atomic_model_set_attr(name: str, value: Any) -> Any: if name in {"out_bias", "out_std"}: value = to_jax_array(value) if value is not None: diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 34ee765459..203da40d07 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -20,7 +20,7 @@ def forward_common_atomic( - self, + self: "BaseModel", extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, @@ -28,7 +28,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, -): +) -> dict[str, jnp.ndarray]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, extended_atype, @@ -60,16 +60,16 @@ def forward_common_atomic( if vdef.r_differentiable: def eval_output( - cc_ext, - extended_atype, - nlist, - mapping, - fparam, - aparam, + cc_ext: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray], + fparam: Optional[jnp.ndarray], + aparam: Optional[jnp.ndarray], *, - _kk=kk, - _atom_axis=atom_axis, - ): + _kk: str = kk, + _atom_axis: int = atom_axis, + ) -> jnp.ndarray: atomic_ret = self.atomic_model.forward_common_atomic( cc_ext[None, ...], extended_atype[None, ...], @@ -117,16 +117,16 @@ def eval_output( if do_atomic_virial: def eval_ce( - cc_ext, - extended_atype, - nlist, - mapping, - fparam, - aparam, + cc_ext: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray], + fparam: Optional[jnp.ndarray], + aparam: Optional[jnp.ndarray], *, - _kk=kk, - _atom_axis=atom_axis - 1, - ): + _kk: str = kk, + _atom_axis: int = atom_axis - 1, + ) -> jnp.ndarray: # atomic_ret[_kk]: [nf, nloc, *def] atomic_ret = self.atomic_model.forward_common_atomic( cc_ext[None, ...], From 668043d4214704e337a774e95d54f748db038fe2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 12:00:55 +0000 Subject: [PATCH 3/8] feat(jax): enable ANN rule and add type hints to core JAX backend --- deepmd/jax/common.py | 16 ++++++------ deepmd/jax/infer/deep_eval.py | 8 +++--- deepmd/jax/model/dp_model.py | 4 +-- deepmd/jax/model/dp_zbl_model.py | 4 +-- deepmd/jax/model/hlo.py | 42 +++++++++++++++---------------- deepmd/jax/model/model.py | 4 +-- deepmd/jax/utils/neighbor_stat.py | 2 +- deepmd/jax/utils/network.py | 10 +++++--- deepmd/jax/utils/serialization.py | 14 ++++++++--- pyproject.toml | 2 +- 10 files changed, 60 insertions(+), 46 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 59f36d11ad..14ae1cad9d 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -70,11 +70,11 @@ def flax_module( metas.add(type(nnx.Module)) class MixedMetaClass(*metas): - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return type(nnx.Module).__call__(self, *args, **kwargs) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: return super().__init_subclass__(**kwargs) def __setattr__(self, name: str, value: Any) -> None: @@ -84,20 +84,22 @@ def __setattr__(self, name: str, value: Any) -> None: class ArrayAPIVariable(nnx.Variable): - def __array__(self, *args, **kwargs): + def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray: return self.value.__array__(*args, **kwargs) - def __array_namespace__(self, *args, **kwargs): + def __array_namespace__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__array_namespace__(*args, **kwargs) - def __dlpack__(self, *args, **kwargs): + def __dlpack__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack__(*args, **kwargs) - def __dlpack_device__(self, *args, **kwargs): + def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack_device__(*args, **kwargs) -def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray: +def scatter_sum( + input: jnp.ndarray, dim: int, index: jnp.ndarray, src: jnp.ndarray +) -> jnp.ndarray: """Reduces all values from the src tensor to the indices specified in the index tensor.""" idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape) new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel() diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 2e74c15fff..92ed78a13e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -301,7 +301,7 @@ def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Calla """ if self.auto_batch_size is not None: - def eval_func(*args, **kwargs): + def eval_func(*args: Any, **kwargs: Any) -> Any: return self.auto_batch_size.execute_all( inner_func, numb_test, natoms, *args, **kwargs ) @@ -335,7 +335,7 @@ def _eval_model( fparam: Optional[np.ndarray], aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], - ): + ) -> tuple[np.ndarray, ...]: model = self.dp nframes = coords.shape[0] @@ -395,7 +395,9 @@ def _eval_model( ) # this is kinda hacky return tuple(results) - def _get_output_shape(self, odef, nframes, natoms): + def _get_output_shape( + self, odef: OutputVariableDef, nframes: int, natoms: int + ) -> list[int]: if odef.category == OutputVariableCategory.DERV_C_REDU: # virial return [nframes, *odef.shape[:-1], 9] diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py index 436582f22b..ee98a689e4 100644 --- a/deepmd/jax/model/dp_model.py +++ b/deepmd/jax/model/dp_model.py @@ -56,7 +56,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, extended_coord, @@ -74,7 +74,7 @@ def format_nlist( extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, - ): + ) -> jnp.ndarray: return dpmodel_model.format_nlist( self, jax.lax.stop_gradient(extended_coord), diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py index babbc65233..065dbc7aa7 100644 --- a/deepmd/jax/model/dp_zbl_model.py +++ b/deepmd/jax/model/dp_zbl_model.py @@ -38,7 +38,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, extended_coord, @@ -56,7 +56,7 @@ def format_nlist( extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, - ): + ) -> jnp.ndarray: return DPZBLModelDP.format_nlist( self, jax.lax.stop_gradient(extended_coord), diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 4d59957456..38f960df98 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -44,21 +44,21 @@ class HLO(BaseModel): def __init__( self, - stablehlo, - stablehlo_atomic_virial, - stablehlo_no_ghost, - stablehlo_atomic_virial_no_ghost, - model_def_script, - type_map, - rcut, - dim_fparam, - dim_aparam, - sel_type, - is_aparam_nall, - model_output_type, - mixed_types, - min_nbor_dist, - sel, + stablehlo: Any, + stablehlo_atomic_virial: Any, + stablehlo_no_ghost: Any, + stablehlo_atomic_virial_no_ghost: Any, + model_def_script: str, + type_map: list[str], + rcut: float, + dim_fparam: int, + dim_aparam: int, + sel_type: list[int], + is_aparam_nall: bool, + model_output_type: str, + mixed_types: bool, + min_nbor_dist: Optional[float], + sel: list[int], ) -> None: self._call_lower = jax_export.deserialize(stablehlo).call self._call_lower_atomic_virial = jax_export.deserialize( @@ -125,7 +125,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: """Return model prediction. Parameters @@ -165,7 +165,7 @@ def call( do_atomic_virial=do_atomic_virial, ) - def model_output_def(self): + def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) ) @@ -179,7 +179,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: if extended_coord.shape[1] > nlist.shape[1]: if do_atomic_virial: call_lower = self._call_lower_atomic_virial @@ -203,15 +203,15 @@ def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map - def get_rcut(self): + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut - def get_dim_fparam(self): + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam - def get_dim_aparam(self): + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index dc350e968c..321f33b315 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -26,7 +26,7 @@ ) -def get_standard_model(data: dict): +def get_standard_model(data: dict) -> BaseModel: """Get a Model from a dictionary. Parameters @@ -103,7 +103,7 @@ def get_zbl_model(data: dict) -> DPZBLModel: ) -def get_model(data: dict): +def get_model(data: dict) -> BaseModel: """Get a model from a dictionary. Parameters diff --git a/deepmd/jax/utils/neighbor_stat.py b/deepmd/jax/utils/neighbor_stat.py index 6d9bc872e8..ddfc4199a3 100644 --- a/deepmd/jax/utils/neighbor_stat.py +++ b/deepmd/jax/utils/neighbor_stat.py @@ -82,7 +82,7 @@ def _execute( coord: np.ndarray, atype: np.ndarray, cell: Optional[np.ndarray], - ): + ) -> tuple[np.ndarray, np.ndarray]: """Execute the operation. Parameters diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 78da4c96f5..5a42323b90 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -4,6 +4,8 @@ ClassVar, ) +import numpy as np + from deepmd.dpmodel.common import ( NativeOP, ) @@ -26,16 +28,16 @@ class ArrayAPIParam(nnx.Param): - def __array__(self, *args, **kwargs): + def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray: return self.value.__array__(*args, **kwargs) - def __array_namespace__(self, *args, **kwargs): + def __array_namespace__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__array_namespace__(*args, **kwargs) - def __dlpack__(self, *args, **kwargs): + def __dlpack__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack__(*args, **kwargs) - def __dlpack_device__(self, *args, **kwargs): + def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack_device__(*args, **kwargs) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d4da49e08..93d42f5d31 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -2,6 +2,9 @@ from pathlib import ( Path, ) +from typing import ( + Any, +) import numpy as np import orbax.checkpoint as ocp @@ -55,10 +58,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial( do_atomic_virial: bool, has_ghost_atoms: bool - ): + ) -> Any: def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: jnp.ndarray, + atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray, + fparam: jnp.ndarray, + aparam: jnp.ndarray, + ) -> dict[str, jnp.ndarray]: return call_lower( coord, atype, diff --git a/pyproject.toml b/pyproject.toml index 554a113ca2..3d0b7f1b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -424,7 +424,7 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "data/**" = ["ANN"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253"] -"deepmd/jax/**" = ["TID253", "ANN"] +"deepmd/jax/**" = ["TID253"] "deepmd/pd/**" = ["TID253", "ANN"] "source/**" = ["ANN"] From 4064b3b0f951169d960994950be8bf3cbdea4ab1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 12:31:27 +0000 Subject: [PATCH 4/8] feat(jax): add comprehensive type hints to jax2tf interop code Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/jax/jax2tf/format_nlist.py | 2 +- deepmd/jax/jax2tf/make_model.py | 2 +- deepmd/jax/jax2tf/nlist.py | 6 +- deepmd/jax/jax2tf/region.py | 2 +- deepmd/jax/jax2tf/serialization.py | 62 +- deepmd/jax/jax2tf/tfmodel.py | 14 +- pyproject.toml | 1 + source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++++------------ 8 files changed, 640 insertions(+), 542 deletions(-) diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py index f0c630206f..5cf93610e7 100644 --- a/deepmd/jax/jax2tf/format_nlist.py +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -9,7 +9,7 @@ def format_nlist( nlist: tnp.ndarray, nsel: int, rcut: float, -): +) -> tnp.ndarray: """Format neighbor list. If nnei == nsel, do nothing; diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index 29ed131f8e..341fdf0d1f 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -44,7 +44,7 @@ def model_call_from_call_lower( fparam: tnp.ndarray, aparam: tnp.ndarray, do_atomic_virial: bool = False, -): +) -> dict[str, tnp.ndarray]: """Return model prediction from lower interface. Parameters diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py index 5a0ed58b63..f85526f1e9 100644 --- a/deepmd/jax/jax2tf/nlist.py +++ b/deepmd/jax/jax2tf/nlist.py @@ -115,7 +115,7 @@ def nlist_distinguish_types( nlist: tnp.ndarray, atype: tnp.ndarray, sel: list[int], -): +) -> tnp.ndarray: """Given a nlist that does not distinguish atom types, return a nlist that distinguish atom types. @@ -140,7 +140,7 @@ def nlist_distinguish_types( return ret -def tf_outer(a, b): +def tf_outer(a: tnp.ndarray, b: tnp.ndarray) -> tnp.ndarray: return tf.einsum("i,j->ij", a, b) @@ -150,7 +150,7 @@ def extend_coord_with_ghosts( atype: tnp.ndarray, cell: tnp.ndarray, rcut: float, -): +) -> tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray]: """Extend the coordinates of the atoms by appending peridoc images. The number of images is large enough to ensure all the neighbors within rcut are appended. diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py index 96024bd79a..a90e693478 100644 --- a/deepmd/jax/jax2tf/region.py +++ b/deepmd/jax/jax2tf/region.py @@ -93,7 +93,7 @@ def to_face_distance( return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0)) -def b_to_face_distance(cell): +def b_to_face_distance(cell: tnp.ndarray) -> tnp.ndarray: volume = tf.linalg.det(cell) c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) _h2yz = volume / tf.linalg.norm(c_yz, axis=-1) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index aac022ace9..096fc41e5a 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json from typing import ( + Callable, Optional, ) @@ -38,10 +39,17 @@ def deserialize_to_file(model_file: str, data: dict) -> None: tf_model = tf.Module() - def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms): + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ) -> Callable: def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: return call_lower( coord, atype, @@ -86,8 +94,13 @@ def call_lower_with_fixed_do_atomic_virial( ], ) def call_lower_without_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], @@ -112,7 +125,14 @@ def call_lower_without_atomic_virial( tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], ) - def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + def call_lower_with_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], @@ -126,7 +146,7 @@ def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial - def make_call_whether_do_atomic_virial(do_atomic_virial: bool): + def make_call_whether_do_atomic_virial(do_atomic_virial: bool) -> Callable: if do_atomic_virial: call_lower = call_lower_with_atomic_virial else: @@ -138,7 +158,7 @@ def call( box: Optional[tnp.ndarray] = None, fparam: Optional[tnp.ndarray] = None, aparam: Optional[tnp.ndarray] = None, - ): + ) -> dict[str, tnp.ndarray]: """Return model prediction. Parameters @@ -194,7 +214,7 @@ def call_with_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, - ): + ) -> dict[str, tnp.ndarray]: return make_call_whether_do_atomic_virial(do_atomic_virial=True)( coord, atype, box, fparam, aparam ) @@ -217,7 +237,7 @@ def call_without_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, - ): + ) -> dict[str, tnp.ndarray]: return make_call_whether_do_atomic_virial(do_atomic_virial=False)( coord, atype, box, fparam, aparam ) @@ -226,49 +246,49 @@ def call_without_atomic_virial( # set functions to export other attributes @tf.function - def get_type_map(): + def get_type_map() -> tf.Tensor: return tf.constant(model.get_type_map(), dtype=tf.string) tf_model.get_type_map = get_type_map @tf.function - def get_rcut(): + def get_rcut() -> tf.Tensor: return tf.constant(model.get_rcut(), dtype=tf.double) tf_model.get_rcut = get_rcut @tf.function - def get_dim_fparam(): + def get_dim_fparam() -> tf.Tensor: return tf.constant(model.get_dim_fparam(), dtype=tf.int64) tf_model.get_dim_fparam = get_dim_fparam @tf.function - def get_dim_aparam(): + def get_dim_aparam() -> tf.Tensor: return tf.constant(model.get_dim_aparam(), dtype=tf.int64) tf_model.get_dim_aparam = get_dim_aparam @tf.function - def get_sel_type(): + def get_sel_type() -> tf.Tensor: return tf.constant(model.get_sel_type(), dtype=tf.int64) tf_model.get_sel_type = get_sel_type @tf.function - def is_aparam_nall(): + def is_aparam_nall() -> tf.Tensor: return tf.constant(model.is_aparam_nall(), dtype=tf.bool) tf_model.is_aparam_nall = is_aparam_nall @tf.function - def model_output_type(): + def model_output_type() -> tf.Tensor: return tf.constant(model.model_output_type(), dtype=tf.string) tf_model.model_output_type = model_output_type @tf.function - def mixed_types(): + def mixed_types() -> tf.Tensor: return tf.constant(model.mixed_types(), dtype=tf.bool) tf_model.mixed_types = mixed_types @@ -276,19 +296,19 @@ def mixed_types(): if model.get_min_nbor_dist() is not None: @tf.function - def get_min_nbor_dist(): + def get_min_nbor_dist() -> tf.Tensor: return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) tf_model.get_min_nbor_dist = get_min_nbor_dist @tf.function - def get_sel(): + def get_sel() -> tf.Tensor: return tf.constant(model.get_sel(), dtype=tf.int64) tf_model.get_sel = get_sel @tf.function - def get_model_def_script(): + def get_model_def_script() -> tf.Tensor: return tf.constant( json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 0d7b13ba1f..713e023e69 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -45,7 +45,7 @@ def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: class TFModelWrapper(tf.Module): def __init__( self, - model, + model: str, ) -> None: self.model = tf.saved_model.load(model) self._call_lower = jax2tf.call_tf(self.model.call_lower) @@ -115,7 +115,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> Any: """Return model prediction. Parameters @@ -165,7 +165,7 @@ def call( aparam, ) - def model_output_def(self): + def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) ) @@ -179,7 +179,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> Any: if do_atomic_virial: call_lower = self._call_lower_atomic_virial else: @@ -207,15 +207,15 @@ def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map - def get_rcut(self): + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut - def get_dim_fparam(self): + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam - def get_dim_aparam(self): + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam diff --git a/pyproject.toml b/pyproject.toml index 3d0b7f1b9d..8feb5b08fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -425,6 +425,7 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253"] "deepmd/jax/**" = ["TID253"] +"deepmd/jax/jax2tf/**" = ["TID253", "ANN"] "deepmd/pd/**" = ["TID253", "ANN"] "source/**" = ["ANN"] diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 86cfa77378..3a51be271d 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,577 +22,654 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) + def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f'{me}: warning: {msg}\n') + """Emits a nicely-decorated warning.""" + sys.stderr.write(f"{me}: warning: {msg}\n") + def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f'{me}: error: {msg}\n') - sys.exit(1) - -def run(args, stdin=''): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env['LC_ALL'] = 'c' - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) as p: - out, err = p.communicate(input=stdin.encode('utf-8')) - out = out.decode('utf-8') - err = err.decode('utf-8') - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f"{me}: error: {msg}\n") + sys.exit(1) + + +def run(args, stdin=""): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env["LC_ALL"] = "c" + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) as p: + out, err = p.communicate(input=stdin.encode("utf-8")) + out = out.decode("utf-8") + err = err.decode("utf-8") + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err + def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc + def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals + def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(['readelf', '-sW', f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r' +', line) - if line.startswith('Num'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(':', ''), words)) - elif toc is not None: - sym = parse_row(words, toc, ['Value']) - name = sym['Name'] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format - if '@' in name: - sym['Default'] = '@@' in name - name, ver = re.split(r'@+', name) - sym['Name'] = name - sym['Version'] = ver - else: - sym['Default'] = True - sym['Version'] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]['Demangled Name'] = name - - return syms + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(["readelf", "-sW", f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r" +", line) + if line.startswith("Num"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(":", ""), words)) + elif toc is not None: + sym = parse_row(words, toc, ["Value"]) + name = sym["Name"] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format + if "@" in name: + sym["Default"] = "@@" in name + name, ver = re.split(r"@+", name) + sym["Name"] = name + sym["Version"] = ver + else: + sym["Default"] = True + sym["Version"] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]["Demangled Name"] = name + + return syms + def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(['readelf', '-rW', f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == 'There are no relocations in this file.': - return [] - if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS - continue - if re.match(r'^\s*Offset', line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r' \+ ', '+', line) - words = re.split(r'\s+', line) - rel = parse_row(words, toc, ['Offset', 'Info']) - rels.append(rel) - # Split symbolic representation - sym_name = 'Symbol\'s Name + Addend' - if sym_name not in rel and 'Symbol\'s Name' in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel['Symbol\'s Name'] + '+0' - if rel[sym_name]: - p = rel[sym_name].split('+') - if len(p) == 1: - p = ['', p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels + """Collect ELF dynamic relocs.""" + + out, _ = run(["readelf", "-rW", f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == "There are no relocations in this file.": + return [] + if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS + continue + if re.match(r"^\s*Offset", line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r" \+ ", "+", line) + words = re.split(r"\s+", line) + rel = parse_row(words, toc, ["Offset", "Info"]) + rels.append(rel) + # Split symbolic representation + sym_name = "Symbol's Name + Addend" + if sym_name not in rel and "Symbol's Name" in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel["Symbol's Name"] + "+0" + if rel[sym_name]: + p = rel[sym_name].split("+") + if len(p) == 1: + p = ["", p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels + def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(['readelf', '-SW', f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r'\[\s+', '[', line) - words = re.split(r' +', line) - if line.startswith('[Nr]'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {'Addr' : 'Address'}) - elif line.startswith('[') and toc is not None: - sec = parse_row(words, toc, ['Address', 'Off', 'Size']) - if 'A' in sec['Flg']: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections + """Collect section info from ELF.""" + + out, _ = run(["readelf", "-SW", f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r"\[\s+", "[", line) + words = re.split(r" +", line) + if line.startswith("[Nr]"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {"Addr": "Address"}) + elif line.startswith("[") and toc is not None: + sec = parse_row(words, toc, ["Address", "Off", "Size"]) + if "A" in sec["Flg"]: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections + def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, 'rb') as f: - def is_symbol_in_section(sym, sec): - sec_end = sec['Address'] + sec['Size'] - is_start_in_section = sec['Address'] <= sym['Value'] < sec_end - is_end_in_section = sym['Value'] + sym['Size'] <= sec_end - return is_start_in_section and is_end_in_section - for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") - sec = sec[0] - f.seek(sec['Off']) - data[name] = f.read(s['Size']) - return data + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, "rb") as f: + + def is_symbol_in_section(sym, sec): + sec_end = sec["Address"] + sec["Size"] + is_start_in_section = sec["Address"] <= sym["Value"] < sec_end + is_end_in_section = sym["Value"] + sym["Size"] <= sec_end + return is_start_in_section and is_end_in_section + + for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error( + f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" + ) + sec = sec[0] + f.seek(sec["Off"]) + data[name] = f.read(s["Size"]) + return data + def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s['Demangled Name'].startswith('typeinfo name'): - data[name] = [('byte', int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') - data[name].append(('offset', val)) - start = s['Value'] - finish = start + s['Size'] - # TODO: binary search (bisect) - for rel in rels: - if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: - i = (rel['Offset'] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = 'reloc', rel - return data + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s["Demangled Name"].startswith("typeinfo name"): + data[name] = [("byte", int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes( + b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" + ) + data[name].append(("offset", val)) + start = s["Value"] + finish = start + s["Size"] + # TODO: binary search (bisect) + for rel in rels: + if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: + i = (rel["Offset"] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = "reloc", rel + return data + def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = { - 'reloc' : 'const void *', - 'byte' : 'unsigned char', - 'offset' : 'size_t' - } - - ss = [] - ss.append('''\ + """Generate code for vtables""" + c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} + + ss = [] + ss.append("""\ #ifdef __cplusplus extern "C" { #endif -''') +""") - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != 'reloc': - continue - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f'''\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != "reloc": + continue + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f"""\ extern const char {sym_name}[]; -''') +""") - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s['Demangled Name'].startswith('typeinfo name'): - declarator = 'const unsigned char %s[]' - else: - field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) - declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != 'reloc': - vals.append(str(val) + 'UL') - else: - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - vals.append(f'(const char *)&{sym_name} + {addend}') - code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + '_type' - type_decl = decl % type_name - ss.append(f'''\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s["Demangled Name"].startswith("typeinfo name"): + declarator = "const unsigned char %s[]" + else: + field_types = ( + f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) + ) + declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != "reloc": + vals.append(str(val) + "UL") + else: + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + vals.append(f"(const char *)&{sym_name} + {addend}") + code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + "_type" + type_decl = decl % type_name + ss.append(f"""\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -''') +""") - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + '_type' - ss.append(f'''\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + "_type" + ss.append(f"""\ const {type_name} {name} = {init}; -''') +""") - ss.append('''\ + ss.append("""\ #ifdef __cplusplus } // extern "C" #endif -''') +""") + + return "".join(ss) - return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" + """Read ELF's SONAME.""" + + out, _ = run(["readelf", "-d", f]) - out, _ = run(['readelf', '-d', f]) + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) + if soname_match is not None: + return soname_match[1] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) - if soname_match is not None: - return soname_match[1] + return None - return None def main(): - """Driver function""" - parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser( + description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""") - - parser.add_argument('library', - metavar='LIB', - help="Library to be wrapped.") - parser.add_argument('--verbose', '-v', - help="Print diagnostic info", - action='count', - default=0) - parser.add_argument('--dlopen', - help="Emit dlopen call (default)", - dest='dlopen', action='store_true', default=True) - parser.add_argument('--no-dlopen', - help="Do not emit dlopen call (user must load/unload library himself)", - dest='dlopen', action='store_false') - parser.add_argument('--dlopen-callback', - help="Call user-provided custom callback to load library instead of dlopen", - default='') - parser.add_argument('--dlsym-callback', - help="Call user-provided custom callback to resolve a symbol, " - "instead of dlsym", - default='') - parser.add_argument('--library-load-name', - help="Use custom name for dlopened library (default is SONAME)") - parser.add_argument('--lazy-load', - help="Load library on first call to any of it's functions (default)", - dest='lazy_load', action='store_true', default=True) - parser.add_argument('--no-lazy-load', - help="Load library at program start", - dest='lazy_load', action='store_false') - parser.add_argument('--vtables', - help="Intercept virtual tables (EXPERIMENTAL)", - dest='vtables', action='store_true', default=False) - parser.add_argument('--no-vtables', - help="Do not intercept virtual tables (default)", - dest='vtables', action='store_false') - parser.add_argument('--no-weak-symbols', - help="Don't bind weak symbols", dest='no_weak_symbols', - action='store_true', default=False) - parser.add_argument('--target', - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1]) - parser.add_argument('--symbol-list', - help="Path to file with symbols that should be present in wrapper " - "(all by default)") - parser.add_argument('--symbol-prefix', - metavar='PFX', - help="Prefix wrapper symbols with PFX", - default='') - parser.add_argument('-q', '--quiet', - help="Do not print progress info", - action='store_true') - parser.add_argument('--outdir', '-o', - help="Path to create wrapper at", - default='./') - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith('arm'): - target = 'arm' # Handle armhf-..., armel-... - elif re.match(r'^i[0-9]86', args.target): - target = 'i386' - elif args.target.startswith('mips64'): - target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith('mips'): - target = 'mips' # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split('-')[0] - quiet = args.quiet - outdir = args.outdir - - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, 'r') as f: - funs = [] - for line in re.split(r'\r?\n', f.read()): - line = re.sub(r'#.*', '', line) - line = line.strip() - if line: - funs.append(line) +""", + ) + + parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") + parser.add_argument( + "--verbose", "-v", help="Print diagnostic info", action="count", default=0 + ) + parser.add_argument( + "--dlopen", + help="Emit dlopen call (default)", + dest="dlopen", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-dlopen", + help="Do not emit dlopen call (user must load/unload library himself)", + dest="dlopen", + action="store_false", + ) + parser.add_argument( + "--dlopen-callback", + help="Call user-provided custom callback to load library instead of dlopen", + default="", + ) + parser.add_argument( + "--dlsym-callback", + help="Call user-provided custom callback to resolve a symbol, instead of dlsym", + default="", + ) + parser.add_argument( + "--library-load-name", + help="Use custom name for dlopened library (default is SONAME)", + ) + parser.add_argument( + "--lazy-load", + help="Load library on first call to any of it's functions (default)", + dest="lazy_load", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-lazy-load", + help="Load library at program start", + dest="lazy_load", + action="store_false", + ) + parser.add_argument( + "--vtables", + help="Intercept virtual tables (EXPERIMENTAL)", + dest="vtables", + action="store_true", + default=False, + ) + parser.add_argument( + "--no-vtables", + help="Do not intercept virtual tables (default)", + dest="vtables", + action="store_false", + ) + parser.add_argument( + "--no-weak-symbols", + help="Don't bind weak symbols", + dest="no_weak_symbols", + action="store_true", + default=False, + ) + parser.add_argument( + "--target", + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1], + ) + parser.add_argument( + "--symbol-list", + help="Path to file with symbols that should be present in wrapper " + "(all by default)", + ) + parser.add_argument( + "--symbol-prefix", + metavar="PFX", + help="Prefix wrapper symbols with PFX", + default="", + ) + parser.add_argument( + "-q", "--quiet", help="Do not print progress info", action="store_true" + ) + parser.add_argument( + "--outdir", "-o", help="Path to create wrapper at", default="./" + ) + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith("arm"): + target = "arm" # Handle armhf-..., armel-... + elif re.match(r"^i[0-9]86", args.target): + target = "i386" + elif args.target.startswith("mips64"): + target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith("mips"): + target = "mips" # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split("-")[0] + quiet = args.quiet + outdir = args.outdir - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, "r") as f: + funs = [] + for line in re.split(r"\r?\n", f.read()): + line = re.sub(r"#.*", "", line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, 'arch', target) + target_dir = os.path.join(root, "arch", target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=';') - cfg.read(target_dir + '/config.ini') + cfg = configparser.ConfigParser(inline_comment_prefixes=";") + cfg.read(target_dir + "/config.ini") - ptr_size = int(cfg['Arch']['PointerSize']) - symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) + ptr_size = int(cfg["Arch"]["PointerSize"]) + symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) - def is_exported(s): - conditions = [ - s['Bind'] != 'LOCAL', - s['Type'] != 'NOTYPE', - s['Ndx'] != 'UND', - s['Name'] not in ['', '_init', '_fini']] - if args.no_weak_symbols: - conditions.append(s['Bind'] != 'WEAK') - return all(conditions) + def is_exported(s): + conditions = [ + s["Bind"] != "LOCAL", + s["Type"] != "NOTYPE", + s["Ndx"] != "UND", + s["Name"] not in ["", "_init", "_fini"], + ] + if args.no_weak_symbols: + conditions.append(s["Bind"] != "WEAK") + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return (s['Type'] == 'OBJECT' + def is_data_symbol(s): + return ( + s["Type"] == "OBJECT" # Allow vtables if --vtables is on - and not (' for ' in s['Demangled Name'] and args.vtables)) - - exported_data = [s['Name'] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn(f"library '{input_name}' contains data symbols which won't be intercepted: " - + ', '.join(exported_data)) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s['Default']: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s['Name']) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) - funs = [name for name in funs if name in all_funs] - - if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") - - # Collect vtables - - if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s['Name'] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") + and not (" for " in s["Demangled Name"] and args.vtables) + ) + + exported_data = [s["Name"] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn( + f"library '{input_name}' contains data symbols which won't be intercepted: " + + ", ".join(exported_data) + ) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s["Default"]: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s["Name"]) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn( + "some user-specified functions are not present in library: " + + ", ".join(missing_funs) + ) + funs = [name for name in funs if name in all_funs] - secs = collect_sections(input_name) if verbose: - print("Sections:") - for sec in secs: - print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}") + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") - bites = read_unrelocated_data(input_name, cls_syms, secs) + # Collect vtables - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel['Symbol\'s Name + Addend'] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]['Demangled Name'] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) - - tramp_file = f'{suffix}.tramp.S' - with open(os.path.join(outdir, tramp_file), 'w') as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + '/table.S.tpl', 'r') as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - table_size=ptr_size*(len(funs) + 1)) - f.write(table_text) - - with open(target_dir + '/trampoline.S.tpl', 'r') as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i*ptr_size, - number=i) - f.write(tramp_text) - - # Generate C code - - init_file = f'{suffix}.init.c' - with open(os.path.join(outdir, init_file), 'w') as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: - if funs: - sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' - else: - sym_names = '' - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names) - f.write(init_text) if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - -if __name__ == '__main__': - main() + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match( + r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] + ) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s["Name"] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + + secs = collect_sections(input_name) + if verbose: + print("Sections:") + for sec in secs: + print( + f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}" + ) + + bites = read_unrelocated_data(input_name, cls_syms, secs) + + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel["Symbol's Name + Addend"] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data( + cls_syms, bites, rels, ptr_size, symbol_reloc_types + ) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]["Demangled Name"] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print( + " " + + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) + ) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) + + tramp_file = f"{suffix}.tramp.S" + with open(os.path.join(outdir, tramp_file), "w") as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + "/table.S.tpl", "r") as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) + ) + f.write(table_text) + + with open(target_dir + "/trampoline.S.tpl", "r") as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i * ptr_size, + number=i, + ) + f.write(tramp_text) + + # Generate C code + + init_file = f"{suffix}.init.c" + with open(os.path.join(outdir, init_file), "w") as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: + if funs: + sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," + else: + sym_names = "" + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names, + ) + f.write(init_text) + if args.vtables: + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + + +if __name__ == "__main__": + main() From c1c0b2c91506b9ed70fc8474e81c096c494cd06d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 12:59:27 +0000 Subject: [PATCH 5/8] feat(jax): enable ANN rule and add comprehensive type hints to JAX backend Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/jax/utils/serialization.py | 5 +- pyproject.toml | 1 - source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++-------------- 3 files changed, 509 insertions(+), 590 deletions(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 93d42f5d31..829dafae86 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -2,9 +2,6 @@ from pathlib import ( Path, ) -from typing import ( - Any, -) import numpy as np import orbax.checkpoint as ocp @@ -58,7 +55,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial( do_atomic_virial: bool, has_ghost_atoms: bool - ) -> Any: + ) -> jax_export.Exported: def call_lower_with_fixed_do_atomic_virial( coord: jnp.ndarray, atype: jnp.ndarray, diff --git a/pyproject.toml b/pyproject.toml index 8feb5b08fc..3d0b7f1b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -425,7 +425,6 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253"] "deepmd/jax/**" = ["TID253"] -"deepmd/jax/jax2tf/**" = ["TID253", "ANN"] "deepmd/pd/**" = ["TID253", "ANN"] "source/**" = ["ANN"] diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 3a51be271d..86cfa77378 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,654 +22,577 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) - def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f"{me}: warning: {msg}\n") - + """Emits a nicely-decorated warning.""" + sys.stderr.write(f'{me}: warning: {msg}\n') def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f"{me}: error: {msg}\n") - sys.exit(1) - - -def run(args, stdin=""): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env["LC_ALL"] = "c" - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen( - args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) as p: - out, err = p.communicate(input=stdin.encode("utf-8")) - out = out.decode("utf-8") - err = err.decode("utf-8") - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err - + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f'{me}: error: {msg}\n') + sys.exit(1) + +def run(args, stdin=''): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env['LC_ALL'] = 'c' + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, env=env) as p: + out, err = p.communicate(input=stdin.encode('utf-8')) + out = out.decode('utf-8') + err = err.decode('utf-8') + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc - + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals - + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(["readelf", "-sW", f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r" +", line) - if line.startswith("Num"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(":", ""), words)) - elif toc is not None: - sym = parse_row(words, toc, ["Value"]) - name = sym["Name"] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format - if "@" in name: - sym["Default"] = "@@" in name - name, ver = re.split(r"@+", name) - sym["Name"] = name - sym["Version"] = ver - else: - sym["Default"] = True - sym["Version"] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]["Demangled Name"] = name - - return syms - + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(['readelf', '-sW', f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r' +', line) + if line.startswith('Num'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(':', ''), words)) + elif toc is not None: + sym = parse_row(words, toc, ['Value']) + name = sym['Name'] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format + if '@' in name: + sym['Default'] = '@@' in name + name, ver = re.split(r'@+', name) + sym['Name'] = name + sym['Version'] = ver + else: + sym['Default'] = True + sym['Version'] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]['Demangled Name'] = name + + return syms def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(["readelf", "-rW", f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == "There are no relocations in this file.": - return [] - if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS - continue - if re.match(r"^\s*Offset", line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r" \+ ", "+", line) - words = re.split(r"\s+", line) - rel = parse_row(words, toc, ["Offset", "Info"]) - rels.append(rel) - # Split symbolic representation - sym_name = "Symbol's Name + Addend" - if sym_name not in rel and "Symbol's Name" in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel["Symbol's Name"] + "+0" - if rel[sym_name]: - p = rel[sym_name].split("+") - if len(p) == 1: - p = ["", p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels - + """Collect ELF dynamic relocs.""" + + out, _ = run(['readelf', '-rW', f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == 'There are no relocations in this file.': + return [] + if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS + continue + if re.match(r'^\s*Offset', line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r' \+ ', '+', line) + words = re.split(r'\s+', line) + rel = parse_row(words, toc, ['Offset', 'Info']) + rels.append(rel) + # Split symbolic representation + sym_name = 'Symbol\'s Name + Addend' + if sym_name not in rel and 'Symbol\'s Name' in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel['Symbol\'s Name'] + '+0' + if rel[sym_name]: + p = rel[sym_name].split('+') + if len(p) == 1: + p = ['', p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(["readelf", "-SW", f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r"\[\s+", "[", line) - words = re.split(r" +", line) - if line.startswith("[Nr]"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {"Addr": "Address"}) - elif line.startswith("[") and toc is not None: - sec = parse_row(words, toc, ["Address", "Off", "Size"]) - if "A" in sec["Flg"]: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections - + """Collect section info from ELF.""" + + out, _ = run(['readelf', '-SW', f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r'\[\s+', '[', line) + words = re.split(r' +', line) + if line.startswith('[Nr]'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {'Addr' : 'Address'}) + elif line.startswith('[') and toc is not None: + sec = parse_row(words, toc, ['Address', 'Off', 'Size']) + if 'A' in sec['Flg']: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, "rb") as f: - - def is_symbol_in_section(sym, sec): - sec_end = sec["Address"] + sec["Size"] - is_start_in_section = sec["Address"] <= sym["Value"] < sec_end - is_end_in_section = sym["Value"] + sym["Size"] <= sec_end - return is_start_in_section and is_end_in_section - - for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error( - f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" - ) - sec = sec[0] - f.seek(sec["Off"]) - data[name] = f.read(s["Size"]) - return data - + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, 'rb') as f: + def is_symbol_in_section(sym, sec): + sec_end = sec['Address'] + sec['Size'] + is_start_in_section = sec['Address'] <= sym['Value'] < sec_end + is_end_in_section = sym['Value'] + sym['Size'] <= sec_end + return is_start_in_section and is_end_in_section + for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") + sec = sec[0] + f.seek(sec['Off']) + data[name] = f.read(s['Size']) + return data def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s["Demangled Name"].startswith("typeinfo name"): - data[name] = [("byte", int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes( - b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" - ) - data[name].append(("offset", val)) - start = s["Value"] - finish = start + s["Size"] - # TODO: binary search (bisect) - for rel in rels: - if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: - i = (rel["Offset"] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = "reloc", rel - return data - + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s['Demangled Name'].startswith('typeinfo name'): + data[name] = [('byte', int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') + data[name].append(('offset', val)) + start = s['Value'] + finish = start + s['Size'] + # TODO: binary search (bisect) + for rel in rels: + if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: + i = (rel['Offset'] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = 'reloc', rel + return data def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} - - ss = [] - ss.append("""\ + """Generate code for vtables""" + c_types = { + 'reloc' : 'const void *', + 'byte' : 'unsigned char', + 'offset' : 'size_t' + } + + ss = [] + ss.append('''\ #ifdef __cplusplus extern "C" { #endif -""") +''') - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != "reloc": - continue - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f"""\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != 'reloc': + continue + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f'''\ extern const char {sym_name}[]; -""") +''') - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s["Demangled Name"].startswith("typeinfo name"): - declarator = "const unsigned char %s[]" - else: - field_types = ( - f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) - ) - declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != "reloc": - vals.append(str(val) + "UL") - else: - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - vals.append(f"(const char *)&{sym_name} + {addend}") - code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + "_type" - type_decl = decl % type_name - ss.append(f"""\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s['Demangled Name'].startswith('typeinfo name'): + declarator = 'const unsigned char %s[]' + else: + field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) + declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != 'reloc': + vals.append(str(val) + 'UL') + else: + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + vals.append(f'(const char *)&{sym_name} + {addend}') + code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + '_type' + type_decl = decl % type_name + ss.append(f'''\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -""") +''') - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + "_type" - ss.append(f"""\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + '_type' + ss.append(f'''\ const {type_name} {name} = {init}; -""") +''') - ss.append("""\ + ss.append('''\ #ifdef __cplusplus } // extern "C" #endif -""") - - return "".join(ss) +''') + return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" - - out, _ = run(["readelf", "-d", f]) + """Read ELF's SONAME.""" - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) - if soname_match is not None: - return soname_match[1] + out, _ = run(['readelf', '-d', f]) - return None + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) + if soname_match is not None: + return soname_match[1] + return None def main(): - """Driver function""" - parser = argparse.ArgumentParser( - description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""", - ) - - parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") - parser.add_argument( - "--verbose", "-v", help="Print diagnostic info", action="count", default=0 - ) - parser.add_argument( - "--dlopen", - help="Emit dlopen call (default)", - dest="dlopen", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-dlopen", - help="Do not emit dlopen call (user must load/unload library himself)", - dest="dlopen", - action="store_false", - ) - parser.add_argument( - "--dlopen-callback", - help="Call user-provided custom callback to load library instead of dlopen", - default="", - ) - parser.add_argument( - "--dlsym-callback", - help="Call user-provided custom callback to resolve a symbol, instead of dlsym", - default="", - ) - parser.add_argument( - "--library-load-name", - help="Use custom name for dlopened library (default is SONAME)", - ) - parser.add_argument( - "--lazy-load", - help="Load library on first call to any of it's functions (default)", - dest="lazy_load", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-lazy-load", - help="Load library at program start", - dest="lazy_load", - action="store_false", - ) - parser.add_argument( - "--vtables", - help="Intercept virtual tables (EXPERIMENTAL)", - dest="vtables", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-vtables", - help="Do not intercept virtual tables (default)", - dest="vtables", - action="store_false", - ) - parser.add_argument( - "--no-weak-symbols", - help="Don't bind weak symbols", - dest="no_weak_symbols", - action="store_true", - default=False, - ) - parser.add_argument( - "--target", - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1], - ) - parser.add_argument( - "--symbol-list", - help="Path to file with symbols that should be present in wrapper " - "(all by default)", - ) - parser.add_argument( - "--symbol-prefix", - metavar="PFX", - help="Prefix wrapper symbols with PFX", - default="", - ) - parser.add_argument( - "-q", "--quiet", help="Do not print progress info", action="store_true" - ) - parser.add_argument( - "--outdir", "-o", help="Path to create wrapper at", default="./" - ) - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith("arm"): - target = "arm" # Handle armhf-..., armel-... - elif re.match(r"^i[0-9]86", args.target): - target = "i386" - elif args.target.startswith("mips64"): - target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith("mips"): - target = "mips" # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split("-")[0] - quiet = args.quiet - outdir = args.outdir +""") - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, "r") as f: - funs = [] - for line in re.split(r"\r?\n", f.read()): - line = re.sub(r"#.*", "", line) - line = line.strip() - if line: - funs.append(line) - - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + parser.add_argument('library', + metavar='LIB', + help="Library to be wrapped.") + parser.add_argument('--verbose', '-v', + help="Print diagnostic info", + action='count', + default=0) + parser.add_argument('--dlopen', + help="Emit dlopen call (default)", + dest='dlopen', action='store_true', default=True) + parser.add_argument('--no-dlopen', + help="Do not emit dlopen call (user must load/unload library himself)", + dest='dlopen', action='store_false') + parser.add_argument('--dlopen-callback', + help="Call user-provided custom callback to load library instead of dlopen", + default='') + parser.add_argument('--dlsym-callback', + help="Call user-provided custom callback to resolve a symbol, " + "instead of dlsym", + default='') + parser.add_argument('--library-load-name', + help="Use custom name for dlopened library (default is SONAME)") + parser.add_argument('--lazy-load', + help="Load library on first call to any of it's functions (default)", + dest='lazy_load', action='store_true', default=True) + parser.add_argument('--no-lazy-load', + help="Load library at program start", + dest='lazy_load', action='store_false') + parser.add_argument('--vtables', + help="Intercept virtual tables (EXPERIMENTAL)", + dest='vtables', action='store_true', default=False) + parser.add_argument('--no-vtables', + help="Do not intercept virtual tables (default)", + dest='vtables', action='store_false') + parser.add_argument('--no-weak-symbols', + help="Don't bind weak symbols", dest='no_weak_symbols', + action='store_true', default=False) + parser.add_argument('--target', + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1]) + parser.add_argument('--symbol-list', + help="Path to file with symbols that should be present in wrapper " + "(all by default)") + parser.add_argument('--symbol-prefix', + metavar='PFX', + help="Prefix wrapper symbols with PFX", + default='') + parser.add_argument('-q', '--quiet', + help="Do not print progress info", + action='store_true') + parser.add_argument('--outdir', '-o', + help="Path to create wrapper at", + default='./') + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith('arm'): + target = 'arm' # Handle armhf-..., armel-... + elif re.match(r'^i[0-9]86', args.target): + target = 'i386' + elif args.target.startswith('mips64'): + target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith('mips'): + target = 'mips' # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split('-')[0] + quiet = args.quiet + outdir = args.outdir + + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, 'r') as f: + funs = [] + for line in re.split(r'\r?\n', f.read()): + line = re.sub(r'#.*', '', line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, "arch", target) + target_dir = os.path.join(root, 'arch', target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=";") - cfg.read(target_dir + "/config.ini") + cfg = configparser.ConfigParser(inline_comment_prefixes=';') + cfg.read(target_dir + '/config.ini') - ptr_size = int(cfg["Arch"]["PointerSize"]) - symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) + ptr_size = int(cfg['Arch']['PointerSize']) + symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) - def is_exported(s): - conditions = [ - s["Bind"] != "LOCAL", - s["Type"] != "NOTYPE", - s["Ndx"] != "UND", - s["Name"] not in ["", "_init", "_fini"], - ] - if args.no_weak_symbols: - conditions.append(s["Bind"] != "WEAK") - return all(conditions) + def is_exported(s): + conditions = [ + s['Bind'] != 'LOCAL', + s['Type'] != 'NOTYPE', + s['Ndx'] != 'UND', + s['Name'] not in ['', '_init', '_fini']] + if args.no_weak_symbols: + conditions.append(s['Bind'] != 'WEAK') + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return ( - s["Type"] == "OBJECT" + def is_data_symbol(s): + return (s['Type'] == 'OBJECT' # Allow vtables if --vtables is on - and not (" for " in s["Demangled Name"] and args.vtables) - ) - - exported_data = [s["Name"] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn( - f"library '{input_name}' contains data symbols which won't be intercepted: " - + ", ".join(exported_data) - ) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s["Default"]: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s["Name"]) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn( - "some user-specified functions are not present in library: " - + ", ".join(missing_funs) - ) - funs = [name for name in funs if name in all_funs] + and not (' for ' in s['Demangled Name'] and args.vtables)) + + exported_data = [s['Name'] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn(f"library '{input_name}' contains data symbols which won't be intercepted: " + + ', '.join(exported_data)) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s['Default']: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s['Name']) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) + funs = [name for name in funs if name in all_funs] + + if verbose: + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") + + # Collect vtables + + if args.vtables: + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s['Name'] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + secs = collect_sections(input_name) if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") + print("Sections:") + for sec in secs: + print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}") - # Collect vtables + bites = read_unrelocated_data(input_name, cls_syms, secs) + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel['Symbol\'s Name + Addend'] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]['Demangled Name'] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) + + tramp_file = f'{suffix}.tramp.S' + with open(os.path.join(outdir, tramp_file), 'w') as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + '/table.S.tpl', 'r') as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + table_size=ptr_size*(len(funs) + 1)) + f.write(table_text) + + with open(target_dir + '/trampoline.S.tpl', 'r') as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i*ptr_size, + number=i) + f.write(tramp_text) + + # Generate C code + + init_file = f'{suffix}.init.c' + with open(os.path.join(outdir, init_file), 'w') as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: + if funs: + sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' + else: + sym_names = '' + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names) + f.write(init_text) if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match( - r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] - ) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s["Name"] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") - - secs = collect_sections(input_name) - if verbose: - print("Sections:") - for sec in secs: - print( - f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}" - ) - - bites = read_unrelocated_data(input_name, cls_syms, secs) - - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel["Symbol's Name + Addend"] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data( - cls_syms, bites, rels, ptr_size, symbol_reloc_types - ) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]["Demangled Name"] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print( - " " - + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) - ) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) - - tramp_file = f"{suffix}.tramp.S" - with open(os.path.join(outdir, tramp_file), "w") as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + "/table.S.tpl", "r") as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) - ) - f.write(table_text) - - with open(target_dir + "/trampoline.S.tpl", "r") as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i * ptr_size, - number=i, - ) - f.write(tramp_text) - - # Generate C code - - init_file = f"{suffix}.init.c" - with open(os.path.join(outdir, init_file), "w") as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: - if funs: - sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," - else: - sym_names = "" - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names, - ) - f.write(init_text) - if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - - -if __name__ == "__main__": - main() + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + +if __name__ == '__main__': + main() From 73b55acb47e4f019c0960e0d468403ea33b3d67b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 13:54:09 +0000 Subject: [PATCH 6/8] fix(jax): update type hints per review feedback and revert third-party file changes Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/jax/jax2tf/tfmodel.py | 4 +- deepmd/jax/model/hlo.py | 8 +- source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++++------------ 3 files changed, 591 insertions(+), 514 deletions(-) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 713e023e69..61c83fa028 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -115,7 +115,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ) -> Any: + ) -> dict[str, jnp.ndarray]: """Return model prediction. Parameters @@ -179,7 +179,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ) -> Any: + ) -> dict[str, jnp.ndarray]: if do_atomic_virial: call_lower = self._call_lower_atomic_virial else: diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 38f960df98..cbeb915329 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -44,10 +44,10 @@ class HLO(BaseModel): def __init__( self, - stablehlo: Any, - stablehlo_atomic_virial: Any, - stablehlo_no_ghost: Any, - stablehlo_atomic_virial_no_ghost: Any, + stablehlo: bytearray, + stablehlo_atomic_virial: bytearray, + stablehlo_no_ghost: bytearray, + stablehlo_atomic_virial_no_ghost: bytearray, model_def_script: str, type_map: list[str], rcut: float, diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 86cfa77378..3a51be271d 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,577 +22,654 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) + def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f'{me}: warning: {msg}\n') + """Emits a nicely-decorated warning.""" + sys.stderr.write(f"{me}: warning: {msg}\n") + def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f'{me}: error: {msg}\n') - sys.exit(1) - -def run(args, stdin=''): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env['LC_ALL'] = 'c' - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) as p: - out, err = p.communicate(input=stdin.encode('utf-8')) - out = out.decode('utf-8') - err = err.decode('utf-8') - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f"{me}: error: {msg}\n") + sys.exit(1) + + +def run(args, stdin=""): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env["LC_ALL"] = "c" + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) as p: + out, err = p.communicate(input=stdin.encode("utf-8")) + out = out.decode("utf-8") + err = err.decode("utf-8") + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err + def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc + def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals + def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(['readelf', '-sW', f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r' +', line) - if line.startswith('Num'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(':', ''), words)) - elif toc is not None: - sym = parse_row(words, toc, ['Value']) - name = sym['Name'] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format - if '@' in name: - sym['Default'] = '@@' in name - name, ver = re.split(r'@+', name) - sym['Name'] = name - sym['Version'] = ver - else: - sym['Default'] = True - sym['Version'] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]['Demangled Name'] = name - - return syms + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(["readelf", "-sW", f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r" +", line) + if line.startswith("Num"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(":", ""), words)) + elif toc is not None: + sym = parse_row(words, toc, ["Value"]) + name = sym["Name"] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format + if "@" in name: + sym["Default"] = "@@" in name + name, ver = re.split(r"@+", name) + sym["Name"] = name + sym["Version"] = ver + else: + sym["Default"] = True + sym["Version"] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]["Demangled Name"] = name + + return syms + def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(['readelf', '-rW', f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == 'There are no relocations in this file.': - return [] - if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS - continue - if re.match(r'^\s*Offset', line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r' \+ ', '+', line) - words = re.split(r'\s+', line) - rel = parse_row(words, toc, ['Offset', 'Info']) - rels.append(rel) - # Split symbolic representation - sym_name = 'Symbol\'s Name + Addend' - if sym_name not in rel and 'Symbol\'s Name' in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel['Symbol\'s Name'] + '+0' - if rel[sym_name]: - p = rel[sym_name].split('+') - if len(p) == 1: - p = ['', p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels + """Collect ELF dynamic relocs.""" + + out, _ = run(["readelf", "-rW", f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == "There are no relocations in this file.": + return [] + if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS + continue + if re.match(r"^\s*Offset", line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r" \+ ", "+", line) + words = re.split(r"\s+", line) + rel = parse_row(words, toc, ["Offset", "Info"]) + rels.append(rel) + # Split symbolic representation + sym_name = "Symbol's Name + Addend" + if sym_name not in rel and "Symbol's Name" in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel["Symbol's Name"] + "+0" + if rel[sym_name]: + p = rel[sym_name].split("+") + if len(p) == 1: + p = ["", p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels + def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(['readelf', '-SW', f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r'\[\s+', '[', line) - words = re.split(r' +', line) - if line.startswith('[Nr]'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {'Addr' : 'Address'}) - elif line.startswith('[') and toc is not None: - sec = parse_row(words, toc, ['Address', 'Off', 'Size']) - if 'A' in sec['Flg']: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections + """Collect section info from ELF.""" + + out, _ = run(["readelf", "-SW", f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r"\[\s+", "[", line) + words = re.split(r" +", line) + if line.startswith("[Nr]"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {"Addr": "Address"}) + elif line.startswith("[") and toc is not None: + sec = parse_row(words, toc, ["Address", "Off", "Size"]) + if "A" in sec["Flg"]: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections + def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, 'rb') as f: - def is_symbol_in_section(sym, sec): - sec_end = sec['Address'] + sec['Size'] - is_start_in_section = sec['Address'] <= sym['Value'] < sec_end - is_end_in_section = sym['Value'] + sym['Size'] <= sec_end - return is_start_in_section and is_end_in_section - for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") - sec = sec[0] - f.seek(sec['Off']) - data[name] = f.read(s['Size']) - return data + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, "rb") as f: + + def is_symbol_in_section(sym, sec): + sec_end = sec["Address"] + sec["Size"] + is_start_in_section = sec["Address"] <= sym["Value"] < sec_end + is_end_in_section = sym["Value"] + sym["Size"] <= sec_end + return is_start_in_section and is_end_in_section + + for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error( + f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" + ) + sec = sec[0] + f.seek(sec["Off"]) + data[name] = f.read(s["Size"]) + return data + def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s['Demangled Name'].startswith('typeinfo name'): - data[name] = [('byte', int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') - data[name].append(('offset', val)) - start = s['Value'] - finish = start + s['Size'] - # TODO: binary search (bisect) - for rel in rels: - if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: - i = (rel['Offset'] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = 'reloc', rel - return data + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s["Demangled Name"].startswith("typeinfo name"): + data[name] = [("byte", int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes( + b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" + ) + data[name].append(("offset", val)) + start = s["Value"] + finish = start + s["Size"] + # TODO: binary search (bisect) + for rel in rels: + if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: + i = (rel["Offset"] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = "reloc", rel + return data + def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = { - 'reloc' : 'const void *', - 'byte' : 'unsigned char', - 'offset' : 'size_t' - } - - ss = [] - ss.append('''\ + """Generate code for vtables""" + c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} + + ss = [] + ss.append("""\ #ifdef __cplusplus extern "C" { #endif -''') +""") - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != 'reloc': - continue - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f'''\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != "reloc": + continue + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f"""\ extern const char {sym_name}[]; -''') +""") - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s['Demangled Name'].startswith('typeinfo name'): - declarator = 'const unsigned char %s[]' - else: - field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) - declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != 'reloc': - vals.append(str(val) + 'UL') - else: - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - vals.append(f'(const char *)&{sym_name} + {addend}') - code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + '_type' - type_decl = decl % type_name - ss.append(f'''\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s["Demangled Name"].startswith("typeinfo name"): + declarator = "const unsigned char %s[]" + else: + field_types = ( + f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) + ) + declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != "reloc": + vals.append(str(val) + "UL") + else: + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + vals.append(f"(const char *)&{sym_name} + {addend}") + code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + "_type" + type_decl = decl % type_name + ss.append(f"""\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -''') +""") - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + '_type' - ss.append(f'''\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + "_type" + ss.append(f"""\ const {type_name} {name} = {init}; -''') +""") - ss.append('''\ + ss.append("""\ #ifdef __cplusplus } // extern "C" #endif -''') +""") + + return "".join(ss) - return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" + """Read ELF's SONAME.""" + + out, _ = run(["readelf", "-d", f]) - out, _ = run(['readelf', '-d', f]) + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) + if soname_match is not None: + return soname_match[1] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) - if soname_match is not None: - return soname_match[1] + return None - return None def main(): - """Driver function""" - parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser( + description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""") - - parser.add_argument('library', - metavar='LIB', - help="Library to be wrapped.") - parser.add_argument('--verbose', '-v', - help="Print diagnostic info", - action='count', - default=0) - parser.add_argument('--dlopen', - help="Emit dlopen call (default)", - dest='dlopen', action='store_true', default=True) - parser.add_argument('--no-dlopen', - help="Do not emit dlopen call (user must load/unload library himself)", - dest='dlopen', action='store_false') - parser.add_argument('--dlopen-callback', - help="Call user-provided custom callback to load library instead of dlopen", - default='') - parser.add_argument('--dlsym-callback', - help="Call user-provided custom callback to resolve a symbol, " - "instead of dlsym", - default='') - parser.add_argument('--library-load-name', - help="Use custom name for dlopened library (default is SONAME)") - parser.add_argument('--lazy-load', - help="Load library on first call to any of it's functions (default)", - dest='lazy_load', action='store_true', default=True) - parser.add_argument('--no-lazy-load', - help="Load library at program start", - dest='lazy_load', action='store_false') - parser.add_argument('--vtables', - help="Intercept virtual tables (EXPERIMENTAL)", - dest='vtables', action='store_true', default=False) - parser.add_argument('--no-vtables', - help="Do not intercept virtual tables (default)", - dest='vtables', action='store_false') - parser.add_argument('--no-weak-symbols', - help="Don't bind weak symbols", dest='no_weak_symbols', - action='store_true', default=False) - parser.add_argument('--target', - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1]) - parser.add_argument('--symbol-list', - help="Path to file with symbols that should be present in wrapper " - "(all by default)") - parser.add_argument('--symbol-prefix', - metavar='PFX', - help="Prefix wrapper symbols with PFX", - default='') - parser.add_argument('-q', '--quiet', - help="Do not print progress info", - action='store_true') - parser.add_argument('--outdir', '-o', - help="Path to create wrapper at", - default='./') - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith('arm'): - target = 'arm' # Handle armhf-..., armel-... - elif re.match(r'^i[0-9]86', args.target): - target = 'i386' - elif args.target.startswith('mips64'): - target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith('mips'): - target = 'mips' # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split('-')[0] - quiet = args.quiet - outdir = args.outdir - - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, 'r') as f: - funs = [] - for line in re.split(r'\r?\n', f.read()): - line = re.sub(r'#.*', '', line) - line = line.strip() - if line: - funs.append(line) +""", + ) + + parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") + parser.add_argument( + "--verbose", "-v", help="Print diagnostic info", action="count", default=0 + ) + parser.add_argument( + "--dlopen", + help="Emit dlopen call (default)", + dest="dlopen", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-dlopen", + help="Do not emit dlopen call (user must load/unload library himself)", + dest="dlopen", + action="store_false", + ) + parser.add_argument( + "--dlopen-callback", + help="Call user-provided custom callback to load library instead of dlopen", + default="", + ) + parser.add_argument( + "--dlsym-callback", + help="Call user-provided custom callback to resolve a symbol, instead of dlsym", + default="", + ) + parser.add_argument( + "--library-load-name", + help="Use custom name for dlopened library (default is SONAME)", + ) + parser.add_argument( + "--lazy-load", + help="Load library on first call to any of it's functions (default)", + dest="lazy_load", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-lazy-load", + help="Load library at program start", + dest="lazy_load", + action="store_false", + ) + parser.add_argument( + "--vtables", + help="Intercept virtual tables (EXPERIMENTAL)", + dest="vtables", + action="store_true", + default=False, + ) + parser.add_argument( + "--no-vtables", + help="Do not intercept virtual tables (default)", + dest="vtables", + action="store_false", + ) + parser.add_argument( + "--no-weak-symbols", + help="Don't bind weak symbols", + dest="no_weak_symbols", + action="store_true", + default=False, + ) + parser.add_argument( + "--target", + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1], + ) + parser.add_argument( + "--symbol-list", + help="Path to file with symbols that should be present in wrapper " + "(all by default)", + ) + parser.add_argument( + "--symbol-prefix", + metavar="PFX", + help="Prefix wrapper symbols with PFX", + default="", + ) + parser.add_argument( + "-q", "--quiet", help="Do not print progress info", action="store_true" + ) + parser.add_argument( + "--outdir", "-o", help="Path to create wrapper at", default="./" + ) + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith("arm"): + target = "arm" # Handle armhf-..., armel-... + elif re.match(r"^i[0-9]86", args.target): + target = "i386" + elif args.target.startswith("mips64"): + target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith("mips"): + target = "mips" # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split("-")[0] + quiet = args.quiet + outdir = args.outdir - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, "r") as f: + funs = [] + for line in re.split(r"\r?\n", f.read()): + line = re.sub(r"#.*", "", line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, 'arch', target) + target_dir = os.path.join(root, "arch", target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=';') - cfg.read(target_dir + '/config.ini') + cfg = configparser.ConfigParser(inline_comment_prefixes=";") + cfg.read(target_dir + "/config.ini") - ptr_size = int(cfg['Arch']['PointerSize']) - symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) + ptr_size = int(cfg["Arch"]["PointerSize"]) + symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) - def is_exported(s): - conditions = [ - s['Bind'] != 'LOCAL', - s['Type'] != 'NOTYPE', - s['Ndx'] != 'UND', - s['Name'] not in ['', '_init', '_fini']] - if args.no_weak_symbols: - conditions.append(s['Bind'] != 'WEAK') - return all(conditions) + def is_exported(s): + conditions = [ + s["Bind"] != "LOCAL", + s["Type"] != "NOTYPE", + s["Ndx"] != "UND", + s["Name"] not in ["", "_init", "_fini"], + ] + if args.no_weak_symbols: + conditions.append(s["Bind"] != "WEAK") + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return (s['Type'] == 'OBJECT' + def is_data_symbol(s): + return ( + s["Type"] == "OBJECT" # Allow vtables if --vtables is on - and not (' for ' in s['Demangled Name'] and args.vtables)) - - exported_data = [s['Name'] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn(f"library '{input_name}' contains data symbols which won't be intercepted: " - + ', '.join(exported_data)) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s['Default']: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s['Name']) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) - funs = [name for name in funs if name in all_funs] - - if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") - - # Collect vtables - - if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s['Name'] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") + and not (" for " in s["Demangled Name"] and args.vtables) + ) + + exported_data = [s["Name"] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn( + f"library '{input_name}' contains data symbols which won't be intercepted: " + + ", ".join(exported_data) + ) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s["Default"]: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s["Name"]) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn( + "some user-specified functions are not present in library: " + + ", ".join(missing_funs) + ) + funs = [name for name in funs if name in all_funs] - secs = collect_sections(input_name) if verbose: - print("Sections:") - for sec in secs: - print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}") + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") - bites = read_unrelocated_data(input_name, cls_syms, secs) + # Collect vtables - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel['Symbol\'s Name + Addend'] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]['Demangled Name'] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) - - tramp_file = f'{suffix}.tramp.S' - with open(os.path.join(outdir, tramp_file), 'w') as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + '/table.S.tpl', 'r') as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - table_size=ptr_size*(len(funs) + 1)) - f.write(table_text) - - with open(target_dir + '/trampoline.S.tpl', 'r') as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i*ptr_size, - number=i) - f.write(tramp_text) - - # Generate C code - - init_file = f'{suffix}.init.c' - with open(os.path.join(outdir, init_file), 'w') as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: - if funs: - sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' - else: - sym_names = '' - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names) - f.write(init_text) if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - -if __name__ == '__main__': - main() + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match( + r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] + ) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s["Name"] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + + secs = collect_sections(input_name) + if verbose: + print("Sections:") + for sec in secs: + print( + f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}" + ) + + bites = read_unrelocated_data(input_name, cls_syms, secs) + + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel["Symbol's Name + Addend"] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data( + cls_syms, bites, rels, ptr_size, symbol_reloc_types + ) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]["Demangled Name"] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print( + " " + + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) + ) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) + + tramp_file = f"{suffix}.tramp.S" + with open(os.path.join(outdir, tramp_file), "w") as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + "/table.S.tpl", "r") as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) + ) + f.write(table_text) + + with open(target_dir + "/trampoline.S.tpl", "r") as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i * ptr_size, + number=i, + ) + f.write(tramp_text) + + # Generate C code + + init_file = f"{suffix}.init.c" + with open(os.path.join(outdir, init_file), "w") as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: + if funs: + sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," + else: + sym_names = "" + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names, + ) + f.write(init_text) + if args.vtables: + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + + +if __name__ == "__main__": + main() From 3cb3f241b59bc528695712328a2c85d805edbdf9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:13:42 +0000 Subject: [PATCH 7/8] fix(jax): revert accidental changes to third-party file implib-gen.py Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++-------------- 1 file changed, 508 insertions(+), 585 deletions(-) diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 3a51be271d..86cfa77378 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,654 +22,577 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) - def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f"{me}: warning: {msg}\n") - + """Emits a nicely-decorated warning.""" + sys.stderr.write(f'{me}: warning: {msg}\n') def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f"{me}: error: {msg}\n") - sys.exit(1) - - -def run(args, stdin=""): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env["LC_ALL"] = "c" - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen( - args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) as p: - out, err = p.communicate(input=stdin.encode("utf-8")) - out = out.decode("utf-8") - err = err.decode("utf-8") - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err - + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f'{me}: error: {msg}\n') + sys.exit(1) + +def run(args, stdin=''): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env['LC_ALL'] = 'c' + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, env=env) as p: + out, err = p.communicate(input=stdin.encode('utf-8')) + out = out.decode('utf-8') + err = err.decode('utf-8') + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc - + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals - + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(["readelf", "-sW", f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r" +", line) - if line.startswith("Num"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(":", ""), words)) - elif toc is not None: - sym = parse_row(words, toc, ["Value"]) - name = sym["Name"] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format - if "@" in name: - sym["Default"] = "@@" in name - name, ver = re.split(r"@+", name) - sym["Name"] = name - sym["Version"] = ver - else: - sym["Default"] = True - sym["Version"] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]["Demangled Name"] = name - - return syms - + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(['readelf', '-sW', f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r' +', line) + if line.startswith('Num'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(':', ''), words)) + elif toc is not None: + sym = parse_row(words, toc, ['Value']) + name = sym['Name'] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format + if '@' in name: + sym['Default'] = '@@' in name + name, ver = re.split(r'@+', name) + sym['Name'] = name + sym['Version'] = ver + else: + sym['Default'] = True + sym['Version'] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]['Demangled Name'] = name + + return syms def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(["readelf", "-rW", f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == "There are no relocations in this file.": - return [] - if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS - continue - if re.match(r"^\s*Offset", line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r" \+ ", "+", line) - words = re.split(r"\s+", line) - rel = parse_row(words, toc, ["Offset", "Info"]) - rels.append(rel) - # Split symbolic representation - sym_name = "Symbol's Name + Addend" - if sym_name not in rel and "Symbol's Name" in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel["Symbol's Name"] + "+0" - if rel[sym_name]: - p = rel[sym_name].split("+") - if len(p) == 1: - p = ["", p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels - + """Collect ELF dynamic relocs.""" + + out, _ = run(['readelf', '-rW', f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == 'There are no relocations in this file.': + return [] + if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS + continue + if re.match(r'^\s*Offset', line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r' \+ ', '+', line) + words = re.split(r'\s+', line) + rel = parse_row(words, toc, ['Offset', 'Info']) + rels.append(rel) + # Split symbolic representation + sym_name = 'Symbol\'s Name + Addend' + if sym_name not in rel and 'Symbol\'s Name' in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel['Symbol\'s Name'] + '+0' + if rel[sym_name]: + p = rel[sym_name].split('+') + if len(p) == 1: + p = ['', p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(["readelf", "-SW", f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r"\[\s+", "[", line) - words = re.split(r" +", line) - if line.startswith("[Nr]"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {"Addr": "Address"}) - elif line.startswith("[") and toc is not None: - sec = parse_row(words, toc, ["Address", "Off", "Size"]) - if "A" in sec["Flg"]: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections - + """Collect section info from ELF.""" + + out, _ = run(['readelf', '-SW', f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r'\[\s+', '[', line) + words = re.split(r' +', line) + if line.startswith('[Nr]'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {'Addr' : 'Address'}) + elif line.startswith('[') and toc is not None: + sec = parse_row(words, toc, ['Address', 'Off', 'Size']) + if 'A' in sec['Flg']: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, "rb") as f: - - def is_symbol_in_section(sym, sec): - sec_end = sec["Address"] + sec["Size"] - is_start_in_section = sec["Address"] <= sym["Value"] < sec_end - is_end_in_section = sym["Value"] + sym["Size"] <= sec_end - return is_start_in_section and is_end_in_section - - for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error( - f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" - ) - sec = sec[0] - f.seek(sec["Off"]) - data[name] = f.read(s["Size"]) - return data - + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, 'rb') as f: + def is_symbol_in_section(sym, sec): + sec_end = sec['Address'] + sec['Size'] + is_start_in_section = sec['Address'] <= sym['Value'] < sec_end + is_end_in_section = sym['Value'] + sym['Size'] <= sec_end + return is_start_in_section and is_end_in_section + for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") + sec = sec[0] + f.seek(sec['Off']) + data[name] = f.read(s['Size']) + return data def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s["Demangled Name"].startswith("typeinfo name"): - data[name] = [("byte", int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes( - b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" - ) - data[name].append(("offset", val)) - start = s["Value"] - finish = start + s["Size"] - # TODO: binary search (bisect) - for rel in rels: - if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: - i = (rel["Offset"] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = "reloc", rel - return data - + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s['Demangled Name'].startswith('typeinfo name'): + data[name] = [('byte', int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') + data[name].append(('offset', val)) + start = s['Value'] + finish = start + s['Size'] + # TODO: binary search (bisect) + for rel in rels: + if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: + i = (rel['Offset'] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = 'reloc', rel + return data def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} - - ss = [] - ss.append("""\ + """Generate code for vtables""" + c_types = { + 'reloc' : 'const void *', + 'byte' : 'unsigned char', + 'offset' : 'size_t' + } + + ss = [] + ss.append('''\ #ifdef __cplusplus extern "C" { #endif -""") +''') - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != "reloc": - continue - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f"""\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != 'reloc': + continue + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f'''\ extern const char {sym_name}[]; -""") +''') - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s["Demangled Name"].startswith("typeinfo name"): - declarator = "const unsigned char %s[]" - else: - field_types = ( - f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) - ) - declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != "reloc": - vals.append(str(val) + "UL") - else: - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - vals.append(f"(const char *)&{sym_name} + {addend}") - code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + "_type" - type_decl = decl % type_name - ss.append(f"""\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s['Demangled Name'].startswith('typeinfo name'): + declarator = 'const unsigned char %s[]' + else: + field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) + declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != 'reloc': + vals.append(str(val) + 'UL') + else: + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + vals.append(f'(const char *)&{sym_name} + {addend}') + code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + '_type' + type_decl = decl % type_name + ss.append(f'''\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -""") +''') - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + "_type" - ss.append(f"""\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + '_type' + ss.append(f'''\ const {type_name} {name} = {init}; -""") +''') - ss.append("""\ + ss.append('''\ #ifdef __cplusplus } // extern "C" #endif -""") - - return "".join(ss) +''') + return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" - - out, _ = run(["readelf", "-d", f]) + """Read ELF's SONAME.""" - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) - if soname_match is not None: - return soname_match[1] + out, _ = run(['readelf', '-d', f]) - return None + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) + if soname_match is not None: + return soname_match[1] + return None def main(): - """Driver function""" - parser = argparse.ArgumentParser( - description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""", - ) - - parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") - parser.add_argument( - "--verbose", "-v", help="Print diagnostic info", action="count", default=0 - ) - parser.add_argument( - "--dlopen", - help="Emit dlopen call (default)", - dest="dlopen", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-dlopen", - help="Do not emit dlopen call (user must load/unload library himself)", - dest="dlopen", - action="store_false", - ) - parser.add_argument( - "--dlopen-callback", - help="Call user-provided custom callback to load library instead of dlopen", - default="", - ) - parser.add_argument( - "--dlsym-callback", - help="Call user-provided custom callback to resolve a symbol, instead of dlsym", - default="", - ) - parser.add_argument( - "--library-load-name", - help="Use custom name for dlopened library (default is SONAME)", - ) - parser.add_argument( - "--lazy-load", - help="Load library on first call to any of it's functions (default)", - dest="lazy_load", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-lazy-load", - help="Load library at program start", - dest="lazy_load", - action="store_false", - ) - parser.add_argument( - "--vtables", - help="Intercept virtual tables (EXPERIMENTAL)", - dest="vtables", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-vtables", - help="Do not intercept virtual tables (default)", - dest="vtables", - action="store_false", - ) - parser.add_argument( - "--no-weak-symbols", - help="Don't bind weak symbols", - dest="no_weak_symbols", - action="store_true", - default=False, - ) - parser.add_argument( - "--target", - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1], - ) - parser.add_argument( - "--symbol-list", - help="Path to file with symbols that should be present in wrapper " - "(all by default)", - ) - parser.add_argument( - "--symbol-prefix", - metavar="PFX", - help="Prefix wrapper symbols with PFX", - default="", - ) - parser.add_argument( - "-q", "--quiet", help="Do not print progress info", action="store_true" - ) - parser.add_argument( - "--outdir", "-o", help="Path to create wrapper at", default="./" - ) - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith("arm"): - target = "arm" # Handle armhf-..., armel-... - elif re.match(r"^i[0-9]86", args.target): - target = "i386" - elif args.target.startswith("mips64"): - target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith("mips"): - target = "mips" # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split("-")[0] - quiet = args.quiet - outdir = args.outdir +""") - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, "r") as f: - funs = [] - for line in re.split(r"\r?\n", f.read()): - line = re.sub(r"#.*", "", line) - line = line.strip() - if line: - funs.append(line) - - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + parser.add_argument('library', + metavar='LIB', + help="Library to be wrapped.") + parser.add_argument('--verbose', '-v', + help="Print diagnostic info", + action='count', + default=0) + parser.add_argument('--dlopen', + help="Emit dlopen call (default)", + dest='dlopen', action='store_true', default=True) + parser.add_argument('--no-dlopen', + help="Do not emit dlopen call (user must load/unload library himself)", + dest='dlopen', action='store_false') + parser.add_argument('--dlopen-callback', + help="Call user-provided custom callback to load library instead of dlopen", + default='') + parser.add_argument('--dlsym-callback', + help="Call user-provided custom callback to resolve a symbol, " + "instead of dlsym", + default='') + parser.add_argument('--library-load-name', + help="Use custom name for dlopened library (default is SONAME)") + parser.add_argument('--lazy-load', + help="Load library on first call to any of it's functions (default)", + dest='lazy_load', action='store_true', default=True) + parser.add_argument('--no-lazy-load', + help="Load library at program start", + dest='lazy_load', action='store_false') + parser.add_argument('--vtables', + help="Intercept virtual tables (EXPERIMENTAL)", + dest='vtables', action='store_true', default=False) + parser.add_argument('--no-vtables', + help="Do not intercept virtual tables (default)", + dest='vtables', action='store_false') + parser.add_argument('--no-weak-symbols', + help="Don't bind weak symbols", dest='no_weak_symbols', + action='store_true', default=False) + parser.add_argument('--target', + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1]) + parser.add_argument('--symbol-list', + help="Path to file with symbols that should be present in wrapper " + "(all by default)") + parser.add_argument('--symbol-prefix', + metavar='PFX', + help="Prefix wrapper symbols with PFX", + default='') + parser.add_argument('-q', '--quiet', + help="Do not print progress info", + action='store_true') + parser.add_argument('--outdir', '-o', + help="Path to create wrapper at", + default='./') + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith('arm'): + target = 'arm' # Handle armhf-..., armel-... + elif re.match(r'^i[0-9]86', args.target): + target = 'i386' + elif args.target.startswith('mips64'): + target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith('mips'): + target = 'mips' # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split('-')[0] + quiet = args.quiet + outdir = args.outdir + + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, 'r') as f: + funs = [] + for line in re.split(r'\r?\n', f.read()): + line = re.sub(r'#.*', '', line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, "arch", target) + target_dir = os.path.join(root, 'arch', target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=";") - cfg.read(target_dir + "/config.ini") + cfg = configparser.ConfigParser(inline_comment_prefixes=';') + cfg.read(target_dir + '/config.ini') - ptr_size = int(cfg["Arch"]["PointerSize"]) - symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) + ptr_size = int(cfg['Arch']['PointerSize']) + symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) - def is_exported(s): - conditions = [ - s["Bind"] != "LOCAL", - s["Type"] != "NOTYPE", - s["Ndx"] != "UND", - s["Name"] not in ["", "_init", "_fini"], - ] - if args.no_weak_symbols: - conditions.append(s["Bind"] != "WEAK") - return all(conditions) + def is_exported(s): + conditions = [ + s['Bind'] != 'LOCAL', + s['Type'] != 'NOTYPE', + s['Ndx'] != 'UND', + s['Name'] not in ['', '_init', '_fini']] + if args.no_weak_symbols: + conditions.append(s['Bind'] != 'WEAK') + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return ( - s["Type"] == "OBJECT" + def is_data_symbol(s): + return (s['Type'] == 'OBJECT' # Allow vtables if --vtables is on - and not (" for " in s["Demangled Name"] and args.vtables) - ) - - exported_data = [s["Name"] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn( - f"library '{input_name}' contains data symbols which won't be intercepted: " - + ", ".join(exported_data) - ) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s["Default"]: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s["Name"]) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn( - "some user-specified functions are not present in library: " - + ", ".join(missing_funs) - ) - funs = [name for name in funs if name in all_funs] + and not (' for ' in s['Demangled Name'] and args.vtables)) + + exported_data = [s['Name'] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn(f"library '{input_name}' contains data symbols which won't be intercepted: " + + ', '.join(exported_data)) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s['Default']: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s['Name']) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) + funs = [name for name in funs if name in all_funs] + + if verbose: + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") + + # Collect vtables + + if args.vtables: + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s['Name'] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + secs = collect_sections(input_name) if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") + print("Sections:") + for sec in secs: + print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}") - # Collect vtables + bites = read_unrelocated_data(input_name, cls_syms, secs) + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel['Symbol\'s Name + Addend'] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]['Demangled Name'] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) + + tramp_file = f'{suffix}.tramp.S' + with open(os.path.join(outdir, tramp_file), 'w') as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + '/table.S.tpl', 'r') as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + table_size=ptr_size*(len(funs) + 1)) + f.write(table_text) + + with open(target_dir + '/trampoline.S.tpl', 'r') as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i*ptr_size, + number=i) + f.write(tramp_text) + + # Generate C code + + init_file = f'{suffix}.init.c' + with open(os.path.join(outdir, init_file), 'w') as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: + if funs: + sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' + else: + sym_names = '' + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names) + f.write(init_text) if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match( - r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] - ) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s["Name"] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") - - secs = collect_sections(input_name) - if verbose: - print("Sections:") - for sec in secs: - print( - f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}" - ) - - bites = read_unrelocated_data(input_name, cls_syms, secs) - - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel["Symbol's Name + Addend"] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data( - cls_syms, bites, rels, ptr_size, symbol_reloc_types - ) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]["Demangled Name"] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print( - " " - + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) - ) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) - - tramp_file = f"{suffix}.tramp.S" - with open(os.path.join(outdir, tramp_file), "w") as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + "/table.S.tpl", "r") as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) - ) - f.write(table_text) - - with open(target_dir + "/trampoline.S.tpl", "r") as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i * ptr_size, - number=i, - ) - f.write(tramp_text) - - # Generate C code - - init_file = f"{suffix}.init.c" - with open(os.path.join(outdir, init_file), "w") as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: - if funs: - sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," - else: - sym_names = "" - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names, - ) - f.write(init_text) - if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - - -if __name__ == "__main__": - main() + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + +if __name__ == '__main__': + main() From 15ae1f3eb1685306656186dbb767dc33ea2b966c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 12 Sep 2025 18:23:56 +0000 Subject: [PATCH 8/8] fix(jax): use string literals for JAX type annotations to prevent import hangs Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/jax/utils/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 829dafae86..6a3c839608 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -55,7 +55,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial( do_atomic_virial: bool, has_ghost_atoms: bool - ) -> jax_export.Exported: + ) -> "jax_export.Exported": def call_lower_with_fixed_do_atomic_virial( coord: jnp.ndarray, atype: jnp.ndarray,