classifier.dart 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import 'dart:math';
  2. import 'package:image/image.dart' as imageLib;
  3. import "package:logging/logging.dart";
  4. import 'package:photos/services/object_detection/models/predictions.dart';
  5. import 'package:photos/services/object_detection/models/recognition.dart';
  6. import "package:photos/services/object_detection/models/stats.dart";
  7. import "package:tflite_flutter/tflite_flutter.dart";
  8. import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
  9. /// Classifier
  10. class ObjectClassifier {
  11. final _logger = Logger("Classifier");
  12. /// Instance of Interpreter
  13. late Interpreter _interpreter;
  14. /// Labels file loaded as list
  15. late List<String> _labels;
  16. /// Input size of image (height = width = 300)
  17. static const int inputSize = 300;
  18. /// Result score threshold
  19. static const double threshold = 0.5;
  20. static const String modelFileName = "detect.tflite";
  21. static const String labelFileName = "labelmap.txt";
  22. /// [ImageProcessor] used to pre-process the image
  23. ImageProcessor? imageProcessor;
  24. /// Padding the image to transform into square
  25. late int padSize;
  26. /// Shapes of output tensors
  27. late List<List<int>> _outputShapes;
  28. /// Types of output tensors
  29. late List<TfLiteType> _outputTypes;
  30. /// Number of results to show
  31. static const int numResults = 10;
  32. ObjectClassifier({
  33. Interpreter? interpreter,
  34. List<String>? labels,
  35. }) {
  36. loadModel(interpreter);
  37. loadLabels(labels);
  38. }
  39. /// Loads interpreter from asset
  40. void loadModel(Interpreter? interpreter) async {
  41. try {
  42. _interpreter = interpreter ??
  43. await Interpreter.fromAsset(
  44. "models/" + modelFileName,
  45. options: InterpreterOptions()..threads = 4,
  46. );
  47. final outputTensors = _interpreter.getOutputTensors();
  48. _outputShapes = [];
  49. _outputTypes = [];
  50. outputTensors.forEach((tensor) {
  51. _outputShapes.add(tensor.shape);
  52. _outputTypes.add(tensor.type);
  53. });
  54. _logger.info("Interpreter initialized");
  55. } catch (e, s) {
  56. _logger.severe("Error while creating interpreter", e, s);
  57. }
  58. }
  59. /// Loads labels from assets
  60. void loadLabels(List<String>? labels) async {
  61. try {
  62. _labels =
  63. labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
  64. _logger.info("Labels initialized");
  65. } catch (e, s) {
  66. _logger.severe("Error while loading labels", e, s);
  67. }
  68. }
  69. /// Pre-process the image
  70. TensorImage _getProcessedImage(TensorImage inputImage) {
  71. padSize = max(inputImage.height, inputImage.width);
  72. imageProcessor ??= ImageProcessorBuilder()
  73. .add(ResizeWithCropOrPadOp(padSize, padSize))
  74. .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
  75. .build();
  76. inputImage = imageProcessor!.process(inputImage);
  77. return inputImage;
  78. }
  79. /// Runs object detection on the input image
  80. Predictions? predict(imageLib.Image image) {
  81. final predictStartTime = DateTime.now().millisecondsSinceEpoch;
  82. final preProcessStart = DateTime.now().millisecondsSinceEpoch;
  83. // Create TensorImage from image
  84. TensorImage inputImage = TensorImage.fromImage(image);
  85. // Pre-process TensorImage
  86. inputImage = _getProcessedImage(inputImage);
  87. final preProcessElapsedTime =
  88. DateTime.now().millisecondsSinceEpoch - preProcessStart;
  89. // TensorBuffers for output tensors
  90. final outputLocations = TensorBufferFloat(_outputShapes[0]);
  91. final outputClasses = TensorBufferFloat(_outputShapes[1]);
  92. final outputScores = TensorBufferFloat(_outputShapes[2]);
  93. final numLocations = TensorBufferFloat(_outputShapes[3]);
  94. // Inputs object for runForMultipleInputs
  95. // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
  96. final inputs = [inputImage.buffer];
  97. // Outputs map
  98. final outputs = {
  99. 0: outputLocations.buffer,
  100. 1: outputClasses.buffer,
  101. 2: outputScores.buffer,
  102. 3: numLocations.buffer,
  103. };
  104. final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
  105. // run inference
  106. _interpreter.runForMultipleInputs(inputs, outputs);
  107. final inferenceTimeElapsed =
  108. DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
  109. // Maximum number of results to show
  110. final resultsCount = min(numResults, numLocations.getIntValue(0));
  111. // Using labelOffset = 1 as ??? at index 0
  112. const labelOffset = 1;
  113. final recognitions = <Recognition>[];
  114. for (int i = 0; i < resultsCount; i++) {
  115. // Prediction score
  116. final score = outputScores.getDoubleValue(i);
  117. // Label string
  118. final labelIndex = outputClasses.getIntValue(i) + labelOffset;
  119. final label = _labels.elementAt(labelIndex);
  120. if (score > threshold) {
  121. recognitions.add(
  122. Recognition(i, label, score),
  123. );
  124. }
  125. }
  126. final predictElapsedTime =
  127. DateTime.now().millisecondsSinceEpoch - predictStartTime;
  128. _logger.info(recognitions);
  129. return Predictions(
  130. recognitions,
  131. Stats(
  132. predictElapsedTime,
  133. predictElapsedTime,
  134. inferenceTimeElapsed,
  135. preProcessElapsedTime,
  136. ),
  137. );
  138. }
  139. /// Gets the interpreter instance
  140. Interpreter get interpreter => _interpreter;
  141. /// Gets the loaded labels
  142. List<String> get labels => _labels;
  143. }