123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- import "dart:async";
- import "dart:collection";
- import "dart:io";
- import "package:computer/computer.dart";
- import "package:logging/logging.dart";
- import "package:photos/core/cache/lru_map.dart";
- import "package:photos/core/configuration.dart";
- import "package:photos/core/event_bus.dart";
- import "package:photos/db/embeddings_db.dart";
- import "package:photos/db/files_db.dart";
- import "package:photos/events/diff_sync_complete_event.dart";
- import 'package:photos/events/embedding_updated_event.dart';
- import "package:photos/events/file_uploaded_event.dart";
- import "package:photos/events/machine_learning_control_event.dart";
- import "package:photos/models/embedding.dart";
- import "package:photos/models/file/file.dart";
- import "package:photos/services/collections_service.dart";
- import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart';
- import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart';
- import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
- import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart';
- import "package:photos/utils/debouncer.dart";
- import "package:photos/utils/device_info.dart";
- import "package:photos/utils/local_settings.dart";
- import "package:photos/utils/thumbnail_util.dart";
- class SemanticSearchService {
- SemanticSearchService._privateConstructor();
- static final SemanticSearchService instance =
- SemanticSearchService._privateConstructor();
- static final Computer _computer = Computer.shared();
- static final LRUMap<String, List<double>> _queryCache = LRUMap(20);
- static const kEmbeddingLength = 512;
- static const kScoreThreshold = 0.23;
- static const kShouldPushEmbeddings = true;
- static const kDebounceDuration = Duration(milliseconds: 4000);
- final _logger = Logger("SemanticSearchService");
- final _queue = Queue<EnteFile>();
- final _frameworkInitialization = Completer<bool>();
- final _embeddingLoaderDebouncer =
- Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
- late Model _currentModel;
- late MLFramework _mlFramework;
- bool _hasInitialized = false;
- bool _isComputingEmbeddings = false;
- bool _isSyncing = false;
- Future<List<EnteFile>>? _ongoingRequest;
- List<Embedding> _cachedEmbeddings = <Embedding>[];
- PendingQuery? _nextQuery;
- Completer<void> _mlController = Completer<void>();
- get hasInitialized => _hasInitialized;
- Future<void> init({bool shouldSyncImmediately = false}) async {
- if (!LocalSettings.instance.hasEnabledMagicSearch()) {
- return;
- }
- if (_hasInitialized) {
- _logger.info("Initialized already");
- return;
- }
- _hasInitialized = true;
- final shouldDownloadOverMobileData =
- Configuration.instance.shouldBackupOverMobileData();
- _currentModel = await _getCurrentModel();
- _mlFramework = _currentModel == Model.onnxClip
- ? ONNX(shouldDownloadOverMobileData)
- : GGML(shouldDownloadOverMobileData);
- await EmbeddingsDB.instance.init();
- await EmbeddingStore.instance.init();
- await _loadEmbeddings();
- Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
- _embeddingLoaderDebouncer.run(() async {
- await _loadEmbeddings();
- });
- });
- Bus.instance.on<DiffSyncCompleteEvent>().listen((event) {
- // Diff sync is complete, we can now pull embeddings from remote
- unawaited(sync());
- });
- if (Configuration.instance.hasConfiguredAccount() &&
- kShouldPushEmbeddings) {
- unawaited(EmbeddingStore.instance.pushEmbeddings());
- }
- // ignore: unawaited_futures
- _loadModels().then((v) async {
- _logger.info("Getting text embedding");
- await _getTextEmbedding("warm up text encoder");
- _logger.info("Got text embedding");
- });
- // Adding to queue only on init?
- Bus.instance.on<FileUploadedEvent>().listen((event) async {
- _addToQueue(event.file);
- });
- if (shouldSyncImmediately) {
- unawaited(sync());
- }
- if (Platform.isAndroid) {
- Bus.instance.on<MachineLearningControlEvent>().listen((event) {
- if (event.shouldRun) {
- _startIndexing();
- } else {
- _pauseIndexing();
- }
- });
- } else {
- _startIndexing();
- }
- }
- Future<void> release() async {
- if (_frameworkInitialization.isCompleted) {
- await _mlFramework.release();
- }
- }
- Future<void> sync() async {
- if (_isSyncing) {
- return;
- }
- _isSyncing = true;
- await EmbeddingStore.instance.pullEmbeddings(_currentModel);
- await _backFill();
- _isSyncing = false;
- }
- Future<List<EnteFile>> search(String query) async {
- if (!LocalSettings.instance.hasEnabledMagicSearch() ||
- !_frameworkInitialization.isCompleted) {
- return [];
- }
- if (_ongoingRequest == null) {
- _ongoingRequest = _getMatchingFiles(query).then((result) {
- _ongoingRequest = null;
- if (_nextQuery != null) {
- final next = _nextQuery;
- _nextQuery = null;
- search(next!.query).then((nextResult) {
- next.completer.complete(nextResult);
- });
- }
- return result;
- });
- return _ongoingRequest!;
- } else {
- // If there's an ongoing request, create or replace the nextCompleter.
- _logger.info("Queuing query $query");
- await _nextQuery?.completer.future
- .timeout(const Duration(seconds: 0)); // Cancels the previous future.
- _nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
- return _nextQuery!.completer.future;
- }
- }
- Future<IndexStatus> getIndexStatus() async {
- return IndexStatus(
- _cachedEmbeddings.length,
- (await _getFileIDsToBeIndexed()).length,
- );
- }
- InitializationState getFrameworkInitializationState() {
- if (!_hasInitialized) {
- return InitializationState.notInitialized;
- }
- return _mlFramework.initializationState;
- }
- Future<void> clearIndexes() async {
- await EmbeddingStore.instance.clearEmbeddings(_currentModel);
- _logger.info("Indexes cleared for $_currentModel");
- }
- Future<void> _loadEmbeddings() async {
- _logger.info("Pulling cached embeddings");
- final startTime = DateTime.now();
- _cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel);
- final endTime = DateTime.now();
- _logger.info(
- "Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
- );
- _logger.info("Cached embeddings: " + _cachedEmbeddings.length.toString());
- }
- Future<void> _backFill() async {
- if (!LocalSettings.instance.hasEnabledMagicSearch() ||
- !MLFramework.kImageEncoderEnabled) {
- return;
- }
- await _frameworkInitialization.future;
- _logger.info("Attempting backfill for image embeddings");
- final fileIDs = await _getFileIDsToBeIndexed();
- final files = await FilesDB.instance.getUploadedFiles(fileIDs);
- _logger.info(files.length.toString() + " to be embedded");
- // await _cacheThumbnails(files);
- _queue.addAll(files);
- unawaited(_pollQueue());
- }
- Future<void> _cacheThumbnails(List<EnteFile> files) async {
- int counter = 0;
- const batchSize = 100;
- for (var i = 0; i < files.length;) {
- final futures = <Future>[];
- for (var j = 0; j < batchSize && i < files.length; j++, i++) {
- futures.add(getThumbnail(files[i]));
- }
- await Future.wait(futures);
- counter += futures.length;
- _logger.info("$counter/${files.length} thumbnails cached");
- }
- }
- Future<List<int>> _getFileIDsToBeIndexed() async {
- final uploadedFileIDs = await FilesDB.instance
- .getOwnedFileIDs(Configuration.instance.getUserID()!);
- final embeddedFileIDs = _cachedEmbeddings.map((e) => e.fileID).toSet();
- uploadedFileIDs.removeWhere(
- (id) => embeddedFileIDs.contains(id),
- );
- return uploadedFileIDs;
- }
- Future<void> clearQueue() async {
- _queue.clear();
- }
- Future<List<EnteFile>> _getMatchingFiles(String query) async {
- final textEmbedding = await _getTextEmbedding(query);
- final queryResults = await _getScores(textEmbedding);
- final filesMap = await FilesDB.instance
- .getFilesFromIDs(queryResults.map((e) => e.id).toList());
- final results = <EnteFile>[];
- final ignoredCollections =
- CollectionsService.instance.getHiddenCollectionIds();
- for (final result in queryResults) {
- final file = filesMap[result.id];
- if (file != null && !ignoredCollections.contains(file.collectionID)) {
- results.add(filesMap[result.id]!);
- }
- }
- _logger.info(results.length.toString() + " results");
- return results;
- }
- void _addToQueue(EnteFile file) {
- if (!LocalSettings.instance.hasEnabledMagicSearch()) {
- return;
- }
- _logger.info("Adding " + file.toString() + " to the queue");
- _queue.add(file);
- _pollQueue();
- }
- Future<void> _loadModels() async {
- _logger.info("Initializing ML framework");
- try {
- await _mlFramework.init();
- _frameworkInitialization.complete(true);
- } catch (e, s) {
- _logger.severe("ML framework initialization failed", e, s);
- }
- _logger.info("ML framework initialized");
- }
- Future<void> _pollQueue() async {
- if (_isComputingEmbeddings) {
- return;
- }
- _isComputingEmbeddings = true;
- while (_queue.isNotEmpty) {
- await computeImageEmbedding(_queue.removeLast());
- }
- _isComputingEmbeddings = false;
- }
- Future<void> computeImageEmbedding(EnteFile file) async {
- if (!MLFramework.kImageEncoderEnabled) {
- return;
- }
- if (!_frameworkInitialization.isCompleted) {
- return;
- }
- if (!_mlController.isCompleted) {
- _logger.info("Waiting for a green signal from controller...");
- await _mlController.future;
- }
- try {
- final thumbnail = await getThumbnailForUploadedFile(file);
- if (thumbnail == null) {
- _logger.warning("Could not get thumbnail for $file");
- return;
- }
- final filePath = thumbnail.path;
- _logger.info("Running clip over $file");
- final result = await _mlFramework.getImageEmbedding(filePath);
- if (result.length != kEmbeddingLength) {
- _logger.severe("Discovered incorrect embedding for $file - $result");
- return;
- }
- final embedding = Embedding(
- fileID: file.uploadedFileID!,
- model: _currentModel,
- embedding: result,
- );
- await EmbeddingStore.instance.storeEmbedding(
- file,
- embedding,
- );
- } catch (e, s) {
- _logger.severe(e, s);
- }
- }
- Future<List<double>> _getTextEmbedding(String query) async {
- _logger.info("Searching for " + query);
- final cachedResult = _queryCache.get(query);
- if (cachedResult != null) {
- return cachedResult;
- }
- try {
- final result = await _mlFramework.getTextEmbedding(query);
- _queryCache.put(query, result);
- return result;
- } catch (e) {
- _logger.severe("Could not get text embedding", e);
- return [];
- }
- }
- Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
- final startTime = DateTime.now();
- final List<QueryResult> queryResults = await _computer.compute(
- computeBulkScore,
- param: {
- "imageEmbeddings": _cachedEmbeddings,
- "textEmbedding": textEmbedding,
- },
- taskName: "computeBulkScore",
- );
- final endTime = DateTime.now();
- _logger.info(
- "computingScores took: " +
- (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
- .toString() +
- "ms",
- );
- return queryResults;
- }
- Future<Model> _getCurrentModel() async {
- if (await isGrapheneOS()) {
- return Model.ggmlClip;
- } else {
- return Model.onnxClip;
- }
- }
- void _startIndexing() {
- _logger.info("Start indexing");
- if (!_mlController.isCompleted) {
- _mlController.complete();
- }
- }
- void _pauseIndexing() {
- if (_mlController.isCompleted) {
- _logger.info("Pausing indexing");
- _mlController = Completer<void>();
- }
- }
- }
- List<QueryResult> computeBulkScore(Map args) {
- final queryResults = <QueryResult>[];
- final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
- final textEmbedding = args["textEmbedding"] as List<double>;
- for (final imageEmbedding in imageEmbeddings) {
- final score = computeScore(
- imageEmbedding.embedding,
- textEmbedding,
- );
- if (score >= SemanticSearchService.kScoreThreshold) {
- queryResults.add(QueryResult(imageEmbedding.fileID, score));
- }
- }
- queryResults.sort((first, second) => second.score.compareTo(first.score));
- return queryResults;
- }
- double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
- assert(
- imageEmbedding.length == textEmbedding.length,
- "The two embeddings should have the same length",
- );
- double score = 0;
- for (int index = 0; index < imageEmbedding.length; index++) {
- score += imageEmbedding[index] * textEmbedding[index];
- }
- return score;
- }
- class QueryResult {
- final int id;
- final double score;
- QueryResult(this.id, this.score);
- }
- class PendingQuery {
- final String query;
- final Completer<List<EnteFile>> completer;
- PendingQuery(this.query, this.completer);
- }
- class IndexStatus {
- final int indexedItems, pendingItems;
- IndexStatus(this.indexedItems, this.pendingItems);
- }
|