Skip to content

Commit fc685b3

Browse files
Merge pull request astrale-sharp#13 from arnaudgolfouse/avoid-return-copy
Avoid copy when the plugin returns
2 parents 052fba5 + d200f8c commit fc685b3

File tree

6 files changed

+179
-90
lines changed

6 files changed

+179
-90
lines changed

examples/hello_c/hello.c

+27-12
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,37 @@
44
#include <cstddef>
55
#include <cstdint>
66
#include <cstdlib>
7+
#include <cstring>
78
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern "C"
89
#else
910
#include <stddef.h>
1011
#include <stdint.h>
1112
#include <stdlib.h>
13+
#include <string.h>
1214
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern
1315
#endif
1416

17+
// ===
18+
// Functions for the protocol
19+
1520
PROTOCOL_FUNCTION void
1621
wasm_minimal_protocol_send_result_to_host(const uint8_t *ptr, size_t len);
1722
PROTOCOL_FUNCTION void wasm_minimal_protocol_write_args_to_buffer(uint8_t *ptr);
1823

24+
EMSCRIPTEN_KEEPALIVE void wasm_minimal_protocol_free_byte_buffer(uint8_t *ptr,
25+
size_t len) {
26+
free(ptr);
27+
}
28+
29+
// ===
30+
1931
EMSCRIPTEN_KEEPALIVE
2032
int32_t hello(void) {
21-
const char message[] = "Hello world !";
22-
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
23-
sizeof(message) - 1);
33+
const char static_message[] = "Hello world !";
34+
const size_t length = sizeof(static_message);
35+
char *message = malloc(length);
36+
memcpy((void *)message, (void *)static_message, length);
37+
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
2438
return 0;
2539
}
2640

@@ -36,7 +50,6 @@ int32_t double_it(size_t arg_len) {
3650
alloc_result[arg_len + i] = alloc_result[i];
3751
}
3852
wasm_minimal_protocol_send_result_to_host(alloc_result, result_len);
39-
free(alloc_result);
4053
return 0;
4154
}
4255

@@ -66,7 +79,6 @@ int32_t concatenate(size_t arg1_len, size_t arg2_len) {
6679

6780
wasm_minimal_protocol_send_result_to_host(result, total_len + 1);
6881

69-
free(result);
7082
free(args);
7183
return 0;
7284
}
@@ -102,24 +114,27 @@ int32_t shuffle(size_t arg1_len, size_t arg2_len, size_t arg3_len) {
102114

103115
wasm_minimal_protocol_send_result_to_host(result, result_len);
104116

105-
free(result);
106117
free(args);
107118
return 0;
108119
}
109120

110121
EMSCRIPTEN_KEEPALIVE
111122
int32_t returns_ok() {
112-
const char message[] = "This is an `Ok`";
113-
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
114-
sizeof(message) - 1);
123+
const char static_message[] = "This is an `Ok`";
124+
const size_t length = sizeof(static_message);
125+
char *message = malloc(length);
126+
memcpy((void *)message, (void *)static_message, length);
127+
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
115128
return 0;
116129
}
117130

118131
EMSCRIPTEN_KEEPALIVE
119132
int32_t returns_err() {
120-
const char message[] = "This is an `Err`";
121-
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
122-
sizeof(message) - 1);
133+
const char static_message[] = "This is an `Err`";
134+
const size_t length = sizeof(static_message);
135+
char *message = malloc(length);
136+
memcpy((void *)message, (void *)static_message, length);
137+
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
123138
return 1;
124139
}
125140

examples/hello_zig/hello.zig

+32-17
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
11
const std = @import("std");
22
const allocator = std.heap.page_allocator;
33

4+
// ===
5+
// Functions for the protocol
6+
47
extern "typst_env" fn wasm_minimal_protocol_send_result_to_host(ptr: [*]const u8, len: usize) void;
58
extern "typst_env" fn wasm_minimal_protocol_write_args_to_buffer(ptr: [*]u8) void;
69

10+
export fn wasm_minimal_protocol_free_byte_buffer(ptr: [*]u8, len: usize) void {
11+
var slice: []u8 = undefined;
12+
slice.ptr = ptr;
13+
slice.len = len;
14+
allocator.free(slice);
15+
}
16+
17+
// ===
18+
719
export fn hello() i32 {
820
const message = "Hello world !";
9-
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
21+
var result = allocator.alloc(u8, message.len) catch return 1;
22+
@memcpy(result, message);
23+
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
1024
return 0;
1125
}
1226

1327
export fn double_it(arg1_len: usize) i32 {
14-
var alloc_result = allocator.alloc(u8, arg1_len * 2) catch return 1;
15-
defer allocator.free(alloc_result);
16-
wasm_minimal_protocol_write_args_to_buffer(alloc_result.ptr);
28+
var result = allocator.alloc(u8, arg1_len * 2) catch return 1;
29+
wasm_minimal_protocol_write_args_to_buffer(result.ptr);
1730
for (0..arg1_len) |i| {
18-
alloc_result[i + arg1_len] = alloc_result[i];
31+
result[i + arg1_len] = result[i];
1932
}
20-
wasm_minimal_protocol_send_result_to_host(alloc_result.ptr, alloc_result.len);
33+
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
2134
return 0;
2235
}
2336

