Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HDDS-12583. Skip Fields in protobuf while deserializing protobuf #8068

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,63 @@

package org.apache.hadoop.hdds.utils.db;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import com.google.protobuf.WireFormat;
import jakarta.annotation.Nonnull;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.ratis.util.function.CheckedFunction;

/**
* Codecs to serialize/deserialize Protobuf v2 messages.
*/
public final class Proto2Codec<M extends MessageLite>
implements Codec<M> {
private static final ConcurrentMap<Class<? extends MessageLite>,
private static final ConcurrentMap<Pair<Class<? extends MessageLite>, Set<String>>,
Codec<? extends MessageLite>> CODECS
= new ConcurrentHashMap<>();

/**
* @return the {@link Codec} for the given class.
*/
public static <T extends MessageLite> Codec<T> get(T t) {
final Codec<?> codec = CODECS.computeIfAbsent(t.getClass(),
key -> new Proto2Codec<>(t));
return get(t, Collections.emptySet());
}

/**
* @return the {@link Codec} for the given class.
*/
public static <T extends MessageLite> Codec<T> get(T t, Set<String> fieldsToBeSkipped) {
final Codec<?> codec = CODECS.computeIfAbsent(Pair.of(t.getClass(), fieldsToBeSkipped),
key -> new Proto2Codec<>(t, fieldsToBeSkipped));
return (Codec<T>) codec;
}

private final Class<M> clazz;
private final Parser<M> parser;
private final Descriptors.Descriptor descriptor;
private final Supplier<Message.Builder> builderSupplier;
private final Set<String> fieldsToBeSkipped;

private Proto2Codec(M m) {
private Proto2Codec(M m, Set<String> fieldsToBeSkipped) {
this.clazz = (Class<M>) m.getClass();
this.parser = (Parser<M>) m.getParserForType();
this.descriptor = ((Message)m).getDescriptorForType();
this.fieldsToBeSkipped = fieldsToBeSkipped;
this.builderSupplier = ((Message)m)::newBuilderForType;
}

@Override
Expand Down Expand Up @@ -83,19 +105,124 @@ private CheckedFunction<OutputStream, Integer, IOException> writeTo(
public M fromCodecBuffer(@Nonnull CodecBuffer buffer)
throws IOException {
try (InputStream in = buffer.getInputStream()) {
return parser.parseFrom(in);
if (this.fieldsToBeSkipped.isEmpty()) {
return parser.parseFrom(in);
} else {
return parse(CodedInputStream.newInstance(in));
}
}
}

private Object getValue(CodedInputStream input, Descriptors.FieldDescriptor field) throws IOException {
Object value;
switch (field.getType()) {
case DOUBLE:
value = input.readDouble();
break;
case FLOAT:
value = input.readFloat();
break;
case INT64:
value = input.readInt64();
break;
case UINT64:
value = input.readUInt64();
break;
case INT32:
value = input.readInt32();
break;
case FIXED64:
value = input.readFixed64();
break;
case FIXED32:
value = input.readFixed32();
break;
case BOOL:
value = input.readBool();
break;
case STRING:
value = input.readString();
break;
case GROUP:
case MESSAGE:
value = DynamicMessage.newBuilder(field.getMessageType());
input.readMessage((MessageLite.Builder) value,
ExtensionRegistryLite.getEmptyRegistry());
value = ((MessageLite.Builder) value).build();
break;
case BYTES:
value = input.readBytes();
break;
case UINT32:
value = input.readUInt32();
break;
case ENUM:
value = field.getEnumType().findValueByNumber(input.readEnum());
System.out.println(((Descriptors.EnumValueDescriptor)value).getName());
break;
case SFIXED32:
value = input.readSFixed32();
break;
case SFIXED64:
value = input.readSFixed64();
break;
case SINT32:
value = input.readSInt32();
break;
case SINT64:
value = input.readSInt64();
break;
default:
throw new UnsupportedOperationException();
}
System.out.println(field.getName() + ": " + value);
return value;
}

private M parse(CodedInputStream codedInputStream) throws IOException {
Message.Builder builder = this.builderSupplier.get();
while (!codedInputStream.isAtEnd()) {
int tag = codedInputStream.readTag();

if (tag == 0) {
break;
}
int fieldNumber = WireFormat.getTagFieldNumber(tag);

final Descriptors.FieldDescriptor field = descriptor.findFieldByNumber(fieldNumber);
if (field != null && !this.fieldsToBeSkipped.contains(field.getName())) {
try {
Object value = getValue(codedInputStream, field);
if (field.isRepeated()) {
builder.addRepeatedField(field, value);
} else {
builder.setField(field, value);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
codedInputStream.skipField(tag);
}
}
return (M) builder.build();
}

@Override
public byte[] toPersistedFormat(M message) {
return message.toByteArray();
}


@Override
public M fromPersistedFormat(byte[] bytes)
throws InvalidProtocolBufferException {
return parser.parseFrom(bytes);
throws IOException {
if (fieldsToBeSkipped.isEmpty()) {
return parser.parseFrom(bytes);
} else {
return parse(CodedInputStream.newInstance(bytes));
}

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,19 @@ default VALUE getReadCopy(KEY key) throws IOException {
TableIterator<KEY, ? extends KeyValue<KEY, VALUE>> iterator()
throws IOException;

/**
* Returns the iterator for this metadata store.
* @param keyCodec
* @param valueCodec
* @param prefix
* @return MetaStoreIterator
* @throws IOException on failure.
*/
default TableIterator<KEY, ? extends KeyValue<KEY, VALUE>> iterator(Codec<KEY> keyCodec, Codec<VALUE> valueCodec,
KEY prefix) throws IOException {
throw new NotImplementedException("iterator is not implemented");
}

/**
* Returns a prefixed iterator for this metadata store.
* @param prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,16 @@ private byte[] encodeValue(VALUE value) throws IOException {
return value == null ? null : valueCodec.toPersistedFormat(value);
}

private KEY decodeKey(byte[] key) throws IOException {
return key == null ? null : keyCodec.fromPersistedFormat(key);
private KEY decodeKey(Codec<KEY> kCodec, byte[] key) throws IOException {
return key == null ? null : kCodec.fromPersistedFormat(key);
}

private VALUE decodeValue(byte[] value) throws IOException {
return value == null ? null : valueCodec.fromPersistedFormat(value);
return decodeValue(valueCodec, value);
}

private VALUE decodeValue(Codec<VALUE> vCodec, byte[] value) throws IOException {
return value == null ? null : vCodec.fromPersistedFormat(value);
}

@Override
Expand Down Expand Up @@ -418,12 +422,12 @@ public Table.KeyValueIterator<KEY, VALUE> iterator() throws IOException {
}

@Override
public Table.KeyValueIterator<KEY, VALUE> iterator(KEY prefix)
public Table.KeyValueIterator<KEY, VALUE> iterator(Codec<KEY> kCodec, Codec<VALUE> vCodec, KEY prefix)
throws IOException {
if (supportCodecBuffer) {
final CodecBuffer prefixBuffer = encodeKeyCodecBuffer(prefix);
try {
return newCodecBufferTableIterator(rawTable.iterator(prefixBuffer));
return newCodecBufferTableIterator(kCodec, vCodec, rawTable.iterator(prefixBuffer));
} catch (Throwable t) {
if (prefixBuffer != null) {
prefixBuffer.release();
Expand All @@ -432,10 +436,15 @@ public Table.KeyValueIterator<KEY, VALUE> iterator(KEY prefix)
}
} else {
final byte[] prefixBytes = encodeKey(prefix);
return new TypedTableIterator(rawTable.iterator(prefixBytes));
return new TypedTableIterator(kCodec, vCodec, rawTable.iterator(prefixBytes));
}
}

@Override
public Table.KeyValueIterator<KEY, VALUE> iterator(KEY prefix) throws IOException {
return iterator(keyCodec, valueCodec, prefix);
}

@Override
public String getName() {
return rawTable.getName();
Expand Down Expand Up @@ -499,7 +508,8 @@ public List<TypedKeyValue> getRangeKVs(
rawTable.getRangeKVs(startKeyBytes, count, prefixBytes, filters);

List<TypedKeyValue> rangeKVs = new ArrayList<>();
rangeKVBytes.forEach(byteKV -> rangeKVs.add(new TypedKeyValue(byteKV)));
rangeKVBytes.forEach(byteKV -> rangeKVs.add(new TypedKeyValue(keyCodec,
valueCodec, byteKV)));

return rangeKVs;
}
Expand All @@ -520,7 +530,8 @@ public List<TypedKeyValue> getSequentialRangeKVs(
prefixBytes, filters);

List<TypedKeyValue> rangeKVs = new ArrayList<>();
rangeKVBytes.forEach(byteKV -> rangeKVs.add(new TypedKeyValue(byteKV)));
rangeKVBytes.forEach(byteKV -> rangeKVs.add(new TypedKeyValue(
keyCodec, valueCodec, byteKV)));

return rangeKVs;
}
Expand Down Expand Up @@ -558,19 +569,23 @@ TableCache<KEY, VALUE> getCache() {
public final class TypedKeyValue implements KeyValue<KEY, VALUE> {

private final KeyValue<byte[], byte[]> rawKeyValue;
private final Codec<KEY> kCodec;
private final Codec<VALUE> vCodec;

private TypedKeyValue(KeyValue<byte[], byte[]> rawKeyValue) {
private TypedKeyValue(Codec<KEY> kCodec, Codec<VALUE> vCodec, KeyValue<byte[], byte[]> rawKeyValue) {
this.kCodec = kCodec;
this.vCodec = vCodec;
this.rawKeyValue = rawKeyValue;
}

@Override
public KEY getKey() throws IOException {
return decodeKey(rawKeyValue.getKey());
return decodeKey(kCodec, rawKeyValue.getKey());
}

@Override
public VALUE getValue() throws IOException {
return decodeValue(rawKeyValue.getValue());
return decodeValue(vCodec, rawKeyValue.getValue());
}

public byte[] getRawKey() throws IOException {
Expand All @@ -582,8 +597,8 @@ public byte[] getRawValue() throws IOException {
}
}

RawIterator<CodecBuffer> newCodecBufferTableIterator(
TableIterator<CodecBuffer, KeyValue<CodecBuffer, CodecBuffer>> i) {
RawIterator<CodecBuffer> newCodecBufferTableIterator(Codec<KEY> kCodec,
Codec<VALUE> vCodec, TableIterator<CodecBuffer, KeyValue<CodecBuffer, CodecBuffer>> i) {
return new RawIterator<CodecBuffer>(i) {
@Override
AutoCloseSupplier<CodecBuffer> convert(KEY key) throws IOException {
Expand Down Expand Up @@ -616,9 +631,14 @@ KeyValue<KEY, VALUE> convert(KeyValue<CodecBuffer, CodecBuffer> raw)
* Table Iterator implementation for strongly typed tables.
*/
public class TypedTableIterator extends RawIterator<byte[]> {
TypedTableIterator(
private final Codec<KEY> kCodec;
private final Codec<VALUE> vCodec;

TypedTableIterator(Codec<KEY> kCodec, Codec<VALUE> vCodec,
TableIterator<byte[], KeyValue<byte[], byte[]>> rawIterator) {
super(rawIterator);
this.kCodec = kCodec;
this.vCodec = vCodec;
}

@Override
Expand All @@ -629,7 +649,7 @@ AutoCloseSupplier<byte[]> convert(KEY key) throws IOException {

@Override
KeyValue<KEY, VALUE> convert(KeyValue<byte[], byte[]> raw) {
return new TypedKeyValue(raw);
return new TypedKeyValue(kCodec, vCodec, raw);
}
}

Expand Down
Loading