diff --git a/README.md b/README.md index 8d272596c..2ee7a5ae3 100644 --- a/README.md +++ b/README.md @@ -65,8 +65,9 @@ You can alternatively install the build from PlayStore or F-Droid. 2. Clone this repository with `git clone git@github.com:ente-io/photos-app.git` 3. Pull in all submodules with `git submodule update --init --recursive` 4. Enable repo git hooks `git config core.hooksPath hooks` -5. For Android, run `flutter build apk --release --flavor independent` -6. For iOS, run `flutter build ios` +5. Setup TensorFlowLite by executing `setup.sh` +6. For Android, run `flutter build apk --release --flavor independent` +7. For iOS, run `flutter build ios`
diff --git a/android/app/build.gradle b/android/app/build.gradle index 8e8dee587..1387fc824 100644 --- a/android/app/build.gradle +++ b/android/app/build.gradle @@ -46,7 +46,7 @@ android { defaultConfig { applicationId "io.ente.photos" - minSdkVersion 19 + minSdkVersion 21 targetSdkVersion 33 versionCode flutterVersionCode.toInteger() versionName flutterVersionName diff --git a/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so b/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so new file mode 100644 index 000000000..bf11c7ab4 Binary files /dev/null and b/android/app/src/main/jniLibs/arm64-v8a/libtensorflowlite_c.so differ diff --git a/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so b/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so new file mode 100644 index 000000000..3f687ab98 Binary files /dev/null and b/android/app/src/main/jniLibs/armeabi-v7a/libtensorflowlite_c.so differ diff --git a/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so b/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so new file mode 100644 index 000000000..1ac12a0f1 Binary files /dev/null and b/android/app/src/main/jniLibs/x86/libtensorflowlite_c.so differ diff --git a/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so b/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so new file mode 100644 index 000000000..9a007fddf Binary files /dev/null and b/android/app/src/main/jniLibs/x86_64/libtensorflowlite_c.so differ diff --git a/assets/models/detect.tflite b/assets/models/detect.tflite new file mode 100644 index 000000000..8015ee5d8 Binary files /dev/null and b/assets/models/detect.tflite differ diff --git a/assets/models/labelmap.txt b/assets/models/labelmap.txt new file mode 100644 index 000000000..941cb4e13 --- /dev/null +++ b/assets/models/labelmap.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/ios/Podfile.lock b/ios/Podfile.lock index eefda7d78..f5bb7ea2f 100644 --- a/ios/Podfile.lock +++ b/ios/Podfile.lock @@ -1,6 +1,8 @@ PODS: - background_fetch (1.1.4): - Flutter + - camera_avfoundation (0.0.1): + - Flutter - connectivity (0.0.1): - Flutter - Reachability @@ -156,6 +158,10 @@ PODS: - sqflite (0.0.2): - Flutter - FMDB (>= 2.7.5) + - tflite_flutter (0.1.0): + - Flutter + - tflite_flutter_helper (0.0.1): + - Flutter - Toast (4.0.0) - uni_links (0.0.1): - Flutter @@ -171,6 +177,7 @@ PODS: DEPENDENCIES: - background_fetch (from `.symlinks/plugins/background_fetch/ios`) + - camera_avfoundation (from `.symlinks/plugins/camera_avfoundation/ios`) - connectivity (from `.symlinks/plugins/connectivity/ios`) - device_info (from `.symlinks/plugins/device_info/ios`) - firebase_core (from `.symlinks/plugins/firebase_core/ios`) @@ -201,6 +208,8 @@ DEPENDENCIES: - share_plus (from `.symlinks/plugins/share_plus/ios`) - shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/ios`) - sqflite (from `.symlinks/plugins/sqflite/ios`) + - tflite_flutter (from `.symlinks/plugins/tflite_flutter/ios`) + - tflite_flutter_helper (from `.symlinks/plugins/tflite_flutter_helper/ios`) - uni_links (from `.symlinks/plugins/uni_links/ios`) - url_launcher_ios (from `.symlinks/plugins/url_launcher_ios/ios`) - video_player (from `.symlinks/plugins/video_player/ios`) @@ -231,6 +240,8 @@ SPEC REPOS: EXTERNAL SOURCES: background_fetch: :path: ".symlinks/plugins/background_fetch/ios" + camera_avfoundation: + :path: ".symlinks/plugins/camera_avfoundation/ios" connectivity: :path: ".symlinks/plugins/connectivity/ios" device_info: @@ -291,6 +302,10 @@ EXTERNAL SOURCES: :path: ".symlinks/plugins/shared_preferences_foundation/ios" sqflite: :path: ".symlinks/plugins/sqflite/ios" + tflite_flutter: + :path: ".symlinks/plugins/tflite_flutter/ios" + tflite_flutter_helper: + :path: ".symlinks/plugins/tflite_flutter_helper/ios" uni_links: :path: ".symlinks/plugins/uni_links/ios" url_launcher_ios: @@ -304,6 +319,7 @@ EXTERNAL SOURCES: SPEC CHECKSUMS: background_fetch: bd64e544b303ee4cd4cf2fe8cb2187b72aecf9ca + camera_avfoundation: 07c77549ea54ad95d8581be86617c094a46280d9 connectivity: c4130b2985d4ef6fd26f9702e886bd5260681467 device_info: d7d233b645a32c40dfdc212de5cf646ca482f175 Firebase: f92fc551ead69c94168d36c2b26188263860acd9 @@ -351,6 +367,8 @@ SPEC CHECKSUMS: share_plus: 056a1e8ac890df3e33cb503afffaf1e9b4fbae68 shared_preferences_foundation: 297b3ebca31b34ec92be11acd7fb0ba932c822ca sqflite: 6d358c025f5b867b29ed92fc697fd34924e11904 + tflite_flutter: 9157a660578930a99728974f247369af1c3595d5 + tflite_flutter_helper: 543b46b6bd064b21c92ea6e54bc0b29f1ce74cb5 Toast: 91b396c56ee72a5790816f40d3a94dd357abc196 uni_links: d97da20c7701486ba192624d99bffaaffcfc298a url_launcher_ios: fb12c43172927bb5cf75aeebd073f883801f1993 diff --git a/ios/Runner.xcodeproj/project.pbxproj b/ios/Runner.xcodeproj/project.pbxproj index bfd8d9494..9f41a9a3a 100644 --- a/ios/Runner.xcodeproj/project.pbxproj +++ b/ios/Runner.xcodeproj/project.pbxproj @@ -274,6 +274,7 @@ "${BUILT_PRODUCTS_DIR}/Sentry/Sentry.framework", "${BUILT_PRODUCTS_DIR}/Toast/Toast.framework", "${BUILT_PRODUCTS_DIR}/background_fetch/background_fetch.framework", + "${BUILT_PRODUCTS_DIR}/camera_avfoundation/camera_avfoundation.framework", "${BUILT_PRODUCTS_DIR}/connectivity/connectivity.framework", "${BUILT_PRODUCTS_DIR}/device_info/device_info.framework", "${BUILT_PRODUCTS_DIR}/fk_user_agent/fk_user_agent.framework", @@ -303,6 +304,8 @@ "${BUILT_PRODUCTS_DIR}/share_plus/share_plus.framework", "${BUILT_PRODUCTS_DIR}/shared_preferences_foundation/shared_preferences_foundation.framework", "${BUILT_PRODUCTS_DIR}/sqflite/sqflite.framework", + "${BUILT_PRODUCTS_DIR}/tflite_flutter/tflite_flutter.framework", + "${BUILT_PRODUCTS_DIR}/tflite_flutter_helper/tflite_flutter_helper.framework", "${BUILT_PRODUCTS_DIR}/uni_links/uni_links.framework", "${BUILT_PRODUCTS_DIR}/url_launcher_ios/url_launcher_ios.framework", "${BUILT_PRODUCTS_DIR}/video_player/video_player.framework", @@ -327,6 +330,7 @@ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/Sentry.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/Toast.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/background_fetch.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/camera_avfoundation.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/connectivity.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/device_info.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fk_user_agent.framework", @@ -356,6 +360,8 @@ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/share_plus.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/shared_preferences_foundation.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/sqflite.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/tflite_flutter.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/tflite_flutter_helper.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/uni_links.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/url_launcher_ios.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/video_player.framework", diff --git a/lib/main.dart b/lib/main.dart index 4548871ab..e998a81d8 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -24,6 +24,7 @@ import 'package:photos/services/local_file_update_service.dart'; import 'package:photos/services/local_sync_service.dart'; import 'package:photos/services/memories_service.dart'; import 'package:photos/services/notification_service.dart'; +import "package:photos/services/object_detection/object_detection_service.dart"; import 'package:photos/services/push_service.dart'; import 'package:photos/services/remote_sync_service.dart'; import 'package:photos/services/search_service.dart'; @@ -160,6 +161,9 @@ Future _init(bool isBackground, {String via = ''}) async { }); } FeatureFlagService.instance.init(); + if (FeatureFlagService.instance.isInternalUserOrDebugBuild()) { + await ObjectDetectionService.instance.init(); + } _logger.info("Initialization done"); } diff --git a/lib/services/object_detection/models/predictions.dart b/lib/services/object_detection/models/predictions.dart new file mode 100644 index 000000000..80c41a58c --- /dev/null +++ b/lib/services/object_detection/models/predictions.dart @@ -0,0 +1,9 @@ +import "package:photos/services/object_detection/models/recognition.dart"; +import "package:photos/services/object_detection/models/stats.dart"; + +class Predictions { + final List recognitions; + final Stats stats; + + Predictions(this.recognitions, this.stats); +} diff --git a/lib/services/object_detection/models/recognition.dart b/lib/services/object_detection/models/recognition.dart new file mode 100644 index 000000000..469326265 --- /dev/null +++ b/lib/services/object_detection/models/recognition.dart @@ -0,0 +1,18 @@ +/// Represents the recognition output from the model +class Recognition { + /// Index of the result + int id; + + /// Label of the result + String label; + + /// Confidence [0.0, 1.0] + double score; + + Recognition(this.id, this.label, this.score); + + @override + String toString() { + return 'Recognition(id: $id, label: $label, score: $score)'; + } +} diff --git a/lib/services/object_detection/models/stats.dart b/lib/services/object_detection/models/stats.dart new file mode 100644 index 000000000..487f1f6a0 --- /dev/null +++ b/lib/services/object_detection/models/stats.dart @@ -0,0 +1,27 @@ +/// Bundles different elapsed times +class Stats { + /// Total time taken in the isolate where the inference runs + int totalPredictTime; + + /// [totalPredictTime] + communication overhead time + /// between main isolate and another isolate + int totalElapsedTime; + + /// Time for which inference runs + int inferenceTime; + + /// Time taken to pre-process the image + int preProcessingTime; + + Stats( + this.totalPredictTime, + this.totalElapsedTime, + this.inferenceTime, + this.preProcessingTime, + ); + + @override + String toString() { + return 'Stats{totalPredictTime: $totalPredictTime, totalElapsedTime: $totalElapsedTime, inferenceTime: $inferenceTime, preProcessingTime: $preProcessingTime}'; + } +} diff --git a/lib/services/object_detection/object_detection_service.dart b/lib/services/object_detection/object_detection_service.dart new file mode 100644 index 000000000..47294be1e --- /dev/null +++ b/lib/services/object_detection/object_detection_service.dart @@ -0,0 +1,57 @@ +import "dart:isolate"; +import "dart:typed_data"; + +import "package:logging/logging.dart"; +import "package:photos/services/object_detection/models/predictions.dart"; +import 'package:photos/services/object_detection/models/recognition.dart'; +import "package:photos/services/object_detection/tflite/classifier.dart"; +import "package:photos/services/object_detection/utils/isolate_utils.dart"; + +class ObjectDetectionService { + final _logger = Logger("ObjectDetectionService"); + + /// Instance of [ObjectClassifier] + late ObjectClassifier _classifier; + + /// Instance of [IsolateUtils] + late IsolateUtils _isolateUtils; + + ObjectDetectionService._privateConstructor(); + + Future init() async { + _isolateUtils = IsolateUtils(); + await _isolateUtils.start(); + _classifier = ObjectClassifier(); + } + + static ObjectDetectionService instance = + ObjectDetectionService._privateConstructor(); + + Future> predict(Uint8List bytes) async { + try { + final isolateData = IsolateData( + bytes, + _classifier.interpreter.address, + _classifier.labels, + ); + final predictions = await _inference(isolateData); + final Set results = {}; + for (final Recognition result in predictions.recognitions) { + results.add(result.label); + } + return results.toList(); + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + /// Runs inference in another isolate + Future _inference(IsolateData isolateData) async { + final responsePort = ReceivePort(); + _isolateUtils.sendPort.send( + isolateData..responsePort = responsePort.sendPort, + ); + return await responsePort.first; + } +} diff --git a/lib/services/object_detection/tflite/classifier.dart b/lib/services/object_detection/tflite/classifier.dart new file mode 100644 index 000000000..299b31244 --- /dev/null +++ b/lib/services/object_detection/tflite/classifier.dart @@ -0,0 +1,179 @@ +import 'dart:math'; + +import 'package:image/image.dart' as imageLib; +import "package:logging/logging.dart"; +import 'package:photos/services/object_detection/models/predictions.dart'; +import 'package:photos/services/object_detection/models/recognition.dart'; +import "package:photos/services/object_detection/models/stats.dart"; +import "package:tflite_flutter/tflite_flutter.dart"; +import "package:tflite_flutter_helper/tflite_flutter_helper.dart"; + +/// Classifier +class ObjectClassifier { + final _logger = Logger("Classifier"); + + /// Instance of Interpreter + late Interpreter _interpreter; + + /// Labels file loaded as list + late List _labels; + + /// Input size of image (height = width = 300) + static const int inputSize = 300; + + /// Result score threshold + static const double threshold = 0.5; + + static const String modelFileName = "detect.tflite"; + static const String labelFileName = "labelmap.txt"; + + /// [ImageProcessor] used to pre-process the image + ImageProcessor? imageProcessor; + + /// Padding the image to transform into square + late int padSize; + + /// Shapes of output tensors + late List> _outputShapes; + + /// Types of output tensors + late List _outputTypes; + + /// Number of results to show + static const int numResults = 10; + + ObjectClassifier({ + Interpreter? interpreter, + List? labels, + }) { + loadModel(interpreter); + loadLabels(labels); + } + + /// Loads interpreter from asset + void loadModel(Interpreter? interpreter) async { + try { + _interpreter = interpreter ?? + await Interpreter.fromAsset( + "models/" + modelFileName, + options: InterpreterOptions()..threads = 4, + ); + final outputTensors = _interpreter.getOutputTensors(); + _outputShapes = []; + _outputTypes = []; + outputTensors.forEach((tensor) { + _outputShapes.add(tensor.shape); + _outputTypes.add(tensor.type); + }); + _logger.info("Interpreter initialized"); + } catch (e, s) { + _logger.severe("Error while creating interpreter", e, s); + } + } + + /// Loads labels from assets + void loadLabels(List? labels) async { + try { + _labels = + labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName); + _logger.info("Labels initialized"); + } catch (e, s) { + _logger.severe("Error while loading labels", e, s); + } + } + + /// Pre-process the image + TensorImage _getProcessedImage(TensorImage inputImage) { + padSize = max(inputImage.height, inputImage.width); + imageProcessor ??= ImageProcessorBuilder() + .add(ResizeWithCropOrPadOp(padSize, padSize)) + .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR)) + .build(); + inputImage = imageProcessor!.process(inputImage); + return inputImage; + } + + /// Runs object detection on the input image + Predictions? predict(imageLib.Image image) { + final predictStartTime = DateTime.now().millisecondsSinceEpoch; + + final preProcessStart = DateTime.now().millisecondsSinceEpoch; + + // Create TensorImage from image + TensorImage inputImage = TensorImage.fromImage(image); + + // Pre-process TensorImage + inputImage = _getProcessedImage(inputImage); + + final preProcessElapsedTime = + DateTime.now().millisecondsSinceEpoch - preProcessStart; + + // TensorBuffers for output tensors + final outputLocations = TensorBufferFloat(_outputShapes[0]); + final outputClasses = TensorBufferFloat(_outputShapes[1]); + final outputScores = TensorBufferFloat(_outputShapes[2]); + final numLocations = TensorBufferFloat(_outputShapes[3]); + + // Inputs object for runForMultipleInputs + // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference + final inputs = [inputImage.buffer]; + + // Outputs map + final outputs = { + 0: outputLocations.buffer, + 1: outputClasses.buffer, + 2: outputScores.buffer, + 3: numLocations.buffer, + }; + + final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch; + + // run inference + _interpreter.runForMultipleInputs(inputs, outputs); + + final inferenceTimeElapsed = + DateTime.now().millisecondsSinceEpoch - inferenceTimeStart; + + // Maximum number of results to show + final resultsCount = min(numResults, numLocations.getIntValue(0)); + + // Using labelOffset = 1 as ??? at index 0 + const labelOffset = 1; + + final recognitions = []; + + for (int i = 0; i < resultsCount; i++) { + // Prediction score + final score = outputScores.getDoubleValue(i); + + // Label string + final labelIndex = outputClasses.getIntValue(i) + labelOffset; + final label = _labels.elementAt(labelIndex); + + if (score > threshold) { + recognitions.add( + Recognition(i, label, score), + ); + } + } + + final predictElapsedTime = + DateTime.now().millisecondsSinceEpoch - predictStartTime; + _logger.info(recognitions); + return Predictions( + recognitions, + Stats( + predictElapsedTime, + predictElapsedTime, + inferenceTimeElapsed, + preProcessElapsedTime, + ), + ); + } + + /// Gets the interpreter instance + Interpreter get interpreter => _interpreter; + + /// Gets the loaded labels + List get labels => _labels; +} diff --git a/lib/services/object_detection/utils/isolate_utils.dart b/lib/services/object_detection/utils/isolate_utils.dart new file mode 100644 index 000000000..2d55424d4 --- /dev/null +++ b/lib/services/object_detection/utils/isolate_utils.dart @@ -0,0 +1,55 @@ +import 'dart:isolate'; +import "dart:typed_data"; + +import 'package:image/image.dart' as imgLib; +import "package:photos/services/object_detection/tflite/classifier.dart"; +import 'package:tflite_flutter/tflite_flutter.dart'; + +/// Manages separate Isolate instance for inference +class IsolateUtils { + static const String debugName = "InferenceIsolate"; + + late SendPort _sendPort; + final _receivePort = ReceivePort(); + + SendPort get sendPort => _sendPort; + + Future start() async { + await Isolate.spawn( + entryPoint, + _receivePort.sendPort, + debugName: debugName, + ); + + _sendPort = await _receivePort.first; + } + + static void entryPoint(SendPort sendPort) async { + final port = ReceivePort(); + sendPort.send(port.sendPort); + + await for (final IsolateData isolateData in port) { + final classifier = ObjectClassifier( + interpreter: Interpreter.fromAddress(isolateData.interpreterAddress), + labels: isolateData.labels, + ); + final image = imgLib.decodeImage(isolateData.input); + final results = classifier.predict(image!); + isolateData.responsePort.send(results); + } + } +} + +/// Bundles data to pass between Isolate +class IsolateData { + Uint8List input; + int interpreterAddress; + List labels; + late SendPort responsePort; + + IsolateData( + this.input, + this.interpreterAddress, + this.labels, + ); +} diff --git a/lib/ui/viewer/file/file_info_widget.dart b/lib/ui/viewer/file/file_info_widget.dart index 2a6e7b8f3..e40c1c165 100644 --- a/lib/ui/viewer/file/file_info_widget.dart +++ b/lib/ui/viewer/file/file_info_widget.dart @@ -10,6 +10,7 @@ import "package:photos/ente_theme_data.dart"; import "package:photos/models/file.dart"; import "package:photos/models/file_type.dart"; import 'package:photos/services/collections_service.dart'; +import "package:photos/services/feature_flag_service.dart"; import 'package:photos/theme/ente_theme.dart'; import 'package:photos/ui/components/divider_widget.dart'; import 'package:photos/ui/components/icon_button_widget.dart'; @@ -17,6 +18,7 @@ import 'package:photos/ui/components/title_bar_widget.dart'; import 'package:photos/ui/viewer/file/collections_list_of_file_widget.dart'; import 'package:photos/ui/viewer/file/device_folders_list_of_file_widget.dart'; import 'package:photos/ui/viewer/file/file_caption_widget.dart'; +import "package:photos/ui/viewer/file/object_tags_widget.dart"; import 'package:photos/ui/viewer/file/raw_exif_list_tile_widget.dart'; import "package:photos/utils/date_time_util.dart"; import "package:photos/utils/exif_util.dart"; @@ -235,6 +237,16 @@ class _FileInfoWidgetState extends State { : DeviceFoldersListOfFileWidget(allDeviceFoldersOfFile), ), ), + FeatureFlagService.instance.isInternalUserOrDebugBuild() + ? SizedBox( + height: 62, + child: ListTile( + horizontalTitleGap: 0, + leading: const Icon(Icons.image_search), + title: ObjectTagsWidget(file), + ), + ) + : null, (file.uploadedFileID != null && file.updationTime != null) ? ListTile( horizontalTitleGap: 2, diff --git a/lib/ui/viewer/file/object_tags_widget.dart b/lib/ui/viewer/file/object_tags_widget.dart new file mode 100644 index 000000000..4a24b4ac1 --- /dev/null +++ b/lib/ui/viewer/file/object_tags_widget.dart @@ -0,0 +1,77 @@ +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/ente_theme_data.dart"; +import "package:photos/models/file.dart"; +import "package:photos/services/object_detection/object_detection_service.dart"; +import "package:photos/ui/common/loading_widget.dart"; +import "package:photos/utils/thumbnail_util.dart"; + +class ObjectTagsWidget extends StatelessWidget { + final File file; + + const ObjectTagsWidget(this.file, {Key? key}) : super(key: key); + + @override + Widget build(BuildContext context) { + return FutureBuilder>( + future: getThumbnail(file).then((data) { + return ObjectDetectionService.instance.predict(data!); + }), + builder: (context, snapshot) { + if (snapshot.hasData) { + final List tags = snapshot.data!; + if (tags.isEmpty) { + return const ObjectTagWidget("No Results"); + } + return ListView.builder( + itemCount: tags.length, + scrollDirection: Axis.horizontal, + itemBuilder: (context, index) { + return ObjectTagWidget(tags[index]); + }, + ); + } else if (snapshot.hasError) { + Logger("ObjectTagsWidget").severe(snapshot.error); + return const Icon(Icons.error); + } else { + return const EnteLoadingWidget(); + } + }, + ); + } +} + +class ObjectTagWidget extends StatelessWidget { + final String name; + const ObjectTagWidget(this.name, {Key? key}) : super(key: key); + + @override + Widget build(BuildContext context) { + return Container( + margin: const EdgeInsets.only( + top: 10, + bottom: 18, + right: 8, + ), + decoration: BoxDecoration( + color: Theme.of(context) + .colorScheme + .inverseBackgroundColor + .withOpacity(0.025), + borderRadius: const BorderRadius.all( + Radius.circular(8), + ), + ), + child: Center( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Text( + name!, + style: Theme.of(context).textTheme.subtitle2, + overflow: TextOverflow.ellipsis, + ), + ), + ), + ); + } +} diff --git a/lib/utils/thumbnail_util.dart b/lib/utils/thumbnail_util.dart index 916c9ff0b..7c699af3e 100644 --- a/lib/utils/thumbnail_util.dart +++ b/lib/utils/thumbnail_util.dart @@ -32,6 +32,17 @@ class FileDownloadItem { FileDownloadItem(this.file, this.completer, this.cancelToken, this.counter); } +Future getThumbnail(File file) async { + if (file.isRemoteFile) { + return getThumbnailFromServer(file); + } else { + return getThumbnailFromLocal( + file, + size: thumbnailLargeSize, + ); + } +} + Future getThumbnailFromServer(File file) async { final cachedThumbnail = cachedThumbnailPath(file); if (await cachedThumbnail.exists()) { diff --git a/pubspec.lock b/pubspec.lock index 2703ad68e..79660f993 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -121,6 +121,46 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.2" + camera: + dependency: transitive + description: + name: camera + sha256: "3ad71371b8168a4c8012c0b40a53c05afc75d46cc688b0f37b4611a841d47b25" + url: "https://pub.dev" + source: hosted + version: "0.9.8+1" + camera_android: + dependency: transitive + description: + name: camera_android + sha256: "665d62c1f334722c7519ca5d3b94ad68ecaa801691870602da5638a42c1fff67" + url: "https://pub.dev" + source: hosted + version: "0.9.8+3" + camera_avfoundation: + dependency: transitive + description: + name: camera_avfoundation + sha256: "6a68c20593d4cd58974d555f74a48b244f9db28cc9156de57781122d11b8754b" + url: "https://pub.dev" + source: hosted + version: "0.9.11" + camera_platform_interface: + dependency: transitive + description: + name: camera_platform_interface + sha256: b632be28e61d00a233f67d98ea90fd7041956f27a1c65500188ee459be60e15f + url: "https://pub.dev" + source: hosted + version: "2.4.0" + camera_web: + dependency: transitive + description: + name: camera_web + sha256: "18cdbee5441e9a6fb129fdd9b68a06d1b8c5236932ba97d5faeaefe80db2e5bd" + url: "https://pub.dev" + source: hosted + version: "0.2.1+6" characters: dependency: transitive description: @@ -1498,6 +1538,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.1.1" + stream_transform: + dependency: transitive + description: + name: stream_transform + sha256: "14a00e794c7c11aa145a170587321aedce29769c08d7f58b1d141da75e3b1c6f" + url: "https://pub.dev" + source: hosted + version: "2.1.0" string_scanner: dependency: transitive description: @@ -1562,6 +1610,23 @@ packages: url: "https://pub.dev" source: hosted version: "0.4.20" + tflite_flutter: + dependency: "direct main" + description: + name: tflite_flutter + sha256: "663483abf86066cdf9eae29df65e5e39ea5823d3f426507a81840d3b84ce58cd" + url: "https://pub.dev" + source: hosted + version: "0.9.0" + tflite_flutter_helper: + dependency: "direct main" + description: + path: "." + ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a + resolved-ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a + url: "https://github.com/elephantum/tflite_flutter_helper.git" + source: git + version: "0.3.0" timezone: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 75c2fe194..a94c086f3 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -106,6 +106,11 @@ dependencies: step_progress_indicator: ^1.0.2 syncfusion_flutter_core: ^19.2.49 syncfusion_flutter_sliders: ^19.2.49 + tflite_flutter: ^0.9.0 + tflite_flutter_helper: + git: + url: https://github.com/elephantum/tflite_flutter_helper.git # Fixes https://github.com/am15h/tflite_flutter_helper/issues/57 + ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a tuple: ^2.0.0 uni_links: ^0.5.1 url_launcher: ^6.0.3 @@ -150,6 +155,7 @@ flutter_native_splash: flutter: assets: - assets/ + - assets/models/ fonts: - family: Inter fonts: diff --git a/setup.sh b/setup.sh new file mode 100755 index 000000000..9b23d2705 --- /dev/null +++ b/setup.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash + +# Originally from https://github.com/am15h/tflite_flutter_plugin/blob/master/install.sh + +cd "$(dirname "$(readlink -f "$0")")" + +# Pull from the latest tag where binaries were built +TAG="v0.5.0" + +IOS_URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/" +ANDROID_URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/" + +IOS_ASSET="TensorFlowLiteC.framework.zip" +IOS_FRAMEWORK="TensorFlowLiteC.framework" +IOS_DIR="ios/.symlinks/plugins/tflite_flutter/ios/" +MACOSX_METADATA_DIR="__MACOSX" + +ANDROID_DIR="android/app/src/main/jniLibs/" +ANDROID_LIB="libtensorflowlite_c.so" + +ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so" +ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so" +ARM="libtensorflowlite_c_arm.so" +ARM_64="libtensorflowlite_c_arm64.so" +X86="libtensorflowlite_c_x86.so" +X86_64="libtensorflowlite_c_x86_64.so" + +delegate=0 + +while getopts "d" OPTION +do + case $OPTION in + d) delegate=1;; + esac +done + +wget "${IOS_URL}${TAG}/${IOS_ASSET}" +unzip ${IOS_ASSET} +rm -rf ${MACOSX_METADATA_DIR} +rm ${IOS_ASSET} +rm -rf "${IOS_DIR}/${IOS_FRAMEWORK}" +mv ${IOS_FRAMEWORK} ${IOS_DIR} + +download () { + wget "${ANDROID_URL}${TAG}/$1" + mkdir -p "${ANDROID_DIR}$2/" + mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}" +} + +if [ ${delegate} -eq 1 ] +then + +download ${ARM_DELEGATE} "armeabi-v7a" +download ${ARM_64_DELEGATE} "arm64-v8a" + +else + +download ${ARM} "armeabi-v7a" +download ${ARM_64} "arm64-v8a" + +fi + +download ${X86} "x86" +download ${X86_64} "x86_64"