Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ exclude = [
"dist",
"docs"
]
ignore-init-module-imports = true

[tool.ruff.per-file-ignores]

[tool.ruff.lint.per-file-ignores]
"setup.py" = ["D100", "SIM115"]
"__about__.py" = ["D100"]
"__init__.py" = ["D100"]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
# Use numpy-style docstrings.
convention = "numpy"

Expand Down
91 changes: 46 additions & 45 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import os
import re
from collections.abc import Iterable
from dataclasses import MISSING, dataclass, field
from typing import Any, Dict, Iterable, List, Optional
from typing import Any

from omegaconf import OmegaConf

Expand Down Expand Up @@ -103,22 +104,22 @@ class DataConfig:

"""

target: Optional[List[str]] = field(
target: list[str] | None = field(
default=None,
metadata={
"help": "A list of strings with the names of the target column(s)."
" It is mandatory for all except SSL tasks."
},
)
continuous_cols: List = field(
continuous_cols: list = field(
default_factory=list,
metadata={"help": "Column names of the numeric fields. Defaults to []"},
)
categorical_cols: List = field(
categorical_cols: list = field(
default_factory=list,
metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"},
)
date_columns: List = field(
date_columns: list = field(
default_factory=list,
metadata={
"help": "(Column names, Freq) tuples of the date fields. For eg. a field named"
Expand All @@ -131,14 +132,14 @@ class DataConfig:
default=True,
metadata={"help": "Whether or not to encode the derived variables from date"},
)
validation_split: Optional[float] = field(
validation_split: float | None = field(
default=0.2,
metadata={
"help": "Percentage of Training rows to keep aside as validation."
" Used only if Validation Data is not given separately"
},
)
continuous_feature_transform: Optional[str] = field(
continuous_feature_transform: str | None = field(
default=None,
metadata={
"help": "Whether or not to transform the features before modelling. By default it is turned off.",
Expand All @@ -164,7 +165,7 @@ class DataConfig:
" the noise is only applied for QuantileTransformer"
},
)
num_workers: Optional[int] = field(
num_workers: int | None = field(
default=0,
metadata={"help": "The number of workers used for data loading. For windows always set to 0"},
)
Expand All @@ -186,7 +187,7 @@ class DataConfig:
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
)

dataloader_kwargs: Dict[str, Any] = field(
dataloader_kwargs: dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
)
Expand Down Expand Up @@ -229,19 +230,19 @@ class InferredConfig:
continuous_dim: int = field(
metadata={"help": "The number of continuous features"},
)
output_dim: Optional[int] = field(
output_dim: int | None = field(
default=None,
metadata={"help": "The number of output targets"},
)
output_cardinality: Optional[List[int]] = field(
output_cardinality: list[int] | None = field(
default=None,
metadata={"help": "The number of unique values in classification output"},
)
categorical_cardinality: Optional[List[int]] = field(
categorical_cardinality: list[int] | None = field(
default=None,
metadata={"help": "The number of unique values in categorical features"},
)
embedding_dims: Optional[List] = field(
embedding_dims: list | None = field(
default=None,
metadata={
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
Expand Down Expand Up @@ -384,30 +385,30 @@ class TrainerConfig:
},
)
max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"})
min_epochs: Optional[int] = field(
min_epochs: int | None = field(
default=1,
metadata={"help": "Force training for at least these many epochs. 1 by default"},
)
max_time: Optional[int] = field(
max_time: int | None = field(
default=None,
metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"},
)
accelerator: Optional[str] = field(
accelerator: str | None = field(
default="auto",
metadata={
"help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
" Defaults to 'auto'",
"choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
},
)
devices: Optional[int] = field(
devices: int | None = field(
default=-1,
metadata={
"help": "Number of devices to train on. -1 uses all available devices."
" By default uses all available devices (-1)",
},
)
devices_list: Optional[List[int]] = field(
devices_list: list[int] | None = field(
default=None,
metadata={
"help": "List of devices to train on (list). If specified, takes precedence over `devices` argument."
Expand Down Expand Up @@ -454,15 +455,15 @@ class TrainerConfig:
"help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
},
)
profiler: Optional[str] = field(
profiler: str | None = field(
default=None,
metadata={
"help": "To profile individual steps during training and assist in identifying bottlenecks."
" None, simple or advanced, pytorch",
"choices": [None, "simple", "advanced", "pytorch"],
},
)
early_stopping: Optional[str] = field(
early_stopping: str | None = field(
default="valid_loss",
metadata={
"help": "The loss/metric that needed to be monitored for early stopping."
Expand All @@ -484,14 +485,14 @@ class TrainerConfig:
default=3,
metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
)
early_stopping_kwargs: Optional[Dict[str, Any]] = field(
early_stopping_kwargs: dict[str, Any] | None = field(
default_factory=lambda: {},
metadata={
"help": "Additional keyword arguments for the early stopping callback."
" See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
},
)
checkpoints: Optional[str] = field(
checkpoints: str | None = field(
default="valid_loss",
metadata={
"help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
Expand All @@ -505,7 +506,7 @@ class TrainerConfig:
default=1,
metadata={"help": "Number of training steps between checkpoints"},
)
checkpoints_name: Optional[str] = field(
checkpoints_name: str | None = field(
default=None,
metadata={
"help": "The name under which the models will be saved. If left blank,"
Expand All @@ -521,7 +522,7 @@ class TrainerConfig:
default=1,
metadata={"help": "The number of best models to save"},
)
checkpoints_kwargs: Optional[Dict[str, Any]] = field(
checkpoints_kwargs: dict[str, Any] | None = field(
default_factory=lambda: {},
metadata={
"help": "Additional keyword arguments for the checkpoints callback. See the documentation"
Expand Down Expand Up @@ -553,7 +554,7 @@ class TrainerConfig:
default=42,
metadata={"help": "Seed for random number generators. Defaults to 42"},
)
trainer_kwargs: Dict[str, Any] = field(
trainer_kwargs: dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."},
)
Expand Down Expand Up @@ -611,14 +612,14 @@ class ExperimentConfig:
},
)

run_name: Optional[str] = field(
run_name: str | None = field(
default=None,
metadata={
"help": "The name of the run; a specific identifier to recognize the run."
" If left blank, will be assigned a auto-generated name"
},
)
exp_watch: Optional[str] = field(
exp_watch: str | None = field(
default=None,
metadata={
"help": "The level of logging required. Can be `gradients`, `parameters`, `all` or `None`."
Expand Down Expand Up @@ -690,29 +691,29 @@ class OptimizerConfig:
" for example 'torch_optimizer.RAdam'."
},
)
optimizer_params: Dict = field(
optimizer_params: dict = field(
default_factory=lambda: {},
metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
)
lr_scheduler: Optional[str] = field(
lr_scheduler: str | None = field(
default=None,
metadata={
"help": "The name of the LearningRateScheduler to use, if any, from"
" https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
" If None, will not use any scheduler. Defaults to `None`",
},
)
lr_scheduler_params: Optional[Dict] = field(
lr_scheduler_params: dict | None = field(
default_factory=lambda: {},
metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
)

lr_scheduler_monitor_metric: Optional[str] = field(
lr_scheduler_monitor_metric: str | None = field(
default="valid_loss",
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
)

lr_scheduler_interval: Optional[str] = field(
lr_scheduler_interval: str | None = field(
default="epoch",
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
)
Expand Down Expand Up @@ -823,7 +824,7 @@ class ModelConfig:
}
)

head: Optional[str] = field(
head: str | None = field(
default="LinearHead",
metadata={
"help": "The head to be used for the model. Should be one of the heads defined"
Expand All @@ -832,14 +833,14 @@ class ModelConfig:
},
)

head_config: Optional[Dict] = field(
head_config: dict | None = field(
default_factory=lambda: {"layers": ""},
metadata={
"help": "The config as a dict which defines the head."
" If left empty, will be initialized as default linear head."
},
)
embedding_dims: Optional[List] = field(
embedding_dims: list | None = field(
default=None,
metadata={
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
Expand All @@ -860,15 +861,15 @@ class ModelConfig:
default=1e-3,
metadata={"help": "The learning rate of the model. Defaults to 1e-3."},
)
loss: Optional[str] = field(
loss: str | None = field(
default=None,
metadata={
"help": "The loss function to be applied. By Default it is MSELoss for regression "
"and CrossEntropyLoss for classification. Unless you are sure what you are doing, "
"leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification"
},
)
metrics: Optional[List[str]] = field(
metrics: list[str] | None = field(
default=None,
metadata={
"help": "the list of metrics you need to track during training. The metrics should be one "
Expand All @@ -877,23 +878,23 @@ class ModelConfig:
"and mean_squared_error for regression"
},
)
metrics_prob_input: Optional[List[bool]] = field(
metrics_prob_input: list[bool] | None = field(
default=None,
metadata={
"help": "Is a mandatory parameter for classification metrics defined in the config. This defines "
"whether the input to the metric function is the probability or the class. Length should be same "
"as the number of metrics. Defaults to None."
},
)
metrics_params: Optional[List] = field(
metrics_params: list | None = field(
default=None,
metadata={
"help": "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` "
"because the multiclass version can handle binary as well and for simplicity we are only using "
"`multiclass`."
},
)
target_range: Optional[List] = field(
target_range: list | None = field(
default=None,
metadata={
"help": "The range in which we should limit the output variable. "
Expand All @@ -902,7 +903,7 @@ class ModelConfig:
},
)

virtual_batch_size: Optional[int] = field(
virtual_batch_size: int | None = field(
default=None,
metadata={
"help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
Expand Down Expand Up @@ -1001,23 +1002,23 @@ class SSLModelConfig:

task: str = field(init=False, default="ssl")

encoder_config: Optional[ModelConfig] = field(
encoder_config: ModelConfig | None = field(
default=None,
metadata={
"help": "The config of the encoder to be used for the model."
" Should be one of the model configs defined in PyTorch Tabular",
},
)

decoder_config: Optional[ModelConfig] = field(
decoder_config: ModelConfig | None = field(
default=None,
metadata={
"help": "The config of decoder to be used for the model."
" Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity",
},
)

embedding_dims: Optional[List] = field(
embedding_dims: list | None = field(
default=None,
metadata={
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
Expand All @@ -1033,7 +1034,7 @@ class SSLModelConfig:
default=True,
metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
)
virtual_batch_size: Optional[int] = field(
virtual_batch_size: int | None = field(
default=None,
metadata={
"help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
Expand Down
Loading