diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java index dd53fdf0b1b4c..5e57e7a07c604 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/ArrayKeyIndexType.java @@ -29,8 +29,7 @@ public class ArrayKeyIndexType { @Override public boolean equals(Object o) { - if (o instanceof ArrayKeyIndexType) { - ArrayKeyIndexType other = (ArrayKeyIndexType) o; + if (o instanceof ArrayKeyIndexType other) { return Arrays.equals(key, other.key) && Arrays.equals(id, other.id); } return false; diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java index ebb5c2c5ed55c..81b9044f7f096 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType1.java @@ -36,8 +36,7 @@ public class CustomType1 { @Override public boolean equals(Object o) { - if (o instanceof CustomType1) { - CustomType1 other = (CustomType1) o; + if (o instanceof CustomType1 other) { return id.equals(other.id) && name.equals(other.name); } return false; diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType2.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType2.java index 3bb66bb3ec700..6378f2219a15c 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType2.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/CustomType2.java @@ -30,8 +30,7 @@ public class CustomType2 { @Override public boolean equals(Object o) { - if (o instanceof CustomType2) { - CustomType2 other = (CustomType2) o; + if (o instanceof CustomType2 other) { return id.equals(other.id) && parentId.equals(other.parentId); } return false; diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/IntKeyType.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/IntKeyType.java index f7051246f77b6..8d3190e510a47 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/IntKeyType.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/IntKeyType.java @@ -31,8 +31,7 @@ public class IntKeyType { @Override public boolean equals(Object o) { - if (o instanceof IntKeyType) { - IntKeyType other = (IntKeyType) o; + if (o instanceof IntKeyType other) { return key == other.key && id.equals(other.id) && values.equals(other.values); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 0f1781cbf1f2c..cbad4c61b9b4a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -62,8 +62,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof ChunkFetchFailure) { - ChunkFetchFailure o = (ChunkFetchFailure) other; + if (other instanceof ChunkFetchFailure o) { return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 7b034d5c2f595..2865388b3297c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -56,8 +56,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof ChunkFetchRequest) { - ChunkFetchRequest o = (ChunkFetchRequest) other; + if (other instanceof ChunkFetchRequest o) { return streamChunkId.equals(o.streamChunkId); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index eaad143fc3f5f..aa89b2062f626 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -75,8 +75,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof ChunkFetchSuccess) { - ChunkFetchSuccess o = (ChunkFetchSuccess) other; + if (other instanceof ChunkFetchSuccess o) { return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java index c85d104fdd0fe..3723730ebc06c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java @@ -84,8 +84,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof MergedBlockMetaRequest) { - MergedBlockMetaRequest o = (MergedBlockMetaRequest) other; + if (other instanceof MergedBlockMetaRequest o) { return requestId == o.requestId && shuffleId == o.shuffleId && shuffleMergeId == o.shuffleMergeId && reduceId == o.reduceId && Objects.equal(appId, o.appId); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 06dc447309dd9..00de47dc9fc2d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -59,8 +59,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { in.body().release(); - if (in instanceof AbstractResponseMessage) { - AbstractResponseMessage resp = (AbstractResponseMessage) in; + if (in instanceof AbstractResponseMessage resp) { // Re-encode this message as a failure response. String error = e.getMessage() != null ? e.getMessage() : "null"; logger.error(String.format("Error processing %s for client %s", diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index 719f6c64c5dee..91c818f3612a9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -66,8 +66,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof OneWayMessage) { - OneWayMessage o = (OneWayMessage) other; + if (other instanceof OneWayMessage o) { return super.equals(o); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 6e4f5687d16cd..02a45d68c650e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -60,8 +60,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof RpcFailure) { - RpcFailure o = (RpcFailure) other; + if (other instanceof RpcFailure o) { return requestId == o.requestId && errorString.equals(o.errorString); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index f2609ce2dbdb3..a7dbe1283b314 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -72,8 +72,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof RpcRequest) { - RpcRequest o = (RpcRequest) other; + if (other instanceof RpcRequest o) { return requestId == o.requestId && super.equals(o); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 51b36ea183362..85709e36f83ee 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -72,8 +72,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof RpcResponse) { - RpcResponse o = (RpcResponse) other; + if (other instanceof RpcResponse o) { return requestId == o.requestId && super.equals(o); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java index 29201d135ba93..ae795ca4d1472 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -60,8 +60,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof StreamChunkId) { - StreamChunkId o = (StreamChunkId) other; + if (other instanceof StreamChunkId o) { return streamId == o.streamId && chunkIndex == o.chunkIndex; } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index 06836f5eea390..9a7bf2f65af3a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -62,8 +62,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof StreamFailure) { - StreamFailure o = (StreamFailure) other; + if (other instanceof StreamFailure o) { return streamId.equals(o.streamId) && error.equals(o.error); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 3d035e5c94f23..5906b4d380d6e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -61,8 +61,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof StreamRequest) { - StreamRequest o = (StreamRequest) other; + if (other instanceof StreamRequest o) { return streamId.equals(o.streamId); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index f30605ce836fc..0c0aa5c9a635b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -75,8 +75,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof StreamResponse) { - StreamResponse o = (StreamResponse) other; + if (other instanceof StreamResponse o) { return byteCount == o.byteCount && streamId.equals(o.streamId); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index fb50801a51ba3..4722f39dfa9db 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -91,8 +91,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof UploadStream) { - UploadStream o = (UploadStream) other; + if (other instanceof UploadStream o) { return requestId == o.requestId && super.equals(o); } return false; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 05a5afe195e8c..524ff0a310655 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -127,17 +127,14 @@ private class ClientCallbackHandler implements CallbackHandler { public void handle(Callback[] callbacks) throws UnsupportedCallbackException { for (Callback callback : callbacks) { - if (callback instanceof NameCallback) { + if (callback instanceof NameCallback nc) { logger.trace("SASL client callback: setting username"); - NameCallback nc = (NameCallback) callback; nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); - } else if (callback instanceof PasswordCallback) { + } else if (callback instanceof PasswordCallback pc) { logger.trace("SASL client callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); - } else if (callback instanceof RealmCallback) { + } else if (callback instanceof RealmCallback rc) { logger.trace("SASL client callback: setting realm"); - RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); } else if (callback instanceof RealmChoiceCallback) { // ignore (?) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e22e09d2a22e6..26e5718cb4a70 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -156,20 +156,16 @@ private class DigestCallbackHandler implements CallbackHandler { @Override public void handle(Callback[] callbacks) throws UnsupportedCallbackException { for (Callback callback : callbacks) { - if (callback instanceof NameCallback) { + if (callback instanceof NameCallback nc) { logger.trace("SASL server callback: setting username"); - NameCallback nc = (NameCallback) callback; nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); - } else if (callback instanceof PasswordCallback) { + } else if (callback instanceof PasswordCallback pc) { logger.trace("SASL server callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); - } else if (callback instanceof RealmCallback) { + } else if (callback instanceof RealmCallback rc) { logger.trace("SASL server callback: setting realm"); - RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); - } else if (callback instanceof AuthorizeCallback) { - AuthorizeCallback ac = (AuthorizeCallback) callback; + } else if (callback instanceof AuthorizeCallback ac) { String authId = ac.getAuthenticationID(); String authzId = ac.getAuthorizationID(); ac.setAuthorized(authId.equals(authzId)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f55ca2204cdb4..e12f9120fdbb3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -148,8 +148,7 @@ public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exce /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */ @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof IdleStateEvent) { - IdleStateEvent e = (IdleStateEvent) evt; + if (evt instanceof IdleStateEvent e) { // See class comment for timeout semantics. In addition to ensuring we only timeout while // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 3d7c1b1ca0cc1..efcc83f409eac 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -128,8 +128,7 @@ public StreamCallbackWithID receiveStream( ByteBuffer messageHeader, RpcResponseCallback callback) { BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader); - if (msgObj instanceof PushBlockStream) { - PushBlockStream message = (PushBlockStream) msgObj; + if (msgObj instanceof PushBlockStream message) { checkAuth(client, message.appId); return mergeManager.receiveBlockDataAsStream(message); } else { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java index 39ddf2c2a7ed6..e0971d49510a9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java @@ -131,8 +131,7 @@ protected void decode(ChannelHandlerContext channelHandlerContext, List list) throws Exception { delegate.decode(channelHandlerContext, byteBuf, list); Object msg = list.get(list.size() - 1); - if (msg instanceof RpcRequest) { - RpcRequest req = (RpcRequest) msg; + if (msg instanceof RpcRequest req) { ByteBuffer buffer = req.body().nioByteBuffer(); byte type = Unpooled.wrappedBuffer(buffer).readByte(); if (type == BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE.id()) { @@ -171,8 +170,7 @@ static class FinalizedHandler extends SimpleChannelInboundHandler= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + if (base instanceof byte[] bytes && offset >= BYTE_ARRAY_OFFSET) { // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -1401,8 +1400,7 @@ public int compare(final UTF8String other) { @Override public boolean equals(final Object other) { - if (other instanceof UTF8String) { - UTF8String o = (UTF8String) other; + if (other instanceof UTF8String o) { if (numBytes != o.numBytes) { return false; } diff --git a/core/src/main/java/org/apache/spark/api/java/Optional.java b/core/src/main/java/org/apache/spark/api/java/Optional.java index fd0f495ca29da..362149c92145e 100644 --- a/core/src/main/java/org/apache/spark/api/java/Optional.java +++ b/core/src/main/java/org/apache/spark/api/java/Optional.java @@ -168,10 +168,9 @@ public T orNull() { @Override public boolean equals(Object obj) { - if (!(obj instanceof Optional)) { + if (!(obj instanceof Optional other)) { return false; } - Optional other = (Optional) obj; return Objects.equals(value, other.value); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 125205f416d35..be7c0864c2f4c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -317,10 +317,9 @@ void setConnectionThread(Thread t) { @Override protected void handle(Message msg) throws IOException { try { - if (msg instanceof Hello) { + if (msg instanceof Hello hello) { timeout.cancel(); timeout = null; - Hello hello = (Hello) msg; AbstractAppHandle handle = secretToPendingApps.remove(hello.secret); if (handle != null) { handle.setConnection(this); @@ -334,8 +333,7 @@ protected void handle(Message msg) throws IOException { if (handle == null) { throw new IllegalArgumentException("Expected hello, got: " + msgClassName); } - if (msg instanceof SetAppId) { - SetAppId set = (SetAppId) msg; + if (msg instanceof SetAppId set) { handle.setAppId(set.appId); } else if (msg instanceof SetState) { handle.setState(((SetState)msg).state); diff --git a/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java index 61894b18fe695..17895e73d9fcf 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java +++ b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentifierImpl.java @@ -65,8 +65,7 @@ public String toString() { @Override public boolean equals(Object o) { if (this == o) return true; - if (!(o instanceof IdentifierImpl)) return false; - IdentifierImpl that = (IdentifierImpl) o; + if (!(o instanceof IdentifierImpl that)) return false; return Arrays.equals(namespace, that.namespace) && name.equals(that.name); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index be50350b106a1..91f04c3d327ac 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -57,8 +57,7 @@ public static Object read( if (physicalDataType instanceof PhysicalStringType) { return obj.getUTF8String(ordinal); } - if (physicalDataType instanceof PhysicalDecimalType) { - PhysicalDecimalType dt = (PhysicalDecimalType) physicalDataType; + if (physicalDataType instanceof PhysicalDecimalType dt) { return obj.getDecimal(ordinal, dt.precision(), dt.scale()); } if (physicalDataType instanceof PhysicalCalendarIntervalType) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 6bea714e7d58a..ea6f1e05422b5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -329,8 +329,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof UnsafeArrayData) { - UnsafeArrayData o = (UnsafeArrayData) other; + if (other instanceof UnsafeArrayData o) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index d2433292fc7bd..8f9d5919e1d9f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -538,8 +538,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - if (other instanceof UnsafeRow) { - UnsafeRow o = (UnsafeRow) other; + if (other instanceof UnsafeRow o) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ColumnDefaultValue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ColumnDefaultValue.java index b8e75c11c813a..cc3ff63fb29bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ColumnDefaultValue.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ColumnDefaultValue.java @@ -67,8 +67,7 @@ public Literal getValue() { @Override public boolean equals(Object o) { if (this == o) return true; - if (!(o instanceof ColumnDefaultValue)) return false; - ColumnDefaultValue that = (ColumnDefaultValue) o; + if (!(o instanceof ColumnDefaultValue that)) return false; return sql.equals(that.sql) && value.equals(that.value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 9ca0fe4787f10..e529a8e9250fb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -53,18 +53,14 @@ public String build(Expression expr) { return visitLiteral((Literal) expr); } else if (expr instanceof NamedReference) { return visitNamedReference((NamedReference) expr); - } else if (expr instanceof Cast) { - Cast cast = (Cast) expr; + } else if (expr instanceof Cast cast) { return visitCast(build(cast.expression()), cast.dataType()); - } else if (expr instanceof Extract) { - Extract extract = (Extract) expr; + } else if (expr instanceof Extract extract) { return visitExtract(extract.field(), build(extract.source())); - } else if (expr instanceof SortOrder) { - SortOrder sortOrder = (SortOrder) expr; + } else if (expr instanceof SortOrder sortOrder) { return visitSortOrder( build(sortOrder.expression()), sortOrder.direction(), sortOrder.nullOrdering()); - } else if (expr instanceof GeneralScalarExpression) { - GeneralScalarExpression e = (GeneralScalarExpression) expr; + } else if (expr instanceof GeneralScalarExpression e) { String name = e.name(); switch (name) { case "IN": { @@ -181,26 +177,21 @@ public String build(Expression expr) { default: return visitUnexpectedExpr(expr); } - } else if (expr instanceof Min) { - Min min = (Min) expr; + } else if (expr instanceof Min min) { return visitAggregateFunction("MIN", false, expressionsToStringArray(min.children())); - } else if (expr instanceof Max) { - Max max = (Max) expr; + } else if (expr instanceof Max max) { return visitAggregateFunction("MAX", false, expressionsToStringArray(max.children())); - } else if (expr instanceof Count) { - Count count = (Count) expr; + } else if (expr instanceof Count count) { return visitAggregateFunction("COUNT", count.isDistinct(), expressionsToStringArray(count.children())); - } else if (expr instanceof Sum) { - Sum sum = (Sum) expr; + } else if (expr instanceof Sum sum) { return visitAggregateFunction("SUM", sum.isDistinct(), expressionsToStringArray(sum.children())); } else if (expr instanceof CountStar) { return visitAggregateFunction("COUNT", false, new String[]{"*"}); - } else if (expr instanceof Avg) { - Avg avg = (Avg) expr; + } else if (expr instanceof Avg avg) { return visitAggregateFunction("AVG", avg.isDistinct(), expressionsToStringArray(avg.children())); } else if (expr instanceof GeneralAggregateFunc) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 34870597ee321..31ecf5cbe17f8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -181,14 +181,11 @@ void initAccessor(ValueVector vector) { accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); } else if (vector instanceof TimeStampMicroVector) { accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof MapVector mapVector) { accessor = new MapAccessor(mapVector); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; + } else if (vector instanceof ListVector listVector) { accessor = new ArrayAccessor(listVector); - } else if (vector instanceof StructVector) { - StructVector structVector = (StructVector) vector; + } else if (vector instanceof StructVector structVector) { accessor = new StructAccessor(structVector); childColumns = new ArrowColumnVector[structVector.size()]; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java index da794e4bb9184..c0d2ae8e7d0e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java @@ -68,8 +68,7 @@ public InternalRow copy() { row.update(i, getUTF8String(i).copy()); } else if (pdt instanceof PhysicalBinaryType) { row.update(i, getBinary(i)); - } else if (pdt instanceof PhysicalDecimalType) { - PhysicalDecimalType t = (PhysicalDecimalType)pdt; + } else if (pdt instanceof PhysicalDecimalType t) { row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); } else if (pdt instanceof PhysicalStructType) { row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy()); @@ -169,8 +168,7 @@ public Object get(int ordinal, DataType dataType) { return getUTF8String(ordinal); } else if (dataType instanceof BinaryType) { return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; + } else if (dataType instanceof DecimalType t) { return getDecimal(ordinal, t.precision(), t.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index c4fbc2ff64229..1df4653f55276 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -75,8 +75,7 @@ public InternalRow copy() { row.update(i, getUTF8String(i).copy()); } else if (pdt instanceof PhysicalBinaryType) { row.update(i, getBinary(i)); - } else if (pdt instanceof PhysicalDecimalType) { - PhysicalDecimalType t = (PhysicalDecimalType)pdt; + } else if (pdt instanceof PhysicalDecimalType t) { row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); } else if (pdt instanceof PhysicalStructType) { row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy()); @@ -176,8 +175,7 @@ public Object get(int ordinal, DataType dataType) { return getUTF8String(ordinal); } else if (dataType instanceof BinaryType) { return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; + } else if (dataType instanceof DecimalType t) { return getDecimal(ordinal, t.precision(), t.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java index 89f6996e4610f..6fbd76538aa8f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java @@ -40,8 +40,7 @@ static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { vector instanceof DecimalColumnVector || vector instanceof TimestampColumnVector) { return new OrcAtomicColumnVector(type, vector); - } else if (vector instanceof StructColumnVector) { - StructColumnVector structVector = (StructColumnVector) vector; + } else if (vector instanceof StructColumnVector structVector) { OrcColumnVector[] fields = new OrcColumnVector[structVector.fields.length]; int ordinal = 0; for (StructField f : ((StructType) type).fields()) { @@ -49,13 +48,11 @@ static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { ordinal++; } return new OrcStructColumnVector(type, vector, fields); - } else if (vector instanceof ListColumnVector) { - ListColumnVector listVector = (ListColumnVector) vector; + } else if (vector instanceof ListColumnVector listVector) { OrcColumnVector dataVector = toOrcColumnVector( ((ArrayType) type).elementType(), listVector.child); return new OrcArrayColumnVector(type, vector, dataVector); - } else if (vector instanceof MapColumnVector) { - MapColumnVector mapVector = (MapColumnVector) vector; + } else if (vector instanceof MapColumnVector mapVector) { MapType mapType = (MapType) type; OrcColumnVector keysVector = toOrcColumnVector(mapType.keyType(), mapVector.keys); OrcColumnVector valuesVector = toOrcColumnVector(mapType.valueType(), mapVector.values); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 15d58f0c7572a..d5675db4c3ad9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -1143,8 +1143,7 @@ private static boolean canReadAsBinaryDecimal(ColumnDescriptor descriptor, DataT } private static boolean isLongDecimal(DataType dt) { - if (dt instanceof DecimalType) { - DecimalType d = (DecimalType) dt; + if (dt instanceof DecimalType d) { return d.precision() == 20 && d.scale() == 0; } return false; @@ -1153,8 +1152,7 @@ private static boolean isLongDecimal(DataType dt) { private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) { DecimalType d = (DecimalType) dt; LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); - if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) { - DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation; + if (typeAnnotation instanceof DecimalLogicalTypeAnnotation decimalType) { // It's OK if the required decimal precision is larger than or equal to the physical decimal // precision in the Parquet metadata, as long as the decimal scale is the same. return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 71ea3e9ce097f..baefa254466ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -403,9 +403,8 @@ private boolean containsPath(Type parquetType, String[] path) { private boolean containsPath(Type parquetType, String[] path, int depth) { if (path.length == depth) return true; - if (parquetType instanceof GroupType) { + if (parquetType instanceof GroupType parquetGroupType) { String fieldName = path[depth]; - GroupType parquetGroupType = (GroupType) parquetType; if (parquetGroupType.containsField(fieldName)) { return containsPath(parquetGroupType.getType(fieldName), path, depth + 1); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 6ab81cf404839..7b841ab9933e2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -75,8 +75,7 @@ public static void populate(ConstantColumnVector col, InternalRow row, int field } else if (pdt instanceof PhysicalStringType) { UTF8String v = row.getUTF8String(fieldIdx); col.setUtf8String(v); - } else if (pdt instanceof PhysicalDecimalType) { - PhysicalDecimalType dt = (PhysicalDecimalType) pdt; + } else if (pdt instanceof PhysicalDecimalType dt) { Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale()); if (dt.precision() <= Decimal.MAX_INT_DIGITS()) { col.setInt((int)d.toUnscaledLong()); @@ -151,8 +150,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t == DataTypes.BinaryType) { byte[] b = (byte[]) o; dst.appendByteArray(b, 0, b.length); - } else if (t instanceof DecimalType) { - DecimalType dt = (DecimalType) t; + } else if (t instanceof DecimalType dt) { Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale()); if (dt.precision() <= Decimal.MAX_INT_DIGITS()) { dst.appendInt((int) d.toUnscaledLong()); @@ -182,8 +180,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) { - if (t instanceof ArrayType) { - ArrayType at = (ArrayType)t; + if (t instanceof ArrayType at) { if (src.isNullAt(fieldIdx)) { dst.appendNull(); } else { @@ -193,8 +190,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i appendValue(dst.arrayData(), at.elementType(), o); } } - } else if (t instanceof StructType) { - StructType st = (StructType)t; + } else if (t instanceof StructType st) { if (src.isNullAt(fieldIdx)) { dst.appendStruct(true); } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 64568f18f6858..eda58815f3b3a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -73,8 +73,7 @@ public InternalRow copy() { row.update(i, getUTF8String(i).copy()); } else if (dt instanceof BinaryType) { row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; + } else if (dt instanceof DecimalType t) { row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); } else if (dt instanceof DateType) { row.setInt(i, getInt(i)); @@ -178,8 +177,7 @@ public Object get(int ordinal, DataType dataType) { return getUTF8String(ordinal); } else if (dataType instanceof BinaryType) { return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; + } else if (dataType instanceof DecimalType t) { return getDecimal(ordinal, t.precision(), t.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); @@ -214,8 +212,7 @@ public void update(int ordinal, Object value) { setFloat(ordinal, (float) value); } else if (dt instanceof DoubleType) { setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; + } else if (dt instanceof DecimalType t) { Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale()); setDecimal(ordinal, d, t.precision()); } else if (dt instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index ac8da471f0033..4c8ceff356595 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -721,8 +721,7 @@ public Optional appendObjects(int length, Object value) { if (value instanceof Byte) { return Optional.of(appendBytes(length, (Byte) value)); } - if (value instanceof Decimal) { - Decimal decimal = (Decimal) value; + if (value instanceof Decimal decimal) { long unscaled = decimal.toUnscaledLong(); if (decimal.precision() < 10) { return Optional.of(appendInts(length, (int) unscaled)); @@ -745,8 +744,7 @@ public Optional appendObjects(int length, Object value) { if (value instanceof Short) { return Optional.of(appendShorts(length, (Short) value)); } - if (value instanceof UTF8String) { - UTF8String utf8 = (UTF8String) value; + if (value instanceof UTF8String utf8) { byte[] bytes = utf8.getBytes(); int result = 0; for (int i = 0; i < length; ++i) { @@ -754,8 +752,7 @@ public Optional appendObjects(int length, Object value) { } return Optional.of(result); } - if (value instanceof GenericArrayData) { - GenericArrayData arrayData = (GenericArrayData) value; + if (value instanceof GenericArrayData arrayData) { int result = 0; for (int i = 0; i < length; ++i) { appendArray(arrayData.numElements()); @@ -768,8 +765,7 @@ public Optional appendObjects(int length, Object value) { } return Optional.of(result); } - if (value instanceof GenericInternalRow) { - GenericInternalRow row = (GenericInternalRow) value; + if (value instanceof GenericInternalRow row) { int result = 0; for (int i = 0; i < length; ++i) { appendStruct(false); @@ -783,8 +779,7 @@ public Optional appendObjects(int length, Object value) { } return Optional.of(result); } - if (value instanceof ArrayBasedMapData) { - ArrayBasedMapData data = (ArrayBasedMapData) value; + if (value instanceof ArrayBasedMapData data) { appendArray(length); int result = 0; for (int i = 0; i < length; ++i) { @@ -965,14 +960,12 @@ protected WritableColumnVector(int capacity, DataType dataType) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - } else if (type instanceof StructType) { - StructType st = (StructType)type; + } else if (type instanceof StructType st) { this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - } else if (type instanceof MapType) { - MapType mapType = (MapType) type; + } else if (type instanceof MapType mapType) { this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index f5884efd8b23c..a83041dc522c6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -282,8 +282,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (!(obj instanceof ArrayRecord)) return false; - ArrayRecord other = (ArrayRecord) obj; + if (!(obj instanceof ArrayRecord other)) return false; return (other.id == this.id) && Objects.equals(other.intervals, this.intervals) && Arrays.equals(other.ints, ints); } @@ -330,8 +329,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (!(obj instanceof MapRecord)) return false; - MapRecord other = (MapRecord) obj; + if (!(obj instanceof MapRecord other)) return false; return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); } @@ -376,8 +374,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (!(obj instanceof Interval)) return false; - Interval other = (Interval) obj; + if (!(obj instanceof Interval other)) return false; return (other.startTime == this.startTime) && (other.endTime == this.endTime); } @@ -635,10 +632,9 @@ public String toString() { } public boolean equals(Object o) { - if (!(o instanceof Item)) { + if (!(o instanceof Item other)) { return false; } - Item other = (Item) o; if (other.getK().equals(k) && other.getV() == v) { return true; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d8068d57ee5e3..254c6df282091 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1752,8 +1752,7 @@ public int hashCode() { } public boolean equals(Object other) { - if (other instanceof BeanWithEnum) { - BeanWithEnum beanWithEnum = (BeanWithEnum) other; + if (other instanceof BeanWithEnum beanWithEnum) { return beanWithEnum.regularField.equals(regularField) && beanWithEnum.enumField.equals(enumField); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 1a55d198361ee..0c12fd5484a65 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -61,8 +61,7 @@ public StructType readSchema() { @Override public Filter[] pushFilters(Filter[] filters) { Filter[] supported = Arrays.stream(filters).filter(f -> { - if (f instanceof GreaterThan) { - GreaterThan gt = (GreaterThan) f; + if (f instanceof GreaterThan gt) { return gt.attribute().equals("i") && gt.value() instanceof Integer; } else { return false; @@ -70,8 +69,7 @@ public Filter[] pushFilters(Filter[] filters) { }).toArray(Filter[]::new); Filter[] unsupported = Arrays.stream(filters).filter(f -> { - if (f instanceof GreaterThan) { - GreaterThan gt = (GreaterThan) f; + if (f instanceof GreaterThan gt) { return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); } else { return true; @@ -114,8 +112,7 @@ public InputPartition[] planInputPartitions() { Integer lowerBound = null; for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; + if (filter instanceof GreaterThan f) { if ("i".equals(f.attribute()) && f.value() instanceof Integer) { lowerBound = (Integer) f.value(); break;