diff --git a/lightgbmlss/distributions/LogitNormal.py b/lightgbmlss/distributions/LogitNormal.py new file mode 100644 index 0000000..0d99b73 --- /dev/null +++ b/lightgbmlss/distributions/LogitNormal.py @@ -0,0 +1,85 @@ +from torch.distributions import Normal, TransformedDistribution, SigmoidTransform +from .distribution_utils import DistributionClass +from ..utils import * + + +class LogitNormal(DistributionClass): + """ + Logit-Normal distribution class. + + Distributional Parameters + ------------------------- + loc: torch.Tensor + Mean of the normal distribution before applying the logit transformation. + scale: torch.Tensor + Standard deviation of the normal distribution before applying the logit transformation. + + Source + ------------------------- + https://pytorch.org/docs/stable/distributions.html#normal + + Parameters + ------------------------- + stabilization: str + Stabilization method for the Gradient and Hessian. Options are "None", "MAD", "L2". + response_fn: str + Response function for transforming the distributional parameters to the correct support. Options are + "identity" (no transformation) or "softplus" (softplus to ensure positivity). + loss_fn: str + Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score). + Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable. + initialize: bool + Whether to initialize the distributional parameters with unconditional start values. Initialization can help + to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal + solutions if the unconditional start values are far from the optimal values. + """ + + def __init__(self, + stabilization: str = "None", + response_fn: str = "identity", + loss_fn: str = "nll", + initialize: bool = False, + ): + + # Input Checks + if stabilization not in ["None", "MAD", "L2"]: + raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.") + if loss_fn not in ["nll", "crps"]: + raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.") + if not isinstance(initialize, bool): + raise ValueError("Invalid initialize. Please choose from True or False.") + + # Specify Response Functions + response_functions = {"identity": identity_fn, "softplus": softplus_fn} + if response_fn in response_functions: + response_fn = response_functions[response_fn] + else: + raise ValueError("Invalid response function. Please choose from 'identity' or 'softplus'.") + + # Define Logit-Normal as a transformed distribution + base_distribution = Normal + + # Create a proper class instead of lambda to have arg_constraints + class LogitNormalDistribution(TransformedDistribution): + arg_constraints = base_distribution.arg_constraints + + def __init__(self, loc, scale): + super().__init__(base_distribution(loc, scale), [SigmoidTransform()]) + + transformed_distribution = LogitNormalDistribution + + # Define Parameter Mapping + param_dict = {"loc": identity_fn, "scale": response_fn} + torch.distributions.Distribution.set_default_validate_args(False) + + # Specify Distribution Class + super().__init__(distribution=transformed_distribution, + univariate=True, + discrete=False, + n_dist_param=len(param_dict), + stabilization=stabilization, + param_dict=param_dict, + distribution_arg_names=list(param_dict.keys()), + loss_fn=loss_fn, + initialize=initialize, + ) \ No newline at end of file diff --git a/lightgbmlss/distributions/__init__.py b/lightgbmlss/distributions/__init__.py index af15802..f1af70a 100644 --- a/lightgbmlss/distributions/__init__.py +++ b/lightgbmlss/distributions/__init__.py @@ -23,4 +23,5 @@ from . import ZALN from . import SplineFlow from . import Mixture -from . import Logistic \ No newline at end of file +from . import Logistic +from . import LogitNormal \ No newline at end of file