diff --git a/pyproject.toml b/pyproject.toml index 3dd8a037..1221f3c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,14 +149,14 @@ exclude = [ "dist", "docs" ] -ignore-init-module-imports = true -[tool.ruff.per-file-ignores] + +[tool.ruff.lint.per-file-ignores] "setup.py" = ["D100", "SIM115"] "__about__.py" = ["D100"] "__init__.py" = ["D100"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] # Use numpy-style docstrings. convention = "numpy" diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ad7bda85..a642fa67 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -5,8 +5,9 @@ import os import re +from collections.abc import Iterable from dataclasses import MISSING, dataclass, field -from typing import Any, Dict, Iterable, List, Optional +from typing import Any from omegaconf import OmegaConf @@ -103,22 +104,22 @@ class DataConfig: """ - target: Optional[List[str]] = field( + target: list[str] | None = field( default=None, metadata={ "help": "A list of strings with the names of the target column(s)." " It is mandatory for all except SSL tasks." }, ) - continuous_cols: List = field( + continuous_cols: list = field( default_factory=list, metadata={"help": "Column names of the numeric fields. Defaults to []"}, ) - categorical_cols: List = field( + categorical_cols: list = field( default_factory=list, metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"}, ) - date_columns: List = field( + date_columns: list = field( default_factory=list, metadata={ "help": "(Column names, Freq) tuples of the date fields. For eg. a field named" @@ -131,14 +132,14 @@ class DataConfig: default=True, metadata={"help": "Whether or not to encode the derived variables from date"}, ) - validation_split: Optional[float] = field( + validation_split: float | None = field( default=0.2, metadata={ "help": "Percentage of Training rows to keep aside as validation." " Used only if Validation Data is not given separately" }, ) - continuous_feature_transform: Optional[str] = field( + continuous_feature_transform: str | None = field( default=None, metadata={ "help": "Whether or not to transform the features before modelling. By default it is turned off.", @@ -164,7 +165,7 @@ class DataConfig: " the noise is only applied for QuantileTransformer" }, ) - num_workers: Optional[int] = field( + num_workers: int | None = field( default=0, metadata={"help": "The number of workers used for data loading. For windows always set to 0"}, ) @@ -186,7 +187,7 @@ class DataConfig: metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"}, ) - dataloader_kwargs: Dict[str, Any] = field( + dataloader_kwargs: dict[str, Any] = field( default_factory=dict, metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."}, ) @@ -229,19 +230,19 @@ class InferredConfig: continuous_dim: int = field( metadata={"help": "The number of continuous features"}, ) - output_dim: Optional[int] = field( + output_dim: int | None = field( default=None, metadata={"help": "The number of output targets"}, ) - output_cardinality: Optional[List[int]] = field( + output_cardinality: list[int] | None = field( default=None, metadata={"help": "The number of unique values in classification output"}, ) - categorical_cardinality: Optional[List[int]] = field( + categorical_cardinality: list[int] | None = field( default=None, metadata={"help": "The number of unique values in categorical features"}, ) - embedding_dims: Optional[List] = field( + embedding_dims: list | None = field( default=None, metadata={ "help": "The dimensions of the embedding for each categorical column as a list of tuples " @@ -384,15 +385,15 @@ class TrainerConfig: }, ) max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"}) - min_epochs: Optional[int] = field( + min_epochs: int | None = field( default=1, metadata={"help": "Force training for at least these many epochs. 1 by default"}, ) - max_time: Optional[int] = field( + max_time: int | None = field( default=None, metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"}, ) - accelerator: Optional[str] = field( + accelerator: str | None = field( default="auto", metadata={ "help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'." @@ -400,14 +401,14 @@ class TrainerConfig: "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"], }, ) - devices: Optional[int] = field( + devices: int | None = field( default=-1, metadata={ "help": "Number of devices to train on. -1 uses all available devices." " By default uses all available devices (-1)", }, ) - devices_list: Optional[List[int]] = field( + devices_list: list[int] | None = field( default=None, metadata={ "help": "List of devices to train on (list). If specified, takes precedence over `devices` argument." @@ -454,7 +455,7 @@ class TrainerConfig: "help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility." }, ) - profiler: Optional[str] = field( + profiler: str | None = field( default=None, metadata={ "help": "To profile individual steps during training and assist in identifying bottlenecks." @@ -462,7 +463,7 @@ class TrainerConfig: "choices": [None, "simple", "advanced", "pytorch"], }, ) - early_stopping: Optional[str] = field( + early_stopping: str | None = field( default="valid_loss", metadata={ "help": "The loss/metric that needed to be monitored for early stopping." @@ -484,14 +485,14 @@ class TrainerConfig: default=3, metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"}, ) - early_stopping_kwargs: Optional[Dict[str, Any]] = field( + early_stopping_kwargs: dict[str, Any] | None = field( default_factory=lambda: {}, metadata={ "help": "Additional keyword arguments for the early stopping callback." " See the documentation for the PyTorch Lightning EarlyStopping callback for more details." }, ) - checkpoints: Optional[str] = field( + checkpoints: str | None = field( default="valid_loss", metadata={ "help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints" @@ -505,7 +506,7 @@ class TrainerConfig: default=1, metadata={"help": "Number of training steps between checkpoints"}, ) - checkpoints_name: Optional[str] = field( + checkpoints_name: str | None = field( default=None, metadata={ "help": "The name under which the models will be saved. If left blank," @@ -521,7 +522,7 @@ class TrainerConfig: default=1, metadata={"help": "The number of best models to save"}, ) - checkpoints_kwargs: Optional[Dict[str, Any]] = field( + checkpoints_kwargs: dict[str, Any] | None = field( default_factory=lambda: {}, metadata={ "help": "Additional keyword arguments for the checkpoints callback. See the documentation" @@ -553,7 +554,7 @@ class TrainerConfig: default=42, metadata={"help": "Seed for random number generators. Defaults to 42"}, ) - trainer_kwargs: Dict[str, Any] = field( + trainer_kwargs: dict[str, Any] = field( default_factory=dict, metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."}, ) @@ -611,14 +612,14 @@ class ExperimentConfig: }, ) - run_name: Optional[str] = field( + run_name: str | None = field( default=None, metadata={ "help": "The name of the run; a specific identifier to recognize the run." " If left blank, will be assigned a auto-generated name" }, ) - exp_watch: Optional[str] = field( + exp_watch: str | None = field( default=None, metadata={ "help": "The level of logging required. Can be `gradients`, `parameters`, `all` or `None`." @@ -690,11 +691,11 @@ class OptimizerConfig: " for example 'torch_optimizer.RAdam'." }, ) - optimizer_params: Dict = field( + optimizer_params: dict = field( default_factory=lambda: {}, metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."}, ) - lr_scheduler: Optional[str] = field( + lr_scheduler: str | None = field( default=None, metadata={ "help": "The name of the LearningRateScheduler to use, if any, from" @@ -702,17 +703,17 @@ class OptimizerConfig: " If None, will not use any scheduler. Defaults to `None`", }, ) - lr_scheduler_params: Optional[Dict] = field( + lr_scheduler_params: dict | None = field( default_factory=lambda: {}, metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."}, ) - lr_scheduler_monitor_metric: Optional[str] = field( + lr_scheduler_monitor_metric: str | None = field( default="valid_loss", metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"}, ) - lr_scheduler_interval: Optional[str] = field( + lr_scheduler_interval: str | None = field( default="epoch", metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."}, ) @@ -823,7 +824,7 @@ class ModelConfig: } ) - head: Optional[str] = field( + head: str | None = field( default="LinearHead", metadata={ "help": "The head to be used for the model. Should be one of the heads defined" @@ -832,14 +833,14 @@ class ModelConfig: }, ) - head_config: Optional[Dict] = field( + head_config: dict | None = field( default_factory=lambda: {"layers": ""}, metadata={ "help": "The config as a dict which defines the head." " If left empty, will be initialized as default linear head." }, ) - embedding_dims: Optional[List] = field( + embedding_dims: list | None = field( default=None, metadata={ "help": "The dimensions of the embedding for each categorical column as a list of tuples " @@ -860,7 +861,7 @@ class ModelConfig: default=1e-3, metadata={"help": "The learning rate of the model. Defaults to 1e-3."}, ) - loss: Optional[str] = field( + loss: str | None = field( default=None, metadata={ "help": "The loss function to be applied. By Default it is MSELoss for regression " @@ -868,7 +869,7 @@ class ModelConfig: "leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification" }, ) - metrics: Optional[List[str]] = field( + metrics: list[str] | None = field( default=None, metadata={ "help": "the list of metrics you need to track during training. The metrics should be one " @@ -877,7 +878,7 @@ class ModelConfig: "and mean_squared_error for regression" }, ) - metrics_prob_input: Optional[List[bool]] = field( + metrics_prob_input: list[bool] | None = field( default=None, metadata={ "help": "Is a mandatory parameter for classification metrics defined in the config. This defines " @@ -885,7 +886,7 @@ class ModelConfig: "as the number of metrics. Defaults to None." }, ) - metrics_params: Optional[List] = field( + metrics_params: list | None = field( default=None, metadata={ "help": "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` " @@ -893,7 +894,7 @@ class ModelConfig: "`multiclass`." }, ) - target_range: Optional[List] = field( + target_range: list | None = field( default=None, metadata={ "help": "The range in which we should limit the output variable. " @@ -902,7 +903,7 @@ class ModelConfig: }, ) - virtual_batch_size: Optional[int] = field( + virtual_batch_size: int | None = field( default=None, metadata={ "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's " @@ -1001,7 +1002,7 @@ class SSLModelConfig: task: str = field(init=False, default="ssl") - encoder_config: Optional[ModelConfig] = field( + encoder_config: ModelConfig | None = field( default=None, metadata={ "help": "The config of the encoder to be used for the model." @@ -1009,7 +1010,7 @@ class SSLModelConfig: }, ) - decoder_config: Optional[ModelConfig] = field( + decoder_config: ModelConfig | None = field( default=None, metadata={ "help": "The config of decoder to be used for the model." @@ -1017,7 +1018,7 @@ class SSLModelConfig: }, ) - embedding_dims: Optional[List] = field( + embedding_dims: list | None = field( default=None, metadata={ "help": "The dimensions of the embedding for each categorical column as a list of tuples " @@ -1033,7 +1034,7 @@ class SSLModelConfig: default=True, metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."}, ) - virtual_batch_size: Optional[int] = field( + virtual_batch_size: int | None = field( default=None, metadata={ "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's " diff --git a/src/pytorch_tabular/models/autoint/config.py b/src/pytorch_tabular/models/autoint/config.py index 511b44d3..646c5af6 100644 --- a/src/pytorch_tabular/models/autoint/config.py +++ b/src/pytorch_tabular/models/autoint/config.py @@ -4,7 +4,6 @@ """AutomaticFeatureInteraction Config.""" from dataclasses import dataclass, field -from typing import Optional from pytorch_tabular.config import ModelConfig @@ -138,7 +137,7 @@ class AutoIntConfig(ModelConfig): default=16, metadata={"help": "The dimensions of the embedding for continuous and categorical columns. Defaults to 16"}, ) - embedding_initialization: Optional[str] = field( + embedding_initialization: str | None = field( default="kaiming_uniform", metadata={ "help": "Initialization scheme for the embedding layers. Defaults to `kaiming`", @@ -158,7 +157,7 @@ class AutoIntConfig(ModelConfig): " For more details refer to Appendix A of the TabTransformer paper. Defaults to False" }, ) - share_embedding_strategy: Optional[str] = field( + share_embedding_strategy: str | None = field( default="fraction", metadata={ "help": "There are two strategies in adding shared embeddings." diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 380f4cab..9b075470 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -6,9 +6,10 @@ import importlib import warnings from abc import ABCMeta, abstractmethod +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from pathlib import Path +from typing import Any import numpy as np import pytorch_lightning as pl @@ -59,7 +60,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo return merged_config -def _create_optimizer(optimizer: Union[str, Callable]) -> Type[Optimizer]: +def _create_optimizer(optimizer: str | Callable) -> type[Optimizer]: """Instantiate Optimizer.""" if callable(optimizer): return optimizer @@ -74,11 +75,11 @@ class BaseModel(pl.LightningModule, metaclass=ABCMeta): def __init__( self, config: DictConfig, - custom_loss: Optional[torch.nn.Module] = None, - custom_metrics: Optional[List[Callable]] = None, - custom_metrics_prob_inputs: Optional[List[bool]] = None, - custom_optimizer: Optional[torch.optim.Optimizer] = None, - custom_optimizer_params: Dict = {}, + custom_loss: torch.nn.Module | None = None, + custom_metrics: list[Callable] | None = None, + custom_metrics_prob_inputs: list[bool] | None = None, + custom_optimizer: torch.optim.Optimizer | None = None, + custom_optimizer_params: dict = {}, **kwargs, ): """Base Model for PyTorch Tabular. @@ -210,7 +211,7 @@ def head(self): @classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[str, Path], + checkpoint_path: str | Path, map_location=None, strict=True, **kwargs, @@ -267,7 +268,7 @@ def _setup_metrics(self): else: self.metrics = self.custom_metrics - def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor: + def calculate_loss(self, output: dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor: """Calculates the loss for the model. Args: @@ -344,7 +345,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: boo def calculate_metrics( self, y: torch.Tensor, y_hat: torch.Tensor, tag: str, sync_dist: bool = False - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: """Calculates the metrics for the model. Args: @@ -433,12 +434,12 @@ def data_aware_initialization(self, datamodule): """Performs data-aware initialization of the model when defined.""" pass - def compute_backbone(self, x: Dict) -> torch.Tensor: + def compute_backbone(self, x: dict) -> torch.Tensor: # Returns output x = self.backbone(x) return x - def embed_input(self, x: Dict) -> torch.Tensor: + def embed_input(self, x: dict) -> torch.Tensor: return self.embedding_layer(x) def apply_output_sigmoid_scaling(self, y_hat: torch.Tensor) -> torch.Tensor: @@ -458,7 +459,7 @@ def apply_output_sigmoid_scaling(self, y_hat: torch.Tensor) -> torch.Tensor: y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min) return y_hat - def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> Dict[str, Any]: + def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> dict[str, Any]: """Packs the output of the model. Args: @@ -476,7 +477,7 @@ def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> D return {"logits": y_hat} return {"logits": y_hat, "backbone_features": backbone_features} - def compute_head(self, backbone_features: Tensor) -> Dict[str, Any]: + def compute_head(self, backbone_features: Tensor) -> dict[str, Any]: """Computes the head of the model. Args: @@ -490,7 +491,7 @@ def compute_head(self, backbone_features: Tensor) -> Dict[str, Any]: y_hat = self.apply_output_sigmoid_scaling(y_hat) return self.pack_output(y_hat, backbone_features) - def forward(self, x: Dict) -> Dict[str, Any]: + def forward(self, x: dict) -> dict[str, Any]: """The forward pass of the model. Args: @@ -501,7 +502,7 @@ def forward(self, x: Dict) -> Dict[str, Any]: x = self.compute_backbone(x) return self.compute_head(x) - def predict(self, x: Dict, ret_model_output: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]: + def predict(self, x: dict, ret_model_output: bool = False) -> torch.Tensor | tuple[torch.Tensor, dict]: """Predicts the output of the model. Args: @@ -701,13 +702,13 @@ def __init__( self, backbone: nn.Module, head: str, - head_config: Dict, + head_config: dict, config: DictConfig, - custom_loss: Optional[torch.nn.Module] = None, - custom_metrics: Optional[List[Callable]] = None, - custom_metrics_prob_inputs: Optional[List[bool]] = None, - custom_optimizer: Optional[torch.optim.Optimizer] = None, - custom_optimizer_params: Dict = {}, + custom_loss: torch.nn.Module | None = None, + custom_metrics: list[Callable] | None = None, + custom_metrics_prob_inputs: list[bool] | None = None, + custom_optimizer: torch.optim.Optimizer | None = None, + custom_optimizer_params: dict = {}, **kwargs, ): assert hasattr(config, "loss") or custom_loss is not None, "Loss function not defined in the config" diff --git a/src/pytorch_tabular/models/common/heads/config.py b/src/pytorch_tabular/models/common/heads/config.py index f84c354f..a38ae0b0 100644 --- a/src/pytorch_tabular/models/common/heads/config.py +++ b/src/pytorch_tabular/models/common/heads/config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import List, Optional # from typing import Any, Dict, Iterable, List, Optional @@ -114,7 +113,7 @@ class MixtureDensityHeadConfig: "help": "Whether to have a bias term in the sigma layer. Defaults to False", }, ) - mu_bias_init: Optional[List] = field( + mu_bias_init: list | None = field( default=None, metadata={ "help": "To initialize the bias parameter of the mu layer to predefined cluster centers." @@ -123,7 +122,7 @@ class MixtureDensityHeadConfig: }, ) - weight_regularization: Optional[int] = field( + weight_regularization: int | None = field( default=2, metadata={ "help": "Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2", @@ -131,25 +130,25 @@ class MixtureDensityHeadConfig: }, ) - lambda_sigma: Optional[float] = field( + lambda_sigma: float | None = field( default=0.1, metadata={ "help": "The regularization constant for weight regularization of sigma layer. Defaults to 0.1", }, ) - lambda_pi: Optional[float] = field( + lambda_pi: float | None = field( default=0.1, metadata={ "help": "The regularization constant for weight regularization of pi layer. Defaults to 0.1", }, ) - lambda_mu: Optional[float] = field( + lambda_mu: float | None = field( default=0, metadata={ "help": "The regularization constant for weight regularization of mu layer. Defaults to 0", }, ) - softmax_temperature: Optional[float] = field( + softmax_temperature: float | None = field( default=1, metadata={ "help": "The temperature to be used in the gumbel softmax of the mixing coefficients." diff --git a/src/pytorch_tabular/models/common/layers/__init__.py b/src/pytorch_tabular/models/common/layers/__init__.py index 9e5f22cc..8f6a48db 100644 --- a/src/pytorch_tabular/models/common/layers/__init__.py +++ b/src/pytorch_tabular/models/common/layers/__init__.py @@ -1,10 +1,26 @@ from . import activations from .batch_norm import GBN, BatchNorm1d -from .embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer, SharedEmbeddings -from .gated_units import GEGLU, GatedFeatureLearningUnit, PositionWiseFeedForward, ReGLU, SwiGLU +from .embeddings import ( + Embedding1dLayer, + Embedding2dLayer, + PreEncoded1dLayer, + SharedEmbeddings, +) +from .gated_units import ( + GEGLU, + GatedFeatureLearningUnit, + PositionWiseFeedForward, + ReGLU, + SwiGLU, +) from .misc import Add, Lambda, ModuleWithInit, Residual from .soft_trees import ODST, NeuralDecisionTree -from .transformers import AddNorm, AppendCLSToken, MultiHeadedAttention, TransformerEncoderBlock +from .transformers import ( + AddNorm, + AppendCLSToken, + MultiHeadedAttention, + TransformerEncoderBlock, +) __all__ = [ "PreEncoded1dLayer", diff --git a/src/pytorch_tabular/models/common/layers/embeddings.py b/src/pytorch_tabular/models/common/layers/embeddings.py index bbc7f2ba..38705cd4 100644 --- a/src/pytorch_tabular/models/common/layers/embeddings.py +++ b/src/pytorch_tabular/models/common/layers/embeddings.py @@ -1,7 +1,7 @@ # W605 import math from functools import partial -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import torch from torch import nn @@ -59,10 +59,10 @@ class PreEncoded1dLayer(nn.Module): def __init__( self, continuous_dim: int, - categorical_dim: Tuple[int, int], + categorical_dim: tuple[int, int], embedding_dropout: float = 0.0, batch_norm_continuous_input: bool = False, - virtual_batch_size: Optional[int] = None, + virtual_batch_size: int | None = None, ): super().__init__() self.continuous_dim = continuous_dim @@ -77,7 +77,7 @@ def __init__( if batch_norm_continuous_input: self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size) - def forward(self, x: Dict[str, Any]) -> torch.Tensor: + def forward(self, x: dict[str, Any]) -> torch.Tensor: assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features" # (B, N) continuous_data, categorical_data = ( @@ -114,10 +114,10 @@ class Embedding1dLayer(nn.Module): def __init__( self, continuous_dim: int, - categorical_embedding_dims: Tuple[int, int], + categorical_embedding_dims: tuple[int, int], embedding_dropout: float = 0.0, batch_norm_continuous_input: bool = False, - virtual_batch_size: Optional[int] = None, + virtual_batch_size: int | None = None, ): super().__init__() self.continuous_dim = continuous_dim @@ -134,7 +134,7 @@ def __init__( if batch_norm_continuous_input: self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size) - def forward(self, x: Dict[str, Any]) -> torch.Tensor: + def forward(self, x: dict[str, Any]) -> torch.Tensor: assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features" # (B, N) continuous_data, categorical_data = ( @@ -178,15 +178,15 @@ class Embedding2dLayer(nn.Module): def __init__( self, continuous_dim: int, - categorical_cardinality: List[int], + categorical_cardinality: list[int], embedding_dim: int, - shared_embedding_strategy: Optional[str] = None, + shared_embedding_strategy: str | None = None, frac_shared_embed: float = 0.25, embedding_bias: bool = False, batch_norm_continuous_input: bool = False, - virtual_batch_size: Optional[int] = None, + virtual_batch_size: int | None = None, embedding_dropout: float = 0.0, - initialization: Optional[str] = None, + initialization: str | None = None, ): """ Args: @@ -266,7 +266,7 @@ def __init__( else: self.embd_dropout = None - def forward(self, x: Dict[str, Any]) -> torch.Tensor: + def forward(self, x: dict[str, Any]) -> torch.Tensor: assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features" # (B, N) continuous_data, categorical_data = ( diff --git a/src/pytorch_tabular/models/common/layers/gated_units.py b/src/pytorch_tabular/models/common/layers/gated_units.py index 1eac0f2a..e7f2cd5a 100644 --- a/src/pytorch_tabular/models/common/layers/gated_units.py +++ b/src/pytorch_tabular/models/common/layers/gated_units.py @@ -2,7 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT import random -from typing import Callable +from collections.abc import Callable import torch import torch.nn as nn diff --git a/src/pytorch_tabular/models/common/layers/misc.py b/src/pytorch_tabular/models/common/layers/misc.py index a1a250e8..a4214ec8 100644 --- a/src/pytorch_tabular/models/common/layers/misc.py +++ b/src/pytorch_tabular/models/common/layers/misc.py @@ -1,5 +1,5 @@ # W605 -from typing import Callable, Union +from collections.abc import Callable import torch from torch import nn @@ -58,7 +58,7 @@ def __call__(self, *args, **kwargs): class Add(nn.Module): """A module that adds a constant/parameter value to the input.""" - def __init__(self, add_value: Union[float, torch.Tensor]): + def __init__(self, add_value: float | torch.Tensor): """Initialize the module. Args: diff --git a/src/pytorch_tabular/models/common/layers/soft_trees.py b/src/pytorch_tabular/models/common/layers/soft_trees.py index 921f5c3e..74729790 100644 --- a/src/pytorch_tabular/models/common/layers/soft_trees.py +++ b/src/pytorch_tabular/models/common/layers/soft_trees.py @@ -1,5 +1,5 @@ import random -from typing import Callable +from collections.abc import Callable from warnings import warn import numpy as np diff --git a/src/pytorch_tabular/models/common/layers/transformers.py b/src/pytorch_tabular/models/common/layers/transformers.py index 4711c19f..523ef746 100644 --- a/src/pytorch_tabular/models/common/layers/transformers.py +++ b/src/pytorch_tabular/models/common/layers/transformers.py @@ -1,6 +1,5 @@ # W605 import math -from typing import Optional import torch from einops import rearrange @@ -83,7 +82,7 @@ def __init__( keep_attn: bool = True, ff_dropout: float = 0.1, add_norm_dropout: float = 0.1, - transformer_head_dim: Optional[int] = None, + transformer_head_dim: int | None = None, ): """ Args: diff --git a/src/pytorch_tabular/models/danet/config.py b/src/pytorch_tabular/models/danet/config.py index 13978296..026b50a5 100644 --- a/src/pytorch_tabular/models/danet/config.py +++ b/src/pytorch_tabular/models/danet/config.py @@ -4,7 +4,6 @@ """AutomaticFeatureInteraction Config.""" from dataclasses import dataclass, field -from typing import Optional from pytorch_tabular.config import ModelConfig @@ -85,7 +84,7 @@ class DANetConfig(ModelConfig): }, ) - abstlay_dim_2: Optional[int] = field( + abstlay_dim_2: int | None = field( default=None, metadata={ "help": "The dimension for the intermediate output in the second ABSTLAY layer in a Block." @@ -108,7 +107,7 @@ class DANetConfig(ModelConfig): " https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity" }, ) - virtual_batch_size: Optional[int] = field( + virtual_batch_size: int | None = field( default=256, metadata={ "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's " diff --git a/src/pytorch_tabular/models/ft_transformer/config.py b/src/pytorch_tabular/models/ft_transformer/config.py index 3697da51..12e683f3 100644 --- a/src/pytorch_tabular/models/ft_transformer/config.py +++ b/src/pytorch_tabular/models/ft_transformer/config.py @@ -4,7 +4,6 @@ """AutomaticFeatureInteraction Config.""" from dataclasses import dataclass, field -from typing import Optional from pytorch_tabular.config import ModelConfig @@ -108,7 +107,7 @@ class FTTransformerConfig(ModelConfig): default=32, metadata={"help": "The embedding dimension for the input categorical features. Defaults to 32"}, ) - embedding_initialization: Optional[str] = field( + embedding_initialization: str | None = field( default="kaiming_uniform", metadata={ "help": "Initialization scheme for the embedding layers. Defaults to `kaiming`", @@ -128,7 +127,7 @@ class FTTransformerConfig(ModelConfig): " to Appendix A of the TabTransformer paper. Defaults to False" }, ) - share_embedding_strategy: Optional[str] = field( + share_embedding_strategy: str | None = field( default="fraction", metadata={ "help": "There are two strategies in adding shared embeddings." @@ -161,7 +160,7 @@ class FTTransformerConfig(ModelConfig): default=6, metadata={"help": "The number of layers of stacked Multi-Headed Attention layers. Defaults to 6"}, ) - transformer_head_dim: Optional[int] = field( + transformer_head_dim: int | None = field( default=None, metadata={ "help": "The number of hidden units in the Multi-Headed Attention layers." diff --git a/src/pytorch_tabular/models/gandalf/gandalf.py b/src/pytorch_tabular/models/gandalf/gandalf.py index e3bf7f55..c259cd1c 100644 --- a/src/pytorch_tabular/models/gandalf/gandalf.py +++ b/src/pytorch_tabular/models/gandalf/gandalf.py @@ -5,7 +5,11 @@ import torch.nn as nn from omegaconf import DictConfig -from pytorch_tabular.models.common.layers import Add, Embedding1dLayer, GatedFeatureLearningUnit +from pytorch_tabular.models.common.layers import ( + Add, + Embedding1dLayer, + GatedFeatureLearningUnit, +) from pytorch_tabular.models.common.layers.activations import t_softmax from pytorch_tabular.utils import get_logger diff --git a/src/pytorch_tabular/models/gate/gate_model.py b/src/pytorch_tabular/models/gate/gate_model.py index 7acf7065..dff9e6e7 100644 --- a/src/pytorch_tabular/models/gate/gate_model.py +++ b/src/pytorch_tabular/models/gate/gate_model.py @@ -6,8 +6,19 @@ from omegaconf import DictConfig from pytorch_tabular.models.common.heads import blocks -from pytorch_tabular.models.common.layers import Add, Embedding1dLayer, GatedFeatureLearningUnit, NeuralDecisionTree -from pytorch_tabular.models.common.layers.activations import entmax15, entmoid15, sparsemax, sparsemoid, t_softmax +from pytorch_tabular.models.common.layers import ( + Add, + Embedding1dLayer, + GatedFeatureLearningUnit, + NeuralDecisionTree, +) +from pytorch_tabular.models.common.layers.activations import ( + entmax15, + entmoid15, + sparsemax, + sparsemoid, + t_softmax, +) from pytorch_tabular.utils import get_logger from ..base_model import BaseModel diff --git a/src/pytorch_tabular/models/mixture_density/config.py b/src/pytorch_tabular/models/mixture_density/config.py index 428e5871..2d1bb4ff 100644 --- a/src/pytorch_tabular/models/mixture_density/config.py +++ b/src/pytorch_tabular/models/mixture_density/config.py @@ -4,7 +4,6 @@ """Mixture Density Head Config.""" from dataclasses import dataclass, field -from typing import Dict from pytorch_tabular.config.config import ModelConfig @@ -72,12 +71,12 @@ class MDNConfig(ModelConfig): " The config class should be a valid module path from `models`. e.g. `FTTransformerConfig`" }, ) - backbone_config_params: Dict = field( + backbone_config_params: dict = field( default=None, metadata={"help": "The dict of config parameters for defining the Backbone."}, ) head: str = field(init=False, default="MixtureDensityHead") - head_config: Dict = field( + head_config: dict = field( default=None, metadata={"help": "The config for defining the Mixed Density Network Head"}, ) diff --git a/src/pytorch_tabular/models/mixture_density/mdn.py b/src/pytorch_tabular/models/mixture_density/mdn.py index 6ae02db1..5d120f88 100644 --- a/src/pytorch_tabular/models/mixture_density/mdn.py +++ b/src/pytorch_tabular/models/mixture_density/mdn.py @@ -3,7 +3,6 @@ # For license information, see LICENSE.TXT """Mixture Density Models.""" -from typing import Dict, Optional, Union import torch import torch.nn as nn @@ -86,7 +85,7 @@ def _build_network(self): self._head = self._get_head_from_config() # Redefining forward because TabTransformer flow is slightly different - def forward(self, x: Dict): + def forward(self, x: dict): if isinstance(self.backbone, TabTransformerBackbone): if self.hparams.categorical_dim > 0: x_cat = self.embed_input({"categorical": x["categorical"]}) @@ -98,7 +97,7 @@ def forward(self, x: Dict): # Redefining compute_backbone because TabTransformer flow flow is slightly different - def compute_backbone(self, x: Union[Dict, torch.Tensor]): + def compute_backbone(self, x: dict | torch.Tensor): # Returns output if isinstance(self.backbone, TabTransformerBackbone): x = self.backbone(x["categorical"], x["continuous"]) @@ -110,11 +109,11 @@ def compute_head(self, x: Tensor): pi, sigma, mu = self.head(x) return {"pi": pi, "sigma": sigma, "mu": mu, "backbone_features": x} - def predict(self, x: Dict): + def predict(self, x: dict): ret_value = self.forward(x) return self.head.generate_point_predictions(ret_value["pi"], ret_value["sigma"], ret_value["mu"]) - def sample(self, x: Dict, n_samples: Optional[int] = None, ret_model_output=False): + def sample(self, x: dict, n_samples: int | None = None, ret_model_output=False): ret_value = self.forward(x) samples = self.head.generate_samples(ret_value["pi"], ret_value["sigma"], ret_value["mu"], n_samples) if ret_model_output: diff --git a/src/pytorch_tabular/models/node/config.py b/src/pytorch_tabular/models/node/config.py index a5374f8c..1d7a2f3f 100644 --- a/src/pytorch_tabular/models/node/config.py +++ b/src/pytorch_tabular/models/node/config.py @@ -1,6 +1,5 @@ import warnings from dataclasses import dataclass, field -from typing import Optional from pytorch_tabular.config import ModelConfig @@ -139,7 +138,7 @@ class NodeConfig(ModelConfig): "choices": ["entmoid15", "sparsemoid"], }, ) - max_features: Optional[int] = field( + max_features: int | None = field( default=None, metadata={ "help": "If not None, sets a max limit on the number of features to be carried forward" @@ -198,7 +197,7 @@ class NodeConfig(ModelConfig): }, ) - head: Optional[str] = field( + head: str | None = field( default=None, ) diff --git a/src/pytorch_tabular/models/tab_transformer/config.py b/src/pytorch_tabular/models/tab_transformer/config.py index d38986a5..67fac046 100644 --- a/src/pytorch_tabular/models/tab_transformer/config.py +++ b/src/pytorch_tabular/models/tab_transformer/config.py @@ -4,7 +4,6 @@ """AutomaticFeatureInteraction Config.""" from dataclasses import dataclass, field -from typing import Optional from pytorch_tabular.config import ModelConfig @@ -105,7 +104,7 @@ class TabTransformerConfig(ModelConfig): default=32, metadata={"help": "The embedding dimension for the input categorical features. Defaults to 32"}, ) - embedding_initialization: Optional[str] = field( + embedding_initialization: str | None = field( default="kaiming_uniform", metadata={ "help": "Initialization scheme for the embedding layers. Defaults to `kaiming`", @@ -125,7 +124,7 @@ class TabTransformerConfig(ModelConfig): " to Appendix A of the TabTransformer paper. Defaults to False" }, ) - share_embedding_strategy: Optional[str] = field( + share_embedding_strategy: str | None = field( default="fraction", metadata={ "help": "There are two strategies in adding shared embeddings." @@ -151,7 +150,7 @@ class TabTransformerConfig(ModelConfig): default=6, metadata={"help": "The number of layers of stacked Multi-Headed Attention layers. Defaults to 6"}, ) - transformer_head_dim: Optional[int] = field( + transformer_head_dim: int | None = field( default=None, metadata={ "help": "The number of hidden units in the Multi-Headed Attention layers." diff --git a/src/pytorch_tabular/models/tab_transformer/tab_transformer.py b/src/pytorch_tabular/models/tab_transformer/tab_transformer.py index da12d833..bab1b723 100644 --- a/src/pytorch_tabular/models/tab_transformer/tab_transformer.py +++ b/src/pytorch_tabular/models/tab_transformer/tab_transformer.py @@ -14,7 +14,6 @@ """TabTransformer Model.""" from collections import OrderedDict -from typing import Dict import torch import torch.nn as nn @@ -116,7 +115,7 @@ def _build_network(self): self._head = self._get_head_from_config() # Redefining forward because this model flow is slightly different - def forward(self, x: Dict): + def forward(self, x: dict): if self.hparams.categorical_dim > 0: x_cat = self.embed_input({"categorical": x["categorical"]}) else: @@ -125,7 +124,7 @@ def forward(self, x: Dict): return self.compute_head(x) # Redefining compute_backbone because this model flow is slightly different - def compute_backbone(self, x: Dict): + def compute_backbone(self, x: dict): # Returns output x = self.backbone(x["categorical"], x["continuous"]) return x diff --git a/src/pytorch_tabular/models/tabnet/config.py b/src/pytorch_tabular/models/tabnet/config.py index c1142273..f74b2f31 100644 --- a/src/pytorch_tabular/models/tabnet/config.py +++ b/src/pytorch_tabular/models/tabnet/config.py @@ -4,7 +4,6 @@ """Tabnet Model Config.""" from dataclasses import dataclass, field -from typing import List, Optional from pytorch_tabular.config import ModelConfig @@ -111,7 +110,7 @@ class TabNetModelConfig(ModelConfig): "choices": ["sparsemax", "entmax"], }, ) - grouped_features: Optional[List[List[str]]] = field( + grouped_features: list[list[str]] | None = field( default=None, metadata={ "help": ( diff --git a/src/pytorch_tabular/models/tabnet/tabnet_model.py b/src/pytorch_tabular/models/tabnet/tabnet_model.py index 49224048..813f104d 100644 --- a/src/pytorch_tabular/models/tabnet/tabnet_model.py +++ b/src/pytorch_tabular/models/tabnet/tabnet_model.py @@ -3,7 +3,6 @@ # For license information, see LICENSE.TXT """TabNet Model.""" -from typing import Dict import torch import torch.nn as nn @@ -12,7 +11,6 @@ from ..base_model import BaseModel - create_group_matrix = _safe_import( "pytorch_tabnet.utils.create_group_matrix", pkg_name="pytorch-tabnet" ) @@ -58,7 +56,7 @@ def _build_network(self): group_attention_matrix=group_matrix, ) - def unpack_input(self, x: Dict): + def unpack_input(self, x: dict): # unpacking into a tuple x = x["categorical"], x["continuous"] # eliminating None in case there is no categorical or continuous columns @@ -66,7 +64,7 @@ def unpack_input(self, x: Dict): x = torch.cat(tuple(x), dim=1) return x - def forward(self, x: Dict): + def forward(self, x: dict): # unpacking into a tuple x = self.unpack_input(x) # Making two parameters to the right device. diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 4c4b0257..e6a5f162 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -5,7 +5,6 @@ import warnings from abc import ABCMeta, abstractmethod -from typing import Dict, Optional, Union from pathlib import Path import pytorch_lightning as pl @@ -41,10 +40,10 @@ def __init__( self, config: DictConfig, mode: str = "pretrain", - encoder: Optional[nn.Module] = None, - decoder: Optional[nn.Module] = None, - custom_optimizer: Optional[torch.optim.Optimizer] = None, - custom_optimizer_params: Dict = {}, + encoder: nn.Module | None = None, + decoder: nn.Module | None = None, + custom_optimizer: torch.optim.Optimizer | None = None, + custom_optimizer_params: dict = {}, **kwargs, ): """Base Model for all SSL Models. @@ -145,17 +144,17 @@ def calculate_metrics(self, output, tag, sync_dist): pass @abstractmethod - def forward(self, x: Dict): + def forward(self, x: dict): pass @abstractmethod - def featurize(self, x: Dict): + def featurize(self, x: dict): pass @classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[str, Path], + checkpoint_path: str | Path, map_location=None, strict=True, **kwargs, @@ -174,7 +173,7 @@ def load_from_checkpoint( **kwargs, ) - def predict(self, x: Dict, ret_model_output: bool = True): # ret_model_output only for compatibility + def predict(self, x: dict, ret_model_output: bool = True): # ret_model_output only for compatibility assert ret_model_output, "ret_model_output must be True in case of SSL predict" return self.featurize(x) diff --git a/src/pytorch_tabular/ssl_models/common/augmentations.py b/src/pytorch_tabular/ssl_models/common/augmentations.py index bd0984f5..09e4f6c8 100644 --- a/src/pytorch_tabular/ssl_models/common/augmentations.py +++ b/src/pytorch_tabular/ssl_models/common/augmentations.py @@ -1,10 +1,9 @@ -from typing import Dict import numpy as np import torch -def mixup(batch: Dict, lam: float = 0.5) -> Dict: +def mixup(batch: dict, lam: float = 0.5) -> dict: """It apply mixup augmentation, making a weighted average between a tensor and some random element of the tensor taking random rows. @@ -20,7 +19,7 @@ def mixup(batch: Dict, lam: float = 0.5) -> Dict: return result -def cutmix(batch: Dict, lam: float = 0.1) -> Dict: +def cutmix(batch: dict, lam: float = 0.1) -> dict: """Define how apply cutmix to a tensor. :param batch: Tensor on which apply the cutmix augmentation diff --git a/src/pytorch_tabular/ssl_models/common/layers.py b/src/pytorch_tabular/ssl_models/common/layers.py index a3e8f97e..04927c88 100644 --- a/src/pytorch_tabular/ssl_models/common/layers.py +++ b/src/pytorch_tabular/ssl_models/common/layers.py @@ -1,6 +1,6 @@ # W605 from collections import OrderedDict -from typing import Any, Dict, Tuple +from typing import Any import torch from torch import nn @@ -15,7 +15,7 @@ class MixedEmbedding1dLayer(nn.Module): def __init__( self, continuous_dim: int, - categorical_embedding_dims: Tuple[int, int], + categorical_embedding_dims: tuple[int, int], max_onehot_cardinality: int = 4, embedding_dropout: float = 0.0, batch_norm_continuous_input: bool = False, @@ -69,7 +69,7 @@ def embedded_cat_dim(self): ] ) - def forward(self, x: Dict[str, Any]) -> torch.Tensor: + def forward(self, x: dict[str, Any]) -> torch.Tensor: assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features" # (B, N) continuous_data, categorical_data = ( diff --git a/src/pytorch_tabular/ssl_models/dae/config.py b/src/pytorch_tabular/ssl_models/dae/config.py index b1f74885..8f0c7245 100644 --- a/src/pytorch_tabular/ssl_models/dae/config.py +++ b/src/pytorch_tabular/ssl_models/dae/config.py @@ -4,7 +4,6 @@ """DenoisingAutoEncoder Config.""" from dataclasses import dataclass, field -from typing import Dict, List, Optional from pytorch_tabular.config import SSLModelConfig @@ -74,7 +73,7 @@ class DenoisingAutoEncoderConfig(SSLModelConfig): }, ) # Union not supported by omegaconf. Currently Union[float, Dict[str, float]] - noise_probabilities: Dict[str, float] = field( + noise_probabilities: dict[str, float] = field( default_factory=lambda: {}, metadata={ "help": "Dict of individual probabilities to corrupt the input features with swap/zero noise." @@ -89,7 +88,7 @@ class DenoisingAutoEncoderConfig(SSLModelConfig): " For features for which noise_probabilities does not define a probability. Default is 0.8" }, ) - loss_type_weights: Optional[List[float]] = field( + loss_type_weights: list[float] | None = field( default=None, metadata={ "help": "Weights to be used for the loss function in the order [binary, categorical, numerical]." diff --git a/src/pytorch_tabular/ssl_models/dae/dae.py b/src/pytorch_tabular/ssl_models/dae/dae.py index 6550018f..0d833743 100644 --- a/src/pytorch_tabular/ssl_models/dae/dae.py +++ b/src/pytorch_tabular/ssl_models/dae/dae.py @@ -5,7 +5,6 @@ """DenoisingAutoEncoder Model.""" from collections import namedtuple -from typing import Dict import torch import torch.nn as nn @@ -60,11 +59,11 @@ def _build_network(self): self._swap_probabilities = swap_probabilities self.swap_noise = SwapNoiseCorrupter(swap_probabilities) - def _concatenate_features(self, x: Dict): + def _concatenate_features(self, x: dict): x = torch.cat([x[key] for key in self.pick_keys if x[key] is not None], 1) return x - def forward(self, x: Dict, perturb: bool = True, return_input: bool = False): + def forward(self, x: dict, perturb: bool = True, return_input: bool = False): # (B, N, E) x = self._concatenate_features(x) mask = None @@ -160,7 +159,7 @@ def _init_loss_weights(self): def _setup_metrics(self): return None - def forward(self, x: Dict): + def forward(self, x: dict): if self.mode == "pretrain": x = self.embedding_layer(x) # (B, N, E) @@ -238,7 +237,7 @@ def calculate_loss(self, output, tag, sync_dist=False): def calculate_metrics(self, output, tag, sync_dist=False): pass - def featurize(self, x: Dict): + def featurize(self, x: dict): x = self.embedding_layer(x) return self.featurizer(x, perturb=False).features diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 2baebecb..70cfa09e 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -5,9 +5,9 @@ import re import warnings +from collections.abc import Iterable from enum import Enum from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union import joblib import numpy as np @@ -41,9 +41,9 @@ def __init__( self, data: DataFrame, task: str, - continuous_cols: List[str] = None, - categorical_cols: List[str] = None, - target: List[str] = None, + continuous_cols: list[str] = None, + categorical_cols: list[str] = None, + target: list[str] = None, ): """Dataset to Load Tabular Data. @@ -143,9 +143,9 @@ def __init__( train: DataFrame, config: DictConfig, validation: DataFrame = None, - target_transform: Optional[Union[TransformerMixin, Tuple]] = None, - train_sampler: Optional[torch.utils.data.Sampler] = None, - seed: Optional[int] = 42, + target_transform: TransformerMixin | tuple | None = None, + train_sampler: torch.utils.data.Sampler | None = None, + seed: int | None = 42, cache_data: str = "memory", copy_data: bool = True, verbose: bool = True, @@ -246,7 +246,7 @@ def target_transforms(self): def target_transforms(self, value): self._target_transforms = value - def _setup_cache(self, cache_data: Union[str, bool]) -> None: + def _setup_cache(self, cache_data: str | bool) -> None: cache_data = cache_data.lower() if cache_data == self.CACHE_MODES.MEMORY.value: self.cache_mode = self.CACHE_MODES.MEMORY @@ -258,7 +258,7 @@ def _setup_cache(self, cache_data: Union[str, bool]) -> None: logger.warning(f"{cache_data} is not a valid path. Caching in memory") self.cache_mode = self.CACHE_MODES.MEMORY - def _set_target_transform(self, target_transform: Union[TransformerMixin, Tuple]) -> None: + def _set_target_transform(self, target_transform: TransformerMixin | tuple) -> None: if target_transform is not None: if isinstance(target_transform, Iterable): target_transform = FunctionTransformer(func=target_transform[0], inverse_func=target_transform[1]) @@ -425,7 +425,7 @@ def _target_transform(self, data: DataFrame, stage: str) -> DataFrame: data[col] = _target_transform.transform(data[col].values.reshape(-1, 1)) return data - def preprocess_data(self, data: DataFrame, stage: str = "inference") -> Tuple[DataFrame, list]: + def preprocess_data(self, data: DataFrame, stage: str = "inference") -> tuple[DataFrame, list]: """The preprocessing, like Categorical Encoding, Normalization, etc. which any dataframe should undergo before feeding into the dataloder. @@ -512,7 +512,7 @@ def split_train_val(self, train): train = train[~train.index.isin(val_idx)] return train, validation - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str | None = None) -> None: """Data Operations you want to perform on all GPUs, like train-test split, transformations, etc. This is called before accessing the dataloaders. @@ -554,7 +554,7 @@ def inference_only_copy(self): # adapted from gluonts @classmethod - def time_features_from_frequency_str(cls, freq_str: str) -> List[str]: + def time_features_from_frequency_str(cls, freq_str: str) -> list[str]: """Returns a list of time features that will be appropriate for the given frequency string. Args: @@ -700,7 +700,7 @@ def add_datepart( frequency: str, prefix: str = None, drop: bool = True, - ) -> Tuple[DataFrame, List[str]]: + ) -> tuple[DataFrame, list[str]]: """Helper function that adds columns relevant to a date in the column `field_name` of `df`. Args: @@ -802,7 +802,7 @@ def validation_dataset(self) -> TabularDataset: def validation_dataset(self, value): self._validation_dataset = value - def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: + def train_dataloader(self, batch_size: int | None = None) -> DataLoader: """Function that loads the train set. Args: @@ -822,7 +822,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: **self.config.dataloader_kwargs, ) - def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: + def val_dataloader(self, batch_size: int | None = None) -> DataLoader: """Function that loads the validation set. Args: @@ -854,7 +854,7 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame: return df def prepare_inference_dataloader( - self, df: DataFrame, batch_size: Optional[int] = None, copy_df: bool = True + self, df: DataFrame, batch_size: int | None = None, copy_df: bool = True ) -> DataLoader: """Function that prepares and loads the new data. @@ -884,7 +884,7 @@ def prepare_inference_dataloader( **self.config.dataloader_kwargs, ) - def save_dataloader(self, path: Union[str, Path]) -> None: + def save_dataloader(self, path: str | Path) -> None: """Saves the dataloader to a path. Args: @@ -896,7 +896,7 @@ def save_dataloader(self, path: Union[str, Path]) -> None: joblib.dump(self, path) @classmethod - def load_datamodule(cls, path: Union[str, Path]): + def load_datamodule(cls, path: str | Path): """Loads a datamodule from a path. Args: @@ -917,14 +917,14 @@ def copy( self, train: DataFrame, validation: DataFrame = None, - target_transform: Optional[Union[TransformerMixin, Tuple]] = None, - train_sampler: Optional[torch.utils.data.Sampler] = None, - seed: Optional[int] = None, + target_transform: TransformerMixin | tuple | None = None, + train_sampler: torch.utils.data.Sampler | None = None, + seed: int | None = None, cache_data: str = None, copy_data: bool = None, verbose: bool = None, call_setup: bool = True, - config_override: Optional[Dict] = {}, + config_override: dict | None = {}, ): if config_override is not None: for k, v in config_override.items(): diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 5c402d1a..2b48d067 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -10,10 +10,10 @@ import uuid import warnings from collections import defaultdict +from collections.abc import Callable, Iterable from functools import partial from pathlib import Path from pprint import pformat, pprint -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import joblib import numpy as np @@ -75,14 +75,14 @@ class TabularModel: def __init__( self, - config: Optional[DictConfig] = None, - data_config: Optional[Union[DataConfig, str]] = None, - model_config: Optional[Union[ModelConfig, str]] = None, - optimizer_config: Optional[Union[OptimizerConfig, str]] = None, - trainer_config: Optional[Union[TrainerConfig, str]] = None, - experiment_config: Optional[Union[ExperimentConfig, str]] = None, - model_callable: Optional[Callable] = None, - model_state_dict_path: Optional[Union[str, Path]] = None, + config: DictConfig | None = None, + data_config: DataConfig | str | None = None, + model_config: ModelConfig | str | None = None, + optimizer_config: OptimizerConfig | str | None = None, + trainer_config: TrainerConfig | str | None = None, + experiment_config: ExperimentConfig | str | None = None, + model_callable: Callable | None = None, + model_state_dict_path: str | Path | None = None, verbose: bool = True, suppress_lightning_logger: bool = False, ) -> None: @@ -251,7 +251,7 @@ def _read_parse_config(self, config, cls): config = OmegaConf.structured(config) return config - def _get_run_name_uid(self) -> Tuple[str, int]: + def _get_run_name_uid(self) -> tuple[str, int]: """Gets the name of the experiment and increments version by 1. Returns: @@ -286,7 +286,7 @@ def _setup_experiment_tracking(self): f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]" ) - def _prepare_callbacks(self, callbacks=None) -> List: + def _prepare_callbacks(self, callbacks=None) -> list: """Prepares the necesary callbacks to the Trainer based on the configuration. Returns: @@ -325,7 +325,7 @@ def _prepare_callbacks(self, callbacks=None) -> List: logger.debug(f"Callbacks used: {callbacks}") return callbacks - def _prepare_trainer(self, callbacks: List, max_epochs: int = None, min_epochs: int = None) -> pl.Trainer: + def _prepare_trainer(self, callbacks: list, max_epochs: int = None, min_epochs: int = None) -> pl.Trainer: """Prepares the Trainer object. Args: @@ -387,7 +387,7 @@ def _prepare_for_training(self, model, datamodule, callbacks=None, max_epochs=No self.datamodule = datamodule @classmethod - def _load_weights(cls, model, path: Union[str, Path]) -> None: + def _load_weights(cls, model, path: str | Path) -> None: """Loads the model weights in the specified directory. Args: @@ -510,10 +510,10 @@ def load_model(cls, dir: str, map_location=None, strict=True): def prepare_dataloader( self, train: DataFrame, - validation: Optional[DataFrame] = None, - train_sampler: Optional[torch.utils.data.Sampler] = None, - target_transform: Optional[Union[TransformerMixin, Tuple]] = None, - seed: Optional[int] = 42, + validation: DataFrame | None = None, + train_sampler: torch.utils.data.Sampler | None = None, + target_transform: TransformerMixin | tuple | None = None, + seed: int | None = 42, cache_data: str = "memory", ) -> TabularDatamodule: """Prepares the dataloaders for training and validation. @@ -564,11 +564,11 @@ def prepare_dataloader( def prepare_model( self, datamodule: TabularDatamodule, - loss: Optional[torch.nn.Module] = None, - metrics: Optional[List[Callable]] = None, - metrics_prob_inputs: Optional[List[bool]] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - optimizer_params: Dict = None, + loss: torch.nn.Module | None = None, + metrics: list[Callable] | None = None, + metrics_prob_inputs: list[bool] | None = None, + optimizer: torch.optim.Optimizer | None = None, + optimizer_params: dict = None, ) -> BaseModel: """Prepares the model for training. @@ -619,7 +619,7 @@ def train( self, model: pl.LightningModule, datamodule: TabularDatamodule, - callbacks: Optional[List[pl.Callback]] = None, + callbacks: list[pl.Callback] | None = None, max_epochs: int = None, min_epochs: int = None, handle_oom: bool = True, @@ -694,20 +694,20 @@ def train( def fit( self, - train: Optional[DataFrame], - validation: Optional[DataFrame] = None, - loss: Optional[torch.nn.Module] = None, - metrics: Optional[List[Callable]] = None, - metrics_prob_inputs: Optional[List[bool]] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - optimizer_params: Dict = None, - train_sampler: Optional[torch.utils.data.Sampler] = None, - target_transform: Optional[Union[TransformerMixin, Tuple]] = None, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, - seed: Optional[int] = 42, - callbacks: Optional[List[pl.Callback]] = None, - datamodule: Optional[TabularDatamodule] = None, + train: DataFrame | None, + validation: DataFrame | None = None, + loss: torch.nn.Module | None = None, + metrics: list[Callable] | None = None, + metrics_prob_inputs: list[bool] | None = None, + optimizer: torch.optim.Optimizer | None = None, + optimizer_params: dict = None, + train_sampler: torch.utils.data.Sampler | None = None, + target_transform: TransformerMixin | tuple | None = None, + max_epochs: int | None = None, + min_epochs: int | None = None, + seed: int | None = 42, + callbacks: list[pl.Callback] | None = None, + datamodule: TabularDatamodule | None = None, cache_data: str = "memory", handle_oom: bool = True, ) -> pl.Trainer: @@ -808,16 +808,16 @@ def fit( def pretrain( self, - train: Optional[DataFrame], - validation: Optional[DataFrame] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - optimizer_params: Dict = None, + train: DataFrame | None, + validation: DataFrame | None = None, + optimizer: torch.optim.Optimizer | None = None, + optimizer_params: dict = None, # train_sampler: Optional[torch.utils.data.Sampler] = None, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, - seed: Optional[int] = 42, - callbacks: Optional[List[pl.Callback]] = None, - datamodule: Optional[TabularDatamodule] = None, + max_epochs: int | None = None, + min_epochs: int | None = None, + seed: int | None = 42, + callbacks: list[pl.Callback] | None = None, + datamodule: TabularDatamodule | None = None, cache_data: str = "memory", ) -> pl.Trainer: """The pretrained method which takes in the data and triggers the training. @@ -886,24 +886,24 @@ def create_finetune_model( self, task: str, head: str, - head_config: Dict, + head_config: dict, train: DataFrame, - validation: Optional[DataFrame] = None, - train_sampler: Optional[torch.utils.data.Sampler] = None, - target_transform: Optional[Union[TransformerMixin, Tuple]] = None, - target: Optional[str] = None, - optimizer_config: Optional[OptimizerConfig] = None, - trainer_config: Optional[TrainerConfig] = None, - experiment_config: Optional[ExperimentConfig] = None, - loss: Optional[torch.nn.Module] = None, - metrics: Optional[List[Union[Callable, str]]] = None, - metrics_prob_input: Optional[List[bool]] = None, - metrics_params: Optional[Dict] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - optimizer_params: Dict = None, - learning_rate: Optional[float] = None, - target_range: Optional[Tuple[float, float]] = None, - seed: Optional[int] = 42, + validation: DataFrame | None = None, + train_sampler: torch.utils.data.Sampler | None = None, + target_transform: TransformerMixin | tuple | None = None, + target: str | None = None, + optimizer_config: OptimizerConfig | None = None, + trainer_config: TrainerConfig | None = None, + experiment_config: ExperimentConfig | None = None, + loss: torch.nn.Module | None = None, + metrics: list[Callable | str] | None = None, + metrics_prob_input: list[bool] | None = None, + metrics_params: dict | None = None, + optimizer: torch.optim.Optimizer | None = None, + optimizer_params: dict = None, + learning_rate: float | None = None, + target_range: tuple[float, float] | None = None, + seed: int | None = 42, ): """Creates a new TabularModel model using the pretrained weights and the new task and head. @@ -1081,9 +1081,9 @@ def create_finetune_model( def finetune( self, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, - callbacks: Optional[List[pl.Callback]] = None, + max_epochs: int | None = None, + min_epochs: int | None = None, + callbacks: list[pl.Callback] | None = None, freeze_backbone: bool = False, ) -> pl.Trainer: """Finetunes the model on the provided data. @@ -1126,10 +1126,10 @@ def find_learning_rate( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, plot: bool = True, - callbacks: Optional[List] = None, - ) -> Tuple[float, DataFrame]: + callbacks: list | None = None, + ) -> tuple[float, DataFrame]: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -1186,11 +1186,11 @@ def find_learning_rate( def evaluate( self, - test: Optional[DataFrame] = None, - test_loader: Optional[torch.utils.data.DataLoader] = None, - ckpt_path: Optional[Union[str, Path]] = None, + test: DataFrame | None = None, + test_loader: torch.utils.data.DataLoader | None = None, + ckpt_path: str | Path | None = None, verbose: bool = True, - ) -> Union[dict, list]: + ) -> dict | list: """Evaluates the dataframe using the loss and metrics already set in config. Args: @@ -1334,12 +1334,12 @@ def _format_predicitons( def _predict( self, test: DataFrame, - quantiles: Optional[List] = [0.25, 0.5, 0.75], - n_samples: Optional[int] = 100, + quantiles: list | None = [0.25, 0.5, 0.75], + n_samples: int | None = 100, ret_logits=False, include_input_features: bool = False, - device: Optional[torch.device] = None, - progress_bar: Optional[str] = None, + device: torch.device | None = None, + progress_bar: str | None = None, ) -> DataFrame: """Uses the trained model to predict on new data and return as a dataframe. @@ -1407,17 +1407,17 @@ def _predict( def predict( self, test: DataFrame, - quantiles: Optional[List] = [0.25, 0.5, 0.75], - n_samples: Optional[int] = 100, + quantiles: list | None = [0.25, 0.5, 0.75], + n_samples: int | None = 100, ret_logits=False, include_input_features: bool = False, - device: Optional[torch.device] = None, - progress_bar: Optional[str] = None, - test_time_augmentation: Optional[bool] = False, - num_tta: Optional[float] = 5, - alpha_tta: Optional[float] = 0.1, - aggregate_tta: Optional[str] = "mean", - tta_seed: Optional[int] = 42, + device: torch.device | None = None, + progress_bar: str | None = None, + test_time_augmentation: bool | None = False, + num_tta: float | None = 5, + alpha_tta: float | None = 0.1, + aggregate_tta: str | None = "mean", + tta_seed: int | None = 42, ) -> DataFrame: """Uses the trained model to predict on new data and return as a dataframe. @@ -1595,7 +1595,7 @@ def save_model(self, dir: str, inference_only: bool = False) -> None: if self.custom_model: joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav")) - def save_weights(self, path: Union[str, Path]) -> None: + def save_weights(self, path: str | Path) -> None: """Saves the model weights in the specified directory. Args: @@ -1604,7 +1604,7 @@ def save_weights(self, path: Union[str, Path]) -> None: """ torch.save(self.model.state_dict(), path) - def load_weights(self, path: Union[str, Path]) -> None: + def load_weights(self, path: str | Path) -> None: """Loads the model weights in the specified directory. Args: @@ -1616,9 +1616,9 @@ def load_weights(self, path: Union[str, Path]) -> None: # TODO Need to test ONNX export def save_model_for_inference( self, - path: Union[str, Path], + path: str | Path, kind: str = "pytorch", - onnx_export_params: Dict = {"opset_version": 12}, + onnx_export_params: dict = {"opset_version": 12}, ) -> bool: """Saves the model for inference. @@ -1897,7 +1897,7 @@ def feature_importance(self) -> DataFrame: """Returns the feature importance of the model as a pandas DataFrame.""" return self.model.feature_importance() - def _prepare_input_for_captum(self, test_dl: torch.utils.data.DataLoader) -> Dict: + def _prepare_input_for_captum(self, test_dl: torch.utils.data.DataLoader) -> dict: tensor_inp = [] tensor_tgt = [] for x in test_dl: @@ -1909,7 +1909,7 @@ def _prepare_input_for_captum(self, test_dl: torch.utils.data.DataLoader) -> Dic def _prepare_baselines_captum( self, - baselines: Union[float, torch.tensor, str], + baselines: float | torch.tensor | str, test_dl: torch.utils.data.DataLoader, do_baselines: bool, is_full_baselines: bool, @@ -1970,8 +1970,8 @@ def explain( self, data: DataFrame, method: str = "GradientShap", - method_args: Optional[Dict] = {}, - baselines: Union[float, torch.tensor, str] = None, + method_args: dict | None = {}, + baselines: float | torch.tensor | str = None, **kwargs, ) -> DataFrame: """Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned @@ -2134,11 +2134,11 @@ def _split_kwargs(self, kwargs): def cross_validate( self, - cv: Optional[Union[int, Iterable, BaseCrossValidator]], + cv: int | Iterable | BaseCrossValidator | None, train: DataFrame, - metric: Optional[Union[str, Callable]] = None, + metric: str | Callable | None = None, return_oof: bool = False, - groups: Optional[Union[str, np.ndarray]] = None, + groups: str | np.ndarray | None = None, verbose: bool = True, reset_datamodule: bool = True, handle_oom: bool = True, @@ -2251,10 +2251,10 @@ def cross_validate( def _combine_predictions( self, - pred_prob_l: List[DataFrame], - pred_idx: Union[pd.Index, List], - aggregate: Union[str, Callable], - weights: Optional[List[float]] = None, + pred_prob_l: list[DataFrame], + pred_idx: pd.Index | list, + aggregate: str | Callable, + weights: list[float] | None = None, ): if aggregate == "mean": bagged_pred = np.average(pred_prob_l, axis=0, weights=weights) @@ -2300,15 +2300,15 @@ def _combine_predictions( def bagging_predict( self, - cv: Optional[Union[int, Iterable, BaseCrossValidator]], + cv: int | Iterable | BaseCrossValidator | None, train: DataFrame, test: DataFrame, - groups: Optional[Union[str, np.ndarray]] = None, + groups: str | np.ndarray | None = None, verbose: bool = True, reset_datamodule: bool = True, return_raw_predictions: bool = False, - aggregate: Union[str, Callable] = "mean", - weights: Optional[List[float]] = None, + aggregate: str | Callable = "mean", + weights: list[float] | None = None, handle_oom: bool = True, **kwargs, ): diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index fc97140e..4ecd56eb 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -1,8 +1,8 @@ import copy import time import warnings +from collections.abc import Callable from contextlib import nullcontext -from typing import Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -98,17 +98,17 @@ def _validate_args( task: str, train: pd.DataFrame, test: pd.DataFrame, - data_config: Union[DataConfig, str], - optimizer_config: Union[OptimizerConfig, str], - trainer_config: Union[TrainerConfig, str], - model_list: Union[str, List[Union[ModelConfig, str]]] = "lite", - metrics: Optional[List[Union[str, Callable]]] = None, - metrics_params: Optional[List[dict]] = None, - metrics_prob_input: Optional[List[bool]] = None, - validation: Optional[pd.DataFrame] = None, - experiment_config: Optional[Union[ExperimentConfig, str]] = None, - common_model_args: Optional[dict] = {}, - rank_metric: Optional[str] = "loss", + data_config: DataConfig | str, + optimizer_config: OptimizerConfig | str, + trainer_config: TrainerConfig | str, + model_list: str | list[ModelConfig | str] = "lite", + metrics: list[str | Callable] | None = None, + metrics_params: list[dict] | None = None, + metrics_prob_input: list[bool] | None = None, + validation: pd.DataFrame | None = None, + experiment_config: ExperimentConfig | str | None = None, + common_model_args: dict | None = {}, + rank_metric: str | None = "loss", ): assert task in [ "classification", @@ -172,17 +172,17 @@ def model_sweep( task: str, train: pd.DataFrame, test: pd.DataFrame, - data_config: Union[DataConfig, str], - optimizer_config: Union[OptimizerConfig, str], - trainer_config: Union[TrainerConfig, str], - model_list: Union[str, List[Union[ModelConfig, str]]] = "lite", - metrics: Optional[List[Union[str, Callable]]] = None, - metrics_params: Optional[List[dict]] = None, - metrics_prob_input: Optional[List[bool]] = None, - validation: Optional[pd.DataFrame] = None, - experiment_config: Optional[Union[ExperimentConfig, str]] = None, - common_model_args: Optional[dict] = {}, - rank_metric: Optional[Tuple[str, str]] = ("loss", "lower_is_better"), + data_config: DataConfig | str, + optimizer_config: OptimizerConfig | str, + trainer_config: TrainerConfig | str, + model_list: str | list[ModelConfig | str] = "lite", + metrics: list[str | Callable] | None = None, + metrics_params: list[dict] | None = None, + metrics_prob_input: list[bool] | None = None, + validation: pd.DataFrame | None = None, + experiment_config: ExperimentConfig | str | None = None, + common_model_args: dict | None = {}, + rank_metric: tuple[str, str] | None = ("loss", "lower_is_better"), return_best_model: bool = True, seed: int = 42, ignore_oom: bool = True, @@ -296,7 +296,7 @@ def model_sweep( model_list = [ ( getattr(models, model_config[0])(task=task, **model_config[1], **common_model_args) - if isinstance(model_config, Tuple) + if isinstance(model_config, tuple) else ( getattr(models, model_config)(task=task, **common_model_args) if isinstance(model_config, str) diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index d199d1fb..2b21f0e9 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -5,9 +5,9 @@ import warnings from collections import namedtuple +from collections.abc import Callable, Iterable from copy import deepcopy from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional, Union import numpy as np import pandas as pd @@ -23,7 +23,12 @@ TrainerConfig, ) from pytorch_tabular.tabular_model import TabularModel -from pytorch_tabular.utils import OOMException, OutOfMemoryHandler, get_logger, suppress_lightning_logs +from pytorch_tabular.utils import ( + OOMException, + OutOfMemoryHandler, + get_logger, + suppress_lightning_logs, +) logger = get_logger(__name__) @@ -41,12 +46,12 @@ class TabularModelTuner: def __init__( self, - data_config: Optional[Union[DataConfig, str]] = None, - model_config: Optional[Union[ModelConfig, str]] = None, - optimizer_config: Optional[Union[OptimizerConfig, str]] = None, - trainer_config: Optional[Union[TrainerConfig, List[TrainerConfig]]] = None, - model_callable: Optional[Callable] = None, - model_state_dict_path: Optional[Union[str, Path]] = None, + data_config: DataConfig | str | None = None, + model_config: ModelConfig | str | None = None, + optimizer_config: OptimizerConfig | str | None = None, + trainer_config: TrainerConfig | list[TrainerConfig] | None = None, + model_callable: Callable | None = None, + model_state_dict_path: str | Path | None = None, suppress_lightning_logger: bool = True, **kwargs, ): @@ -125,7 +130,7 @@ def _update_configs( self, optimizer_config: OptimizerConfig, model_config: ModelConfig, - params: Dict, + params: dict, ): """Update the configs with the new parameters.""" # update configs with the new parameters @@ -156,19 +161,19 @@ def _update_configs( def tune( self, train: DataFrame, - search_space: Union[Dict, List[Dict]], - metric: Union[str, Callable], + search_space: dict | list[dict], + metric: str | Callable, mode: str, strategy: str, - validation: Optional[DataFrame] = None, - n_trials: Optional[int] = None, - cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None, - cv_agg_func: Optional[Callable] = np.mean, - cv_kwargs: Optional[Dict] = {}, + validation: DataFrame | None = None, + n_trials: int | None = None, + cv: int | Iterable | BaseCrossValidator | None = None, + cv_agg_func: Callable | None = np.mean, + cv_kwargs: dict | None = {}, return_best_model: bool = True, verbose: bool = False, progress_bar: bool = True, - random_state: Optional[int] = 42, + random_state: int | None = 42, ignore_oom: bool = True, **kwargs, ): diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index e08503ed..3f323b54 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -1,7 +1,8 @@ import math import textwrap +from collections.abc import Callable from pathlib import Path -from typing import IO, Any, Callable, Dict, Optional, Union +from typing import IO, Any import numpy as np import torch @@ -15,9 +16,9 @@ from .logger import get_logger -_PATH = Union[str, Path] -_DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] +_PATH = str | Path +_DEVICE = torch.device | str | int +_MAP_LOCATION_TYPE = _DEVICE | Callable[[_DEVICE], _DEVICE] | dict[_DEVICE, _DEVICE] | None logger = get_logger(__name__) @@ -62,7 +63,7 @@ def generate_doc_dataclass(dataclass, desc=None, width=100): # Copied over pytorch_lightning.utilities.cloud_io.load as it was deprecated def pl_load( - path_or_url: Union[IO, _PATH], + path_or_url: IO | _PATH, map_location: _MAP_LOCATION_TYPE = None, ) -> Any: """Loads a checkpoint.