Skip to content

Commit

Permalink
CNDB-11760: Prevent full deserialization in CQL's CONTAINS operator (#…
Browse files Browse the repository at this point in the history
…1441)

Prevent full deserialization in CQL's CONTAINS/CONTAINS KEY operators. 
This also solves a minor correctness issue by making contains equality use on the column type comparator, 
instead of using the Object#equals method of the deserialized values.
  • Loading branch information
adelapena authored Nov 29, 2024
1 parent eae3180 commit bc17f2a
Show file tree
Hide file tree
Showing 10 changed files with 519 additions and 22 deletions.
24 changes: 3 additions & 21 deletions src/java/org/apache/cassandra/cql3/Operator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -132,23 +130,8 @@ public String toString()
@Override
public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer)
{
switch(((CollectionType<?>) type).kind)
{
case LIST :
ListType<?> listType = (ListType<?>) type;
List<?> list = listType.getSerializer().deserialize(leftOperand);
return list.contains(listType.getElementsType().getSerializer().deserialize(rightOperand));
case SET:
SetType<?> setType = (SetType<?>) type;
Set<?> set = setType.getSerializer().deserialize(leftOperand);
return set.contains(setType.getElementsType().getSerializer().deserialize(rightOperand));
case MAP:
MapType<?, ?> mapType = (MapType<?, ?>) type;
Map<?, ?> map = mapType.getSerializer().deserialize(leftOperand);
return map.containsValue(mapType.getValuesType().getSerializer().deserialize(rightOperand));
default:
throw new AssertionError();
}
CollectionType<?> collectionType = (CollectionType<?>) type;
return collectionType.contains(leftOperand, rightOperand);
}
},
CONTAINS_KEY(6)
Expand All @@ -163,8 +146,7 @@ public String toString()
public boolean isSatisfiedBy(AbstractType<?> type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer)
{
MapType<?, ?> mapType = (MapType<?, ?>) type;
Map<?, ?> map = mapType.getSerializer().deserialize(leftOperand);
return map.containsKey(mapType.getKeysType().getSerializer().deserialize(rightOperand));
return mapType.containsKey(leftOperand, rightOperand);
}
},

Expand Down
8 changes: 8 additions & 0 deletions src/java/org/apache/cassandra/db/marshal/CollectionType.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ protected boolean equalsNoFrozenNoSubtypes(AbstractType<?> that)
return kind == ((CollectionType<?>)that).kind;
}

/**
* Checks if the specified serialized collection contains the specified serialized collection element.
*
* @param element a serialized collection element
* @return {@code true} if the collection contains the value, {@code false} otherwise
*/
public abstract boolean contains(ByteBuffer collection, ByteBuffer element);

private static class CollectionPathSerializer implements CellPath.Serializer
{
public void serialize(CellPath path, DataOutputPlus out) throws IOException
Expand Down
6 changes: 6 additions & 0 deletions src/java/org/apache/cassandra/db/marshal/ListType.java
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,10 @@ public boolean isList()
{
return true;
}

@Override
public boolean contains(ByteBuffer list, ByteBuffer element)
{
return CollectionSerializer.contains(getElementsType(), list, element, false, false, ProtocolVersion.V3);
}
}
26 changes: 25 additions & 1 deletion src/java/org/apache/cassandra/db/marshal/MapType.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.cql3.Json;
import org.apache.cassandra.cql3.Maps;
import org.apache.cassandra.cql3.Term;
Expand Down Expand Up @@ -346,4 +345,29 @@ public String toJSONString(ByteBuffer buffer, ProtocolVersion protocolVersion)
}
return sb.append('}').toString();
}

/**
* Checks if the specified serialized map contains the specified serialized map value.
*
* @param map a serialized map
* @param value a serialized map value
* @return {@code true} if the map contains the value, {@code false} otherwise
*/
@Override
public boolean contains(ByteBuffer map, ByteBuffer value)
{
return CollectionSerializer.contains(getValuesType(), map, value, true, false, ProtocolVersion.V3);
}

/**
* Checks if the specified serialized map contains the specified serialized map key.
*
* @param map a serialized map
* @param key a serialized map key
* @return {@code true} if the map contains the key, {@code false} otherwise
*/
public boolean containsKey(ByteBuffer map, ByteBuffer key)
{
return CollectionSerializer.contains(getKeysType(), map, key, true, true, ProtocolVersion.V3);
}
}
7 changes: 7 additions & 0 deletions src/java/org/apache/cassandra/db/marshal/SetType.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.cassandra.db.rows.Cell;
import org.apache.cassandra.exceptions.ConfigurationException;
import org.apache.cassandra.exceptions.SyntaxException;
import org.apache.cassandra.serializers.CollectionSerializer;
import org.apache.cassandra.serializers.MarshalException;
import org.apache.cassandra.serializers.SetSerializer;
import org.apache.cassandra.transport.ProtocolVersion;
Expand Down Expand Up @@ -190,4 +191,10 @@ public String toJSONString(ByteBuffer buffer, ProtocolVersion protocolVersion)
{
return ListType.setOrListToJsonString(buffer, getElementsType(), protocolVersion);
}

