semantic_search_service.dart 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import "dart:async";
  2. import "dart:collection";
  3. import "dart:io";
  4. import "package:computer/computer.dart";
  5. import "package:logging/logging.dart";
  6. import "package:photos/core/cache/lru_map.dart";
  7. import "package:photos/core/configuration.dart";
  8. import "package:photos/core/event_bus.dart";
  9. import "package:photos/db/embeddings_db.dart";
  10. import "package:photos/db/files_db.dart";
  11. import "package:photos/events/diff_sync_complete_event.dart";
  12. import 'package:photos/events/embedding_updated_event.dart';
  13. import "package:photos/events/file_uploaded_event.dart";
  14. import "package:photos/events/machine_learning_control_event.dart";
  15. import "package:photos/models/embedding.dart";
  16. import "package:photos/models/file/file.dart";
  17. import "package:photos/services/collections_service.dart";
  18. import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart';
  19. import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart';
  20. import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
  21. import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart';
  22. import "package:photos/utils/debouncer.dart";
  23. import "package:photos/utils/device_info.dart";
  24. import "package:photos/utils/local_settings.dart";
  25. import "package:photos/utils/thumbnail_util.dart";
  26. class SemanticSearchService {
  27. SemanticSearchService._privateConstructor();
  28. static final SemanticSearchService instance =
  29. SemanticSearchService._privateConstructor();
  30. static final Computer _computer = Computer.shared();
  31. static final LRUMap<String, List<double>> _queryCache = LRUMap(20);
  32. static const kEmbeddingLength = 512;
  33. static const kScoreThreshold = 0.23;
  34. static const kShouldPushEmbeddings = true;
  35. static const kDebounceDuration = Duration(milliseconds: 4000);
  36. final _logger = Logger("SemanticSearchService");
  37. final _queue = Queue<EnteFile>();
  38. final _frameworkInitialization = Completer<bool>();
  39. final _embeddingLoaderDebouncer =
  40. Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
  41. late Model _currentModel;
  42. late MLFramework _mlFramework;
  43. bool _hasInitialized = false;
  44. bool _isComputingEmbeddings = false;
  45. bool _isSyncing = false;
  46. Future<List<EnteFile>>? _ongoingRequest;
  47. List<Embedding> _cachedEmbeddings = <Embedding>[];
  48. PendingQuery? _nextQuery;
  49. Completer<void> _mlController = Completer<void>();
  50. get hasInitialized => _hasInitialized;
  51. Future<void> init({bool shouldSyncImmediately = false}) async {
  52. if (!LocalSettings.instance.hasEnabledMagicSearch()) {
  53. return;
  54. }
  55. if (_hasInitialized) {
  56. _logger.info("Initialized already");
  57. return;
  58. }
  59. _hasInitialized = true;
  60. final shouldDownloadOverMobileData =
  61. Configuration.instance.shouldBackupOverMobileData();
  62. _currentModel = await _getCurrentModel();
  63. _mlFramework = _currentModel == Model.onnxClip
  64. ? ONNX(shouldDownloadOverMobileData)
  65. : GGML(shouldDownloadOverMobileData);
  66. await EmbeddingsDB.instance.init();
  67. await EmbeddingStore.instance.init();
  68. await _loadEmbeddings();
  69. Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
  70. _embeddingLoaderDebouncer.run(() async {
  71. await _loadEmbeddings();
  72. });
  73. });
  74. Bus.instance.on<DiffSyncCompleteEvent>().listen((event) {
  75. // Diff sync is complete, we can now pull embeddings from remote
  76. unawaited(sync());
  77. });
  78. if (Configuration.instance.hasConfiguredAccount() &&
  79. kShouldPushEmbeddings) {
  80. unawaited(EmbeddingStore.instance.pushEmbeddings());
  81. }
  82. // ignore: unawaited_futures
  83. _loadModels().then((v) async {
  84. _logger.info("Getting text embedding");
  85. await _getTextEmbedding("warm up text encoder");
  86. _logger.info("Got text embedding");
  87. });
  88. // Adding to queue only on init?
  89. Bus.instance.on<FileUploadedEvent>().listen((event) async {
  90. _addToQueue(event.file);
  91. });
  92. if (shouldSyncImmediately) {
  93. unawaited(sync());
  94. }
  95. if (Platform.isAndroid) {
  96. Bus.instance.on<MachineLearningControlEvent>().listen((event) {
  97. if (event.shouldRun) {
  98. _startIndexing();
  99. } else {
  100. _pauseIndexing();
  101. }
  102. });
  103. } else {
  104. _startIndexing();
  105. }
  106. }
  107. Future<void> release() async {
  108. if (_frameworkInitialization.isCompleted) {
  109. await _mlFramework.release();
  110. }
  111. }
  112. Future<void> sync() async {
  113. if (_isSyncing) {
  114. return;
  115. }
  116. _isSyncing = true;
  117. await EmbeddingStore.instance.pullEmbeddings(_currentModel);
  118. await _backFill();
  119. _isSyncing = false;
  120. }
  121. Future<List<EnteFile>> search(String query) async {
  122. if (!LocalSettings.instance.hasEnabledMagicSearch() ||
  123. !_frameworkInitialization.isCompleted) {
  124. return [];
  125. }
  126. if (_ongoingRequest == null) {
  127. _ongoingRequest = _getMatchingFiles(query).then((result) {
  128. _ongoingRequest = null;
  129. if (_nextQuery != null) {
  130. final next = _nextQuery;
  131. _nextQuery = null;
  132. search(next!.query).then((nextResult) {
  133. next.completer.complete(nextResult);
  134. });
  135. }
  136. return result;
  137. });
  138. return _ongoingRequest!;
  139. } else {
  140. // If there's an ongoing request, create or replace the nextCompleter.
  141. _logger.info("Queuing query $query");
  142. await _nextQuery?.completer.future
  143. .timeout(const Duration(seconds: 0)); // Cancels the previous future.
  144. _nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
  145. return _nextQuery!.completer.future;
  146. }
  147. }
  148. Future<IndexStatus> getIndexStatus() async {
  149. return IndexStatus(
  150. _cachedEmbeddings.length,
  151. (await _getFileIDsToBeIndexed()).length,
  152. );
  153. }
  154. InitializationState getFrameworkInitializationState() {
  155. if (!_hasInitialized) {
  156. return InitializationState.notInitialized;
  157. }
  158. return _mlFramework.initializationState;
  159. }
  160. Future<void> clearIndexes() async {
  161. await EmbeddingStore.instance.clearEmbeddings(_currentModel);
  162. _logger.info("Indexes cleared for $_currentModel");
  163. }
  164. Future<void> _loadEmbeddings() async {
  165. _logger.info("Pulling cached embeddings");
  166. final startTime = DateTime.now();
  167. _cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel);
  168. final endTime = DateTime.now();
  169. _logger.info(
  170. "Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
  171. );
  172. _logger.info("Cached embeddings: " + _cachedEmbeddings.length.toString());
  173. }
  174. Future<void> _backFill() async {
  175. if (!LocalSettings.instance.hasEnabledMagicSearch() ||
  176. !MLFramework.kImageEncoderEnabled) {
  177. return;
  178. }
  179. await _frameworkInitialization.future;
  180. _logger.info("Attempting backfill for image embeddings");
  181. final fileIDs = await _getFileIDsToBeIndexed();
  182. final files = await FilesDB.instance.getUploadedFiles(fileIDs);
  183. _logger.info(files.length.toString() + " to be embedded");
  184. // await _cacheThumbnails(files);
  185. _queue.addAll(files);
  186. unawaited(_pollQueue());
  187. }
  188. Future<void> _cacheThumbnails(List<EnteFile> files) async {
  189. int counter = 0;
  190. const batchSize = 100;
  191. for (var i = 0; i < files.length;) {
  192. final futures = <Future>[];
  193. for (var j = 0; j < batchSize && i < files.length; j++, i++) {
  194. futures.add(getThumbnail(files[i]));
  195. }
  196. await Future.wait(futures);
  197. counter += futures.length;
  198. _logger.info("$counter/${files.length} thumbnails cached");
  199. }
  200. }
  201. Future<List<int>> _getFileIDsToBeIndexed() async {
  202. final uploadedFileIDs = await FilesDB.instance
  203. .getOwnedFileIDs(Configuration.instance.getUserID()!);
  204. final embeddedFileIDs = _cachedEmbeddings.map((e) => e.fileID).toSet();
  205. uploadedFileIDs.removeWhere(
  206. (id) => embeddedFileIDs.contains(id),
  207. );
  208. return uploadedFileIDs;
  209. }
  210. Future<void> clearQueue() async {
  211. _queue.clear();
  212. }
  213. Future<List<EnteFile>> _getMatchingFiles(String query) async {
  214. final textEmbedding = await _getTextEmbedding(query);
  215. final queryResults = await _getScores(textEmbedding);
  216. final filesMap = await FilesDB.instance
  217. .getFilesFromIDs(queryResults.map((e) => e.id).toList());
  218. final results = <EnteFile>[];
  219. final ignoredCollections =
  220. CollectionsService.instance.getHiddenCollectionIds();
  221. for (final result in queryResults) {
  222. final file = filesMap[result.id];
  223. if (file != null && !ignoredCollections.contains(file.collectionID)) {
  224. results.add(filesMap[result.id]!);
  225. }
  226. }
  227. _logger.info(results.length.toString() + " results");
  228. return results;
  229. }
  230. void _addToQueue(EnteFile file) {
  231. if (!LocalSettings.instance.hasEnabledMagicSearch()) {
  232. return;
  233. }
  234. _logger.info("Adding " + file.toString() + " to the queue");
  235. _queue.add(file);
  236. _pollQueue();
  237. }
  238. Future<void> _loadModels() async {
  239. _logger.info("Initializing ML framework");
  240. try {
  241. await _mlFramework.init();
  242. _frameworkInitialization.complete(true);
  243. } catch (e, s) {
  244. _logger.severe("ML framework initialization failed", e, s);
  245. }
  246. _logger.info("ML framework initialized");
  247. }
  248. Future<void> _pollQueue() async {
  249. if (_isComputingEmbeddings) {
  250. return;
  251. }
  252. _isComputingEmbeddings = true;
  253. while (_queue.isNotEmpty) {
  254. await computeImageEmbedding(_queue.removeLast());
  255. }
  256. _isComputingEmbeddings = false;
  257. }
  258. Future<void> computeImageEmbedding(EnteFile file) async {
  259. if (!MLFramework.kImageEncoderEnabled) {
  260. return;
  261. }
  262. if (!_frameworkInitialization.isCompleted) {
  263. return;
  264. }
  265. if (!_mlController.isCompleted) {
  266. _logger.info("Waiting for a green signal from controller...");
  267. await _mlController.future;
  268. }
  269. try {
  270. final thumbnail = await getThumbnailForUploadedFile(file);
  271. if (thumbnail == null) {
  272. _logger.warning("Could not get thumbnail for $file");
  273. return;
  274. }
  275. final filePath = thumbnail.path;
  276. _logger.info("Running clip over $file");
  277. final result = await _mlFramework.getImageEmbedding(filePath);
  278. if (result.length != kEmbeddingLength) {
  279. _logger.severe("Discovered incorrect embedding for $file - $result");
  280. return;
  281. }
  282. final embedding = Embedding(
  283. fileID: file.uploadedFileID!,
  284. model: _currentModel,
  285. embedding: result,
  286. );
  287. await EmbeddingStore.instance.storeEmbedding(
  288. file,
  289. embedding,
  290. );
  291. } catch (e, s) {
  292. _logger.severe(e, s);
  293. }
  294. }
  295. Future<List<double>> _getTextEmbedding(String query) async {
  296. _logger.info("Searching for " + query);
  297. final cachedResult = _queryCache.get(query);
  298. if (cachedResult != null) {
  299. return cachedResult;
  300. }
  301. try {
  302. final result = await _mlFramework.getTextEmbedding(query);
  303. _queryCache.put(query, result);
  304. return result;
  305. } catch (e) {
  306. _logger.severe("Could not get text embedding", e);
  307. return [];
  308. }
  309. }
  310. Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
  311. final startTime = DateTime.now();
  312. final List<QueryResult> queryResults = await _computer.compute(
  313. computeBulkScore,
  314. param: {
  315. "imageEmbeddings": _cachedEmbeddings,
  316. "textEmbedding": textEmbedding,
  317. },
  318. taskName: "computeBulkScore",
  319. );
  320. final endTime = DateTime.now();
  321. _logger.info(
  322. "computingScores took: " +
  323. (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
  324. .toString() +
  325. "ms",
  326. );
  327. return queryResults;
  328. }
  329. Future<Model> _getCurrentModel() async {
  330. if (await isGrapheneOS()) {
  331. return Model.ggmlClip;
  332. } else {
  333. return Model.onnxClip;
  334. }
  335. }
  336. void _startIndexing() {
  337. _logger.info("Start indexing");
  338. if (!_mlController.isCompleted) {
  339. _mlController.complete();
  340. }
  341. }
  342. void _pauseIndexing() {
  343. if (_mlController.isCompleted) {
  344. _logger.info("Pausing indexing");
  345. _mlController = Completer<void>();
  346. }
  347. }
  348. }
  349. List<QueryResult> computeBulkScore(Map args) {
  350. final queryResults = <QueryResult>[];
  351. final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
  352. final textEmbedding = args["textEmbedding"] as List<double>;
  353. for (final imageEmbedding in imageEmbeddings) {
  354. final score = computeScore(
  355. imageEmbedding.embedding,
  356. textEmbedding,
  357. );
  358. if (score >= SemanticSearchService.kScoreThreshold) {
  359. queryResults.add(QueryResult(imageEmbedding.fileID, score));
  360. }
  361. }
  362. queryResults.sort((first, second) => second.score.compareTo(first.score));
  363. return queryResults;
  364. }
  365. double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
  366. assert(
  367. imageEmbedding.length == textEmbedding.length,
  368. "The two embeddings should have the same length",
  369. );
  370. double score = 0;
  371. for (int index = 0; index < imageEmbedding.length; index++) {
  372. score += imageEmbedding[index] * textEmbedding[index];
  373. }
  374. return score;
  375. }
  376. class QueryResult {
  377. final int id;
  378. final double score;
  379. QueryResult(this.id, this.score);
  380. }
  381. class PendingQuery {
  382. final String query;
  383. final Completer<List<EnteFile>> completer;
  384. PendingQuery(this.query, this.completer);
  385. }
  386. class IndexStatus {
  387. final int indexedItems, pendingItems;
  388. IndexStatus(this.indexedItems, this.pendingItems);
  389. }