diff --git a/backends/aoti/common_shims_slim.cpp b/backends/aoti/common_shims_slim.cpp index b004c1d16a6..1976ee94a3d 100644 --- a/backends/aoti/common_shims_slim.cpp +++ b/backends/aoti/common_shims_slim.cpp @@ -98,6 +98,67 @@ AOTITorchError aoti_torch_get_device_index( return Error::Ok; } +// ============================================================ +// DType Constants - Implementations +// ============================================================ + +int32_t aoti_torch_dtype_float32() { + return 6; // ScalarType::Float +} + +int32_t aoti_torch_dtype_bfloat16() { + return 15; // ScalarType::BFloat16 +} + +int32_t aoti_torch_dtype_int64() { + return 4; // ScalarType::Long +} + +int32_t aoti_torch_dtype_int32() { + return 3; // ScalarType::Int +} + +int32_t aoti_torch_dtype_int16() { + return 2; // ScalarType::Short +} + +int32_t aoti_torch_dtype_int8() { + return 1; // ScalarType::Char +} + +int32_t aoti_torch_dtype_bool() { + return 11; // ScalarType::Bool +} + +// ============================================================ +// Device Type Constants - Implementations +// ============================================================ + +int32_t aoti_torch_device_type_cpu() { + return 0; // DeviceType::CPU +} + +int32_t aoti_torch_device_type_cuda() { + return 1; // DeviceType::CUDA +} + +// ============================================================ +// Grad Mode Functions - Implementations +// ============================================================ + +bool aoti_torch_grad_mode_is_enabled() { + // ExecuTorch doesn't support autograd + return false; +} + +AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) { + if (enabled) { + // ExecuTorch doesn't support autograd + return Error::NotSupported; + } + return Error::Ok; +} + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/aoti/common_shims_slim.h b/backends/aoti/common_shims_slim.h index 26022a76f14..b75b8d784e5 100644 --- a/backends/aoti/common_shims_slim.h +++ b/backends/aoti/common_shims_slim.h @@ -62,6 +62,32 @@ aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type); AOTI_SHIM_EXPORT AOTITorchError aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index); +// ============================================================ +// DType Constants - Declarations +// ============================================================ + +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); + +// ============================================================ +// Device Type Constants - Declarations +// ============================================================ + +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cuda(); + +// ============================================================ +// Grad Mode Functions - Declarations +// ============================================================ + +AOTI_SHIM_EXPORT bool aoti_torch_grad_mode_is_enabled(); +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled); + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/aoti/tests/test_common_shims_slim.cpp b/backends/aoti/tests/test_common_shims_slim.cpp index 728bcc6a34f..ca744565955 100644 --- a/backends/aoti/tests/test_common_shims_slim.cpp +++ b/backends/aoti/tests/test_common_shims_slim.cpp @@ -589,3 +589,44 @@ TEST_F(CommonShimsSlimTest, ConsistentPointerReturn) { EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); EXPECT_EQ(strides_ptr1, strides_ptr2); } + +// ============================================================================ +// DType Constants Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, DTypeConstants) { + // Verify dtype constants match expected PyTorch ScalarType values + EXPECT_EQ(aoti_torch_dtype_float32(), 6); // ScalarType::Float + EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // ScalarType::BFloat16 + EXPECT_EQ(aoti_torch_dtype_int64(), 4); // ScalarType::Long + EXPECT_EQ(aoti_torch_dtype_int32(), 3); // ScalarType::Int + EXPECT_EQ(aoti_torch_dtype_int16(), 2); // ScalarType::Short + EXPECT_EQ(aoti_torch_dtype_int8(), 1); // ScalarType::Char + EXPECT_EQ(aoti_torch_dtype_bool(), 11); // ScalarType::Bool +} + +// ============================================================================ +// Device Type Constants Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, DeviceTypeConstants) { + EXPECT_EQ(aoti_torch_device_type_cpu(), 0); // DeviceType::CPU + EXPECT_EQ(aoti_torch_device_type_cuda(), 1); // DeviceType::CUDA +} + +// ============================================================================ +// Grad Mode Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, GradModeIsEnabled) { + // ExecuTorch doesn't support autograd, so should always return false + EXPECT_EQ(aoti_torch_grad_mode_is_enabled(), false); +} + +TEST_F(CommonShimsSlimTest, GradModeSetEnabled) { + // Setting to false should succeed + EXPECT_EQ(aoti_torch_grad_mode_set_enabled(false), Error::Ok); + + // Setting to true should fail (not supported in ExecuTorch) + EXPECT_EQ(aoti_torch_grad_mode_set_enabled(true), Error::NotSupported); +}