-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_multi_model.py
More file actions
67 lines (60 loc) · 2.25 KB
/
Copy pathtrain_multi_model.py
File metadata and controls
67 lines (60 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Script to train an agent to operate into the market and trade between 3 diffrent pairs.
Update the variables at configs/vars.py.
Example:
python train_multi_model.py
Lucas Draichi 2019
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import gym
import os
import ray
from datetime import date
from gym.spaces import Discrete, Box
from configs.functions import get_datasets
from configs.vars import *
from ray.tune import run_experiments, grid_search
from ray.tune.registry import register_env
from env.MultiModelEnvRank1 import TradingEnv
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/train.py
if __name__ == "__main__":
df1, _ = get_datasets(SYMBOL_1, TRADE_INSTRUMENT, HISTO, LIMIT)
df2, _ = get_datasets(SYMBOL_2, TRADE_INSTRUMENT, HISTO, LIMIT)
df3, _ = get_datasets(SYMBOL_3, TRADE_INSTRUMENT, HISTO, LIMIT)
register_env("MultiTradingEnv-v1", lambda config: TradingEnv(config))
experiment_spec = {
"agravai": {
"run": "PPO",
"env": "MultiTradingEnv-v1",
"stop": {
"timesteps_total": TIMESTEPS_TOTAL, #1e6 = 1M
},
"checkpoint_freq": CHECKPOINT_FREQUENCY,
"checkpoint_at_end": True,
"local_dir": '/home/lucas/Documents/cryptocurrency_prediction/tensorboard', # you can comment this line and your chapoints will be saved in ~/ray_results/
"restore": RESTORE_PATH,
"config": {
# "lr_schedule": grid_search(LEARNING_RATE_SCHEDULE),
"lr": 7e-6,
"num_workers": 3, # parallelism
'observation_filter': 'MeanStdFilter',
'vf_share_layers': True, # testing
"env_config": {
'df1': df1,
'df2': df2,
'df3': df3,
's1': SYMBOL_1,
's2': SYMBOL_2,
's3': SYMBOL_3,
'trade_instrument': TRADE_INSTRUMENT,
'render_title': '',
'histo': HISTO
},
}
}
}
ray.init()
run_experiments(experiments=experiment_spec)