This repository was archived by the owner on Aug 11, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathexperiments.py
More file actions
205 lines (167 loc) · 9.28 KB
/
experiments.py
File metadata and controls
205 lines (167 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import pydoc
import terminaltables
from paperspace import logger, constants, client, config
from paperspace.commands import CommandBase
from paperspace.workspace import S3WorkspaceHandler
from paperspace.logger import log_response
from paperspace.utils import get_terminal_lines
experiments_api = client.API(config.CONFIG_EXPERIMENTS_HOST, headers=client.default_headers)
class ExperimentCommand(CommandBase):
def __init__(self, workspace_handler=None, **kwargs):
super(ExperimentCommand, self).__init__(**kwargs)
self._workspace_handler = workspace_handler or S3WorkspaceHandler(experiments_api=self.api, logger=self.logger)
def _log_create_experiment(self, response, success_msg_template, error_msg):
if response.ok:
j = response.json()
handle = j["handle"]
msg = success_msg_template.format(handle)
self.logger.log(msg)
else:
try:
data = response.json()
self.logger.log_error_response(data)
except ValueError:
self.logger.error(error_msg)
class CreateExperimentCommand(ExperimentCommand):
def execute(self, json_):
workspace_url = self._workspace_handler.upload_workspace(json_)
if workspace_url:
json_['workspaceUrl'] = workspace_url
response = self.api.post("/experiments/", json=json_)
self._log_create_experiment(response,
"New experiment created with handle: {}",
"Unknown error while creating the experiment")
class CreateAndStartExperimentCommand(ExperimentCommand):
def execute(self, json_):
workspace_url = self._workspace_handler.upload_workspace(json_)
if workspace_url:
json_['workspaceUrl'] = workspace_url
response = self.api.post("/experiments/create_and_start/", json=json_)
self._log_create_experiment(response,
"New experiment created and started with handle: {}",
"Unknown error while creating/starting the experiment")
def start_experiment(experiment_handle, api=experiments_api):
url = "/experiments/{}/start/".format(experiment_handle)
response = api.put(url)
log_response(response, "Experiment started", "Unknown error while starting the experiment")
def stop_experiment(experiment_handle, api=experiments_api):
url = "/experiments/{}/stop/".format(experiment_handle)
response = api.put(url)
log_response(response, "Experiment stopped", "Unknown error while stopping the experiment")
def delete_experiment(experiment_handle, api=experiments_api):
url = "/experiments/{}".format(experiment_handle)
response = api.delete(url)
log_response(response, "Experiment deleted", "Unknown error while deleting the experiment")
class ListExperimentsCommand(object):
def __init__(self, api=experiments_api, logger_=logger):
self.api = api
self.logger = logger_
def execute(self, project_handles=None):
project_handles = project_handles or []
params = self._get_query_params(project_handles)
response = self.api.get("/experiments/", params=params)
try:
data = response.json()
if not response.ok:
self.logger.log_error_response(data)
return
experiments = self._get_experiments_list(data, bool(project_handles))
except (ValueError, KeyError) as e:
self.logger.error("Error while parsing response data: {}".format(e))
else:
self._log_experiments_list(experiments)
@staticmethod
def _get_query_params(project_handles):
params = {"limit": -1} # so the API sends back full list without pagination
for i, handle in enumerate(project_handles):
key = "projectHandle[{}]".format(i)
params[key] = handle
return params
@staticmethod
def _make_experiments_list_table(experiments):
data = [("Name", "Handle", "Status")]
for experiment in experiments:
name = experiment["templateHistory"]["params"].get("name")
handle = experiment["handle"]
status = constants.ExperimentState.get_state_str(experiment["state"])
data.append((name, handle, status))
ascii_table = terminaltables.AsciiTable(data)
table_string = ascii_table.table
return table_string
@staticmethod
def _get_experiments_list(data, filtered=False):
if not filtered: # If filtering by projectHandle response data has different format...
return data["data"]
experiments = []
for project_experiments in data["data"]:
for experiment in project_experiments["data"]:
experiments.append(experiment)
return experiments
def _log_experiments_list(self, experiments):
if not experiments:
self.logger.warning("No experiments found")
else:
table_str = self._make_experiments_list_table(experiments)
if len(table_str.splitlines()) > get_terminal_lines():
pydoc.pager(table_str)
else:
self.logger.log(table_str)
def _make_details_table(experiment):
if experiment["experimentTypeId"] == constants.ExperimentType.SINGLE_NODE:
data = (
("Name", experiment["templateHistory"]["params"].get("name")),
("Handle", experiment.get("handle")),
("State", constants.ExperimentState.get_state_str(experiment.get("state"))),
("Ports", experiment["templateHistory"]["params"].get("ports")),
("Project Handle", experiment["templateHistory"]["params"].get("project_handle")),
("Worker Command", experiment["templateHistory"]["params"].get("worker_command")),
("Worker Container", experiment["templateHistory"]["params"].get("worker_container")),
("Worker Machine Type", experiment["templateHistory"]["params"].get("worker_machine_type")),
("Working Directory", experiment["templateHistory"]["params"].get("workingDirectory")),
("Workspace URL", experiment["templateHistory"]["params"].get("workspaceUrl")),
("Model Type", experiment["templateHistory"]["params"].get("modelType")),
("Model Path", experiment["templateHistory"]["params"].get("modelPath")),
)
elif experiment["experimentTypeId"] in (constants.ExperimentType.GRPC_MULTI_NODE,
constants.ExperimentType.MPI_MULTI_NODE):
data = (
("Name", experiment["templateHistory"]["params"].get("name")),
("Handle", experiment.get("handle")),
("State", constants.ExperimentState.get_state_str(experiment.get("state"))),
("Artifact directory", experiment["templateHistory"]["params"].get("artifactDirectory")),
("Cluster ID", experiment["templateHistory"]["params"].get("clusterId")),
("Experiment Env", experiment["templateHistory"]["params"].get("experimentEnv")),
("Experiment Type",
constants.ExperimentType.get_type_str(experiment["templateHistory"]["params"].get("experimentTypeId"))),
("Model Type", experiment["templateHistory"]["params"].get("modelType")),
("Model Path", experiment["templateHistory"]["params"].get("modelPath")),
("Parameter Server Command", experiment["templateHistory"]["params"].get("parameter_server_command")),
("Parameter Server Container", experiment["templateHistory"]["params"].get("parameter_server_container")),
("Parameter Server Count", experiment["templateHistory"]["params"].get("parameter_server_count")),
("Parameter Server Machine Type",
experiment["templateHistory"]["params"].get("parameter_server_machine_type")),
("Ports", experiment["templateHistory"]["params"].get("ports")),
("Project Handle", experiment["templateHistory"]["params"].get("project_handle")),
("Worker Command", experiment["templateHistory"]["params"].get("worker_command")),
("Worker Container", experiment["templateHistory"]["params"].get("worker_container")),
("Worker Count", experiment["templateHistory"]["params"].get("worker_count")),
("Worker Machine Type", experiment["templateHistory"]["params"].get("worker_machine_type")),
("Working Directory", experiment["templateHistory"]["params"].get("workingDirectory")),
("Workspace URL", experiment["templateHistory"]["params"].get("workspaceUrl")),
)
else:
raise ValueError("Wrong experiment type: {}".format(experiment["experimentTypeId"]))
ascii_table = terminaltables.AsciiTable(data)
table_string = ascii_table.table
return table_string
def get_experiment_details(experiment_handle, api=experiments_api):
url = "/experiments/{}/".format(experiment_handle)
response = api.get(url)
details = response.content
if response.ok:
try:
experiment = response.json()["data"]
details = _make_details_table(experiment)
except (ValueError, KeyError) as e:
logger.error("Error parsing response data")
log_response(response, details, "Unknown error while retrieving details of the experiment")