diff --git a/eea-generator/pom.xml b/eea-generator/pom.xml index a75200cf15..315b7b85a9 100644 --- a/eea-generator/pom.xml +++ b/eea-generator/pom.xml @@ -5,7 +5,9 @@ SPDX-FileContributor: Sebastian Thomschke (https://sebthom.de), Vegard IT GmbH ( SPDX-License-Identifier: EPL-2.0 SPDX-ArtifactOfProjectHomePage: https://github.com/vegardit/no-npe --> - + 4.0.0 @@ -53,7 +55,12 @@ SPDX-ArtifactOfProjectHomePage: https://github.com/vegardit/no-npe io.github.classgraph classgraph - 4.8.176 + 4.8.177 + + + org.ow2.asm + asm-util + 9.7 diff --git a/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/EEAGenerator.java b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/EEAGenerator.java index 93cfa03b5d..765b4ee2aa 100644 --- a/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/EEAGenerator.java +++ b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/EEAGenerator.java @@ -38,6 +38,9 @@ import com.vegardit.no_npe.eea_generator.EEAFile.ClassMember; import com.vegardit.no_npe.eea_generator.EEAFile.SaveOption; import com.vegardit.no_npe.eea_generator.EEAFile.ValueWithComment; +import com.vegardit.no_npe.eea_generator.internal.BytecodeAnalyzer; +import com.vegardit.no_npe.eea_generator.internal.ClassGraphUtils; +import com.vegardit.no_npe.eea_generator.internal.ClassGraphUtils.MethodReturnKind; import com.vegardit.no_npe.eea_generator.internal.Props; import io.github.classgraph.ClassGraph; @@ -218,7 +221,7 @@ public static void main(final String... args) throws Exception { } protected static ValueWithComment computeAnnotatedSignature(final EEAFile.ClassMember member, final ClassInfo classInfo, - final ClassMemberInfo memberInfo) { + final ClassMemberInfo memberInfo, final BytecodeAnalyzer bytecodeAnalyzer) { final var templates = new ArrayList(); if (isThrowable(classInfo)) { @@ -238,73 +241,91 @@ protected static ValueWithComment computeAnnotatedSignature(final EEAFile.ClassM if (memberInfo instanceof MethodInfo) { final MethodInfo methodInfo = (MethodInfo) memberInfo; - /* - * mark the return value of builder methods as @NonNull. - */ - if (classInfo.getName().endsWith("Builder") // - && !methodInfo.isStatic() // non-static - && methodInfo.isPublic() // - && methodInfo.getTypeDescriptor().getResultType() instanceof ClassRefTypeSignature // - && (methodInfo.getName().equals("build") && methodInfo.getParameterInfo().length == 0 // - || Objects.equals(((ClassRefTypeSignature) methodInfo.getTypeDescriptor().getResultType()).getClassInfo(), classInfo))) - // (...)Lcom/example/MyBuilder -> (...)L1com/example/MyBuilder; - return new ValueWithComment(insert(member.originalSignature.value, member.originalSignature.value.lastIndexOf(")") + 2, "1"), - ""); - - /* - * mark the parameter of Comparable#compareTo(Object) as @NonNull. - */ - if (classInfo.implementsInterface("java.lang.Comparable") // - && !methodInfo.isStatic() // non-static - && member.originalSignature.value.endsWith(")I") // returns Integer - && methodInfo.isPublic() // - && methodInfo.getParameterInfo().length == 1 // only 1 parameter - && methodInfo.getParameterInfo()[0].getTypeDescriptor() instanceof ClassRefTypeSignature) - // (Lcom/example/Entity;)I -> (L1com/example/Entity;)I - return new ValueWithComment(insert(member.originalSignature.value, 2, "1"), ""); - - /* - * mark the parameter of single-parameter void methods as @NonNull, - * if the class name matches "*Listener" and the parameter type name matches "*Event" - */ - if (classInfo.isInterface() // - && classInfo.getName().endsWith("Listener") // - && !methodInfo.isStatic() // non-static - && member.originalSignature.value.endsWith(")V") // returns void - && methodInfo.getParameterInfo().length == 1 // only 1 parameter - && methodInfo.getParameterInfo()[0].getTypeDescriptor().toString().endsWith("Event")) - // (Ljava/lang/String;)V -> (L1java/lang/String;)V - return new ValueWithComment(insert(member.originalSignature.value, 2, "1"), ""); - - /* - * mark the parameter of single-parameter methods as @NonNull - * with signature matching: void (add|remove)*Listener(*Listener) - */ - if (!methodInfo.isStatic() // non-static - && (methodInfo.getName().startsWith("add") || methodInfo.getName().startsWith("remove")) // - && methodInfo.getName().endsWith("Listener") // - && member.originalSignature.value.endsWith(")V") // returns void - && methodInfo.getParameterInfo().length == 1 // only 1 parameter - && methodInfo.getParameterInfo()[0].getTypeDescriptor().toString().endsWith("Listener")) - return new ValueWithComment( // - member.originalSignature.value.startsWith("(") // - // (Lcom/example/MyListener;)V -> (L1com/example/MyListener;)V - // (TT;)V -> (T1T;)V - ? insert(member.originalSignature.value, 2, "1") // - // (TT;)V --> <1T::Lcom/example/MyListener;>(TT;)V - : insert(member.originalSignature.value, 1, "1"), // - ""); - - if (hasObjectReturnType(member)) { // returns non-void - if (hasNullableAnnotation(methodInfo.getAnnotationInfo())) + final var returnKind = ClassGraphUtils.getMethodReturnKind(methodInfo); + if (returnKind == MethodReturnKind.ARRAY || returnKind == MethodReturnKind.OBJECT) { + + final var returnTypeNullability = bytecodeAnalyzer.determineMethodReturnTypeNullability(methodInfo); + /* + * mark the return value of a method as nullable if the byte code analysis of the method body determines it returns null values + * or the method is annotated with a known nullable annotation. + */ + if (returnTypeNullability.isNullable() // + || hasNullableAnnotation(methodInfo.getAnnotationInfo())) // ()Ljava/lang/String -> ()L0java/lang/String; return new ValueWithComment(insert(member.originalSignature.value, member.originalSignature.value.lastIndexOf(")") + 2, "0"), ""); - if (hasNonNullAnnotation(methodInfo.getAnnotationInfo())) + /* + * mark the return value of a method as non-null if the method is annotated with a non-null annotation + * or has a method name starting with "create". + */ + if (returnTypeNullability.isNonNull() // + || hasNonNullAnnotation(methodInfo.getAnnotationInfo()) // + || methodInfo.getName().startsWith("create")) // ()Ljava/lang/String -> ()L1java/lang/String; + // create...(...)LLcom/example/Entity -> create...(...)L1Lcom/example/Entity; + return new ValueWithComment(insert(member.originalSignature.value, member.originalSignature.value.lastIndexOf(")") + 2, "1"), + ""); + + /* + * mark the return value of builder methods as @NonNull. + */ + if (classInfo.getName().endsWith("Builder") // + && !methodInfo.isStatic() // non-static + && methodInfo.isPublic() // + && methodInfo.getTypeDescriptor().getResultType() instanceof ClassRefTypeSignature // + && (methodInfo.getName().equals("build") && methodInfo.getParameterInfo().length == 0 // + || Objects.equals(((ClassRefTypeSignature) methodInfo.getTypeDescriptor().getResultType()).getClassInfo(), + classInfo))) + // (...)Lcom/example/MyBuilder -> (...)L1com/example/MyBuilder; return new ValueWithComment(insert(member.originalSignature.value, member.originalSignature.value.lastIndexOf(")") + 2, "1"), ""); + + } else { + + /* + * mark the parameter of Comparable#compareTo(Object) as @NonNull. + */ + if (classInfo.implementsInterface("java.lang.Comparable") // + && !methodInfo.isStatic() // non-static + && member.originalSignature.value.endsWith(")I") // returns Integer + && methodInfo.isPublic() // + && methodInfo.getParameterInfo().length == 1 // only 1 parameter + && methodInfo.getParameterInfo()[0].getTypeDescriptor() instanceof ClassRefTypeSignature) + // (Lcom/example/Entity;)I -> (L1com/example/Entity;)I + return new ValueWithComment(insert(member.originalSignature.value, 2, "1"), ""); + + /* + * mark the parameter of single-parameter void methods as @NonNull, + * if the class name matches "*Listener" and the parameter type name matches "*Event" + */ + if (classInfo.isInterface() // + && classInfo.getName().endsWith("Listener") // + && !methodInfo.isStatic() // non-static + && member.originalSignature.value.endsWith(")V") // returns void + && methodInfo.getParameterInfo().length == 1 // only 1 parameter + && methodInfo.getParameterInfo()[0].getTypeDescriptor().toString().endsWith("Event")) + // (Ljava/lang/String;)V -> (L1java/lang/String;)V + return new ValueWithComment(insert(member.originalSignature.value, 2, "1"), ""); + + /* + * mark the parameter of single-parameter methods as @NonNull + * with signature matching: void (add|remove)*Listener(*Listener) + */ + if (!methodInfo.isStatic() // non-static + && (methodInfo.getName().startsWith("add") || methodInfo.getName().startsWith("remove")) // + && methodInfo.getName().endsWith("Listener") // + && member.originalSignature.value.endsWith(")V") // returns void + && methodInfo.getParameterInfo().length == 1 // only 1 parameter + && methodInfo.getParameterInfo()[0].getTypeDescriptor().toString().endsWith("Listener")) + return new ValueWithComment( // + member.originalSignature.value.startsWith("(") // + // (Lcom/example/MyListener;)V -> (L1com/example/MyListener;)V + // (TT;)V -> (T1T;)V + ? insert(member.originalSignature.value, 2, "1") // + // (TT;)V --> <1T::Lcom/example/MyListener;>(TT;)V + : insert(member.originalSignature.value, 1, "1"), // + ""); } } @@ -324,14 +345,6 @@ protected static ValueWithComment computeAnnotatedSignature(final EEAFile.ClassM return new ValueWithComment(member.originalSignature.value); } - protected static boolean hasObjectReturnType(final EEAFile.ClassMember member) { - final String sig = member.originalSignature.value; - // object return type: (Ljava/lang/String;)Ljava/lang/String; or (Ljava/lang/String;)TT; - // void return type: (Ljava/lang/String;)V - // primitive return type: (Ljava/lang/String;)B - return sig.charAt(sig.length() - 2) != ')'; - } - protected static EEAFile computeEEAFile(final ClassInfo classInfo) { LOG.log(Level.DEBUG, "Scanning class [{0}]...", classInfo.getName()); @@ -390,6 +403,8 @@ protected static EEAFile computeEEAFile(final ClassInfo classInfo) { } eeaFile.addEmptyLine(); + final var bytecodeAnalyzer = new BytecodeAnalyzer(classInfo); + // static fields for (final FieldInfo f : getStaticFields(fields)) { if (classInfo.isEnum()) { @@ -400,28 +415,28 @@ protected static EEAFile computeEEAFile(final ClassInfo classInfo) { } final var member = eeaFile.addMember(f.getName(), f.getTypeSignatureOrTypeDescriptorStr()); // CHECKSTYLE:IGNORE .* - member.annotatedSignature = computeAnnotatedSignature(member, classInfo, f); + member.annotatedSignature = computeAnnotatedSignature(member, classInfo, f, bytecodeAnalyzer); } eeaFile.addEmptyLine(); // static methods for (final MethodInfo m : getStaticMethods(methods)) { final var member = eeaFile.addMember(m.getName(), m.getTypeSignatureOrTypeDescriptorStr()); - member.annotatedSignature = computeAnnotatedSignature(member, classInfo, m); + member.annotatedSignature = computeAnnotatedSignature(member, classInfo, m, bytecodeAnalyzer); } eeaFile.addEmptyLine(); // instance fields for (final FieldInfo f : getInstanceFields(fields)) { final var member = eeaFile.addMember(f.getName(), f.getTypeSignatureOrTypeDescriptorStr()); // CHECKSTYLE:IGNORE .* - member.annotatedSignature = computeAnnotatedSignature(member, classInfo, f); + member.annotatedSignature = computeAnnotatedSignature(member, classInfo, f, bytecodeAnalyzer); } eeaFile.addEmptyLine(); // instance methods for (final MethodInfo m : getInstanceMethods(methods)) { final var member = eeaFile.addMember(m.getName(), m.getTypeSignatureOrTypeDescriptorStr()); // CHECKSTYLE:IGNORE .* - member.annotatedSignature = computeAnnotatedSignature(member, classInfo, m); + member.annotatedSignature = computeAnnotatedSignature(member, classInfo, m, bytecodeAnalyzer); } return eeaFile; } diff --git a/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzer.java b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzer.java new file mode 100644 index 0000000000..421f22ffac --- /dev/null +++ b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzer.java @@ -0,0 +1,257 @@ +/* + * SPDX-FileCopyrightText: © Vegard IT GmbH (https://vegardit.com) and contributors. + * SPDX-License-Identifier: EPL-2.0 + */ +package com.vegardit.no_npe.eea_generator.internal; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.eclipse.jdt.annotation.NonNullByDefault; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.ConstantDynamic; +import org.objectweb.asm.Handle; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +import io.github.classgraph.ClassInfo; +import io.github.classgraph.MethodInfo; + +/** + * Analyzes bytecode to determine the nullability of method return types. + * + * @author Sebastian Thomschke (https://sebthom.de), Vegard IT GmbH (https://vegardit.com) + */ +public class BytecodeAnalyzer { + + public enum Nullability { + NON_NULL, + NULLABLE, + UNKNOWN; + + public boolean isNullable() { + return this == NULLABLE; + } + + public boolean isNonNull() { + return this == NON_NULL; + } + } + + private final ClassReader classReader; + + public BytecodeAnalyzer(final ClassInfo classInfo) { + try (var classFileResource = classInfo.getResource()) { + if (classFileResource == null) + throw new IOException("Class resource not found: " + classInfo); + try (var is = classFileResource.open()) { + classReader = new ClassReader(is); + } + } catch (final IOException ex) { + throw new UncheckedIOException("Failed to read class resource: " + classInfo, ex); + } + } + + public Nullability determineMethodReturnTypeNullability(final MethodInfo methodInfo) { + switch (ClassGraphUtils.getMethodReturnKind(methodInfo)) { + case PRIMITIVE: + case VOID: + return Nullability.NON_NULL; + default: + } + + if (methodInfo.isAbstract()) + return Nullability.UNKNOWN; + + final String methodName = methodInfo.getName(); + final String methodDescriptor = methodInfo.getTypeDescriptorStr(); + + final var returnNullabilities = new ArrayList(); + classReader.accept(new ClassVisitor(Opcodes.ASM9) { + @Override + @NonNullByDefault({}) + public MethodVisitor visitMethod(final int access, final String name, final String descriptor, final String signature, + final String[] exceptions) { + if (name.equals(methodName) && descriptor.equals(methodDescriptor)) + return new MethodVisitor(Opcodes.ASM9) { + private final ArrayDeque operandStack = new ArrayDeque<>(); + private final Map localVariableNullability = new HashMap<>(); + + private boolean isKnownNonNullMethod(final String clazz, final String methodName, final String descriptor) { + // CHECKSTYLE:IGNORE .* FOR NEXT 5 LINES + return methodName.equals("") // + || methodName.equals("toString") && descriptor.equals("()Ljava/lang/String;") // + || clazz.equals("java/lang/String") && methodName.equals("valueOf") && descriptor.equals( + "(Ljava/lang/Object;)Ljava/lang/String;") // + || clazz.equals("java/lang/StringBuilder") && methodName.startsWith("append"); + } + + @Override + public void visitFieldInsn(final int opcode, final String owner, final String name, final String descriptor) { + switch (opcode) { + case Opcodes.GETSTATIC: + case Opcodes.GETFIELD: + // pushing a field value onto the stack; nullability unknown + operandStack.push(Nullability.UNKNOWN); + break; + case Opcodes.PUTSTATIC: + case Opcodes.PUTFIELD: + if (!operandStack.isEmpty()) { + operandStack.pop(); + } + break; + default: + operandStack.push(Nullability.UNKNOWN); + break; + } + } + + @Override + public void visitInsn(final int opcode) { + switch (opcode) { + case Opcodes.ACONST_NULL: + operandStack.push(Nullability.NULLABLE); + break; + case Opcodes.ARETURN: + if (operandStack.isEmpty()) { + // stack underflow; treat as possibly null + returnNullabilities.add(Nullability.UNKNOWN); + } else { + final Nullability returnValue = operandStack.pop(); + returnNullabilities.add(returnValue); + } + // clear the operand stack after a return + operandStack.clear(); + break; + case Opcodes.DUP: + if (operandStack.isEmpty()) { + // stack underflow; treat as possibly null + operandStack.push(Nullability.UNKNOWN); + } else { + final Nullability top = operandStack.peek(); + operandStack.push(top); + } + break; + case Opcodes.POP: + if (!operandStack.isEmpty()) { + operandStack.pop(); + } + break; + default: + // for other instructions, assume they might alter the stack + // for simplicity, we reset the stack to UNKNOWN + operandStack.clear(); + operandStack.push(Nullability.UNKNOWN); + break; + } + } + + @Override + public void visitLdcInsn(final Object constant) { + if (constant instanceof Integer || constant instanceof Float // + || constant instanceof Long || constant instanceof Double // + || constant instanceof String // + || constant instanceof Type || constant instanceof Handle) { + // primitive constants, string constants, class literals, method handles are non-null + operandStack.push(Nullability.NON_NULL); + } else if (constant instanceof ConstantDynamic) { + // ConstantDynamic may resolve to null, so treat it as possibly null + operandStack.push(Nullability.UNKNOWN); + } else { + // handle other unexpected types conservatively as possibly null + operandStack.push(Nullability.UNKNOWN); + } + } + + @Override + public void visitMethodInsn(final int opcode, final String owner, final String name, final String descriptor, + final boolean isInterface) { + // pop arguments off the stack, push return value + final Type methodType = Type.getMethodType(descriptor); + for (int i = 0, argCount = methodType.getArgumentTypes().length; i < argCount; i++) { + if (!operandStack.isEmpty()) { + operandStack.pop(); + } + } + if (opcode != Opcodes.INVOKESTATIC && opcode != Opcodes.INVOKEDYNAMIC) { + // pop 'this' reference + if (!operandStack.isEmpty()) { + operandStack.pop(); + } + } + + // push the return value onto the stack + switch (methodType.getReturnType().getSort()) { + case Type.VOID: + break; + case Type.OBJECT: + case Type.ARRAY: + if (isKnownNonNullMethod(owner, name, descriptor)) { + operandStack.push(Nullability.NON_NULL); + } else { + // reference type; nullability unknown + operandStack.push(Nullability.UNKNOWN); + } + break; + default: + // primitive type; definitely non-null + operandStack.push(Nullability.NON_NULL); + } + } + + @Override + public void visitTypeInsn(final int opcode, final String type) { + if (opcode == Opcodes.NEW) { + operandStack.push(Nullability.NON_NULL); + } else { + operandStack.push(Nullability.UNKNOWN); + } + } + + @Override + public void visitVarInsn(final int opcode, final int varIndex) { + switch (opcode) { + case Opcodes.ALOAD: + // loading a reference from a local variable + final Nullability varNullability = localVariableNullability.getOrDefault(varIndex, Nullability.UNKNOWN); + operandStack.push(varNullability); + break; + case Opcodes.ASTORE: + // storing a value into a local variable + if (!operandStack.isEmpty()) { // CHECKSTYLE:IGNORE .* + final Nullability valueNullability = operandStack.pop(); + localVariableNullability.put(varIndex, valueNullability); + } else { + // stack underflow; assume possibly null + localVariableNullability.put(varIndex, Nullability.UNKNOWN); + } + break; + default: + operandStack.push(Nullability.UNKNOWN); + break; + } + } + }; + return super.visitMethod(access, name, descriptor, signature, exceptions); + } + }, 0); + + if (returnNullabilities.isEmpty()) + // no return statements found (shouldn't happen); nullability unknown + return Nullability.UNKNOWN; + + if (returnNullabilities.contains(Nullability.NULLABLE)) + return Nullability.NULLABLE; + + if (returnNullabilities.contains(Nullability.UNKNOWN)) + return Nullability.UNKNOWN; + + return Nullability.NON_NULL; + } +} diff --git a/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/ClassGraphUtils.java b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/ClassGraphUtils.java index 54855d928e..1573854748 100644 --- a/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/ClassGraphUtils.java +++ b/eea-generator/src/main/java/com/vegardit/no_npe/eea_generator/internal/ClassGraphUtils.java @@ -10,17 +10,28 @@ import java.util.TreeSet; import io.github.classgraph.AnnotationInfoList; +import io.github.classgraph.ArrayTypeSignature; +import io.github.classgraph.BaseTypeSignature; import io.github.classgraph.ClassInfo; +import io.github.classgraph.ClassRefTypeSignature; import io.github.classgraph.FieldInfo; import io.github.classgraph.FieldInfoList; import io.github.classgraph.MethodInfo; import io.github.classgraph.MethodInfoList; +import io.github.classgraph.TypeSignature; /** * @author Sebastian Thomschke (https://sebthom.de), Vegard IT GmbH (https://vegardit.com) */ public final class ClassGraphUtils { + public enum MethodReturnKind { + ARRAY, + PRIMITIVE, + OBJECT, + VOID + } + private static final Set NULLABLE_ANNOTATIONS = Set.of( // "android.annotation.Nullable", // "android.support.annotation.Nullable", // @@ -160,6 +171,26 @@ public static SortedSet getInstanceMethods(final MethodInfoList meth return getFilteredAndSortedMethods(methods, false); } + /** + * Determines the return kind of the given method. + * This method distinguishes between methods that return objects, arrays, primitive types, or void. + * + * @param methodInfo the method whose return type is to be checked + * @return MethodReturnKind representing if the method returns an object, array, primitive, or void + */ + public static MethodReturnKind getMethodReturnKind(final MethodInfo methodInfo) { + final TypeSignature returnType = methodInfo.getTypeDescriptor().getResultType(); + if (returnType == null) + return MethodReturnKind.VOID; + if (returnType instanceof BaseTypeSignature) + return MethodReturnKind.PRIMITIVE; + if (returnType instanceof ArrayTypeSignature) + return MethodReturnKind.ARRAY; + if (returnType instanceof ClassRefTypeSignature) + return MethodReturnKind.OBJECT; + throw new IllegalStateException("Unknown method return kind: " + returnType); + } + public static SortedSet getStaticFields(final FieldInfoList fields) { return getFilteredAndSortedFields(fields, true); } @@ -176,14 +207,14 @@ public static boolean hasNullableAnnotation(final AnnotationInfoList annos) { return annos.stream().anyMatch(a -> NULLABLE_ANNOTATIONS.contains(a.getName())); } - public static boolean hasSuperclass(final ClassInfo classInfo, final String superClassName) { - return !classInfo.getSuperclasses().filter(c -> c.getName().equals(superClassName)).isEmpty(); - } - public static boolean hasPackageVisibility(final ClassInfo classInfo) { return !classInfo.isPublic() && !classInfo.isPrivate() && !classInfo.isProtected(); } + public static boolean hasSuperclass(final ClassInfo classInfo, final String superClassName) { + return !classInfo.getSuperclasses().filter(c -> c.getName().equals(superClassName)).isEmpty(); + } + public static boolean isStaticField(final ClassInfo classInfo, final String fieldName) { final var fieldInfo = classInfo.getDeclaredFieldInfo(fieldName); if (fieldInfo == null) diff --git a/eea-generator/src/test/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzerTest.java b/eea-generator/src/test/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzerTest.java new file mode 100644 index 0000000000..e6440d25a5 --- /dev/null +++ b/eea-generator/src/test/java/com/vegardit/no_npe/eea_generator/internal/BytecodeAnalyzerTest.java @@ -0,0 +1,148 @@ +/* + * SPDX-FileCopyrightText: © Vegard IT GmbH (https://vegardit.com) and contributors. + * SPDX-License-Identifier: EPL-2.0 + */ +package com.vegardit.no_npe.eea_generator.internal; + +import static java.lang.annotation.ElementType.METHOD; +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Method; +import java.util.function.Supplier; + +import org.eclipse.jdt.annotation.Nullable; +import org.junit.jupiter.api.Test; + +import com.vegardit.no_npe.eea_generator.internal.BytecodeAnalyzer.Nullability; + +import io.github.classgraph.ClassGraph; +import io.github.classgraph.ScanResult; + +/** + * @author Sebastian Thomschke (https://sebthom.de), Vegard IT GmbH (https://vegardit.com) + */ +class BytecodeAnalyzerTest { + + @Retention(RetentionPolicy.RUNTIME) + @Target(METHOD) + public @interface ReturnValueNullability { + Nullability value(); + } + + @ReturnValueNullability(Nullability.NULLABLE) + static @Nullable Object returningAConstNull1() { + return null; + } + + @ReturnValueNullability(Nullability.NULLABLE) + static @Nullable Object returningAConstNull2() { + if (System.currentTimeMillis() == 123) + return "Hey"; + return null; + } + + @ReturnValueNullability(Nullability.NULLABLE) + static @Nullable Object returningAConstNull3() { + return System.currentTimeMillis() == 123 ? "Hey" : null; + } + + @ReturnValueNullability(Nullability.NULLABLE) + static @Nullable Object returningAConstNull4(final boolean condition) { + return condition ? "Hey" : null; + } + + @ReturnValueNullability(Nullability.NULLABLE) + static @Nullable Object returningAConstNull5() { + final String foo = null; + return foo; + } + + @ReturnValueNullability(Nullability.UNKNOWN) + static @Nullable Object returningDynamicNull1() { + return new @Nullable Object[] {null}[0]; + } + + @ReturnValueNullability(Nullability.UNKNOWN) + static @Nullable Object returningDynamicNull2() { + final @Nullable String env = System.getProperty("Abcdefg1234567"); + @SuppressWarnings("unused") + final var unused = new Object(); + return env; + } + + @ReturnValueNullability(Nullability.NON_NULL) + static Object neverReturningNull1() { + return "Hey"; + } + + @ReturnValueNullability(Nullability.NON_NULL) + static Object neverReturningNull2() { + return new Object(); + } + + @ReturnValueNullability(Nullability.NON_NULL) + static Object neverReturningNull3() { + return new String("Test"); + } + + @ReturnValueNullability(Nullability.NON_NULL) + static Object neverReturningNull4() { + return new Object() + " test"; + } + + @ReturnValueNullability(Nullability.NON_NULL) + public Object neverReturningNull5(final boolean condition) { + if (condition) + return new Object(); + return "Constant String"; + } + + /* test method to ensure that `return null` in lambdas are not mistaken as null returns */ + @ReturnValueNullability(Nullability.NON_NULL) + static Object neverReturningNull6() { + final Supplier<@Nullable String> foo = () -> { + System.out.print("Ho"); + return null; + }; + foo.get(); // use foo to avoid potential dead code elimination + return "Hey"; + } + + @ReturnValueNullability(Nullability.NON_NULL) + static void neverReturningNull7() { + // nothing to do + } + + @ReturnValueNullability(Nullability.NON_NULL) + static int neverReturningNull8() { + return 1; + } + + @Test + @SuppressWarnings("null") + void testDetermineMethodReturnTypeNullability() { + final var className = BytecodeAnalyzerTest.class.getName(); + try (ScanResult scanResult = new ClassGraph() // + .enableAllInfo() // + .enableSystemJarsAndModules() // + .acceptClasses(className) // + .scan()) { + + final var classInfo = scanResult.getClassInfo(className); + assert classInfo != null; + + final var analyzer = new BytecodeAnalyzer(classInfo); + + for (final Method m : this.getClass().getDeclaredMethods()) { + final var anno = m.getAnnotation(ReturnValueNullability.class); + if (anno != null) { + assertThat(analyzer.determineMethodReturnTypeNullability(classInfo.getMethodInfo(m.getName()).get(0))).describedAs(m + .getName()).isEqualTo(anno.value()); + } + } + } + } +}