-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.js
More file actions
29 lines (24 loc) · 789 Bytes
/
train_model.js
File metadata and controls
29 lines (24 loc) · 789 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
const tf = require('@tensorflow/tfjs');
const tf_cpu = require('@tensorflow/tfjs-node'); // run on CPU
const data = require('./data');
const TEST_SPLIT = 0.30;
const NUM_EPOCHS = 100;
async function run() {
const [xTrain, yTrain, xTest, yTest] = await data.getData(TEST_SPLIT);
const model = await data.getModel();
const history = model.fit(xTrain, yTrain, {
epochs: NUM_EPOCHS,
validationData: [xTest, yTest],
callbacks: {
onEpochEnd: async (epoch, logs) => {
// See the loss and accuracy values at the end of every training epoch.
console.log((1 + epoch) + ": ", logs);
},
onTrainEnd: async () => {
const saveResults = await model.save("file://saved-model");
console.log(saveResults);
}
}
});
}
run();