diff --git a/lib/services/semantic_search_service.dart b/lib/services/semantic_search_service.dart index 23f2d3303..b4e56a0c1 100644 --- a/lib/services/semantic_search_service.dart +++ b/lib/services/semantic_search_service.dart @@ -18,6 +18,8 @@ class SemanticSearchService { SemanticSearchService._privateConstructor(); static final Computer _computer = Computer.shared(); + static const int batchSize = 4; + bool hasLoaded = false; final _logger = Logger("SemanticSearchService"); Future>? _ongoingRequest; @@ -123,7 +125,7 @@ class SemanticSearchService { Future _loadModel() async { const modelPath = - "assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.q4_0.bin"; + "assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin"; final path = await _getAccessiblePathForAsset(modelPath, "model.bin"); final startTime = DateTime.now(); @@ -155,51 +157,75 @@ class SemanticSearchService { Future _computeMissingEmbeddings() async { final files = await FilesDB.instance.getFilesWithoutEmbeddings(); _logger.info(files.length.toString() + " pending to be embedded"); + int counter = 0; + final List batch = []; for (final file in files) { - await _computeImageEmbedding(file); + if (counter < batchSize) { + batch.add(file); + counter++; + } else { + await _computeImageEmbeddings(batch); + counter = 0; + batch.clear(); + } } } - Future _computeImageEmbedding(EnteFile file) async { + Future _computeImageEmbeddings(List files) async { if (!hasLoaded) { return; } - // _logger.info("Running clip"); - final imagePath = (await getThumbnailFile(file))!.path; + _logger.info("Running clip over " + files.length.toString() + " items"); + final List filesToBeIndexed = []; + final List filePaths = []; final startTime = DateTime.now(); - // ignore: prefer_typing_uninitialized_variables - var imageEmbedding; final embeddings = await FilesDB.instance.getAllEmbeddings(); - bool hasCachedEmbedding = false; - for (final embedding in embeddings) { - if (embedding.id == file.generatedID) { - imageEmbedding = embedding.embedding; - hasCachedEmbedding = true; - _logger.info("Found cached embedding"); + for (final file in files) { + bool hasCachedEmbedding = false; + for (final embedding in embeddings) { + if (embedding.id == file.generatedID) { + _logger.info("Found cached embedding"); + hasCachedEmbedding = true; + } + } + if (!hasCachedEmbedding) { + filesToBeIndexed.add(file); + filePaths.add((await getThumbnailFile(file))!.path); } } - if (!hasCachedEmbedding) { - imageEmbedding ??= await _computer.compute( - createImageEmbedding, - param: { - "imagePath": imagePath, - }, - taskName: "createImageEmbedding", + final imageEmbeddings = await _computer.compute( + createImageEmbeddings, + param: { + "imagePaths": filePaths, + }, + taskName: "createImageEmbedding", + ) as List>; + for (int i = 0; i < imageEmbeddings.length; i++) { + await FilesDB.instance.insertEmbedding( + Embedding( + files[i].generatedID!, + imageEmbeddings[i], + -1, + ), ); - await FilesDB.instance - .insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1)); } final endTime = DateTime.now(); _logger.info( - "createImageEmbedding took: " + + "createImageEmbeddings took: " + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) .toString() + - "ms", + "ms for " + + imageEmbeddings.length.toString() + + " items", ); } } +List> createImageEmbeddings(Map args) { + return CLIP.createBatchImageEmbedding(args["imagePaths"]); +} + List createImageEmbedding(Map args) { return CLIP.createImageEmbedding(args["imagePath"]); } diff --git a/plugins/clip_ggml b/plugins/clip_ggml index 84c5a499e..c15bc3408 160000 --- a/plugins/clip_ggml +++ b/plugins/clip_ggml @@ -1 +1 @@ -Subproject commit 84c5a499e700cc72e546c42d771d4a92a7a21ec2 +Subproject commit c15bc340878fbc679027083c962f1fcbfe68b263