-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathregression.clj
More file actions
112 lines (91 loc) · 3.29 KB
/
Copy pathregression.clj
File metadata and controls
112 lines (91 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
(ns examples.regression
(:require
[zero-one.geni.core :as g]
[zero-one.geni.ml :as ml]))
;; Linear Regression
(def training (g/read-libsvm! "test/resources/sample_libsvm_data.txt"))
(def lr (ml/linear-regression {:max-iter 10
:reg-param 0.8
:elastic-net-param 0.8}))
(def lr-model (ml/fit training lr))
(-> training
(ml/transform lr-model)
(g/select :label :prediction)
(g/limit 5)
g/show)
;;=>
;; +-----+----------+
;; |label|prediction|
;; +-----+----------+
;; |0.0 |0.57 |
;; |1.0 |0.57 |
;; |1.0 |0.57 |
;; |1.0 |0.57 |
;; |1.0 |0.57 |
;; +-----+----------+
(take 3 (ml/coefficients lr-model))
;;=> (0.0 0.0 0.0)
(ml/intercept lr-model)
;;=> 0.57
;; Random Forest Regression
(def data (g/read-libsvm! "test/resources/sample_libsvm_data.txt"))
(def feature-indexer
(ml/fit data (ml/vector-indexer {:input-col :features
:output-col :indexed-features
:max-categories 4})))
(def split-data (g/random-split data [0.7 0.3]))
(def train-data (first split-data))
(def test-data (second split-data))
(def pipeline
(ml/pipeline
feature-indexer
(ml/random-forest-regressor {:label-col :label
:features-col :indexed-features})))
(def model (ml/fit train-data pipeline))
(def predictions (ml/transform test-data model))
(def evaluator
(ml/regression-evaluator {:label-col :label
:prediction-col :prediction
:metric-name "rmse"}))
(-> predictions
(g/select :prediction :label)
(g/show {:num-rows 5}))
;;=>
;; +----------+-----+
;; |prediction|label|
;; +----------+-----+
;; |0.0 |0.0 |
;; |0.0 |0.0 |
;; |0.0 |0.0 |
;; |0.0 |0.0 |
;; |0.25 |0.0 |
;; +----------+-----+
;; only showing top 5 rows
(println "RMSE:" (ml/evaluate predictions evaluator))
;;=> RMSE: 0.173685539601507
;; Survival Regression
(def train
(g/table->dataset
[[1.218 1.0 (g/dense 1.560 -0.605)]
[2.949 0.0 (g/dense 0.346 2.158)]
[3.627 0.0 (g/dense 1.380 0.231)]
[0.273 1.0 (g/dense 0.520 1.151)]
[4.199 0.0 (g/dense 0.795 -0.226)]]
[:label :censor :features]))
(def quantile-probabilities [0.3 0.6])
(def aft
(ml/aft-survival-regression
{:quantile-probabilities quantile-probabilities
:quantiles-col :quantiles}))
(def aft-model (ml/fit train aft))
(-> train (ml/transform aft-model) g/show)
;;=>
;; +-----+------+--------------+------------------+--------------------------------------+
;; |label|censor|features |prediction |quantiles |
;; +-----+------+--------------+------------------+--------------------------------------+
;; |1.218|1.0 |[1.56,-0.605] |5.71897948763501 |[1.1603238947151657,4.995456010274772]|
;; |2.949|0.0 |[0.346,2.158] |18.07652118149533 |[3.667545845471735,15.789611866277625]|
;; |3.627|0.0 |[1.38,0.231] |7.381861804239101 |[1.4977061305190822,6.447962612338965]|
;; |0.273|1.0 |[0.52,1.151] |13.577612501425284|[2.7547621481506823,11.8598722240697] |
;; |4.199|0.0 |[0.795,-0.226]|9.013097744073898 |[1.8286676321297806,7.87282650587843] |
;; +-----+------+--------------+------------------+--------------------------------------+