classifier.dart 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import "dart:math";
  2. import 'package:image/image.dart' as image_lib;
  3. import "package:logging/logging.dart";
  4. import "package:photos/services/object_detection/models/predictions.dart";
  5. import "package:tflite_flutter/tflite_flutter.dart";
  6. import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
  7. abstract class Classifier {
  8. // Path to the model
  9. String get modelPath;
  10. // Path to the labels
  11. String get labelPath;
  12. // Input size expected by the model (for eg. width = height = 224)
  13. int get inputSize;
  14. // Logger implementation for the specific classifier
  15. Logger get logger;
  16. Predictions? predict(image_lib.Image image);
  17. /// Instance of Interpreter
  18. late Interpreter _interpreter;
  19. /// Labels file loaded as list
  20. late List<String> _labels;
  21. /// Shapes of output tensors
  22. late List<List<int>> _outputShapes;
  23. /// Types of output tensors
  24. late List<TfLiteType> _outputTypes;
  25. /// Gets the interpreter instance
  26. Interpreter get interpreter => _interpreter;
  27. /// Gets the loaded labels
  28. List<String> get labels => _labels;
  29. /// Gets the output shapes
  30. List<List<int>> get outputShapes => _outputShapes;
  31. /// Gets the output types
  32. List<TfLiteType> get outputTypes => _outputTypes;
  33. /// Loads interpreter from asset
  34. void loadModel(Interpreter? interpreter) async {
  35. try {
  36. _interpreter = interpreter ??
  37. await Interpreter.fromAsset(
  38. modelPath,
  39. options: InterpreterOptions()..threads = 4,
  40. );
  41. final outputTensors = _interpreter.getOutputTensors();
  42. _outputShapes = [];
  43. _outputTypes = [];
  44. for (var tensor in outputTensors) {
  45. _outputShapes.add(tensor.shape);
  46. _outputTypes.add(tensor.type);
  47. }
  48. logger.info("Interpreter initialized");
  49. } catch (e, s) {
  50. logger.severe("Error while creating interpreter", e, s);
  51. }
  52. }
  53. /// Loads labels from assets
  54. void loadLabels(List<String>? labels) async {
  55. try {
  56. _labels = labels ?? await FileUtil.loadLabels(labelPath);
  57. logger.info("Labels initialized");
  58. } catch (e, s) {
  59. logger.severe("Error while loading labels", e, s);
  60. }
  61. }
  62. /// Pre-process the image
  63. TensorImage getProcessedImage(TensorImage inputImage) {
  64. final padSize = max(inputImage.height, inputImage.width);
  65. final imageProcessor = ImageProcessorBuilder()
  66. .add(ResizeWithCropOrPadOp(padSize, padSize))
  67. .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
  68. .build();
  69. inputImage = imageProcessor.process(inputImage);
  70. return inputImage;
  71. }
  72. }