Skip to content

Commit 2e0bcee

Browse files
author
Colin Grambow
committed
Add prediction script
1 parent ae0063c commit 2e0bcee

3 files changed

Lines changed: 50 additions & 0 deletions

File tree

predict.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env python
2+
3+
import reacdiff.parsing as parsing
4+
import reacdiff.train.predict as predict
5+
6+
7+
if __name__ == '__main__':
8+
args = parsing.parse_predict_args()
9+
predict.predict(args)

reacdiff/parsing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@ def parse_dataprep_args():
1111
return parser.parse_args()
1212

1313

14+
def parse_predict_args():
15+
parser = argparse.ArgumentParser(
16+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
17+
)
18+
parser.add_argument('--data_path', type=str, required=True,
19+
help='Path to data containing states for prediction task')
20+
parser.add_argument('--model', type=str, required=True,
21+
help='Path to trained model')
22+
parser.add_argument('--data_path2', type=str,
23+
help='Path to additional observable states for prediction')
24+
parser.add_argument('--save_path', type=str, default=os.path.join(os.getcwd(), 'preds.csv'),
25+
help='Path to save predictions to')
26+
parser.add_argument('--batch_size', type=int, default=32,
27+
help='Batch size')
28+
return parser.parse_args()
29+
30+
1431
def parse_train_args():
1532
parser = argparse.ArgumentParser(
1633
formatter_class=argparse.ArgumentDefaultsHelpFormatter

reacdiff/train/predict.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
3+
import keras
4+
5+
import reacdiff.data.data as datamod
6+
import reacdiff.utils as utils
7+
8+
9+
def predict(args):
10+
# Load data
11+
print('Loading data')
12+
data = datamod.Dataset(
13+
datamod.load_data(args.data_path),
14+
data2=None if args.data_path2 is None else datamod.load_data(args.data_path2)
15+
)
16+
17+
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
18+
19+
# Load model
20+
model = keras.models.load_model(args.model, custom_objects={'rmse': utils.rmse, 'mae': utils.mae})
21+
22+
# Predict
23+
preds = model.predict(data.get_data(), batch_size=args.batch_size, verbose=1)
24+
datamod.save_csv(preds, args.save_path)

0 commit comments

Comments
 (0)