ソースを参照

Add MobileNetv1

vishnukvmd 2 年 前
コミット
a5646511e0

+ 0 - 0
assets/models/labelmap.txt → assets/models/cocossd/labels.txt


+ 0 - 0
assets/models/detect.tflite → assets/models/cocossd/model.tflite


+ 1001 - 0
assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt

@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue

BIN
assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite


+ 46 - 18
lib/services/object_detection/object_detection_service.dart

@@ -4,18 +4,18 @@ import "dart:typed_data";
 import "package:logging/logging.dart";
 import "package:photos/services/object_detection/models/predictions.dart";
 import 'package:photos/services/object_detection/models/recognition.dart';
-import "package:photos/services/object_detection/tflite/classifier.dart";
+import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
+import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
 import "package:photos/services/object_detection/utils/isolate_utils.dart";
 
 class ObjectDetectionService {
-  static const scoreThreshold = 0.6;
+  static const scoreThreshold = 0.5;
 
   final _logger = Logger("ObjectDetectionService");
 
-  /// Instance of [ObjectClassifier]
-  late ObjectClassifier _classifier;
+  late CocoSSDClassifier _objectClassifier;
+  late MobileNetClassifier _mobileNetClassifier;
 
-  /// Instance of [IsolateUtils]
   late IsolateUtils _isolateUtils;
 
   ObjectDetectionService._privateConstructor();
@@ -23,7 +23,8 @@ class ObjectDetectionService {
   Future<void> init() async {
     _isolateUtils = IsolateUtils();
     await _isolateUtils.start();
-    _classifier = ObjectClassifier();
+    _objectClassifier = CocoSSDClassifier();
+    _mobileNetClassifier = MobileNetClassifier();
   }
 
   static ObjectDetectionService instance =
@@ -31,18 +32,11 @@ class ObjectDetectionService {
 
   Future<List<String>> predict(Uint8List bytes) async {
     try {
-      final isolateData = IsolateData(
-        bytes,
-        _classifier.interpreter.address,
-        _classifier.labels,
-      );
-      final predictions = await _inference(isolateData);
-      final Set<String> results = {};
-      for (final Recognition result in predictions.recognitions) {
-        if (result.score > scoreThreshold) {
-          results.add(result.label);
-        }
-      }
+      final results = <String>{};
+      final objectResults = await _getObjects(bytes);
+      results.addAll(objectResults);
+      final mobileNetResults = await _getMobileNetResults(bytes);
+      results.addAll(mobileNetResults);
       return results.toList();
     } catch (e, s) {
       _logger.severe(e, s);
@@ -50,6 +44,40 @@ class ObjectDetectionService {
     }
   }
 
+  Future<List<String>> _getObjects(Uint8List bytes) async {
+    final isolateData = IsolateData(
+      bytes,
+      _objectClassifier.interpreter.address,
+      _objectClassifier.labels,
+      ClassifierType.cocossd,
+    );
+    final predictions = await _inference(isolateData);
+    final Set<String> results = {};
+    for (final Recognition result in predictions.recognitions) {
+      if (result.score > scoreThreshold) {
+        results.add(result.label);
+      }
+    }
+    return results.toList();
+  }
+
+  Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
+    final isolateData = IsolateData(
+      bytes,
+      _mobileNetClassifier.interpreter.address,
+      _mobileNetClassifier.labels,
+      ClassifierType.mobilenet,
+    );
+    final predictions = await _inference(isolateData);
+    final Set<String> results = {};
+    for (final Recognition result in predictions.recognitions) {
+      if (result.score > scoreThreshold) {
+        results.add(result.label);
+      }
+    }
+    return results.toList();
+  }
+
   /// Runs inference in another isolate
   Future<Predictions> _inference(IsolateData isolateData) async {
     final responsePort = ReceivePort();

+ 3 - 176
lib/services/object_detection/tflite/classifier.dart

@@ -1,179 +1,6 @@
-import 'dart:math';
-
 import 'package:image/image.dart' as imageLib;
-import "package:logging/logging.dart";
-import 'package:photos/services/object_detection/models/predictions.dart';
-import 'package:photos/services/object_detection/models/recognition.dart';
-import "package:photos/services/object_detection/models/stats.dart";
-import "package:tflite_flutter/tflite_flutter.dart";
-import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
-
-/// Classifier
-class ObjectClassifier {
-  final _logger = Logger("Classifier");
-
-  /// Instance of Interpreter
-  late Interpreter _interpreter;
-
-  /// Labels file loaded as list
-  late List<String> _labels;
-
-  /// Input size of image (height = width = 300)
-  static const int inputSize = 300;
-
-  /// Result score threshold
-  static const double threshold = 0.5;
-
-  static const String modelFileName = "detect.tflite";
-  static const String labelFileName = "labelmap.txt";
-
-  /// [ImageProcessor] used to pre-process the image
-  ImageProcessor? imageProcessor;
-
-  /// Padding the image to transform into square
-  late int padSize;
-
-  /// Shapes of output tensors
-  late List<List<int>> _outputShapes;
-
-  /// Types of output tensors
-  late List<TfLiteType> _outputTypes;
-
-  /// Number of results to show
-  static const int numResults = 10;
-
-  ObjectClassifier({
-    Interpreter? interpreter,
-    List<String>? labels,
-  }) {
-    loadModel(interpreter);
-    loadLabels(labels);
-  }
-
-  /// Loads interpreter from asset
-  void loadModel(Interpreter? interpreter) async {
-    try {
-      _interpreter = interpreter ??
-          await Interpreter.fromAsset(
-            "models/" + modelFileName,
-            options: InterpreterOptions()..threads = 4,
-          );
-      final outputTensors = _interpreter.getOutputTensors();
-      _outputShapes = [];
-      _outputTypes = [];
-      outputTensors.forEach((tensor) {
-        _outputShapes.add(tensor.shape);
-        _outputTypes.add(tensor.type);
-      });
-      _logger.info("Interpreter initialized");
-    } catch (e, s) {
-      _logger.severe("Error while creating interpreter", e, s);
-    }
-  }
-
-  /// Loads labels from assets
-  void loadLabels(List<String>? labels) async {
-    try {
-      _labels =
-          labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
-      _logger.info("Labels initialized");
-    } catch (e, s) {
-      _logger.severe("Error while loading labels", e, s);
-    }
-  }
-
-  /// Pre-process the image
-  TensorImage _getProcessedImage(TensorImage inputImage) {
-    padSize = max(inputImage.height, inputImage.width);
-    imageProcessor ??= ImageProcessorBuilder()
-        .add(ResizeWithCropOrPadOp(padSize, padSize))
-        .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
-        .build();
-    inputImage = imageProcessor!.process(inputImage);
-    return inputImage;
-  }
-
-  /// Runs object detection on the input image
-  Predictions? predict(imageLib.Image image) {
-    final predictStartTime = DateTime.now().millisecondsSinceEpoch;
-
-    final preProcessStart = DateTime.now().millisecondsSinceEpoch;
-
-    // Create TensorImage from image
-    TensorImage inputImage = TensorImage.fromImage(image);
-
-    // Pre-process TensorImage
-    inputImage = _getProcessedImage(inputImage);
-
-    final preProcessElapsedTime =
-        DateTime.now().millisecondsSinceEpoch - preProcessStart;
-
-    // TensorBuffers for output tensors
-    final outputLocations = TensorBufferFloat(_outputShapes[0]);
-    final outputClasses = TensorBufferFloat(_outputShapes[1]);
-    final outputScores = TensorBufferFloat(_outputShapes[2]);
-    final numLocations = TensorBufferFloat(_outputShapes[3]);
-
-    // Inputs object for runForMultipleInputs
-    // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
-    final inputs = [inputImage.buffer];
-
-    // Outputs map
-    final outputs = {
-      0: outputLocations.buffer,
-      1: outputClasses.buffer,
-      2: outputScores.buffer,
-      3: numLocations.buffer,
-    };
-
-    final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
-
-    // run inference
-    _interpreter.runForMultipleInputs(inputs, outputs);
-
-    final inferenceTimeElapsed =
-        DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
-
-    // Maximum number of results to show
-    final resultsCount = min(numResults, numLocations.getIntValue(0));
-
-    // Using labelOffset = 1 as ??? at index 0
-    const labelOffset = 1;
-
-    final recognitions = <Recognition>[];
-
-    for (int i = 0; i < resultsCount; i++) {
-      // Prediction score
-      final score = outputScores.getDoubleValue(i);
-
-      // Label string
-      final labelIndex = outputClasses.getIntValue(i) + labelOffset;
-      final label = _labels.elementAt(labelIndex);
-
-      if (score > threshold) {
-        recognitions.add(
-          Recognition(i, label, score),
-        );
-      }
-    }
-
-    final predictElapsedTime =
-        DateTime.now().millisecondsSinceEpoch - predictStartTime;
-    _logger.info(recognitions);
-    return Predictions(
-      recognitions,
-      Stats(
-        predictElapsedTime,
-        predictElapsedTime,
-        inferenceTimeElapsed,
-        preProcessElapsedTime,
-      ),
-    );
-  }
-
-  /// Gets the interpreter instance
-  Interpreter get interpreter => _interpreter;
+import "package:photos/services/object_detection/models/predictions.dart";
 
-  /// Gets the loaded labels
-  List<String> get labels => _labels;
+abstract class Classifier {
+  Predictions? predict(imageLib.Image image);
 }

+ 180 - 0
lib/services/object_detection/tflite/cocossd_classifier.dart

@@ -0,0 +1,180 @@
+import 'dart:math';
+
+import 'package:image/image.dart' as imageLib;
+import "package:logging/logging.dart";
+import 'package:photos/services/object_detection/models/predictions.dart';
+import 'package:photos/services/object_detection/models/recognition.dart';
+import "package:photos/services/object_detection/models/stats.dart";
+import "package:photos/services/object_detection/tflite/classifier.dart";
+import "package:tflite_flutter/tflite_flutter.dart";
+import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
+
+/// Classifier
+class CocoSSDClassifier extends Classifier {
+  final _logger = Logger("Classifier");
+
+  /// Instance of Interpreter
+  late Interpreter _interpreter;
+
+  /// Labels file loaded as list
+  late List<String> _labels;
+
+  /// Input size of image (height = width = 300)
+  static const int inputSize = 300;
+
+  /// Result score threshold
+  static const double threshold = 0.5;
+
+  static const String modelFileName = "model.tflite";
+  static const String labelFileName = "labels.txt";
+
+  /// [ImageProcessor] used to pre-process the image
+  ImageProcessor? imageProcessor;
+
+  /// Padding the image to transform into square
+  late int padSize;
+
+  /// Shapes of output tensors
+  late List<List<int>> _outputShapes;
+
+  /// Types of output tensors
+  late List<TfLiteType> _outputTypes;
+
+  /// Number of results to show
+  static const int numResults = 10;
+
+  CocoSSDClassifier({
+    Interpreter? interpreter,
+    List<String>? labels,
+  }) {
+    loadModel(interpreter);
+    loadLabels(labels);
+  }
+
+  /// Loads interpreter from asset
+  void loadModel(Interpreter? interpreter) async {
+    try {
+      _interpreter = interpreter ??
+          await Interpreter.fromAsset(
+            "models/cocossd/" + modelFileName,
+            options: InterpreterOptions()..threads = 4,
+          );
+      final outputTensors = _interpreter.getOutputTensors();
+      _outputShapes = [];
+      _outputTypes = [];
+      outputTensors.forEach((tensor) {
+        _outputShapes.add(tensor.shape);
+        _outputTypes.add(tensor.type);
+      });
+      _logger.info("Interpreter initialized");
+    } catch (e, s) {
+      _logger.severe("Error while creating interpreter", e, s);
+    }
+  }
+
+  /// Loads labels from assets
+  void loadLabels(List<String>? labels) async {
+    try {
+      _labels = labels ??
+          await FileUtil.loadLabels("assets/models/cocossd/" + labelFileName);
+      _logger.info("Labels initialized");
+    } catch (e, s) {
+      _logger.severe("Error while loading labels", e, s);
+    }
+  }
+
+  /// Pre-process the image
+  TensorImage _getProcessedImage(TensorImage inputImage) {
+    padSize = max(inputImage.height, inputImage.width);
+    imageProcessor ??= ImageProcessorBuilder()
+        .add(ResizeWithCropOrPadOp(padSize, padSize))
+        .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
+        .build();
+    inputImage = imageProcessor!.process(inputImage);
+    return inputImage;
+  }
+
+  /// Runs object detection on the input image
+  Predictions? predict(imageLib.Image image) {
+    final predictStartTime = DateTime.now().millisecondsSinceEpoch;
+
+    final preProcessStart = DateTime.now().millisecondsSinceEpoch;
+
+    // Create TensorImage from image
+    TensorImage inputImage = TensorImage.fromImage(image);
+
+    // Pre-process TensorImage
+    inputImage = _getProcessedImage(inputImage);
+
+    final preProcessElapsedTime =
+        DateTime.now().millisecondsSinceEpoch - preProcessStart;
+
+    // TensorBuffers for output tensors
+    final outputLocations = TensorBufferFloat(_outputShapes[0]);
+    final outputClasses = TensorBufferFloat(_outputShapes[1]);
+    final outputScores = TensorBufferFloat(_outputShapes[2]);
+    final numLocations = TensorBufferFloat(_outputShapes[3]);
+
+    // Inputs object for runForMultipleInputs
+    // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
+    final inputs = [inputImage.buffer];
+
+    // Outputs map
+    final outputs = {
+      0: outputLocations.buffer,
+      1: outputClasses.buffer,
+      2: outputScores.buffer,
+      3: numLocations.buffer,
+    };
+
+    final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
+
+    // run inference
+    _interpreter.runForMultipleInputs(inputs, outputs);
+
+    final inferenceTimeElapsed =
+        DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
+
+    // Maximum number of results to show
+    final resultsCount = min(numResults, numLocations.getIntValue(0));
+
+    // Using labelOffset = 1 as ??? at index 0
+    const labelOffset = 1;
+
+    final recognitions = <Recognition>[];
+
+    for (int i = 0; i < resultsCount; i++) {
+      // Prediction score
+      final score = outputScores.getDoubleValue(i);
+
+      // Label string
+      final labelIndex = outputClasses.getIntValue(i) + labelOffset;
+      final label = _labels.elementAt(labelIndex);
+
+      if (score > threshold) {
+        recognitions.add(
+          Recognition(i, label, score),
+        );
+      }
+    }
+
+    final predictElapsedTime =
+        DateTime.now().millisecondsSinceEpoch - predictStartTime;
+    _logger.info(recognitions);
+    return Predictions(
+      recognitions,
+      Stats(
+        predictElapsedTime,
+        predictElapsedTime,
+        inferenceTimeElapsed,
+        preProcessElapsedTime,
+      ),
+    );
+  }
+
+  /// Gets the interpreter instance
+  Interpreter get interpreter => _interpreter;
+
+  /// Gets the loaded labels
+  List<String> get labels => _labels;
+}

+ 151 - 0
lib/services/object_detection/tflite/mobilenet_classifier.dart

@@ -0,0 +1,151 @@
+import 'dart:math';
+
+import 'package:image/image.dart' as imageLib;
+import "package:logging/logging.dart";
+import 'package:photos/services/object_detection/models/predictions.dart';
+import 'package:photos/services/object_detection/models/recognition.dart';
+import "package:photos/services/object_detection/models/stats.dart";
+import "package:photos/services/object_detection/tflite/classifier.dart";
+import "package:tflite_flutter/tflite_flutter.dart";
+import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
+
+class MobileNetClassifier extends Classifier {
+  final _logger = Logger("MobileNetClassifier");
+
+  /// Instance of Interpreter
+  late Interpreter _interpreter;
+
+  /// Labels file loaded as list
+  late List<String> _labels;
+
+  /// Input size of image (height = width = 300)
+  static const int inputSize = 224;
+
+  /// Result score threshold
+  static const double threshold = 0.5;
+
+  static const String modelFileName = "mobilenet_v1_1.0_224_quant.tflite";
+  static const String labelFileName = "labels_mobilenet_quant_v1_224.txt";
+
+  /// [ImageProcessor] used to pre-process the image
+  ImageProcessor? imageProcessor;
+
+  /// Padding the image to transform into square
+  late int padSize;
+
+  /// Shapes of output tensors
+  late List<List<int>> _outputShapes;
+
+  /// Types of output tensors
+  late List<TfLiteType> _outputTypes;
+
+  /// Number of results to show
+  static const int numResults = 10;
+
+  MobileNetClassifier({
+    Interpreter? interpreter,
+    List<String>? labels,
+  }) {
+    loadModel(interpreter);
+    loadLabels(labels);
+  }
+
+  /// Loads interpreter from asset
+  void loadModel(Interpreter? interpreter) async {
+    try {
+      _interpreter = interpreter ??
+          await Interpreter.fromAsset(
+            "models/mobilenet/" + modelFileName,
+            options: InterpreterOptions()..threads = 4,
+          );
+      final outputTensors = _interpreter.getOutputTensors();
+      _outputShapes = [];
+      _outputTypes = [];
+      outputTensors.forEach((tensor) {
+        _outputShapes.add(tensor.shape);
+        _outputTypes.add(tensor.type);
+      });
+      _logger.info("Interpreter initialized");
+    } catch (e, s) {
+      _logger.severe("Error while creating interpreter", e, s);
+    }
+  }
+
+  /// Loads labels from assets
+  void loadLabels(List<String>? labels) async {
+    try {
+      _labels = labels ??
+          await FileUtil.loadLabels("assets/models/mobilenet/" + labelFileName);
+      _logger.info("Labels initialized");
+    } catch (e, s) {
+      _logger.severe("Error while loading labels", e, s);
+    }
+  }
+
+  /// Pre-process the image
+  TensorImage _getProcessedImage(TensorImage inputImage) {
+    padSize = max(inputImage.height, inputImage.width);
+    imageProcessor ??= ImageProcessorBuilder()
+        // .add(ResizeWithCropOrPadOp(padSize, padSize))
+        .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
+        .build();
+    inputImage = imageProcessor!.process(inputImage);
+    return inputImage;
+  }
+
+  /// Runs object detection on the input image
+  Predictions? predict(imageLib.Image image) {
+    final predictStartTime = DateTime.now().millisecondsSinceEpoch;
+
+    final preProcessStart = DateTime.now().millisecondsSinceEpoch;
+
+    // Create TensorImage from image
+    TensorImage inputImage = TensorImage.fromImage(image);
+
+    // Pre-process TensorImage
+    inputImage = _getProcessedImage(inputImage);
+
+    final preProcessElapsedTime =
+        DateTime.now().millisecondsSinceEpoch - preProcessStart;
+
+    // TensorBuffers for output tensors
+    final output = TensorBufferUint8(_outputShapes[0]);
+    final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
+    // run inference
+    _interpreter.run(inputImage.buffer, output.buffer);
+
+    final inferenceTimeElapsed =
+        DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
+
+    final recognitions = <Recognition>[];
+    for (int i = 0; i < 1001; i++) {
+      final score = output.getDoubleValue(i) / 255;
+      if (score >= threshold) {
+        final label = _labels.elementAt(i);
+
+        recognitions.add(
+          Recognition(i, "#" + label, score),
+        );
+      }
+    }
+
+    final predictElapsedTime =
+        DateTime.now().millisecondsSinceEpoch - predictStartTime;
+    _logger.info(recognitions);
+    return Predictions(
+      recognitions,
+      Stats(
+        predictElapsedTime,
+        predictElapsedTime,
+        inferenceTimeElapsed,
+        preProcessElapsedTime,
+      ),
+    );
+  }
+
+  /// Gets the interpreter instance
+  Interpreter get interpreter => _interpreter;
+
+  /// Gets the loaded labels
+  List<String> get labels => _labels;
+}

+ 20 - 5
lib/services/object_detection/utils/isolate_utils.dart

@@ -2,7 +2,8 @@ import 'dart:isolate';
 import "dart:typed_data";
 
 import 'package:image/image.dart' as imgLib;
-import "package:photos/services/object_detection/tflite/classifier.dart";
+import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
+import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
 import 'package:tflite_flutter/tflite_flutter.dart';
 
 /// Manages separate Isolate instance for inference
@@ -29,10 +30,17 @@ class IsolateUtils {
     sendPort.send(port.sendPort);
 
     await for (final IsolateData isolateData in port) {
-      final classifier = ObjectClassifier(
-        interpreter: Interpreter.fromAddress(isolateData.interpreterAddress),
-        labels: isolateData.labels,
-      );
+      final classifier = isolateData.type == ClassifierType.cocossd
+          ? CocoSSDClassifier(
+              interpreter:
+                  Interpreter.fromAddress(isolateData.interpreterAddress),
+              labels: isolateData.labels,
+            )
+          : MobileNetClassifier(
+              interpreter:
+                  Interpreter.fromAddress(isolateData.interpreterAddress),
+              labels: isolateData.labels,
+            );
       final image = imgLib.decodeImage(isolateData.input);
       final results = classifier.predict(image!);
       isolateData.responsePort.send(results);
@@ -45,11 +53,18 @@ class IsolateData {
   Uint8List input;
   int interpreterAddress;
   List<String> labels;
+  ClassifierType type;
   late SendPort responsePort;
 
   IsolateData(
     this.input,
     this.interpreterAddress,
     this.labels,
+    this.type,
   );
 }
+
+enum ClassifierType {
+  cocossd,
+  mobilenet,
+}

+ 2 - 1
pubspec.yaml

@@ -165,7 +165,8 @@ flutter_native_splash:
 flutter:
   assets:
     - assets/
-    - assets/models/
+    - assets/models/cocossd/
+    - assets/models/mobilenet/
   fonts:
   - family: Inter
     fonts: