Skip to content

Commit f96b414

Browse files
committed
[FLINK-39226][python] Fix embedded PyIterator class cast after recovery
1 parent c38b7e6 commit f96b414

9 files changed

Lines changed: 288 additions & 74 deletions

File tree

flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.flink.util.Preconditions;
3232

3333
import com.google.protobuf.AbstractMessageLite;
34-
import pemja.core.object.PyIterator;
3534

3635
import java.util.List;
3736
import java.util.stream.Collectors;
@@ -151,18 +150,17 @@ public void processElement(StreamRecord<IN> element) throws Exception {
151150
timestamp = element.getTimestamp();
152151

153152
IN value = element.getValue();
154-
PyIterator results =
155-
(PyIterator)
153+
try (EmbeddedPythonIterator results =
154+
EmbeddedPythonIterator.from(
156155
interpreter.invokeMethod(
157156
"operation",
158157
"process_element",
159-
inputDataConverter.toExternal(value));
160-
161-
while (results.hasNext()) {
162-
OUT result = outputDataConverter.toInternal(results.next());
163-
collector.collect(result);
158+
inputDataConverter.toExternal(value)))) {
159+
while (results.hasNext()) {
160+
OUT result = outputDataConverter.toInternal(results.next());
161+
collector.collect(result);
162+
}
164163
}
165-
results.close();
166164
}
167165

