diff --git a/core-common/src/main/java/io/roastedroot/protobuf4j/common/Protobuf.java b/core-common/src/main/java/io/roastedroot/protobuf4j/common/Protobuf.java index 5a44b5d..8562007 100644 --- a/core-common/src/main/java/io/roastedroot/protobuf4j/common/Protobuf.java +++ b/core-common/src/main/java/io/roastedroot/protobuf4j/common/Protobuf.java @@ -87,13 +87,6 @@ public static ImportMemory defaultMemory() { new MemoryLimits(WASM_INITIAL_MEMORY_PAGES, MemoryLimits.MAX_PAGES, true))); } - private static int writeCString(Instance instance, String str) { - byte[] strBytes = str.getBytes(StandardCharsets.UTF_8); - var strPtr = (int) instance.exports().function("malloc").apply(strBytes.length + 1)[0]; - instance.memory().writeCString(strPtr, str); - return strPtr; - } - public static PluginProtos.CodeGeneratorResponse runNativePlugin( Function instanceBuilder, NativePlugin plugin, @@ -186,20 +179,20 @@ public static DescriptorProtos.FileDescriptorSet getDescriptors( fileNamesStrBuilder.append(file); fileNamesStrBuilder.append(FILE_NAMES_SEPARATOR); } - var ptr = writeCString(instance, fileNamesStrBuilder.toString()); - - var result = exports.exportDescriptors(ptr); - if (result == 0) { - throw new RuntimeException("Null pointer returned from protobuf"); + try (var namesBuffer = new WasmCStringBuffer(exports, fileNamesStrBuilder.toString())) { + var result = exports.exportDescriptors(namesBuffer.ptr()); + if (result == 0) { + throw new RuntimeException("Null pointer returned from protobuf"); + } + var resultPtr = (int) (result & 0xFFFFFFFFL); + var resultLen = (int) ((result >> 32) & 0xFFFFFFFFL); + try { + var resultBytes = exports.memory().readBytes(resultPtr, resultLen); + return DescriptorProtos.FileDescriptorSet.parseFrom(resultBytes); + } finally { + exports.free(resultPtr); + } } - var resultPtr = (int) (result & 0xFFFFFFFFL); - var resultLen = (int) ((result >> 32) & 0xFFFFFFFFL); - var resultBytes = exports.memory().readBytes(resultPtr, resultLen); - - exports.free(ptr); - exports.free(resultPtr); - - return DescriptorProtos.FileDescriptorSet.parseFrom(resultBytes); } catch (IOException e) { throw new RuntimeException( "Failed to generate java files from proto files " @@ -353,18 +346,14 @@ public static CompatibilityResult checkCompatibility( public static ValidationResult validateSyntax(Instance instance, String fileName) { var exports = new Protobuf_ModuleExports(instance); - var ptr = writeCString(instance, fileName); - try { - var result = exports.validateSyntax(ptr); + try (var nameBuffer = new WasmCStringBuffer(exports, fileName)) { + var result = exports.validateSyntax(nameBuffer.ptr()); if (result == 0) { return ValidationResult.valid(); - } else { - var res = ValidationResult.invalid(exports.memory().readCString(result)); - exports.free(result); - return res; } - } finally { - exports.free(ptr); + var res = ValidationResult.invalid(exports.memory().readCString(result)); + exports.free(result); + return res; } } diff --git a/core-common/src/main/java/io/roastedroot/protobuf4j/common/WasmCStringBuffer.java b/core-common/src/main/java/io/roastedroot/protobuf4j/common/WasmCStringBuffer.java new file mode 100644 index 0000000..20b20a3 --- /dev/null +++ b/core-common/src/main/java/io/roastedroot/protobuf4j/common/WasmCStringBuffer.java @@ -0,0 +1,28 @@ +package io.roastedroot.protobuf4j.common; + +import java.nio.charset.StandardCharsets; + +/** + * Malloc'd null-terminated UTF-8 string in WASM memory. Prefer try-with-resources so {@link #close} + * runs on all exit paths (including exceptions). + */ +final class WasmCStringBuffer implements AutoCloseable { + private final Protobuf_ModuleExports exports; + private final int ptr; + + WasmCStringBuffer(Protobuf_ModuleExports exports, String str) { + this.exports = exports; + byte[] strBytes = str.getBytes(StandardCharsets.UTF_8); + this.ptr = exports.malloc(strBytes.length + 1); + exports.memory().writeCString(ptr, str); + } + + int ptr() { + return ptr; + } + + @Override + public void close() { + exports.free(ptr); + } +} diff --git a/core-v4/src/test/java/io/roastedroot/protobuf4j/common/WasmCStringBufferTest.java b/core-v4/src/test/java/io/roastedroot/protobuf4j/common/WasmCStringBufferTest.java new file mode 100644 index 0000000..2abbf47 --- /dev/null +++ b/core-v4/src/test/java/io/roastedroot/protobuf4j/common/WasmCStringBufferTest.java @@ -0,0 +1,62 @@ +package io.roastedroot.protobuf4j.common; + +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.dylibso.chicory.runtime.ByteArrayMemory; +import com.dylibso.chicory.runtime.ImportValues; +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.wasi.WasiOptions; +import com.dylibso.chicory.wasi.WasiPreview1; +import io.roastedroot.protobuf4j.ProtobufWrapperV4; +import io.roastedroot.zerofs.Configuration; +import io.roastedroot.zerofs.ZeroFs; +import java.nio.file.FileSystem; +import java.nio.file.Path; +import org.junit.jupiter.api.Test; + +/** + * Ensures {@link WasmCStringBuffer} is closed when the try block exits abnormally, so WASM malloc + * does not leak across iterations. + */ +public class WasmCStringBufferTest { + + @Test + void closesWhenTryBlockThrows() throws Exception { + try (FileSystem fs = + ZeroFs.newFileSystem( + Configuration.unix().toBuilder().setAttributeViews("unix").build())) { + Path workdir = fs.getPath("."); + try (WasiPreview1 wasi = + WasiPreview1.builder() + .withOptions( + WasiOptions.builder() + .withDirectory(workdir.toString(), workdir) + .build()) + .build()) { + Instance instance = + Instance.builder(ProtobufWrapperV4.load()) + .withImportValues( + ImportValues.builder() + .addFunction(wasi.toHostFunctions()) + .addMemory(Protobuf.defaultMemory()) + .build()) + .withMachineFactory(ProtobufWrapperV4::create) + .withMemoryFactory(ByteArrayMemory::new) + .withStart(false) + .build(); + var exports = new Protobuf_ModuleExports(instance); + for (int i = 0; i < 2000; i++) { + assertThrows( + IllegalStateException.class, + () -> { + try (var buf = new WasmCStringBuffer(exports, "x")) { + assertNotEquals(0, buf.ptr(), "malloc returned 0"); + throw new IllegalStateException("boom"); + } + }); + } + } + } + } +}