Skip to content

Commit

Permalink
Downcall: FIX cross-module access
Browse files Browse the repository at this point in the history
  • Loading branch information
squid233 committed Feb 2, 2024
1 parent 2b048d0 commit 4ed2f7b
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 35 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import overrun.marshal.gen.*;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.lang.invoke.MethodHandles;

/**
* GLFW constants and functions
Expand All @@ -25,7 +26,7 @@ interface GLFW {
/**
* The instance of the loaded library
*/
GLFW INSTANCE = Downcall.load("libglfw3.so");
GLFW INSTANCE = Downcall.load(MethodHandles.lookup(), "libglfw3.so");

/**
* A field
Expand Down Expand Up @@ -94,6 +95,6 @@ Import as a Gradle dependency:

```groovy
dependencies {
implementation("io.github.over-run:marshal:0.1.0-alpha.16-jdk22")
implementation("io.github.over-run:marshal:0.1.0-alpha.17-jdk22")
}
```
15 changes: 6 additions & 9 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ allprojects {
dependencies {
// add your dependencies
compileOnly("org.jetbrains:annotations:24.1.0")
testImplementation(platform("org.junit:junit-bom:5.10.1"))
testImplementation("org.junit.jupiter:junit-jupiter")
}

tasks.withType<Test> {
useJUnitPlatform()
}

tasks.withType<JavaCompile> {
Expand Down Expand Up @@ -129,15 +135,6 @@ allprojects {
}
}

dependencies {
testImplementation(platform("org.junit:junit-bom:5.10.1"))
testImplementation("org.junit.jupiter:junit-jupiter")
}

tasks.withType<Test> {
useJUnitPlatform()
}

tasks.withType<Jar> {
archiveBaseName = projArtifactId
from(rootProject.file(projLicenseFileName)).rename(
Expand Down
3 changes: 3 additions & 0 deletions demo/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dependencies {
implementation(rootProject)
}
26 changes: 26 additions & 0 deletions demo/src/test/java/module-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* MIT License
*
* Copyright (c) 2024 Overrun Organization
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*/

/**
* @author squid233
* @since 0.1.0
*/
module io.github.overrun.marshal.demo {
exports overrun.marshal.demo;
opens overrun.marshal.demo;
requires io.github.overrun.marshal;
requires org.junit.jupiter.api;
}
60 changes: 60 additions & 0 deletions demo/src/test/java/overrun/marshal/demo/CrossModuleTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* MIT License
*
* Copyright (c) 2024 Overrun Organization
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*/

package overrun.marshal.demo;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import overrun.marshal.Downcall;

import java.lang.foreign.*;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Optional;