@Override
public boolean contains(ByteBuffer set, ByteBuffer element)
{
return CollectionSerializer.contains(getElementsType(), set, element, false, false, ProtocolVersion.V3);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,47 @@ protected ByteBuffer copyAsNewCollection(ByteBuffer input, int count, int startP
ByteBufferUtil.copyBytes(input, startPos, output, sizeLen, bodyLen);
return output;
}

/**
* Checks if the specified serialized collection contains the specified serialized collection element.
*
* @param elementType the type of the collection elements
* @param collection a serialized collection
* @param element a serialized collection element
* @param hasKeys whether the collection has keys, that is, it's a map
* @param getKeys whether to check keys or values
* @param version the protocol version uses for serialization
* @return {@code true} if the collection contains the element, {@code false} otherwise
*/
public static boolean contains(AbstractType<?> elementType,
ByteBuffer collection,
ByteBuffer element,
boolean hasKeys,
boolean getKeys,
ProtocolVersion version)
{
assert hasKeys || !getKeys;
int size = readCollectionSize(collection, ByteBufferAccessor.instance, version);
int offset = sizeOfCollectionSize(size, version);

for (int i = 0; i < size; i++)
{
// read the key (if the collection has keys)
if (hasKeys)
{
ByteBuffer key = readValue(collection, ByteBufferAccessor.instance, offset, version);
if (getKeys && elementType.compare(key, element) == 0)
return true;
offset += sizeOfValue(key, ByteBufferAccessor.instance, version);
}

// read the value
ByteBuffer value = readValue(collection, ByteBufferAccessor.instance, offset, version);
if (!getKeys && elementType.compare(value, element) == 0)
return true;
offset += sizeOfValue(value, ByteBufferAccessor.instance, version);
}

return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.test.microbench;


import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.CollectionType;
import org.apache.cassandra.db.marshal.ListType;
import org.apache.cassandra.db.marshal.MapType;
import org.apache.cassandra.db.marshal.SetType;
import org.apache.cassandra.utils.AbstractTypeGenerators;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;

import static org.apache.cassandra.utils.AbstractTypeGenerators.getTypeSupport;
import static org.quicktheories.QuickTheory.qt;

/**
* Benchmarks {@link org.apache.cassandra.cql3.Operator#CONTAINS} and {@link org.apache.cassandra.cql3.Operator#CONTAINS_KEY}
* comparing calls to {@link CollectionType#contains(ByteBuffer, ByteBuffer)} to the full collection deserialization
* followed by a call to {@link java.util.Collection#contains(Object)} that was done before CNDB-11760.
*/
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 1, time = 1) // seconds
@Measurement(iterations = 3, time = 1) // seconds
@Fork(value = 4)
@Threads(4)
@State(Scope.Benchmark)
public class CollectionContainsTest
{
@Param({ "INT", "TEXT" })
public String type;

@Param({ "1", "10", "100", "1000" })
public int collectionSize;

private ListType<?> listType;
private SetType<?> setType;
private MapType<?, ?> mapType;

private ByteBuffer list;
private ByteBuffer set;
private ByteBuffer map;

private final List<ByteBuffer> values = new ArrayList<>();

@Setup(Level.Trial)
public void setup() throws Throwable
{
AbstractType<?> elementsType = CQL3Type.Native.valueOf(type).getType();
setup(elementsType);
}

private <T> void setup(AbstractType<T> elementsType)
{
ListType<T> listType = ListType.getInstance(elementsType, false);
SetType<T> setType = SetType.getInstance(elementsType, false);
MapType<T, T> mapType = MapType.getInstance(elementsType, elementsType, false);

List<T> listValues = new ArrayList<>();
Set<T> setValues = new HashSet<>();
Map<T, T> mapValues = new HashMap<>();

AbstractTypeGenerators.TypeSupport<T> support = getTypeSupport(elementsType);
qt().withExamples(collectionSize).forAll(support.valueGen).checkAssert(value -> {
listValues.add(value);
setValues.add(value);
mapValues.put(value, value);
});

list = listType.decompose(listValues);
set = setType.decompose(setValues);
map = mapType.decompose(mapValues);

this.listType = listType;
this.setType = setType;
this.mapType = mapType;

qt().withExamples(100).forAll(support.bytesGen()).checkAssert(values::add);
}

@Benchmark
public Object listContainsNonDeserializing()
{
return test(v -> listType.contains(list, v));
}

@Benchmark
public Object listContainsDeserializing()
{
return test(v -> listType.compose(list).contains(listType.getElementsType().compose(v)));
}

@Benchmark
public Object setContainsNonDeserializing()
{
return test(v -> setType.contains(set, v));
}

@Benchmark
public Object setContainsDeserializing()
{
return test(v -> setType.compose(set).contains(setType.getElementsType().compose(v)));
}

@Benchmark
public Object mapContainsNonDeserializing()
{
return test(v -> mapType.contains(map, v));
}

@Benchmark
public Object mapContainsDeserializing()
{
return test(v -> mapType.compose(map).containsValue(mapType.getValuesType().compose(v)));
}

@Benchmark
public Object mapContainsKeyNonDeserializing()
{
return test(v -> mapType.containsKey(map, v));
}

@Benchmark
public Object mapContainsKeyDeserializing()
{
return test(v -> mapType.compose(map).containsKey(mapType.getKeysType().compose(v)));
}

private int test(Function<ByteBuffer, Boolean> containsFunction)
{
int contained = 0;
for (ByteBuffer v : values)
{
if (containsFunction.apply(v))
contained++;
}
return contained;
}
}
Loading

0 comments on commit bc17f2a

Please sign in to comment.