You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
style(jax): enable ANN rule and add comprehensive type hints to JAX backend (#4967)
This PR enables the Ruff ANN (type annotation) rule for the JAX backend
and adds comprehensive type hints to all methods across the core JAX
implementation.
## Changes Made
**Configuration Changes:**
- [x] Removed `ANN` from the exclude list for `deepmd/jax/**` in
`pyproject.toml`, enabling type annotation checking for the entire JAX
backend
- [x] Removed unnecessary exclusion for `deepmd/jax/jax2tf/**` as it now
passes ANN checks with proper type annotations
- [x] The global `ANN401` ignore remains active to allow necessary `Any`
type usage
**Type Annotations Added:**
- [x] **Base functions**: Added type hints to
`base_atomic_model_set_attr` and `forward_common_atomic` functions that
are used throughout the JAX backend
- [x] **Atomic models**: Complete type annotations for all classes in
`deepmd/jax/atomic_model/`
- [x] **Descriptors**: Type hints verified for all descriptor classes
- [x] **Fitting modules**: Type annotations confirmed for fitting
implementations
- [x] **Inference**: Added return types for `_eval_model`,
`_get_output_shape`, and nested evaluation functions
- [x] **Models**: Complete type hints for model classes including
complex HLO model parameters
- [x] **Utilities**: Type annotations for network classes, neighbor
statistics, and serialization functions
- [x] **Array protocol methods**: Proper typing for `__array__`,
`__array_namespace__`, `__dlpack__`, and `__dlpack_device__` methods
- [x] **Root level**: Type hints for common utility functions like
`scatter_sum`
- [x] **JAX2TF interop**: Added comprehensive type annotations to all
functions in the `deepmd/jax/jax2tf/` directory including:
- `format_nlist.py`: Return type annotation for nlist formatting
function
- `make_model.py`: Return type for model call wrapper function
- `nlist.py`: Type hints for neighbor list functions including
`nlist_distinguish_types`, `tf_outer`, and `extend_coord_with_ghosts`
- `region.py`: Type annotations for region distance calculations
- `serialization.py`: Complete type hints for all model serialization
functions and nested closures, using proper `jax.export.Exported` type
- `tfmodel.py`: Type annotations for TensorFlow model wrapper class
methods
**Bug Fixes:**
- [x] **Third-party file protection**: Reverted accidental changes to
`source/3rdparty/implib/implib-gen.py` which should not be modified
- [x] **Improved type accuracy**: Updated
`exported_whether_do_atomic_virial` return type from `Any` to
`jax.export.Exported` for better type safety
- [x] **Enhanced return type precision**: Updated
`TFModelWrapper.call()` and `TFModelWrapper.call_lower()` return types
from `Any` to `dict[str, jnp.ndarray]` for better type safety
- [x] **Improved HLO parameter types**: Updated HLO model stablehlo
parameters from `Any` to `bytearray` for more precise typing
- [x] **Fixed TF2 eager mode test hanging**: Used string literals for
JAX type annotations (`"jax_export.Exported"`) to prevent import-time
evaluation issues that could cause tests to hang in environments where
JAX is not fully available
## Technical Details
The implementation follows existing codebase patterns:
- Uses `Any` for complex interop types (properly ignored by global
ANN401 rule)
- Leverages forward references for circular dependencies (e.g.,
`"BaseModel"`)
- Maintains consistency with existing type annotation styles
- Handles JAX-specific array types (`jnp.ndarray`) and TensorFlow types
(`tnp.ndarray`, `tf.Tensor`) appropriately
- Uses appropriate return types for TensorFlow interop functions (e.g.,
`dict[str, tnp.ndarray]` for model outputs)
- Uses precise JAX export types like `jax.export.Exported` where
applicable
- Uses appropriate binary data types like `bytearray` for serialized HLO
models
- **Uses string literals for JAX types** to prevent import-time
evaluation issues in test environments where JAX may not be fully
available
## Validation
All core JAX backend directories now pass ruff checks with the ANN rule
enabled:
- `deepmd/jax/atomic_model/` ✅
- `deepmd/jax/descriptor/` ✅
- `deepmd/jax/fitting/` ✅
- `deepmd/jax/infer/` ✅
- `deepmd/jax/model/` ✅
- `deepmd/jax/utils/` ✅
- `deepmd/jax/jax2tf/` ✅ (now fully compliant with ANN rules)
- Root level files ✅
**Test Hanging Issue Fixed**: The TF2 eager mode test hanging issue was
caused by runtime evaluation of JAX type annotations in environments
where JAX was not fully available. This has been resolved by using
string literals for the problematic type annotations.
**Configuration Simplified**: Removed the specific exclusion for
`deepmd/jax/jax2tf/` directory as it now passes all ANN checks with
proper type annotations, making the configuration cleaner and more
consistent.
This change significantly improves type safety and developer experience
for the entire JAX backend while maintaining backward compatibility and
fixing the test hanging issue.
Fixes#4942.
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey3.medallia.com/?EAHeSx-AP01bZqG0Ld9QLQ) to start
the survey.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
0 commit comments