Skip to content

feat(torch): add torch extra and update config files#2290

Open
cosmicBboy wants to merge 27 commits into
dev/tensordictfrom
dev/pytorch-tensordict-phase1
Open

feat(torch): add torch extra and update config files#2290
cosmicBboy wants to merge 27 commits into
dev/tensordictfrom
dev/pytorch-tensordict-phase1

Conversation

@cosmicBboy
Copy link
Copy Markdown
Collaborator

Implements phase one of the Tensordict spec

Niels Bantilan and others added 10 commits April 13, 2026 20:57
- Add TensorDictSchema and Tensor component classes
- Implement TensorDictModel with class-based schema definition
- Add Field descriptor with check parameters (gt, ge, lt, le, etc.)
- Create tensordict_engine for PyTorch dtype resolution
- Implement TensorDictSchemaBackend with batch_size, key, dtype, shape validation
- Add TensorDictCheckBackend for tensor value checks
- Update spec examples to use correct Field API pattern
- Add module layout documentation to spec
- Add test_tensordict_container.py for Tensor component and TensorDictSchema
- Add test_tensordict_model.py for TensorDictModel class-based schemas
- Add test_tensordict_engine.py for tensordict_engine dtype resolution
- Add __init__.py for test package
- Rename tensordict_api.py to tensordict.py per spec
- Fix Check calling to use get_backend() pattern
- Add tensorclass validation tests
- Fix spec module layout duplicate entry
- Update test imports
- Add TestTensorDictErrorCases with 14+ error validation tests
- Add TestTensorDictModelErrorCases with model error tests
- Add TestTensorDictEngineErrorCases with engine error tests
- Test batch_size, dtype, shape, missing keys, wrong input types
- Test error messages contain relevant info
- Test lazy validation collects multiple errors
- tensordict_engine: strip 'torch.' prefix from string dtypes
- base.py: use data= not schema_context= for SchemaError
- base.py: implement failure_cases_metadata for lazy validation
- base.py: fix SchemaErrorReason codes (COLUMN_NOT_IN_DATAFRAME, CHECK_ERROR)
- base.py: fix empty check result message
- tests: fix dtype comparisons, SchemaError vs SchemaErrors, error messages
- Add pytorch_guide/ with index, schema, model, checks, and error reporting docs
- Add reference/pytorch.rst API reference
- Update index.md and reference/index.md to include pytorch guide
- Add torch/tensordict extra to pyproject.toml
- Add torch to DATAFRAME_EXTRAS in noxfile.py
- Fix tensordict/types.py to handle missing torch gracefully
- Add torch and tensordict to requirements.txt and environment.yml
- Remove non-existent modules from pytorch reference
… type annotation

- Expose errors module and DataType in pandera.tensordict namespace
- Use DataType (from tensordict_engine) instead of torch.Tensor for type annotations
- Replace _field classmethod with Field function for field configuration
- Fix validate to be a classmethod for correct class-level calling
- Update tests and documentation to match other pandera backends
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 15, 2026

Codecov Report

❌ Patch coverage is 2.31214% with 845 lines in your changes missing coverage. Please review.
✅ Project coverage is 78.39%. Comparing base (5c35aa6) to head (8550274).

Files with missing lines Patch % Lines
pandera/api/tensordict/model.py 0.00% 331 Missing ⚠️
pandera/backends/tensordict/base.py 0.00% 161 Missing ⚠️
pandera/engines/tensordict_engine.py 0.00% 88 Missing ⚠️
pandera/backends/tensordict/checks.py 0.00% 54 Missing ⚠️
pandera/api/tensordict/container.py 0.00% 46 Missing ⚠️
pandera/backends/tensordict/builtin_checks.py 0.00% 44 Missing ⚠️
pandera/api/tensordict/model_components.py 0.00% 34 Missing ⚠️
pandera/api/tensordict/components.py 0.00% 19 Missing ⚠️
pandera/api/tensordict/types.py 0.00% 15 Missing ⚠️
pandera/backends/tensordict/register.py 0.00% 15 Missing ⚠️
... and 11 more
Additional details and impacted files
@@                Coverage Diff                 @@
##           dev/tensordict    #2290      +/-   ##
==================================================
- Coverage           82.74%   78.39%   -4.36%     
==================================================
  Files                 180      193      +13     
  Lines               15086    15925     +839     
==================================================
+ Hits                12483    12484       +1     
- Misses               2603     3441     +838     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Niels Bantilan and others added 17 commits April 15, 2026 00:44
- Run ruff check --fix (94 errors fixed)
- Run ruff format on all files
- Fix import ordering with isort
- Convert to f-strings with flynt
- Upgrade syntax with pyupgrade
- Add all required CheckResult fields in checks backend
- Simplify tensor check application logic (no dict wrapping)
- Fix validate() signatures to match base classes
- Resolve type annotation errors with mypy config directives
- Remove duplicate test function names
- Clean up redundant code across tensordict modules
- Implement Tensor component class (pandera/api/tensordict/components.py)
- Implement TensorDictSchema with keys parameter (pandera/api/tensordict/container.py)
- Implement TensorDictModel for declarative schemas (pandera/api/tensordict/model.py)
- Implement validation backend (pandera/backends/tensordict/base.py)
- Implement check backend for tensor value validation (pandera/backends/tensordict/checks.py)
- Register builtin checks for torch.Tensor types (pandera/backends/tensordict/builtin_checks.py)
- Implement backend registration with tensorclass support (pandera/backends/tensordict/register.py)
- Update entry point exports (pandera/tensordict.py)
- Add tests for all functionality
- Support both TensorDict and tensorclass objects
- Support lazy validation and value checks via existing pandera Check API
- Fixed preprocess() signature to include key parameter
- Changed CoreCheckResult to CheckResult for error handling
- Fixed register_builtin_check fallback signature
- Added missing imports (BaseSchemaBackend, BaseConfig, Tensor)
- Fixed annotation access from .annotation to .raw_annotation
- Updated TensorDictModel to follow the same pattern as xarray.DatasetModel and DataFrameModel
- Fields are now defined using pa.Field() directly in type annotations without needing custom _field classmethod
- Type annotations specify dtype (e.g., torch.float32, torch.int64) instead of torch.Tensor
- Added descriptors for field collection, schema caching, checks, and parsers
- Exported SchemaError and SchemaErrors from pandera.tensordict module
- Updated documentation to reflect correct API usage
- Added comprehensive test coverage for TensorDictModel

Fixes issue with class-based model implementation that required _field classmethod
- Add missing CHECK_KEY and PARSER_KEY constants to model.py
- Fix mypy type errors by filtering None values from extract methods
- Apply ruff formatting fixes
- error_reporting.md: Use str(err) instead of non-existent err.message attribute
- tensordict_checks.md: Fix test data to pass less_than(1.0) check (0.95 vs 1.0)
- tensordict_model.md: Fix indentation and replace unsupported checks parameter with isin
- tensordict_schema.md: Update example to show validation errors without TensorDict creation conflicts
- Add optional torch import to model.py module namespace
- Remove redundant dependency installation in docs session
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant