Skip to content

Commit a46a5e3

Browse files
committed
Improved code compilation pattern, closes #472
1 parent 23ef1d6 commit a46a5e3

3 files changed

Lines changed: 24 additions & 8 deletions

File tree

docs/cli.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,12 +1359,9 @@ The following example adds a new ``score`` column, then updates it to list a ran
13591359
random.seed(10)
13601360

13611361
def convert(value):
1362-
global random
13631362
return random.random()
13641363
'
13651364

1366-
Note the ``global random`` line here. Due to the way the tool compiles Python code, this is necessary to ensure the ``random`` module is available within the ``convert()`` function. If you were to omit this you would see a ``NameError: name 'random' is not defined`` error.
1367-
13681365
.. _cli_convert_recipes:
13691366

13701367
sqlite-utils convert recipes

sqlite_utils/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,11 @@ def progressbar(*args, **kwargs):
432432

433433

434434
def _compile_code(code, imports, variable="value"):
435-
locals = {}
436435
globals = {"r": recipes, "recipes": recipes}
437436
# If user defined a convert() function, return that
438437
try:
439-
exec(code, globals, locals)
440-
return locals["convert"]
438+
exec(code, globals)
439+
return globals["convert"]
441440
except (AttributeError, SyntaxError, NameError, KeyError, TypeError):
442441
pass
443442

@@ -464,8 +463,8 @@ def _compile_code(code, imports, variable="value"):
464463

465464
for import_ in imports:
466465
globals[import_.split(".")[0]] = __import__(import_)
467-
exec(code_o, globals, locals)
468-
return locals["fn"]
466+
exec(code_o, globals)
467+
return globals["fn"]
469468

470469

471470
def chunks(sequence: Iterable, size: int) -> Iterable[Iterable]:

tests/test_cli_convert.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,23 @@ def test_convert_hyphen_workaround(fresh_db_and_path):
606606
assert list(db["names"].rows) == [
607607
{"id": 1, "name": "-"},
608608
]
609+
610+
611+
def test_convert_initialization_pattern(fresh_db_and_path):
612+
db, db_path = fresh_db_and_path
613+
db["names"].insert_all([{"id": 1, "name": "Cleo"}], pk="id")
614+
result = CliRunner().invoke(
615+
cli.cli,
616+
[
617+
"convert",
618+
db_path,
619+
"names",
620+
"name",
621+
"-",
622+
],
623+
input="import random\nrandom.seed(1)\ndef convert(value): return random.randint(0, 100)",
624+
)
625+
assert 0 == result.exit_code, result.output
626+
assert list(db["names"].rows) == [
627+
{"id": 1, "name": "17"},
628+
]

0 commit comments

Comments
 (0)