Search improvements (#1645)

This commit is contained in:
Vishnu Mohandas 2024-01-12 15:51:01 +05:30 committed by GitHub
commit f401199c74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 170 additions and 29 deletions

View file

@ -69,7 +69,7 @@ class MessageLookup extends MessageLookupByLibrary {
"Please drop an email to ${supportEmail} from your registered email address";
static String m16(count, storageSaved) =>
"Your have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)";
"You have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)";
static String m17(count, formattedSize) =>
"${count} files, ${formattedSize} each";
@ -1451,6 +1451,8 @@ class MessageLookup extends MessageLookupByLibrary {
"viewer": MessageLookupByLibrary.simpleMessage("Viewer"),
"visitWebToManage": MessageLookupByLibrary.simpleMessage(
"Please visit web.ente.io to manage your subscription"),
"waitingForWifi":
MessageLookupByLibrary.simpleMessage("Waiting for WiFi..."),
"weAreOpenSource":
MessageLookupByLibrary.simpleMessage("We are open source!"),
"weDontSupportEditingPhotosAndAlbumsThatYouDont":

View file

@ -704,6 +704,7 @@ class MessageLookup extends MessageLookupByLibrary {
MessageLookupByLibrary.simpleMessage("正在加载 EXIF 数据..."),
"loadingGallery": MessageLookupByLibrary.simpleMessage("正在加载图库..."),
"loadingMessage": MessageLookupByLibrary.simpleMessage("正在加载您的照片..."),
"loadingModel": MessageLookupByLibrary.simpleMessage("正在下载模型..."),
"localGallery": MessageLookupByLibrary.simpleMessage("本地相册"),
"location": MessageLookupByLibrary.simpleMessage("地理位置"),
"locationName": MessageLookupByLibrary.simpleMessage("地点名称"),

View file

@ -2886,6 +2886,16 @@ class S {
);
}
/// `Waiting for WiFi...`
String get waitingForWifi {
return Intl.message(
'Waiting for WiFi...',
name: 'waitingForWifi',
desc: '',
args: [],
);
}
/// `Status`
String get status {
return Intl.message(
@ -3523,10 +3533,10 @@ class S {
);
}
/// `Your have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)`
/// `You have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)`
String duplicateFileCountWithStorageSaved(int count, String storageSaved) {
return Intl.message(
'Your have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)',
'You have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)',
name: 'duplicateFileCountWithStorageSaved',
desc:
'The text to display when the user has successfully cleaned up duplicate files',

View file

@ -410,6 +410,7 @@
"magicSearch": "Magic search",
"magicSearchDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.",
"loadingModel": "Downloading models...",
"waitingForWifi": "Waiting for WiFi...",
"status": "Status",
"indexedItems": "Indexed items",
"pendingItems": "Pending items",

View file

@ -50,9 +50,15 @@ class EmbeddingStore {
Future<void> pushEmbeddings() async {
final pendingItems = await EmbeddingsDB.instance.getUnsyncedEmbeddings();
final fileMap = await FilesDB.instance
.getFilesFromIDs(pendingItems.map((e) => e.fileID).toList());
_logger.info("Pushing ${pendingItems.length} embeddings");
for (final item in pendingItems) {
final file = await FilesDB.instance.getAnyUploadedFile(item.fileID);
await _pushEmbedding(file!, item);
try {
await _pushEmbedding(fileMap[item.fileID]!, item);
} catch (e, s) {
_logger.severe(e, s);
}
}
}
@ -67,6 +73,7 @@ class EmbeddingStore {
}
Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
_logger.info("Pushing embedding for $file");
final encryptionKey = getFileKey(file);
final embeddingJSON = jsonEncode(embedding.embedding);
final encryptedEmbedding = await CryptoUtil.encryptChaCha(
@ -86,7 +93,7 @@ class EmbeddingStore {
"decryptionHeader": header,
},
);
final updationTime = response.data["updationTime"];
final updationTime = response.data["updatedAt"];
embedding.updationTime = updationTime;
await EmbeddingsDB.instance.put(embedding);
} catch (e, s) {
@ -138,7 +145,10 @@ class EmbeddingStore {
for (final embedding in remoteEmbeddings) {
final file = fileMap[embedding.fileID];
final fileKey = getFileKey(file!);
if (file == null) {
continue;
}
final fileKey = getFileKey(file);
final input = EmbeddingsDecoderInput(embedding, fileKey);
inputs.add(input);
}

View file

@ -11,6 +11,8 @@ class GGML extends MLFramework {
final _computer = Computer.shared();
final _logger = Logger("GGML");
GGML(super.shouldDownloadOverMobileData);
@override
String getImageModelRemotePath() {
return kModelBucketEndpoint + kImageModel;

View file

@ -1,16 +1,47 @@
import "dart:async";
import "dart:io";
import "package:connectivity_plus/connectivity_plus.dart";
import "package:flutter/services.dart";
import "package:logging/logging.dart";
import "package:path/path.dart";
import "package:path_provider/path_provider.dart";
import "package:photos/core/errors.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/core/network/network.dart";
import "package:photos/events/event.dart";
abstract class MLFramework {
static const kImageEncoderEnabled = true;
static const kMaximumRetrials = 3;
final _logger = Logger("MLFramework");
static final _logger = Logger("MLFramework");
final bool shouldDownloadOverMobileData;
InitializationState _state = InitializationState.notInitialized;
final _initializationCompleter = Completer<void>();
MLFramework(this.shouldDownloadOverMobileData) {
Connectivity()
.onConnectivityChanged
.listen((ConnectivityResult result) async {
_logger.info("Connectivity changed to $result");
if (_state == InitializationState.waitingForNetwork &&
await _canDownload()) {
unawaited(init());
}
});
}
InitializationState get initializationState => _state;
set _initState(InitializationState state) {
Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state));
_logger.info("Init state is $state");
_state = state;
}
/// Returns the path of the Image Model hosted remotely
String getImageModelRemotePath();
@ -35,8 +66,18 @@ abstract class MLFramework {
/// initialization. For eg. if you wish to load the model from `/assets`
/// instead of a CDN.
Future<void> init() async {
await _initImageModel();
await _initTextModel();
try {
await Future.wait([_initImageModel(), _initTextModel()]);
} catch (e, s) {
_logger.warning(e, s);
if (e is WiFiUnavailableError) {
return _initializationCompleter.future;
} else {
rethrow;
}
}
_initState = InitializationState.initialized;
_initializationCompleter.complete();
}
// Releases any resources held by the framework
@ -63,27 +104,33 @@ abstract class MLFramework {
if (!kImageEncoderEnabled) {
return;
}
_initState = InitializationState.initializingImageModel;
final path = await _getLocalImageModelPath();
if (File(path).existsSync()) {
if (await File(path).exists()) {
await loadImageModel(path);
} else {
_initState = InitializationState.downloadingImageModel;
final tempFile = File(path + ".temp");
await _downloadFile(getImageModelRemotePath(), tempFile.path);
await tempFile.rename(path);
await loadImageModel(path);
}
_initState = InitializationState.initializedImageModel;
}
Future<void> _initTextModel() async {
final path = await _getLocalTextModelPath();
if (File(path).existsSync()) {
_initState = InitializationState.initializingTextModel;
if (await File(path).exists()) {
await loadTextModel(path);
} else {
_initState = InitializationState.downloadingTextModel;
final tempFile = File(path + ".temp");
await _downloadFile(getTextModelRemotePath(), tempFile.path);
await tempFile.rename(path);
await loadTextModel(path);
}
_initState = InitializationState.initializedTextModel;
}
Future<String> _getLocalImageModelPath() async {
@ -103,6 +150,10 @@ abstract class MLFramework {
String savePath, {
int trialCount = 1,
}) async {
if (!await _canDownload()) {
_initState = InitializationState.waitingForNetwork;
throw WiFiUnavailableError();
}
_logger.info("Downloading " + url);
final existingFile = File(savePath);
if (await existingFile.exists()) {
@ -120,6 +171,12 @@ abstract class MLFramework {
}
}
Future<bool> _canDownload() async {
final connectivityResult = await (Connectivity().checkConnectivity());
return connectivityResult != ConnectivityResult.mobile ||
shouldDownloadOverMobileData;
}
Future<String> getAccessiblePathForAsset(
String assetPath,
String tempName,
@ -131,3 +188,21 @@ abstract class MLFramework {
return file.path;
}
}
class MLFrameworkInitializationUpdateEvent extends Event {
final InitializationState state;
MLFrameworkInitializationUpdateEvent(this.state);
}
enum InitializationState {
notInitialized,
waitingForNetwork,
downloadingImageModel,
initializingImageModel,
initializedImageModel,
downloadingTextModel,
initializingTextModel,
initializedTextModel,
initialized,
}

View file

@ -17,6 +17,8 @@ class ONNX extends MLFramework {
int _textEncoderAddress = 0;
int _imageEncoderAddress = 0;
ONNX(super.shouldDownloadOverMobileData);
@override
String getImageModelRemotePath() {
return kModelBucketEndpoint + kImageModel;

View file

@ -38,11 +38,11 @@ class SemanticSearchService {
final _logger = Logger("SemanticSearchService");
final _queue = Queue<EnteFile>();
final _mlFramework = kCurrentModel == Model.onnxClip ? ONNX() : GGML();
final _frameworkInitialization = Completer<bool>();
final _embeddingLoaderDebouncer =
Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
late MLFramework _mlFramework;
bool _hasInitialized = false;
bool _isComputingEmbeddings = false;
bool _isSyncing = false;
@ -61,6 +61,11 @@ class SemanticSearchService {
return;
}
_hasInitialized = true;
final shouldDownloadOverMobileData =
Configuration.instance.shouldBackupOverMobileData();
_mlFramework = kCurrentModel == Model.onnxClip
? ONNX(shouldDownloadOverMobileData)
: GGML(shouldDownloadOverMobileData);
await EmbeddingsDB.instance.init();
await EmbeddingStore.instance.init();
await _loadEmbeddings();
@ -145,8 +150,8 @@ class SemanticSearchService {
);
}
Future<bool> getFrameworkInitializationStatus() {
return _frameworkInitialization.future;
InitializationState getFrameworkInitializationState() {
return _mlFramework.initializationState;
}
Future<void> clearIndexes() async {

View file

@ -6,6 +6,7 @@ import "package:photos/core/event_bus.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/generated/l10n.dart";
import "package:photos/services/feature_flag_service.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import "package:photos/services/semantic_search/semantic_search_service.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/common/loading_widget.dart";
@ -29,6 +30,32 @@ class MachineLearningSettingsPage extends StatefulWidget {
class _MachineLearningSettingsPageState
extends State<MachineLearningSettingsPage> {
late InitializationState _state;
late StreamSubscription<MLFrameworkInitializationUpdateEvent>
_eventSubscription;
@override
void initState() {
super.initState();
_eventSubscription =
Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
_fetchState();
setState(() {});
});
_fetchState();
}
void _fetchState() {
_state = SemanticSearchService.instance.getFrameworkInitializationState();
}
@override
void dispose() {
super.dispose();
_eventSubscription.cancel();
}
@override
Widget build(BuildContext context) {
return Scaffold(
@ -118,17 +145,9 @@ class _MachineLearningSettingsPageState
hasEnabled
? Column(
children: [
FutureBuilder(
future: SemanticSearchService.instance
.getFrameworkInitializationStatus(),
builder: (BuildContext context, AsyncSnapshot snapshot) {
if (snapshot.hasData) {
return const MagicSearchIndexStatsWidget();
} else {
return const ModelLoadingState();
}
},
),
_state == InitializationState.initialized
? const MagicSearchIndexStatsWidget()
: ModelLoadingState(_state),
const SizedBox(
height: 12,
),
@ -158,7 +177,12 @@ class _MachineLearningSettingsPageState
}
class ModelLoadingState extends StatelessWidget {
const ModelLoadingState({super.key});
final InitializationState state;
const ModelLoadingState(
this.state, {
Key? key,
}) : super(key: key);
@override
Widget build(BuildContext context) {
@ -167,7 +191,7 @@ class ModelLoadingState extends StatelessWidget {
MenuSectionTitle(title: S.of(context).status),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: S.of(context).loadingModel,
title: _getTitle(context),
),
trailingWidget: EnteLoadingWidget(
size: 12,
@ -180,6 +204,15 @@ class ModelLoadingState extends StatelessWidget {
],
);
}
String _getTitle(BuildContext context) {
switch (state) {
case InitializationState.waitingForNetwork:
return S.of(context).waitingForWifi;
default:
return S.of(context).loadingModel;
}
}
}
class MagicSearchIndexStatsWidget extends StatefulWidget {

View file

@ -12,7 +12,7 @@ description: ente photos application
# Read more about iOS versioning at
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
version: 0.8.33+553
version: 0.8.35+555
environment:
sdk: ">=3.0.0 <4.0.0"