Skip to content

Commit a20040e

Browse files
authored
Use special method lookup in async for (#2025)
* Use special method lookup in async for * Remove dead code * Use nameof for StopAsyncIteration
1 parent 19d49c6 commit a20040e

File tree

8 files changed

+43
-46
lines changed

8 files changed

+43
-46
lines changed

src/core/IronPython/Compiler/Ast/AsyncForStatement.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
using System.Threading;
88

9-
using Microsoft.Scripting;
9+
using IronPython.Runtime.Binding;
10+
using IronPython.Runtime.Exceptions;
11+
1012
using MSAst = System.Linq.Expressions;
1113

1214
namespace IronPython.Compiler.Ast {
@@ -78,24 +80,22 @@ T SetScope<T>(T node) where T : Node {
7880
}
7981

8082
// _iter = ITER.__aiter__()
81-
var aiterAttr = SetScope(new MemberExpression(List, "__aiter__"));
82-
var aiterCall = SetScope(new CallExpression(aiterAttr, null, null));
83+
var aiterCall = SetScope(new UnaryExpression(PythonOperationKind.AIter, List));
8384
var assignIter = SetScope(new AssignmentStatement([SetScope(new NameExpression(iterName))], aiterCall));
8485

8586
// running = True
8687
var trueConst = SetScope(new ConstantExpression(true));
8788
var assignRunning = SetScope(new AssignmentStatement([SetScope(new NameExpression(runningName))], trueConst));
8889

8990
// TARGET = await __aiter.__anext__()
90-
var anextAttr = SetScope(new MemberExpression(SetScope(new NameExpression(iterName)), "__anext__"));
91-
var anextCall = SetScope(new CallExpression(anextAttr, null, null));
91+
var anextCall = SetScope(new UnaryExpression(PythonOperationKind.ANext, SetScope(new NameExpression(iterName))));
9292
var awaitNext = new AwaitExpression(anextCall);
9393
var assignTarget = SetScope(new AssignmentStatement([Left], awaitNext));
9494

9595
// except StopAsyncIteration: __running = False
9696
var falseConst = SetScope(new ConstantExpression(false));
9797
var stopRunning = SetScope(new AssignmentStatement([SetScope(new NameExpression(runningName))], falseConst));
98-
var handler = SetScope(new TryStatementHandler(SetScope(new NameExpression("StopAsyncIteration")), null!, SetScope(new SuiteStatement([stopRunning]))));
98+
var handler = SetScope(new TryStatementHandler(SetScope(new NameExpression(nameof(PythonExceptions.StopAsyncIteration))), null!, SetScope(new SuiteStatement([stopRunning]))));
9999
handler.HeaderIndex = span.End;
100100

101101
// try/except/else block

src/core/IronPython/Compiler/Ast/AsyncStatement.cs

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/core/IronPython/Compiler/Ast/PythonNameBinder.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,6 @@ public override bool Walk(AsyncForStatement node) {
351351
node.Parent = _currentScope;
352352
return base.Walk(node);
353353
}
354-
// AsyncStatement
355-
public override bool Walk(AsyncStatement node) {
356-
node.Parent = _currentScope;
357-
return base.Walk(node);
358-
}
359354
// AsyncWithStatement
360355
public override bool Walk(AsyncWithStatement node) {
361356
node.Parent = _currentScope;

src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,6 @@ public virtual void PostWalk(AssignmentStatement node) { }
144144
public virtual bool Walk(AsyncForStatement node) { return true; }
145145
public virtual void PostWalk(AsyncForStatement node) { }
146146

147-
// AsyncStatement
148-
public virtual bool Walk(AsyncStatement node) { return true; }
149-
public virtual void PostWalk(AsyncStatement node) { }
150-
151147
// AsyncWithStatement
152148
public virtual bool Walk(AsyncWithStatement node) { return true; }
153149
public virtual void PostWalk(AsyncWithStatement node) { }
@@ -415,10 +411,6 @@ public override void PostWalk(AssignmentStatement node) { }
415411
public override bool Walk(AsyncForStatement node) { return false; }
416412
public override void PostWalk(AsyncForStatement node) { }
417413

418-
// AsyncStatement
419-
public override bool Walk(AsyncStatement node) { return false; }
420-
public override void PostWalk(AsyncStatement node) { }
421-
422414
// AsyncWithStatement
423415
public override bool Walk(AsyncWithStatement node) { return false; }
424416
public override void PostWalk(AsyncWithStatement node) { }

src/core/IronPython/Compiler/Ast/UnaryExpression.cs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
// See the LICENSE file in the project root for more information.
44

5-
using MSAst = System.Linq.Expressions;
5+
#nullable enable
66

7-
using System;
87
using System.Diagnostics;
98

109
using IronPython.Runtime.Binding;
1110

12-
namespace IronPython.Compiler.Ast {
13-
using Ast = MSAst.Expression;
14-
using AstUtils = Microsoft.Scripting.Ast.Utils;
11+
using MSAst = System.Linq.Expressions;
1512

13+
namespace IronPython.Compiler.Ast {
1614
public class UnaryExpression : Expression {
1715
public UnaryExpression(PythonOperator op, Expression expression) {
1816
Operator = op;
17+
OperationKind = PythonOperatorToOperatorString(op);
18+
Expression = expression;
19+
EndIndex = expression.EndIndex;
20+
}
21+
22+
internal UnaryExpression(PythonOperationKind op, Expression expression) {
23+
OperationKind = op;
1924
Expression = expression;
2025
EndIndex = expression.EndIndex;
2126
}
@@ -24,13 +29,10 @@ public UnaryExpression(PythonOperator op, Expression expression) {
2429

2530
public PythonOperator Operator { get; }
2631

27-
public override MSAst.Expression Reduce() {
28-
return GlobalParent.Operation(
29-
typeof(object),
30-
PythonOperatorToOperatorString(Operator),
31-
Expression
32-
);
33-
}
32+
internal PythonOperationKind OperationKind { get; }
33+
34+
public override MSAst.Expression Reduce()
35+
=> GlobalParent.Operation(typeof(object), OperationKind, Expression);
3436

3537
public override void Walk(PythonWalker walker) {
3638
if (walker.Walk(this)) {

src/core/IronPython/Runtime/Binding/PythonOperationKind.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ internal enum PythonOperationKind {
106106
/// </summary>
107107
GetEnumeratorForIteration,
108108

109+
AIter,
110+
ANext,
111+
109112
///<summary>Operator for performing add</summary>
110113
Add,
111114
///<summary>Operator for performing sub</summary>

src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ internal static partial class PythonProtocol {
190190
case PythonOperationKind.GetEnumeratorForIteration:
191191
res = MakeEnumeratorOperation(operation, args[0]);
192192
break;
193+
case PythonOperationKind.AIter:
194+
res = MakeUnaryOperation(operation, args[0], "__aiter__", TypeError(operation, "'async for' requires an object with __aiter__ method, got {0}", args));
195+
break;
196+
case PythonOperationKind.ANext:
197+
res = MakeUnaryOperation(operation, args[0], "__anext__", TypeError(operation, "'async for' received an invalid object from __aiter__: {0}", args));
198+
break;
193199
default:
194200
res = BindingHelpers.AddPythonBoxing(MakeBinaryOperation(operation, args, operation.Operation, null));
195201
break;

tests/suite/test_async.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,20 @@ async def test():
296296

297297
self.assertEqual(run_coro(test()), [110, 120, 210, 220])
298298

299+
def test_special_method_lookup(self):
300+
"""Ensure async for looks up __aiter__/__anext__ on the type, not the instance."""
301+
302+
a = AsyncIter([1, 2, 3])
303+
a.__aiter__ = lambda: AsyncIter([98]) # should be ignored
304+
a.__anext__ = lambda: 99 # should be ignored
305+
306+
async def test():
307+
result = []
308+
async for x in a:
309+
result.append(x)
310+
return result
311+
312+
self.assertEqual(run_coro(test()), [1, 2, 3])
299313

300314
class AsyncCombinedTest(unittest.TestCase):
301315
"""Tests combining async with and async for."""

0 commit comments

Comments
 (0)