object_detection_service.dart 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import "dart:isolate";
  2. import "dart:typed_data";
  3. import "package:logging/logging.dart";
  4. import "package:photos/services/object_detection/models/predictions.dart";
  5. import 'package:photos/services/object_detection/models/recognition.dart';
  6. import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
  7. import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
  8. import "package:photos/services/object_detection/tflite/scene_classifier.dart";
  9. import "package:photos/services/object_detection/utils/isolate_utils.dart";
  10. class ObjectDetectionService {
  11. static const scoreThreshold = 0.5;
  12. final _logger = Logger("ObjectDetectionService");
  13. late CocoSSDClassifier _objectClassifier;
  14. late MobileNetClassifier _mobileNetClassifier;
  15. late SceneClassifier _sceneClassifier;
  16. late IsolateUtils _isolateUtils;
  17. ObjectDetectionService._privateConstructor();
  18. bool inInitiated = false;
  19. Future<void> init() async {
  20. _isolateUtils = IsolateUtils();
  21. await _isolateUtils.start();
  22. try {
  23. _objectClassifier = CocoSSDClassifier();
  24. } catch (e, s) {
  25. _logger.severe("Could not initialize cocossd", e, s);
  26. }
  27. try {
  28. _mobileNetClassifier = MobileNetClassifier();
  29. } catch (e, s) {
  30. _logger.severe("Could not initialize mobilenet", e, s);
  31. }
  32. try {
  33. _sceneClassifier = SceneClassifier();
  34. } catch (e, s) {
  35. _logger.severe("Could not initialize sceneclassifier", e, s);
  36. }
  37. inInitiated = true;
  38. }
  39. static ObjectDetectionService instance =
  40. ObjectDetectionService._privateConstructor();
  41. Future<List<String>> predict(Uint8List bytes) async {
  42. try {
  43. if (!inInitiated) {
  44. return Future.error("ObjectDetectionService init is not completed");
  45. }
  46. final results = <String>{};
  47. results.addAll(await _getObjects(bytes));
  48. results.addAll(await _getMobileNetResults(bytes));
  49. results.addAll(await _getSceneResults(bytes));
  50. return results.toList();
  51. } catch (e, s) {
  52. _logger.severe(e, s);
  53. rethrow;
  54. }
  55. }
  56. Future<List<String>> _getObjects(Uint8List bytes) async {
  57. try {
  58. final isolateData = IsolateData(
  59. bytes,
  60. _objectClassifier.interpreter.address,
  61. _objectClassifier.labels,
  62. ClassifierType.cocossd,
  63. );
  64. return _getPredictions(isolateData);
  65. } catch (e, s) {
  66. _logger.severe("Could not run cocossd", e, s);
  67. }
  68. return [];
  69. }
  70. Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
  71. try {
  72. final isolateData = IsolateData(
  73. bytes,
  74. _mobileNetClassifier.interpreter.address,
  75. _mobileNetClassifier.labels,
  76. ClassifierType.mobilenet,
  77. );
  78. return _getPredictions(isolateData);
  79. } catch (e, s) {
  80. _logger.severe("Could not run mobilenet", e, s);
  81. }
  82. return [];
  83. }
  84. Future<List<String>> _getSceneResults(Uint8List bytes) async {
  85. try {
  86. final isolateData = IsolateData(
  87. bytes,
  88. _sceneClassifier.interpreter.address,
  89. _sceneClassifier.labels,
  90. ClassifierType.scenes,
  91. );
  92. return _getPredictions(isolateData);
  93. } catch (e, s) {
  94. _logger.severe("Could not run scene detection", e, s);
  95. }
  96. return [];
  97. }
  98. Future<List<String>> _getPredictions(IsolateData isolateData) async {
  99. final predictions = await _inference(isolateData);
  100. final Set<String> results = {};
  101. if (predictions.error == null) {
  102. for (final Recognition result in predictions.recognitions!) {
  103. if (result.score > scoreThreshold) {
  104. results.add(result.label);
  105. }
  106. }
  107. _logger.info(
  108. "Time taken for " +
  109. isolateData.type.toString() +
  110. ": " +
  111. predictions.stats!.totalElapsedTime.toString() +
  112. "ms",
  113. );
  114. } else {
  115. _logger.severe(
  116. "Error while fetching predictions for " + isolateData.type.toString(),
  117. predictions.error,
  118. );
  119. }
  120. return results.toList();
  121. }
  122. /// Runs inference in another isolate
  123. Future<Predictions> _inference(IsolateData isolateData) async {
  124. final responsePort = ReceivePort();
  125. _isolateUtils.sendPort.send(
  126. isolateData..responsePort = responsePort.sendPort,
  127. );
  128. return await responsePort.first;
  129. }
  130. }