diff --git a/lib/services/object_detection/tflite/classifier.dart b/lib/services/object_detection/tflite/classifier.dart new file mode 100644 index 000000000..299b31244 --- /dev/null +++ b/lib/services/object_detection/tflite/classifier.dart @@ -0,0 +1,179 @@ +import 'dart:math'; + +import 'package:image/image.dart' as imageLib; +import "package:logging/logging.dart"; +import 'package:photos/services/object_detection/models/predictions.dart'; +import 'package:photos/services/object_detection/models/recognition.dart'; +import "package:photos/services/object_detection/models/stats.dart"; +import "package:tflite_flutter/tflite_flutter.dart"; +import "package:tflite_flutter_helper/tflite_flutter_helper.dart"; + +/// Classifier +class ObjectClassifier { + final _logger = Logger("Classifier"); + + /// Instance of Interpreter + late Interpreter _interpreter; + + /// Labels file loaded as list + late List _labels; + + /// Input size of image (height = width = 300) + static const int inputSize = 300; + + /// Result score threshold + static const double threshold = 0.5; + + static const String modelFileName = "detect.tflite"; + static const String labelFileName = "labelmap.txt"; + + /// [ImageProcessor] used to pre-process the image + ImageProcessor? imageProcessor; + + /// Padding the image to transform into square + late int padSize; + + /// Shapes of output tensors + late List> _outputShapes; + + /// Types of output tensors + late List _outputTypes; + + /// Number of results to show + static const int numResults = 10; + + ObjectClassifier({ + Interpreter? interpreter, + List? labels, + }) { + loadModel(interpreter); + loadLabels(labels); + } + + /// Loads interpreter from asset + void loadModel(Interpreter? interpreter) async { + try { + _interpreter = interpreter ?? + await Interpreter.fromAsset( + "models/" + modelFileName, + options: InterpreterOptions()..threads = 4, + ); + final outputTensors = _interpreter.getOutputTensors(); + _outputShapes = []; + _outputTypes = []; + outputTensors.forEach((tensor) { + _outputShapes.add(tensor.shape); + _outputTypes.add(tensor.type); + }); + _logger.info("Interpreter initialized"); + } catch (e, s) { + _logger.severe("Error while creating interpreter", e, s); + } + } + + /// Loads labels from assets + void loadLabels(List? labels) async { + try { + _labels = + labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName); + _logger.info("Labels initialized"); + } catch (e, s) { + _logger.severe("Error while loading labels", e, s); + } + } + + /// Pre-process the image + TensorImage _getProcessedImage(TensorImage inputImage) { + padSize = max(inputImage.height, inputImage.width); + imageProcessor ??= ImageProcessorBuilder() + .add(ResizeWithCropOrPadOp(padSize, padSize)) + .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR)) + .build(); + inputImage = imageProcessor!.process(inputImage); + return inputImage; + } + + /// Runs object detection on the input image + Predictions? predict(imageLib.Image image) { + final predictStartTime = DateTime.now().millisecondsSinceEpoch; + + final preProcessStart = DateTime.now().millisecondsSinceEpoch; + + // Create TensorImage from image + TensorImage inputImage = TensorImage.fromImage(image); + + // Pre-process TensorImage + inputImage = _getProcessedImage(inputImage); + + final preProcessElapsedTime = + DateTime.now().millisecondsSinceEpoch - preProcessStart; + + // TensorBuffers for output tensors + final outputLocations = TensorBufferFloat(_outputShapes[0]); + final outputClasses = TensorBufferFloat(_outputShapes[1]); + final outputScores = TensorBufferFloat(_outputShapes[2]); + final numLocations = TensorBufferFloat(_outputShapes[3]); + + // Inputs object for runForMultipleInputs + // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference + final inputs = [inputImage.buffer]; + + // Outputs map + final outputs = { + 0: outputLocations.buffer, + 1: outputClasses.buffer, + 2: outputScores.buffer, + 3: numLocations.buffer, + }; + + final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch; + + // run inference + _interpreter.runForMultipleInputs(inputs, outputs); + + final inferenceTimeElapsed = + DateTime.now().millisecondsSinceEpoch - inferenceTimeStart; + + // Maximum number of results to show + final resultsCount = min(numResults, numLocations.getIntValue(0)); + + // Using labelOffset = 1 as ??? at index 0 + const labelOffset = 1; + + final recognitions = []; + + for (int i = 0; i < resultsCount; i++) { + // Prediction score + final score = outputScores.getDoubleValue(i); + + // Label string + final labelIndex = outputClasses.getIntValue(i) + labelOffset; + final label = _labels.elementAt(labelIndex); + + if (score > threshold) { + recognitions.add( + Recognition(i, label, score), + ); + } + } + + final predictElapsedTime = + DateTime.now().millisecondsSinceEpoch - predictStartTime; + _logger.info(recognitions); + return Predictions( + recognitions, + Stats( + predictElapsedTime, + predictElapsedTime, + inferenceTimeElapsed, + preProcessElapsedTime, + ), + ); + } + + /// Gets the interpreter instance + Interpreter get interpreter => _interpreter; + + /// Gets the loaded labels + List get labels => _labels; +}