-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
59 lines (54 loc) · 1.94 KB
/
train.py
File metadata and controls
59 lines (54 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Script to train an agent to operate into the market according to the pair
Example:
python train.py \
--algo PPO \
--pair XRP/USDT \
Lucas Draichi 2019
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import gym
import os
import ray
from datetime import date
from gym.spaces import Discrete, Box
from configs.functions import get_datasets
from configs.vars import *
from env.TradingEnvV1 import TradingEnv
from ray.tune import run_experiments, grid_search
from ray.tune.registry import register_env
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='\n train a reinforcement learning agent')
parser.add_argument('--pair', type=str, required=True, help='The pair to be traded e.g.: ETH/BTC')
parser.add_argument('--algo', type=str, required=True, help='Choose algorithm to train')
args = parser.parse_args()
from_symbol, to_symbol = args.pair.split('/')
df, _ = get_datasets(from_symbol, to_symbol, HISTO, LIMIT)
register_env("TradingEnv-v0", lambda config: TradingEnv(config))
ray.init()
run_experiments({
"{}_{}_{}_{}".format(from_symbol + to_symbol, LIMIT, HISTO, date.today()): {
"run": args.algo,
"env": "TradingEnv-v0",
"stop": {
"timesteps_total": TIMESTEPS_TOTAL,
},
"checkpoint_freq": CHECKPOINT_FREQUENCY,
"checkpoint_at_end": True,
"config": {
"lr_schedule": grid_search(LEARNING_RATE_SCHEDULE),
"num_workers": 3, # parallelism
'observation_filter': 'MeanStdFilter',
'vf_share_layers': True, # testing
# "vf_clip_param": 10000000.0,
"env_config": {
'df': df,
'render_title': ''
},
}
}
})