@@ -27,7 +40,6 @@ export fn concatenate(arg1_len: usize, arg2_len: usize) i32 {
2740
wasm_minimal_protocol_write_args_to_buffer(args.ptr);
2841

2942
var result = allocator.alloc(u8, arg1_len + arg2_len + 1) catch return 1;
30-
defer allocator.free(result);
3143
for (0..arg1_len) |i| {
3244
result[i] = args[i];
3345
}
@@ -49,27 +61,30 @@ export fn shuffle(arg1_len: usize, arg2_len: usize, arg3_len: usize) i32 {
4961
var arg2 = args[arg1_len .. arg1_len + arg2_len];
5062
var arg3 = args[arg1_len + arg2_len .. args.len];
5163

52-
var result: std.ArrayList(u8) = std.ArrayList(u8).initCapacity(allocator, args_len + 2) catch return 1;
53-
defer result.deinit();
54-
result.appendSlice(arg3) catch return 1;
55-
result.append('-') catch return 1;
56-
result.appendSlice(arg1) catch return 1;
57-
result.append('-') catch return 1;
58-
result.appendSlice(arg2) catch return 1;
64+
var result = allocator.alloc(u8, arg1_len + arg2_len + arg3_len + 2) catch return 1;
65+
@memcpy(result[0..arg3.len], arg3);
66+
result[arg3.len] = '-';
67+
@memcpy(result[arg3.len + 1 ..][0..arg1.len], arg1);
68+
result[arg3.len + arg1.len + 1] = '-';
69+
@memcpy(result[arg3.len + arg1.len + 2 ..][0..arg2.len], arg2);
5970

60-
wasm_minimal_protocol_send_result_to_host(result.items.ptr, result.items.len);
71+
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
6172
return 0;
6273
}
6374

6475
export fn returns_ok() i32 {
6576
const message = "This is an `Ok`";
66-
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
77+
var result = allocator.alloc(u8, message.len) catch return 1;
78+
@memcpy(result, message);
79+
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
6780
return 0;
6881
}
6982

7083
export fn returns_err() i32 {
7184
const message = "This is an `Err`";
72-
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
85+
var result = allocator.alloc(u8, message.len) catch return 1;
86+
@memcpy(result, message);
87+
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
7388
return 1;
7489
}
7590

examples/host-wasmi/src/lib.rs

+79-42
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,67 @@
1-
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Module, Value};
1+
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Memory, Module, Value};
22

33
type Store = wasmi::Store<PersistentData>;
44

5+
/// Reference to a slice of memory returned after
6+
/// [calling a wasm function](PluginInstance::call).
7+
///
8+
/// # Drop
9+
/// On [`Drop`], this will free the slice of memory inside the plugin.
10+
///
11+
/// As such, this structure mutably borrows the [`PluginInstance`], which prevents
12+
/// another function from being called.
13+
pub struct ReturnedData<'a> {
14+
memory: Memory,
15+
ptr: u32,
16+
len: u32,
17+
free_function: &'a Function,
18+
context_mut: &'a mut Store,
19+
}
20+
21+
impl<'a> ReturnedData<'a> {
22+
/// Get a reference to the returned slice of data.
23+
///
24+
/// # Panic
25+
/// This may panic if the function returned an invalid `(ptr, len)` pair.
26+
pub fn get(&self) -> &[u8] {
27+
&self.memory.data(&*self.context_mut)[self.ptr as usize..(self.ptr + self.len) as usize]
28+
}
29+
}
30+
31+
impl Drop for ReturnedData<'_> {
32+
fn drop(&mut self) {
33+
self.free_function
34+
.call(
35+
&mut *self.context_mut,
36+
&[Value::I32(self.ptr as _), Value::I32(self.len as _)],
37+
&mut [],
38+
)
39+
.unwrap();
40+
}
41+
}
42+
543
#[derive(Debug, Clone)]
644
struct PersistentData {
7-
result_data: Vec<u8>,
45+
result_ptr: u32,
46+
result_len: u32,
847
arg_buffer: Vec<u8>,
948
}
1049

1150
#[derive(Debug)]
1251
pub struct PluginInstance {
1352
store: Store,
53+
memory: Memory,
54+
free_function: Function,
1455
functions: Vec<(String, Function)>,
1556
}
1657

