Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-32097][Connectors/Kinesis] Implement support for Kinesis deaggregation #188

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions flink-connector-aws/flink-connector-aws-kinesis-streams/pom.xml
Original file line number Diff line number Diff line change
@@ -33,6 +33,14 @@ under the License.
<name>Flink : Connectors : AWS : Amazon Kinesis Data Streams Connector v2</name>
<packaging>jar</packaging>

<repositories>
<!-- used for the kinesis aggregator dependency since it is not available in maven central -->
<repository>
<id>jitpack.io</id>
<url>https://jitpack.io</url>
</repository>
</repositories>

<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
@@ -52,6 +60,11 @@ under the License.
<artifactId>kinesis</artifactId>
</dependency>

<dependency>
<groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId>
</dependency>

<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>arns</artifactId>
@@ -102,6 +115,15 @@ under the License.
<scope>test</scope>
</dependency>

<dependency>
<!-- the kinesis aggregator dependency since it is not available in maven central -->
<!-- look into issue https://github.com/awslabs/kinesis-aggregation/issues/120 -->
<groupId>com.github.awslabs.kinesis-aggregation</groupId>
<artifactId>amazon-kinesis-aggregator</artifactId>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has Apache 2.0 licence, all good so
https://github.com/awslabs/kinesis-aggregation

<version>2.0.3</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>nl.jqno.equalsverifier</groupId>
<artifactId>equalsverifier</artifactId>
Original file line number Diff line number Diff line change
@@ -70,9 +70,9 @@
import software.amazon.awssdk.services.kinesis.KinesisClient;
import software.amazon.awssdk.services.kinesis.model.DescribeStreamConsumerResponse;
import software.amazon.awssdk.services.kinesis.model.LimitExceededException;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.time.Duration;
import java.util.Map;
@@ -209,8 +209,10 @@ public SimpleVersionedSerializer<KinesisShardSplit> getSplitSerializer() {
return new KinesisStreamsSourceEnumeratorStateSerializer(new KinesisShardSplitSerializer());
}

