Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-34996][Connectors/Kafka] Use UserCodeCL to instantiate Deserializer #89

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,12 @@ class KafkaSerializerWrapper<IN> implements SerializationSchema<IN> {
this(serializerClass, isKey, Collections.emptyMap(), topicSelector);
}

@SuppressWarnings("unchecked")
@Override
public void open(InitializationContext context) throws Exception {
final ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader();
try (TemporaryClassLoaderContext ignored =
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
serializer =
InstantiationUtil.instantiate(
serializerClass.getName(),
Serializer.class,
getClass().getClassLoader());
initializeSerializer(userCodeClassLoader);

if (serializer instanceof Configurable) {
((Configurable) serializer).configure(config);
Expand All @@ -88,4 +83,11 @@ public byte[] serialize(IN element) {
checkState(serializer != null, "Call open() once before trying to serialize elements.");
return serializer.serialize(topicSelector.apply(element), element);
}

@SuppressWarnings("unchecked")
protected void initializeSerializer(ClassLoader classLoader) throws Exception {
serializer =
InstantiationUtil.instantiate(
serializerClass.getName(), Serializer.class, classLoader);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@ class KafkaValueOnlyDeserializerWrapper<T> implements KafkaRecordDeserialization
}

@Override
@SuppressWarnings("unchecked")
public void open(DeserializationSchema.InitializationContext context) throws Exception {
ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader();
try (TemporaryClassLoaderContext ignored =
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
deserializer =
(Deserializer<T>)
InstantiationUtil.instantiate(
deserializerClass.getName(),
Deserializer.class,
getClass().getClassLoader());
initializeDeserializer(userCodeClassLoader);

if (deserializer instanceof Configurable) {
((Configurable) deserializer).configure(config);
Expand Down Expand Up @@ -103,4 +97,11 @@ public void deserialize(ConsumerRecord<byte[], byte[]> record, Collector<T> coll
public TypeInformation<T> getProducedType() {
return TypeExtractor.createTypeInfo(Deserializer.class, deserializerClass, 0, null, null);
}

@SuppressWarnings("unchecked")
protected void initializeDeserializer(ClassLoader classLoader) throws Exception {
deserializer =
InstantiationUtil.instantiate(
deserializerClass.getName(), Deserializer.class, classLoader);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.flink.connector.kafka.sink;

import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.connector.testutils.formats.DummyInitializationContext;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.util.FlinkUserCodeClassLoaders;
import org.apache.flink.util.SimpleUserCodeClassLoader;
import org.apache.flink.util.UserCodeClassLoader;

import org.apache.kafka.common.serialization.StringSerializer;
import org.junit.Test;

import java.net.URL;

import static org.junit.Assert.assertEquals;

/** Tests for {@link KafkaSerializerWrapper}. */
public class KafkaSerializerWrapperTest {
@Test
public void testUserCodeClassLoaderIsUsed() throws Exception {
final KafkaSerializerWrapperCaptureForTest wrapper =
new KafkaSerializerWrapperCaptureForTest();
final ClassLoader classLoader =
FlinkUserCodeClassLoaders.childFirst(
new URL[0],
getClass().getClassLoader(),
new String[0],
throwable -> {},
true);
wrapper.open(
new SerializationSchema.InitializationContext() {
@Override
public MetricGroup getMetricGroup() {
return new UnregisteredMetricsGroup();
}

@Override
public UserCodeClassLoader getUserCodeClassLoader() {
return SimpleUserCodeClassLoader.create(classLoader);
}
});

assertEquals(classLoader, wrapper.getClassLoaderUsed());
}

@Test
public void testDefaultClassLoaderIsUsed() throws Exception {
final KafkaSerializerWrapperCaptureForTest wrapper =
new KafkaSerializerWrapperCaptureForTest();
wrapper.open(new DummyInitializationContext());

assertEquals(
DummyInitializationContext.class.getClassLoader(), wrapper.getClassLoaderUsed());
}

static class KafkaSerializerWrapperCaptureForTest extends KafkaSerializerWrapper<String> {
private ClassLoader classLoaderUsed;

KafkaSerializerWrapperCaptureForTest() {
super(StringSerializer.class, true, (value) -> "topic");
}

public ClassLoader getClassLoaderUsed() {
return classLoaderUsed;
}

@Override
protected void initializeSerializer(ClassLoader classLoader) throws Exception {
classLoaderUsed = classLoader;
super.initializeSerializer(classLoader);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.apache.flink.connector.kafka.source.reader.deserializer;

import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.connector.testutils.formats.DummyInitializationContext;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.util.FlinkUserCodeClassLoaders;
import org.apache.flink.util.SimpleUserCodeClassLoader;
import org.apache.flink.util.UserCodeClassLoader;

import org.apache.kafka.common.serialization.StringDeserializer;
import org.junit.Test;

import java.net.URL;
import java.util.HashMap;

import static org.junit.Assert.assertEquals;

/** Tests for {@link KafkaValueOnlyDeserializerWrapper}. */
public class KafkaValueOnlyDeserializerWrapperTest {
@Test
public void testUserCodeClassLoaderIsUsed() throws Exception {
final KafkaValueOnlyDeserializerWrapperCaptureForTest wrapper =
new KafkaValueOnlyDeserializerWrapperCaptureForTest();
final ClassLoader classLoader =
FlinkUserCodeClassLoaders.childFirst(
new URL[0],
getClass().getClassLoader(),
new String[0],
throwable -> {},
true);
wrapper.open(
new DeserializationSchema.InitializationContext() {
@Override
public MetricGroup getMetricGroup() {
return new UnregisteredMetricsGroup();
}

@Override
public UserCodeClassLoader getUserCodeClassLoader() {
return SimpleUserCodeClassLoader.create(classLoader);
}
});

assertEquals(classLoader, wrapper.getClassLoaderUsed());
}

@Test
public void testDefaultClassLoaderIsUsed() throws Exception {
final KafkaValueOnlyDeserializerWrapperCaptureForTest wrapper =
new KafkaValueOnlyDeserializerWrapperCaptureForTest();
wrapper.open(new DummyInitializationContext());

assertEquals(
DummyInitializationContext.class.getClassLoader(), wrapper.getClassLoaderUsed());
}

static class KafkaValueOnlyDeserializerWrapperCaptureForTest
extends KafkaValueOnlyDeserializerWrapper<String> {
private ClassLoader classLoaderUsed;

KafkaValueOnlyDeserializerWrapperCaptureForTest() {
super(StringDeserializer.class, new HashMap<>());
}

public ClassLoader getClassLoaderUsed() {
return classLoaderUsed;
}

@Override
protected void initializeDeserializer(ClassLoader classLoader) throws Exception {
classLoaderUsed = classLoader;
super.initializeDeserializer(classLoader);
}
}
}
Loading