Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNA callback should check inputs and outputs for null. #30

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions src/main/java/org/extism/sdk/HostFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,12 @@ public class HostFunction<T extends HostUserData> {

public final LibExtism.ExtismValType[] returns;

public final Optional<T> userData;

public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.ExtismValType[] returns, ExtismFunction f, Optional<T> userData) {
this.freed = false;
this.name = name;
this.params = params;
this.returns = returns;
this.userData = userData;
this.callback = (Pointer currentPlugin,
LibExtism.ExtismVal inputs,
int nInputs,
LibExtism.ExtismVal outs,
int nOutputs,
Pointer data) -> {

LibExtism.ExtismVal[] outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);

f.invoke(
new ExtismCurrentPlugin(currentPlugin),
(LibExtism.ExtismVal[]) inputs.toArray(nInputs),
outputs,
userData
);

for (LibExtism.ExtismVal output : outputs) {
convertOutput(output, output);
}
};
this.callback = new Callback(f, userData);

this.pointer = LibExtism.INSTANCE.extism_function_new(
this.name,
Expand All @@ -61,7 +39,7 @@ public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.Ext
);
}

void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
static void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
if (fromHostFunction.t != original.t)
throw new ExtismException(String.format("Output type mismatch, got %d but expected %d", fromHostFunction.t, original.t));

Expand Down Expand Up @@ -103,4 +81,45 @@ public void free() {
this.freed = true;
}
}

static class Callback<T> implements LibExtism.InternalExtismFunction {
private final ExtismFunction f;
private final Optional<T> userData;

public Callback(ExtismFunction f, Optional<T> userData) {
this.f = f;
this.userData = userData;
}

@Override
public void invoke(Pointer currentPlugin, LibExtism.ExtismVal ins, int nInputs, LibExtism.ExtismVal outs, int nOutputs, Pointer data) {

LibExtism.ExtismVal[] inputs;
LibExtism.ExtismVal[] outputs;

if (outs == null) {
if (nOutputs > 0) {
throw new ExtismException("Output array is null but nOutputs is greater than 0");
}
outputs = new LibExtism.ExtismVal[0];
} else {
outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);
}

if (ins == null) {
if (nInputs > 0) {
throw new ExtismException("Input array is null but nInputs is greater than 0");
}
inputs = new LibExtism.ExtismVal[0];
} else {
inputs = (LibExtism.ExtismVal[]) ins.toArray(nInputs);
}

f.invoke(new ExtismCurrentPlugin(currentPlugin), inputs, outputs, userData);

for (LibExtism.ExtismVal output : outputs) {
convertOutput(output, output);
}
}
}
}
25 changes: 25 additions & 0 deletions src/test/java/org/extism/sdk/HostFunctionTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.extism.sdk;

import com.sun.jna.Pointer;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertThrows;

public class HostFunctionTests {
@Test
public void callbackShouldAcceptNullParameters() {
var callback = new HostFunction.Callback<>(
(plugin, params, returns, userData) -> {/* NOOP */}, null);
callback.invoke(Pointer.NULL, null, 0, null, 0, Pointer.NULL);
}

@Test
public void callbackShouldThrowOnNullParametersAndNonzeroCounts() {
var callback = new HostFunction.Callback<>(
(plugin, params, returns, userData) -> {/* NOOP */}, null);
assertThrows(ExtismException.class, () ->
callback.invoke(Pointer.NULL, null, 1, null, 0, Pointer.NULL));
assertThrows(ExtismException.class, () ->
callback.invoke(Pointer.NULL, null, 0, null, 1, Pointer.NULL));
}
}
102 changes: 51 additions & 51 deletions src/test/java/org/extism/sdk/PluginTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,57 +53,57 @@ public void shouldInvokeFunctionFromUrlWasmSource() {
assertThat(output).isEqualTo("{\"count\":4,\"total\":4,\"vowels\":\"aeiouyAEIOUY\"}");
}

// @Test
// public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
// var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
// var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));
//
// // Our application KV store
// // Pretend this is redis or a database :)
// var kvStore = new HashMap<String, byte[]>();
//
// ExtismFunction kvWrite = (plugin, params, returns, data) -> {
// System.out.println("Hello from Java Host Function!");
// var key = plugin.inputString(params[0]);
// var value = plugin.inputBytes(params[1]);
// System.out.println("Writing to key " + key);
// kvStore.put(key, value);
// };
//
// ExtismFunction kvRead = (plugin, params, returns, data) -> {
// System.out.println("Hello from Java Host Function!");
// var key = plugin.inputString(params[0]);
// System.out.println("Reading from key " + key);
// var value = kvStore.get(key);
// if (value == null) {
// // default to zeroed bytes
// var zero = new byte[]{0,0,0,0};
// plugin.returnBytes(returns[0], zero);
// } else {
// plugin.returnBytes(returns[0], value);
// }
// };
//
// HostFunction kvWriteHostFn = new HostFunction<>(
// "kv_write",
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
// new LibExtism.ExtismValType[0],
// kvWrite,
// Optional.empty()
// );
//
// HostFunction kvReadHostFn = new HostFunction<>(
// "kv_read",
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
// kvRead,
// Optional.empty()
// );
//
// HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
// var plugin = new Plugin(manifest, false, functions);
// var output = plugin.call("count_vowels", "Hello, World!");
// }
@Test
public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));

// Our application KV store
// Pretend this is redis or a database :)
var kvStore = new HashMap<String, byte[]>();

ExtismFunction kvWrite = (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
var key = plugin.inputString(params[0]);
var value = plugin.inputBytes(params[1]);
System.out.println("Writing to key " + key);
kvStore.put(key, value);
};

ExtismFunction kvRead = (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
var key = plugin.inputString(params[0]);
System.out.println("Reading from key " + key);
var value = kvStore.get(key);
if (value == null) {
// default to zeroed bytes
var zero = new byte[]{0,0,0,0};
plugin.returnBytes(returns[0], zero);
} else {
plugin.returnBytes(returns[0], value);
}
};

HostFunction kvWriteHostFn = new HostFunction<>(
"kv_write",
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
new LibExtism.ExtismValType[0],
kvWrite,
Optional.empty()
);

HostFunction kvReadHostFn = new HostFunction<>(
"kv_read",
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
kvRead,
Optional.empty()
);

HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
var plugin = new Plugin(manifest, false, functions);
var output = plugin.call("count_vowels", "Hello, World!");
}

@Test
public void shouldInvokeFunctionFromByteArrayWasmSource() {
Expand Down
Loading