isolate_utils.dart 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import 'dart:isolate';
  2. import "dart:typed_data";
  3. import 'package:image/image.dart' as imgLib;
  4. import "package:photos/services/object_detection/models/predictions.dart";
  5. import "package:photos/services/object_detection/tflite/classifier.dart";
  6. import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
  7. import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
  8. import "package:photos/services/object_detection/tflite/scene_classifier.dart";
  9. import 'package:tflite_flutter/tflite_flutter.dart';
  10. /// Manages separate Isolate instance for inference
  11. class IsolateUtils {
  12. static const String debugName = "InferenceIsolate";
  13. late SendPort _sendPort;
  14. final _receivePort = ReceivePort();
  15. SendPort get sendPort => _sendPort;
  16. Future<void> start() async {
  17. await Isolate.spawn<SendPort>(
  18. entryPoint,
  19. _receivePort.sendPort,
  20. debugName: debugName,
  21. );
  22. _sendPort = await _receivePort.first;
  23. }
  24. static void entryPoint(SendPort sendPort) async {
  25. final port = ReceivePort();
  26. sendPort.send(port.sendPort);
  27. await for (final IsolateData isolateData in port) {
  28. final classifier = _getClassifier(isolateData);
  29. final image = imgLib.decodeImage(isolateData.input);
  30. try {
  31. final results = classifier.predict(image!);
  32. isolateData.responsePort.send(results);
  33. } catch (e) {
  34. isolateData.responsePort.send(Predictions(null, null, error: e));
  35. }
  36. }
  37. }
  38. static Classifier _getClassifier(IsolateData isolateData) {
  39. final interpreter = Interpreter.fromAddress(isolateData.interpreterAddress);
  40. if (isolateData.type == ClassifierType.cocossd) {
  41. return CocoSSDClassifier(
  42. interpreter: interpreter,
  43. labels: isolateData.labels,
  44. );
  45. } else if (isolateData.type == ClassifierType.mobilenet) {
  46. return MobileNetClassifier(
  47. interpreter: interpreter,
  48. labels: isolateData.labels,
  49. );
  50. } else {
  51. return SceneClassifier(
  52. interpreter: interpreter,
  53. labels: isolateData.labels,
  54. );
  55. }
  56. }
  57. }
  58. /// Bundles data to pass between Isolate
  59. class IsolateData {
  60. Uint8List input;
  61. int interpreterAddress;
  62. List<String> labels;
  63. ClassifierType type;
  64. late SendPort responsePort;
  65. IsolateData(
  66. this.input,
  67. this.interpreterAddress,
  68. this.labels,
  69. this.type,
  70. );
  71. }
  72. enum ClassifierType {
  73. cocossd,
  74. mobilenet,
  75. scenes,
  76. }