/**
* Test cross module
*
* @author squid233
* @since 0.1.0
*/
public final class CrossModuleTest {
private static final Linker LINKER = Linker.nativeLinker();
private static final MemorySegment s_get;

static {
try {
s_get = LINKER.upcallStub(MethodHandles.lookup().findStatic(CrossModuleTest.class, "get", MethodType.methodType(int.class)), FunctionDescriptor.of(ValueLayout.JAVA_INT), Arena.ofAuto());
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

private static final SymbolLookup LOOKUP = name -> "get".equals(name) ? Optional.of(s_get) : Optional.empty();

private static int get() {
return 1;
}

public interface I {
int get();
}

@Test
void testCrossModule() {
Assertions.assertEquals(1, Downcall.load(MethodHandles.lookup(), I.class, LOOKUP).get());
}
}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ projGroupId=io.github.over-run
projArtifactId=marshal
# The project name should only contain lowercase letters, numbers and hyphen.
projName=marshal
projVersion=0.1.0-alpha.16-jdk22
projVersion=0.1.0-alpha.17-jdk22
projDesc=Marshaler of native libraries
# Uncomment them if you want to publish to maven repository.
projUrl=https://github.com/Over-Run/marshal
Expand Down
2 changes: 2 additions & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pluginManagement {

val projName: String by settings
rootProject.name = projName

include("demo")
47 changes: 28 additions & 19 deletions src/main/java/overrun/marshal/Downcall.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@
/**
* Downcall library loader.
* <h2>Loading native library</h2>
* You can load native libraries with {@link #load(Class, SymbolLookup)}.
* You can load native libraries with {@link #load(MethodHandles.Lookup, Class, SymbolLookup)}.
* This method generates a hidden class that loads method handle with the given symbol lookup.
* <p>
* The {@code load} methods accept a lookup object for defining hidden class with the caller.
* The lookup object <strong>MUST</strong> have {@linkplain MethodHandles.Lookup#hasFullPrivilegeAccess() full privilege access}
* which allows {@code Downcall} to define the hidden class.
* You can obtain that lookup object with {@link MethodHandles#lookup()}.
* <p>
* The generated class implements the target class.
* The target class <strong>MUST</strong> be {@code public}
* as the generated class is defined in package {@code overrun.marshal}.
* <h2>Methods</h2>
* The loader finds method from the target class and its superclasses.
* <p>
Expand All @@ -74,7 +77,8 @@
* <h3>Example</h3>
* <pre>{@code
* public interface GL {
* GL INSTANCE = Downcall.load(lookup, Map.of("glClear", FunctionDescriptor.ofVoid(ValueLayout.JAVA_INT)));
* GL INSTANCE = Downcall.load(MethodHandles.lookup(), lookup,
* Map.of("glClear", FunctionDescriptor.ofVoid(ValueLayout.JAVA_INT)));
* MethodHandle glClear();
*
* @Skip
Expand All @@ -86,7 +90,7 @@
* <h2>Example</h2>
* <pre>{@code
* public interface GL {
* GL INSTANCE = Downcall.load("libGL.so");
* GL INSTANCE = Downcall.load(MethodHandles.lookup(), "libGL.so");
* int COLOR_BUFFER_BIT = 0x00004000;
* void glClear(int mask);
* }
Expand All @@ -103,7 +107,6 @@
* @since 0.1.0
*/
public final class Downcall {
private static final StackWalker STACK_WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
private static final ClassDesc CD_Addressable = ClassDesc.of("overrun.marshal.Addressable");
private static final ClassDesc CD_AddressLayout = ClassDesc.of("java.lang.foreign.AddressLayout");
private static final ClassDesc CD_Arena = ClassDesc.of("java.lang.foreign.Arena");
Expand Down Expand Up @@ -190,63 +193,68 @@ private Downcall() {
/**
* Loads the given class with the given symbol lookup.
*
* @param caller the lookup object for the caller
* @param targetClass the target class
* @param lookup the symbol lookup
* @param descriptorMap the custom function descriptors for each method handle
* @param <T> the type of the target class
* @return the loaded implementation instance of the target class
*/
public static <T> T load(Class<T> targetClass, SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
return loadBytecode(targetClass, lookup, descriptorMap);
public static <T> T load(MethodHandles.Lookup caller, Class<T> targetClass, SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
return loadBytecode(caller, targetClass, lookup, descriptorMap);
}

/**
* Loads the given class with the given symbol lookup.
*
* @param caller the lookup object for the caller
* @param targetClass the target class
* @param lookup the symbol lookup
* @param <T> the type of the target class
* @return the loaded implementation instance of the target class
*/
public static <T> T load(Class<T> targetClass, SymbolLookup lookup) {
return load(targetClass, lookup, Map.of());
public static <T> T load(MethodHandles.Lookup caller, Class<T> targetClass, SymbolLookup lookup) {
return load(caller, targetClass, lookup, Map.of());
}

/**
* Loads the caller class with the given symbol lookup.
*
* @param caller the lookup object for the caller
* @param lookup the symbol lookup
* @param descriptorMap the custom function descriptors for each method handle
* @param <T> the type of the caller class
* @return the loaded implementation instance of the caller class
*/
@SuppressWarnings("unchecked")
public static <T> T load(SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
return load((Class<T>) STACK_WALKER.getCallerClass(), lookup, descriptorMap);
public static <T> T load(MethodHandles.Lookup caller, SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
return load(caller, (Class<T>) caller.lookupClass(), lookup, descriptorMap);
}

/**
* Loads the caller class with the given symbol lookup.
*
* @param caller the lookup object for the caller
* @param lookup the symbol lookup
* @param <T> the type of the caller class
* @return the loaded implementation instance of the caller class
*/
@SuppressWarnings("unchecked")
public static <T> T load(SymbolLookup lookup) {
return load((Class<T>) STACK_WALKER.getCallerClass(), lookup);
public static <T> T load(MethodHandles.Lookup caller, SymbolLookup lookup) {
return load(caller, (Class<T>) caller.lookupClass(), lookup);
}

/**
* Loads the caller class with the given library name.
*
* @param caller the lookup object for the caller
* @param libname the library name
* @param <T> the type of the caller class
* @return the loaded implementation instance of the caller class
*/
@SuppressWarnings("unchecked")
public static <T> T load(String libname) {
return load((Class<T>) STACK_WALKER.getCallerClass(), SymbolLookup.libraryLookup(libname, Arena.ofAuto()));
public static <T> T load(MethodHandles.Lookup caller, String libname) {
return load(caller, (Class<T>) caller.lookupClass(), SymbolLookup.libraryLookup(libname, Arena.ofAuto()));
}

private static void convertToValueLayout(CodeBuilder codeBuilder, Class<?> aClass) {
Expand Down Expand Up @@ -399,9 +407,10 @@ private static Method findUpcallWrapper(Class<?> aClass) {
}

@SuppressWarnings("unchecked")
private static <T> T loadBytecode(Class<?> targetClass, SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
private static <T> T loadBytecode(MethodHandles.Lookup caller, Class<?> targetClass, SymbolLookup lookup, Map<String, FunctionDescriptor> descriptorMap) {
final Class<?> lookupClass = caller.lookupClass();
final ClassFile cf = of();
final ClassDesc cd_thisClass = ClassDesc.of(Downcall.class.getPackageName(), DEFAULT_NAME);
final ClassDesc cd_thisClass = ClassDesc.of(lookupClass.getPackageName(), DEFAULT_NAME);
final byte[] bytes = cf.build(cd_thisClass, classBuilder -> {
final List<Method> methodList = Arrays.stream(targetClass.getMethods())
.filter(method ->
Expand Down Expand Up @@ -1048,7 +1057,7 @@ private static <T> T loadBytecode(Class<?> targetClass, SymbolLookup lookup, Map
});

try {
final MethodHandles.Lookup hiddenClass = MethodHandles.lookup()
final MethodHandles.Lookup hiddenClass = caller
.defineHiddenClassWithClassData(bytes, Map.copyOf(descriptorMap), true, MethodHandles.Lookup.ClassOption.STRONG);
return (T) hiddenClass.findConstructor(hiddenClass.lookupClass(), MethodType.methodType(void.class, SymbolLookup.class))
.invoke(lookup);
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/overrun/marshal/test/DescriptorMapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static boolean acceptLong(long d) {

public interface Interface {
static Interface getInstance(ValueLayout returnLayout, ValueLayout acceptLayout) {
return Downcall.load(lookup(returnLayout, acceptLayout), Map.of(
return Downcall.load(MethodHandles.lookup(), lookup(returnLayout, acceptLayout), Map.of(
"testReturn", FunctionDescriptor.of(returnLayout),
"testAccept", FunctionDescriptor.of(JAVA_BOOLEAN, acceptLayout)
));
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/overrun/marshal/test/ID3.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static SymbolLookup load() {
}
}

ID3 INSTANCE = Downcall.load(Provider.load());
ID3 INSTANCE = Downcall.load(MethodHandles.lookup(), Provider.load());

@Override
@Entrypoint("get1")
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/overrun/marshal/test/IDowncall.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.Map;

/**
Expand All @@ -35,7 +36,7 @@ public interface IDowncall {
Map<String, FunctionDescriptor> MAP = Map.of("testDefault", FunctionDescriptor.of(ValueLayout.JAVA_INT));

static IDowncall getInstance(boolean testDefaultNull) {
return Downcall.load(DowncallProvider.lookup(testDefaultNull), MAP);
return Downcall.load(MethodHandles.lookup(), DowncallProvider.lookup(testDefaultNull), MAP);
}

void test();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import overrun.marshal.Downcall;

import java.lang.foreign.SymbolLookup;
import java.lang.invoke.MethodHandles;

/**
* Test indirect interface
Expand All @@ -38,7 +39,7 @@ default int fun1() {
public interface I2 extends I1 {
}

I2 INSTANCE = Downcall.load(I2.class, SymbolLookup.loaderLookup());
I2 INSTANCE = Downcall.load(MethodHandles.lookup(), I2.class, SymbolLookup.loaderLookup());

@Test
void testIndirectInterface() {
Expand Down

0 comments on commit 4ed2f7b

Please sign in to comment.