diff --git a/README.md b/README.md
index 8d272596c..2ee7a5ae3 100644
--- a/README.md
+++ b/README.md
@@ -65,8 +65,9 @@ You can alternatively install the build from PlayStore or F-Droid.
2. Clone this repository with `git clone git@github.com:ente-io/photos-app.git`
3. Pull in all submodules with `git submodule update --init --recursive`
4. Enable repo git hooks `git config core.hooksPath hooks`
-5. For Android, run `flutter build apk --release --flavor independent`
-6. For iOS, run `flutter build ios`
+5. Setup TensorFlowLite by executing `setup.sh`
+6. For Android, run `flutter build apk --release --flavor independent`
+7. For iOS, run `flutter build ios`
diff --git a/android/app/build.gradle b/android/app/build.gradle
index 8e8dee587..1387fc824 100644
--- a/android/app/build.gradle
+++ b/android/app/build.gradle
@@ -46,7 +46,7 @@ android {
defaultConfig {
applicationId "io.ente.photos"
- minSdkVersion 19
+ minSdkVersion 21
targetSdkVersion 33
versionCode flutterVersionCode.toInteger()
versionName flutterVersionName
diff --git a/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so b/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so
new file mode 100644
index 000000000..bf11c7ab4
Binary files /dev/null and b/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so differ
diff --git a/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so b/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so
new file mode 100644
index 000000000..3f687ab98
Binary files /dev/null and b/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so differ
diff --git a/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so b/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so
new file mode 100644
index 000000000..1ac12a0f1
Binary files /dev/null and b/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so differ
diff --git a/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so b/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so
new file mode 100644
index 000000000..9a007fddf
Binary files /dev/null and b/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so differ
diff --git a/assets/models/detect.tflite b/assets/models/detect.tflite
new file mode 100644
index 000000000..8015ee5d8
Binary files /dev/null and b/assets/models/detect.tflite differ
diff --git a/assets/models/labelmap.txt b/assets/models/labelmap.txt
new file mode 100644
index 000000000..941cb4e13
--- /dev/null
+++ b/assets/models/labelmap.txt
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+dining table
+toilet
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/ios/Podfile.lock b/ios/Podfile.lock
index eefda7d78..f5bb7ea2f 100644
--- a/ios/Podfile.lock
+++ b/ios/Podfile.lock
@@ -1,6 +1,8 @@
PODS:
- background_fetch (1.1.4):
- Flutter
+ - camera_avfoundation (0.0.1):
+ - Flutter
- connectivity (0.0.1):
- Flutter
- Reachability
@@ -156,6 +158,10 @@ PODS:
- sqflite (0.0.2):
- Flutter
- FMDB (>= 2.7.5)
+ - tflite_flutter (0.1.0):
+ - Flutter
+ - tflite_flutter_helper (0.0.1):
+ - Flutter
- Toast (4.0.0)
- uni_links (0.0.1):
- Flutter
@@ -171,6 +177,7 @@ PODS:
DEPENDENCIES:
- background_fetch (from `.symlinks/plugins/background_fetch/ios`)
+ - camera_avfoundation (from `.symlinks/plugins/camera_avfoundation/ios`)
- connectivity (from `.symlinks/plugins/connectivity/ios`)
- device_info (from `.symlinks/plugins/device_info/ios`)
- firebase_core (from `.symlinks/plugins/firebase_core/ios`)
@@ -201,6 +208,8 @@ DEPENDENCIES:
- share_plus (from `.symlinks/plugins/share_plus/ios`)
- shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/ios`)
- sqflite (from `.symlinks/plugins/sqflite/ios`)
+ - tflite_flutter (from `.symlinks/plugins/tflite_flutter/ios`)
+ - tflite_flutter_helper (from `.symlinks/plugins/tflite_flutter_helper/ios`)
- uni_links (from `.symlinks/plugins/uni_links/ios`)
- url_launcher_ios (from `.symlinks/plugins/url_launcher_ios/ios`)
- video_player (from `.symlinks/plugins/video_player/ios`)
@@ -231,6 +240,8 @@ SPEC REPOS:
EXTERNAL SOURCES:
background_fetch:
:path: ".symlinks/plugins/background_fetch/ios"
+ camera_avfoundation:
+ :path: ".symlinks/plugins/camera_avfoundation/ios"
connectivity:
:path: ".symlinks/plugins/connectivity/ios"
device_info:
@@ -291,6 +302,10 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/shared_preferences_foundation/ios"
sqflite:
:path: ".symlinks/plugins/sqflite/ios"
+ tflite_flutter:
+ :path: ".symlinks/plugins/tflite_flutter/ios"
+ tflite_flutter_helper:
+ :path: ".symlinks/plugins/tflite_flutter_helper/ios"
uni_links:
:path: ".symlinks/plugins/uni_links/ios"
url_launcher_ios:
@@ -304,6 +319,7 @@ EXTERNAL SOURCES:
SPEC CHECKSUMS:
background_fetch: bd64e544b303ee4cd4cf2fe8cb2187b72aecf9ca
+ camera_avfoundation: 07c77549ea54ad95d8581be86617c094a46280d9
connectivity: c4130b2985d4ef6fd26f9702e886bd5260681467
device_info: d7d233b645a32c40dfdc212de5cf646ca482f175
Firebase: f92fc551ead69c94168d36c2b26188263860acd9
@@ -351,6 +367,8 @@ SPEC CHECKSUMS:
share_plus: 056a1e8ac890df3e33cb503afffaf1e9b4fbae68
shared_preferences_foundation: 297b3ebca31b34ec92be11acd7fb0ba932c822ca
sqflite: 6d358c025f5b867b29ed92fc697fd34924e11904
+ tflite_flutter: 9157a660578930a99728974f247369af1c3595d5
+ tflite_flutter_helper: 543b46b6bd064b21c92ea6e54bc0b29f1ce74cb5
Toast: 91b396c56ee72a5790816f40d3a94dd357abc196
uni_links: d97da20c7701486ba192624d99bffaaffcfc298a
url_launcher_ios: fb12c43172927bb5cf75aeebd073f883801f1993
diff --git a/ios/Runner.xcodeproj/project.pbxproj b/ios/Runner.xcodeproj/project.pbxproj
index bfd8d9494..9f41a9a3a 100644
--- a/ios/Runner.xcodeproj/project.pbxproj
+++ b/ios/Runner.xcodeproj/project.pbxproj
@@ -274,6 +274,7 @@
"${BUILT_PRODUCTS_DIR}/Sentry/Sentry.framework",
"${BUILT_PRODUCTS_DIR}/Toast/Toast.framework",
"${BUILT_PRODUCTS_DIR}/background_fetch/background_fetch.framework",
+ "${BUILT_PRODUCTS_DIR}/camera_avfoundation/camera_avfoundation.framework",
"${BUILT_PRODUCTS_DIR}/connectivity/connectivity.framework",
"${BUILT_PRODUCTS_DIR}/device_info/device_info.framework",
"${BUILT_PRODUCTS_DIR}/fk_user_agent/fk_user_agent.framework",
@@ -303,6 +304,8 @@
"${BUILT_PRODUCTS_DIR}/share_plus/share_plus.framework",
"${BUILT_PRODUCTS_DIR}/shared_preferences_foundation/shared_preferences_foundation.framework",
"${BUILT_PRODUCTS_DIR}/sqflite/sqflite.framework",
+ "${BUILT_PRODUCTS_DIR}/tflite_flutter/tflite_flutter.framework",
+ "${BUILT_PRODUCTS_DIR}/tflite_flutter_helper/tflite_flutter_helper.framework",
"${BUILT_PRODUCTS_DIR}/uni_links/uni_links.framework",
"${BUILT_PRODUCTS_DIR}/url_launcher_ios/url_launcher_ios.framework",
"${BUILT_PRODUCTS_DIR}/video_player/video_player.framework",
@@ -327,6 +330,7 @@
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/Sentry.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/Toast.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/background_fetch.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/camera_avfoundation.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/connectivity.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/device_info.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fk_user_agent.framework",
@@ -356,6 +360,8 @@
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/share_plus.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/shared_preferences_foundation.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/sqflite.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/tflite_flutter.framework",
+ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/tflite_flutter_helper.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/uni_links.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/url_launcher_ios.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/video_player.framework",
diff --git a/lib/main.dart b/lib/main.dart
index 4548871ab..e998a81d8 100644
--- a/lib/main.dart
+++ b/lib/main.dart
@@ -24,6 +24,7 @@ import 'package:photos/services/local_file_update_service.dart';
import 'package:photos/services/local_sync_service.dart';
import 'package:photos/services/memories_service.dart';
import 'package:photos/services/notification_service.dart';
+import "package:photos/services/object_detection/object_detection_service.dart";
import 'package:photos/services/push_service.dart';
import 'package:photos/services/remote_sync_service.dart';
import 'package:photos/services/search_service.dart';
@@ -160,6 +161,9 @@ Future _init(bool isBackground, {String via = ''}) async {
});
}
FeatureFlagService.instance.init();
+ if (FeatureFlagService.instance.isInternalUserOrDebugBuild()) {
+ await ObjectDetectionService.instance.init();
+ }
_logger.info("Initialization done");
}
diff --git a/lib/services/object_detection/models/predictions.dart b/lib/services/object_detection/models/predictions.dart
new file mode 100644
index 000000000..80c41a58c
--- /dev/null
+++ b/lib/services/object_detection/models/predictions.dart
@@ -0,0 +1,9 @@
+import "package:photos/services/object_detection/models/recognition.dart";
+import "package:photos/services/object_detection/models/stats.dart";
+
+class Predictions {
+ final List recognitions;
+ final Stats stats;
+
+ Predictions(this.recognitions, this.stats);
+}
diff --git a/lib/services/object_detection/models/recognition.dart b/lib/services/object_detection/models/recognition.dart
new file mode 100644
index 000000000..469326265
--- /dev/null
+++ b/lib/services/object_detection/models/recognition.dart
@@ -0,0 +1,18 @@
+/// Represents the recognition output from the model
+class Recognition {
+ /// Index of the result
+ int id;
+
+ /// Label of the result
+ String label;
+
+ /// Confidence [0.0, 1.0]
+ double score;
+
+ Recognition(this.id, this.label, this.score);
+
+ @override
+ String toString() {
+ return 'Recognition(id: $id, label: $label, score: $score)';
+ }
+}
diff --git a/lib/services/object_detection/models/stats.dart b/lib/services/object_detection/models/stats.dart
new file mode 100644
index 000000000..487f1f6a0
--- /dev/null
+++ b/lib/services/object_detection/models/stats.dart
@@ -0,0 +1,27 @@
+/// Bundles different elapsed times
+class Stats {
+ /// Total time taken in the isolate where the inference runs
+ int totalPredictTime;
+
+ /// [totalPredictTime] + communication overhead time
+ /// between main isolate and another isolate
+ int totalElapsedTime;
+
+ /// Time for which inference runs
+ int inferenceTime;
+
+ /// Time taken to pre-process the image
+ int preProcessingTime;
+
+ Stats(
+ this.totalPredictTime,
+ this.totalElapsedTime,
+ this.inferenceTime,
+ this.preProcessingTime,
+ );
+
+ @override
+ String toString() {
+ return 'Stats{totalPredictTime: $totalPredictTime, totalElapsedTime: $totalElapsedTime, inferenceTime: $inferenceTime, preProcessingTime: $preProcessingTime}';
+ }
+}
diff --git a/lib/services/object_detection/object_detection_service.dart b/lib/services/object_detection/object_detection_service.dart
new file mode 100644
index 000000000..47294be1e
--- /dev/null
+++ b/lib/services/object_detection/object_detection_service.dart
@@ -0,0 +1,57 @@
+import "dart:isolate";
+import "dart:typed_data";
+
+import "package:logging/logging.dart";
+import "package:photos/services/object_detection/models/predictions.dart";
+import 'package:photos/services/object_detection/models/recognition.dart';
+import "package:photos/services/object_detection/tflite/classifier.dart";
+import "package:photos/services/object_detection/utils/isolate_utils.dart";
+
+class ObjectDetectionService {
+ final _logger = Logger("ObjectDetectionService");
+
+ /// Instance of [ObjectClassifier]
+ late ObjectClassifier _classifier;
+
+ /// Instance of [IsolateUtils]
+ late IsolateUtils _isolateUtils;
+
+ ObjectDetectionService._privateConstructor();
+
+ Future init() async {
+ _isolateUtils = IsolateUtils();
+ await _isolateUtils.start();
+ _classifier = ObjectClassifier();
+ }
+
+ static ObjectDetectionService instance =
+ ObjectDetectionService._privateConstructor();
+
+ Future> predict(Uint8List bytes) async {
+ try {
+ final isolateData = IsolateData(
+ bytes,
+ _classifier.interpreter.address,
+ _classifier.labels,
+ );
+ final predictions = await _inference(isolateData);
+ final Set results = {};
+ for (final Recognition result in predictions.recognitions) {
+ results.add(result.label);
+ }
+ return results.toList();
+ } catch (e, s) {
+ _logger.severe(e, s);
+ rethrow;
+ }
+ }
+
+ /// Runs inference in another isolate
+ Future _inference(IsolateData isolateData) async {
+ final responsePort = ReceivePort();
+ _isolateUtils.sendPort.send(
+ isolateData..responsePort = responsePort.sendPort,
+ );
+ return await responsePort.first;
+ }
+}
diff --git a/lib/services/object_detection/tflite/classifier.dart b/lib/services/object_detection/tflite/classifier.dart
new file mode 100644
index 000000000..299b31244
--- /dev/null
+++ b/lib/services/object_detection/tflite/classifier.dart
@@ -0,0 +1,179 @@
+import 'dart:math';
+
+import 'package:image/image.dart' as imageLib;
+import "package:logging/logging.dart";
+import 'package:photos/services/object_detection/models/predictions.dart';
+import 'package:photos/services/object_detection/models/recognition.dart';
+import "package:photos/services/object_detection/models/stats.dart";
+import "package:tflite_flutter/tflite_flutter.dart";
+import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
+
+/// Classifier
+class ObjectClassifier {
+ final _logger = Logger("Classifier");
+
+ /// Instance of Interpreter
+ late Interpreter _interpreter;
+
+ /// Labels file loaded as list
+ late List _labels;
+
+ /// Input size of image (height = width = 300)
+ static const int inputSize = 300;
+
+ /// Result score threshold
+ static const double threshold = 0.5;
+
+ static const String modelFileName = "detect.tflite";
+ static const String labelFileName = "labelmap.txt";
+
+ /// [ImageProcessor] used to pre-process the image
+ ImageProcessor? imageProcessor;
+
+ /// Padding the image to transform into square
+ late int padSize;
+
+ /// Shapes of output tensors
+ late List> _outputShapes;
+
+ /// Types of output tensors
+ late List _outputTypes;
+
+ /// Number of results to show
+ static const int numResults = 10;
+
+ ObjectClassifier({
+ Interpreter? interpreter,
+ List? labels,
+ }) {
+ loadModel(interpreter);
+ loadLabels(labels);
+ }
+
+ /// Loads interpreter from asset
+ void loadModel(Interpreter? interpreter) async {
+ try {
+ _interpreter = interpreter ??
+ await Interpreter.fromAsset(
+ "models/" + modelFileName,
+ options: InterpreterOptions()..threads = 4,
+ );
+ final outputTensors = _interpreter.getOutputTensors();
+ _outputShapes = [];
+ _outputTypes = [];
+ outputTensors.forEach((tensor) {
+ _outputShapes.add(tensor.shape);
+ _outputTypes.add(tensor.type);
+ });
+ _logger.info("Interpreter initialized");
+ } catch (e, s) {
+ _logger.severe("Error while creating interpreter", e, s);
+ }
+ }
+
+ /// Loads labels from assets
+ void loadLabels(List? labels) async {
+ try {
+ _labels =
+ labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
+ _logger.info("Labels initialized");
+ } catch (e, s) {
+ _logger.severe("Error while loading labels", e, s);
+ }
+ }
+
+ /// Pre-process the image
+ TensorImage _getProcessedImage(TensorImage inputImage) {
+ padSize = max(inputImage.height, inputImage.width);
+ imageProcessor ??= ImageProcessorBuilder()
+ .add(ResizeWithCropOrPadOp(padSize, padSize))
+ .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
+ .build();
+ inputImage = imageProcessor!.process(inputImage);
+ return inputImage;
+ }
+
+ /// Runs object detection on the input image
+ Predictions? predict(imageLib.Image image) {
+ final predictStartTime = DateTime.now().millisecondsSinceEpoch;
+
+ final preProcessStart = DateTime.now().millisecondsSinceEpoch;
+
+ // Create TensorImage from image
+ TensorImage inputImage = TensorImage.fromImage(image);
+
+ // Pre-process TensorImage
+ inputImage = _getProcessedImage(inputImage);
+
+ final preProcessElapsedTime =
+ DateTime.now().millisecondsSinceEpoch - preProcessStart;
+
+ // TensorBuffers for output tensors
+ final outputLocations = TensorBufferFloat(_outputShapes[0]);
+ final outputClasses = TensorBufferFloat(_outputShapes[1]);
+ final outputScores = TensorBufferFloat(_outputShapes[2]);
+ final numLocations = TensorBufferFloat(_outputShapes[3]);
+
+ // Inputs object for runForMultipleInputs
+ // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
+ final inputs = [inputImage.buffer];
+
+ // Outputs map
+ final outputs = {
+ 0: outputLocations.buffer,
+ 1: outputClasses.buffer,
+ 2: outputScores.buffer,
+ 3: numLocations.buffer,
+ };
+
+ final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
+
+ // run inference
+ _interpreter.runForMultipleInputs(inputs, outputs);
+
+ final inferenceTimeElapsed =
+ DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
+
+ // Maximum number of results to show
+ final resultsCount = min(numResults, numLocations.getIntValue(0));
+
+ // Using labelOffset = 1 as ??? at index 0
+ const labelOffset = 1;
+
+ final recognitions = [];
+
+ for (int i = 0; i < resultsCount; i++) {
+ // Prediction score
+ final score = outputScores.getDoubleValue(i);
+
+ // Label string
+ final labelIndex = outputClasses.getIntValue(i) + labelOffset;
+ final label = _labels.elementAt(labelIndex);
+
+ if (score > threshold) {
+ recognitions.add(
+ Recognition(i, label, score),
+ );
+ }
+ }
+
+ final predictElapsedTime =
+ DateTime.now().millisecondsSinceEpoch - predictStartTime;
+ _logger.info(recognitions);
+ return Predictions(
+ recognitions,
+ Stats(
+ predictElapsedTime,
+ predictElapsedTime,
+ inferenceTimeElapsed,
+ preProcessElapsedTime,
+ ),
+ );
+ }
+
+ /// Gets the interpreter instance
+ Interpreter get interpreter => _interpreter;
+
+ /// Gets the loaded labels
+ List get labels => _labels;
+}
diff --git a/lib/services/object_detection/utils/isolate_utils.dart b/lib/services/object_detection/utils/isolate_utils.dart
new file mode 100644
index 000000000..2d55424d4
--- /dev/null
+++ b/lib/services/object_detection/utils/isolate_utils.dart
@@ -0,0 +1,55 @@
+import 'dart:isolate';
+import "dart:typed_data";
+
+import 'package:image/image.dart' as imgLib;
+import "package:photos/services/object_detection/tflite/classifier.dart";
+import 'package:tflite_flutter/tflite_flutter.dart';
+
+/// Manages separate Isolate instance for inference
+class IsolateUtils {
+ static const String debugName = "InferenceIsolate";
+
+ late SendPort _sendPort;
+ final _receivePort = ReceivePort();
+
+ SendPort get sendPort => _sendPort;
+
+ Future start() async {
+ await Isolate.spawn(
+ entryPoint,
+ _receivePort.sendPort,
+ debugName: debugName,
+ );
+
+ _sendPort = await _receivePort.first;
+ }
+
+ static void entryPoint(SendPort sendPort) async {
+ final port = ReceivePort();
+ sendPort.send(port.sendPort);
+
+ await for (final IsolateData isolateData in port) {
+ final classifier = ObjectClassifier(
+ interpreter: Interpreter.fromAddress(isolateData.interpreterAddress),
+ labels: isolateData.labels,
+ );
+ final image = imgLib.decodeImage(isolateData.input);
+ final results = classifier.predict(image!);
+ isolateData.responsePort.send(results);
+ }
+ }
+}
+
+/// Bundles data to pass between Isolate
+class IsolateData {
+ Uint8List input;
+ int interpreterAddress;
+ List labels;
+ late SendPort responsePort;
+
+ IsolateData(
+ this.input,
+ this.interpreterAddress,
+ this.labels,
+ );
+}
diff --git a/lib/ui/viewer/file/file_info_widget.dart b/lib/ui/viewer/file/file_info_widget.dart
index 2a6e7b8f3..e40c1c165 100644
--- a/lib/ui/viewer/file/file_info_widget.dart
+++ b/lib/ui/viewer/file/file_info_widget.dart
@@ -10,6 +10,7 @@ import "package:photos/ente_theme_data.dart";
import "package:photos/models/file.dart";
import "package:photos/models/file_type.dart";
import 'package:photos/services/collections_service.dart';
+import "package:photos/services/feature_flag_service.dart";
import 'package:photos/theme/ente_theme.dart';
import 'package:photos/ui/components/divider_widget.dart';
import 'package:photos/ui/components/icon_button_widget.dart';
@@ -17,6 +18,7 @@ import 'package:photos/ui/components/title_bar_widget.dart';
import 'package:photos/ui/viewer/file/collections_list_of_file_widget.dart';
import 'package:photos/ui/viewer/file/device_folders_list_of_file_widget.dart';
import 'package:photos/ui/viewer/file/file_caption_widget.dart';
+import "package:photos/ui/viewer/file/object_tags_widget.dart";
import 'package:photos/ui/viewer/file/raw_exif_list_tile_widget.dart';
import "package:photos/utils/date_time_util.dart";
import "package:photos/utils/exif_util.dart";
@@ -235,6 +237,16 @@ class _FileInfoWidgetState extends State {
: DeviceFoldersListOfFileWidget(allDeviceFoldersOfFile),
),
),
+ FeatureFlagService.instance.isInternalUserOrDebugBuild()
+ ? SizedBox(
+ height: 62,
+ child: ListTile(
+ horizontalTitleGap: 0,
+ leading: const Icon(Icons.image_search),
+ title: ObjectTagsWidget(file),
+ ),
+ )
+ : null,
(file.uploadedFileID != null && file.updationTime != null)
? ListTile(
horizontalTitleGap: 2,
diff --git a/lib/ui/viewer/file/object_tags_widget.dart b/lib/ui/viewer/file/object_tags_widget.dart
new file mode 100644
index 000000000..4a24b4ac1
--- /dev/null
+++ b/lib/ui/viewer/file/object_tags_widget.dart
@@ -0,0 +1,77 @@
+import "package:flutter/material.dart";
+import "package:logging/logging.dart";
+import "package:photos/ente_theme_data.dart";
+import "package:photos/models/file.dart";
+import "package:photos/services/object_detection/object_detection_service.dart";
+import "package:photos/ui/common/loading_widget.dart";
+import "package:photos/utils/thumbnail_util.dart";
+
+class ObjectTagsWidget extends StatelessWidget {
+ final File file;
+
+ const ObjectTagsWidget(this.file, {Key? key}) : super(key: key);
+
+ @override
+ Widget build(BuildContext context) {
+ return FutureBuilder>(
+ future: getThumbnail(file).then((data) {
+ return ObjectDetectionService.instance.predict(data!);
+ }),
+ builder: (context, snapshot) {
+ if (snapshot.hasData) {
+ final List tags = snapshot.data!;
+ if (tags.isEmpty) {
+ return const ObjectTagWidget("No Results");
+ }
+ return ListView.builder(
+ itemCount: tags.length,
+ scrollDirection: Axis.horizontal,
+ itemBuilder: (context, index) {
+ return ObjectTagWidget(tags[index]);
+ },
+ );
+ } else if (snapshot.hasError) {
+ Logger("ObjectTagsWidget").severe(snapshot.error);
+ return const Icon(Icons.error);
+ } else {
+ return const EnteLoadingWidget();
+ }
+ },
+ );
+ }
+}
+
+class ObjectTagWidget extends StatelessWidget {
+ final String name;
+ const ObjectTagWidget(this.name, {Key? key}) : super(key: key);
+
+ @override
+ Widget build(BuildContext context) {
+ return Container(
+ margin: const EdgeInsets.only(
+ top: 10,
+ bottom: 18,
+ right: 8,
+ ),
+ decoration: BoxDecoration(
+ color: Theme.of(context)
+ .colorScheme
+ .inverseBackgroundColor
+ .withOpacity(0.025),
+ borderRadius: const BorderRadius.all(
+ Radius.circular(8),
+ ),
+ ),
+ child: Center(
+ child: Padding(
+ padding: const EdgeInsets.symmetric(horizontal: 8),
+ child: Text(
+ name!,
+ style: Theme.of(context).textTheme.subtitle2,
+ overflow: TextOverflow.ellipsis,
+ ),
+ ),
+ ),
+ );
+ }
+}
diff --git a/lib/utils/thumbnail_util.dart b/lib/utils/thumbnail_util.dart
index 916c9ff0b..7c699af3e 100644
--- a/lib/utils/thumbnail_util.dart
+++ b/lib/utils/thumbnail_util.dart
@@ -32,6 +32,17 @@ class FileDownloadItem {
FileDownloadItem(this.file, this.completer, this.cancelToken, this.counter);
}
+Future getThumbnail(File file) async {
+ if (file.isRemoteFile) {
+ return getThumbnailFromServer(file);
+ } else {
+ return getThumbnailFromLocal(
+ file,
+ size: thumbnailLargeSize,
+ );
+ }
+}
+
Future getThumbnailFromServer(File file) async {
final cachedThumbnail = cachedThumbnailPath(file);
if (await cachedThumbnail.exists()) {
diff --git a/pubspec.lock b/pubspec.lock
index 2703ad68e..79660f993 100644
--- a/pubspec.lock
+++ b/pubspec.lock
@@ -121,6 +121,46 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.0.2"
+ camera:
+ dependency: transitive
+ description:
+ name: camera
+ sha256: "3ad71371b8168a4c8012c0b40a53c05afc75d46cc688b0f37b4611a841d47b25"
+ url: "https://pub.dev"
+ source: hosted
+ version: "0.9.8+1"
+ camera_android:
+ dependency: transitive
+ description:
+ name: camera_android
+ sha256: "665d62c1f334722c7519ca5d3b94ad68ecaa801691870602da5638a42c1fff67"
+ url: "https://pub.dev"
+ source: hosted
+ version: "0.9.8+3"
+ camera_avfoundation:
+ dependency: transitive
+ description:
+ name: camera_avfoundation
+ sha256: "6a68c20593d4cd58974d555f74a48b244f9db28cc9156de57781122d11b8754b"
+ url: "https://pub.dev"
+ source: hosted
+ version: "0.9.11"
+ camera_platform_interface:
+ dependency: transitive
+ description:
+ name: camera_platform_interface
+ sha256: b632be28e61d00a233f67d98ea90fd7041956f27a1c65500188ee459be60e15f
+ url: "https://pub.dev"
+ source: hosted
+ version: "2.4.0"
+ camera_web:
+ dependency: transitive
+ description:
+ name: camera_web
+ sha256: "18cdbee5441e9a6fb129fdd9b68a06d1b8c5236932ba97d5faeaefe80db2e5bd"
+ url: "https://pub.dev"
+ source: hosted
+ version: "0.2.1+6"
characters:
dependency: transitive
description:
@@ -1498,6 +1538,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "2.1.1"
+ stream_transform:
+ dependency: transitive
+ description:
+ name: stream_transform
+ sha256: "14a00e794c7c11aa145a170587321aedce29769c08d7f58b1d141da75e3b1c6f"
+ url: "https://pub.dev"
+ source: hosted
+ version: "2.1.0"
string_scanner:
dependency: transitive
description:
@@ -1562,6 +1610,23 @@ packages:
url: "https://pub.dev"
source: hosted
version: "0.4.20"
+ tflite_flutter:
+ dependency: "direct main"
+ description:
+ name: tflite_flutter
+ sha256: "663483abf86066cdf9eae29df65e5e39ea5823d3f426507a81840d3b84ce58cd"
+ url: "https://pub.dev"
+ source: hosted
+ version: "0.9.0"
+ tflite_flutter_helper:
+ dependency: "direct main"
+ description:
+ path: "."
+ ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a
+ resolved-ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a
+ url: "https://github.com/elephantum/tflite_flutter_helper.git"
+ source: git
+ version: "0.3.0"
timezone:
dependency: transitive
description:
diff --git a/pubspec.yaml b/pubspec.yaml
index 75c2fe194..a94c086f3 100644
--- a/pubspec.yaml
+++ b/pubspec.yaml
@@ -106,6 +106,11 @@ dependencies:
step_progress_indicator: ^1.0.2
syncfusion_flutter_core: ^19.2.49
syncfusion_flutter_sliders: ^19.2.49
+ tflite_flutter: ^0.9.0
+ tflite_flutter_helper:
+ git:
+ url: https://github.com/elephantum/tflite_flutter_helper.git # Fixes https://github.com/am15h/tflite_flutter_helper/issues/57
+ ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a
tuple: ^2.0.0
uni_links: ^0.5.1
url_launcher: ^6.0.3
@@ -150,6 +155,7 @@ flutter_native_splash:
flutter:
assets:
- assets/
+ - assets/models/
fonts:
- family: Inter
fonts:
diff --git a/setup.sh b/setup.sh
new file mode 100755
index 000000000..9b23d2705
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,64 @@
+#!/usr/bin/env bash
+
+# Originally from https://github.com/am15h/tflite_flutter_plugin/blob/master/install.sh
+
+cd "$(dirname "$(readlink -f "$0")")"
+
+# Pull from the latest tag where binaries were built
+TAG="v0.5.0"
+
+IOS_URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/"
+ANDROID_URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/"
+
+IOS_ASSET="TensorFlowLiteC.framework.zip"
+IOS_FRAMEWORK="TensorFlowLiteC.framework"
+IOS_DIR="ios/.symlinks/plugins/tflite_flutter/ios/"
+MACOSX_METADATA_DIR="__MACOSX"
+
+ANDROID_DIR="android/app/src/main/jniLibs/"
+ANDROID_LIB="libtensorflowlite_c.so"
+
+ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so"
+ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so"
+ARM="libtensorflowlite_c_arm.so"
+ARM_64="libtensorflowlite_c_arm64.so"
+X86="libtensorflowlite_c_x86.so"
+X86_64="libtensorflowlite_c_x86_64.so"
+
+delegate=0
+
+while getopts "d" OPTION
+do
+ case $OPTION in
+ d) delegate=1;;
+ esac
+done
+
+wget "${IOS_URL}${TAG}/${IOS_ASSET}"
+unzip ${IOS_ASSET}
+rm -rf ${MACOSX_METADATA_DIR}
+rm ${IOS_ASSET}
+rm -rf "${IOS_DIR}/${IOS_FRAMEWORK}"
+mv ${IOS_FRAMEWORK} ${IOS_DIR}
+
+download () {
+ wget "${ANDROID_URL}${TAG}/$1"
+ mkdir -p "${ANDROID_DIR}$2/"
+ mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}"
+}
+
+if [ ${delegate} -eq 1 ]
+then
+
+download ${ARM_DELEGATE} "armeabi-v7a"
+download ${ARM_64_DELEGATE} "arm64-v8a"
+
+else
+
+download ${ARM} "armeabi-v7a"
+download ${ARM_64} "arm64-v8a"
+
+fi
+
+download ${X86} "x86"
+download ${X86_64} "x86_64"