@@ -25,11 +25,16 @@ class Expectile(DistributionClass):
2525 List of expectiles in increasing order.
2626 penalize_crossing: bool
2727 Whether to include a penalty term to discourage crossing of expectiles.
28+ initialize: bool
29+ Whether to initialize the distributional parameters with unconditional start values. Initialization can help
30+ to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal
31+ solutions if the unconditional start values are far from the optimal values.
2832 """
2933 def __init__ (self ,
3034 stabilization : str = "None" ,
3135 expectiles : List = [0.1 , 0.5 , 0.9 ],
3236 penalize_crossing : bool = False ,
37+ initialize : bool = False ,
3338 ):
3439
3540 # Input Checks
@@ -41,6 +46,8 @@ def __init__(self,
4146 raise ValueError ("Expectiles must be between 0 and 1." )
4247 if not isinstance (penalize_crossing , bool ):
4348 raise ValueError ("penalize_crossing must be a boolean. Please choose from True or False." )
49+ if not isinstance (initialize , bool ):
50+ raise ValueError ("Invalid initialize. Please choose from True or False." )
4451
4552 # Set the parameters specific to the distribution
4653 distribution = Expectile_Torch
@@ -61,7 +68,8 @@ def __init__(self,
6168 distribution_arg_names = list (param_dict .keys ()),
6269 loss_fn = "nll" ,
6370 tau = torch .tensor (expectiles ),
64- penalize_crossing = penalize_crossing
71+ penalize_crossing = penalize_crossing ,
72+ initialize = initialize ,
6573 )
6674
6775
0 commit comments