Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,25 @@ endif()
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# Common AOTI functionality - combines all AOTI common components
set(_aoti_common_sources common_shims.cpp)
add_library(aoti_common STATIC ${_aoti_common_sources})
# Common AOTI functionality - header-only library for common shims
add_library(aoti_common INTERFACE)
target_include_directories(
aoti_common
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
INTERFACE $<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
)
target_compile_options(
aoti_common
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
INTERFACE $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
)
target_compile_definitions(
aoti_common PRIVATE $<$<PLATFORM_ID:Windows>:EXPORT_AOTI_FUNCTIONS>
aoti_common INTERFACE $<$<PLATFORM_ID:Windows>:EXPORT_AOTI_FUNCTIONS>
)
# Ensure symbols are exported properly
if(APPLE)
target_link_options(aoti_common PUBLIC -Wl,-export_dynamic)
else()
target_link_options(
aoti_common PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
)
endif()

# Link against ExecuTorch libraries and standard libraries
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)
target_link_libraries(aoti_common INTERFACE extension_tensor ${CMAKE_DL_LIBS})

install(
TARGETS aoti_common
Expand Down
6 changes: 6 additions & 0 deletions backends/aoti/aoti_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <string>
#include <vector>

#ifdef CUDA_AVAILABLE
#include <executorch/backends/aoti/slim/core/slim_tensor.h>
#endif

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -95,6 +100,7 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
AOTInductorModelContainerRunFunc run;
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;

};

} // namespace aoti
Expand Down
99 changes: 97 additions & 2 deletions backends/aoti/common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,11 @@ int32_t aoti_torch_device_type_cuda() {
// ============================================================

bool aoti_torch_grad_mode_is_enabled() {
// ExecuTorch doesn't support autograd
return false;
}

AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) {
if (enabled) {
// ExecuTorch doesn't support autograd
return Error::NotSupported;
}
return Error::Ok;
Expand All @@ -162,3 +160,100 @@ AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) {
} // namespace aoti
} // namespace backends
} // namespace executorch

// ============================================================
// extern "C" wrappers for dynamic symbol lookup (dlsym)
// AOTI-generated .so files need these symbols with C linkage
// ============================================================

using Tensor = executorch::backends::aoti::Tensor;
using Error = executorch::runtime::Error;

extern "C" {

Error aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
return executorch::backends::aoti::aoti_torch_get_data_ptr(
tensor, ret_data_ptr);
}

Error aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
return executorch::backends::aoti::aoti_torch_get_sizes(tensor, ret_sizes);
}

Error aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
return executorch::backends::aoti::aoti_torch_get_strides(tensor, ret_strides);
}

Error aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
return executorch::backends::aoti::aoti_torch_get_dtype(tensor, ret_dtype);
}

Error aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
return executorch::backends::aoti::aoti_torch_get_dim(tensor, ret_dim);
}

Error aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset) {
return executorch::backends::aoti::aoti_torch_get_storage_offset(
tensor, ret_storage_offset);
}

Error aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
return executorch::backends::aoti::aoti_torch_get_storage_size(
tensor, ret_size);
}

Error aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type) {
return executorch::backends::aoti::aoti_torch_get_device_type(
tensor, ret_device_type);
}

Error aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index) {
return executorch::backends::aoti::aoti_torch_get_device_index(
tensor, ret_device_index);
}

int32_t aoti_torch_dtype_float32() {
return executorch::backends::aoti::aoti_torch_dtype_float32();
}

int32_t aoti_torch_dtype_bfloat16() {
return executorch::backends::aoti::aoti_torch_dtype_bfloat16();
}

int32_t aoti_torch_dtype_int64() {
return executorch::backends::aoti::aoti_torch_dtype_int64();
}

int32_t aoti_torch_dtype_int32() {
return executorch::backends::aoti::aoti_torch_dtype_int32();
}

int32_t aoti_torch_dtype_int16() {
return executorch::backends::aoti::aoti_torch_dtype_int16();
}

int32_t aoti_torch_dtype_int8() {
return executorch::backends::aoti::aoti_torch_dtype_int8();
}

int32_t aoti_torch_dtype_bool() {
return executorch::backends::aoti::aoti_torch_dtype_bool();
}

int32_t aoti_torch_device_type_cpu() {
return executorch::backends::aoti::aoti_torch_device_type_cpu();
}

int32_t aoti_torch_device_type_cuda() {
return executorch::backends::aoti::aoti_torch_device_type_cuda();
}

bool aoti_torch_grad_mode_is_enabled() {
return executorch::backends::aoti::aoti_torch_grad_mode_is_enabled();
}

Error aoti_torch_grad_mode_set_enabled(bool enabled) {
return executorch::backends::aoti::aoti_torch_grad_mode_set_enabled(enabled);
}

} // extern "C"
2 changes: 1 addition & 1 deletion backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#pragma once

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
#include <executorch/backends/aoti/slim/core/slim_tensor.h>
#include <executorch/runtime/core/error.h>
#include <cstdint>

Expand Down
16 changes: 6 additions & 10 deletions backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,22 @@ def define_common_targets():
],
)

# AOTI common shims functionality
# AOTI common shims functionality (header-only library)
# The caller determines which tensor type is used by defining CUDA_AVAILABLE.
# - With CUDA_AVAILABLE=1: Uses SlimTensor
# - Without CUDA_AVAILABLE: Uses ETensor
runtime.cxx_library(
name = "common_shims",
srcs = [
"common_shims.cpp",
],
headers = [
"common_shims.h",
"export.h",
"utils.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
# Constructor needed for backend registration.
compiler_flags = ["-Wno-global-constructors"],
visibility = ["PUBLIC"],
deps = [
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/backends/aoti/slim/core:slimtensor",
],
)

Expand Down
19 changes: 0 additions & 19 deletions backends/aoti/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,8 @@ cpp_unittest(
srcs = [
"test_common_shims.cpp",
],
headers = [
"utils.h",
],
deps = [
"//executorch/backends/aoti:common_shims",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/extension/tensor:tensor",
],
)

cpp_unittest(
name = "test_common_shims_slim",
srcs = [
"test_common_shims_slim.cpp",
],
deps = [
"//executorch/backends/aoti:common_shims_slim",
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
"//executorch/runtime/core:core",
Expand Down
74 changes: 0 additions & 74 deletions backends/aoti/tests/utils.h

This file was deleted.

8 changes: 6 additions & 2 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,18 @@ install(
)

# CUDA-specific AOTI shim symbols (dynamically linked)
# Uses common_shims_slim.cpp for SlimTensor-based shim implementations
set(_aoti_cuda_shim_sources
runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp
${EXECUTORCH_ROOT}/backends/aoti/common_shims_slim.cpp
runtime/shims/memory.cpp
runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp
)

add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})

# Define CUDA_AVAILABLE to use SlimTensor in common_shims.h
target_compile_definitions(aoti_cuda_shims PRIVATE CUDA_AVAILABLE=1)

# Define export macros for shared library
if(MSVC)
target_compile_definitions(aoti_cuda_shims PRIVATE EXPORT_AOTI_FUNCTIONS)
Expand Down
Loading
Loading