Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions lightgbmlss/distributions/Cauchy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch.distributions import Cauchy as Cauchy_Torch
from .distribution_utils import DistributionClass
from ..utils import *
from typing import List


class Cauchy(DistributionClass):
Expand Down Expand Up @@ -33,12 +34,15 @@ class Cauchy(DistributionClass):
Whether to initialize the distributional parameters with unconditional start values. Initialization can help
to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal
solutions if the unconditional start values are far from the optimal values.
natural_gradient: bool
Whether to use natural gradient descent for optimization.
"""
def __init__(self,
stabilization: str = "None",
response_fn: str = "exp",
loss_fn: str = "nll",
initialize: bool = False,
natural_gradient: bool = False,
):

# Input Checks
Expand Down Expand Up @@ -72,4 +76,54 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn,
initialize=initialize,
natural_gradient=natural_gradient,
)

def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute Fisher Information Matrix diagonal for Cauchy distribution.

For Cauchy distribution with parameters (μ, σ):
- Fisher Information w.r.t. μ: I(μ) = 1/(2σ²)
- Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ):
I(η_σ) = (1/(2σ²)) * (g'(η_σ))²

Parameters
----------
predt : List[torch.Tensor]
[eta_mu, eta_sigma] - raw parameters before response functions

Returns
-------
fim : List[torch.Tensor]
[FIM_mu, FIM_sigma]
"""
eta_mu, eta_sigma = predt[0], predt[1]

# Apply response functions
response_fn_sigma = self.param_dict["scale"]
sigma = response_fn_sigma(eta_sigma)

# FIM for μ (location parameter uses identity)
fim_mu = 1.0 / (2.0 * sigma ** 2 + 1e-12)

# FIM for σ natural parameter - optimize for exp case
if response_fn_sigma == exp_fn:
# For exp: g'(η) = σ, so I(η_σ) = 1/2
fim_sigma = torch.ones_like(eta_sigma) * 0.5
else:
# For other response functions: compute derivative
eta_sigma_grad = eta_sigma.detach().requires_grad_(True)
sigma_grad = response_fn_sigma(eta_sigma_grad)

g_prime = torch.autograd.grad(
outputs=sigma_grad.sum(),
inputs=eta_sigma_grad,
create_graph=False,
retain_graph=False
)[0]

# Fisher Information: I(η_σ) = (1/(2σ²)) * (g'(η))²
fim_sigma = (1.0 / (2.0 * sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2)

return [fim_mu, fim_sigma]
57 changes: 53 additions & 4 deletions lightgbmlss/distributions/Gaussian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch.distributions import Normal as Gaussian_Torch
from .distribution_utils import DistributionClass
from ..utils import *
from typing import List


class Gaussian(DistributionClass):
Expand Down Expand Up @@ -29,16 +31,13 @@ class Gaussian(DistributionClass):
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
Hence, using the CRPS disregards any variation in the curvature of the loss function.
initialize: bool
Whether to initialize the distributional parameters with unconditional start values. Initialization can help
to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal
solutions if the unconditional start values are far from the optimal values.
"""
def __init__(self,
stabilization: str = "None",
response_fn: str = "exp",
loss_fn: str = "nll",
initialize: bool = False,
natural_gradient: bool = False,
):

# Input Checks
Expand Down Expand Up @@ -72,4 +71,54 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn,
initialize=initialize,
natural_gradient=natural_gradient,
)

def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute Fisher Information Matrix diagonal for Gaussian distribution.

For Gaussian N(μ, σ²):
- Fisher Information w.r.t. μ: I(μ) = 1/σ²
- Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ):
I(η_σ) = (2/σ²) * (g'(η_σ))²

Parameters
----------
predt : List[torch.Tensor]
[eta_mu, eta_sigma] - raw parameters before response functions

Returns
-------
fim : List[torch.Tensor]
[FIM_mu, FIM_sigma]
"""
eta_mu, eta_sigma = predt[0], predt[1]

# Apply response functions
response_fn_sigma = self.param_dict["scale"]
sigma = response_fn_sigma(eta_sigma)

# FIM for μ (location parameter uses identity)
fim_mu = 1.0 / (sigma ** 2 + 1e-12)

# FIM for σ natural parameter - optimize for exp case
if response_fn_sigma == exp_fn:
# For exp: g'(η) = σ, so I(η_σ) = 2
fim_sigma = torch.ones_like(eta_sigma) * 2.0
else:
# For other response functions: compute derivative
eta_sigma_grad = eta_sigma.detach().requires_grad_(True)
sigma_grad = response_fn_sigma(eta_sigma_grad)

g_prime = torch.autograd.grad(
outputs=sigma_grad.sum(),
inputs=eta_sigma_grad,
create_graph=False,
retain_graph=False
)[0]

# Fisher Information: I(η_σ) = (2/σ²) * (g'(η))²
fim_sigma = (2.0 / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2)

return [fim_mu, fim_sigma]
60 changes: 60 additions & 0 deletions lightgbmlss/distributions/Gumbel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.distributions import Gumbel as Gumbel_Torch
from .distribution_utils import DistributionClass
from ..utils import *
from typing import List
import math


class Gumbel(DistributionClass):
Expand Down Expand Up @@ -39,6 +41,7 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll",
initialize: bool = False,
natural_gradient: bool = False,
):

