From 7a3d08895e2960adeb94f0d2d6d1205ae1c142d9 Mon Sep 17 00:00:00 2001
From: Yeahn Kim <yeahnkim@gmail.com>
Date: Wed, 13 Nov 2024 15:56:43 -0800
Subject: [PATCH] Substitute type arguments when checking type parameter
 nullability at call site (#1070)

- I have made NullAway to check generic type paramater's nullability
using type arguments explicityly passed at the call site.
- The relevant unit test is `void genericMethodAndVoidType()` in the
`jspecify/GenericMethodTests.java` file.

---------

Co-authored-by: Manu Sridharan <msridhar@gmail.com>
---
 .../main/java/com/uber/nullaway/NullAway.java |  2 +-
 .../nullaway/generics/GenericsChecks.java     | 88 ++++++++++++-------
 .../nullaway/jspecify/GenericMethodTests.java | 30 ++++++-
 .../uber/nullaway/jspecify/GenericsTests.java | 22 +++++
 4 files changed, 107 insertions(+), 35 deletions(-)

diff --git a/nullaway/src/main/java/com/uber/nullaway/NullAway.java b/nullaway/src/main/java/com/uber/nullaway/NullAway.java
index e7c6af8f47..38b1b1d6dd 100644
--- a/nullaway/src/main/java/com/uber/nullaway/NullAway.java
+++ b/nullaway/src/main/java/com/uber/nullaway/NullAway.java
@@ -1852,7 +1852,7 @@ private Description handleInvocation(
       }
       if (config.isJSpecifyMode()) {
         GenericsChecks.compareGenericTypeParameterNullabilityForCall(
-            formalParams, actualParams, varArgsMethod, this, state);
+            methodSymbol, tree, actualParams, varArgsMethod, this, state);
         if (!methodSymbol.getTypeParameters().isEmpty()) {
           GenericsChecks.checkGenericMethodCallTypeArguments(tree, state, this, config, handler);
         }
diff --git a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java
index 7f39ca2375..85ca6861bb 100644
--- a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java
+++ b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java
@@ -595,14 +595,16 @@ public static void checkTypeParameterNullnessForConditionalExpression(
    * Checks that for each parameter p at a call, the type parameter nullability for p's type matches
    * that of the corresponding formal parameter. If a mismatch is found, report an error.
    *
-   * @param formalParams the formal parameters
-   * @param actualParams the actual parameters
+   * @param methodSymbol the symbol for the method being called
+   * @param tree the tree representing the method call
+   * @param actualParams the actual parameters at the call
    * @param isVarArgs true if the call is to a varargs method
    * @param analysis the analysis object
    * @param state the visitor state
    */
   public static void compareGenericTypeParameterNullabilityForCall(
-      List<Symbol.VarSymbol> formalParams,
+      Symbol.MethodSymbol methodSymbol,
+      Tree tree,
       List<? extends ExpressionTree> actualParams,
       boolean isVarArgs,
       NullAway analysis,
@@ -610,14 +612,35 @@ public static void compareGenericTypeParameterNullabilityForCall(
     if (!analysis.getConfig().isJSpecifyMode()) {
       return;
     }
-    int n = formalParams.size();
+    Type invokedMethodType = methodSymbol.type;
+    // substitute class-level type arguments for instance methods
+    if (!methodSymbol.isStatic() && tree instanceof MethodInvocationTree) {
+      ExpressionTree methodSelect = ((MethodInvocationTree) tree).getMethodSelect();
+      Type enclosingType;
+      if (methodSelect instanceof MemberSelectTree) {
+        enclosingType = getTreeType(((MemberSelectTree) methodSelect).getExpression(), state);
+      } else {
+        // implicit this parameter
+        enclosingType = methodSymbol.owner.type;
+      }
+      if (enclosingType != null) {
+        invokedMethodType = state.getTypes().memberType(enclosingType, methodSymbol);
+      }
+    }
+    // substitute type arguments for generic methods
+    if (tree instanceof MethodInvocationTree && methodSymbol.type instanceof Type.ForAll) {
+      invokedMethodType =
+          substituteTypeArgsInGenericMethodType((MethodInvocationTree) tree, methodSymbol, state);
+    }
+    List<Type> formalParamTypes = invokedMethodType.getParameterTypes();
+    int n = formalParamTypes.size();
     if (isVarArgs) {
       // If the last argument is var args, don't check it now, it will be checked against
       // all remaining actual arguments in the next loop.
       n = n - 1;
     }
     for (int i = 0; i < n; i++) {
-      Type formalParameter = formalParams.get(i).type;
+      Type formalParameter = formalParamTypes.get(i);
       if (formalParameter.isRaw()) {
         // bail out of any checking involving raw types for now
         return;
@@ -630,11 +653,11 @@ public static void compareGenericTypeParameterNullabilityForCall(
         }
       }
     }
-    if (isVarArgs && !formalParams.isEmpty()) {
+    if (isVarArgs && !formalParamTypes.isEmpty()) {
       Type.ArrayType varargsArrayType =
-          (Type.ArrayType) formalParams.get(formalParams.size() - 1).type;
+          (Type.ArrayType) formalParamTypes.get(formalParamTypes.size() - 1);
       Type varargsElementType = varargsArrayType.elemtype;
-      for (int i = formalParams.size() - 1; i < actualParams.size(); i++) {
+      for (int i = formalParamTypes.size() - 1; i < actualParams.size(); i++) {
         Type actualParameterType = getTreeType(actualParams.get(i), state);
         // If the actual parameter type is assignable to the varargs array type, then the call site
         // is passing the varargs directly in an array, and we should skip our check.
@@ -796,19 +819,9 @@ public static Nullness getGenericReturnNullnessAtInvocation(
       Config config) {
     // If generic method invocation
     if (!invokedMethodSymbol.getTypeParameters().isEmpty()) {
-      List<? extends Tree> typeArgumentTrees = tree.getTypeArguments();
-      com.sun.tools.javac.util.List<Type> explicitTypeArgs =
-          convertTreesToTypes(typeArgumentTrees); // Convert to Type objects
-      Type.ForAll forAllType = (Type.ForAll) invokedMethodSymbol.type;
-      // Extract the underlying MethodType (the actual signature)
-      Type.MethodType methodTypeInsideForAll = (Type.MethodType) forAllType.asMethodType();
       // Substitute type arguments inside the return type
-      // NOTE: if the return type it not a type variable of the method itself, or if
-      // explicitTypeArgs is empty, this is a noop.
       Type substitutedReturnType =
-          state
-              .getTypes()
-              .subst(methodTypeInsideForAll.restype, forAllType.tvars, explicitTypeArgs);
+          substituteTypeArgsInGenericMethodType(tree, invokedMethodSymbol, state).getReturnType();
       // If this condition evaluates to false, we fall through to the subsequent logic, to handle
       // type variables declared on the enclosing class
       if (substitutedReturnType != null
@@ -842,6 +855,27 @@ private static com.sun.tools.javac.util.List<Type> convertTreesToTypes(
     return com.sun.tools.javac.util.List.from(types);
   }
 
+  /**
+   * Substitutes the type arguments from a generic method invocation into the method's type.
+   *
+   * @param methodInvocationTree the method invocation tree
+   * @param methodSymbol symbol for the invoked generic method
+   * @param state the visitor state
+   * @return the substituted method type for the generic method
+   */
+  private static Type substituteTypeArgsInGenericMethodType(
+      MethodInvocationTree methodInvocationTree,
+      Symbol.MethodSymbol methodSymbol,
+      VisitorState state) {
+
+    List<? extends Tree> typeArgumentTrees = methodInvocationTree.getTypeArguments();
+    com.sun.tools.javac.util.List<Type> explicitTypeArgs = convertTreesToTypes(typeArgumentTrees);
+
+    Type.ForAll forAllType = (Type.ForAll) methodSymbol.type;
+    Type.MethodType underlyingMethodType = (Type.MethodType) forAllType.qtype;
+    return state.getTypes().subst(underlyingMethodType, forAllType.tvars, explicitTypeArgs);
+  }
+
   /**
    * Computes the nullness of a formal parameter of a generic method at an invocation, in the
    * context of the declared type of its receiver argument. If the formal parameter's type is a type
@@ -884,23 +918,11 @@ public static Nullness getGenericParameterNullnessAtInvocation(
       Config config) {
     // If generic method invocation
     if (!invokedMethodSymbol.getTypeParameters().isEmpty()) {
-      List<? extends Tree> typeArgumentTrees = tree.getTypeArguments();
-      com.sun.tools.javac.util.List<Type> explicitTypeArgs =
-          convertTreesToTypes(typeArgumentTrees); // Convert to Type objects
-
-      Type.ForAll forAllType = (Type.ForAll) invokedMethodSymbol.type;
-      // Extract the underlying MethodType (the actual signature)
-      Type.MethodType methodTypeInsideForAll = (Type.MethodType) forAllType.qtype;
       // Substitute the argument types within the MethodType
       // NOTE: if explicitTypeArgs is empty, this is a noop
       List<Type> substitutedParamTypes =
-          state
-              .getTypes()
-              .subst(
-                  methodTypeInsideForAll.argtypes,
-                  forAllType.tvars, // The type variables from the ForAll
-                  explicitTypeArgs // The actual type arguments from the method invocation
-                  );
+          substituteTypeArgsInGenericMethodType(tree, invokedMethodSymbol, state)
+              .getParameterTypes();
       // If this condition evaluates to false, we fall through to the subsequent logic, to handle
       // type variables declared on the enclosing class
       if (substitutedParamTypes != null
diff --git a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java
index 6b9d409ba2..2fbd4e68f0 100644
--- a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java
+++ b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java
@@ -105,8 +105,36 @@ public void genericInstanceMethods() {
   }
 
   @Test
-  @Ignore("requires generic method support")
   public void genericMethodAndVoidType() {
+    makeHelper()
+        .addSourceLines(
+            "Test.java",
+            "package com.uber;",
+            "import org.jspecify.annotations.Nullable;",
+            "class Test {",
+            "  static class Foo {",
+            "    <C extends @Nullable Object> void foo(C c, Visitor<C> visitor) {",
+            "      visitor.visit(this, c);",
+            "    }",
+            "  }",
+            "  static abstract class Visitor<C extends @Nullable Object> {",
+            "    abstract void visit(Foo foo, C c);",
+            "  }",
+            "  static class MyVisitor extends Visitor<@Nullable Void> {",
+            "    @Override",
+            "    void visit(Foo foo, @Nullable Void c) {}",
+            "  }",
+            "  static void test(Foo f) {",
+            "    // this is safe",
+            "    f.<@Nullable Void>foo(null, new MyVisitor());",
+            "  }",
+            "}")
+        .doTest();
+  }
+
+  @Test
+  @Ignore("requires inference of generic method type arguments")
+  public void genericMethodAndVoidTypeWithInference() {
     makeHelper()
         .addSourceLines(
             "Test.java",
diff --git a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericsTests.java b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericsTests.java
index c686b25f0c..74a7ef481f 100644
--- a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericsTests.java
+++ b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericsTests.java
@@ -938,6 +938,28 @@ public void parameterPassing() {
         .doTest();
   }
 
+  @Test
+  public void parameterPassingInstanceMethods() {
+    makeHelper()
+        .addSourceLines(
+            "Test.java",
+            "package com.uber;",
+            "import org.jspecify.annotations.Nullable;",
+            "class Test {",
+            "  static class A<T extends @Nullable Object> {",
+            "    void foo(A<T> a) {}",
+            "    void bar(A<T> a) { foo(a); this.foo(a); }",
+            "  }",
+            "  static void test(A<@Nullable String> p, A<String> q) {",
+            "    // BUG: Diagnostic contains: Cannot pass parameter of type",
+            "    p.foo(q);",
+            "    // this one is fine",
+            "    p.foo(p);",
+            "  }",
+            "}")
+        .doTest();
+  }
+
   @Test
   public void varargsParameter() {
     makeHelper()