Skip to content

Comments

feat!: make image segmentation generic, general refactor#814

Open
chmjkb wants to merge 24 commits intomainfrom
@chmjkb/image-segmentation-generic
Open

feat!: make image segmentation generic, general refactor#814
chmjkb wants to merge 24 commits intomainfrom
@chmjkb/image-segmentation-generic

Conversation

@chmjkb
Copy link
Collaborator

@chmjkb chmjkb commented Feb 17, 2026

Description

Refactors image segmentation into a generic, multi-model architecture. Previously the module was hardcoded to DeepLab V3 — now it supports multiple built-in models (DeepLab V3, selfie segmentation, RF-DETR) and custom user-provided models with type-safe label maps.

Key changes:

  • C++ base class: Extracted BaseImageSegmentation with virtual preprocess()/postprocess() methods. ImageSegmentation is now a thin subclass. This allows future models to override preprocessing (e.g. different normalization) or postprocessing without duplicating the pipeline.
  • Optional normalization in C++: readImageToTensor now accepts optional normMean/normStd params, eliminating duplicated normalization logic. also, imo it would be a good idea to do such factories for the entire API
  • Generic TypeScript module: ImageSegmentationModule<T> is generic over model name or custom LabelEnum. Two static factories: fromModelName() (built-in models with auto label resolution) and fromCustomConfig() (custom models with user-provided labels).
  • Generic hook: useImageSegmentation infers the model's label types from the config — no explicit generic parameter needed. forward() return type narrows based on classesOfInterest passed in.
  • Correct return types: forward() now returns Record<'ARGMAX', Int32Array> & Record<K, Float32Array> matching what the native side actually produces (was incorrectly typed as number[]).
  • ARGMAX always returned: Removed 'ARGMAX' from classesOfInterest — it's always in the output regardless, and the return type reflects this.

Introduces a breaking change?

  • Yes
  • No

Type of change

  • Bug fix (change which fixes an issue)
  • New feature (change which adds functionality)
  • Documentation update (improves or adds clarity to existing documentation)
  • Other (chores, tests, code style improvements etc.)

Tested on

  • iOS
  • Android

Testing instructions

  1. Build and run the computer-vision demo app
  2. Navigate to Image Segmentation screen
  3. Pick an image and run segmentation — verify the ARGMAX overlay renders correctly
  4. Verify the hook API works as expected:
const { isReady, forward } = useImageSegmentation({
  model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 },
});

// Returns Record<'ARGMAX', Int32Array> — no generic needed
const result = await forward(imageUri);

// Narrows return type to include 'PERSON' key as Float32Array
const result2 = await forward(imageUri, ['PERSON']);
  1. Verify TypeScript autocompletion: classesOfInterest should only suggest valid label keys for the chosen model (e.g. 'PERSON', 'CAR' for DeepLab, 'SELFIE'/'BACKGROUND' for selfie segmentation)
  2. You can also try changing the parameters, to say selfie segmentation and see how the return types react. Please contact me for weights for selfie segmentation as I'm not pushing them to HF yet

Screenshots

Related issues

Checklist

  • I have performed a self-review of my code
  • I have commented my code, particularly in hard-to-understand areas
  • I have updated the documentation accordingly
  • My changes generate no new warnings

Additional notes

The ImageSegmentationModule.fromCustomConfig() API allows users to bring their own segmentation model with a custom label map:

const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const;
const seg = await ImageSegmentationModule.fromCustomConfig(
  'https://example.com/model.pte',
  { labelMap: MyLabels },
);

Comment on lines 26 to 29
'rfdetr': {
labelMap: CocoLabel,
preprocessorConfig: { normMean: IMAGENET_MEAN, normStd: IMAGENET_STD },
},
Copy link
Collaborator Author

@chmjkb chmjkb Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi this doesnt work as the rfdetr is an instance segmentation model and needs a separate implementation, but leaving this so you can get an idea on what the configuration looks like

@chmjkb chmjkb marked this pull request as ready for review February 17, 2026 14:30
@chmjkb chmjkb changed the title feat: make image segmentation generic, general refactor feat!: make image segmentation generic, general refactor Feb 17, 2026
@chmjkb chmjkb requested a review from msluszniak February 18, 2026 11:19
@chmjkb chmjkb force-pushed the @chmjkb/image-segmentation-generic branch from 9c0f38e to 6a523df Compare February 18, 2026 11:59
@chmjkb chmjkb requested a review from msluszniak February 19, 2026 10:21
@@ -1,170 +1 @@
#include "ImageSegmentation.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove file

@@ -0,0 +1,57 @@
#pragma once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing virtual destructor in model classes

Comment on lines 46 to 48
if (preventLoad) return;

let currentInstance: ImageSegmentationModule<ModelNameOf<C>> | null = null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add isMounted guard here (reference here)

onDownloadProgress: (progress: number) => void = () => {}
): Promise<ImageSegmentationModule<L>> {
const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource);
if (!paths?.[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should be also checking for paths.length == =1

Co-authored-by: Mateusz Kopcinski <120639731+mkopcins@users.noreply.github.com>
Comment on lines 25 to 30
if (normMean.size() >= 3) {
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
}
if (normStd.size() >= 3) {
normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if norm vector has less than 3 elements? Maybe we should log here?

@@ -1,170 +1 @@
#include "ImageSegmentation.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a .cpp file that just includes header?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a generic interface for image segmentation models

3 participants