|
@@ -0,0 +1,179 @@
|
|
|
|
+import 'dart:math';
|
|
|
|
+
|
|
|
|
+import 'package:image/image.dart' as imageLib;
|
|
|
|
+import "package:logging/logging.dart";
|
|
|
|
+import 'package:photos/services/object_detection/models/predictions.dart';
|
|
|
|
+import 'package:photos/services/object_detection/models/recognition.dart';
|
|
|
|
+import "package:photos/services/object_detection/models/stats.dart";
|
|
|
|
+import "package:tflite_flutter/tflite_flutter.dart";
|
|
|
|
+import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
|
|
|
|
+
|
|
|
|
+/// Classifier
|
|
|
|
+class ObjectClassifier {
|
|
|
|
+ final _logger = Logger("Classifier");
|
|
|
|
+
|
|
|
|
+ /// Instance of Interpreter
|
|
|
|
+ late Interpreter _interpreter;
|
|
|
|
+
|
|
|
|
+ /// Labels file loaded as list
|
|
|
|
+ late List<String> _labels;
|
|
|
|
+
|
|
|
|
+ /// Input size of image (height = width = 300)
|
|
|
|
+ static const int inputSize = 300;
|
|
|
|
+
|
|
|
|
+ /// Result score threshold
|
|
|
|
+ static const double threshold = 0.5;
|
|
|
|
+
|
|
|
|
+ static const String modelFileName = "detect.tflite";
|
|
|
|
+ static const String labelFileName = "labelmap.txt";
|
|
|
|
+
|
|
|
|
+ /// [ImageProcessor] used to pre-process the image
|
|
|
|
+ ImageProcessor? imageProcessor;
|
|
|
|
+
|
|
|
|
+ /// Padding the image to transform into square
|
|
|
|
+ late int padSize;
|
|
|
|
+
|
|
|
|
+ /// Shapes of output tensors
|
|
|
|
+ late List<List<int>> _outputShapes;
|
|
|
|
+
|
|
|
|
+ /// Types of output tensors
|
|
|
|
+ late List<TfLiteType> _outputTypes;
|
|
|
|
+
|
|
|
|
+ /// Number of results to show
|
|
|
|
+ static const int numResults = 10;
|
|
|
|
+
|
|
|
|
+ ObjectClassifier({
|
|
|
|
+ Interpreter? interpreter,
|
|
|
|
+ List<String>? labels,
|
|
|
|
+ }) {
|
|
|
|
+ loadModel(interpreter);
|
|
|
|
+ loadLabels(labels);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Loads interpreter from asset
|
|
|
|
+ void loadModel(Interpreter? interpreter) async {
|
|
|
|
+ try {
|
|
|
|
+ _interpreter = interpreter ??
|
|
|
|
+ await Interpreter.fromAsset(
|
|
|
|
+ "models/" + modelFileName,
|
|
|
|
+ options: InterpreterOptions()..threads = 4,
|
|
|
|
+ );
|
|
|
|
+ final outputTensors = _interpreter.getOutputTensors();
|
|
|
|
+ _outputShapes = [];
|
|
|
|
+ _outputTypes = [];
|
|
|
|
+ outputTensors.forEach((tensor) {
|
|
|
|
+ _outputShapes.add(tensor.shape);
|
|
|
|
+ _outputTypes.add(tensor.type);
|
|
|
|
+ });
|
|
|
|
+ _logger.info("Interpreter initialized");
|
|
|
|
+ } catch (e, s) {
|
|
|
|
+ _logger.severe("Error while creating interpreter", e, s);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Loads labels from assets
|
|
|
|
+ void loadLabels(List<String>? labels) async {
|
|
|
|
+ try {
|
|
|
|
+ _labels =
|
|
|
|
+ labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
|
|
|
|
+ _logger.info("Labels initialized");
|
|
|
|
+ } catch (e, s) {
|
|
|
|
+ _logger.severe("Error while loading labels", e, s);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Pre-process the image
|
|
|
|
+ TensorImage _getProcessedImage(TensorImage inputImage) {
|
|
|
|
+ padSize = max(inputImage.height, inputImage.width);
|
|
|
|
+ imageProcessor ??= ImageProcessorBuilder()
|
|
|
|
+ .add(ResizeWithCropOrPadOp(padSize, padSize))
|
|
|
|
+ .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
|
|
|
+ .build();
|
|
|
|
+ inputImage = imageProcessor!.process(inputImage);
|
|
|
|
+ return inputImage;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Runs object detection on the input image
|
|
|
|
+ Predictions? predict(imageLib.Image image) {
|
|
|
|
+ final predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
+
|
|
|
|
+ final preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
+
|
|
|
|
+ // Create TensorImage from image
|
|
|
|
+ TensorImage inputImage = TensorImage.fromImage(image);
|
|
|
|
+
|
|
|
|
+ // Pre-process TensorImage
|
|
|
|
+ inputImage = _getProcessedImage(inputImage);
|
|
|
|
+
|
|
|
|
+ final preProcessElapsedTime =
|
|
|
|
+ DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
|
|
|
+
|
|
|
|
+ // TensorBuffers for output tensors
|
|
|
|
+ final outputLocations = TensorBufferFloat(_outputShapes[0]);
|
|
|
|
+ final outputClasses = TensorBufferFloat(_outputShapes[1]);
|
|
|
|
+ final outputScores = TensorBufferFloat(_outputShapes[2]);
|
|
|
|
+ final numLocations = TensorBufferFloat(_outputShapes[3]);
|
|
|
|
+
|
|
|
|
+ // Inputs object for runForMultipleInputs
|
|
|
|
+ // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
|
|
|
|
+ final inputs = [inputImage.buffer];
|
|
|
|
+
|
|
|
|
+ // Outputs map
|
|
|
|
+ final outputs = {
|
|
|
|
+ 0: outputLocations.buffer,
|
|
|
|
+ 1: outputClasses.buffer,
|
|
|
|
+ 2: outputScores.buffer,
|
|
|
|
+ 3: numLocations.buffer,
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
+
|
|
|
|
+ // run inference
|
|
|
|
+ _interpreter.runForMultipleInputs(inputs, outputs);
|
|
|
|
+
|
|
|
|
+ final inferenceTimeElapsed =
|
|
|
|
+ DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
|
|
|
+
|
|
|
|
+ // Maximum number of results to show
|
|
|
|
+ final resultsCount = min(numResults, numLocations.getIntValue(0));
|
|
|
|
+
|
|
|
|
+ // Using labelOffset = 1 as ??? at index 0
|
|
|
|
+ const labelOffset = 1;
|
|
|
|
+
|
|
|
|
+ final recognitions = <Recognition>[];
|
|
|
|
+
|
|
|
|
+ for (int i = 0; i < resultsCount; i++) {
|
|
|
|
+ // Prediction score
|
|
|
|
+ final score = outputScores.getDoubleValue(i);
|
|
|
|
+
|
|
|
|
+ // Label string
|
|
|
|
+ final labelIndex = outputClasses.getIntValue(i) + labelOffset;
|
|
|
|
+ final label = _labels.elementAt(labelIndex);
|
|
|
|
+
|
|
|
|
+ if (score > threshold) {
|
|
|
|
+ recognitions.add(
|
|
|
|
+ Recognition(i, label, score),
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ final predictElapsedTime =
|
|
|
|
+ DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
|
|
|
+ _logger.info(recognitions);
|
|
|
|
+ return Predictions(
|
|
|
|
+ recognitions,
|
|
|
|
+ Stats(
|
|
|
|
+ predictElapsedTime,
|
|
|
|
+ predictElapsedTime,
|
|
|
|
+ inferenceTimeElapsed,
|
|
|
|
+ preProcessElapsedTime,
|
|
|
|
+ ),
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Gets the interpreter instance
|
|
|
|
+ Interpreter get interpreter => _interpreter;
|
|
|
|
+
|
|
|
|
+ /// Gets the loaded labels
|
|
|
|
+ List<String> get labels => _labels;
|
|
|
|
+}
|