# Input Checks
Expand Down Expand Up @@ -72,4 +75,61 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn,
initialize=initialize,
natural_gradient=natural_gradient,
)

def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute Fisher Information Matrix diagonal for Gumbel distribution.

For Gumbel distribution with parameters (μ, σ):
- Fisher Information w.r.t. μ: I(μ) = 1/σ²
- Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ):
I(η_σ) = [(γ-1)² + π²/6] / σ² * (g'(η_σ))²
where γ ≈ 0.5772156649 is the Euler-Mascheroni constant

Parameters
----------
predt : List[torch.Tensor]
[eta_mu, eta_sigma] - raw parameters before response functions

Returns
-------
fim : List[torch.Tensor]
[FIM_mu, FIM_sigma]
"""
eta_mu, eta_sigma = predt[0], predt[1]

# Apply response functions
response_fn_sigma = self.param_dict["scale"]
sigma = response_fn_sigma(eta_sigma)

# FIM for μ (location parameter uses identity)
fim_mu = 1.0 / (sigma ** 2 + 1e-12)

# Euler-Mascheroni constant
euler_gamma = 0.5772156649015329

# Constant factor: (γ-1)² + π²/6
constant_factor = (euler_gamma - 1.0) ** 2 + (math.pi ** 2) / 6.0

# FIM for σ natural parameter - optimize for exp case
if response_fn_sigma == exp_fn:
# For exp: g'(η) = σ, so I(η_σ) = (γ-1)² + π²/6
fim_sigma = torch.ones_like(eta_sigma) * constant_factor
else:
# For other response functions: compute derivative
eta_sigma_grad = eta_sigma.detach().requires_grad_(True)
sigma_grad = response_fn_sigma(eta_sigma_grad)

g_prime = torch.autograd.grad(
outputs=sigma_grad.sum(),
inputs=eta_sigma_grad,
create_graph=False,
retain_graph=False
)[0]

# Fisher Information: I(η_σ) = [(γ-1)² + π²/6] / σ² * (g'(η))²
fim_sigma = (constant_factor / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2)

return [fim_mu, fim_sigma]
54 changes: 54 additions & 0 deletions lightgbmlss/distributions/Laplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch.distributions import Laplace as Laplace_Torch
from .distribution_utils import DistributionClass
from ..utils import *
from typing import List


class Laplace(DistributionClass):
Expand Down Expand Up @@ -73,3 +74,56 @@ def __init__(self,
loss_fn=loss_fn,
initialize=initialize,
)

def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute Fisher Information Matrix diagonal for Laplace distribution.

For Laplace L(μ, b):
- Fisher Information w.r.t. μ: I(μ) = 1/b²
- Fisher Information w.r.t. natural parameter η_b where b = g(η_b):
I(η_b) = (1/b²) * (g'(η_b))²

For common response functions:
- exp: g'(η) = b, so I(η_b) = 1
- softplus: g'(η) = sigmoid(η), so I(η_b) = (1/b²) * sigmoid(η)²

Parameters
----------
predt : List[torch.Tensor]
[eta_mu, eta_b] - raw parameters before response functions

Returns
-------
fim : List[torch.Tensor]
[FIM_mu, FIM_b]
"""
eta_mu, eta_b = predt[0], predt[1]

# Apply response functions
response_fn_b = self.param_dict["scale"]
b = response_fn_b(eta_b)

# FIM for μ (location parameter uses identity)
fim_mu = 1.0 / (b ** 2 + 1e-12)

# FIM for b natural parameter - optimize for exp case
if response_fn_b == exp_fn:
# For exp: g'(η) = b, so I(η_b) = 1
fim_b = torch.ones_like(eta_b)
else:
# For other response functions: compute derivative
eta_b_grad = eta_b.detach().requires_grad_(True)
b_grad = response_fn_b(eta_b_grad)

g_prime = torch.autograd.grad(
outputs=b_grad.sum(),
inputs=eta_b_grad,
create_graph=False,
retain_graph=False
)[0]

# Fisher Information: I(η_b) = (1/b²) * (g'(η))²
fim_b = (1.0 / (b.detach() ** 2 + 1e-12)) * (g_prime ** 2)

return [fim_mu, fim_b]
55 changes: 55 additions & 0 deletions lightgbmlss/distributions/LogNormal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch.distributions import LogNormal as LogNormal_Torch
from .distribution_utils import DistributionClass
from ..utils import *
from typing import List


class LogNormal(DistributionClass):
Expand Down Expand Up @@ -33,12 +34,15 @@ class LogNormal(DistributionClass):
Whether to initialize the distributional parameters with unconditional start values. Initialization can help
to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal
solutions if the unconditional start values are far from the optimal values.
natural_gradient: bool
Whether to use natural gradient descent for optimization.
"""
def __init__(self,
stabilization: str = "None",
response_fn: str = "exp",
loss_fn: str = "nll",
initialize: bool = False,
natural_gradient: bool = False,
):

# Input Checks
Expand Down Expand Up @@ -72,4 +76,55 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn,
initialize=initialize,
natural_gradient=natural_gradient,
)

def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Compute Fisher Information Matrix diagonal for LogNormal distribution.

For LogNormal with parameters (μ, σ) where Y = exp(X) and X ~ N(μ, σ²):
- Fisher Information w.r.t. μ: I(μ) = 1/σ²
- Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ):
I(η_σ) = (2/σ²) * (g'(η_σ))²

Parameters
----------
predt : List[torch.Tensor]
[eta_mu, eta_sigma] - raw parameters before response functions

Returns
-------
fim : List[torch.Tensor]
[FIM_mu, FIM_sigma]
"""
eta_mu, eta_sigma = predt[0], predt[1]

# Apply response functions
response_fn_sigma = self.param_dict["scale"]
sigma = response_fn_sigma(eta_sigma)

# FIM for μ (location parameter uses identity)
fim_mu = 1.0 / (sigma ** 2 + 1e-12)

# FIM for σ natural parameter - optimize for exp case
if response_fn_sigma == exp_fn:
# For exp: g'(η) = σ, so I(η_σ) = 2
fim_sigma = torch.ones_like(eta_sigma) * 2.0
else:
# For other response functions: compute derivative
eta_sigma_grad = eta_sigma.detach().requires_grad_(True)
sigma_grad = response_fn_sigma(eta_sigma_grad)

g_prime = torch.autograd.grad(
outputs=sigma_grad.sum(),
inputs=eta_sigma_grad,
create_graph=False,
retain_graph=False
)[0]

# Fisher Information: I(η_σ) = (2/σ²) * (g'(η))²
fim_sigma = (2.0 / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2)

return [fim_mu, fim_sigma]

Loading