forked from PaddlePaddle/PaddleRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstartup.py
More file actions
103 lines (83 loc) · 3.46 KB
/
startup.py
File metadata and controls
103 lines (83 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import warnings
import paddle.fluid as fluid
from paddlerec.core.utils import envs
__all__ = ["StartupBase", "SingleStartup", "PSStartup", "CollectiveStartup"]
class StartupBase(object):
"""R
"""
def __init__(self, context):
pass
def startup(self, context):
pass
def load(self, context, is_fleet=False, main_program=None):
dirname = envs.get_global_env(
"runner." + context["runner_name"] + ".init_model_path", None)
if dirname is None or dirname == "":
return
print("going to load ", dirname)
if is_fleet:
context["fleet"].load_persistables(context["exe"], dirname)
else:
fluid.io.load_persistables(
context["exe"], dirname, main_program=main_program)
class SingleStartup(StartupBase):
"""R
"""
def __init__(self, context):
print("Running SingleStartup.")
pass
def startup(self, context):
for model_dict in context["phases"]:
with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]):
train_prog = context["model"][model_dict["name"]][
"main_program"]
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
context["exe"].run(startup_prog)
self.load(context, main_program=train_prog)
context["status"] = "train_pass"
class PSStartup(StartupBase):
def __init__(self, context):
print("Running PSStartup.")
pass
def startup(self, context):
model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
train_prog = context["model"][model_dict["name"]]["main_program"]
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
context["exe"].run(startup_prog)
self.load(context, True)
context["status"] = "train_pass"
class CollectiveStartup(StartupBase):
def __init__(self, context):
print("Running CollectiveStartup.")
pass
def startup(self, context):
model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
train_prog = context["model"][model_dict["name"]][
"default_main_program"]
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
context["exe"].run(startup_prog)
self.load(context, True)
context["status"] = "train_pass"