-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathLogLinearGLM.java
More file actions
149 lines (113 loc) · 4.29 KB
/
Copy pathLogLinearGLM.java
File metadata and controls
149 lines (113 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package mascot.parameterdynamics;
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.StateNode;
import beast.base.inference.StateNodeInitialiser;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.inference.parameter.RealVectorParam;
import mascot.glmmodel.CovariateList;
import java.util.List;
public class LogLinearGLM extends NeDynamics implements StateNodeInitialiser {
public Input<CovariateList> covariateListInput = new Input<>("covariateList", "input of covariates", Validate.REQUIRED);
public Input<RealVectorParam<? extends Real>> scalerInput = new Input<>("scaler", "input of covariates scaler", Validate.REQUIRED);
public Input<BoolVectorParam> indicatorInput = new Input<>("indicator", "input of covariates scaler", Validate.REQUIRED);
public Input<RealScalarParam<Real>> clockInput = new Input<>("clock", "clock rate of the parameter",Validate.REQUIRED);
public Input<RealVectorParam<? extends Real>> errorInput = new Input<>("error", "time variant error term in the GLM model for the rates");
public Input<RealVectorParam<? extends Real>> constantErrorInput = new Input<>("constantError", "time invariant error term in the GLM model for the rates");
final public Input<RealVectorParam<? extends Real>> rateShiftsInput = new Input<>("rateShifts","When to switch between elements of Ne", Input.Validate.REQUIRED);
RealVectorParam<? extends Real> rateShifts;
boolean valuesKnown = false;
double[] rates;
@Override
public void initAndValidate() {
// set the dimension of the scalers, indicators and potentially the error term
scalerInput.get().setDimension(covariateListInput.get().size());
indicatorInput.get().setDimension(covariateListInput.get().size());
if (errorInput.get()!=null)
errorInput.get().setDimension(covariateListInput.get().get(0).getDimension());
isTime = true;
rateShifts = rateShiftsInput.get();
rates = new double[covariateListInput.get().get(0).getDimension()];
}
@Override
public double getNeTime(double t) {
// if (!valuesKnown)
recalculate();
int intervalnr = getIntervalNr(t);
if (intervalnr>=rateShifts.size()) {
return rates[rateShifts.size()-1];
}
return rates[intervalnr];
}
private int getIntervalNr(double t) {
// check which interval t + offset is in
for (int i = 0; i < rateShifts.size(); i++)
if (t<rateShifts.get(i))
return i;
// after the last interval, just keep using the last element
return rateShifts.size();
}
@Override
public void recalculate() {
for (int i = 0; i < rates.length; i++) {
double logrates = 0;
for (int j = 0; j < covariateListInput.get().size(); j++){
if (indicatorInput.get().get(j)){
logrates += scalerInput.get().get(j)
*covariateListInput.get().get(j).getArrayValue(i);
}
}
if (errorInput.get()!=null)
logrates += errorInput.get().get(i);
if (constantErrorInput.get()!=null)
logrates += constantErrorInput.get().get(0);
rates[i] = clockInput.get().get()*Math.exp(logrates);
}
valuesKnown = true;
}
@Override
public boolean isDirty() {
for (int i = 0; i < scalerInput.get().size(); i++)
if(scalerInput.get().isDirty(i)){
valuesKnown = false;
return true;
}
for (int i = 0; i < indicatorInput.get().size(); i++)
if(indicatorInput.get().isDirty(i)){
valuesKnown = false;
return true;
}
if (errorInput.get() != null)
for (int i = 0; i < errorInput.get().size(); i++)
if(errorInput.get().isDirty(i)){
valuesKnown = false;
return true;
}
if (constantErrorInput.get() != null)
for (int i = 0; i < constantErrorInput.get().size(); i++)
if(constantErrorInput.get().isDirty(i)) {
valuesKnown = false;
return true;
}
if (clockInput.get().somethingIsDirty()) {
valuesKnown = false;
return true;
}
return false;
}
@Override
public void restore() {
valuesKnown = false;
}
@Override
public void initStateNodes() {
}
@Override
public void getInitialisedStateNodes(List<StateNode> stateNodes) {
stateNodes.add(scalerInput.get());
stateNodes.add(indicatorInput.get());
if (errorInput.get() != null) stateNodes.add(errorInput.get());
}
}