diff --git a/lib/services/semantic_search/semantic_search_service.dart b/lib/services/semantic_search/semantic_search_service.dart index dd01b5a61..096a55e99 100644 --- a/lib/services/semantic_search/semantic_search_service.dart +++ b/lib/services/semantic_search/semantic_search_service.dart @@ -110,17 +110,16 @@ class SemanticSearchService { ); startTime = DateTime.now(); - final queryResults = []; - for (final embedding in _cachedEmbeddings) { - final score = computeScore({ - "imageEmbedding": embedding.embedding, + + final List queryResults = await _computer.compute( + computeBulkScore, + param: { + "imageEmbeddings": _cachedEmbeddings, "textEmbedding": textEmbedding, - }); - if (score >= kScoreThreshold) { - queryResults.add(QueryResult(embedding.fileID, score)); - } - } - queryResults.sort((first, second) => second.score.compareTo(first.score)); + }, + taskName: "computeBulkScore", + ); + endTime = DateTime.now(); _logger.info( "computingScores took: " + @@ -260,11 +259,22 @@ List createTextEmbedding(Map args) { return CLIP.createTextEmbedding(args["text"]); } -double computeScore(Map args) { - return CLIP.computeScore( - args["imageEmbedding"] as List, - args["textEmbedding"] as List, - ); +List computeBulkScore(Map args) { + final queryResults = []; + final imageEmbeddings = args["imageEmbeddings"] as List; + final textEmbedding = args["textEmbedding"] as List; + for (final imageEmbedding in imageEmbeddings) { + final score = CLIP.computeScore( + imageEmbedding.embedding, + textEmbedding, + ); + if (score >= 0.23) { + queryResults.add(QueryResult(imageEmbedding.fileID, score)); + } + } + + queryResults.sort((first, second) => second.score.compareTo(first.score)); + return queryResults; } class QueryResult {