feat(torch): add torch extra and update config files#2290
Open
cosmicBboy wants to merge 27 commits into
Open
Conversation
- Add TensorDictSchema and Tensor component classes - Implement TensorDictModel with class-based schema definition - Add Field descriptor with check parameters (gt, ge, lt, le, etc.) - Create tensordict_engine for PyTorch dtype resolution - Implement TensorDictSchemaBackend with batch_size, key, dtype, shape validation - Add TensorDictCheckBackend for tensor value checks - Update spec examples to use correct Field API pattern - Add module layout documentation to spec
- Add test_tensordict_container.py for Tensor component and TensorDictSchema - Add test_tensordict_model.py for TensorDictModel class-based schemas - Add test_tensordict_engine.py for tensordict_engine dtype resolution - Add __init__.py for test package
- Rename tensordict_api.py to tensordict.py per spec - Fix Check calling to use get_backend() pattern - Add tensorclass validation tests - Fix spec module layout duplicate entry - Update test imports
- Add TestTensorDictErrorCases with 14+ error validation tests - Add TestTensorDictModelErrorCases with model error tests - Add TestTensorDictEngineErrorCases with engine error tests - Test batch_size, dtype, shape, missing keys, wrong input types - Test error messages contain relevant info - Test lazy validation collects multiple errors
- tensordict_engine: strip 'torch.' prefix from string dtypes - base.py: use data= not schema_context= for SchemaError - base.py: implement failure_cases_metadata for lazy validation - base.py: fix SchemaErrorReason codes (COLUMN_NOT_IN_DATAFRAME, CHECK_ERROR) - base.py: fix empty check result message - tests: fix dtype comparisons, SchemaError vs SchemaErrors, error messages
- Add pytorch_guide/ with index, schema, model, checks, and error reporting docs - Add reference/pytorch.rst API reference - Update index.md and reference/index.md to include pytorch guide
- Add torch/tensordict extra to pyproject.toml - Add torch to DATAFRAME_EXTRAS in noxfile.py - Fix tensordict/types.py to handle missing torch gracefully - Add torch and tensordict to requirements.txt and environment.yml - Remove non-existent modules from pytorch reference
… type annotation - Expose errors module and DataType in pandera.tensordict namespace - Use DataType (from tensordict_engine) instead of torch.Tensor for type annotations - Replace _field classmethod with Field function for field configuration - Fix validate to be a classmethod for correct class-level calling - Update tests and documentation to match other pandera backends
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## dev/tensordict #2290 +/- ##
==================================================
- Coverage 82.74% 78.39% -4.36%
==================================================
Files 180 193 +13
Lines 15086 15925 +839
==================================================
+ Hits 12483 12484 +1
- Misses 2603 3441 +838 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Run ruff check --fix (94 errors fixed) - Run ruff format on all files - Fix import ordering with isort - Convert to f-strings with flynt - Upgrade syntax with pyupgrade
- Add all required CheckResult fields in checks backend - Simplify tensor check application logic (no dict wrapping) - Fix validate() signatures to match base classes - Resolve type annotation errors with mypy config directives - Remove duplicate test function names - Clean up redundant code across tensordict modules
- Implement Tensor component class (pandera/api/tensordict/components.py) - Implement TensorDictSchema with keys parameter (pandera/api/tensordict/container.py) - Implement TensorDictModel for declarative schemas (pandera/api/tensordict/model.py) - Implement validation backend (pandera/backends/tensordict/base.py) - Implement check backend for tensor value validation (pandera/backends/tensordict/checks.py) - Register builtin checks for torch.Tensor types (pandera/backends/tensordict/builtin_checks.py) - Implement backend registration with tensorclass support (pandera/backends/tensordict/register.py) - Update entry point exports (pandera/tensordict.py) - Add tests for all functionality - Support both TensorDict and tensorclass objects - Support lazy validation and value checks via existing pandera Check API
- Fixed preprocess() signature to include key parameter - Changed CoreCheckResult to CheckResult for error handling - Fixed register_builtin_check fallback signature - Added missing imports (BaseSchemaBackend, BaseConfig, Tensor) - Fixed annotation access from .annotation to .raw_annotation
- Updated TensorDictModel to follow the same pattern as xarray.DatasetModel and DataFrameModel - Fields are now defined using pa.Field() directly in type annotations without needing custom _field classmethod - Type annotations specify dtype (e.g., torch.float32, torch.int64) instead of torch.Tensor - Added descriptors for field collection, schema caching, checks, and parsers - Exported SchemaError and SchemaErrors from pandera.tensordict module - Updated documentation to reflect correct API usage - Added comprehensive test coverage for TensorDictModel Fixes issue with class-based model implementation that required _field classmethod
- Add missing CHECK_KEY and PARSER_KEY constants to model.py - Fix mypy type errors by filtering None values from extract methods - Apply ruff formatting fixes
- error_reporting.md: Use str(err) instead of non-existent err.message attribute - tensordict_checks.md: Fix test data to pass less_than(1.0) check (0.95 vs 1.0) - tensordict_model.md: Fix indentation and replace unsupported checks parameter with isin - tensordict_schema.md: Update example to show validation errors without TensorDict creation conflicts
- Add optional torch import to model.py module namespace - Remove redundant dependency installation in docs session
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements phase one of the Tensordict spec