168166
TypeInformation<IN> getInputTypeInfo() {

flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.flink.util.Preconditions;
3232

3333
import com.google.protobuf.AbstractMessageLite;
34-
import pemja.core.object.PyIterator;
3534

3635
import java.util.List;
3736
import java.util.stream.Collectors;
@@ -178,18 +177,17 @@ public void processElement1(StreamRecord<IN1> element) throws Exception {
178177
timestamp = element.getTimestamp();
179178

180179
IN1 value = element.getValue();
181-
PyIterator results =
182-
(PyIterator)
180+
try (EmbeddedPythonIterator results =
181+
EmbeddedPythonIterator.from(
183182
interpreter.invokeMethod(
184183
"operation",
185184
"process_element1",
186-
inputDataConverter1.toExternal(value));
187-
188-
while (results.hasNext()) {
189-
OUT result = outputDataConverter.toInternal(results.next());
190-
collector.collect(result);
185+
inputDataConverter1.toExternal(value)))) {
186+
while (results.hasNext()) {
187+
OUT result = outputDataConverter.toInternal(results.next());
188+
collector.collect(result);
189+
}
191190
}
192-
results.close();
193191
}
194192

195193
@Override
@@ -198,18 +196,17 @@ public void processElement2(StreamRecord<IN2> element) throws Exception {
198196
timestamp = element.getTimestamp();
199197

200198
IN2 value = element.getValue();
201-
PyIterator results =
202-
(PyIterator)
199+
try (EmbeddedPythonIterator results =
200+
EmbeddedPythonIterator.from(
203201
interpreter.invokeMethod(
204202
"operation",
205203
"process_element2",
206-
inputDataConverter2.toExternal(value));
207-
208-
while (results.hasNext()) {
209-
OUT result = outputDataConverter.toInternal(results.next());
210-
collector.collect(result);
204+
inputDataConverter2.toExternal(value)))) {
205+
while (results.hasNext()) {
206+
OUT result = outputDataConverter.toInternal(results.next());
207+
collector.collect(result);
208+
}
211209
}
212-
results.close();
213210
}
214211

215212
TypeInformation<IN1> getInputTypeInfo1() {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.streaming.api.operators.python.embedded;
20+
21+
import org.apache.flink.annotation.Internal;
22+
23+
import java.lang.reflect.Method;
24+
import java.util.Objects;
25+
26+
/**
27+
* Reflective adapter for embedded Python iterators.
28+
*
29+
* <p>PEMJA iterator objects may come back from a different user-code classloader after recovery, so
30+
* callers should not hard-cast them to {@code pemja.core.object.PyIterator}.
31+
*/
32+
@Internal
33+
public final class EmbeddedPythonIterator implements AutoCloseable {
34+
35+
private final Object iterator;
36+
private final Method hasNextMethod;
37+
private final Method nextMethod;
38+
private final Method closeMethod;
39+
40+
private EmbeddedPythonIterator(Object iterator) {
41+
this.iterator = Objects.requireNonNull(iterator, "iterator must not be null");
42+
43+
try {
44+
Class<?> iteratorClass = iterator.getClass();
45+
this.hasNextMethod = iteratorClass.getMethod("hasNext");
46+
this.nextMethod = iteratorClass.getMethod("next");
47+
this.closeMethod = iteratorClass.getMethod("close");
48+
} catch (ReflectiveOperationException e) {
49+
throw new IllegalStateException(
50+
String.format(
51+
"Failed to adapt embedded Python iterator of type %s.",
52+
iterator.getClass().getName()),
53+
e);
54+
}
55+
}
56+
57+
public static EmbeddedPythonIterator from(Object iterator) {
58+
return new EmbeddedPythonIterator(iterator);
59+
}
60+
61+
public boolean hasNext() throws Exception {
62+
return (boolean) hasNextMethod.invoke(iterator);
63+
}
64+
65+
public Object next() throws Exception {
66+
return nextMethod.invoke(iterator);
67+
}
68+
69+
@Override
70+
public void close() throws Exception {
71+
closeMethod.invoke(iterator);
72+
}
73+
}

flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedCoProcessOperator.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
import org.apache.flink.streaming.api.utils.PythonTypeUtils;
3737
import org.apache.flink.types.Row;
3838

39-
import pemja.core.object.PyIterator;
40-
4139
import java.util.List;
4240

4341
import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
@@ -149,15 +147,14 @@ private void invokeUserFunction(TimeDomain timeDomain, InternalTimer<K, VoidName
149147
onTimerContext.timeDomain = timeDomain;
150148
onTimerContext.timer = timer;
151149

152-
PyIterator results =
153-
(PyIterator)
154-
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp());
155-
156-
while (results.hasNext()) {
157-
OUT result = outputDataConverter.toInternal(results.next());
158-
collector.collect(result);
150+
try (EmbeddedPythonIterator results =
151+
EmbeddedPythonIterator.from(
152+
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp()))) {
153+
while (results.hasNext()) {
154+
OUT result = outputDataConverter.toInternal(results.next());
155+
collector.collect(result);
156+
}
159157
}
160-
results.close();
161158

162159
onTimerContext.timeDomain = null;
163160
onTimerContext.timer = null;

flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
import org.apache.flink.streaming.api.utils.PythonTypeUtils;
3737
import org.apache.flink.types.Row;
3838

39-
import pemja.core.object.PyIterator;
40-
4139
import java.util.List;
4240

4341
import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
@@ -143,15 +141,14 @@ private void invokeUserFunction(TimeDomain timeDomain, InternalTimer<K, VoidName
143141
throws Exception {
144142
onTimerContext.timeDomain = timeDomain;
145143
onTimerContext.timer = timer;
146-
PyIterator results =
147-
(PyIterator)
148-
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp());
149-
150-
while (results.hasNext()) {
151-
OUT result = outputDataConverter.toInternal(results.next());
152-
collector.collect(result);
144+
try (EmbeddedPythonIterator results =
145+
EmbeddedPythonIterator.from(
146+
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp()))) {
147+
while (results.hasNext()) {
148+
OUT result = outputDataConverter.toInternal(results.next());
149+
collector.collect(result);
150+
}
153151
}
154-
results.close();
155152

156153
onTimerContext.timeDomain = null;
157154
onTimerContext.timer = null;

flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonWindowOperator.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
import org.apache.flink.table.runtime.operators.window.Window;
3535
import org.apache.flink.types.Row;
3636

37-
import pemja.core.object.PyIterator;
38-
3937
import java.util.List;
4038

4139
import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
@@ -143,15 +141,14 @@ public <T> DataStreamPythonFunctionOperator<T> copy(
143141
private void invokeUserFunction(InternalTimer<K, W> timer) throws Exception {
144142
windowTimerContext.timer = timer;
145143

146-
PyIterator results =
147-
(PyIterator)
148-
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp());
149-
150-
while (results.hasNext()) {
151-
OUT result = outputDataConverter.toInternal(results.next());
152-
collector.collect(result);
144+
try (EmbeddedPythonIterator results =
145+
EmbeddedPythonIterator.from(
146+
interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp()))) {
147+
while (results.hasNext()) {
148+
OUT result = outputDataConverter.toInternal(results.next());
149+
collector.collect(result);
150+
}
153151
}
154-
results.close();
155152

156153
windowTimerContext.timer = null;
157154
}

flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.flink.configuration.Configuration;
2323
import org.apache.flink.fnexecution.v1.FlinkFnApi;
2424
import org.apache.flink.python.util.ProtoUtils;
25+
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonIterator;
2526
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
2627
import org.apache.flink.table.data.GenericRowData;
2728
import org.apache.flink.table.data.RowData;
@@ -33,8 +34,6 @@
3334
import org.apache.flink.table.types.logical.RowType;
3435
import org.apache.flink.util.Preconditions;
3536

36-
import pemja.core.object.PyIterator;
37-
3837
import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
3938
import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED;
4039
import static org.apache.flink.python.util.ProtoUtils.createFlattenRowTypeCoderInfoDescriptorProto;
@@ -147,25 +146,25 @@ public void processElement(StreamRecord<RowData> element) throws Exception {
147146
userDefinedFunctionInputConverters[i].toExternal(value, udfInputOffsets[i]);
148147
}
149148

150-
PyIterator udtfResults =
151-
(PyIterator)
149+
try (EmbeddedPythonIterator udtfResults =
150+
EmbeddedPythonIterator.from(
152151
interpreter.invokeMethod(
153152
"table_operation",
154153
"process_element",
155-
(Object) (userDefinedFunctionInputArgs));
156-
157-
if (udtfResults.hasNext()) {
158-
do {
159-
Object[] udtfResult = (Object[]) udtfResults.next();
160-
for (int i = 0; i < udtfResult.length; i++) {
161-
reuseResultRowData.setField(
162-
i, userDefinedFunctionOutputConverters[i].toInternal(udtfResult[i]));
163-
}
164-
rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseResultRowData));
165-
} while (udtfResults.hasNext());
166-
} else if (joinType == FlinkJoinType.LEFT) {
167-
rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData));
154+
(Object) (userDefinedFunctionInputArgs)))) {
155+
if (udtfResults.hasNext()) {
156+
do {
157+
Object[] udtfResult = (Object[]) udtfResults.next();
158+
for (int i = 0; i < udtfResult.length; i++) {
159+
reuseResultRowData.setField(
160+
i,
161+
userDefinedFunctionOutputConverters[i].toInternal(udtfResult[i]));
162+
}
163+
rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseResultRowData));
164+
} while (udtfResults.hasNext());
165+
} else if (joinType == FlinkJoinType.LEFT) {
166+
rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData));
167+
}
168168
}
169-
udtfResults.close();
170169
}
171170
}

0 commit comments

Comments
 (0)