123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import 'dart:math';
- import 'package:image/image.dart' as imageLib;
- import "package:logging/logging.dart";
- import 'package:photos/services/object_detection/models/predictions.dart';
- import 'package:photos/services/object_detection/models/recognition.dart';
- import "package:photos/services/object_detection/models/stats.dart";
- import "package:tflite_flutter/tflite_flutter.dart";
- import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
- /// Classifier
- class ObjectClassifier {
- final _logger = Logger("Classifier");
- /// Instance of Interpreter
- late Interpreter _interpreter;
- /// Labels file loaded as list
- late List<String> _labels;
- /// Input size of image (height = width = 300)
- static const int inputSize = 300;
- /// Result score threshold
- static const double threshold = 0.5;
- static const String modelFileName = "detect.tflite";
- static const String labelFileName = "labelmap.txt";
- /// [ImageProcessor] used to pre-process the image
- ImageProcessor? imageProcessor;
- /// Padding the image to transform into square
- late int padSize;
- /// Shapes of output tensors
- late List<List<int>> _outputShapes;
- /// Types of output tensors
- late List<TfLiteType> _outputTypes;
- /// Number of results to show
- static const int numResults = 10;
- ObjectClassifier({
- Interpreter? interpreter,
- List<String>? labels,
- }) {
- loadModel(interpreter);
- loadLabels(labels);
- }
- /// Loads interpreter from asset
- void loadModel(Interpreter? interpreter) async {
- try {
- _interpreter = interpreter ??
- await Interpreter.fromAsset(
- "models/" + modelFileName,
- options: InterpreterOptions()..threads = 4,
- );
- final outputTensors = _interpreter.getOutputTensors();
- _outputShapes = [];
- _outputTypes = [];
- outputTensors.forEach((tensor) {
- _outputShapes.add(tensor.shape);
- _outputTypes.add(tensor.type);
- });
- _logger.info("Interpreter initialized");
- } catch (e, s) {
- _logger.severe("Error while creating interpreter", e, s);
- }
- }
- /// Loads labels from assets
- void loadLabels(List<String>? labels) async {
- try {
- _labels =
- labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
- _logger.info("Labels initialized");
- } catch (e, s) {
- _logger.severe("Error while loading labels", e, s);
- }
- }
- /// Pre-process the image
- TensorImage _getProcessedImage(TensorImage inputImage) {
- padSize = max(inputImage.height, inputImage.width);
- imageProcessor ??= ImageProcessorBuilder()
- .add(ResizeWithCropOrPadOp(padSize, padSize))
- .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
- .build();
- inputImage = imageProcessor!.process(inputImage);
- return inputImage;
- }
- /// Runs object detection on the input image
- Predictions? predict(imageLib.Image image) {
- final predictStartTime = DateTime.now().millisecondsSinceEpoch;
- final preProcessStart = DateTime.now().millisecondsSinceEpoch;
- // Create TensorImage from image
- TensorImage inputImage = TensorImage.fromImage(image);
- // Pre-process TensorImage
- inputImage = _getProcessedImage(inputImage);
- final preProcessElapsedTime =
- DateTime.now().millisecondsSinceEpoch - preProcessStart;
- // TensorBuffers for output tensors
- final outputLocations = TensorBufferFloat(_outputShapes[0]);
- final outputClasses = TensorBufferFloat(_outputShapes[1]);
- final outputScores = TensorBufferFloat(_outputShapes[2]);
- final numLocations = TensorBufferFloat(_outputShapes[3]);
- // Inputs object for runForMultipleInputs
- // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
- final inputs = [inputImage.buffer];
- // Outputs map
- final outputs = {
- 0: outputLocations.buffer,
- 1: outputClasses.buffer,
- 2: outputScores.buffer,
- 3: numLocations.buffer,
- };
- final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
- // run inference
- _interpreter.runForMultipleInputs(inputs, outputs);
- final inferenceTimeElapsed =
- DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
- // Maximum number of results to show
- final resultsCount = min(numResults, numLocations.getIntValue(0));
- // Using labelOffset = 1 as ??? at index 0
- const labelOffset = 1;
- final recognitions = <Recognition>[];
- for (int i = 0; i < resultsCount; i++) {
- // Prediction score
- final score = outputScores.getDoubleValue(i);
- // Label string
- final labelIndex = outputClasses.getIntValue(i) + labelOffset;
- final label = _labels.elementAt(labelIndex);
- if (score > threshold) {
- recognitions.add(
- Recognition(i, label, score),
- );
- }
- }
- final predictElapsedTime =
- DateTime.now().millisecondsSinceEpoch - predictStartTime;
- _logger.info(recognitions);
- return Predictions(
- recognitions,
- Stats(
- predictElapsedTime,
- predictElapsedTime,
- inferenceTimeElapsed,
- preProcessElapsedTime,
- ),
- );
- }
- /// Gets the interpreter instance
- Interpreter get interpreter => _interpreter;
- /// Gets the loaded labels
- List<String> get labels => _labels;
- }
|