object_detection_service.dart 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import "dart:isolate";
  2. import "dart:typed_data";
  3. import "package:logging/logging.dart";
  4. import "package:photos/services/object_detection/models/predictions.dart";
  5. import 'package:photos/services/object_detection/models/recognition.dart';
  6. import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
  7. import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
  8. import "package:photos/services/object_detection/utils/isolate_utils.dart";
  9. class ObjectDetectionService {
  10. static const scoreThreshold = 0.5;
  11. final _logger = Logger("ObjectDetectionService");
  12. late CocoSSDClassifier _objectClassifier;
  13. late MobileNetClassifier _mobileNetClassifier;
  14. late IsolateUtils _isolateUtils;
  15. ObjectDetectionService._privateConstructor();
  16. Future<void> init() async {
  17. _isolateUtils = IsolateUtils();
  18. await _isolateUtils.start();
  19. _objectClassifier = CocoSSDClassifier();
  20. _mobileNetClassifier = MobileNetClassifier();
  21. }
  22. static ObjectDetectionService instance =
  23. ObjectDetectionService._privateConstructor();
  24. Future<List<String>> predict(Uint8List bytes) async {
  25. try {
  26. final results = <String>{};
  27. final objectResults = await _getObjects(bytes);
  28. results.addAll(objectResults);
  29. final mobileNetResults = await _getMobileNetResults(bytes);
  30. results.addAll(mobileNetResults);
  31. return results.toList();
  32. } catch (e, s) {
  33. _logger.severe(e, s);
  34. rethrow;
  35. }
  36. }
  37. Future<List<String>> _getObjects(Uint8List bytes) async {
  38. final isolateData = IsolateData(
  39. bytes,
  40. _objectClassifier.interpreter.address,
  41. _objectClassifier.labels,
  42. ClassifierType.cocossd,
  43. );
  44. final predictions = await _inference(isolateData);
  45. final Set<String> results = {};
  46. for (final Recognition result in predictions.recognitions) {
  47. if (result.score > scoreThreshold) {
  48. results.add(result.label);
  49. }
  50. }
  51. return results.toList();
  52. }
  53. Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
  54. final isolateData = IsolateData(
  55. bytes,
  56. _mobileNetClassifier.interpreter.address,
  57. _mobileNetClassifier.labels,
  58. ClassifierType.mobilenet,
  59. );
  60. final predictions = await _inference(isolateData);
  61. final Set<String> results = {};
  62. for (final Recognition result in predictions.recognitions) {
  63. if (result.score > scoreThreshold) {
  64. results.add(result.label);
  65. }
  66. }
  67. return results.toList();
  68. }
  69. /// Runs inference in another isolate
  70. Future<Predictions> _inference(IsolateData isolateData) async {
  71. final responsePort = ReceivePort();
  72. _isolateUtils.sendPort.send(
  73. isolateData..responsePort = responsePort.sendPort,
  74. );
  75. return await responsePort.first;
  76. }
  77. }