diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 9fc435d35..ef341cb45 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -67,12 +67,21 @@ template using WasmCallVoid = std::function>; template using WasmCallWord = std::function>; +// Callback used to test arg passing from host to wasm. +using WasmCall_WWlfd = std::function; +// Types used to test return values. Floats are passed as parameters as these +// do not conflict with ProxyWasm ABI signatures. +using WasmCall_lf = std::function; +using WasmCall_fff = std::function; +using WasmCall_dfff = std::function; #define FOR_ALL_WASM_VM_EXPORTS(_f) \ _f(proxy_wasm::WasmCallVoid<0>) _f(proxy_wasm::WasmCallVoid<1>) _f(proxy_wasm::WasmCallVoid<2>) \ _f(proxy_wasm::WasmCallVoid<3>) _f(proxy_wasm::WasmCallVoid<5>) \ _f(proxy_wasm::WasmCallWord<0>) _f(proxy_wasm::WasmCallWord<1>) \ - _f(proxy_wasm::WasmCallWord<2>) _f(proxy_wasm::WasmCallWord<3>) + _f(proxy_wasm::WasmCallWord<2>) _f(proxy_wasm::WasmCallWord<3>) \ + _f(proxy_wasm::WasmCall_WWlfd) _f(proxy_wasm::WasmCall_lf) \ + _f(proxy_wasm::WasmCall_fff) _f(proxy_wasm::WasmCall_dfff) // These are templates and its helper for constructing signatures of functions callbacks from Wasm // VMs. diff --git a/include/proxy-wasm/word.h b/include/proxy-wasm/word.h index bc0d23a8c..a51b981b3 100644 --- a/include/proxy-wasm/word.h +++ b/include/proxy-wasm/word.h @@ -24,8 +24,16 @@ namespace proxy_wasm { // Use byteswap functions only when compiling for big-endian platforms. #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define htowasm(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? __builtin_bswap32(x) : (x)) -#define wasmtoh(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? __builtin_bswap32(x) : (x)) +static inline float bswap(float x) { + return std::bit_cast(__builtin_bswap32(std::bit_cast(x))); +} +static inline double bswap(double x) { + return std::bit_cast(__builtin_bswap64(std::bit_cast(x))); +} +static inline uint32_t bswap(uint32_t x) { return __builtin_bswap32(x); } +static inline auto bswap(auto x) { return __builtin_bswap64(x); } +#define htowasm(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? bswap(x) : (x)) +#define wasmtoh(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? bswap(x) : (x)) #else #define htowasm(x, vm_uses_wasm_byte_order) (x) #define wasmtoh(x, vm_uses_wasm_byte_order) (x) diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 25b6623ca..4642de7a6 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -245,6 +245,7 @@ template <> constexpr auto convertArgToValKind() { return wasm::ValKin template <> constexpr auto convertArgToValKind() { return wasm::ValKind::I64; }; template <> constexpr auto convertArgToValKind() { return wasm::ValKind::I64; }; template <> constexpr auto convertArgToValKind() { return wasm::ValKind::F64; }; +template <> constexpr auto convertArgToValKind() { return wasm::ValKind::F32; }; template constexpr auto convertArgsTupleToValTypesImpl(std::index_sequence /*comptime*/) { diff --git a/src/wamr/wamr.cc b/src/wamr/wamr.cc index 8eef73590..f445186b7 100644 --- a/src/wamr/wamr.cc +++ b/src/wamr/wamr.cc @@ -462,6 +462,10 @@ template <> void assignVal(uint64_t t, wasm_val_t &val) { val.kind = WASM_I64; val.of.i64 = static_cast(t); } +template <> void assignVal(float t, wasm_val_t &val) { + val.kind = WASM_F32; + val.of.f32 = static_cast(t); +} template <> void assignVal(double t, wasm_val_t &val) { val.kind = WASM_F64; val.of.f64 = t; @@ -485,17 +489,21 @@ template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; +template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f64(); }; template T convertValueTypeToArg(wasm_val_t val); template <> uint32_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i32); } -template <> Word convertValueTypeToArg(wasm_val_t val) { return val.of.i32; } +template <> Word convertValueTypeToArg(wasm_val_t val) { + return std::bit_cast(val.of.i32); +} template <> int64_t convertValueTypeToArg(wasm_val_t val) { return val.of.i64; } template <> uint64_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i64); } +template <> float convertValueTypeToArg(wasm_val_t val) { return val.of.f32; } template <> double convertValueTypeToArg(wasm_val_t val) { return val.of.f64; } template diff --git a/src/wasmedge/wasmedge.cc b/src/wasmedge/wasmedge.cc index acfe15b3c..b3873c549 100644 --- a/src/wasmedge/wasmedge.cc +++ b/src/wasmedge/wasmedge.cc @@ -49,6 +49,7 @@ template <> WasmEdge_Value makeVal(uint64_t t) { return WasmEdge_ValueGenI64(static_cast(t)); } template <> WasmEdge_Value makeVal(double t) { return WasmEdge_ValueGenF64(t); } +template <> WasmEdge_Value makeVal(float t) { return WasmEdge_ValueGenF32(t); } // Helper function to print values. std::string printValue(const WasmEdge_Value &value) { @@ -143,6 +144,7 @@ template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeG template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenI32(); } template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenI64(); } template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenI64(); } +template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenF32(); } template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenF64(); } // Helper templates to convert valtype to arg. @@ -150,13 +152,16 @@ template T convValTypeToArg(WasmEdge_Value val); template <> uint32_t convValTypeToArg(WasmEdge_Value val) { return static_cast(WasmEdge_ValueGetI32(val)); } -template <> Word convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetI32(val); } +template <> Word convValTypeToArg(WasmEdge_Value val) { + return std::bit_cast(WasmEdge_ValueGetI32(val)); +} template <> int64_t convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetI64(val); } template <> uint64_t convValTypeToArg(WasmEdge_Value val) { return static_cast(WasmEdge_ValueGetI64(val)); } +template <> float convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetF32(val); } template <> double convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetF64(val); } diff --git a/src/wasmtime/wasmtime.cc b/src/wasmtime/wasmtime.cc index a72a0361d..7fd990262 100644 --- a/src/wasmtime/wasmtime.cc +++ b/src/wasmtime/wasmtime.cc @@ -435,6 +435,10 @@ template <> void assignVal(uint64_t t, wasm_val_t &val) { val.kind = WASM_I64; val.of.i64 = static_cast(t); } +template <> void assignVal(float t, wasm_val_t &val) { + val.kind = WASM_F32; + val.of.f32 = t; +} template <> void assignVal(double t, wasm_val_t &val) { val.kind = WASM_F64; val.of.f64 = t; @@ -458,17 +462,21 @@ template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; +template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f64(); }; template T convertValueTypeToArg(wasm_val_t val); template <> uint32_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i32); } -template <> Word convertValueTypeToArg(wasm_val_t val) { return val.of.i32; } +template <> Word convertValueTypeToArg(wasm_val_t val) { + return std::bit_cast(val.of.i32); +} template <> int64_t convertValueTypeToArg(wasm_val_t val) { return val.of.i64; } template <> uint64_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i64); } +template <> float convertValueTypeToArg(wasm_val_t val) { return val.of.f32; } template <> double convertValueTypeToArg(wasm_val_t val) { return val.of.f64; } template diff --git a/test/BUILD b/test/BUILD index 97e70558a..38d7810d3 100644 --- a/test/BUILD +++ b/test/BUILD @@ -70,6 +70,7 @@ cc_test( timeout = "long", srcs = ["runtime_test.cc"], data = [ + "//test/test_data:arg_passing.wasm", "//test/test_data:callback.wasm", "//test/test_data:clock.wasm", "//test/test_data:resource_limits.wasm", @@ -84,6 +85,22 @@ cc_test( ], ) +cc_test( + name = "arg_passing_test", + timeout = "long", + srcs = ["arg_passing_test.cc"], + data = [ + "//test/test_data:arg_passing.wasm", + ], + linkstatic = 1, + deps = [ + ":utility_lib", + "//:lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "exports_test", srcs = ["exports_test.cc"], diff --git a/test/arg_passing_test.cc b/test/arg_passing_test.cc new file mode 100644 index 000000000..eb6474a8b --- /dev/null +++ b/test/arg_passing_test.cc @@ -0,0 +1,148 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include + +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/wasm.h" + +#include "test/utility.h" + +namespace proxy_wasm { +namespace { + +class ArgPassingContext : public TestContext { +public: + using TestContext::TestContext; + WasmResult getHeaderMapPairs(WasmHeaderMapType /* type */, Pairs * /* result */) override { + // GetHeaderMapPairs passes this value as the hostcall return value as + // opposed to output parameter. + return static_cast(3333333333U); + } +}; + +class ArgPassingWasm : public TestWasm { +public: + using TestWasm::TestWasm; + ContextBase *createVmContext() override { return new ArgPassingContext(this); }; +}; + +class ArgPassingTest : public TestVm { +public: + void SetUp() { + auto source = readTestWasmFile("arg_passing.wasm"); + ASSERT_FALSE(source.empty()); + wasm_.emplace(std::move(vm_)); + ASSERT_TRUE(wasm_->load(source, false)); + ASSERT_TRUE(wasm_->initialize()); + context_ = dynamic_cast(wasm_->vm_context()); + ASSERT_NE(context_, nullptr); + } + + std::optional wasm_; + ArgPassingContext *context_; +}; + +INSTANTIATE_TEST_SUITE_P(WasmEngines, ArgPassingTest, testing::ValuesIn(getWasmEngines()), + [](const testing::TestParamInfo &info) { + return info.param; + }); + +TEST_P(ArgPassingTest, WasmCallReturnsWordValue) { + WasmCallWord<0> test_return_u32; + wasm_->wasm_vm()->getFunction("test_return_u32", &test_return_u32); + + EXPECT_EQ(test_return_u32(context_).u32(), 3333333333U) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsNegativeWordValue) { + WasmCallWord<0> test_return_i32; + wasm_->wasm_vm()->getFunction("test_return_i32", &test_return_i32); + + EXPECT_EQ(test_return_i32(context_).u32(), -1111111111) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsLongValue) { + WasmCall_lf test_return_u64; + wasm_->wasm_vm()->getFunction("test_return_u64", &test_return_u64); + + EXPECT_EQ(test_return_u64(context_, 1.0), 11111111111111111111UL) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsFloatValue) { + WasmCall_fff test_return_f32; + wasm_->wasm_vm()->getFunction("test_return_f32", &test_return_f32); + + EXPECT_THAT(test_return_f32(context_, 1.0, 1.0), + testing::AllOf(testing::Lt(1112.0), testing::Gt(1110.0))) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsDoubleValue) { + WasmCall_dfff test_return_f64; + wasm_->wasm_vm()->getFunction("test_return_f64", &test_return_f64); + + EXPECT_THAT(test_return_f64(context_, 1.0, 1.0, 1.0), + testing::AllOf(testing::Lt(1111111112.0), testing::Gt(1111111110.0))) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostCallReturnsWordValue) { + WasmCallWord<0> test_host_return; + wasm_->wasm_vm()->getFunction("test_host_return", &test_host_return); + + EXPECT_TRUE(test_host_return(context_)) << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostPassesPrimitiveValues) { + WasmCall_WWlfd test_primitives; + wasm_->wasm_vm()->getFunction("test_primitives", &test_primitives); + + ASSERT_TRUE(test_primitives(context_, 3333333333U, 11111111111111111111UL, 1111, 1111111111)) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostPassesNegativePrimitiveValues) { + WasmCall_WWlfd test_negative_primitives; + wasm_->wasm_vm()->getFunction("test_negative_primitives", &test_negative_primitives); + + ASSERT_TRUE( + test_negative_primitives(context_, -1111111111, -1111111111111111111, -1111, -1111111111)) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostReadsPointersToWasmMemory) { + WasmCallWord<0> test_buffer_from_wasm; + wasm_->wasm_vm()->getFunction("test_buffer_from_wasm", &test_buffer_from_wasm); + + ASSERT_TRUE(test_buffer_from_wasm(context_)) << context_->getLog(); + + context_->isLogged("hello from wasm land!"); +} + +TEST_P(ArgPassingTest, WasmCallReadsBufferPassedByHost) { + context_->setBuffer(0, "hello from host land!"); + WasmCallWord<0> test_buffer_from_host; + wasm_->wasm_vm()->getFunction("test_buffer_from_host", &test_buffer_from_host); + + ASSERT_TRUE(test_buffer_from_host(context_)) << context_->getLog(); +} + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ArgPassingTest); + +} // namespace +} // namespace proxy_wasm diff --git a/test/test_data/BUILD b/test/test_data/BUILD index e5ecd439e..4e27933af 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -57,6 +57,11 @@ wasm_rust_binary( srcs = ["trap.rs"], ) +wasm_rust_binary( + name = "arg_passing.wasm", + srcs = ["arg_passing.rs"], +) + wasm_rust_binary( name = "resource_limits.wasm", srcs = ["resource_limits.rs"], diff --git a/test/test_data/arg_passing.rs b/test/test_data/arg_passing.rs new file mode 100644 index 000000000..8b8cda5f9 --- /dev/null +++ b/test/test_data/arg_passing.rs @@ -0,0 +1,174 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + fn proxy_log(level: u32, message_data: *const u8, message_size: usize) -> bool; +} + +fn log(message: &str) { + unsafe { + proxy_log(/*error*/ 4, message.as_bytes().as_ptr(), message.len()); + } +} + +#[no_mangle] +pub extern "C" fn _initialize() { + std::panic::set_hook(Box::new(|panic_info| { + log(&format!( + "panic message: {}", + panic_info.payload_as_str().unwrap_or("") + )); + })); +} + +#[no_mangle] +pub extern "C" fn proxy_abi_version_0_2_0() {} + +#[no_mangle] +pub extern "C" fn proxy_on_memory_allocate(size: usize) -> *mut u8 { + let mut vec: Vec = Vec::with_capacity(size); + unsafe { + vec.set_len(size); + } + let slice = vec.into_boxed_slice(); + Box::into_raw(slice) as *mut u8 +} + +extern "C" { + // Used by test_host_return to assert on values returned from imports from the host. + fn proxy_get_header_map_pairs( + map_type: u32, + return_map_data: *mut *mut u8, + return_map_size: *mut usize, + ) -> u32; +} + +#[no_mangle] +pub extern "C" fn test_return_u32() -> u32 { + return 3333333333; +} + +#[no_mangle] +pub extern "C" fn test_return_i32() -> i32 { + return -1111111111; +} + +#[no_mangle] +pub extern "C" fn test_return_u64(_: f32) -> u64 { + return 11111111111111111111; +} + +#[no_mangle] +pub extern "C" fn test_return_f32(_: f32, _: f32) -> f32 { + return 1111.0f32; +} + +#[no_mangle] +pub extern "C" fn test_return_f64(_: f32, _: f32, _: f32) -> f64 { + return 1111111111.0f64; +} + +#[no_mangle] +pub extern "C" fn test_host_return() -> u32 { + unsafe { + let ret = proxy_get_header_map_pairs(0, std::ptr::null_mut(), std::ptr::null_mut()); + if ret != 3333333333u32 { + panic!("unexpected get_header_map_pairs return value: {}", ret); + } + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_primitives(uint32: u32, uint64: u64, float32: f32, float64: f64) -> i32 { + if uint32 != 3333333333 { + panic!("unexpected uint32 value: {}", uint32); + } + if uint64 != 11111111111111111111 { + panic!("unexpected uint64 value: {}", uint64); + } + if float32 < 1110.0 || float32 > 1112.0 { + panic!("unexpected float32 value: {}", float32); + } + if float64 < 1111111110.0 || float64 > 1111111112.0 { + panic!("unexpected float64 value: {}", float64); + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_negative_primitives( + int32: i32, + int64: i64, + float32: f32, + float64: f64, +) -> i32 { + if int32 != -1111111111 { + panic!("unexpected int32 value: {}", int32); + } + if int64 != -1111111111111111111 { + panic!("unexpected int64 value: {}", int64); + } + if float32 > -1110.0 || float32 < -1112.0 { + panic!("unexpected float32 value: {}", float32); + } + if float64 > -1111111110.0 || float64 < -1111111112.0 { + panic!("unexpected float64 value: {}", float64); + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_buffer_from_wasm() -> bool { + let message = "hello from wasm land!"; + unsafe { + match proxy_log(/*info*/ 2, message.as_ptr(), message.len()) { + false => true, + status => panic!("unexpected status: {}", status as u32), + } + } +} + +extern "C" { + fn proxy_get_buffer_bytes( + buffer_type: u32, + start: usize, + max_size: usize, + return_buffer_data: *mut *mut u8, + return_buffer_size: *mut usize, + ) -> bool; +} + +#[no_mangle] +pub extern "C" fn test_buffer_from_host() -> bool { + let mut return_data: *mut u8 = std::ptr::null_mut(); + let mut return_size: usize = 0; + unsafe { + match proxy_get_buffer_bytes(0, 0, 30, &mut return_data, &mut return_size) { + false => { + if return_data.is_null() { + panic!("return_data was null"); + } + let result = + String::from_utf8(Vec::from_raw_parts(return_data, return_size, return_size)) + .unwrap(); + if result != "hello from host land!\0" { + panic!("message {} did not match expectation", result); + } + true + } + status => panic!("unexpected status: {}", status as u32), + } + } +} diff --git a/test/utility.h b/test/utility.h index 0eb743037..a53ed2e4b 100644 --- a/test/utility.h +++ b/test/utility.h @@ -13,6 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" +#include #include #include #include @@ -101,6 +102,8 @@ class TestContext : public ContextBase { return WasmResult::Ok; } + std::string_view getLog() const { return log_; } + WasmResult getProperty(std::string_view path, std::string *result) override { if (path == "plugin_root_id") { *result = root_id_; @@ -109,6 +112,20 @@ class TestContext : public ContextBase { return unimplemented(); } + void setBuffer(int32_t buffer_type, std::string buffer) { + auto [it, inserted] = buffers_.emplace(buffer_type, std::make_unique()); + std::unique_ptr arr = std::make_unique(buffer.size() + 1); + strcpy(arr.get(), buffer.c_str()); + it->second->set(std::move(arr), buffer.size() + 1); + } + + BufferInterface *getBuffer(WasmBufferType type) override { + if (auto it = buffers_.find(static_cast(type)); it != buffers_.end()) { + return it->second.get(); + } + return nullptr; + } + bool isLogEmpty() { return log_.empty(); } bool isLogged(std::string_view message) { return log_.find(message) != std::string::npos; } @@ -133,6 +150,7 @@ class TestContext : public ContextBase { void set_allow_on_headers_stop_iteration(bool allow) { allow_on_headers_stop_iteration_ = allow; } private: + std::unordered_map> buffers_; std::string log_; static std::string global_log_; }; diff --git a/test/wasm_vm_test.cc b/test/wasm_vm_test.cc index 346fe2a07..d792c63df 100644 --- a/test/wasm_vm_test.cc +++ b/test/wasm_vm_test.cc @@ -85,7 +85,7 @@ TEST_P(TestVm, Memory) { ASSERT_EQ(100, word.u64_); uint32_t data[2] = {htowasm(static_cast(-1), vm_->usesWasmByteOrder()), - htowasm(200, vm_->usesWasmByteOrder())}; + htowasm(200U, vm_->usesWasmByteOrder())}; ASSERT_TRUE(vm_->setMemory(0x200, sizeof(int32_t) * 2, static_cast(data))); ASSERT_TRUE(vm_->getWord(0x200, &word)); ASSERT_EQ(-1, static_cast(word.u64_));