diff --git a/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs b/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs index 390bb0d3e..b32b891b1 100644 --- a/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs +++ b/src/core/IronPython/Compiler/Ast/AsyncForStatement.cs @@ -6,7 +6,9 @@ using System.Threading; -using Microsoft.Scripting; +using IronPython.Runtime.Binding; +using IronPython.Runtime.Exceptions; + using MSAst = System.Linq.Expressions; namespace IronPython.Compiler.Ast { @@ -78,8 +80,7 @@ T SetScope(T node) where T : Node { } // _iter = ITER.__aiter__() - var aiterAttr = SetScope(new MemberExpression(List, "__aiter__")); - var aiterCall = SetScope(new CallExpression(aiterAttr, null, null)); + var aiterCall = SetScope(new UnaryExpression(PythonOperationKind.AIter, List)); var assignIter = SetScope(new AssignmentStatement([SetScope(new NameExpression(iterName))], aiterCall)); // running = True @@ -87,15 +88,14 @@ T SetScope(T node) where T : Node { var assignRunning = SetScope(new AssignmentStatement([SetScope(new NameExpression(runningName))], trueConst)); // TARGET = await __aiter.__anext__() - var anextAttr = SetScope(new MemberExpression(SetScope(new NameExpression(iterName)), "__anext__")); - var anextCall = SetScope(new CallExpression(anextAttr, null, null)); + var anextCall = SetScope(new UnaryExpression(PythonOperationKind.ANext, SetScope(new NameExpression(iterName)))); var awaitNext = new AwaitExpression(anextCall); var assignTarget = SetScope(new AssignmentStatement([Left], awaitNext)); // except StopAsyncIteration: __running = False var falseConst = SetScope(new ConstantExpression(false)); var stopRunning = SetScope(new AssignmentStatement([SetScope(new NameExpression(runningName))], falseConst)); - var handler = SetScope(new TryStatementHandler(SetScope(new NameExpression("StopAsyncIteration")), null!, SetScope(new SuiteStatement([stopRunning])))); + var handler = SetScope(new TryStatementHandler(SetScope(new NameExpression(nameof(PythonExceptions.StopAsyncIteration))), null!, SetScope(new SuiteStatement([stopRunning])))); handler.HeaderIndex = span.End; // try/except/else block diff --git a/src/core/IronPython/Compiler/Ast/AsyncStatement.cs b/src/core/IronPython/Compiler/Ast/AsyncStatement.cs deleted file mode 100644 index cfae57eef..000000000 --- a/src/core/IronPython/Compiler/Ast/AsyncStatement.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -#nullable enable - -using System; - -namespace IronPython.Compiler.Ast { - public class AsyncStatement : Statement { - public override void Walk(PythonWalker walker) { - throw new NotImplementedException(); - } - } -} diff --git a/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs b/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs index 2110c69db..c2b118472 100644 --- a/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs +++ b/src/core/IronPython/Compiler/Ast/PythonNameBinder.cs @@ -351,11 +351,6 @@ public override bool Walk(AsyncForStatement node) { node.Parent = _currentScope; return base.Walk(node); } - // AsyncStatement - public override bool Walk(AsyncStatement node) { - node.Parent = _currentScope; - return base.Walk(node); - } // AsyncWithStatement public override bool Walk(AsyncWithStatement node) { node.Parent = _currentScope; diff --git a/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs b/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs index c97de6a5c..6175f6402 100644 --- a/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs +++ b/src/core/IronPython/Compiler/Ast/PythonWalker.Generated.cs @@ -144,10 +144,6 @@ public virtual void PostWalk(AssignmentStatement node) { } public virtual bool Walk(AsyncForStatement node) { return true; } public virtual void PostWalk(AsyncForStatement node) { } - // AsyncStatement - public virtual bool Walk(AsyncStatement node) { return true; } - public virtual void PostWalk(AsyncStatement node) { } - // AsyncWithStatement public virtual bool Walk(AsyncWithStatement node) { return true; } public virtual void PostWalk(AsyncWithStatement node) { } @@ -415,10 +411,6 @@ public override void PostWalk(AssignmentStatement node) { } public override bool Walk(AsyncForStatement node) { return false; } public override void PostWalk(AsyncForStatement node) { } - // AsyncStatement - public override bool Walk(AsyncStatement node) { return false; } - public override void PostWalk(AsyncStatement node) { } - // AsyncWithStatement public override bool Walk(AsyncWithStatement node) { return false; } public override void PostWalk(AsyncWithStatement node) { } diff --git a/src/core/IronPython/Compiler/Ast/UnaryExpression.cs b/src/core/IronPython/Compiler/Ast/UnaryExpression.cs index dbdc75af7..8d02c67a7 100644 --- a/src/core/IronPython/Compiler/Ast/UnaryExpression.cs +++ b/src/core/IronPython/Compiler/Ast/UnaryExpression.cs @@ -2,20 +2,25 @@ // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. -using MSAst = System.Linq.Expressions; +#nullable enable -using System; using System.Diagnostics; using IronPython.Runtime.Binding; -namespace IronPython.Compiler.Ast { - using Ast = MSAst.Expression; - using AstUtils = Microsoft.Scripting.Ast.Utils; +using MSAst = System.Linq.Expressions; +namespace IronPython.Compiler.Ast { public class UnaryExpression : Expression { public UnaryExpression(PythonOperator op, Expression expression) { Operator = op; + OperationKind = PythonOperatorToOperatorString(op); + Expression = expression; + EndIndex = expression.EndIndex; + } + + internal UnaryExpression(PythonOperationKind op, Expression expression) { + OperationKind = op; Expression = expression; EndIndex = expression.EndIndex; } @@ -24,13 +29,10 @@ public UnaryExpression(PythonOperator op, Expression expression) { public PythonOperator Operator { get; } - public override MSAst.Expression Reduce() { - return GlobalParent.Operation( - typeof(object), - PythonOperatorToOperatorString(Operator), - Expression - ); - } + internal PythonOperationKind OperationKind { get; } + + public override MSAst.Expression Reduce() + => GlobalParent.Operation(typeof(object), OperationKind, Expression); public override void Walk(PythonWalker walker) { if (walker.Walk(this)) { diff --git a/src/core/IronPython/Runtime/Binding/PythonOperationKind.cs b/src/core/IronPython/Runtime/Binding/PythonOperationKind.cs index 3f540daed..c654361eb 100644 --- a/src/core/IronPython/Runtime/Binding/PythonOperationKind.cs +++ b/src/core/IronPython/Runtime/Binding/PythonOperationKind.cs @@ -106,6 +106,9 @@ internal enum PythonOperationKind { /// GetEnumeratorForIteration, + AIter, + ANext, + ///Operator for performing add Add, ///Operator for performing sub diff --git a/src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs b/src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs index 61f916e14..5bdad3a51 100644 --- a/src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs +++ b/src/core/IronPython/Runtime/Binding/PythonProtocol.Operations.cs @@ -190,6 +190,12 @@ internal static partial class PythonProtocol { case PythonOperationKind.GetEnumeratorForIteration: res = MakeEnumeratorOperation(operation, args[0]); break; + case PythonOperationKind.AIter: + res = MakeUnaryOperation(operation, args[0], "__aiter__", TypeError(operation, "'async for' requires an object with __aiter__ method, got {0}", args)); + break; + case PythonOperationKind.ANext: + res = MakeUnaryOperation(operation, args[0], "__anext__", TypeError(operation, "'async for' received an invalid object from __aiter__: {0}", args)); + break; default: res = BindingHelpers.AddPythonBoxing(MakeBinaryOperation(operation, args, operation.Operation, null)); break; diff --git a/tests/suite/test_async.py b/tests/suite/test_async.py index 59ae3500a..b6c24bd09 100644 --- a/tests/suite/test_async.py +++ b/tests/suite/test_async.py @@ -296,6 +296,20 @@ async def test(): self.assertEqual(run_coro(test()), [110, 120, 210, 220]) + def test_special_method_lookup(self): + """Ensure async for looks up __aiter__/__anext__ on the type, not the instance.""" + + a = AsyncIter([1, 2, 3]) + a.__aiter__ = lambda: AsyncIter([98]) # should be ignored + a.__anext__ = lambda: 99 # should be ignored + + async def test(): + result = [] + async for x in a: + result.append(x) + return result + + self.assertEqual(run_coro(test()), [1, 2, 3]) class AsyncCombinedTest(unittest.TestCase): """Tests combining async with and async for."""