semantic_search_service.dart 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. import "dart:async";
  2. import "dart:collection";
  3. import "dart:math" show min;
  4. import "package:computer/computer.dart";
  5. import "package:logging/logging.dart";
  6. import "package:photos/core/cache/lru_map.dart";
  7. import "package:photos/core/configuration.dart";
  8. import "package:photos/core/event_bus.dart";
  9. import "package:photos/db/embeddings_db.dart";
  10. import "package:photos/db/files_db.dart";
  11. import "package:photos/events/diff_sync_complete_event.dart";
  12. import 'package:photos/events/embedding_updated_event.dart';
  13. import "package:photos/events/file_uploaded_event.dart";
  14. import "package:photos/events/machine_learning_control_event.dart";
  15. import "package:photos/models/embedding.dart";
  16. import "package:photos/models/file/file.dart";
  17. import "package:photos/services/collections_service.dart";
  18. import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart';
  19. import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart';
  20. import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
  21. import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart';
  22. import "package:photos/utils/debouncer.dart";
  23. import "package:photos/utils/device_info.dart";
  24. import "package:photos/utils/local_settings.dart";
  25. import "package:photos/utils/thumbnail_util.dart";
  26. class SemanticSearchService {
  27. SemanticSearchService._privateConstructor();
  28. static final SemanticSearchService instance =
  29. SemanticSearchService._privateConstructor();
  30. static final Computer _computer = Computer.shared();
  31. static final LRUMap<String, List<double>> _queryCache = LRUMap(20);
  32. static const kEmbeddingLength = 512;
  33. static const kScoreThreshold = 0.23;
  34. static const kShouldPushEmbeddings = true;
  35. static const kDebounceDuration = Duration(milliseconds: 4000);
  36. final _logger = Logger("SemanticSearchService");
  37. final _queue = Queue<EnteFile>();
  38. final _frameworkInitialization = Completer<bool>();
  39. final _embeddingLoaderDebouncer =
  40. Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
  41. late Model _currentModel;
  42. late MLFramework _mlFramework;
  43. bool _hasInitialized = false;
  44. bool _isComputingEmbeddings = false;
  45. bool _isSyncing = false;
  46. List<Embedding> _cachedEmbeddings = <Embedding>[];
  47. Future<(String, List<EnteFile>)>? _searchScreenRequest;
  48. String? _latestPendingQuery;
  49. Completer<void> _mlController = Completer<void>();
  50. get hasInitialized => _hasInitialized;
  51. Future<void> init({bool shouldSyncImmediately = false}) async {
  52. if (!LocalSettings.instance.hasEnabledMagicSearch()) {
  53. return;
  54. }
  55. if (_hasInitialized) {
  56. _logger.info("Initialized already");
  57. return;
  58. }
  59. _hasInitialized = true;
  60. final shouldDownloadOverMobileData =
  61. Configuration.instance.shouldBackupOverMobileData();
  62. _currentModel = await _getCurrentModel();
  63. _mlFramework = _currentModel == Model.onnxClip
  64. ? ONNX(shouldDownloadOverMobileData)
  65. : GGML(shouldDownloadOverMobileData);
  66. await EmbeddingStore.instance.init();
  67. await EmbeddingsDB.instance.init();
  68. await _loadEmbeddings();
  69. Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
  70. _embeddingLoaderDebouncer.run(() async {
  71. await _loadEmbeddings();
  72. });
  73. });
  74. Bus.instance.on<DiffSyncCompleteEvent>().listen((event) {
  75. // Diff sync is complete, we can now pull embeddings from remote
  76. unawaited(sync());
  77. });
  78. if (Configuration.instance.hasConfiguredAccount() &&
  79. kShouldPushEmbeddings) {
  80. unawaited(EmbeddingStore.instance.pushEmbeddings());
  81. }
  82. // ignore: unawaited_futures
  83. _loadModels().then((v) async {
  84. _logger.info("Getting text embedding");
  85. await _getTextEmbedding("warm up text encoder");
  86. _logger.info("Got text embedding");
  87. });
  88. // Adding to queue only on init?
  89. Bus.instance.on<FileUploadedEvent>().listen((event) async {
  90. _addToQueue(event.file);
  91. });
  92. if (shouldSyncImmediately) {
  93. unawaited(sync());
  94. }
  95. Bus.instance.on<MachineLearningControlEvent>().listen((event) {
  96. if (event.shouldRun) {
  97. _startIndexing();
  98. } else {
  99. _pauseIndexing();
  100. }
  101. });
  102. }
  103. Future<void> release() async {
  104. if (_frameworkInitialization.isCompleted) {
  105. await _mlFramework.release();
  106. }
  107. }
  108. Future<void> sync() async {
  109. if (_isSyncing) {
  110. return;
  111. }
  112. _isSyncing = true;
  113. final fetchCompleted =
  114. await EmbeddingStore.instance.pullEmbeddings(_currentModel);
  115. if (fetchCompleted) {
  116. await _backFill();
  117. }
  118. _isSyncing = false;
  119. }
  120. // searchScreenQuery should only be used for the user initiate query on the search screen.
  121. // If there are multiple call tho this method, then for all the calls, the result will be the same as the last query.
  122. Future<(String, List<EnteFile>)> searchScreenQuery(String query) async {
  123. if (!LocalSettings.instance.hasEnabledMagicSearch() ||
  124. !_frameworkInitialization.isCompleted) {
  125. return (query, <EnteFile>[]);
  126. }
  127. // If there's an ongoing request, just update the last query and return its future.
  128. if (_searchScreenRequest != null) {
  129. _latestPendingQuery = query;
  130. return _searchScreenRequest!;
  131. } else {
  132. // No ongoing request, start a new search.
  133. _searchScreenRequest = _getMatchingFiles(query).then((result) {
  134. // Search completed, reset the ongoing request.
  135. _searchScreenRequest = null;
  136. // If there was a new query during the last search, start a new search with the last query.
  137. if (_latestPendingQuery != null) {
  138. final String newQuery = _latestPendingQuery!;
  139. _latestPendingQuery = null; // Reset last query.
  140. // Recursively call search with the latest query.
  141. return searchScreenQuery(newQuery);
  142. }
  143. return (query, result);
  144. });
  145. return _searchScreenRequest!;
  146. }
  147. }
  148. Future<IndexStatus> getIndexStatus() async {
  149. final indexableFileIDs = await FilesDB.instance
  150. .getOwnedFileIDs(Configuration.instance.getUserID()!);
  151. return IndexStatus(
  152. min(_cachedEmbeddings.length, indexableFileIDs.length),
  153. (await _getFileIDsToBeIndexed()).length,
  154. );
  155. }
  156. InitializationState getFrameworkInitializationState() {
  157. if (!_hasInitialized) {
  158. return InitializationState.notInitialized;
  159. }
  160. return _mlFramework.initializationState;
  161. }
  162. Future<void> clearIndexes() async {
  163. await EmbeddingStore.instance.clearEmbeddings(_currentModel);
  164. _logger.info("Indexes cleared for $_currentModel");
  165. }
  166. Future<void> _loadEmbeddings() async {
  167. _logger.info("Pulling cached embeddings");
  168. final startTime = DateTime.now();
  169. _cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel);
  170. final endTime = DateTime.now();
  171. _logger.info(
  172. "Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
  173. );
  174. Bus.instance.fire(EmbeddingCacheUpdatedEvent());
  175. _logger.info("Cached embeddings: " + _cachedEmbeddings.length.toString());
  176. }
  177. Future<void> _backFill() async {
  178. if (!LocalSettings.instance.hasEnabledMagicSearch() ||
  179. !MLFramework.kImageEncoderEnabled) {
  180. return;
  181. }
  182. await _frameworkInitialization.future;
  183. _logger.info("Attempting backfill for image embeddings");
  184. final fileIDs = await _getFileIDsToBeIndexed();
  185. final files = await FilesDB.instance.getUploadedFiles(fileIDs);
  186. _logger.info(files.length.toString() + " to be embedded");
  187. // await _cacheThumbnails(files);
  188. _queue.addAll(files);
  189. unawaited(_pollQueue());
  190. }
  191. Future<void> _cacheThumbnails(List<EnteFile> files) async {
  192. int counter = 0;
  193. const batchSize = 100;
  194. for (var i = 0; i < files.length;) {
  195. final futures = <Future>[];
  196. for (var j = 0; j < batchSize && i < files.length; j++, i++) {
  197. futures.add(getThumbnail(files[i]));
  198. }
  199. await Future.wait(futures);
  200. counter += futures.length;
  201. _logger.info("$counter/${files.length} thumbnails cached");
  202. }
  203. }
  204. Future<List<int>> _getFileIDsToBeIndexed() async {
  205. final uploadedFileIDs = await FilesDB.instance
  206. .getOwnedFileIDs(Configuration.instance.getUserID()!);
  207. final embeddedFileIDs =
  208. await EmbeddingsDB.instance.getFileIDs(_currentModel);
  209. uploadedFileIDs.removeWhere(
  210. (id) => embeddedFileIDs.contains(id),
  211. );
  212. return uploadedFileIDs;
  213. }
  214. Future<void> clearQueue() async {
  215. _queue.clear();
  216. }
  217. Future<List<EnteFile>> _getMatchingFiles(String query) async {
  218. final textEmbedding = await _getTextEmbedding(query);
  219. final queryResults = await _getScores(textEmbedding);
  220. final filesMap = await FilesDB.instance
  221. .getFilesFromIDs(queryResults.map((e) => e.id).toList());
  222. final results = <EnteFile>[];
  223. final ignoredCollections =
  224. CollectionsService.instance.getHiddenCollectionIds();
  225. final deletedEntries = <int>[];
  226. for (final result in queryResults) {
  227. final file = filesMap[result.id];
  228. if (file != null && !ignoredCollections.contains(file.collectionID)) {
  229. results.add(file);
  230. }
  231. if (file == null) {
  232. deletedEntries.add(result.id);
  233. }
  234. }
  235. _logger.info(results.length.toString() + " results");
  236. if (deletedEntries.isNotEmpty) {
  237. unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries));
  238. }
  239. return results;
  240. }
  241. void _addToQueue(EnteFile file) {
  242. if (!LocalSettings.instance.hasEnabledMagicSearch()) {
  243. return;
  244. }
  245. _logger.info("Adding " + file.toString() + " to the queue");
  246. _queue.add(file);
  247. _pollQueue();
  248. }
  249. Future<void> _loadModels() async {
  250. _logger.info("Initializing ML framework");
  251. try {
  252. await _mlFramework.init();
  253. _frameworkInitialization.complete(true);
  254. } catch (e, s) {
  255. _logger.severe("ML framework initialization failed", e, s);
  256. }
  257. _logger.info("ML framework initialized");
  258. }
  259. Future<void> _pollQueue() async {
  260. if (_isComputingEmbeddings) {
  261. return;
  262. }
  263. _isComputingEmbeddings = true;
  264. while (_queue.isNotEmpty) {
  265. await computeImageEmbedding(_queue.removeLast());
  266. }
  267. _isComputingEmbeddings = false;
  268. }
  269. Future<void> computeImageEmbedding(EnteFile file) async {
  270. if (!MLFramework.kImageEncoderEnabled) {
  271. return;
  272. }
  273. if (!_frameworkInitialization.isCompleted) {
  274. return;
  275. }
  276. if (!_mlController.isCompleted) {
  277. _logger.info("Waiting for a green signal from controller...");
  278. await _mlController.future;
  279. }
  280. try {
  281. final thumbnail = await getThumbnailForUploadedFile(file);
  282. if (thumbnail == null) {
  283. _logger.warning("Could not get thumbnail for $file");
  284. return;
  285. }
  286. final filePath = thumbnail.path;
  287. _logger.info("Running clip over $file");
  288. final result = await _mlFramework.getImageEmbedding(filePath);
  289. if (result.length != kEmbeddingLength) {
  290. _logger.severe("Discovered incorrect embedding for $file - $result");
  291. return;
  292. }
  293. final embedding = Embedding(
  294. fileID: file.uploadedFileID!,
  295. model: _currentModel,
  296. embedding: result,
  297. );
  298. await EmbeddingStore.instance.storeEmbedding(
  299. file,
  300. embedding,
  301. );
  302. } catch (e, s) {
  303. _logger.severe(e, s);
  304. }
  305. }
  306. Future<List<double>> _getTextEmbedding(String query) async {
  307. _logger.info("Searching for " + query);
  308. final cachedResult = _queryCache.get(query);
  309. if (cachedResult != null) {
  310. return cachedResult;
  311. }
  312. try {
  313. final result = await _mlFramework.getTextEmbedding(query);
  314. _queryCache.put(query, result);
  315. return result;
  316. } catch (e) {
  317. _logger.severe("Could not get text embedding", e);
  318. return [];
  319. }
  320. }
  321. Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
  322. final startTime = DateTime.now();
  323. final List<QueryResult> queryResults = await _computer.compute(
  324. computeBulkScore,
  325. param: {
  326. "imageEmbeddings": _cachedEmbeddings,
  327. "textEmbedding": textEmbedding,
  328. },
  329. taskName: "computeBulkScore",
  330. );
  331. final endTime = DateTime.now();
  332. _logger.info(
  333. "computingScores took: " +
  334. (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
  335. .toString() +
  336. "ms",
  337. );
  338. return queryResults;
  339. }
  340. Future<Model> _getCurrentModel() async {
  341. if (await isGrapheneOS()) {
  342. return Model.ggmlClip;
  343. } else {
  344. return Model.onnxClip;
  345. }
  346. }
  347. void _startIndexing() {
  348. _logger.info("Start indexing");
  349. if (!_mlController.isCompleted) {
  350. _mlController.complete();
  351. }
  352. }
  353. void _pauseIndexing() {
  354. if (_mlController.isCompleted) {
  355. _logger.info("Pausing indexing");
  356. _mlController = Completer<void>();
  357. }
  358. }
  359. }
  360. List<QueryResult> computeBulkScore(Map args) {
  361. final queryResults = <QueryResult>[];
  362. final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
  363. final textEmbedding = args["textEmbedding"] as List<double>;
  364. for (final imageEmbedding in imageEmbeddings) {
  365. final score = computeScore(
  366. imageEmbedding.embedding,
  367. textEmbedding,
  368. );
  369. if (score >= SemanticSearchService.kScoreThreshold) {
  370. queryResults.add(QueryResult(imageEmbedding.fileID, score));
  371. }
  372. }
  373. queryResults.sort((first, second) => second.score.compareTo(first.score));
  374. return queryResults;
  375. }
  376. double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
  377. assert(
  378. imageEmbedding.length == textEmbedding.length,
  379. "The two embeddings should have the same length",
  380. );
  381. double score = 0;
  382. for (int index = 0; index < imageEmbedding.length; index++) {
  383. score += imageEmbedding[index] * textEmbedding[index];
  384. }
  385. return score;
  386. }
  387. class QueryResult {
  388. final int id;
  389. final double score;
  390. QueryResult(this.id, this.score);
  391. }
  392. class IndexStatus {
  393. final int indexedItems, pendingItems;
  394. IndexStatus(this.indexedItems, this.pendingItems);
  395. }