forked from apache/datafusion-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtpch.py
More file actions
98 lines (85 loc) · 3.4 KB
/
tpch.py
File metadata and controls
98 lines (85 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import time
from datafusion import SessionContext
def bench(data_path, query_path) -> None:
with open("results.csv", "w") as results:
# register tables
start = time.time()
total_time_millis = 0
# create context
# runtime = (
# RuntimeEnvBuilder()
# .with_disk_manager_os()
# .with_fair_spill_pool(10000000)
# )
# config = (
# SessionConfig()
# .with_create_default_catalog_and_schema(True)
# .with_default_catalog_and_schema("datafusion", "tpch")
# .with_information_schema(True)
# )
# ctx = SessionContext(config, runtime)
ctx = SessionContext()
print("Configuration:\n", ctx)
# register tables
with open("create_tables.sql") as f:
sql = ""
for line in f.readlines():
if line.startswith("--"):
continue
sql = sql + line
if sql.strip().endswith(";"):
sql = sql.strip().replace("$PATH", data_path)
ctx.sql(sql)
sql = ""
end = time.time()
time_millis = (end - start) * 1000
total_time_millis += time_millis
print(f"setup,{round(time_millis, 1)}")
results.write(f"setup,{round(time_millis, 1)}\n")
results.flush()
# run queries
for query in range(1, 23):
with open(f"{query_path}/q{query}.sql") as f:
text = f.read()
tmp = text.split(";")
queries = [s.strip() for s in tmp if len(s.strip()) > 0]
try:
start = time.time()
for sql in queries:
print(sql)
df = ctx.sql(sql)
# result_set = df.collect()
df.show()
end = time.time()
time_millis = (end - start) * 1000
total_time_millis += time_millis
print(f"q{query},{round(time_millis, 1)}")
results.write(f"q{query},{round(time_millis, 1)}\n")
results.flush()
except Exception as e:
print("query", query, "failed", e)
print(f"total,{round(total_time_millis, 1)}")
results.write(f"total,{round(total_time_millis, 1)}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("data_path")
parser.add_argument("query_path")
args = parser.parse_args()
bench(args.data_path, args.query_path)