Skip to content

Commit 47ce2ba

Browse files
committed
allow-iterables
1 parent 4627815 commit 47ce2ba

2 files changed

Lines changed: 122 additions & 10 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -838,22 +838,26 @@ def dummy_function(*args, **kwargs):
838838
except:
839839
iterable = None
840840

841-
loop_var_value = None
841+
sample = None
842842
if iterable is not None:
843843
try:
844-
if policy.can_call(iterable.__iter__):
845-
iterator = iter(iterable)
846-
loop_var_value = next(iterator)
847-
except:
848-
pass
844+
if policy.can_call(getattr(iterable, "__iter__", None)):
845+
sample = next(iter(iterable))
846+
except Exception:
847+
sample = None
849848

850849
loop_locals = context.transient_locals.copy()
851-
if isinstance(node.target, ast.Name):
852-
if loop_var_value is not None:
853-
loop_locals[node.target.id] = loop_var_value
854-
855850
loop_context = context.replace(transient_locals=loop_locals)
856851

852+
if sample is not None:
853+
try:
854+
fake_assign = ast.Assign(
855+
targets=[node.target], value=ast.Constant(value=sample)
856+
)
857+
_handle_assign(fake_assign, loop_context)
858+
except:
859+
pass
860+
857861
result = None
858862
for stmt in node.body:
859863
result = eval_node(stmt, loop_context)
@@ -1419,44 +1423,69 @@ def _list_methods(cls, source=None):
14191423

14201424

14211425
dict_keys: type[collections.abc.KeysView] = type({}.keys())
1426+
dict_values: type = type({}.values())
1427+
dict_items: type = type({}.items())
14221428

14231429
NUMERICS = {int, float, complex}
14241430

14251431
ALLOWED_CALLS = {
14261432
bytes,
14271433
*_list_methods(bytes),
1434+
bytes.__iter__,
14281435
dict,
14291436
*_list_methods(dict, dict_non_mutating_methods),
1437+
dict.__iter__,
1438+
dict.keys,
1439+
dict.values,
1440+
dict.items,
1441+
dict_keys.__iter__,
1442+
dict_values.__iter__,
1443+
dict_items.__iter__,
14301444
dict_keys.isdisjoint,
14311445
list,
14321446
*_list_methods(list, list_non_mutating_methods),
1447+
list.__iter__,
14331448
set,
14341449
*_list_methods(set, set_non_mutating_methods),
1450+
set.__iter__,
14351451
frozenset,
14361452
*_list_methods(frozenset),
1453+
frozenset.__iter__,
14371454
range,
1455+
range.__iter__,
14381456
str,
14391457
*_list_methods(str),
1458+
str.__iter__,
14401459
tuple,
14411460
*_list_methods(tuple),
1461+
tuple.__iter__,
14421462
bool,
14431463
*_list_methods(bool),
1464+
enumerate,
1465+
enumerate.__iter__,
14441466
*NUMERICS,
14451467
*[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
14461468
collections.deque,
14471469
*_list_methods(collections.deque, list_non_mutating_methods),
1470+
collections.deque.__iter__,
14481471
collections.defaultdict,
14491472
*_list_methods(collections.defaultdict, dict_non_mutating_methods),
1473+
collections.defaultdict.__iter__,
14501474
collections.OrderedDict,
14511475
*_list_methods(collections.OrderedDict, dict_non_mutating_methods),
1476+
collections.OrderedDict.__iter__,
14521477
collections.UserDict,
14531478
*_list_methods(collections.UserDict, dict_non_mutating_methods),
1479+
collections.UserDict.__iter__,
14541480
collections.UserList,
14551481
*_list_methods(collections.UserList, list_non_mutating_methods),
1482+
collections.UserList.__iter__,
14561483
collections.UserString,
14571484
*_list_methods(collections.UserString, dir(str)),
1485+
collections.UserString.__iter__,
14581486
collections.Counter,
14591487
*_list_methods(collections.Counter, dict_non_mutating_methods),
1488+
collections.Counter.__iter__,
14601489
collections.Counter.elements,
14611490
collections.Counter.most_common,
14621491
object.__dir__,

tests/test_completer.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,6 +2430,89 @@ def _(expected):
24302430
),
24312431
["append", "capitalize"],
24322432
],
2433+
[
2434+
"\n".join(
2435+
[
2436+
"for i in range(10):",
2437+
" i.",
2438+
]
2439+
),
2440+
"bit_length",
2441+
],
2442+
[
2443+
"\n".join(
2444+
[
2445+
"for i in range(10):",
2446+
" if i % 2 == 0:",
2447+
" i.",
2448+
]
2449+
),
2450+
"bit_length",
2451+
],
2452+
[
2453+
"\n".join(
2454+
[
2455+
"for item in ['a', 'b', 'c']:",
2456+
" item.",
2457+
]
2458+
),
2459+
"capitalize",
2460+
],
2461+
[
2462+
"\n".join(
2463+
[
2464+
"for key, value in {'a': 1, 'b': 2}.items():",
2465+
" key.",
2466+
]
2467+
),
2468+
"capitalize",
2469+
],
2470+
[
2471+
"\n".join(
2472+
[
2473+
"for key, value in {'a': 1, 'b': 2}.items():",
2474+
" value.",
2475+
]
2476+
),
2477+
"bit_length",
2478+
],
2479+
[
2480+
"\n".join(
2481+
[
2482+
"for sublist in [[1, 2], [3, 4]]:",
2483+
" sublist.",
2484+
]
2485+
),
2486+
"append",
2487+
],
2488+
[
2489+
"\n".join(
2490+
[
2491+
"for sublist in [[1, 2], [3, 4]]:",
2492+
" for item in sublist:",
2493+
" item.",
2494+
]
2495+
),
2496+
"bit_length",
2497+
],
2498+
[
2499+
"\n".join(
2500+
[
2501+
"for i, char in enumerate('hello'):",
2502+
" char.",
2503+
]
2504+
),
2505+
"capitalize",
2506+
],
2507+
[
2508+
"\n".join(
2509+
[
2510+
"for i, char in enumerate('hello'):",
2511+
" i.",
2512+
]
2513+
),
2514+
"bit_length",
2515+
],
24332516
],
24342517
)
24352518
def test_undefined_variables(use_jedi, evaluation, code, insert_text):

0 commit comments

Comments
 (0)