Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ BaseSemanticSegmentation::computeResult(
}
}

// Filter classes of interest
auto buffersToReturn = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<OwningArrayBuffer>>>();
bool returnAllClasses = classesOfInterest.empty();
for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) {
if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) {
if (cl < allClasses.size() &&
(returnAllClasses || classesOfInterest.contains(allClasses[cl]))) {
(*buffersToReturn)[allClasses[cl]] = resultClasses[cl];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,24 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp
LIBS tokenizers_deps
)

add_rn_test(SemanticSegmentationTests integration/SemanticSegmentationTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/semantic_segmentation/BaseSemanticSegmentation.cpp
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp
${IMAGE_UTILS_SOURCES}
LIBS opencv_deps android
)

add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/instance_segmentation/BaseInstanceSegmentation.cpp
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp
${RNEXECUTORCH_DIR}/utils/computer_vision/Processing.cpp
${IMAGE_UTILS_SOURCES}
LIBS opencv_deps android
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include <algorithm>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <gtest/gtest.h>
#include <rnexecutorch/Error.h>
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
#include <rnexecutorch/models/semantic_segmentation/Constants.h>
#include <rnexecutorch/models/semantic_segmentation/SemanticSegmentation.h>
#include <rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h>
#include <string>
#include <vector>

Expand All @@ -19,6 +19,18 @@ constexpr auto kValidSemanticSegmentationModelPath =
constexpr auto kValidTestImagePath =
"file:///data/local/tmp/rnexecutorch_tests/test_image.jpg";

// DeepLab V3 class labels (Pascal VOC)
static const std::vector<std::string> kDeeplabV3Labels = {
"BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT",
"BOTTLE", "BUS", "CAR", "CAT", "CHAIR",
"COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE",
"PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN",
"TVMONITOR"};

// ImageNet normalization constants
static const std::vector<float> kImageNetMean = {0.485f, 0.456f, 0.406f};
static const std::vector<float> kImageNetStd = {0.229f, 0.224f, 0.225f};

static JSTensorViewIn makeRgbView(std::vector<uint8_t> &buf, int32_t h,
int32_t w) {
buf.assign(static_cast<size_t>(h * w * 3), 128);
Expand All @@ -30,8 +42,9 @@ static JSTensorViewIn makeRgbView(std::vector<uint8_t> &buf, int32_t h,
class SemanticSegmentationForwardTest : public ::testing::Test {
protected:
void SetUp() override {
model = std::make_unique<SemanticSegmentation>(
kValidSemanticSegmentationModelPath, nullptr);
model = std::make_unique<BaseSemanticSegmentation>(
kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd,
kDeeplabV3Labels, nullptr);
auto shapes = model->getAllInputShapes("forward");
ASSERT_FALSE(shapes.empty());
shape = shapes[0];
Expand All @@ -47,21 +60,24 @@ class SemanticSegmentationForwardTest : public ::testing::Test {
make_tensor_ptr(sizes, dummyData.data(), exec_aten::ScalarType::Float);
}

std::unique_ptr<SemanticSegmentation> model;
std::unique_ptr<BaseSemanticSegmentation> model;
std::vector<int32_t> shape;
std::vector<float> dummyData;
std::vector<int32_t> sizes;
TensorPtr inputTensor;
};

TEST(SemanticSegmentationCtorTests, InvalidPathThrows) {
EXPECT_THROW(SemanticSegmentation("this_file_does_not_exist.pte", nullptr),
EXPECT_THROW(BaseSemanticSegmentation("this_file_does_not_exist.pte",
kImageNetMean, kImageNetStd,
kDeeplabV3Labels, nullptr),
RnExecutorchError);
}

TEST(SemanticSegmentationCtorTests, ValidPathDoesntThrow) {
EXPECT_NO_THROW(
SemanticSegmentation(kValidSemanticSegmentationModelPath, nullptr));
EXPECT_NO_THROW(BaseSemanticSegmentation(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd,
kDeeplabV3Labels, nullptr));
}

TEST_F(SemanticSegmentationForwardTest, ForwardWithValidTensorSucceeds) {
Expand Down Expand Up @@ -108,40 +124,52 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) {
// generateFromString tests
// ============================================================================
TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
EXPECT_THROW(
(void)model.generateFromString("nonexistent_image.jpg", {}, true),
RnExecutorchError);
}

TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError);
}

TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
EXPECT_THROW(
(void)model.generateFromString("not_a_valid_uri://bad", {}, true),
RnExecutorchError);
}

TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto result = model.generateFromString(kValidTestImagePath, {}, true);
EXPECT_NE(result.argmax, nullptr);
EXPECT_NE(result.classBuffers, nullptr);
}

TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto result = model.generateFromString(kValidTestImagePath, {}, true);
ASSERT_NE(result.classBuffers, nullptr);
EXPECT_EQ(result.classBuffers->size(), 21u);
}

TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
std::set<std::string, std::less<>> filter = {"PERSON", "CAT"};
auto result = model.generateFromString(kValidTestImagePath, filter, true);
ASSERT_NE(result.classBuffers, nullptr);
Expand All @@ -152,7 +180,9 @@ TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
}

TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto result = model.generateFromString(kValidTestImagePath, {}, false);
EXPECT_NE(result.argmax, nullptr);
}
Expand All @@ -161,7 +191,9 @@ TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
// generateFromPixels tests
// ============================================================================
TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
std::vector<uint8_t> buf;
auto view = makeRgbView(buf, 64, 64);
auto result = model.generateFromPixels(view, {}, true);
Expand All @@ -170,7 +202,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
}

TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
std::vector<uint8_t> buf;
auto view = makeRgbView(buf, 64, 64);
auto result = model.generateFromPixels(view, {}, true);
Expand All @@ -179,7 +213,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
}

TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
std::vector<uint8_t> buf;
auto view = makeRgbView(buf, 64, 64);
std::set<std::string, std::less<>> filter = {"PERSON"};
Expand All @@ -194,32 +230,42 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
// Inherited BaseModel tests
// ============================================================================
TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto shape = model.getInputShape("forward", 0);
EXPECT_EQ(shape.size(), 4);
EXPECT_EQ(shape[0], 1); // Batch size
EXPECT_EQ(shape[1], 3); // RGB channels
}

TEST(SemanticSegmentationInheritedTests, GetAllInputShapesWorks) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto shapes = model.getAllInputShapes("forward");
EXPECT_FALSE(shapes.empty());
}

TEST(SemanticSegmentationInheritedTests, GetMethodMetaWorks) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto result = model.getMethodMeta("forward");
EXPECT_TRUE(result.ok());
}

TEST(SemanticSegmentationInheritedTests, GetMemoryLowerBoundReturnsPositive) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
EXPECT_GT(model.getMemoryLowerBound(), 0u);
}

TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) {
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
nullptr);
auto shape = model.getInputShape("forward", 0);
EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3
}
Expand All @@ -228,29 +274,18 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) {
// Constants tests
// ============================================================================
TEST(SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) {
EXPECT_EQ(constants::kDeeplabV3Resnet50Labels.size(), 21u);
EXPECT_EQ(kDeeplabV3Labels.size(), 21u);
}

TEST(SemanticSegmentationConstantsTests, ClassLabelsContainExpectedClasses) {
auto &labels = constants::kDeeplabV3Resnet50Labels;
bool hasBackground = false;
bool hasPerson = false;
bool hasCat = false;
bool hasDog = false;

for (const auto &label : labels) {
if (label == "BACKGROUND")
hasBackground = true;
if (label == "PERSON")
hasPerson = true;
if (label == "CAT")
hasCat = true;
if (label == "DOG")
hasDog = true;
}
const auto &labels = kDeeplabV3Labels;

auto contains = [&labels](const std::string &target) {
return std::ranges::find(labels, target) != labels.end();
};

EXPECT_TRUE(hasBackground);
EXPECT_TRUE(hasPerson);
EXPECT_TRUE(hasCat);
EXPECT_TRUE(hasDog);
EXPECT_TRUE(contains("BACKGROUND"));
EXPECT_TRUE(contains("PERSON"));
EXPECT_TRUE(contains("CAT"));
EXPECT_TRUE(contains("DOG"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TEST_EXECUTABLES=(
"LLMTests"
"TextToImageTests"
"InstanceSegmentationTests"
"SemanticSegmentationTests"
"OCRTests"
"VerticalOCRTests"
)
Expand Down
Loading