Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@
import org.apache.flink.state.rocksdb.EmbeddedRocksDBStateBackend;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.test.util.AbstractTestBaseJUnit4;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;

import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

import java.io.File;
import java.io.IOException;
Expand All @@ -54,26 +55,20 @@
import java.util.stream.Stream;

import static org.apache.flink.configuration.CheckpointingOptions.FS_SMALL_FILE_THRESHOLD;
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.isIn;
import static org.junit.Assert.assertThat;
import static org.assertj.core.api.Assertions.assertThat;

/** Test the savepoint deep copy. */
@RunWith(value = Parameterized.class)
public class SavepointDeepCopyTest extends AbstractTestBaseJUnit4 {
@ExtendWith(ParameterizedTestExtension.class)
class SavepointDeepCopyTest extends AbstractTestBase {

private static final MemorySize FILE_STATE_SIZE_THRESHOLD = new MemorySize(1);

private static final String TEXT = "The quick brown fox jumps over the lazy dog";
private static final String RANDOM_VALUE = RandomStringUtils.randomAlphanumeric(120);

private final StateBackend backend;
@Parameter public StateBackend backend;

public SavepointDeepCopyTest(StateBackend backend) throws Exception {
this.backend = backend;
}

@Parameterized.Parameters(name = "State Backend: {0}")
@Parameters(name = "State Backend: {0}")
public static Collection<StateBackend> data() {
return Arrays.asList(new HashMapStateBackend(), new EmbeddedRocksDBStateBackend());
}
Expand Down Expand Up @@ -132,8 +127,8 @@ public void readKey(String key, Context ctx, Collector<Tuple2<String, String>> o
*
* @throws Exception throw exceptions when anything goes wrong
*/
@Test
public void testSavepointDeepCopy() throws Exception {
@TestTemplate
void testSavepointDeepCopy() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);

Expand All @@ -156,9 +151,7 @@ public void testSavepointDeepCopy() throws Exception {

Set<String> stateFiles1 = getFileNamesInDirectory(Paths.get(savepointPath1));

Assert.assertTrue(
"Failed to bootstrap savepoint1 with additional state files",
stateFiles1.size() > 1);
assertThat(stateFiles1).hasSizeGreaterThan(1);

// create savepoint2 from savepoint1 created above
File savepointUrl2 = createAndRegisterTempFile(new AbstractID().toHexString());
Expand All @@ -175,14 +168,9 @@ public void testSavepointDeepCopy() throws Exception {

Set<String> stateFiles2 = getFileNamesInDirectory(Paths.get(savepointPath1));

Assert.assertTrue(
"Failed to create savepoint2 from savepoint1 with additional state files",
stateFiles2.size() > 1);
assertThat(stateFiles2).hasSizeGreaterThan(1);

assertThat(
"At least one state file in savepoint1 are not in savepoint2",
stateFiles1,
everyItem(isIn(stateFiles2)));
assertThat(stateFiles1).isSubsetOf(stateFiles2);

// Try to fromExistingSavepoint savepoint2 and read the state of "Operator1" (which has not
// been
Expand All @@ -197,10 +185,7 @@ public void testSavepointDeepCopy() throws Exception {
.size();

long expectedKeyNum = Arrays.stream(TEXT.split(" ")).distinct().count();
Assert.assertEquals(
"Unexpected number of keys in the state of Operator1",
expectedKeyNum,
actuallyKeyNum);
assertThat(actuallyKeyNum).isEqualTo(expectedKeyNum);
}

private static Set<String> getFileNamesInDirectory(Path path) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.apache.flink.api.connector.source.SplitEnumerator;
import org.apache.flink.api.connector.source.SplitEnumeratorContext;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.client.program.rest.RestClusterClient;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.core.io.InputStatus;
import org.apache.flink.runtime.jobgraph.JobGraph;
Expand All @@ -43,16 +43,17 @@
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.test.util.AbstractTestBaseJUnit4;
import org.apache.flink.test.junit5.InjectClusterClient;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.source.AbstractTestSource;
import org.apache.flink.test.util.source.SingleSplitEnumerator;
import org.apache.flink.test.util.source.TestSourceReader;
import org.apache.flink.test.util.source.TestSplit;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;

import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.time.Duration;
Expand All @@ -65,9 +66,11 @@
import java.util.stream.Collectors;

import static org.apache.flink.state.api.utils.SavepointTestBase.waitForAllRunningOrSomeTerminal;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;

/** IT case for reading state. */
public abstract class SavepointReaderITTestBase extends AbstractTestBaseJUnit4 {
public abstract class SavepointReaderITTestBase extends AbstractTestBase {
static final String UID = "stateful-operator";

static final String LIST_NAME = "list";
Expand All @@ -82,6 +85,13 @@ public abstract class SavepointReaderITTestBase extends AbstractTestBaseJUnit4 {

private final MapStateDescriptor<Integer, String> broadcast;

private RestClusterClient<?> clusterClient;

@BeforeEach
void setClusterClient(@InjectClusterClient RestClusterClient<?> clusterClient) {
this.clusterClient = clusterClient;
}

SavepointReaderITTestBase(
ListStateDescriptor<Integer> list,
ListStateDescriptor<Integer> union,
Expand All @@ -93,7 +103,7 @@ public abstract class SavepointReaderITTestBase extends AbstractTestBaseJUnit4 {
}

@Test
public void testOperatorStateInputFormat() throws Exception {
void testOperatorStateInputFormat() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);

Expand Down Expand Up @@ -133,21 +143,15 @@ private void verifyListState(String path, StreamExecutionEnvironment env) throws
List<Integer> listResult = JobResultRetriever.collect(readListState(savepoint));
listResult.sort(Comparator.naturalOrder());

Assert.assertEquals(
"Unexpected elements read from list state",
SavepointSource.getElements(),
listResult);
assertThat(listResult).isEqualTo(SavepointSource.getElements());
}

private void verifyUnionState(String path, StreamExecutionEnvironment env) throws Exception {
SavepointReader savepoint = SavepointReader.read(env, path, new HashMapStateBackend());
List<Integer> unionResult = JobResultRetriever.collect(readUnionState(savepoint));
unionResult.sort(Comparator.naturalOrder());

Assert.assertEquals(
"Unexpected elements read from union state",
SavepointSource.getElements(),
unionResult);
assertThat(unionResult).isEqualTo(SavepointSource.getElements());
}

private void verifyBroadcastState(String path, StreamExecutionEnvironment env)
Expand All @@ -168,34 +172,29 @@ private void verifyBroadcastState(String path, StreamExecutionEnvironment env)
.sorted(Comparator.naturalOrder())
.collect(Collectors.toList());

Assert.assertEquals(
"Unexpected element in broadcast state keys",
SavepointSource.getElements(),
broadcastStateKeys);

Assert.assertEquals(
"Unexpected element in broadcast state values",
SavepointSource.getElements().stream()
.map(Object::toString)
.sorted()
.collect(Collectors.toList()),
broadcastStateValues);
assertThat(broadcastStateKeys).isEqualTo(SavepointSource.getElements());

assertThat(broadcastStateValues)
.isEqualTo(
SavepointSource.getElements().stream()
.map(Object::toString)
.sorted()
.collect(Collectors.toList()));
}

private String takeSavepoint(JobGraph jobGraph) throws Exception {
SavepointSource.initializeForTest();

ClusterClient<?> client = MINI_CLUSTER_RESOURCE.getClusterClient();
JobID jobId = jobGraph.getJobID();

Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5));

String dirPath = getTempDirPath(new AbstractID().toHexString());

try {
JobID jobID = client.submitJob(jobGraph).get();
JobID jobID = clusterClient.submitJob(jobGraph).get();

waitForAllRunningOrSomeTerminal(jobID, MINI_CLUSTER_RESOURCE);
waitForAllRunningOrSomeTerminal(jobID, clusterClient);
boolean finished = false;
while (deadline.hasTimeLeft()) {
if (SavepointSource.isFinished()) {
Expand All @@ -212,14 +211,14 @@ private String takeSavepoint(JobGraph jobGraph) throws Exception {
}

if (!finished) {
Assert.fail("Failed to initialize state within deadline");
fail("Failed to initialize state within deadline");
}

CompletableFuture<String> path =
client.triggerSavepoint(jobID, dirPath, SavepointFormatType.CANONICAL);
clusterClient.triggerSavepoint(jobID, dirPath, SavepointFormatType.CANONICAL);
return path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
} finally {
client.cancel(jobId).get();
clusterClient.cancel(jobId).get();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.util.Collector;

import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;
Expand All @@ -43,6 +42,8 @@
import java.util.Objects;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;

/** IT case for reading state. */
public abstract class SavepointReaderKeyedStateITCase<B extends StateBackend>
extends SavepointTestBase {
Expand All @@ -57,7 +58,7 @@ public abstract class SavepointReaderKeyedStateITCase<B extends StateBackend>
protected abstract Tuple2<Configuration, B> getStateBackendTuple();

@Test
public void testUserKeyedStateReader() throws Exception {
void testUserKeyedStateReader() throws Exception {
Tuple2<Configuration, B> backendTuple = getStateBackendTuple();
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment(backendTuple.f0);
Expand All @@ -81,8 +82,7 @@ public void testUserKeyedStateReader() throws Exception {

Set<Pojo> expected = new HashSet<>(elements);

Assert.assertEquals(
"Unexpected results from keyed state", expected, new HashSet<>(results));
assertThat(new HashSet<>(results)).isEqualTo(expected);
}

private static class KeyedStatefulOperator extends KeyedProcessFunction<Integer, Pojo, Void> {
Expand Down
Loading