diff --git a/BUILD b/BUILD index e531ce1b3..50753da04 100644 --- a/BUILD +++ b/BUILD @@ -187,31 +187,11 @@ cc_library( ], ) -genrule( - name = "prefixed_wasmtime_sources", - srcs = [ - "src/wasmtime/types.h", - "src/wasmtime/wasmtime.cc", - ], - outs = [ - "src/wasmtime/prefixed_types.h", - "src/wasmtime/prefixed_wasmtime.cc", - ], - cmd = """ - for file in $(SRCS); do - sed -e 's/wasm_/wasmtime_wasm_/g' \ - -e 's/wasmtime\\/types.h/wasmtime\\/prefixed_types.h/g' \ - $$file >$(@D)/src/wasmtime/prefixed_$$(basename $$file) - done - """, -) - cc_library( name = "wasmtime_lib", srcs = [ "src/common/types.h", - "src/wasmtime/prefixed_types.h", - "src/wasmtime/prefixed_wasmtime.cc", + "src/wasmtime/wasmtime.cc", ], hdrs = ["include/proxy-wasm/wasmtime.h"], copts = [ diff --git a/bazel/external/wasmtime.BUILD b/bazel/external/wasmtime.BUILD index 490f32779..1bf1724cd 100644 --- a/bazel/external/wasmtime.BUILD +++ b/bazel/external/wasmtime.BUILD @@ -57,11 +57,20 @@ genrule( toolchains = ["@bazel_tools//tools/cpp:current_cc_toolchain"], ) -# This must match the features defined in `bazel/cargo/wasmtime/Cargo.toml` for -# the C/C++ API to expose the right set of methods. +# This should match the features defined in `bazel/cargo/wasmtime/Cargo.toml` +# for the C/C++ API to expose the right set of methods. Listing the feature +# here enables the C-api implementations of the features. Rust-side +# implementations are controlled by the Cargo.toml file. Mismatching features +# will result in compile/link time failures. features = [ "cranelift", "gc-drc", + # The C++ API references the wat feature whenever cranelift is turned on. + # Without adding `wat` to the headers, the C++ API will fail at compile time. + # `wat` is not actually used by proxy-wasm, so the corresponding feature is not + # enabled in Cargo.toml. If proxy-wasm used wat, this configuration would fail + # at link time. + "wat", ] # Wasmtime C-api headers use cmakedefines to generate the config file. diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 9fc435d35..ef341cb45 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -67,12 +67,21 @@ template using WasmCallVoid = std::function>; template using WasmCallWord = std::function>; +// Callback used to test arg passing from host to wasm. +using WasmCall_WWlfd = std::function; +// Types used to test return values. Floats are passed as parameters as these +// do not conflict with ProxyWasm ABI signatures. +using WasmCall_lf = std::function; +using WasmCall_fff = std::function; +using WasmCall_dfff = std::function; #define FOR_ALL_WASM_VM_EXPORTS(_f) \ _f(proxy_wasm::WasmCallVoid<0>) _f(proxy_wasm::WasmCallVoid<1>) _f(proxy_wasm::WasmCallVoid<2>) \ _f(proxy_wasm::WasmCallVoid<3>) _f(proxy_wasm::WasmCallVoid<5>) \ _f(proxy_wasm::WasmCallWord<0>) _f(proxy_wasm::WasmCallWord<1>) \ - _f(proxy_wasm::WasmCallWord<2>) _f(proxy_wasm::WasmCallWord<3>) + _f(proxy_wasm::WasmCallWord<2>) _f(proxy_wasm::WasmCallWord<3>) \ + _f(proxy_wasm::WasmCall_WWlfd) _f(proxy_wasm::WasmCall_lf) \ + _f(proxy_wasm::WasmCall_fff) _f(proxy_wasm::WasmCall_dfff) // These are templates and its helper for constructing signatures of functions callbacks from Wasm // VMs. diff --git a/include/proxy-wasm/word.h b/include/proxy-wasm/word.h index bc0d23a8c..a51b981b3 100644 --- a/include/proxy-wasm/word.h +++ b/include/proxy-wasm/word.h @@ -24,8 +24,16 @@ namespace proxy_wasm { // Use byteswap functions only when compiling for big-endian platforms. #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define htowasm(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? __builtin_bswap32(x) : (x)) -#define wasmtoh(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? __builtin_bswap32(x) : (x)) +static inline float bswap(float x) { + return std::bit_cast(__builtin_bswap32(std::bit_cast(x))); +} +static inline double bswap(double x) { + return std::bit_cast(__builtin_bswap64(std::bit_cast(x))); +} +static inline uint32_t bswap(uint32_t x) { return __builtin_bswap32(x); } +static inline auto bswap(auto x) { return __builtin_bswap64(x); } +#define htowasm(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? bswap(x) : (x)) +#define wasmtoh(x, vm_uses_wasm_byte_order) ((vm_uses_wasm_byte_order) ? bswap(x) : (x)) #else #define htowasm(x, vm_uses_wasm_byte_order) (x) #define wasmtoh(x, vm_uses_wasm_byte_order) (x) diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 25b6623ca..4642de7a6 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -245,6 +245,7 @@ template <> constexpr auto convertArgToValKind() { return wasm::ValKin template <> constexpr auto convertArgToValKind() { return wasm::ValKind::I64; }; template <> constexpr auto convertArgToValKind() { return wasm::ValKind::I64; }; template <> constexpr auto convertArgToValKind() { return wasm::ValKind::F64; }; +template <> constexpr auto convertArgToValKind() { return wasm::ValKind::F32; }; template constexpr auto convertArgsTupleToValTypesImpl(std::index_sequence /*comptime*/) { diff --git a/src/wamr/wamr.cc b/src/wamr/wamr.cc index 8eef73590..f445186b7 100644 --- a/src/wamr/wamr.cc +++ b/src/wamr/wamr.cc @@ -462,6 +462,10 @@ template <> void assignVal(uint64_t t, wasm_val_t &val) { val.kind = WASM_I64; val.of.i64 = static_cast(t); } +template <> void assignVal(float t, wasm_val_t &val) { + val.kind = WASM_F32; + val.of.f32 = static_cast(t); +} template <> void assignVal(double t, wasm_val_t &val) { val.kind = WASM_F64; val.of.f64 = t; @@ -485,17 +489,21 @@ template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; +template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f32(); }; template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f64(); }; template T convertValueTypeToArg(wasm_val_t val); template <> uint32_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i32); } -template <> Word convertValueTypeToArg(wasm_val_t val) { return val.of.i32; } +template <> Word convertValueTypeToArg(wasm_val_t val) { + return std::bit_cast(val.of.i32); +} template <> int64_t convertValueTypeToArg(wasm_val_t val) { return val.of.i64; } template <> uint64_t convertValueTypeToArg(wasm_val_t val) { return static_cast(val.of.i64); } +template <> float convertValueTypeToArg(wasm_val_t val) { return val.of.f32; } template <> double convertValueTypeToArg(wasm_val_t val) { return val.of.f64; } template diff --git a/src/wasmedge/wasmedge.cc b/src/wasmedge/wasmedge.cc index acfe15b3c..cc93dce69 100644 --- a/src/wasmedge/wasmedge.cc +++ b/src/wasmedge/wasmedge.cc @@ -49,6 +49,7 @@ template <> WasmEdge_Value makeVal(uint64_t t) { return WasmEdge_ValueGenI64(static_cast(t)); } template <> WasmEdge_Value makeVal(double t) { return WasmEdge_ValueGenF64(t); } +template <> WasmEdge_Value makeVal(float t) { return WasmEdge_ValueGenF32(t); } // Helper function to print values. std::string printValue(const WasmEdge_Value &value) { @@ -144,19 +145,23 @@ template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValT template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenI64(); } template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenI64(); } template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenF64(); } +template <> WasmEdge_ValType convArgToValType() { return WasmEdge_ValTypeGenF32(); } // Helper templates to convert valtype to arg. template T convValTypeToArg(WasmEdge_Value val); template <> uint32_t convValTypeToArg(WasmEdge_Value val) { return static_cast(WasmEdge_ValueGetI32(val)); } -template <> Word convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetI32(val); } +template <> Word convValTypeToArg(WasmEdge_Value val) { + return std::bit_cast(WasmEdge_ValueGetI32(val)); +} template <> int64_t convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetI64(val); } template <> uint64_t convValTypeToArg(WasmEdge_Value val) { return static_cast(WasmEdge_ValueGetI64(val)); } +template <> float convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetF32(val); } template <> double convValTypeToArg(WasmEdge_Value val) { return WasmEdge_ValueGetF64(val); } diff --git a/src/wasmtime/types.h b/src/wasmtime/types.h deleted file mode 100644 index 14fe75053..000000000 --- a/src/wasmtime/types.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/common/types.h" -#include "crates/c-api/include/wasm.h" - -namespace proxy_wasm::wasmtime { - -using WasmEnginePtr = common::CSmartPtr; -using WasmFuncPtr = common::CSmartPtr; -using WasmStorePtr = common::CSmartPtr; -using WasmModulePtr = common::CSmartPtr; -using WasmSharedModulePtr = common::CSmartPtr; -using WasmMemoryPtr = common::CSmartPtr; -using WasmTablePtr = common::CSmartPtr; -using WasmInstancePtr = common::CSmartPtr; -using WasmFunctypePtr = common::CSmartPtr; -using WasmTrapPtr = common::CSmartPtr; -using WasmExternPtr = common::CSmartPtr; - -using WasmByteVec = - common::CSmartType; -using WasmImporttypeVec = common::CSmartType; -using WasmExportTypeVec = common::CSmartType; -using WasmExternVec = - common::CSmartType; -using WasmValtypeVec = - common::CSmartType; - -} // namespace proxy_wasm::wasmtime diff --git a/src/wasmtime/wasmtime.cc b/src/wasmtime/wasmtime.cc index a72a0361d..c2ae65e38 100644 --- a/src/wasmtime/wasmtime.cc +++ b/src/wasmtime/wasmtime.cc @@ -14,38 +14,76 @@ #include "include/proxy-wasm/wasmtime.h" -#include #include #include #include #include #include -#include -#include -#include -#include "src/wasmtime/types.h" +#include "include/proxy-wasm/limits.h" +#include "include/proxy-wasm/word.h" -#include "crates/c-api/include/wasm.h" +#include "crates/c-api/include/wasmtime.hh" // IWYU pragma: keep + +namespace wasmtime::detail { +// Defines how wasmtime serializes proxy_wasm::Word. +template <> struct WasmType { + static const bool valid = true; + static const ValKind kind = ValKind::I32; + static void store(Store::Context cx, wasmtime_val_raw_t *p, const proxy_wasm::Word &t) { + p->i32 = t; + } + static proxy_wasm::Word load(Store::Context cx, wasmtime_val_raw_t *p) { return p->i32; } +}; +} // namespace wasmtime::detail namespace proxy_wasm { namespace wasmtime { +namespace { + +using ::wasmtime::Config; +using ::wasmtime::Engine; +using ::wasmtime::ExportType; +using ::wasmtime::Extern; +using ::wasmtime::Instance; +using ::wasmtime::Linker; +using ::wasmtime::Memory; +using ::wasmtime::Module; +using ::wasmtime::Result; +using ::wasmtime::Span; +using ::wasmtime::Store; +using ::wasmtime::Table; +using ::wasmtime::TrapResult; + +Engine *engine() { + static auto *const engine = []() { return new Engine(Config{}); }(); + return engine; +} -struct HostFuncData { - HostFuncData(std::string name) : name_(std::move(name)) {} - - std::string name_; - WasmFuncPtr callback_; - void *raw_func_{}; - WasmVm *vm_{}; -}; +template std::string printValue(const T &value) { return std::to_string(value); } +template <> std::string printValue(const proxy_wasm::Word &value) { + return std::to_string(value.u64_); +} -using HostFuncDataPtr = std::unique_ptr; +std::string printValues() { return ""; } +template std::string printValues(Arg arg, Args... args) { + return printValue(arg) + ((", " + printValue(args)) + ... + ""); +} -wasm_engine_t *engine() { - static const auto engine = WasmEnginePtr(wasm_engine_new()); - return engine.get(); +auto ConvertWasmToHostEndianness(const auto &arg) { + return wasmtoh(convertWordToUint32(arg), true); +} +void InPlaceConvertWasmToHostEndianness(auto &...args) { + (void(args = ConvertWasmToHostEndianness(args)), ...); +} +auto ConvertHostToWasmEndianness(const auto &arg) { + return htowasm(convertWordToUint32(arg), true); } +void InPlaceConvertHostToWasmEndianness(auto &...args) { + (void(args = ConvertHostToWasmEndianness(args)), ...); +} + +} // namespace class Wasmtime : public WasmVm { public: @@ -83,6 +121,10 @@ class Wasmtime : public WasmVm { void warm() override; + void terminate() override {} + + bool usesWasmByteOrder() override { return true; } + private: template void registerHostFunctionImpl(std::string_view module_name, std::string_view function_name, @@ -100,68 +142,57 @@ class Wasmtime : public WasmVm { void getModuleFunctionImpl(std::string_view function_name, std::function *function); - void terminate() override {} - bool usesWasmByteOrder() override { return true; } - // Initialize the Wasmtime store if necessary. void initStore(); - WasmStorePtr store_; - WasmModulePtr module_; - WasmSharedModulePtr shared_module_; - WasmInstancePtr instance_; - WasmMemoryPtr memory_; - WasmTablePtr table_; + std::optional store_; + std::optional module_; + std::optional instance_; + std::optional memory_; + std::optional table_; + Linker linker_ = Linker(*engine()); - std::unordered_map host_functions_; - std::unordered_map module_functions_; + std::unordered_map module_functions_; }; void Wasmtime::initStore() { - if (store_ != nullptr) { + if (store_.has_value()) { return; } - store_ = wasm_store_new(engine()); + store_.emplace(*engine()); } bool Wasmtime::load(std::string_view bytecode, std::string_view /*precompiled*/, const std::unordered_map & /*function_names*/) { initStore(); - if (store_ == nullptr) { + if (!store_.has_value()) { return false; } - WasmByteVec vec; - wasm_byte_vec_new(vec.get(), bytecode.size(), bytecode.data()); - - module_ = wasm_module_new(store_.get(), vec.get()); - if (module_ == nullptr) { - return false; - } - - shared_module_ = wasm_module_share(module_.get()); - if (shared_module_ == nullptr) { + Result module = + Module::compile(*engine(), std::span((uint8_t *)bytecode.data(), bytecode.size())); + if (!module) { + fail(FailState::UnableToInitializeCode, "Failed to load Wasm code: " + module.err().message()); return false; } + module_.emplace(module.ok()); return true; } std::unique_ptr Wasmtime::clone() { - assert(shared_module_ != nullptr); - auto clone = std::make_unique(); if (clone == nullptr) { return nullptr; } - clone->store_ = wasm_store_new(engine()); - if (clone->store_ == nullptr) { + clone->store_.emplace(Store(*engine())); + if (!clone->store_.has_value()) { return nullptr; } - clone->module_ = wasm_module_obtain(clone->store_.get(), shared_module_.get()); - if (clone->module_ == nullptr) { + clone->module_.emplace(*module_); + if (!clone->module_.has_value()) { return nullptr; } @@ -174,466 +205,159 @@ std::unique_ptr Wasmtime::clone() { return clone; } -static bool equalValTypes(const wasm_valtype_vec_t *left, const wasm_valtype_vec_t *right) { - if (left->size != right->size) { - return false; - } - - for (size_t i = 0; i < left->size; i++) { - if (wasm_valtype_kind(left->data[i]) != wasm_valtype_kind(right->data[i])) { - return false; - } - } - - return true; -} - -static std::string printValue(const wasm_val_t &value) { - switch (value.kind) { - case WASM_I32: - return std::to_string(value.of.i32); - case WASM_I64: - return std::to_string(value.of.i64); - case WASM_F32: - return std::to_string(value.of.f32); - case WASM_F64: - return std::to_string(value.of.f64); - default: - return "unknown"; - } -} - -static std::string printValues(const wasm_val_vec_t *values) { - if (values->size == 0) { - return ""; - } - - std::string s; - for (size_t i = 0; i < values->size; i++) { - if (i != 0U) { - s.append(", "); - } - s.append(printValue(values->data[i])); - } - return s; -} - -static const char *printValKind(wasm_valkind_t kind) { - switch (kind) { - case WASM_I32: - return "i32"; - case WASM_I64: - return "i64"; - case WASM_F32: - return "f32"; - case WASM_F64: - return "f64"; - case WASM_EXTERNREF: - return "externref"; - case WASM_FUNCREF: - return "funcref"; - default: - return "unknown"; - } -} - -static std::string printValTypes(const wasm_valtype_vec_t *types) { - if (types->size == 0) { - return "void"; - } - - std::string s; - s.reserve(types->size * 8 /* max size + " " */ - 1); - for (size_t i = 0; i < types->size; i++) { - if (i != 0U) { - s.append(" "); - } - s.append(printValKind(wasm_valtype_kind(types->data[i]))); - } - return s; -} - bool Wasmtime::link(std::string_view /*debug_name*/) { - assert(module_ != nullptr); - - WasmImporttypeVec import_types; - wasm_module_imports(module_.get(), import_types.get()); - - std::vector imports; - for (size_t i = 0; i < import_types.get()->size; i++) { - const wasm_name_t *module_name_ptr = wasm_importtype_module(import_types.get()->data[i]); - const wasm_name_t *name_ptr = wasm_importtype_name(import_types.get()->data[i]); - const wasm_externtype_t *extern_type = wasm_importtype_type(import_types.get()->data[i]); - - std::string_view module_name(module_name_ptr->data, module_name_ptr->size); - std::string_view name(name_ptr->data, name_ptr->size); - assert(name_ptr->size > 0); - switch (wasm_externtype_kind(extern_type)) { - case WASM_EXTERN_FUNC: { - auto it = host_functions_.find(std::string(module_name) + "." + std::string(name)); - if (it == host_functions_.end()) { - fail(FailState::UnableToInitializeCode, - std::string("Failed to load Wasm module due to a missing import: ") + - std::string(module_name) + "." + std::string(name)); - return false; - } - - auto *func = it->second->callback_.get(); - const wasm_functype_t *exp_type = wasm_externtype_as_functype_const(extern_type); - WasmFunctypePtr actual_type = wasm_func_type(it->second->callback_.get()); - if (!equalValTypes(wasm_functype_params(exp_type), wasm_functype_params(actual_type.get())) || - !equalValTypes(wasm_functype_results(exp_type), - wasm_functype_results(actual_type.get()))) { - fail( - FailState::UnableToInitializeCode, - std::string("Failed to load Wasm module due to an import type mismatch for function ") + - std::string(module_name) + "." + std::string(name) + - ", want: " + printValTypes(wasm_functype_params(exp_type)) + " -> " + - printValTypes(wasm_functype_results(exp_type)) + - ", but host exports: " + printValTypes(wasm_functype_params(actual_type.get())) + - " -> " + printValTypes(wasm_functype_results(actual_type.get()))); - return false; - } - imports.push_back(wasm_func_as_extern(func)); - } break; - case WASM_EXTERN_GLOBAL: { - // TODO(mathetake): add support when/if needed. - fail(FailState::UnableToInitializeCode, - "Failed to load Wasm module due to a missing import: " + std::string(module_name) + "." + - std::string(name)); - return false; - } break; - case WASM_EXTERN_MEMORY: { - assert(memory_ == nullptr); - const wasm_memorytype_t *memory_type = - wasm_externtype_as_memorytype_const(extern_type); // owned by `extern_type` - if (memory_type == nullptr) { - return false; - } - memory_ = wasm_memory_new(store_.get(), memory_type); - if (memory_ == nullptr) { - return false; - } - imports.push_back(wasm_memory_as_extern(memory_.get())); - } break; - case WASM_EXTERN_TABLE: { - assert(table_ == nullptr); - const wasm_tabletype_t *table_type = - wasm_externtype_as_tabletype_const(extern_type); // owned by `extern_type` - if (table_type == nullptr) { - return false; - } - table_ = wasm_table_new(store_.get(), table_type, nullptr); - if (table_ == nullptr) { - return false; - } - imports.push_back(wasm_table_as_extern(table_.get())); - } break; - } - } - - if (import_types.get()->size != imports.size()) { - return false; - } + assert(module_.has_value()); - wasm_extern_vec_t imports_vec = {imports.size(), imports.data()}; - instance_ = wasm_instance_new(store_.get(), module_.get(), &imports_vec, nullptr); - if (instance_ == nullptr) { - fail(FailState::UnableToInitializeCode, "Failed to create new Wasm instance"); + TrapResult instance = linker_.instantiate(store_->context(), *module_); + if (!instance) { + fail(FailState::UnableToInitializeCode, + "Failed to create new Wasm instance: " + instance.err().message()); return false; } + instance_.emplace(instance.ok()); - WasmExportTypeVec export_types; - wasm_module_exports(module_.get(), export_types.get()); - - WasmExternVec exports; - wasm_instance_exports(instance_.get(), exports.get()); - - for (size_t i = 0; i < export_types.get()->size; i++) { - const wasm_externtype_t *exp_extern_type = wasm_exporttype_type(export_types.get()->data[i]); - wasm_extern_t *actual_extern = exports.get()->data[i]; - - wasm_externkind_t kind = wasm_extern_kind(actual_extern); - assert(kind == wasm_externtype_kind(exp_extern_type)); - switch (kind) { - case WASM_EXTERN_FUNC: { - WasmFuncPtr func = wasm_func_copy(wasm_extern_as_func(actual_extern)); - const wasm_name_t *name_ptr = wasm_exporttype_name(export_types.get()->data[i]); - module_functions_.insert_or_assign(std::string(name_ptr->data, name_ptr->size), - std::move(func)); - } break; - case WASM_EXTERN_GLOBAL: { - // TODO(mathetake): add support when/if needed. - } break; - case WASM_EXTERN_MEMORY: { - assert(memory_ == nullptr); - memory_ = wasm_memory_copy(wasm_extern_as_memory(actual_extern)); - if (memory_ == nullptr) { - return false; - } - } break; - case WASM_EXTERN_TABLE: { - // TODO(mathetake): add support when/if needed. - } break; + ExportType::List export_types = module_->exports(); + for (ExportType::Ref export_type : export_types) { + std::optional actual_extern = instance_->get(store_->context(), export_type.name()); + if (!actual_extern.has_value()) { + continue; + } + if (std::holds_alternative<::wasmtime::Func>(*actual_extern)) { + module_functions_.insert_or_assign(std::string(export_type.name()), + std::get<::wasmtime::Func>(*actual_extern)); + } else if (std::holds_alternative(*actual_extern)) { + assert(!memory_.has_value()); + memory_.emplace(std::get(*actual_extern)); } + // TODO: add support for globals and tables and when/if needed. } return true; } -uint64_t Wasmtime::getMemorySize() { return wasm_memory_data_size(memory_.get()); } +uint64_t Wasmtime::getMemorySize() { + return memory_->size(store_->context()) * PROXY_WASM_HOST_WASM_MEMORY_PAGE_SIZE_BYTES; +} std::optional Wasmtime::getMemory(uint64_t pointer, uint64_t size) { - assert(memory_ != nullptr); - if (pointer + size > wasm_memory_data_size(memory_.get())) { + assert(store_.has_value()); + assert(memory_.has_value()); + ::wasmtime::Span data = memory_->data(store_->context()); + if (pointer + size > data.size()) { return std::nullopt; } - return std::string_view(wasm_memory_data(memory_.get()) + pointer, size); + return std::string_view(reinterpret_cast(data.data() + pointer), size); } bool Wasmtime::setMemory(uint64_t pointer, uint64_t size, const void *data) { - assert(memory_ != nullptr); - if (pointer + size > wasm_memory_data_size(memory_.get())) { + assert(store_.has_value()); + assert(memory_.has_value()); + ::wasmtime::Span memory = memory_->data(store_->context()); + if (pointer + size > memory.size()) { return false; } - ::memcpy(wasm_memory_data(memory_.get()) + pointer, data, size); + ::memcpy(memory.data() + pointer, data, size); return true; } bool Wasmtime::getWord(uint64_t pointer, Word *word) { - assert(memory_ != nullptr); + assert(store_.has_value()); + assert(memory_.has_value()); + ::wasmtime::Span memory = memory_->data(store_->context()); constexpr auto size = sizeof(uint32_t); - if (pointer + size > wasm_memory_data_size(memory_.get())) { + if (pointer + size > memory.size()) { return false; } uint32_t word32; - ::memcpy(&word32, wasm_memory_data(memory_.get()) + pointer, size); + ::memcpy(&word32, memory.data() + pointer, size); word->u64_ = wasmtoh(word32, true); return true; } bool Wasmtime::setWord(uint64_t pointer, Word word) { + assert(store_.has_value()); + assert(memory_.has_value()); + ::wasmtime::Span memory = memory_->data(store_->context()); constexpr auto size = sizeof(uint32_t); - if (pointer + size > wasm_memory_data_size(memory_.get())) { + if (pointer + size > memory.size()) { return false; } uint32_t word32 = htowasm(word.u32(), true); - ::memcpy(wasm_memory_data(memory_.get()) + pointer, &word32, size); + ::memcpy(memory.data() + pointer, &word32, size); return true; } -template void assignVal(T t, wasm_val_t &val); -template <> void assignVal(Word t, wasm_val_t &val) { - val.kind = WASM_I32; - val.of.i32 = static_cast(t.u64_); -} -template <> void assignVal(uint32_t t, wasm_val_t &val) { - val.kind = WASM_I32; - val.of.i32 = static_cast(t); -} -template <> void assignVal(uint64_t t, wasm_val_t &val) { - val.kind = WASM_I64; - val.of.i64 = static_cast(t); -} -template <> void assignVal(double t, wasm_val_t &val) { - val.kind = WASM_F64; - val.of.f64 = t; -} - -template wasm_val_t makeVal(T t) { - wasm_val_t val{}; - assignVal(t, val); - return val; -} - -template struct ConvertWordType { - using type = T; // NOLINT(readability-identifier-naming) -}; -template <> struct ConvertWordType { - using type = uint32_t; // NOLINT(readability-identifier-naming) -}; - -template auto convertArgToValTypePtr(); -template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); }; -template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i32(); }; -template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; -template <> auto convertArgToValTypePtr() { return wasm_valtype_new_i64(); }; -template <> auto convertArgToValTypePtr() { return wasm_valtype_new_f64(); }; - -template T convertValueTypeToArg(wasm_val_t val); -template <> uint32_t convertValueTypeToArg(wasm_val_t val) { - return static_cast(val.of.i32); -} -template <> Word convertValueTypeToArg(wasm_val_t val) { return val.of.i32; } -template <> int64_t convertValueTypeToArg(wasm_val_t val) { return val.of.i64; } -template <> uint64_t convertValueTypeToArg(wasm_val_t val) { - return static_cast(val.of.i64); -} -template <> double convertValueTypeToArg(wasm_val_t val) { return val.of.f64; } - -template -constexpr T convertValTypesToArgsTuple(const U &vec, std::index_sequence /*comptime*/) { - return std::make_tuple( - convertValueTypeToArg>::type>( - vec->data[I])...); -} - -template -void convertArgsTupleToValTypesImpl(wasm_valtype_vec_t *types, - std::index_sequence /*comptime*/) { - auto size = std::tuple_size::value; - auto ps = std::array::value>{ - convertArgToValTypePtr::type>()...}; - wasm_valtype_vec_new(types, size, ps.data()); -} - -template ::value>> -void convertArgsTupleToValTypes(wasm_valtype_vec_t *types) { - convertArgsTupleToValTypesImpl(types, Is()); -} - -template WasmFunctypePtr newWasmNewFuncType() { - WasmValtypeVec params; - WasmValtypeVec results; - convertArgsTupleToValTypes(params.get()); - convertArgsTupleToValTypes>(results.get()); - return wasm_functype_new(params.get(), results.get()); -} - -template WasmFunctypePtr newWasmNewFuncType() { - WasmValtypeVec params; - WasmValtypeVec results; - convertArgsTupleToValTypes(params.get()); - convertArgsTupleToValTypes>(results.get()); - return wasm_functype_new(params.get(), results.get()); -} - template void Wasmtime::registerHostFunctionImpl(std::string_view module_name, std::string_view function_name, void (*function)(Args...)) { - auto data = - std::make_unique(std::string(module_name) + "." + std::string(function_name)); - - WasmFunctypePtr type = newWasmNewFuncType>(); - WasmFuncPtr func = wasm_func_new_with_env( - store_.get(), type.get(), - [](void *data, const wasm_val_vec_t *params, wasm_val_vec_t * /*results*/) -> wasm_trap_t * { - auto *func_data = reinterpret_cast(data); - const bool log = func_data->vm_->cmpLogLevel(LogLevel::trace); + Result result = linker_.func_wrap( + module_name, function_name, + [this, function, + function_name = std::string(module_name) + "." + std::string(function_name)](Args... args) { + InPlaceConvertWasmToHostEndianness(args...); + const bool log = cmpLogLevel(LogLevel::trace); if (log) { - func_data->vm_->integration()->trace("[vm->host] " + func_data->name_ + "(" + - printValues(params) + ")"); + integration()->trace("[vm->host] " + function_name + "(" + printValues(args...) + ")"); } - auto args = convertValTypesToArgsTuple>( - params, std::make_index_sequence{}); - auto fn = reinterpret_cast(func_data->raw_func_); - std::apply(fn, args); + function(std::forward(args)...); if (log) { - func_data->vm_->integration()->trace("[vm<-host] " + func_data->name_ + " return: void"); + integration()->trace("[vm<-host] " + function_name + " return: void"); } - return nullptr; - }, - data.get(), nullptr); - - data->vm_ = this; - data->callback_ = std::move(func); - data->raw_func_ = reinterpret_cast(function); - host_functions_.insert_or_assign(std::string(module_name) + "." + std::string(function_name), - std::move(data)); + }); + if (!result) { + fail(FailState::ConfigureFailed, "Failed to register host function: " + result.err().message()); + } }; template void Wasmtime::registerHostFunctionImpl(std::string_view module_name, std::string_view function_name, R (*function)(Args...)) { - auto data = - std::make_unique(std::string(module_name) + "." + std::string(function_name)); - WasmFunctypePtr type = newWasmNewFuncType>(); - WasmFuncPtr func = wasm_func_new_with_env( - store_.get(), type.get(), - [](void *data, const wasm_val_vec_t *params, wasm_val_vec_t *results) -> wasm_trap_t * { - auto *func_data = reinterpret_cast(data); - const bool log = func_data->vm_->cmpLogLevel(LogLevel::trace); + Result result = linker_.func_wrap( + module_name, function_name, + [this, function, + function_name = std::string(module_name) + "." + std::string(function_name)](Args... args) { + InPlaceConvertWasmToHostEndianness(args...); + const bool log = cmpLogLevel(LogLevel::trace); if (log) { - func_data->vm_->integration()->trace("[vm->host] " + func_data->name_ + "(" + - printValues(params) + ")"); + integration()->trace("[vm->host] " + function_name + "(" + printValues(args...) + ")"); } - auto args = convertValTypesToArgsTuple>( - params, std::make_index_sequence{}); - auto fn = reinterpret_cast(func_data->raw_func_); - R res = std::apply(fn, args); - assignVal(res, results->data[0]); + R result = function(std::forward(args)...); if (log) { - func_data->vm_->integration()->trace("[vm<-host] " + func_data->name_ + - " return: " + std::to_string(res)); + integration()->trace("[vm<-host] " + function_name + " return: " + printValue(result)); } - return nullptr; - }, - data.get(), nullptr); - - data->vm_ = this; - data->callback_ = std::move(func); - data->raw_func_ = reinterpret_cast(function); - host_functions_.insert_or_assign(std::string(module_name) + "." + std::string(function_name), - std::move(data)); + return ConvertHostToWasmEndianness(result); + }); + if (!result) { + fail(FailState::ConfigureFailed, "Failed to register host function: " + result.err().message()); + } }; template void Wasmtime::getModuleFunctionImpl(std::string_view function_name, std::function *function) { - auto it = module_functions_.find(std::string(function_name)); if (it == module_functions_.end()) { *function = nullptr; return; } - - WasmValtypeVec exp_args; - WasmValtypeVec exp_returns; - convertArgsTupleToValTypes>(exp_args.get()); - convertArgsTupleToValTypes>(exp_returns.get()); - wasm_func_t *func = it->second.get(); - WasmFunctypePtr func_type = wasm_func_type(func); - - if (!equalValTypes(wasm_functype_params(func_type.get()), exp_args.get()) || - !equalValTypes(wasm_functype_results(func_type.get()), exp_returns.get())) { - fail(FailState::UnableToInitializeCode, - "Bad function signature for: " + std::string(function_name) + ", want: " + - printValTypes(exp_args.get()) + " -> " + printValTypes(exp_returns.get()) + - ", but the module exports: " + printValTypes(wasm_functype_params(func_type.get())) + - " -> " + printValTypes(wasm_functype_results(func_type.get()))); + auto typed_func = it->second.typed, std::monostate>(store_->context()); + if (!typed_func) { + *function = nullptr; return; } - - *function = [func, function_name, this](ContextBase *context, Args... args) -> void { + *function = [func = typed_func.ok(), function_name, this](ContextBase *context, + Args... args) -> void { const bool log = cmpLogLevel(LogLevel::trace); SaveRestoreContext saved_context(context); - wasm_val_vec_t results = WASM_EMPTY_VEC; - WasmTrapPtr trap; - - // Workaround for MSVC++ not supporting zero-sized arrays. - if constexpr (sizeof...(args) > 0) { - wasm_val_t params_arr[] = {makeVal(args)...}; - wasm_val_vec_t params = WASM_ARRAY_VEC(params_arr); - if (log) { - integration()->trace("[host->vm] " + std::string(function_name) + "(" + - printValues(¶ms) + ")"); - } - trap.reset(wasm_func_call(func, ¶ms, &results)); - } else { - wasm_val_vec_t params = WASM_EMPTY_VEC; - if (log) { - integration()->trace("[host->vm] " + std::string(function_name) + "()"); - } - trap.reset(wasm_func_call(func, ¶ms, &results)); + if (log) { + integration()->trace("[host->vm] " + std::string(function_name) + "(" + printValues(args...) + + ")"); } - - if (trap) { - WasmByteVec error_message; - wasm_trap_message(trap.get(), error_message.get()); - std::string message(error_message.get()->data); // NULL-terminated + InPlaceConvertHostToWasmEndianness(args...); + TrapResult result = func.call(store_->context(), {args...}); + if (!result) { fail(FailState::RuntimeError, - "Function: " + std::string(function_name) + " failed: " + message); - return; + "Function: " + std::string(function_name) + " failed: " + result.err().message()); } if (log) { integration()->trace("[host<-vm] " + std::string(function_name) + " return: void"); @@ -649,60 +373,32 @@ void Wasmtime::getModuleFunctionImpl(std::string_view function_name, *function = nullptr; return; } - WasmValtypeVec exp_args; - WasmValtypeVec exp_returns; - convertArgsTupleToValTypes>(exp_args.get()); - convertArgsTupleToValTypes>(exp_returns.get()); - wasm_func_t *func = it->second.get(); - WasmFunctypePtr func_type = wasm_func_type(func); - if (!equalValTypes(wasm_functype_params(func_type.get()), exp_args.get()) || - !equalValTypes(wasm_functype_results(func_type.get()), exp_returns.get())) { - fail(FailState::UnableToInitializeCode, - "Bad function signature for: " + std::string(function_name) + ", want: " + - printValTypes(exp_args.get()) + " -> " + printValTypes(exp_returns.get()) + - ", but the module exports: " + printValTypes(wasm_functype_params(func_type.get())) + - " -> " + printValTypes(wasm_functype_results(func_type.get()))); + auto typed_func = it->second.typed, R>(store_->context()); + if (!typed_func) { + *function = nullptr; return; } - - *function = [func, function_name, this](ContextBase *context, Args... args) -> R { + *function = [func = typed_func.ok(), function_name, this](ContextBase *context, + Args... args) -> R { const bool log = cmpLogLevel(LogLevel::trace); SaveRestoreContext saved_context(context); - wasm_val_t results_arr[1]; - wasm_val_vec_t results = WASM_ARRAY_VEC(results_arr); - WasmTrapPtr trap; - - // Workaround for MSVC++ not supporting zero-sized arrays. - if constexpr (sizeof...(args) > 0) { - wasm_val_t params_arr[] = {makeVal(args)...}; - wasm_val_vec_t params = WASM_ARRAY_VEC(params_arr); - if (log) { - integration()->trace("[host->vm] " + std::string(function_name) + "(" + - printValues(¶ms) + ")"); - } - trap.reset(wasm_func_call(func, ¶ms, &results)); - } else { - wasm_val_vec_t params = WASM_EMPTY_VEC; - if (log) { - integration()->trace("[host->vm] " + std::string(function_name) + "()"); - } - trap.reset(wasm_func_call(func, ¶ms, &results)); + if (log) { + integration()->trace("[host->vm] " + std::string(function_name) + "(" + printValues(args...) + + ")"); } - - if (trap) { - WasmByteVec error_message; - wasm_trap_message(trap.get(), error_message.get()); - std::string message(error_message.get()->data); // NULL-terminated + InPlaceConvertHostToWasmEndianness(args...); + TrapResult result_wasm = func.call(store_->context(), {args...}); + if (!result_wasm) { fail(FailState::RuntimeError, - "Function: " + std::string(function_name) + " failed: " + message); + "Function: " + std::string(function_name) + " failed: " + result_wasm.err().message()); return R{}; } - R ret = convertValueTypeToArg(results.data[0]); + R result = ConvertWasmToHostEndianness(result_wasm.ok()); if (log) { integration()->trace("[host<-vm] " + std::string(function_name) + - " return: " + std::to_string(ret)); + " return: " + printValue(result)); } - return ret; + return result; }; }; diff --git a/test/BUILD b/test/BUILD index 97e70558a..38d7810d3 100644 --- a/test/BUILD +++ b/test/BUILD @@ -70,6 +70,7 @@ cc_test( timeout = "long", srcs = ["runtime_test.cc"], data = [ + "//test/test_data:arg_passing.wasm", "//test/test_data:callback.wasm", "//test/test_data:clock.wasm", "//test/test_data:resource_limits.wasm", @@ -84,6 +85,22 @@ cc_test( ], ) +cc_test( + name = "arg_passing_test", + timeout = "long", + srcs = ["arg_passing_test.cc"], + data = [ + "//test/test_data:arg_passing.wasm", + ], + linkstatic = 1, + deps = [ + ":utility_lib", + "//:lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "exports_test", srcs = ["exports_test.cc"], diff --git a/test/arg_passing_test.cc b/test/arg_passing_test.cc new file mode 100644 index 000000000..6f4458ec5 --- /dev/null +++ b/test/arg_passing_test.cc @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include + +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/wasm.h" + +#include "test/utility.h" + +namespace proxy_wasm { +namespace { + +class ArgPassingContext : public TestContext { +public: + using TestContext::TestContext; + WasmResult getHeaderMapPairs(WasmHeaderMapType /* type */, Pairs * /* result */) override { + return static_cast(3333333333U); + } +}; + +class ArgPassingWasm : public TestWasm { +public: + using TestWasm::TestWasm; + ContextBase *createVmContext() override { return new ArgPassingContext(this); }; +}; + +class ArgPassingTest : public TestVm { +public: + void SetUp() { + auto source = readTestWasmFile("arg_passing.wasm"); + ASSERT_FALSE(source.empty()); + wasm_.emplace(std::move(vm_)); + ASSERT_TRUE(wasm_->load(source, false)); + ASSERT_TRUE(wasm_->initialize()); + context_ = dynamic_cast(wasm_->vm_context()); + ASSERT_NE(context_, nullptr); + } + + std::optional wasm_; + ArgPassingContext *context_; +}; + +INSTANTIATE_TEST_SUITE_P(WasmEngines, ArgPassingTest, testing::ValuesIn(getWasmEngines()), + [](const testing::TestParamInfo &info) { + return info.param; + }); + +TEST_P(ArgPassingTest, WasmCallReturnsWordValue) { + WasmCallWord<0> test_return_u32; + wasm_->wasm_vm()->getFunction("test_return_u32", &test_return_u32); + + EXPECT_EQ(test_return_u32(context_).u32(), 3333333333U) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsNegativeWordValue) { + WasmCallWord<0> test_return_i32; + wasm_->wasm_vm()->getFunction("test_return_i32", &test_return_i32); + + EXPECT_EQ(test_return_i32(context_).u32(), -1111111111) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsLongValue) { + WasmCall_lf test_return_u64; + wasm_->wasm_vm()->getFunction("test_return_u64", &test_return_u64); + + EXPECT_EQ(test_return_u64(context_, 1.0), 11111111111111111111UL) << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsFloatValue) { + WasmCall_fff test_return_f32; + wasm_->wasm_vm()->getFunction("test_return_f32", &test_return_f32); + + EXPECT_THAT(test_return_f32(context_, 1.0, 1.0), + testing::AllOf(testing::Lt(1112.0), testing::Gt(1110.0))) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, WasmCallReturnsDoubleValue) { + WasmCall_dfff test_return_f64; + wasm_->wasm_vm()->getFunction("test_return_f64", &test_return_f64); + + EXPECT_THAT(test_return_f64(context_, 1.0, 1.0, 1.0), + testing::AllOf(testing::Lt(1111111112.0), testing::Gt(1111111110.0))) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostCallReturnsWordValue) { + WasmCallWord<0> test_host_return; + wasm_->wasm_vm()->getFunction("test_host_return", &test_host_return); + + EXPECT_TRUE(test_host_return(context_)) << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostPassesPrimitiveValues) { + WasmCall_WWlfd test_primitives; + wasm_->wasm_vm()->getFunction("test_primitives", &test_primitives); + + ASSERT_TRUE(test_primitives(context_, 3333333333U, 11111111111111111111UL, 1111, 1111111111)) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostPassesNegativePrimitiveValues) { + WasmCall_WWlfd test_negative_primitives; + wasm_->wasm_vm()->getFunction("test_negative_primitives", &test_negative_primitives); + + ASSERT_TRUE( + test_negative_primitives(context_, -1111111111, -1111111111111111111, -1111, -1111111111)) + << context_->getLog(); +} + +TEST_P(ArgPassingTest, HostReadsPointersToWasmMemory) { + WasmCallWord<0> test_buffer_from_wasm; + wasm_->wasm_vm()->getFunction("test_buffer_from_wasm", &test_buffer_from_wasm); + + ASSERT_TRUE(test_buffer_from_wasm(context_)) << context_->getLog(); + + context_->isLogged("hello from wasm land!"); +} + +TEST_P(ArgPassingTest, WasmCallReadsBufferPassedByHost) { + context_->setBuffer(0, "hello from host land!"); + WasmCallWord<0> test_buffer_from_host; + wasm_->wasm_vm()->getFunction("test_buffer_from_host", &test_buffer_from_host); + + ASSERT_TRUE(test_buffer_from_host(context_)) << context_->getLog(); +} + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ArgPassingTest); + +} // namespace +} // namespace proxy_wasm diff --git a/test/test_data/BUILD b/test/test_data/BUILD index e5ecd439e..4e27933af 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -57,6 +57,11 @@ wasm_rust_binary( srcs = ["trap.rs"], ) +wasm_rust_binary( + name = "arg_passing.wasm", + srcs = ["arg_passing.rs"], +) + wasm_rust_binary( name = "resource_limits.wasm", srcs = ["resource_limits.rs"], diff --git a/test/test_data/arg_passing.rs b/test/test_data/arg_passing.rs new file mode 100644 index 000000000..1188b4076 --- /dev/null +++ b/test/test_data/arg_passing.rs @@ -0,0 +1,176 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::{SystemTime, UNIX_EPOCH}; + +extern "C" { + fn proxy_log(level: u32, message_data: *const u8, message_size: usize) -> bool; +} + +fn log(message: &str) { + unsafe { + proxy_log(/*error*/ 4, message.as_bytes().as_ptr(), message.len()); + } +} + +#[no_mangle] +pub extern "C" fn _initialize() { + std::panic::set_hook(Box::new(|panic_info| { + log(&format!( + "panic message: {}", + panic_info.payload_as_str().unwrap_or("") + )); + })); +} + +#[no_mangle] +pub extern "C" fn proxy_abi_version_0_2_0() {} + +#[no_mangle] +pub extern "C" fn proxy_on_memory_allocate(size: usize) -> *mut u8 { + let mut vec: Vec = Vec::with_capacity(size); + unsafe { + vec.set_len(size); + } + let slice = vec.into_boxed_slice(); + Box::into_raw(slice) as *mut u8 +} + +extern "C" { + // Used by test_host_return to assert on values returned from imports from the host. + fn proxy_get_header_map_pairs( + map_type: u32, + return_map_data: *mut *mut u8, + return_map_size: *mut usize, + ) -> u32; +} + +#[no_mangle] +pub extern "C" fn test_return_u32() -> u32 { + return 3333333333; +} + +#[no_mangle] +pub extern "C" fn test_return_i32() -> i32 { + return -1111111111; +} + +#[no_mangle] +pub extern "C" fn test_return_u64(_: f32) -> u64 { + return 11111111111111111111; +} + +#[no_mangle] +pub extern "C" fn test_return_f32(_: f32, _: f32) -> f32 { + return 1111.0f32; +} + +#[no_mangle] +pub extern "C" fn test_return_f64(_: f32, _: f32, _: f32) -> f64 { + return 1111111111.0f64; +} + +#[no_mangle] +pub extern "C" fn test_host_return() -> u32 { + unsafe { + let ret = proxy_get_header_map_pairs(0, std::ptr::null_mut(), std::ptr::null_mut()); + if ret != 3333333333u32 { + panic!("unexpected get_header_map_pairs return value: {}", ret); + } + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_primitives(uint32: u32, uint64: u64, float32: f32, float64: f64) -> i32 { + if uint32 != 3333333333 { + panic!("unexpected uint32 value: {}", uint32); + } + if uint64 != 11111111111111111111 { + panic!("unexpected uint64 value: {}", uint64); + } + if float32 < 1110.0 || float32 > 1112.0 { + panic!("unexpected float32 value: {}", float32); + } + if float64 < 1111111110.0 || float64 > 1111111112.0 { + panic!("unexpected float64 value: {}", float64); + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_negative_primitives( + int32: i32, + int64: i64, + float32: f32, + float64: f64, +) -> i32 { + if int32 != -1111111111 { + panic!("unexpected int32 value: {}", int32); + } + if int64 != -1111111111111111111 { + panic!("unexpected int64 value: {}", int64); + } + if float32 > -1110.0 || float32 < -1112.0 { + panic!("unexpected float32 value: {}", float32); + } + if float64 > -1111111110.0 || float64 < -1111111112.0 { + panic!("unexpected float64 value: {}", float64); + } + return 1; +} + +#[no_mangle] +pub extern "C" fn test_buffer_from_wasm() -> bool { + let message = "hello from wasm land!"; + unsafe { + match proxy_log(/*info*/ 2, message.as_ptr(), message.len()) { + false => true, + status => panic!("unexpected status: {}", status as u32), + } + } +} + +extern "C" { + fn proxy_get_buffer_bytes( + buffer_type: u32, + start: usize, + max_size: usize, + return_buffer_data: *mut *mut u8, + return_buffer_size: *mut usize, + ) -> bool; +} + +#[no_mangle] +pub extern "C" fn test_buffer_from_host() -> bool { + let mut return_data: *mut u8 = std::ptr::null_mut(); + let mut return_size: usize = 0; + unsafe { + match proxy_get_buffer_bytes(0, 0, 30, &mut return_data, &mut return_size) { + false => { + if return_data.is_null() { + panic!("return_data was null"); + } + let result = + String::from_utf8(Vec::from_raw_parts(return_data, return_size, return_size)) + .unwrap(); + if result != "hello from host land!\0" { + panic!("message {} did not match expectation", result); + } + true + } + status => panic!("unexpected status: {}", status as u32), + } + } +} diff --git a/test/utility.h b/test/utility.h index 0eb743037..a53ed2e4b 100644 --- a/test/utility.h +++ b/test/utility.h @@ -13,6 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" +#include #include #include #include @@ -101,6 +102,8 @@ class TestContext : public ContextBase { return WasmResult::Ok; } + std::string_view getLog() const { return log_; } + WasmResult getProperty(std::string_view path, std::string *result) override { if (path == "plugin_root_id") { *result = root_id_; @@ -109,6 +112,20 @@ class TestContext : public ContextBase { return unimplemented(); } + void setBuffer(int32_t buffer_type, std::string buffer) { + auto [it, inserted] = buffers_.emplace(buffer_type, std::make_unique()); + std::unique_ptr arr = std::make_unique(buffer.size() + 1); + strcpy(arr.get(), buffer.c_str()); + it->second->set(std::move(arr), buffer.size() + 1); + } + + BufferInterface *getBuffer(WasmBufferType type) override { + if (auto it = buffers_.find(static_cast(type)); it != buffers_.end()) { + return it->second.get(); + } + return nullptr; + } + bool isLogEmpty() { return log_.empty(); } bool isLogged(std::string_view message) { return log_.find(message) != std::string::npos; } @@ -133,6 +150,7 @@ class TestContext : public ContextBase { void set_allow_on_headers_stop_iteration(bool allow) { allow_on_headers_stop_iteration_ = allow; } private: + std::unordered_map> buffers_; std::string log_; static std::string global_log_; }; diff --git a/test/wasm_vm_test.cc b/test/wasm_vm_test.cc index 346fe2a07..c4ed7e3ef 100644 --- a/test/wasm_vm_test.cc +++ b/test/wasm_vm_test.cc @@ -85,7 +85,7 @@ TEST_P(TestVm, Memory) { ASSERT_EQ(100, word.u64_); uint32_t data[2] = {htowasm(static_cast(-1), vm_->usesWasmByteOrder()), - htowasm(200, vm_->usesWasmByteOrder())}; + htowasm(static_cast(200), vm_->usesWasmByteOrder())}; ASSERT_TRUE(vm_->setMemory(0x200, sizeof(int32_t) * 2, static_cast(data))); ASSERT_TRUE(vm_->getWord(0x200, &word)); ASSERT_EQ(-1, static_cast(word.u64_));