22
33namespace HiFolks \Statistics ;
44
5+ use HiFolks \Statistics \Enums \KdeKernel ;
56use HiFolks \Statistics \Exception \InvalidDataInputException ;
67
78class Stat
@@ -723,16 +724,16 @@ private static function ranks(array $data): array
723724 *
724725 * @param array<int|float> $data sample data
725726 * @param float $h bandwidth (smoothing parameter), must be > 0
726- * @param string $kernel kernel name (normal, logistic, sigmoid, rectangular, triangular, parabolic, quartic, triweight, cosine) or alias
727+ * @param KdeKernel $kernel kernel to use for estimation
727728 * @param bool $cumulative if true, return CDF estimator; otherwise PDF estimator
728729 * @return \Closure a callable that takes a float and returns the estimated density or CDF value
729730 *
730- * @throws InvalidDataInputException if data is empty, bandwidth <= 0, or kernel is invalid
731+ * @throws InvalidDataInputException if data is empty or bandwidth <= 0
731732 */
732733 public static function kde (
733734 array $ data ,
734735 float $ h ,
735- string $ kernel = ' normal ' ,
736+ KdeKernel $ kernel = KdeKernel::Normal ,
736737 bool $ cumulative = false ,
737738 ): \Closure {
738739 if ($ data === []) {
@@ -742,16 +743,7 @@ public static function kde(
742743 throw new InvalidDataInputException ("Bandwidth h must be positive. " );
743744 }
744745
745- $ aliases = [
746- 'gauss ' => 'normal ' ,
747- 'uniform ' => 'rectangular ' ,
748- 'epanechnikov ' => 'parabolic ' ,
749- 'biweight ' => 'quartic ' ,
750- ];
751- $ kernel = strtolower ($ kernel );
752- if (isset ($ aliases [$ kernel ])) {
753- $ kernel = $ aliases [$ kernel ];
754- }
746+ $ kernel = $ kernel ->resolve ();
755747
756748 $ sqrt2pi = sqrt (2.0 * M_PI );
757749
@@ -773,63 +765,56 @@ public static function kde(
773765 };
774766
775767 $ kernels = [
776- ' normal ' => [
768+ KdeKernel::Normal-> value => [
777769 'pdf ' => static fn (float $ t ): float => exp (-$ t * $ t / 2.0 ) / $ sqrt2pi ,
778770 'cdf ' => $ normalCdf ,
779771 'support ' => null ,
780772 ],
781- ' logistic ' => [
773+ KdeKernel::Logistic-> value => [
782774 'pdf ' => static fn (float $ t ): float => 0.5 / (1.0 + cosh ($ t )),
783775 'cdf ' => static fn (float $ t ): float => 1.0 / (1.0 + exp (-$ t )),
784776 'support ' => null ,
785777 ],
786- ' sigmoid ' => [
778+ KdeKernel::Sigmoid-> value => [
787779 'pdf ' => static fn (float $ t ): float => (1.0 / M_PI ) / cosh ($ t ),
788780 'cdf ' => static fn (float $ t ): float => (2.0 / M_PI ) * atan (exp ($ t )),
789781 'support ' => null ,
790782 ],
791- ' rectangular ' => [
783+ KdeKernel::Rectangular-> value => [
792784 'pdf ' => static fn (float $ t ): float => 0.5 ,
793785 'cdf ' => static fn (float $ t ): float => 0.5 * $ t + 0.5 ,
794786 'support ' => 1.0 ,
795787 ],
796- ' triangular ' => [
788+ KdeKernel::Triangular-> value => [
797789 'pdf ' => static fn (float $ t ): float => 1.0 - abs ($ t ),
798790 'cdf ' => static fn (float $ t ): float => $ t >= 0
799791 ? 1.0 - (1.0 - $ t ) * (1.0 - $ t ) / 2.0
800792 : (1.0 + $ t ) * (1.0 + $ t ) / 2.0 ,
801793 'support ' => 1.0 ,
802794 ],
803- ' parabolic ' => [
795+ KdeKernel::Parabolic-> value => [
804796 'pdf ' => static fn (float $ t ): float => 0.75 * (1.0 - $ t * $ t ),
805797 'cdf ' => static fn (float $ t ): float => -0.25 * $ t * $ t * $ t + 0.75 * $ t + 0.5 ,
806798 'support ' => 1.0 ,
807799 ],
808- ' quartic ' => [
800+ KdeKernel::Quartic-> value => [
809801 'pdf ' => static fn (float $ t ): float => (15.0 / 16.0 ) * (1.0 - $ t * $ t ) ** 2 ,
810802 'cdf ' => static fn (float $ t ): float => (15.0 * $ t - 10.0 * $ t ** 3 + 3.0 * $ t ** 5 ) / 16.0 + 0.5 ,
811803 'support ' => 1.0 ,
812804 ],
813- ' triweight ' => [
805+ KdeKernel::Triweight-> value => [
814806 'pdf ' => static fn (float $ t ): float => (35.0 / 32.0 ) * (1.0 - $ t * $ t ) ** 3 ,
815807 'cdf ' => static fn (float $ t ): float => (35.0 * $ t - 35.0 * $ t ** 3 + 21.0 * $ t ** 5 - 5.0 * $ t ** 7 ) / 32.0 + 0.5 ,
816808 'support ' => 1.0 ,
817809 ],
818- ' cosine ' => [
810+ KdeKernel::Cosine-> value => [
819811 'pdf ' => static fn (float $ t ): float => (M_PI / 4.0 ) * cos (M_PI * $ t / 2.0 ),
820812 'cdf ' => static fn (float $ t ): float => 0.5 * sin (M_PI * $ t / 2.0 ) + 0.5 ,
821813 'support ' => 1.0 ,
822814 ],
823815 ];
824816
825- if (! isset ($ kernels [$ kernel ])) {
826- $ valid = implode (', ' , array_merge (array_keys ($ kernels ), array_keys ($ aliases )));
827- throw new InvalidDataInputException (
828- "Unknown kernel ' {$ kernel }'. Valid kernels: {$ valid }. " ,
829- );
830- }
831-
832- $ kernelDef = $ kernels [$ kernel ];
817+ $ kernelDef = $ kernels [$ kernel ->value ]; // @phpstan-ignore offsetAccess.notFound
833818 $ support = $ kernelDef ['support ' ];
834819 $ fn = $ cumulative ? $ kernelDef ['cdf ' ] : $ kernelDef ['pdf ' ];
835820
@@ -888,16 +873,16 @@ public static function kde(
888873 *
889874 * @param array<int|float> $data sample data
890875 * @param float $h bandwidth (smoothing parameter), must be > 0
891- * @param string $kernel kernel name or alias
876+ * @param KdeKernel $kernel kernel to use for estimation
892877 * @param int|null $seed optional seed for reproducibility
893878 * @return \Closure a callable that returns a random float from the KDE
894879 *
895- * @throws InvalidDataInputException if data is empty, bandwidth <= 0, or kernel is invalid
880+ * @throws InvalidDataInputException if data is empty or bandwidth <= 0
896881 */
897882 public static function kdeRandom (
898883 array $ data ,
899884 float $ h ,
900- string $ kernel = ' normal ' ,
885+ KdeKernel $ kernel = KdeKernel::Normal ,
901886 ?int $ seed = null ,
902887 ): \Closure {
903888 if ($ data === []) {
@@ -907,16 +892,7 @@ public static function kdeRandom(
907892 throw new InvalidDataInputException ("Bandwidth h must be positive. " );
908893 }
909894
910- $ aliases = [
911- 'gauss ' => 'normal ' ,
912- 'uniform ' => 'rectangular ' ,
913- 'epanechnikov ' => 'parabolic ' ,
914- 'biweight ' => 'quartic ' ,
915- ];
916- $ kernel = strtolower ($ kernel );
917- if (isset ($ aliases [$ kernel ])) {
918- $ kernel = $ aliases [$ kernel ];
919- }
895+ $ kernel = $ kernel ->resolve ();
920896
921897 // Acklam rational approximation for standard normal inverse CDF
922898 $ normalInvCdf = static function (float $ p ): float {
@@ -995,15 +971,15 @@ public static function kdeRandom(
995971 => ($ t < -1.0 || $ t > 1.0 ) ? 0.0 : (35.0 / 32.0 ) * (1.0 - $ t * $ t ) ** 3 ;
996972
997973 $ invcdfMap = [
998- ' normal ' => $ normalInvCdf ,
999- ' logistic ' => static fn (float $ p ): float => log ($ p / (1.0 - $ p )),
1000- ' sigmoid ' => static fn (float $ p ): float => log (tan ($ p * M_PI / 2.0 )),
1001- ' rectangular ' => static fn (float $ p ): float => 2.0 * $ p - 1.0 ,
1002- ' triangular ' => static fn (float $ p ): float
974+ KdeKernel::Normal-> value => $ normalInvCdf ,
975+ KdeKernel::Logistic-> value => static fn (float $ p ): float => log ($ p / (1.0 - $ p )),
976+ KdeKernel::Sigmoid-> value => static fn (float $ p ): float => log (tan ($ p * M_PI / 2.0 )),
977+ KdeKernel::Rectangular-> value => static fn (float $ p ): float => 2.0 * $ p - 1.0 ,
978+ KdeKernel::Triangular-> value => static fn (float $ p ): float
1003979 => $ p < 0.5 ? sqrt (2.0 * $ p ) - 1.0 : 1.0 - sqrt (2.0 - 2.0 * $ p ),
1004- ' parabolic ' => static fn (float $ p ): float
980+ KdeKernel::Parabolic-> value => static fn (float $ p ): float
1005981 => 2.0 * cos ((acos (2.0 * $ p - 1.0 ) + M_PI ) / 3.0 ),
1006- ' quartic ' => static function (float $ p ) use ($ newtonRaphson , $ quarticCdf , $ quarticPdf ): float {
982+ KdeKernel::Quartic-> value => static function (float $ p ) use ($ newtonRaphson , $ quarticCdf , $ quarticPdf ): float {
1007983 if ($ p <= 0.5 ) {
1008984 $ sign = 1.0 ;
1009985 } else {
@@ -1021,7 +997,7 @@ public static function kdeRandom(
1021997 $ x *= $ sign ;
1022998 return $ newtonRaphson ($ sign === 1.0 ? $ p : 1.0 - $ p , $ quarticCdf , $ quarticPdf , $ x );
1023999 },
1024- ' triweight ' => static function (float $ p ) use ($ newtonRaphson , $ triweightCdf , $ triweightPdf ): float {
1000+ KdeKernel::Triweight-> value => static function (float $ p ) use ($ newtonRaphson , $ triweightCdf , $ triweightPdf ): float {
10251001 if ($ p <= 0.5 ) {
10261002 $ sign = 1.0 ;
10271003 } else {
@@ -1035,17 +1011,10 @@ public static function kdeRandom(
10351011 $ x *= $ sign ;
10361012 return $ newtonRaphson ($ sign === 1.0 ? $ p : 1.0 - $ p , $ triweightCdf , $ triweightPdf , $ x );
10371013 },
1038- ' cosine ' => static fn (float $ p ): float => (2.0 / M_PI ) * asin (2.0 * $ p - 1.0 ),
1014+ KdeKernel::Cosine-> value => static fn (float $ p ): float => (2.0 / M_PI ) * asin (2.0 * $ p - 1.0 ),
10391015 ];
10401016
1041- if (! isset ($ invcdfMap [$ kernel ])) {
1042- $ valid = implode (', ' , array_merge (array_keys ($ invcdfMap ), array_keys ($ aliases )));
1043- throw new InvalidDataInputException (
1044- "Unknown kernel ' {$ kernel }'. Valid kernels: {$ valid }. " ,
1045- );
1046- }
1047-
1048- $ invcdf = $ invcdfMap [$ kernel ];
1017+ $ invcdf = $ invcdfMap [$ kernel ->value ]; // @phpstan-ignore offsetAccess.notFound
10491018 $ n = count ($ data );
10501019
10511020 if ($ seed !== null ) {
0 commit comments