diff --git a/lightgbmlss/distributions/Cauchy.py b/lightgbmlss/distributions/Cauchy.py index 36b9ce2..5751d7c 100644 --- a/lightgbmlss/distributions/Cauchy.py +++ b/lightgbmlss/distributions/Cauchy.py @@ -1,6 +1,7 @@ from torch.distributions import Cauchy as Cauchy_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class Cauchy(DistributionClass): @@ -33,12 +34,15 @@ class Cauchy(DistributionClass): Whether to initialize the distributional parameters with unconditional start values. Initialization can help to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal solutions if the unconditional start values are far from the optimal values. + natural_gradient: bool + Whether to use natural gradient descent for optimization. """ def __init__(self, stabilization: str = "None", response_fn: str = "exp", loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, ): # Input Checks @@ -72,4 +76,54 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn, initialize=initialize, + natural_gradient=natural_gradient, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Cauchy distribution. + + For Cauchy distribution with parameters (μ, σ): + - Fisher Information w.r.t. μ: I(μ) = 1/(2σ²) + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = (1/(2σ²)) * (g'(η_σ))² + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_sigma] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_sigma] + """ + eta_mu, eta_sigma = predt[0], predt[1] + + # Apply response functions + response_fn_sigma = self.param_dict["scale"] + sigma = response_fn_sigma(eta_sigma) + + # FIM for μ (location parameter uses identity) + fim_mu = 1.0 / (2.0 * sigma ** 2 + 1e-12) + + # FIM for σ natural parameter - optimize for exp case + if response_fn_sigma == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = 1/2 + fim_sigma = torch.ones_like(eta_sigma) * 0.5 + else: + # For other response functions: compute derivative + eta_sigma_grad = eta_sigma.detach().requires_grad_(True) + sigma_grad = response_fn_sigma(eta_sigma_grad) + + g_prime = torch.autograd.grad( + outputs=sigma_grad.sum(), + inputs=eta_sigma_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = (1/(2σ²)) * (g'(η))² + fim_sigma = (1.0 / (2.0 * sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + return [fim_mu, fim_sigma] diff --git a/lightgbmlss/distributions/Gaussian.py b/lightgbmlss/distributions/Gaussian.py index 9f75cf3..ee9a58f 100644 --- a/lightgbmlss/distributions/Gaussian.py +++ b/lightgbmlss/distributions/Gaussian.py @@ -1,6 +1,8 @@ +import torch from torch.distributions import Normal as Gaussian_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class Gaussian(DistributionClass): @@ -29,16 +31,13 @@ class Gaussian(DistributionClass): Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score). Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable. Hence, using the CRPS disregards any variation in the curvature of the loss function. - initialize: bool - Whether to initialize the distributional parameters with unconditional start values. Initialization can help - to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal - solutions if the unconditional start values are far from the optimal values. """ def __init__(self, stabilization: str = "None", response_fn: str = "exp", loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, ): # Input Checks @@ -72,4 +71,54 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn, initialize=initialize, + natural_gradient=natural_gradient, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Gaussian distribution. + + For Gaussian N(μ, σ²): + - Fisher Information w.r.t. μ: I(μ) = 1/σ² + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = (2/σ²) * (g'(η_σ))² + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_sigma] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_sigma] + """ + eta_mu, eta_sigma = predt[0], predt[1] + + # Apply response functions + response_fn_sigma = self.param_dict["scale"] + sigma = response_fn_sigma(eta_sigma) + + # FIM for μ (location parameter uses identity) + fim_mu = 1.0 / (sigma ** 2 + 1e-12) + + # FIM for σ natural parameter - optimize for exp case + if response_fn_sigma == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = 2 + fim_sigma = torch.ones_like(eta_sigma) * 2.0 + else: + # For other response functions: compute derivative + eta_sigma_grad = eta_sigma.detach().requires_grad_(True) + sigma_grad = response_fn_sigma(eta_sigma_grad) + + g_prime = torch.autograd.grad( + outputs=sigma_grad.sum(), + inputs=eta_sigma_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = (2/σ²) * (g'(η))² + fim_sigma = (2.0 / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + return [fim_mu, fim_sigma] diff --git a/lightgbmlss/distributions/Gumbel.py b/lightgbmlss/distributions/Gumbel.py index 4106225..95c1b92 100644 --- a/lightgbmlss/distributions/Gumbel.py +++ b/lightgbmlss/distributions/Gumbel.py @@ -1,6 +1,8 @@ from torch.distributions import Gumbel as Gumbel_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List +import math class Gumbel(DistributionClass): @@ -39,6 +41,7 @@ def __init__(self, response_fn: str = "exp", loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, ): # Input Checks @@ -72,4 +75,61 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn, initialize=initialize, + natural_gradient=natural_gradient, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Gumbel distribution. + + For Gumbel distribution with parameters (μ, σ): + - Fisher Information w.r.t. μ: I(μ) = 1/σ² + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = [(γ-1)² + π²/6] / σ² * (g'(η_σ))² + where γ ≈ 0.5772156649 is the Euler-Mascheroni constant + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_sigma] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_sigma] + """ + eta_mu, eta_sigma = predt[0], predt[1] + + # Apply response functions + response_fn_sigma = self.param_dict["scale"] + sigma = response_fn_sigma(eta_sigma) + + # FIM for μ (location parameter uses identity) + fim_mu = 1.0 / (sigma ** 2 + 1e-12) + + # Euler-Mascheroni constant + euler_gamma = 0.5772156649015329 + + # Constant factor: (γ-1)² + π²/6 + constant_factor = (euler_gamma - 1.0) ** 2 + (math.pi ** 2) / 6.0 + + # FIM for σ natural parameter - optimize for exp case + if response_fn_sigma == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = (γ-1)² + π²/6 + fim_sigma = torch.ones_like(eta_sigma) * constant_factor + else: + # For other response functions: compute derivative + eta_sigma_grad = eta_sigma.detach().requires_grad_(True) + sigma_grad = response_fn_sigma(eta_sigma_grad) + + g_prime = torch.autograd.grad( + outputs=sigma_grad.sum(), + inputs=eta_sigma_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = [(γ-1)² + π²/6] / σ² * (g'(η))² + fim_sigma = (constant_factor / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + return [fim_mu, fim_sigma] diff --git a/lightgbmlss/distributions/Laplace.py b/lightgbmlss/distributions/Laplace.py index 6b3bff0..082aaa7 100644 --- a/lightgbmlss/distributions/Laplace.py +++ b/lightgbmlss/distributions/Laplace.py @@ -1,6 +1,7 @@ from torch.distributions import Laplace as Laplace_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class Laplace(DistributionClass): @@ -73,3 +74,56 @@ def __init__(self, loss_fn=loss_fn, initialize=initialize, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Laplace distribution. + + For Laplace L(μ, b): + - Fisher Information w.r.t. μ: I(μ) = 1/b² + - Fisher Information w.r.t. natural parameter η_b where b = g(η_b): + I(η_b) = (1/b²) * (g'(η_b))² + + For common response functions: + - exp: g'(η) = b, so I(η_b) = 1 + - softplus: g'(η) = sigmoid(η), so I(η_b) = (1/b²) * sigmoid(η)² + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_b] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_b] + """ + eta_mu, eta_b = predt[0], predt[1] + + # Apply response functions + response_fn_b = self.param_dict["scale"] + b = response_fn_b(eta_b) + + # FIM for μ (location parameter uses identity) + fim_mu = 1.0 / (b ** 2 + 1e-12) + + # FIM for b natural parameter - optimize for exp case + if response_fn_b == exp_fn: + # For exp: g'(η) = b, so I(η_b) = 1 + fim_b = torch.ones_like(eta_b) + else: + # For other response functions: compute derivative + eta_b_grad = eta_b.detach().requires_grad_(True) + b_grad = response_fn_b(eta_b_grad) + + g_prime = torch.autograd.grad( + outputs=b_grad.sum(), + inputs=eta_b_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_b) = (1/b²) * (g'(η))² + fim_b = (1.0 / (b.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + return [fim_mu, fim_b] diff --git a/lightgbmlss/distributions/LogNormal.py b/lightgbmlss/distributions/LogNormal.py index 226715a..7908435 100644 --- a/lightgbmlss/distributions/LogNormal.py +++ b/lightgbmlss/distributions/LogNormal.py @@ -1,6 +1,7 @@ from torch.distributions import LogNormal as LogNormal_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class LogNormal(DistributionClass): @@ -33,12 +34,15 @@ class LogNormal(DistributionClass): Whether to initialize the distributional parameters with unconditional start values. Initialization can help to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal solutions if the unconditional start values are far from the optimal values. + natural_gradient: bool + Whether to use natural gradient descent for optimization. """ def __init__(self, stabilization: str = "None", response_fn: str = "exp", loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, ): # Input Checks @@ -72,4 +76,55 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn, initialize=initialize, + natural_gradient=natural_gradient, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for LogNormal distribution. + + For LogNormal with parameters (μ, σ) where Y = exp(X) and X ~ N(μ, σ²): + - Fisher Information w.r.t. μ: I(μ) = 1/σ² + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = (2/σ²) * (g'(η_σ))² + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_sigma] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_sigma] + """ + eta_mu, eta_sigma = predt[0], predt[1] + + # Apply response functions + response_fn_sigma = self.param_dict["scale"] + sigma = response_fn_sigma(eta_sigma) + + # FIM for μ (location parameter uses identity) + fim_mu = 1.0 / (sigma ** 2 + 1e-12) + + # FIM for σ natural parameter - optimize for exp case + if response_fn_sigma == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = 2 + fim_sigma = torch.ones_like(eta_sigma) * 2.0 + else: + # For other response functions: compute derivative + eta_sigma_grad = eta_sigma.detach().requires_grad_(True) + sigma_grad = response_fn_sigma(eta_sigma_grad) + + g_prime = torch.autograd.grad( + outputs=sigma_grad.sum(), + inputs=eta_sigma_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = (2/σ²) * (g'(η))² + fim_sigma = (2.0 / (sigma.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + return [fim_mu, fim_sigma] + diff --git a/lightgbmlss/distributions/Logistic.py b/lightgbmlss/distributions/Logistic.py index 85909d5..207bc18 100644 --- a/lightgbmlss/distributions/Logistic.py +++ b/lightgbmlss/distributions/Logistic.py @@ -1,6 +1,7 @@ from pyro.distributions import Logistic as Logistic_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class Logistic(DistributionClass): @@ -73,3 +74,59 @@ def __init__(self, loss_fn=loss_fn, initialize=initialize, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Logistic distribution. + + For Logistic L(μ, σ): + - Fisher Information w.r.t. μ: I(μ) = 3/σ² + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = ((π²+3)/(9σ²)) * (g'(η_σ))² + + For common response functions: + - exp: g'(η) = σ, so I(η_σ) = (π²+3)/9 + - softplus: g'(η) = sigmoid(η), so I(η_σ) = ((π²+3)/(9σ²)) * sigmoid(η)² + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_mu, eta_sigma] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_mu, FIM_sigma] + """ + eta_mu, eta_sigma = predt[0], predt[1] + + # Apply response functions + response_fn_sigma = self.param_dict["scale"] + sigma = response_fn_sigma(eta_sigma) + + # FIM for μ (location parameter uses identity) + fim_mu = 3.0 / (sigma ** 2 + 1e-12) + + # Constant for logistic distribution + pi_squared_plus_3 = torch.pi ** 2 + 3.0 + + # FIM for σ natural parameter - optimize for exp case + if response_fn_sigma == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = (π²+3)/9 + fim_sigma = torch.ones_like(eta_sigma) * (pi_squared_plus_3 / 9.0) + else: + # For other response functions: compute derivative + eta_sigma_grad = eta_sigma.detach().requires_grad_(True) + sigma_grad = response_fn_sigma(eta_sigma_grad) + + g_prime = torch.autograd.grad( + outputs=sigma_grad.sum(), + inputs=eta_sigma_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = ((π²+3)/(9σ²)) * (g'(η))² + fim_sigma = (pi_squared_plus_3 / (9.0 * (sigma.detach() ** 2 + 1e-12))) * (g_prime ** 2) + + return [fim_mu, fim_sigma] diff --git a/lightgbmlss/distributions/Poisson.py b/lightgbmlss/distributions/Poisson.py index 051a869..30271e1 100644 --- a/lightgbmlss/distributions/Poisson.py +++ b/lightgbmlss/distributions/Poisson.py @@ -1,6 +1,7 @@ from torch.distributions import Poisson as Poisson_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class Poisson(DistributionClass): @@ -69,3 +70,57 @@ def __init__(self, loss_fn=loss_fn, initialize=initialize, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Poisson distribution. + + For Poisson with rate parameter λ: + - The variance equals the mean: Var(Y) = λ + - Fisher Information with respect to the natural parameter η (where λ = g(η)) + and g is the response function (exp, softplus, or relu): + FIM_η = λ * (g'(η))^2 + + For common response functions: + - exp: g'(η) = exp(η) = λ, so FIM = λ^2 + - softplus: g'(η) = sigmoid(η), so FIM = λ * sigmoid(η)^2 + - relu: g'(η) = 1 if η > 0 else 0, so FIM ≈ λ (for η > 0) + + Parameters + ---------- + predt : List[torch.Tensor] + [eta] - raw parameter before response function + + Returns + ------- + fim : List[torch.Tensor] + [FIM_eta] - Fisher Information for the natural parameter + """ + # Raw parameter (before response function) + eta = predt[0] + + # Apply response function to get λ + response_fn = list(self.param_dict.values())[0] + lam = response_fn(eta) + + # FIM for rate parameter - optimize for exp case + if response_fn == exp_fn: + # For exp: g'(η) = λ, so I(η) = λ² + fim_eta = lam.detach() ** 2 + 1e-8 + else: + # For other response functions: compute derivative + eta_grad = eta.detach().requires_grad_(True) + lam_grad = response_fn(eta_grad) + + # Get derivative g'(η) + g_prime = torch.autograd.grad( + outputs=lam_grad.sum(), + inputs=eta_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η) = λ * (g'(η))^2 + fim_eta = lam.detach() * (g_prime ** 2) + 1e-8 + + return [fim_eta] \ No newline at end of file diff --git a/lightgbmlss/distributions/StudentT.py b/lightgbmlss/distributions/StudentT.py index 8ec63cb..6ad811e 100644 --- a/lightgbmlss/distributions/StudentT.py +++ b/lightgbmlss/distributions/StudentT.py @@ -1,6 +1,8 @@ +import torch from torch.distributions import StudentT as StudentT_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List class StudentT(DistributionClass): @@ -78,3 +80,61 @@ def __init__(self, loss_fn=loss_fn, initialize=initialize, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Student-T distribution. + + For Student-T(ν, μ, σ): + - Fisher Information w.r.t. μ: I(μ) = (ν+1)/((ν+3)σ²) + - Fisher Information w.r.t. natural parameter η_σ where σ = g(η_σ): + I(η_σ) = (2ν/((ν+3)σ²)) * (g'(η_σ))² + - Fisher Information w.r.t. ν: Using simplified constant approximation + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_df, eta_loc, eta_scale] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_df, FIM_loc, FIM_scale] + """ + eta_df, eta_loc, eta_scale = predt[0], predt[1], predt[2] + + # Apply response functions + df = self.param_dict["df"](eta_df) + loc = self.param_dict["loc"](eta_loc) + scale = self.param_dict["scale"](eta_scale) + + # FIM for loc (μ) - location parameter uses identity + fim_loc = (df + 1.0) / ((df + 3.0) * scale ** 2 + 1e-12) + + # FIM for scale (σ) natural parameter + response_fn_scale = self.param_dict["scale"] + if response_fn_scale == exp_fn: + # For exp: g'(η) = σ, so I(η_σ) = (2ν/((ν+3)σ²)) * σ² = 2ν/(ν+3) + fim_scale = (2.0 * df) / (df + 3.0 + 1e-12) + else: + # For other response functions: compute derivative + eta_scale_grad = eta_scale.detach().requires_grad_(True) + scale_grad = response_fn_scale(eta_scale_grad) + + g_prime = torch.autograd.grad( + outputs=scale_grad.sum(), + inputs=eta_scale_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_σ) = (2ν/((ν+3)σ²)) * (g'(η))² + fim_scale = (2.0 * df.detach() / ((df.detach() + 3.0) * scale.detach() ** 2 + 1e-12)) * (g_prime ** 2) + + # FIM for df (ν) - using simplified constant approximation + # The exact FIM for df is complex and not provided in standard references + # Using a conservative constant value + fim_df = torch.ones_like(eta_df) * 0.5 + + return [fim_df, fim_loc, fim_scale] + diff --git a/lightgbmlss/distributions/Weibull.py b/lightgbmlss/distributions/Weibull.py index 48ddfe6..0492c84 100644 --- a/lightgbmlss/distributions/Weibull.py +++ b/lightgbmlss/distributions/Weibull.py @@ -1,6 +1,8 @@ from torch.distributions import Weibull as Weibull_Torch from .distribution_utils import DistributionClass from ..utils import * +from typing import List +import math class Weibull(DistributionClass): @@ -33,12 +35,15 @@ class Weibull(DistributionClass): Whether to initialize the distributional parameters with unconditional start values. Initialization can help to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal solutions if the unconditional start values are far from the optimal values. + natural_gradient: bool + Whether to use natural gradient descent for optimization. """ def __init__(self, stabilization: str = "None", response_fn: str = "exp", loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, ): # Input Checks @@ -72,4 +77,80 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn, initialize=initialize, + natural_gradient=natural_gradient, ) + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix diagonal for Weibull distribution. + + For Weibull distribution with parameters (λ, β) where λ is scale and β is concentration: + - Fisher Information w.r.t. natural parameter η_λ where λ = g(η_λ): + I(η_λ) = (β²/λ²) * (g'(η_λ))² + - Fisher Information w.r.t. natural parameter η_β where β = g(η_β): + I(η_β) = [(γ-1)² + π²/6] * (g'(η_β))² + where γ ≈ 0.5772156649 is the Euler-Mascheroni constant + + Parameters + ---------- + predt : List[torch.Tensor] + [eta_scale, eta_concentration] - raw parameters before response functions + + Returns + ------- + fim : List[torch.Tensor] + [FIM_scale, FIM_concentration] + """ + eta_scale, eta_concentration = predt[0], predt[1] + + # Apply response functions + response_fn_scale = self.param_dict["scale"] + response_fn_concentration = self.param_dict["concentration"] + scale = response_fn_scale(eta_scale) + concentration = response_fn_concentration(eta_concentration) + + # Euler-Mascheroni constant + euler_gamma = 0.5772156649015329 + + # FIM for scale (λ) natural parameter + if response_fn_scale == exp_fn: + # For exp: g'(η) = λ, so I(η_λ) = β² + fim_scale = concentration.detach() ** 2 + else: + # For other response functions: compute derivative + eta_scale_grad = eta_scale.detach().requires_grad_(True) + scale_grad = response_fn_scale(eta_scale_grad) + + g_prime_scale = torch.autograd.grad( + outputs=scale_grad.sum(), + inputs=eta_scale_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_λ) = (β²/λ²) * (g'(η))² + fim_scale = (concentration.detach() ** 2 / (scale.detach() ** 2 + 1e-12)) * (g_prime_scale ** 2) + + # FIM for concentration (β) natural parameter + # Constant factor: (γ-1)² + π²/6 + constant_factor = (euler_gamma - 1.0) ** 2 + (math.pi ** 2) / 6.0 + + if response_fn_concentration == exp_fn: + # For exp: g'(η) = β, so I(η_β) = constant_factor + fim_concentration = torch.ones_like(eta_concentration) * constant_factor + else: + # For other response functions: compute derivative + eta_concentration_grad = eta_concentration.detach().requires_grad_(True) + concentration_grad = response_fn_concentration(eta_concentration_grad) + + g_prime_concentration = torch.autograd.grad( + outputs=concentration_grad.sum(), + inputs=eta_concentration_grad, + create_graph=False, + retain_graph=False + )[0] + + # Fisher Information: I(η_β) = [(γ-1)² + π²/6] * (g'(η))² + fim_concentration = constant_factor * (g_prime_concentration ** 2) + + return [fim_scale, fim_concentration] diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index da9d379..ac78191 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -9,6 +9,8 @@ from tqdm import tqdm from typing import Any, Dict, Optional, List, Tuple +import matplotlib.pyplot as plt +import seaborn as sns import warnings @@ -55,6 +57,7 @@ def __init__(self, distribution_arg_names: List = None, loss_fn: str = "nll", initialize: bool = False, + natural_gradient: bool = False, tau: Optional[List[torch.Tensor]] = None, penalize_crossing: bool = False, ): @@ -68,6 +71,7 @@ def __init__(self, self.distribution_arg_names = distribution_arg_names self.loss_fn = loss_fn self.initialize = initialize + self.natural_gradient = natural_gradient self.tau = tau self.penalize_crossing = penalize_crossing @@ -161,15 +165,15 @@ def loss_fn_start_values(self, loss: torch.Tensor Loss value. """ - # Replace NaNs and infinity values with 0.5 - nan_inf_idx = torch.isnan(torch.stack(params)) | torch.isinf(torch.stack(params)) - params = torch.where(nan_inf_idx, torch.tensor(0.5), torch.stack(params)) - # Transform parameters to response scale params = [ response_fn(params[i].reshape(-1, 1)) for i, response_fn in enumerate(self.param_dict.values()) ] + # Replace NaNs and infinity values with 0.5 + nan_inf_idx = torch.isnan(torch.stack(params)) | torch.isinf(torch.stack(params)) + params = torch.where(nan_inf_idx, torch.tensor(0.5), torch.stack(params)) + # Specify Distribution and Loss if self.tau is None: dist = self.distribution(*params) @@ -429,6 +433,26 @@ def predict_dist(self, if self.discrete: pred_quant_df = pred_quant_df.astype(int) return pred_quant_df + + def compute_fisher_information_matrix(self, predt: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Compute Fisher Information Matrix (FIM) diagonal elements. + + This is a default implementation that returns ones (no scaling). + Override this method in specific distribution classes for proper FIM computation. + + Parameters + ---------- + predt : List[torch.Tensor] + List of predicted distributional parameters (raw scale, before response functions). + + Returns + ------- + fim : List[torch.Tensor] + List of FIM diagonal elements for each parameter. + """ + # Default: return ones (no natural gradient scaling) + return [torch.ones_like(p) for p in predt] def compute_gradients_and_hessians(self, loss: torch.tensor, @@ -460,34 +484,21 @@ def compute_gradients_and_hessians(self, # Gradient and Hessian grad = autograd(loss, inputs=predt, create_graph=True) hess = [autograd(grad[i].nansum(), inputs=predt[i], retain_graph=True)[0] for i in range(len(grad))] + if self.natural_gradient: + # Compute Fisher Information Matrix + fim = self.compute_fisher_information_matrix(predt) + # Apply natural gradient: grad_natural = grad / FIM + grad = [grad[i] / (fim[i] + 1e-12) for i in range(len(grad))] + else: + pass elif self.loss_fn == "crps": # Gradient and Hessian grad = autograd(loss, inputs=predt, create_graph=True) hess = [torch.ones_like(grad[i]) for i in range(len(grad))] - - # # Approximation of Hessian - # step_size = 1e-6 - # predt_upper = [ - # response_fn(predt[i] + step_size).reshape(-1, 1) for i, response_fn in - # enumerate(self.param_dict.values()) - # ] - # dist_kwargs_upper = dict(zip(self.distribution_arg_names, predt_upper)) - # dist_fit_upper = self.distribution(**dist_kwargs_upper) - # dist_samples_upper = dist_fit_upper.rsample((30,)).squeeze(-1) - # loss_upper = torch.nansum(self.crps_score(self.target, dist_samples_upper)) - # - # predt_lower = [ - # response_fn(predt[i] - step_size).reshape(-1, 1) for i, response_fn in - # enumerate(self.param_dict.values()) - # ] - # dist_kwargs_lower = dict(zip(self.distribution_arg_names, predt_lower)) - # dist_fit_lower = self.distribution(**dist_kwargs_lower) - # dist_samples_lower = dist_fit_lower.rsample((30,)).squeeze(-1) - # loss_lower = torch.nansum(self.crps_score(self.target, dist_samples_lower)) - # - # grad_upper = autograd(loss_upper, inputs=predt_upper) - # grad_lower = autograd(loss_lower, inputs=predt_lower) - # hess = [(grad_upper[i] - grad_lower[i]) / (2 * step_size) for i in range(len(grad))] + if self.natural_gradient: + warnings.warn("Natural Gradient is not implemented for CRPS. Using standard Gradient instead.") + else: + pass # Stabilization of Derivatives if self.stabilization != "None": diff --git a/tests/test_distributions/test_natural.py b/tests/test_distributions/test_natural.py new file mode 100644 index 0000000..c7c1b55 --- /dev/null +++ b/tests/test_distributions/test_natural.py @@ -0,0 +1,52 @@ +from ..utils import BaseTestClass +import pytest +import torch +import inspect + + +class TestClass(BaseTestClass): + @staticmethod + def supports_natural_gradient(dist_class): + """Check if distribution supports natural_gradient parameter.""" + sig = inspect.signature(dist_class.__init__) + return 'natural_gradient' in sig.parameters + + def test_natural_gradient_parameter(self, univariate_cont_dist): + """Test that natural_gradient parameter is properly handled.""" + if not self.supports_natural_gradient(univariate_cont_dist): + pytest.skip(f"{univariate_cont_dist.__name__} doesn't support natural_gradient parameter") + + # Test default value + dist_default = univariate_cont_dist() + assert isinstance(dist_default.natural_gradient, bool) + assert dist_default.natural_gradient is False + + # Test explicit True/False + dist_true = univariate_cont_dist(natural_gradient=True) + assert dist_true.natural_gradient is True + + dist_false = univariate_cont_dist(natural_gradient=False) + assert dist_false.natural_gradient is False + + def test_natural_gradient_affects_gradient_computation(self, univariate_cont_dist): + """Test that natural_gradient enables FIM computation.""" + if not self.supports_natural_gradient(univariate_cont_dist): + pytest.skip(f"{univariate_cont_dist.__name__} doesn't support natural_gradient parameter") + + # Create test data + n_params = univariate_cont_dist().n_dist_param + predt = [torch.randn(100, requires_grad=True) for _ in range(n_params)] + + # Test that natural gradient distribution can compute FIM + dist_natural = univariate_cont_dist(natural_gradient=True) + fim = dist_natural.compute_fisher_information_matrix(predt) + + # Verify FIM is computed correctly + assert fim is not None + assert len(fim) == n_params + + # Verify each FIM element has correct shape and is positive + for i, fim_i in enumerate(fim): + assert fim_i.shape == predt[i].shape + assert torch.all(fim_i > 0), f"FIM element {i} should be positive" + assert torch.all(torch.isfinite(fim_i)), f"FIM element {i} should be finite" \ No newline at end of file diff --git a/tests/test_distributions/test_univariate_cont_distns.py b/tests/test_distributions/test_univariate_cont_distns.py index 539d7f9..6b09ee0 100644 --- a/tests/test_distributions/test_univariate_cont_distns.py +++ b/tests/test_distributions/test_univariate_cont_distns.py @@ -1,5 +1,6 @@ from ..utils import BaseTestClass import pytest +import torch class TestClass(BaseTestClass): @@ -49,3 +50,79 @@ def test_defaults(self, univariate_cont_dist): assert univariate_cont_dist().initialize is False assert univariate_cont_dist(initialize=True).initialize is True assert univariate_cont_dist(initialize=False).initialize is False + + def test_fisher_information_matrix_exp(self, univariate_cont_dist): + """Test FIM computation with exp response function.""" + dist = univariate_cont_dist(response_fn="exp") + + # Create test data with appropriate dimensions + n_params = dist.n_dist_param + eta_tensors = [torch.tensor([0.0, 1.0, -1.0]) for _ in range(n_params)] + + # Compute FIM + fim = dist.compute_fisher_information_matrix(eta_tensors) + + # Assertions + assert len(fim) == n_params + for i, f in enumerate(fim): + assert f.shape == eta_tensors[i].shape + # All FIM values should be positive + assert torch.all(f > 0), f"FIM values for parameter {i} should be positive" + # Should not contain NaN or Inf + assert not torch.any(torch.isnan(f)), f"FIM contains NaN for parameter {i}" + assert not torch.any(torch.isinf(f)), f"FIM contains Inf for parameter {i}" + + def test_fisher_information_matrix_softplus(self, univariate_cont_dist): + """Test FIM computation with softplus response function.""" + dist = univariate_cont_dist(response_fn="softplus") + + # Create test data with appropriate dimensions + n_params = dist.n_dist_param + eta_tensors = [torch.tensor([0.5, 1.0, 1.5]) for _ in range(n_params)] + + # Compute FIM + fim = dist.compute_fisher_information_matrix(eta_tensors) + + # Assertions + assert len(fim) == n_params + for i, f in enumerate(fim): + assert f.shape == eta_tensors[i].shape + # All FIM values should be positive + assert torch.all(f > 0), f"FIM values for parameter {i} should be positive" + # Should not contain NaN or Inf + assert not torch.any(torch.isnan(f)), f"FIM contains NaN for parameter {i}" + assert not torch.any(torch.isinf(f)), f"FIM contains Inf for parameter {i}" + + def test_fisher_information_matrix_numerical_stability(self, univariate_cont_dist): + """Test FIM computation with extreme values.""" + dist = univariate_cont_dist(response_fn="exp") + + # Test with very small and large values + n_params = dist.n_dist_param + eta_tensors = [torch.tensor([-10.0, 0.0, 10.0]) for _ in range(n_params)] + + # Compute FIM + fim = dist.compute_fisher_information_matrix(eta_tensors) + + # Assertions + for i, f in enumerate(fim): + # Should not contain NaN or Inf + assert not torch.any(torch.isnan(f)), f"FIM contains NaN for parameter {i}" + assert not torch.any(torch.isinf(f)), f"FIM contains Inf for parameter {i}" + # All values should be positive + assert torch.all(f > 0), f"FIM values for parameter {i} should be positive" + + def test_fisher_information_matrix_batch_dimensions(self, univariate_cont_dist): + """Test FIM computation with different batch dimensions.""" + dist = univariate_cont_dist(response_fn="exp") + + # Test with different batch sizes + for batch_size in [1, 10, 100]: + eta_tensors = [torch.randn(batch_size) for _ in range(dist.n_dist_param)] + + fim = dist.compute_fisher_information_matrix(eta_tensors) + + assert len(fim) == dist.n_dist_param + for i, f in enumerate(fim): + assert f.shape == (batch_size,) + assert torch.all(f > 0), f"FIM values for parameter {i} should be positive" diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py index d33625f..db35579 100644 --- a/tests/test_model/test_model.py +++ b/tests/test_model/test_model.py @@ -241,3 +241,114 @@ def test_model_flow_train(self, univariate_data, flow_lgblss): # Assertions assert isinstance(lgblss.booster, lgb.Booster) + + @pytest.mark.parametrize("quantile_clipping, clip_value", [(True, 0.1)]) + def test_lightgbmlss_with_quantile_gradient_clipping(self, univariate_data, quantile_clipping, clip_value): + dtrain, _, deval, X_test = univariate_data + + # Define the distribution and model with gradient clipping + dist = Gaussian(stabilization="None", + response_fn="exp", + loss_fn="nll", + initialize=False, + natural_gradient=True, + quantile_clipping=quantile_clipping, + clip_value=clip_value) + model = LightGBMLSS(dist=dist) + + # Define training parameters + params = { + "verbosity": -1, + "learning_rate": 0.1, + "num_leaves": 31, + "min_data_in_leaf": 20, + "feature_fraction": 0.9 + } + + # Train the model + model.train( + params=params, + train_set=dtrain, + num_boost_round=1, + valid_sets=[dtrain, deval], + ) + + # Predict and evaluate + y_pred = model.predict(pd.DataFrame(X_test))['loc'].values + + # Assert that predictions are not NaN and have the correct shape + assert not np.isnan(y_pred).any(), "Predictions contain NaN values" + assert y_pred.shape == (X_test.shape[0],), "Predictions have incorrect shape" + + @pytest.mark.parametrize("quantile_clipping, clip_value", [(False, 0.1)]) + def test_lightgbmlss_with_gradient_clipping(self, univariate_data, quantile_clipping, clip_value): + dtrain, _, deval, X_test = univariate_data + + # Define the distribution and model with gradient clipping + dist = Gaussian(stabilization="None", + response_fn="exp", + loss_fn="nll", + initialize=False, + natural_gradient=True, + quantile_clipping=quantile_clipping, + clip_value=clip_value) + model = LightGBMLSS(dist=dist) + + # Define training parameters + params = { + "verbosity": -1, + "learning_rate": 0.1, + "num_leaves": 31, + "min_data_in_leaf": 20, + "feature_fraction": 0.9 + } + + # Train the model + model.train( + params=params, + train_set=dtrain, + num_boost_round=1, + valid_sets=[dtrain, deval], + ) + + # Predict and evaluate + y_pred = model.predict(pd.DataFrame(X_test))['loc'].values + + # Assert that predictions are not NaN and have the correct shape + assert not np.isnan(y_pred).any(), "Predictions contain NaN values" + assert y_pred.shape == (X_test.shape[0],), "Predictions have incorrect shape" + + def test_lightgbmlss_with_natural_gradient(self, univariate_data): + dtrain, _, deval, X_test = univariate_data + + # Define the distribution and model + dist = Gaussian(stabilization="None", + response_fn="exp", + loss_fn="nll", + initialize=False, + natural_gradient=True) + model = LightGBMLSS(dist=dist) + + # Define training parameters + params = { + "verbosity": -1, + "learning_rate": 0.1, + "num_leaves": 31, + "min_data_in_leaf": 20, + "feature_fraction": 0.9 + } + + # Train the model + model.train( + params=params, + train_set=dtrain, + num_boost_round=1, + valid_sets=[dtrain, deval], + ) + + # Predict and evaluate + y_pred = model.predict(pd.DataFrame(X_test))['loc'].values + + # Assert that predictions are not NaN and have the correct shape + assert not np.isnan(y_pred).any(), "Predictions contain NaN values" + assert y_pred.shape == (X_test.shape[0],), "Predictions have incorrect shape"