Skip to content

Commit 201e3a7

Browse files
committed
Add testing utilities
1 parent 9c26e70 commit 201e3a7

1 file changed

Lines changed: 46 additions & 8 deletions

File tree

singlestoredb/tests/utils.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@
55

66
import os
77
import uuid
8+
from typing import Any
9+
from typing import Dict
810
from urllib.parse import urlparse
911

1012
import singlestoredb as s2
1113
from singlestoredb.connection import build_params
1214

1315

16+
def apply_template(content: str, vars: Dict[str, Any]) -> str:
17+
for k, v in vars.items():
18+
key = '{{%s}}' % k
19+
if key in content:
20+
content = content.replace(key, v)
21+
return content
22+
23+
1424
def load_sql(sql_file: str) -> str:
1525
"""
1626
Load a file containing SQL code.
@@ -41,7 +51,7 @@ def load_sql(sql_file: str) -> str:
4151
# If no database name was specified, use initializer URL if given.
4252
# HTTP can't change databases, so you can't initialize from HTTP
4353
# while also creating a database.
44-
args = {}
54+
args = {'local_infile': True}
4555
if not dbname and 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
4656
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
4757

@@ -58,6 +68,8 @@ def load_sql(sql_file: str) -> str:
5868

5969
dbexisted = bool(dbname)
6070

71+
template_vars = dict(DATABASE_NAME=dbname, TEST_PATH=os.path.dirname(sql_file))
72+
6173
# Always use the default driver since not all operations are
6274
# permitted in the HTTP API.
6375
with open(sql_file, 'r') as infile:
@@ -66,14 +78,16 @@ def load_sql(sql_file: str) -> str:
6678
if not dbname:
6779
dbname = 'TEST_{}'.format(uuid.uuid4()).replace('-', '_')
6880
cur.execute(f'CREATE DATABASE {dbname};')
69-
cur.execute(f'USE {dbname};')
7081

71-
# Execute lines in SQL.
72-
for cmd in infile.read().split(';\n'):
73-
cmd = cmd.strip()
74-
if cmd:
75-
cmd += ';'
76-
cur.execute(cmd)
82+
cur.execute(f'USE {dbname};')
83+
template_vars['DATABASE_NAME'] = dbname
84+
85+
# Execute lines in SQL.
86+
for cmd in infile.read().split(';\n'):
87+
cmd = apply_template(cmd.strip(), template_vars)
88+
if cmd:
89+
cmd += ';'
90+
cur.execute(cmd)
7791

7892
# Start HTTP server as needed.
7993
if http_port:
@@ -93,3 +107,27 @@ def drop_database(name: str) -> None:
93107
with s2.connect(**args) as conn:
94108
with conn.cursor() as cur:
95109
cur.execute(f'DROP DATABASE {name};')
110+
111+
112+
def create_user(name: str, password: str, dbname: str) -> None:
113+
"""Create a user for the test database."""
114+
if name:
115+
args = {}
116+
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
117+
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
118+
with s2.connect(**args) as conn:
119+
with conn.cursor() as cur:
120+
cur.execute(f'DROP USER IF EXISTS {name};')
121+
cur.execute(f'CREATE USER "{name}"@"%" IDENTIFIED BY "{password}"')
122+
cur.execute(f'GRANT ALL ON {dbname}.* to "{name}"@"%"')
123+
124+
125+
def drop_user(name: str) -> None:
126+
"""Drop a database with the given name."""
127+
if name:
128+
args = {}
129+
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
130+
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
131+
with s2.connect(**args) as conn:
132+
with conn.cursor() as cur:
133+
cur.execute(f'DROP USER IF EXISTS {name};')

0 commit comments

Comments
 (0)