private Supplier<SplitReader<Record, KinesisShardSplit>> getKinesisShardSplitReaderSupplier(
Configuration sourceConfig, Map<String, KinesisShardMetrics> shardMetricGroupMap) {
private Supplier<SplitReader<KinesisClientRecord, KinesisShardSplit>>
getKinesisShardSplitReaderSupplier(
Configuration sourceConfig,
Map<String, KinesisShardMetrics> shardMetricGroupMap) {
KinesisSourceConfigOptions.ReaderType readerType = sourceConfig.get(READER_TYPE);
switch (readerType) {
// We create a new stream proxy for each split reader since they have their own
Original file line number Diff line number Diff line change
@@ -29,8 +29,8 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import javax.annotation.Nullable;

@@ -41,7 +41,6 @@
import java.util.Deque;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

@@ -50,10 +49,10 @@
/** Base implementation of the SplitReader for reading from KinesisShardSplits. */
@Internal
public abstract class KinesisShardSplitReaderBase
implements SplitReader<Record, KinesisShardSplit> {
implements SplitReader<KinesisClientRecord, KinesisShardSplit> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change actually changes the interface exposed for the KinesisSource, so it would be a backwards incompatible change. Is there a way we can wrap this internally?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been looking at the implications of not changing from Record into KinesisClientRecord and this is what I have found:

  • The deaggreated records would have to be converted from KinesisClientRecord, since AggregatorUtil only uses KinesisClientRecord, to Record adding a slight overhead.
  • The KinesisClientRecord has the subSequenceNumber that could and should be used for checkpointing.

Apart from that I don't see any more issues. Either way I don't mind changing it back to Record.


private static final Logger LOG = LoggerFactory.getLogger(KinesisShardSplitReaderBase.class);
private static final RecordsWithSplitIds<Record> INCOMPLETE_SHARD_EMPTY_RECORDS =
private static final RecordsWithSplitIds<KinesisClientRecord> INCOMPLETE_SHARD_EMPTY_RECORDS =
new KinesisRecordsWithSplitIds(Collections.emptyIterator(), null, false);

private final Deque<KinesisShardSplitState> assignedSplits = new ArrayDeque<>();
@@ -65,7 +64,7 @@ protected KinesisShardSplitReaderBase(Map<String, KinesisShardMetrics> shardMetr
}

@Override
public RecordsWithSplitIds<Record> fetch() throws IOException {
public RecordsWithSplitIds<KinesisClientRecord> fetch() throws IOException {
KinesisShardSplitState splitState = assignedSplits.poll();

// When there are no assigned splits, return quickly
@@ -103,7 +102,7 @@ public RecordsWithSplitIds<Record> fetch() throws IOException {
.get(splitState.getShardId())
.setMillisBehindLatest(recordBatch.getMillisBehindLatest());

if (recordBatch.getRecords().isEmpty()) {
if (recordBatch.getDeaggregatedRecords().isEmpty()) {
if (recordBatch.isCompleted()) {
return new KinesisRecordsWithSplitIds(
Collections.emptyIterator(), splitState.getSplitId(), true);
@@ -115,12 +114,12 @@ public RecordsWithSplitIds<Record> fetch() throws IOException {
splitState.setNextStartingPosition(
StartingPosition.continueFromSequenceNumber(
recordBatch
.getRecords()
.get(recordBatch.getRecords().size() - 1)
.getDeaggregatedRecords()
.get(recordBatch.getDeaggregatedRecords().size() - 1)
.sequenceNumber()));

return new KinesisRecordsWithSplitIds(
recordBatch.getRecords().iterator(),
recordBatch.getDeaggregatedRecords().iterator(),
splitState.getSplitId(),
recordBatch.isCompleted());
}
@@ -154,48 +153,20 @@ public void pauseOrResumeSplits(
splitsToResume.forEach(split -> pausedSplitIds.remove(split.splitId()));
}

/**
* Dataclass to store a batch of Kinesis records with metadata. Used to pass Kinesis records
* from the SplitReader implementation to the SplitReaderBase.
*/
@Internal
protected static class RecordBatch {
private final List<Record> records;
private final long millisBehindLatest;
private final boolean completed;

public RecordBatch(List<Record> records, long millisBehindLatest, boolean completed) {
this.records = records;
this.millisBehindLatest = millisBehindLatest;
this.completed = completed;
}

public List<Record> getRecords() {
return records;
}

public long getMillisBehindLatest() {
return millisBehindLatest;
}

public boolean isCompleted() {
return completed;
}
}

/**
* Implementation of {@link RecordsWithSplitIds} for sending Kinesis records from fetcher to the
* SourceReader.
*/
@Internal
private static class KinesisRecordsWithSplitIds implements RecordsWithSplitIds<Record> {
private static class KinesisRecordsWithSplitIds
implements RecordsWithSplitIds<KinesisClientRecord> {

private final Iterator<Record> recordsIterator;
private final Iterator<KinesisClientRecord> recordsIterator;
private final String splitId;
private final boolean isComplete;

public KinesisRecordsWithSplitIds(
Iterator<Record> recordsIterator, String splitId, boolean isComplete) {
Iterator<KinesisClientRecord> recordsIterator, String splitId, boolean isComplete) {
this.recordsIterator = recordsIterator;
this.splitId = splitId;
this.isComplete = isComplete;
@@ -209,7 +180,7 @@ public String nextSplit() {

@Nullable
@Override
public Record nextRecordFromSplit() {
public KinesisClientRecord nextRecordFromSplit() {
return recordsIterator.hasNext() ? recordsIterator.next() : null;
}

Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
import org.apache.flink.connector.kinesis.source.split.StartingPosition;
import org.apache.flink.util.Collector;

import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

/**
* Emits record from the source into the Flink job graph. This serves as the interface between the
@@ -36,7 +36,7 @@
*/
@Internal
public class KinesisStreamsRecordEmitter<T>
implements RecordEmitter<Record, T, KinesisShardSplitState> {
implements RecordEmitter<KinesisClientRecord, T, KinesisShardSplitState> {

private final KinesisDeserializationSchema<T> deserializationSchema;
private final SourceOutputWrapper<T> sourceOutputWrapper = new SourceOutputWrapper<>();
@@ -47,7 +47,7 @@ public KinesisStreamsRecordEmitter(KinesisDeserializationSchema<T> deserializati

@Override
public void emitRecord(
Record element, SourceOutput<T> output, KinesisShardSplitState splitState)
KinesisClientRecord element, SourceOutput<T> output, KinesisShardSplitState splitState)
throws Exception {
sourceOutputWrapper.setSourceOutput(output);
sourceOutputWrapper.setTimestamp(element.approximateArrivalTimestamp().toEpochMilli());
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.util.HashSet;
import java.util.List;
@@ -45,14 +45,14 @@
@Internal
public class KinesisStreamsSourceReader<T>
extends SingleThreadMultiplexSourceReaderBase<
Record, T, KinesisShardSplit, KinesisShardSplitState> {
KinesisClientRecord, T, KinesisShardSplit, KinesisShardSplitState> {

private static final Logger LOG = LoggerFactory.getLogger(KinesisStreamsSourceReader.class);
private final Map<String, KinesisShardMetrics> shardMetricGroupMap;

public KinesisStreamsSourceReader(
SingleThreadFetcherManager<Record, KinesisShardSplit> splitFetcherManager,
RecordEmitter<Record, T, KinesisShardSplitState> recordEmitter,
SingleThreadFetcherManager<KinesisClientRecord, KinesisShardSplit> splitFetcherManager,
RecordEmitter<KinesisClientRecord, T, KinesisShardSplitState> recordEmitter,
Configuration config,
SourceReaderContext context,
Map<String, KinesisShardMetrics> shardMetricGroupMap) {
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.connector.kinesis.source.reader;

import org.apache.flink.annotation.Internal;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;

import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.AggregatorUtil;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.util.ArrayList;
import java.util.List;

/**
* Dataclass to store a batch of Kinesis records with metadata. Used to pass Kinesis records from
* the SplitReader implementation to the SplitReaderBase.
*
* <p>Input records are de-aggregated using KCL 3.x library. It is expected that AWS SDK v2.x
* messages are converted to KCL 3.x {@link KinesisClientRecord}.
*/
@Internal
public class RecordBatch {
private final List<KinesisClientRecord> deaggregatedRecords;
private final long millisBehindLatest;
private final boolean completed;

public RecordBatch(
final List<Record> records,
final KinesisShardSplit subscribedShard,
final long millisBehindLatest,
final boolean completed) {
this.deaggregatedRecords = deaggregateRecords(records, subscribedShard);
this.millisBehindLatest = millisBehindLatest;
this.completed = completed;
}

public List<KinesisClientRecord> getDeaggregatedRecords() {
return deaggregatedRecords;
}

public long getMillisBehindLatest() {
return millisBehindLatest;
}

public boolean isCompleted() {
return completed;
}

private List<KinesisClientRecord> deaggregateRecords(
final List<Record> records, final KinesisShardSplit subscribedShard) {
final List<KinesisClientRecord> kinesisClientRecords = new ArrayList<>();
for (Record record : records) {
kinesisClientRecords.add(KinesisClientRecord.fromRecord(record));
}

final String startingHashKey = subscribedShard.getStartingHashKey();
final String endingHashKey = subscribedShard.getEndingHashKey();

return new AggregatorUtil()
.deaggregate(kinesisClientRecords, startingHashKey, endingHashKey);
}
}
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy;
import org.apache.flink.connector.kinesis.source.reader.KinesisShardSplitReaderBase;
import org.apache.flink.connector.kinesis.source.reader.RecordBatch;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplitState;

@@ -69,7 +70,11 @@ protected RecordBatch fetchRecords(KinesisShardSplitState splitState) {
if (shardCompleted) {
splitSubscriptions.remove(splitState.getShardId());
}
return new RecordBatch(event.records(), event.millisBehindLatest(), shardCompleted);
return new RecordBatch(
event.records(),
splitState.getKinesisShardSplit(),
event.millisBehindLatest(),
shardCompleted);
}

@Override
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics;
import org.apache.flink.connector.kinesis.source.proxy.StreamProxy;
import org.apache.flink.connector.kinesis.source.reader.KinesisShardSplitReaderBase;
import org.apache.flink.connector.kinesis.source.reader.RecordBatch;
import org.apache.flink.connector.kinesis.source.split.KinesisShardSplitState;

import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
@@ -59,8 +60,12 @@ protected RecordBatch fetchRecords(KinesisShardSplitState splitState) {
splitState.getNextStartingPosition(),
this.maxRecordsToGet);
boolean isCompleted = getRecordsResponse.nextShardIterator() == null;

return new RecordBatch(
getRecordsResponse.records(), getRecordsResponse.millisBehindLatest(), isCompleted);
getRecordsResponse.records(),
splitState.getKinesisShardSplit(),
getRecordsResponse.millisBehindLatest(),
isCompleted);
}

@Override
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import org.apache.flink.connector.kinesis.source.KinesisStreamsSource;
import org.apache.flink.util.Collector;

import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.io.IOException;
import java.io.Serializable;
@@ -60,7 +60,7 @@ default void open(DeserializationSchema.InitializationContext context) throws Ex
* @param output the identifier of the shard the record was sent to
* @throws IOException exception when deserializing record
*/
void deserialize(Record record, String stream, String shardId, Collector<T> output)
void deserialize(KinesisClientRecord record, String stream, String shardId, Collector<T> output)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change to the public interface

throws IOException;

static <T> KinesisDeserializationSchema<T> of(DeserializationSchema<T> deserializationSchema) {
Original file line number Diff line number Diff line change
@@ -22,9 +22,10 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.util.Collector;

import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.io.IOException;
import java.nio.ByteBuffer;

/**
* A simple wrapper for using the {@link DeserializationSchema} with the {@link
@@ -48,9 +49,15 @@ public void open(DeserializationSchema.InitializationContext context) throws Exc
}

@Override
public void deserialize(Record record, String stream, String shardId, Collector<T> output)
public void deserialize(
KinesisClientRecord record, String stream, String shardId, Collector<T> output)
throws IOException {
deserializationSchema.deserialize(record.data().asByteArray(), output);
ByteBuffer recordData = record.data();

byte[] dataBytes = new byte[recordData.remaining()];
recordData.get(dataBytes);

deserializationSchema.deserialize(dataBytes, output);
}

@Override
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.test.junit5.MiniClusterExtension;

import com.amazonaws.kinesis.agg.AggRecord;
import com.amazonaws.kinesis.agg.RecordAggregator;
import com.google.common.collect.Lists;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
@@ -24,7 +26,6 @@
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.SdkSystemSetting;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.regions.Region;
@@ -41,7 +42,6 @@

import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

@@ -117,6 +117,17 @@ void singleShardStreamIsConsumed() throws Exception {
.runScenario();
}

@Test
void singleShardStreamWithAggregationIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("single-shard-stream-aggregation")
.shardCount(1)
.aggregationFactor(10)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/single-shard-stream-aggregation")
.runScenario();
}

@Test
void multipleShardStreamIsConsumed() throws Exception {
new Scenario()
@@ -127,6 +138,17 @@ void multipleShardStreamIsConsumed() throws Exception {
.runScenario();
}

@Test
void multipleShardStreamWithAggregationIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("multiple-shard-stream-aggregation")
.shardCount(4)
.aggregationFactor(10)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/multiple-shard-stream-aggregation")
.runScenario();
}

@Test
void reshardedStreamIsConsumed() throws Exception {
new Scenario()
@@ -138,6 +160,18 @@ void reshardedStreamIsConsumed() throws Exception {
.runScenario();
}

@Test
void reshardedStreamWithAggregationIsConsumed() throws Exception {
new Scenario()
.localstackStreamName("resharded-stream-aggregation")
.shardCount(1)
.aggregationFactor(10)
.reshardStream(2)
.withSourceConnectionStreamArn(
"arn:aws:kinesis:ap-southeast-1:000000000000:stream/resharded-stream-aggregation")
.runScenario();
}

private Configuration getDefaultConfiguration() {
Configuration configuration = new Configuration();
configuration.setString(AWS_ENDPOINT, MOCK_KINESIS_CONTAINER.getEndpoint());
@@ -152,6 +186,7 @@ private Configuration getDefaultConfiguration() {

private class Scenario {
private final int expectedElements = 1000;
private int aggregationFactor = 1;
private String localstackStreamName = null;
private int shardCount = 1;
private boolean shouldReshardStream = false;
@@ -203,6 +238,11 @@ public Scenario reshardStream(int targetShardCount) {
return this;
}

public Scenario aggregationFactor(int aggregationFactor) {
this.aggregationFactor = aggregationFactor;
return this;
}

private void prepareStream(String streamName) throws Exception {
final RateLimiter rateLimiter =
RateLimiterBuilder.newBuilder()
@@ -242,13 +282,8 @@ private void putRecords(String streamName, int startInclusive, int endInclusive)

for (List<byte[]> partition : Lists.partition(messages, 500)) {
List<PutRecordsRequestEntry> entries =
partition.stream()
.map(
msg ->
PutRecordsRequestEntry.builder()
.partitionKey(UUID.randomUUID().toString())
.data(SdkBytes.fromByteArray(msg))
.build())
Lists.partition(partition, aggregationFactor).stream()
.map(this::createAggregatePutRecordsRequestEntry)
.collect(Collectors.toList());
PutRecordsRequest requests =
PutRecordsRequest.builder().streamName(streamName).records(entries).build();
@@ -259,6 +294,22 @@ private void putRecords(String streamName, int startInclusive, int endInclusive)
}
}

private PutRecordsRequestEntry createAggregatePutRecordsRequestEntry(
List<byte[]> messages) {
RecordAggregator recordAggregator = new RecordAggregator();

for (byte[] message : messages) {
try {
recordAggregator.addUserRecord("key", message);
} catch (Exception e) {
throw new RuntimeException("Failed to add record to aggregator", e);
}
}

AggRecord aggRecord = recordAggregator.clearAndGet();
return aggRecord.toPutRecordsRequestEntry();
}

private void reshard(String streamName) {
kinesisClient.updateShardCount(
UpdateShardCountRequest.builder()
Original file line number Diff line number Diff line change
@@ -29,10 +29,10 @@
import org.apache.flink.util.Collector;

import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
@@ -50,24 +50,18 @@ class KinesisStreamsRecordEmitterTest {
@Test
void testEmitRecord() throws Exception {
final Instant startTime = Instant.now();
List<Record> inputRecords =
List<KinesisClientRecord> inputRecords =
Stream.of(
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-1")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-1")))
.approximateArrivalTimestamp(startTime)
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-2")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-2")))
.approximateArrivalTimestamp(startTime.plusSeconds(10))
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-3")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-3")))
.approximateArrivalTimestamp(startTime.plusSeconds(20))
.sequenceNumber("some-sequence-number")
.build())
@@ -79,7 +73,7 @@ void testEmitRecord() throws Exception {

KinesisStreamsRecordEmitter<String> emitter =
new KinesisStreamsRecordEmitter<>(KinesisDeserializationSchema.of(STRING_SCHEMA));
for (Record record : inputRecords) {
for (KinesisClientRecord record : inputRecords) {
emitter.emitRecord(record, output, splitState);
}

@@ -97,26 +91,20 @@ void testEmitRecord() throws Exception {
@Test
void testEmitRecordBasedOnSequenceNumber() throws Exception {
final Instant startTime = Instant.now();
List<Record> inputRecords =
List<KinesisClientRecord> inputRecords =
Stream.of(
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-1")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-1")))
.sequenceNumber("emit")
.approximateArrivalTimestamp(startTime)
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-2")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-2")))
.sequenceNumber("emit")
.approximateArrivalTimestamp(startTime.plusSeconds(10))
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-3")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-3")))
.approximateArrivalTimestamp(startTime.plusSeconds(20))
.sequenceNumber("do-not-emit")
.build())
@@ -126,7 +114,7 @@ void testEmitRecordBasedOnSequenceNumber() throws Exception {

KinesisStreamsRecordEmitter<String> emitter =
new KinesisStreamsRecordEmitter<>(new SequenceNumberBasedDeserializationSchema());
for (Record record : inputRecords) {
for (KinesisClientRecord record : inputRecords) {
emitter.emitRecord(record, output, splitState);
}

@@ -139,24 +127,18 @@ void testEmitRecordBasedOnSequenceNumber() throws Exception {
@Test
void testEmitRecordWithMetadata() throws Exception {
final Instant startTime = Instant.now();
List<Record> inputRecords =
List<KinesisClientRecord> inputRecords =
Stream.of(
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-1")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-1")))
.approximateArrivalTimestamp(startTime)
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-2")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-2")))
.approximateArrivalTimestamp(startTime.plusSeconds(10))
.build(),
Record.builder()
.data(
SdkBytes.fromByteArray(
STRING_SCHEMA.serialize("data-3")))
KinesisClientRecord.builder()
.data(ByteBuffer.wrap(STRING_SCHEMA.serialize("data-3")))
.approximateArrivalTimestamp(startTime.plusSeconds(20))
.sequenceNumber("some-sequence-number")
.build())
@@ -168,7 +150,7 @@ void testEmitRecordWithMetadata() throws Exception {
new KinesisStreamsRecordEmitter<>(
new AssertRecordMetadataDeserializationSchema(
splitState.getStreamArn(), splitState.getShardId()));
for (Record record : inputRecords) {
for (KinesisClientRecord record : inputRecords) {
emitter.emitRecord(record, output, splitState);
}

@@ -225,10 +207,13 @@ private static class SequenceNumberBasedDeserializationSchema

@Override
public void deserialize(
Record record, String stream, String shardId, Collector<String> output)
KinesisClientRecord record, String stream, String shardId, Collector<String> output)
throws IOException {
if (Objects.equals(record.sequenceNumber(), "emit")) {
STRING_SCHEMA.deserialize(record.data().asByteArray(), output);
ByteBuffer recordData = record.data();
byte[] dataBytes = new byte[recordData.remaining()];
recordData.get(dataBytes);
STRING_SCHEMA.deserialize(dataBytes, output);
}
}

@@ -251,11 +236,15 @@ private AssertRecordMetadataDeserializationSchema(

@Override
public void deserialize(
Record record, String stream, String shardId, Collector<String> output)
KinesisClientRecord record, String stream, String shardId, Collector<String> output)
throws IOException {
assertThat(stream).isEqualTo(expectedStreamArn);
assertThat(shardId).isEqualTo(expectedShardId);
STRING_SCHEMA.deserialize(record.data().asByteArray(), output);

ByteBuffer recordData = record.data();
byte[] dataBytes = new byte[recordData.remaining()];
recordData.get(dataBytes);
STRING_SCHEMA.deserialize(dataBytes, output);
}

@Override
Original file line number Diff line number Diff line change
@@ -32,8 +32,10 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -46,6 +48,7 @@
import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.SHARD_GET_RECORDS_MAX;
import static org.apache.flink.connector.kinesis.source.util.KinesisStreamProxyProvider.getTestStreamProxy;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.convertToKinesisClientRecord;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.generateShardId;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.getTestRecord;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.getTestSplit;
@@ -80,7 +83,7 @@ public void init() {

@Test
void testNoAssignedSplitsHandledGracefully() throws Exception {
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
assertThat(retrievedRecords.nextSplit()).isNull();
@@ -95,7 +98,7 @@ void testAssignedSplitHasNoRecordsHandledGracefully() throws Exception {
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

// When fetching records
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

// Then retrieve no records
assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
@@ -116,7 +119,7 @@ void testSplitWithExpiredShardHandledAsCompleted() throws Exception {
splitReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(testSplit)));

// When fetching records
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

// Then retrieve no records and mark split as complete
assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
@@ -128,28 +131,24 @@ void testSplitWithExpiredShardHandledAsCompleted() throws Exception {
void testSingleAssignedSplitAllConsumed() throws Exception {
// Given assigned split with records
testStreamProxy.addShards(TEST_SHARD_ID);
List<Record> expectedRecords =
List<Record> inputRecords =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(0)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(0)));
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(1)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(1)));
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(2)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(2)));
splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

List<KinesisClientRecord> expectedRecords = convertToKinesisClientRecord(inputRecords);

// When fetching records
List<Record> records = new ArrayList<>();
for (int i = 0; i < expectedRecords.size(); i++) {
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
List<KinesisClientRecord> records = new ArrayList<>();
for (int i = 0; i < inputRecords.size(); i++) {
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();
records.addAll(readAllRecords(retrievedRecords));
}

@@ -160,35 +159,66 @@ void testSingleAssignedSplitAllConsumed() throws Exception {
void testMultipleAssignedSplitsAllConsumed() throws Exception {
// Given assigned split with records
testStreamProxy.addShards(TEST_SHARD_ID);
List<Record> expectedRecords =
List<Record> inputRecords =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(0)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(0)));
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(1)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(1)));
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(2)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(2)));
splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

List<KinesisClientRecord> expectedRecords = convertToKinesisClientRecord(inputRecords);

// When records are fetched
List<Record> fetchedRecords = new ArrayList<>();
List<KinesisClientRecord> fetchedRecords = new ArrayList<>();
for (int i = 0; i < expectedRecords.size(); i++) {
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();
fetchedRecords.addAll(readAllRecords(retrievedRecords));
}

// Then all records are fetched
assertThat(fetchedRecords).containsExactlyInAnyOrderElementsOf(expectedRecords);
}

@Test
void testAggregatedRecordsAreDeaggregated() throws Exception {
// Given assigned split with aggregated records
testStreamProxy.addShards(TEST_SHARD_ID);
List<Record> inputRecords =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());

KinesisClientRecord aggregatedRecord = TestUtil.createKinesisAggregatedRecord(inputRecords);
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(TestUtil.convertToRecord(aggregatedRecord)));

splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

List<ByteBuffer> expectedRecords =
convertToKinesisClientRecord(inputRecords).stream()
.map(KinesisClientRecord::data)
.collect(Collectors.toList());

// When fetching records
List<KinesisClientRecord> fetchedRecords = readAllRecords(splitReader.fetch());

// Then all records are fetched
assertThat(fetchedRecords)
.allMatch(KinesisClientRecord::aggregated)
.allMatch(
record ->
record.explicitHashKey().equals(aggregatedRecord.explicitHashKey()))
.extracting("data")
.containsExactlyInAnyOrderElementsOf(expectedRecords);
}

@Test
void testHandleEmptyCompletedShard() throws Exception {
// Given assigned split with no records, and the shard is complete
@@ -199,7 +229,7 @@ void testHandleEmptyCompletedShard() throws Exception {
testStreamProxy.setShouldCompleteNextShard(true);

// When fetching records
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

// Returns completed split with no records
assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
@@ -211,21 +241,23 @@ void testHandleEmptyCompletedShard() throws Exception {
void testFinishedSplitsReturned() throws Exception {
// Given assigned split with records from completed shard
testStreamProxy.addShards(TEST_SHARD_ID);
List<Record> expectedRecords =
List<Record> inputRecords =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());
testStreamProxy.addRecords(TestUtil.STREAM_ARN, TEST_SHARD_ID, expectedRecords);
testStreamProxy.addRecords(TestUtil.STREAM_ARN, TEST_SHARD_ID, inputRecords);
KinesisShardSplit split = getTestSplit(TEST_SHARD_ID);
splitReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split)));

// When fetching records
List<Record> fetchedRecords = new ArrayList<>();
List<KinesisClientRecord> fetchedRecords = new ArrayList<>();
testStreamProxy.setShouldCompleteNextShard(true);
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

List<KinesisClientRecord> expectedRecords = convertToKinesisClientRecord(inputRecords);

// Then records can be read successfully, with finishedSplit returned once all records are
// completed
for (int i = 0; i < expectedRecords.size(); i++) {
for (int i = 0; i < inputRecords.size(); i++) {
assertThat(retrievedRecords.nextSplit()).isEqualTo(split.splitId());
assertThat(retrievedRecords.finishedSplits()).isEmpty();
fetchedRecords.add(retrievedRecords.nextRecordFromSplit());
@@ -245,21 +277,19 @@ void testPauseOrResumeSplits() throws Exception {
testStreamProxy.addShards(TEST_SHARD_ID);
KinesisShardSplit testSplit = getTestSplit(TEST_SHARD_ID);

List<Record> expectedRecords =
List<Record> inputRecords =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"))
.collect(Collectors.toList());
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(0)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(0)));
testStreamProxy.addRecords(
TestUtil.STREAM_ARN,
TEST_SHARD_ID,
Collections.singletonList(expectedRecords.get(1)));
TestUtil.STREAM_ARN, TEST_SHARD_ID, Collections.singletonList(inputRecords.get(1)));
splitReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(testSplit)));

List<KinesisClientRecord> expectedRecords = convertToKinesisClientRecord(inputRecords);

// read data from split
RecordsWithSplitIds<Record> records = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> records = splitReader.fetch();
assertThat(readAllRecords(records)).containsExactlyInAnyOrder(expectedRecords.get(0));

// pause split
@@ -330,12 +360,15 @@ record ->
Arrays.asList(testSplit1, testSplit3), Collections.emptyList());

// read data from splits and verify that only records from split 2 were fetched by reader
List<Record> fetchedRecords = new ArrayList<>();
List<KinesisClientRecord> fetchedRecords = new ArrayList<>();
for (int i = 0; i < 10; i++) {
RecordsWithSplitIds<Record> records = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> records = splitReader.fetch();
fetchedRecords.addAll(readAllRecords(records));
}
assertThat(fetchedRecords).containsExactly(recordsFromSplit2.toArray(new Record[0]));

List<KinesisClientRecord> expectedRecordsFromSplit2 =
convertToKinesisClientRecord(recordsFromSplit2);
assertThat(fetchedRecords).containsExactlyElementsOf(expectedRecordsFromSplit2);

// resume split 3
splitReader.pauseOrResumeSplits(
@@ -344,10 +377,13 @@ record ->
// read data from splits and verify that only records from split 3 had been read
fetchedRecords.clear();
for (int i = 0; i < 10; i++) {
RecordsWithSplitIds<Record> records = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> records = splitReader.fetch();
fetchedRecords.addAll(readAllRecords(records));
}
assertThat(fetchedRecords).containsExactly(recordsFromSplit3.toArray(new Record[0]));

List<KinesisClientRecord> expectedRecordsFromSplit3 =
convertToKinesisClientRecord(recordsFromSplit3);
assertThat(fetchedRecords).containsExactlyElementsOf(expectedRecordsFromSplit3);
}

@Test
@@ -388,16 +424,17 @@ void testMaxRecordsToGetParameterPassed() throws IOException {
splitReader.handleSplitsChanges(
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
List<Record> records = new ArrayList<>(readAllRecords(retrievedRecords));
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();
List<KinesisClientRecord> records = new ArrayList<>(readAllRecords(retrievedRecords));

assertThat(sentRecords.size() > maxRecordsToGet).isTrue();
assertThat(records.size()).isEqualTo(maxRecordsToGet);
}

private List<Record> readAllRecords(RecordsWithSplitIds<Record> recordsWithSplitIds) {
List<Record> outputRecords = new ArrayList<>();
Record record;
private List<KinesisClientRecord> readAllRecords(
RecordsWithSplitIds<KinesisClientRecord> recordsWithSplitIds) {
List<KinesisClientRecord> outputRecords = new ArrayList<>();
KinesisClientRecord record;
do {
record = recordsWithSplitIds.nextRecordFromSplit();
if (record != null) {
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.connector.kinesis.source.reader;

import org.apache.flink.connector.kinesis.source.util.TestUtil;

import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.kinesis.model.Record;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.apache.flink.connector.kinesis.source.util.TestUtil.getTestRecord;
import static org.apache.flink.connector.kinesis.source.util.TestUtil.getTestSplit;
import static org.assertj.core.api.Assertions.assertThat;

class RecordBatchTest {

@Test
public void testDeaggregateRecordsPassThrough() {
List<Record> records =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());

RecordBatch result = new RecordBatch(records, getTestSplit(), 100L, true);

assertThat(result.getDeaggregatedRecords().size()).isEqualTo(3);
}

@Test
public void testDeaggregateRecordsWithAggregatedRecords() {
List<Record> records =
Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3"))
.collect(Collectors.toList());

Record aggregatedRecord = TestUtil.createAggregatedRecord(records);

RecordBatch result =
new RecordBatch(
Collections.singletonList(aggregatedRecord), getTestSplit(), 100L, true);

assertThat(result.getDeaggregatedRecords().size()).isEqualTo(3);
}

@Test
public void testGetMillisBehindLatest() {
RecordBatch result =
new RecordBatch(
Collections.singletonList(getTestRecord("data-1")),
getTestSplit(),
100L,
true);

assertThat(result.getMillisBehindLatest()).isEqualTo(100L);
}

@Test
public void testIsCompleted() {
RecordBatch result =
new RecordBatch(
Collections.singletonList(getTestRecord("data-1")),
getTestSplit(),
100L,
true);

assertThat(result.isCompleted()).isTrue();
}
}
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.time.Duration;
import java.util.ArrayList;
@@ -51,7 +51,7 @@ public class FanOutKinesisShardSplitReaderTest {
private static final String TEST_SHARD_ID = TestUtil.generateShardId(1);
private static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000);

SplitReader<Record, KinesisShardSplit> splitReader;
SplitReader<KinesisClientRecord, KinesisShardSplit> splitReader;

private AsyncStreamProxy testAsyncStreamProxy;
private Map<String, KinesisShardMetrics> shardMetricGroupMap;
@@ -78,7 +78,7 @@ public void testNoAssignedSplitsHandledGracefully() throws Exception {
CONSUMER_ARN,
shardMetricGroupMap,
TEST_SUBSCRIPTION_TIMEOUT);
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
assertThat(retrievedRecords.nextSplit()).isNull();
@@ -99,7 +99,7 @@ public void testAssignedSplitHasNoRecordsHandledGracefully() throws Exception {
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

// When fetching records
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

// Then retrieve no records
assertThat(retrievedRecords.nextRecordFromSplit()).isNull();
@@ -122,7 +122,7 @@ public void testSplitWithExpiredShardHandledAsCompleted() throws Exception {
new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID))));

// When fetching records
RecordsWithSplitIds<Record> retrievedRecords = splitReader.fetch();
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = splitReader.fetch();

// Then shard is marked as completed
// Then retrieve no records and mark split as complete
@@ -169,28 +169,28 @@ public void testCloseClosesStreamProxy() throws Exception {
}

private void consumeAllRecordsFromKinesis(
SplitReader<Record, KinesisShardSplit> splitReader, int numRecords) {
SplitReader<KinesisClientRecord, KinesisShardSplit> splitReader, int numRecords) {
consumeRecordsFromKinesis(splitReader, numRecords, true);
}

private void consumeSomeRecordsFromKinesis(
SplitReader<Record, KinesisShardSplit> splitReader, int numRecords) {
SplitReader<KinesisClientRecord, KinesisShardSplit> splitReader, int numRecords) {
consumeRecordsFromKinesis(splitReader, numRecords, false);
}

private void consumeRecordsFromKinesis(
SplitReader<Record, KinesisShardSplit> splitReader,
SplitReader<KinesisClientRecord, KinesisShardSplit> splitReader,
int numRecords,
boolean checkForShardCompletion) {
// Set timeout to prevent infinite loop on failure
assertTimeoutPreemptively(
Duration.ofSeconds(60),
() -> {
int numRetrievedRecords = 0;
RecordsWithSplitIds<Record> retrievedRecords = null;
RecordsWithSplitIds<KinesisClientRecord> retrievedRecords = null;
while (numRetrievedRecords < numRecords) {
retrievedRecords = splitReader.fetch();
List<Record> records = readAllRecords(retrievedRecords);
List<KinesisClientRecord> records = readAllRecords(retrievedRecords);
numRetrievedRecords += records.size();
}
assertThat(numRetrievedRecords).isEqualTo(numRecords);
@@ -206,9 +206,10 @@ private void consumeRecordsFromKinesis(
"did not receive expected " + numRecords + " records within 10 seconds.");
}

private List<Record> readAllRecords(RecordsWithSplitIds<Record> recordsWithSplitIds) {
List<Record> outputRecords = new ArrayList<>();
Record record;
private List<KinesisClientRecord> readAllRecords(
RecordsWithSplitIds<KinesisClientRecord> recordsWithSplitIds) {
List<KinesisClientRecord> outputRecords = new ArrayList<>();
KinesisClientRecord record;
do {
record = recordsWithSplitIds.nextRecordFromSplit();
if (record != null) {
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.google.common.collect.Lists.partition;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.commons.lang3.RandomStringUtils.randomAlphabetic;
import static org.apache.flink.connector.kinesis.source.split.StartingPositionUtil.toSdkStartingPosition;
@@ -121,7 +122,12 @@ List<SubscribeToShardEvent> getEventsToSend() {
records.add(createRecord(sequenceNumber));
}

eventBuilder.records(records);
List<Record> aggregatedRecords =
partition(records, aggregationFactor).stream()
.map(TestUtil::createAggregatedRecord)
.collect(Collectors.toList());

eventBuilder.records(aggregatedRecords);

String continuation =
sequenceNumber.get() < totalRecords
Original file line number Diff line number Diff line change
@@ -27,17 +27,23 @@
import org.apache.flink.metrics.Gauge;
import org.apache.flink.metrics.testutils.MetricListener;

import com.amazonaws.kinesis.agg.AggRecord;
import com.amazonaws.kinesis.agg.RecordAggregator;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.model.HashKeyRange;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.assertj.core.api.Assertions.assertThat;
@@ -161,6 +167,43 @@ public static Record getTestRecord(String data) {
.build();
}

public static Record convertToRecord(KinesisClientRecord record) {
return Record.builder()
.data(SdkBytes.fromByteBuffer(record.data()))
.approximateArrivalTimestamp(record.approximateArrivalTimestamp())
.build();
}

public static List<KinesisClientRecord> convertToKinesisClientRecord(List<Record> records) {
return records.stream().map(KinesisClientRecord::fromRecord).collect(Collectors.toList());
}

public static Record createAggregatedRecord(List<Record> records) {
KinesisClientRecord aggregatedRecord = createKinesisAggregatedRecord(records);
return convertToRecord(aggregatedRecord);
}

public static KinesisClientRecord createKinesisAggregatedRecord(List<Record> records) {
RecordAggregator recordAggregator = new RecordAggregator();

for (Record record : records) {
try {
recordAggregator.addUserRecord("key", record.data().asByteArray());
} catch (Exception e) {
throw new RuntimeException("Failed to add record to aggregator", e);
}
}

AggRecord aggRecord = recordAggregator.clearAndGet();

return KinesisClientRecord.builder()
.data(ByteBuffer.wrap(aggRecord.toRecordBytes()))
.partitionKey(aggRecord.getPartitionKey())
.explicitHashKey(aggRecord.getExplicitHashKey())
.approximateArrivalTimestamp(Instant.now())
.build();
}

public static void assertMillisBehindLatest(
KinesisShardSplit split, long expectedValue, MetricListener metricListener) {
Arn kinesisArn = Arn.fromString(split.getStreamArn());
10 changes: 5 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
@@ -332,11 +332,6 @@ under the License.
<artifactId>javassist</artifactId>
<version>3.24.0-GA</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.25.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@@ -347,6 +342,11 @@ under the License.
<artifactId>amazon-kinesis-client</artifactId>
<version>1.14.8</version>
</dependency>
<dependency>
<groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId>
<version>3.0.1</version>
</dependency>
<dependency>
<groupId>com.squareup.okio</groupId>
<artifactId>okio</artifactId>