Skip to content

Commit 93fc679

Browse files
committed
Use special method lookup in async for
1 parent 19d49c6 commit 93fc679

File tree

5 files changed

+41
-17
lines changed

5 files changed

+41
-17
lines changed

src/core/IronPython/Compiler/Ast/AsyncForStatement.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
using System.Threading;
88

9-
using Microsoft.Scripting;
9+
using IronPython.Runtime.Binding;
10+
1011
using MSAst = System.Linq.Expressions;
1112

1213
namespace IronPython.Compiler.Ast {
@@ -78,17 +79,15 @@ T SetScope<T>(T node) where T : Node {
7879
}
7980

8081
// _iter = ITER.__aiter__()
81-
var aiterAttr = SetScope(new MemberExpression(List, "__aiter__"));
82-
var aiterCall = SetScope(new CallExpression(aiterAttr, null, null));
82+
var aiterCall = SetScope(new UnaryExpression(PythonOperationKind.AIter, List));
8383
var assignIter = SetScope(new AssignmentStatement([SetScope(new NameExpression(iterName))], aiterCall));
8484

8585
// running = True
8686
var trueConst = SetScope(new ConstantExpression(true));
8787
var assignRunning = SetScope(new AssignmentStatement([SetScope(new NameExpression(runningName))], trueConst));
8888

8989
// TARGET = await __aiter.__anext__()
90-
var anextAttr = SetScope(new MemberExpression(SetScope(new NameExpression(iterName)), "__anext__"));
91-
var anextCall = SetScope(new CallExpression(anextAttr, null, null));
90+
var anextCall = SetScope(new UnaryExpression(PythonOperationKind.ANext, SetScope(new NameExpression(iterName))));
9291
var awaitNext = new AwaitExpression(anextCall);
9392
var assignTarget = SetScope(new AssignmentStatement([Left], awaitNext));
9493

src/core/IronPython/Compiler/Ast/UnaryExpression.cs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
// See the LICENSE file in the project root for more information.
44

5-
using MSAst = System.Linq.Expressions;
5+
#nullable enable
66

7-
using System;
87
using System.Diagnostics;
98

109
using IronPython.Runtime.Binding;
1110

12-
namespace IronPython.Compiler.Ast {
13-
using Ast = MSAst.Expression;
14-
using AstUtils = Microsoft.Scripting.Ast.Utils;
11+
using MSAst = System.Linq.Expressions;
1512

13+
namespace IronPython.Compiler.Ast {
1614
public class UnaryExpression : Expression {
1715
public UnaryExpression(PythonOperator op, Expression expression) {
1816
Operator = op;
17+
OperationKind = PythonOperatorToOperatorString(op);
18+
Expression = expression;
19+
EndIndex = expression.EndIndex;
20+
}
21+
22+
internal UnaryExpression(PythonOperationKind op, Expression expression) {
23+
OperationKind = op;
1924
Expression = expression;
2025
EndIndex = expression.EndIndex;
2126
}
@@ -24,13 +29,10 @@ public UnaryExpression(PythonOperator op, Expression expression) {
2429

2530
public PythonOperator Operator { get; }
2631

27-
public override MSAst.Expression Reduce() {
28-
return GlobalParent.Operation(
29-
typeof(object),
30-
PythonOperatorToOperatorString(Operator),
31-
Expression
32-
);
33-
}
32+
internal PythonOperationKind OperationKind { get; }
33+
34+
public override MSAst.Expression Reduce()
35+
=> GlobalParent.Operation(typeof(object), OperationKind, Expression);
3436

3537
public override void Walk(PythonWalker walker) {
3638
if (walker.Walk(this)) {

src/core/IronPython/Runtime/Binding/PythonOperationKind.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ internal enum PythonOperationKind {
106106
/// </summary>
107107
GetEnumeratorForIteration,
108108

109+
AIter,
110+
ANext,
111+
109112
///<summary>Operator for performing add</summary>
110113
Add,
111114
///<summary>Operator for performing sub</summary>

src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ internal static partial class PythonProtocol {
190190
case PythonOperationKind.GetEnumeratorForIteration:
191191
res = MakeEnumeratorOperation(operation, args[0]);
192192
break;
193+
case PythonOperationKind.AIter:
194+
res = MakeUnaryOperation(operation, args[0], "__aiter__", TypeError(operation, "'async for' requires an object with __aiter__ method, got {0}", args));
195+
break;
196+
case PythonOperationKind.ANext:
197+
res = MakeUnaryOperation(operation, args[0], "__anext__", TypeError(operation, "'async for' received an invalid object from __aiter__: {0}", args));
198+
break;
193199
default:
194200
res = BindingHelpers.AddPythonBoxing(MakeBinaryOperation(operation, args, operation.Operation, null));
195201
break;

tests/suite/test_async.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,20 @@ async def test():
296296

297297
self.assertEqual(run_coro(test()), [110, 120, 210, 220])
298298

299+
def test_special_method_lookup(self):
300+
"""Ensure async for looks up __aiter__ on the type, not the instance."""
301+
302+
a = AsyncIter([1, 2, 3])
303+
a.__aiter__ = lambda: AsyncIter([98]) # should be ignored
304+
a.__anext__ = lambda: 99 # should be ignored
305+
306+
async def test():
307+
result = []
308+
async for x in a:
309+
result.append(x)
310+
return result
311+
312+
self.assertEqual(run_coro(test()), [1, 2, 3])
299313

300314
class AsyncCombinedTest(unittest.TestCase):
301315
"""Tests combining async with and async for."""

0 commit comments

Comments
 (0)