Convert cached magic files to GenericSearchResult and surface it on search tab
This commit is contained in:
parent
9134e75bf0
commit
71f1494444
6 changed files with 105 additions and 15 deletions
|
@ -13,6 +13,7 @@ import "package:photos/models/search/generic_search_result.dart";
|
|||
import "package:photos/models/search/search_result.dart";
|
||||
import "package:photos/models/typedefs.dart";
|
||||
import "package:photos/services/collections_service.dart";
|
||||
import "package:photos/services/magic_cache_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
import "package:photos/ui/viewer/gallery/collection_page.dart";
|
||||
import "package:photos/ui/viewer/location/add_location_sheet.dart";
|
||||
|
@ -40,7 +41,7 @@ enum SectionType {
|
|||
face,
|
||||
location,
|
||||
// Grouping based on ML or manual tagging
|
||||
content,
|
||||
magic,
|
||||
// includes year, month , day, event ResultType
|
||||
moment,
|
||||
album,
|
||||
|
@ -56,7 +57,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
switch (this) {
|
||||
case SectionType.face:
|
||||
return S.of(context).faces;
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
return S.of(context).contents;
|
||||
case SectionType.moment:
|
||||
return S.of(context).moments;
|
||||
|
@ -77,7 +78,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
switch (this) {
|
||||
case SectionType.face:
|
||||
return S.of(context).searchFaceEmptySection;
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
return "Contents";
|
||||
case SectionType.moment:
|
||||
return S.of(context).searchDatesEmptySection;
|
||||
|
@ -100,7 +101,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
switch (this) {
|
||||
case SectionType.face:
|
||||
return false;
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
return false;
|
||||
case SectionType.moment:
|
||||
return false;
|
||||
|
@ -121,7 +122,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
switch (this) {
|
||||
case SectionType.face:
|
||||
return true;
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
return false;
|
||||
case SectionType.moment:
|
||||
return false;
|
||||
|
@ -143,7 +144,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
case SectionType.face:
|
||||
// todo: later
|
||||
return "Setup";
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
// todo: later
|
||||
return "Add tags";
|
||||
case SectionType.moment:
|
||||
|
@ -165,7 +166,7 @@ extension SectionTypeExtensions on SectionType {
|
|||
switch (this) {
|
||||
case SectionType.face:
|
||||
return Icons.adaptive.arrow_forward_outlined;
|
||||
case SectionType.content:
|
||||
case SectionType.magic:
|
||||
return null;
|
||||
case SectionType.moment:
|
||||
return null;
|
||||
|
@ -247,8 +248,8 @@ extension SectionTypeExtensions on SectionType {
|
|||
case SectionType.face:
|
||||
return Future.value(List<GenericSearchResult>.empty());
|
||||
|
||||
case SectionType.content:
|
||||
return Future.value(List<GenericSearchResult>.empty());
|
||||
case SectionType.magic:
|
||||
return MagicCacheService.instance.getMagicGenericSearchResult();
|
||||
|
||||
case SectionType.moment:
|
||||
return SearchService.instance.getRandomMomentsSearchResults(context);
|
||||
|
|
|
@ -267,6 +267,48 @@ class SemanticSearchService {
|
|||
return results;
|
||||
}
|
||||
|
||||
Future<List<int>> getMatchingFileIDs(String query, double minScore) async {
|
||||
final textEmbedding = await _getTextEmbedding(query);
|
||||
|
||||
final queryResults =
|
||||
await _getScores(textEmbedding, scoreThreshold: minScore);
|
||||
|
||||
final filesMap = await FilesDB.instance.getFilesFromIDs(
|
||||
queryResults
|
||||
.map(
|
||||
(e) => e.id,
|
||||
)
|
||||
.toList(),
|
||||
);
|
||||
final results = <EnteFile>[];
|
||||
|
||||
final ignoredCollections =
|
||||
CollectionsService.instance.getHiddenCollectionIds();
|
||||
final deletedEntries = <int>[];
|
||||
for (final result in queryResults) {
|
||||
final file = filesMap[result.id];
|
||||
if (file != null && !ignoredCollections.contains(file.collectionID)) {
|
||||
results.add(file);
|
||||
}
|
||||
if (file == null) {
|
||||
deletedEntries.add(result.id);
|
||||
}
|
||||
}
|
||||
|
||||
_logger.info(results.length.toString() + " results");
|
||||
|
||||
if (deletedEntries.isNotEmpty) {
|
||||
unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries));
|
||||
}
|
||||
|
||||
final matchingFileIDs = <int>[];
|
||||
for (EnteFile file in results) {
|
||||
matchingFileIDs.add(file.uploadedFileID!);
|
||||
}
|
||||
|
||||
return matchingFileIDs;
|
||||
}
|
||||
|
||||
void _addToQueue(EnteFile file) {
|
||||
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
|
||||
return;
|
||||
|
@ -355,13 +397,17 @@ class SemanticSearchService {
|
|||
}
|
||||
}
|
||||
|
||||
Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
|
||||
Future<List<QueryResult>> _getScores(
|
||||
List<double> textEmbedding, {
|
||||
double? scoreThreshold,
|
||||
}) async {
|
||||
final startTime = DateTime.now();
|
||||
final List<QueryResult> queryResults = await _computer.compute(
|
||||
computeBulkScore,
|
||||
param: {
|
||||
"imageEmbeddings": _cachedEmbeddings,
|
||||
"textEmbedding": textEmbedding,
|
||||
"scoreThreshold": scoreThreshold,
|
||||
},
|
||||
taskName: "computeBulkScore",
|
||||
);
|
||||
|
@ -402,12 +448,14 @@ List<QueryResult> computeBulkScore(Map args) {
|
|||
final queryResults = <QueryResult>[];
|
||||
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
|
||||
final textEmbedding = args["textEmbedding"] as List<double>;
|
||||
final scoreThreshold = args["scoreThreshold"] as double? ??
|
||||
SemanticSearchService.kScoreThreshold;
|
||||
for (final imageEmbedding in imageEmbeddings) {
|
||||
final score = computeScore(
|
||||
imageEmbedding.embedding,
|
||||
textEmbedding,
|
||||
);
|
||||
if (score >= SemanticSearchService.kScoreThreshold) {
|
||||
if (score >= scoreThreshold) {
|
||||
queryResults.add(QueryResult(imageEmbedding.fileID, score));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,11 @@ import "dart:convert";
|
|||
import 'dart:math';
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/search/generic_search_result.dart";
|
||||
import "package:photos/models/search/search_types.dart";
|
||||
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
import "package:shared_preferences/shared_preferences.dart";
|
||||
|
||||
const _promptsJson = {
|
||||
|
@ -64,6 +68,24 @@ class MagicCache {
|
|||
}
|
||||
}
|
||||
|
||||
extension MagicCacheServiceExtension on MagicCache {
|
||||
Future<GenericSearchResult> toGenericSearchResult() async {
|
||||
final allEnteFiles = await SearchService.instance.getAllFiles();
|
||||
final enteFilesInMagicCache = <EnteFile>[];
|
||||
for (EnteFile file in allEnteFiles) {
|
||||
if (file.uploadedFileID != null &&
|
||||
fileUploadedIDs.contains(file.uploadedFileID as int)) {
|
||||
enteFilesInMagicCache.add(file);
|
||||
}
|
||||
}
|
||||
return GenericSearchResult(
|
||||
ResultType.magic,
|
||||
title,
|
||||
enteFilesInMagicCache,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class MagicCacheService {
|
||||
static const _key = "magic_cache";
|
||||
late SharedPreferences prefs;
|
||||
|
@ -132,4 +154,19 @@ class MagicCacheService {
|
|||
}
|
||||
return numbers;
|
||||
}
|
||||
|
||||
Future<List<GenericSearchResult>> getMagicGenericSearchResult() async {
|
||||
await Future.delayed(const Duration(seconds: 10));
|
||||
final magicCaches = await getMagicCache();
|
||||
if (magicCaches == null) {
|
||||
_logger.info("No magic cache found");
|
||||
return [];
|
||||
}
|
||||
final List<GenericSearchResult> genericSearchResults = [];
|
||||
for (MagicCache magicCache in magicCaches) {
|
||||
final genericSearchResult = await magicCache.toGenericSearchResult();
|
||||
genericSearchResults.add(genericSearchResult);
|
||||
}
|
||||
return genericSearchResults;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,8 +79,7 @@ class _AllSectionsExamplesProviderState
|
|||
_logger.info("'_debounceTimer: reloading all sections in search tab");
|
||||
final allSectionsExamples = <Future<List<SearchResult>>>[];
|
||||
for (SectionType sectionType in SectionType.values) {
|
||||
if (sectionType == SectionType.face ||
|
||||
sectionType == SectionType.content) {
|
||||
if (sectionType == SectionType.face) {
|
||||
continue;
|
||||
}
|
||||
allSectionsExamples.add(
|
||||
|
|
|
@ -22,7 +22,7 @@ class _NoResultWidgetState extends State<NoResultWidget> {
|
|||
searchTypes = SectionType.values.toList(growable: true);
|
||||
// remove face and content sectionType
|
||||
searchTypes.remove(SectionType.face);
|
||||
searchTypes.remove(SectionType.content);
|
||||
searchTypes.remove(SectionType.magic);
|
||||
}
|
||||
|
||||
@override
|
||||
|
|
|
@ -78,7 +78,7 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
|
|||
final searchTypes = SectionType.values.toList(growable: true);
|
||||
// remove face and content sectionType
|
||||
searchTypes.remove(SectionType.face);
|
||||
searchTypes.remove(SectionType.content);
|
||||
// searchTypes.remove(SectionType.magic);
|
||||
return Padding(
|
||||
padding: const EdgeInsets.only(top: 8),
|
||||
child: Stack(
|
||||
|
@ -131,6 +131,11 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
|
|||
snapshot.data!.elementAt(index)
|
||||
as List<GenericSearchResult>,
|
||||
);
|
||||
case SectionType.magic:
|
||||
return MomentsSection(
|
||||
snapshot.data!.elementAt(index)
|
||||
as List<GenericSearchResult>,
|
||||
);
|
||||
default:
|
||||
const SizedBox.shrink();
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue