Skip to content

Commit 3c1e3bc

Browse files
committed
fixes #759
1 parent 24e833a commit 3c1e3bc

3 files changed

Lines changed: 78 additions & 31 deletions

File tree

fastcore/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@
556556
'fastcore.script.bool_arg': ('script.html#bool_arg', 'fastcore/script.py'),
557557
'fastcore.script.call_parse': ('script.html#call_parse', 'fastcore/script.py'),
558558
'fastcore.script.clean_type_str': ('script.html#clean_type_str', 'fastcore/script.py'),
559+
'fastcore.script.set_ctx': ('script.html#set_ctx', 'fastcore/script.py'),
559560
'fastcore.script.store_false': ('script.html#store_false', 'fastcore/script.py'),
560561
'fastcore.script.store_true': ('script.html#store_true', 'fastcore/script.py')},
561562
'fastcore.shutil': {},

fastcore/script.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# %% auto #0
66
__all__ = ['SCRIPT_INFO', 'store_true', 'store_false', 'bool_arg', 'clean_type_str', 'Param', 'anno_parser', 'args_from_prog',
7-
'call_parse']
7+
'set_ctx', 'call_parse']
88

99
# %% ../nbs/06_script.ipynb #8a36db98
1010
import inspect,argparse,shutil,types
@@ -135,24 +135,39 @@ def args_from_prog(func, prog):
135135
# %% ../nbs/06_script.ipynb #f76b07f6
136136
SCRIPT_INFO = SimpleNamespace(func=None)
137137

138-
# %% ../nbs/06_script.ipynb #3c1b3a65
138+
# %% ../nbs/06_script.ipynb #42c8e85f
139+
from contextvars import ContextVar
140+
from contextlib import contextmanager
141+
142+
# %% ../nbs/06_script.ipynb #e4537112
143+
@contextmanager
144+
def set_ctx(cv, val=True):
145+
token = cv.set(val)
146+
try: yield
147+
finally: cv.reset(token)
148+
149+
# %% ../nbs/06_script.ipynb #fc816498
150+
_in_call_parse = ContextVar('_in_call_parse', default=False)
151+
139152
def call_parse(func=None, nested=False):
140153
"Decorator to create a simple CLI from `func` using `anno_parser`"
141154
if func is None: return partial(call_parse, nested=nested)
142155

143156
@wraps(func)
144157
def _f(*args, **kwargs):
145-
mod = inspect.getmodule(inspect.currentframe().f_back)
146-
if not mod: return func(*args, **kwargs)
147-
if not SCRIPT_INFO.func and mod.__name__=="__main__": SCRIPT_INFO.func = func.__name__
148-
if len(sys.argv)>1 and sys.argv[1]=='': sys.argv.pop(1)
149-
p = anno_parser(func)
150-
if nested: args, sys.argv[1:] = p.parse_known_args()
151-
else: args = p.parse_args()
152-
args = args.__dict__
153-
xtra = otherwise(args.pop('xtra', ''), eq(1), p.prog)
154-
tfunc = trace(func) if args.pop('pdb', False) else func
155-
return tfunc(**merge(args, args_from_prog(func, xtra)))
158+
if args or kwargs or _in_call_parse.get(): return func(*args, **kwargs)
159+
with set_ctx(_in_call_parse):
160+
mod = inspect.getmodule(inspect.currentframe().f_back)
161+
if not mod: return func(*args, **kwargs)
162+
if not SCRIPT_INFO.func and mod.__name__=="__main__": SCRIPT_INFO.func = func.__name__
163+
if len(sys.argv)>1 and sys.argv[1]=='': sys.argv.pop(1)
164+
p = anno_parser(func)
165+
if nested: args, sys.argv[1:] = p.parse_known_args()
166+
else: args = p.parse_args()
167+
args = args.__dict__
168+
xtra = otherwise(args.pop('xtra', ''), eq(1), p.prog)
169+
tfunc = trace(func) if args.pop('pdb', False) else func
170+
return tfunc(**merge(args, args_from_prog(func, xtra)))
156171

157172
mod = inspect.getmodule(inspect.currentframe().f_back)
158173
if getattr(mod, '__name__', '') =="__main__":

nbs/06_script.ipynb

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -503,16 +503,16 @@
503503
"name": "stdout",
504504
"output_type": "stream",
505505
"text": [
506-
"usage: progname [-h] [--a] [--v] [--b B] [--c {aa,bb,cc}] required\n",
506+
"usage: progname [-h] [--v] [--b B] [--c {aa,bb,cc}] required a\n",
507507
"\n",
508508
"my docs\n",
509509
"\n",
510510
"positional arguments:\n",
511511
" required Required param\n",
512+
" a param 1\n",
512513
"\n",
513514
"options:\n",
514515
" -h, --help show this help message and exit\n",
515-
" --a param 1 (default: False)\n",
516516
" --v Print version\n",
517517
" --b B param 2 (default: test)\n",
518518
" --c {aa,bb,cc} param 3 (default: aa)\n"
@@ -570,16 +570,16 @@
570570
"name": "stdout",
571571
"output_type": "stream",
572572
"text": [
573-
"usage: progname [-h] [--a] [--v] [--b B] [--c {aa,bb,cc}] required\n",
573+
"usage: progname [-h] [--v] [--b B] [--c {aa,bb,cc}] required a\n",
574574
"\n",
575575
"my docs\n",
576576
"\n",
577577
"positional arguments:\n",
578578
" required Required param\n",
579+
" a param 1\n",
579580
"\n",
580581
"options:\n",
581582
" -h, --help show this help message and exit\n",
582-
" --a param 1 (default: False)\n",
583583
" --v Print version\n",
584584
" --b B param 2 (default: test)\n",
585585
" --c {aa,bb,cc} param 3 (default: aa)\n"
@@ -609,16 +609,16 @@
609609
"name": "stdout",
610610
"output_type": "stream",
611611
"text": [
612-
"usage: progname [-h] [--a] [--b B] [--c {aa,bb,cc}] required\n",
612+
"usage: progname [-h] [--b B] [--c {aa,bb,cc}] required a\n",
613613
"\n",
614614
"my docs\n",
615615
"\n",
616616
"positional arguments:\n",
617617
" required Required param\n",
618+
" a param 1\n",
618619
"\n",
619620
"options:\n",
620621
" -h, --help show this help message and exit\n",
621-
" --a param 1 (default: False)\n",
622622
" --b B param 2 (default: test)\n",
623623
" --c {aa,bb,cc} param 3 (default: aa)\n"
624624
]
@@ -752,28 +752,59 @@
752752
{
753753
"cell_type": "code",
754754
"execution_count": null,
755-
"id": "3c1b3a65",
755+
"id": "42c8e85f",
756756
"metadata": {},
757757
"outputs": [],
758758
"source": [
759759
"#| export\n",
760+
"from contextvars import ContextVar\n",
761+
"from contextlib import contextmanager"
762+
]
763+
},
764+
{
765+
"cell_type": "code",
766+
"execution_count": null,
767+
"id": "e4537112",
768+
"metadata": {},
769+
"outputs": [],
770+
"source": [
771+
"#| export\n",
772+
"@contextmanager\n",
773+
"def set_ctx(cv, val=True):\n",
774+
" token = cv.set(val)\n",
775+
" try: yield\n",
776+
" finally: cv.reset(token)"
777+
]
778+
},
779+
{
780+
"cell_type": "code",
781+
"execution_count": null,
782+
"id": "fc816498",
783+
"metadata": {},
784+
"outputs": [],
785+
"source": [
786+
"#| export\n",
787+
"_in_call_parse = ContextVar('_in_call_parse', default=False)\n",
788+
"\n",
760789
"def call_parse(func=None, nested=False):\n",
761790
" \"Decorator to create a simple CLI from `func` using `anno_parser`\"\n",
762791
" if func is None: return partial(call_parse, nested=nested)\n",
763792
"\n",
764793
" @wraps(func)\n",
765794
" def _f(*args, **kwargs):\n",
766-
" mod = inspect.getmodule(inspect.currentframe().f_back)\n",
767-
" if not mod: return func(*args, **kwargs)\n",
768-
" if not SCRIPT_INFO.func and mod.__name__==\"__main__\": SCRIPT_INFO.func = func.__name__\n",
769-
" if len(sys.argv)>1 and sys.argv[1]=='': sys.argv.pop(1)\n",
770-
" p = anno_parser(func)\n",
771-
" if nested: args, sys.argv[1:] = p.parse_known_args()\n",
772-
" else: args = p.parse_args()\n",
773-
" args = args.__dict__\n",
774-
" xtra = otherwise(args.pop('xtra', ''), eq(1), p.prog)\n",
775-
" tfunc = trace(func) if args.pop('pdb', False) else func\n",
776-
" return tfunc(**merge(args, args_from_prog(func, xtra)))\n",
795+
" if args or kwargs or _in_call_parse.get(): return func(*args, **kwargs)\n",
796+
" with set_ctx(_in_call_parse):\n",
797+
" mod = inspect.getmodule(inspect.currentframe().f_back)\n",
798+
" if not mod: return func(*args, **kwargs)\n",
799+
" if not SCRIPT_INFO.func and mod.__name__==\"__main__\": SCRIPT_INFO.func = func.__name__\n",
800+
" if len(sys.argv)>1 and sys.argv[1]=='': sys.argv.pop(1)\n",
801+
" p = anno_parser(func)\n",
802+
" if nested: args, sys.argv[1:] = p.parse_known_args()\n",
803+
" else: args = p.parse_args()\n",
804+
" args = args.__dict__\n",
805+
" xtra = otherwise(args.pop('xtra', ''), eq(1), p.prog)\n",
806+
" tfunc = trace(func) if args.pop('pdb', False) else func\n",
807+
" return tfunc(**merge(args, args_from_prog(func, xtra)))\n",
777808
"\n",
778809
" mod = inspect.getmodule(inspect.currentframe().f_back)\n",
779810
" if getattr(mod, '__name__', '') ==\"__main__\":\n",

0 commit comments

Comments
 (0)