1758
impl PluginInstance {
1859
pub fn new_from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
1960
let engine = Engine::default();
2061
let data = PersistentData {
21-
result_data: Vec::new(),
2262
arg_buffer: Vec::new(),
63+
result_ptr: 0,
64+
result_len: 0,
2365
};
2466
let mut store = Store::new(&engine, data);
2567

@@ -32,11 +74,8 @@ impl PluginInstance {
3274
"typst_env",
3375
"wasm_minimal_protocol_send_result_to_host",
3476
move |mut caller: Caller<PersistentData>, ptr: u32, len: u32| {
35-
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
36-
let mut buffer = std::mem::take(&mut caller.data_mut().result_data);
37-
buffer.resize(len as usize, 0);
38-
memory.read(&caller, ptr as _, &mut buffer).unwrap();
39-
caller.data_mut().result_data = buffer;
77+
caller.data_mut().result_ptr = ptr;
78+
caller.data_mut().result_len = len;
4079
},
4180
)
4281
.unwrap()
@@ -51,54 +90,44 @@ impl PluginInstance {
5190
},
5291
)
5392
.unwrap()
54-
// hack to accept wasi file
55-
// https://github.com/near/wasi-stub is preferred
56-
/*
57-
.func_wrap(
58-
"wasi_snapshot_preview1",
59-
"fd_write",
60-
|_: i32, _: i32, _: i32, _: i32| 0i32,
61-
)
62-
.unwrap()
63-
.func_wrap(
64-
"wasi_snapshot_preview1",
65-
"environ_get",
66-
|_: i32, _: i32| 0i32,
67-
)
68-
.unwrap()
69-
.func_wrap(
70-
"wasi_snapshot_preview1",
71-
"environ_sizes_get",
72-
|_: i32, _: i32| 0i32,
73-
)
74-
.unwrap()
75-
.func_wrap(
76-
"wasi_snapshot_preview1",
77-
"proc_exit",
78-
|_: i32| {},
79-
)
80-
.unwrap()
81-
*/
8293
.instantiate(&mut store, &module)
8394
.map_err(|e| format!("{e}"))?
8495
.start(&mut store)
8596
.map_err(|e| format!("{e}"))?;
8697

98+
let mut free_function = None;
8799
let functions = instance
88100
.exports(&store)
89101
.filter_map(|e| {
90102
let name = e.name().to_owned();
91-
e.into_func().map(|func| (name, func))
103+
104+
e.into_func().map(|func| {
105+
if name == "wasm_minimal_protocol_free_byte_buffer" {
106+
free_function = Some(func);
107+
}
108+
(name, func)
109+
})
92110
})
93111
.collect::<Vec<_>>();
94-
Ok(Self { store, functions })
112+
let free_function = free_function.unwrap();
113+
let memory = instance
114+
.get_export(&store, "memory")
115+
.unwrap()
116+
.into_memory()
117+
.unwrap();
118+
Ok(Self {
119+
store,
120+
memory,
121+
free_function,
122+
functions,
123+
})
95124
}
96125

97126
fn write(&mut self, args: &[&[u8]]) {
98127
self.store.data_mut().arg_buffer = args.concat();
99128
}
100129

101-
pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<Vec<u8>, String> {
130+
pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<ReturnedData, String> {
102131
self.write(args);
103132

104133
let (_, function) = self
@@ -122,11 +151,19 @@ impl PluginInstance {
122151
code.first().cloned().unwrap_or(Value::I32(3)) // if the function returns nothing
123152
};
124153

125-
let s = std::mem::take(&mut self.store.data_mut().result_data);
154+
let (ptr, len) = (self.store.data().result_ptr, self.store.data().result_len);
155+
156+
let result = ReturnedData {
157+
memory: self.memory,
158+
ptr,
159+
len,
160+
free_function: &self.free_function,
161+
context_mut: &mut self.store,
162+
};
126163

127164
match code {
128-
Value::I32(0) => Ok(s),
129-
Value::I32(1) => Err(match String::from_utf8(s) {
165+
Value::I32(0) => Ok(result),
166+
Value::I32(1) => Err(match std::str::from_utf8(result.get()) {
130167
Ok(err) => format!("plugin errored with: '{}'", err,),
131168
Err(_) => String::from("plugin errored and did not return valid UTF-8"),
132169
}),

examples/test-runner/src/main.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
// you need to build the hello example first
33

44
use anyhow::Result;
5-
use std::process::Command;
6-
75
use host_wasmi::PluginInstance;
6+
use std::process::Command;
87

98
#[cfg(not(feature = "wasi"))]
109
mod consts {
@@ -118,7 +117,7 @@ fn main() -> Result<()> {
118117
return Ok(());
119118
}
120119
};
121-
match String::from_utf8(result) {
120+
match std::str::from_utf8(result.get()) {
122121
Ok(s) => println!("{s}"),
123122
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
124123
}
@@ -141,7 +140,7 @@ fn main() -> Result<()> {
141140
continue;
142141
}
143142
};
144-
match String::from_utf8(result) {
143+
match std::str::from_utf8(result.get()) {
145144
Ok(s) => println!("{s}"),
146145
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
147146
}

0 commit comments

Comments
 (0)