File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -161,6 +161,16 @@ py_binary(
161161 ],
162162)
163163
164+ py_test (
165+ name = "train_ggnn_test" ,
166+ srcs = ["train_ggnn_test.py" ],
167+ data = ["//programl/test/data:reachability_dataflow_dataset" ],
168+ deps = [
169+ ":train_ggnn" ,
170+ "//third_party/py/labm8" ,
171+ ],
172+ )
173+
164174py_binary (
165175 name = "train_lstm" ,
166176 srcs = ["train_lstm.py" ],
Original file line number Diff line number Diff line change 1+ # Copyright 2019-2020 the ProGraML authors.
2+ #
3+ # Contact Chris Cummins <chrisc.101@gmail.com>.
4+ #
5+ # Licensed under the Apache License, Version 2.0 (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ import subprocess
17+ import sys
18+
19+ from labm8 .py import bazelutil , test
20+
21+ TRAIN_GGNN = bazelutil .DataPath ("programl/programl/task/dataflow/train_ggnn" )
22+
23+
24+ REACHABILITY_DATAFLOW_DATASET = bazelutil .DataArchive (
25+ "programl/programl/test/data/reachability_dataflow_dataset.tar.bz2"
26+ )
27+
28+
29+ def test_reachability_end_to_end ():
30+ with REACHABILITY_DATAFLOW_DATASET as d :
31+ p = subprocess .Popen (
32+ [
33+ TRAIN_GGNN ,
34+ "--path" ,
35+ str (d ),
36+ "--analysis" ,
37+ "reachability" ,
38+ "--limit_max_data_flow_steps" ,
39+ "--layer_timesteps=10" ,
40+ str (10 ),
41+ "--val_graph_count" ,
42+ str (10 ),
43+ "--val_seed" ,
44+ str (0xCC ),
45+ "--train_graph_counts" ,
46+ "10,20" ,
47+ "--batch_size" ,
48+ str (8 ),
49+ ]
50+ )
51+ p .communicate ()
52+ if p .returncode :
53+ sys .exit (1 )
54+
55+
56+ if __name__ == "__main__" :
57+ test .Main ()
You can’t perform that action at this time.
0 commit comments