diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index ea4b0653a..66458cb56 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -218,11 +218,12 @@ BaseSemanticSegmentation::computeResult( } } - // Filter classes of interest auto buffersToReturn = std::make_shared< std::unordered_map>>(); + bool returnAllClasses = classesOfInterest.empty(); for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { - if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) { + if (cl < allClasses.size() && + (returnAllClasses || classesOfInterest.contains(allClasses[cl]))) { (*buffersToReturn)[allClasses[cl]] = resultClasses[cl]; } } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index f6fe386a7..d68ab3350 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -300,12 +300,24 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp LIBS tokenizers_deps ) +add_rn_test(SemanticSegmentationTests integration/SemanticSegmentationTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/models/semantic_segmentation/BaseSemanticSegmentation.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/instance_segmentation/BaseInstanceSegmentation.cpp ${RNEXECUTORCH_DIR}/models/VisionModel.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp ${RNEXECUTORCH_DIR}/utils/computer_vision/Processing.cpp ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp index 957421f09..09c5d42f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp @@ -1,10 +1,10 @@ +#include #include #include #include #include #include -#include -#include +#include #include #include @@ -19,6 +19,18 @@ constexpr auto kValidSemanticSegmentationModelPath = constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +// DeepLab V3 class labels (Pascal VOC) +static const std::vector kDeeplabV3Labels = { + "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", + "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", + "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", + "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", + "TVMONITOR"}; + +// ImageNet normalization constants +static const std::vector kImageNetMean = {0.485f, 0.456f, 0.406f}; +static const std::vector kImageNetStd = {0.229f, 0.224f, 0.225f}; + static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, int32_t w) { buf.assign(static_cast(h * w * 3), 128); @@ -30,8 +42,9 @@ static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, class SemanticSegmentationForwardTest : public ::testing::Test { protected: void SetUp() override { - model = std::make_unique( - kValidSemanticSegmentationModelPath, nullptr); + model = std::make_unique( + kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto shapes = model->getAllInputShapes("forward"); ASSERT_FALSE(shapes.empty()); shape = shapes[0]; @@ -47,7 +60,7 @@ class SemanticSegmentationForwardTest : public ::testing::Test { make_tensor_ptr(sizes, dummyData.data(), exec_aten::ScalarType::Float); } - std::unique_ptr model; + std::unique_ptr model; std::vector shape; std::vector dummyData; std::vector sizes; @@ -55,13 +68,16 @@ class SemanticSegmentationForwardTest : public ::testing::Test { }; TEST(SemanticSegmentationCtorTests, InvalidPathThrows) { - EXPECT_THROW(SemanticSegmentation("this_file_does_not_exist.pte", nullptr), + EXPECT_THROW(BaseSemanticSegmentation("this_file_does_not_exist.pte", + kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr), RnExecutorchError); } TEST(SemanticSegmentationCtorTests, ValidPathDoesntThrow) { - EXPECT_NO_THROW( - SemanticSegmentation(kValidSemanticSegmentationModelPath, nullptr)); + EXPECT_NO_THROW(BaseSemanticSegmentation(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr)); } TEST_F(SemanticSegmentationForwardTest, ForwardWithValidTensorSucceeds) { @@ -108,40 +124,52 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) { // generateFromString tests // ============================================================================ TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); EXPECT_THROW( (void)model.generateFromString("nonexistent_image.jpg", {}, true), RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); EXPECT_THROW( (void)model.generateFromString("not_a_valid_uri://bad", {}, true), RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, true); EXPECT_NE(result.argmax, nullptr); EXPECT_NE(result.classBuffers, nullptr); } TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, true); ASSERT_NE(result.classBuffers, nullptr); EXPECT_EQ(result.classBuffers->size(), 21u); } TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); std::set> filter = {"PERSON", "CAT"}; auto result = model.generateFromString(kValidTestImagePath, filter, true); ASSERT_NE(result.classBuffers, nullptr); @@ -152,7 +180,9 @@ TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { } TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, false); EXPECT_NE(result.argmax, nullptr); } @@ -161,7 +191,9 @@ TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { // generateFromPixels tests // ============================================================================ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); auto result = model.generateFromPixels(view, {}, true); @@ -170,7 +202,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { } TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); auto result = model.generateFromPixels(view, {}, true); @@ -179,7 +213,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { } TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); std::set> filter = {"PERSON"}; @@ -194,7 +230,9 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { // Inherited BaseModel tests // ============================================================================ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_EQ(shape.size(), 4); EXPECT_EQ(shape[0], 1); // Batch size @@ -202,24 +240,32 @@ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { } TEST(SemanticSegmentationInheritedTests, GetAllInputShapesWorks) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto shapes = model.getAllInputShapes("forward"); EXPECT_FALSE(shapes.empty()); } TEST(SemanticSegmentationInheritedTests, GetMethodMetaWorks) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } TEST(SemanticSegmentationInheritedTests, GetMemoryLowerBoundReturnsPositive) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); EXPECT_GT(model.getMemoryLowerBound(), 0u); } TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) { - SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, kDeeplabV3Labels, + nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3 } @@ -228,29 +274,18 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) { // Constants tests // ============================================================================ TEST(SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) { - EXPECT_EQ(constants::kDeeplabV3Resnet50Labels.size(), 21u); + EXPECT_EQ(kDeeplabV3Labels.size(), 21u); } TEST(SemanticSegmentationConstantsTests, ClassLabelsContainExpectedClasses) { - auto &labels = constants::kDeeplabV3Resnet50Labels; - bool hasBackground = false; - bool hasPerson = false; - bool hasCat = false; - bool hasDog = false; - - for (const auto &label : labels) { - if (label == "BACKGROUND") - hasBackground = true; - if (label == "PERSON") - hasPerson = true; - if (label == "CAT") - hasCat = true; - if (label == "DOG") - hasDog = true; - } + const auto &labels = kDeeplabV3Labels; + + auto contains = [&labels](const std::string &target) { + return std::ranges::find(labels, target) != labels.end(); + }; - EXPECT_TRUE(hasBackground); - EXPECT_TRUE(hasPerson); - EXPECT_TRUE(hasCat); - EXPECT_TRUE(hasDog); + EXPECT_TRUE(contains("BACKGROUND")); + EXPECT_TRUE(contains("PERSON")); + EXPECT_TRUE(contains("CAT")); + EXPECT_TRUE(contains("DOG")); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh index 843aa1f4f..468f2c965 100755 --- a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh +++ b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh @@ -34,6 +34,7 @@ TEST_EXECUTABLES=( "LLMTests" "TextToImageTests" "InstanceSegmentationTests" + "SemanticSegmentationTests" "OCRTests" "VerticalOCRTests" )