machine_learning_settings_page.dart 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import "dart:async";
  2. import "dart:math" show max, min;
  3. import "package:flutter/material.dart";
  4. import "package:intl/intl.dart";
  5. import "package:logging/logging.dart";
  6. import "package:photos/core/event_bus.dart";
  7. import 'package:photos/events/embedding_updated_event.dart';
  8. import "package:photos/face/db.dart";
  9. import "package:photos/generated/l10n.dart";
  10. import "package:photos/models/ml/ml_versions.dart";
  11. import "package:photos/service_locator.dart";
  12. import "package:photos/services/machine_learning/face_ml/face_ml_service.dart";
  13. import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
  14. import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
  15. import "package:photos/services/remote_assets_service.dart";
  16. import "package:photos/theme/ente_theme.dart";
  17. import "package:photos/ui/common/loading_widget.dart";
  18. import "package:photos/ui/components/buttons/icon_button_widget.dart";
  19. import "package:photos/ui/components/captioned_text_widget.dart";
  20. import "package:photos/ui/components/menu_item_widget/menu_item_widget.dart";
  21. import "package:photos/ui/components/menu_section_description_widget.dart";
  22. import "package:photos/ui/components/menu_section_title.dart";
  23. import "package:photos/ui/components/title_bar_title_widget.dart";
  24. import "package:photos/ui/components/title_bar_widget.dart";
  25. import "package:photos/ui/components/toggle_switch_widget.dart";
  26. import "package:photos/utils/data_util.dart";
  27. import "package:photos/utils/local_settings.dart";
  28. final _logger = Logger("MachineLearningSettingsPage");
  29. class MachineLearningSettingsPage extends StatefulWidget {
  30. const MachineLearningSettingsPage({super.key});
  31. @override
  32. State<MachineLearningSettingsPage> createState() =>
  33. _MachineLearningSettingsPageState();
  34. }
  35. class _MachineLearningSettingsPageState
  36. extends State<MachineLearningSettingsPage> {
  37. late InitializationState _state;
  38. late StreamSubscription<MLFrameworkInitializationUpdateEvent>
  39. _eventSubscription;
  40. @override
  41. void initState() {
  42. super.initState();
  43. _eventSubscription =
  44. Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
  45. _fetchState();
  46. setState(() {});
  47. });
  48. _fetchState();
  49. }
  50. void _fetchState() {
  51. _state = SemanticSearchService.instance.getFrameworkInitializationState();
  52. }
  53. @override
  54. void dispose() {
  55. super.dispose();
  56. _eventSubscription.cancel();
  57. }
  58. @override
  59. Widget build(BuildContext context) {
  60. final bool facesFlag = flagService.faceSearchEnabled;
  61. _logger.info("On page open, facesFlag: $facesFlag");
  62. return Scaffold(
  63. body: CustomScrollView(
  64. primary: false,
  65. slivers: <Widget>[
  66. TitleBarWidget(
  67. flexibleSpaceTitle: TitleBarTitleWidget(
  68. title: S.of(context).machineLearning,
  69. ),
  70. actionIcons: [
  71. IconButtonWidget(
  72. icon: Icons.close_outlined,
  73. iconButtonType: IconButtonType.secondary,
  74. onTap: () {
  75. Navigator.pop(context);
  76. Navigator.pop(context);
  77. Navigator.pop(context);
  78. },
  79. ),
  80. ],
  81. ),
  82. SliverList(
  83. delegate: SliverChildBuilderDelegate(
  84. (delegateBuildContext, index) {
  85. return Padding(
  86. padding: const EdgeInsets.symmetric(horizontal: 16),
  87. child: Padding(
  88. padding: const EdgeInsets.symmetric(vertical: 20),
  89. child: Column(
  90. mainAxisSize: MainAxisSize.min,
  91. children: [
  92. _getMagicSearchSettings(context),
  93. const SizedBox(height: 12),
  94. facesFlag
  95. ? _getFacesSearchSettings(context)
  96. : const SizedBox.shrink(),
  97. ],
  98. ),
  99. ),
  100. );
  101. },
  102. childCount: 1,
  103. ),
  104. ),
  105. ],
  106. ),
  107. );
  108. }
  109. Widget _getMagicSearchSettings(BuildContext context) {
  110. final colorScheme = getEnteColorScheme(context);
  111. final hasEnabled = LocalSettings.instance.hasEnabledMagicSearch();
  112. return Column(
  113. children: [
  114. MenuItemWidget(
  115. captionedTextWidget: CaptionedTextWidget(
  116. title: S.of(context).magicSearch,
  117. ),
  118. menuItemColor: colorScheme.fillFaint,
  119. trailingWidget: ToggleSwitchWidget(
  120. value: () => LocalSettings.instance.hasEnabledMagicSearch(),
  121. onChanged: () async {
  122. await LocalSettings.instance.setShouldEnableMagicSearch(
  123. !LocalSettings.instance.hasEnabledMagicSearch(),
  124. );
  125. if (LocalSettings.instance.hasEnabledMagicSearch()) {
  126. unawaited(
  127. SemanticSearchService.instance
  128. .init(shouldSyncImmediately: true),
  129. );
  130. } else {
  131. await SemanticSearchService.instance.clearQueue();
  132. }
  133. setState(() {});
  134. },
  135. ),
  136. singleBorderRadius: 8,
  137. alignCaptionedTextToLeft: true,
  138. isGestureDetectorDisabled: true,
  139. ),
  140. const SizedBox(
  141. height: 4,
  142. ),
  143. MenuSectionDescriptionWidget(
  144. content: S.of(context).magicSearchDescription,
  145. ),
  146. const SizedBox(
  147. height: 12,
  148. ),
  149. hasEnabled
  150. ? Column(
  151. children: [
  152. _state == InitializationState.initialized
  153. ? const MagicSearchIndexStatsWidget()
  154. : ModelLoadingState(_state),
  155. const SizedBox(
  156. height: 12,
  157. ),
  158. flagService.internalUser
  159. ? MenuItemWidget(
  160. leadingIcon: Icons.delete_sweep_outlined,
  161. captionedTextWidget: CaptionedTextWidget(
  162. title: S.of(context).clearIndexes,
  163. ),
  164. menuItemColor: getEnteColorScheme(context).fillFaint,
  165. singleBorderRadius: 8,
  166. alwaysShowSuccessState: true,
  167. onTap: () async {
  168. await SemanticSearchService.instance.clearIndexes();
  169. if (mounted) {
  170. setState(() => {});
  171. }
  172. },
  173. )
  174. : const SizedBox.shrink(),
  175. ],
  176. )
  177. : const SizedBox.shrink(),
  178. ],
  179. );
  180. }
  181. Widget _getFacesSearchSettings(BuildContext context) {
  182. final colorScheme = getEnteColorScheme(context);
  183. final hasEnabled = LocalSettings.instance.isFaceIndexingEnabled;
  184. return Column(
  185. children: [
  186. MenuItemWidget(
  187. captionedTextWidget: CaptionedTextWidget(
  188. title: S.of(context).faceRecognition,
  189. ),
  190. menuItemColor: colorScheme.fillFaint,
  191. trailingWidget: ToggleSwitchWidget(
  192. value: () => LocalSettings.instance.isFaceIndexingEnabled,
  193. onChanged: () async {
  194. final isEnabled =
  195. await LocalSettings.instance.toggleFaceIndexing();
  196. if (isEnabled) {
  197. unawaited(FaceMlService.instance.ensureInitialized());
  198. } else {
  199. FaceMlService.instance.pauseIndexingAndClustering();
  200. }
  201. if (mounted) {
  202. setState(() {});
  203. }
  204. },
  205. ),
  206. singleBorderRadius: 8,
  207. alignCaptionedTextToLeft: true,
  208. isGestureDetectorDisabled: true,
  209. ),
  210. const SizedBox(
  211. height: 4,
  212. ),
  213. MenuSectionDescriptionWidget(
  214. content: S.of(context).faceRecognitionIndexingDescription,
  215. ),
  216. const SizedBox(
  217. height: 12,
  218. ),
  219. hasEnabled
  220. ? const FaceRecognitionStatusWidget()
  221. : const SizedBox.shrink(),
  222. ],
  223. );
  224. }
  225. }
  226. class ModelLoadingState extends StatefulWidget {
  227. final InitializationState state;
  228. const ModelLoadingState(
  229. this.state, {
  230. Key? key,
  231. }) : super(key: key);
  232. @override
  233. State<ModelLoadingState> createState() => _ModelLoadingStateState();
  234. }
  235. class _ModelLoadingStateState extends State<ModelLoadingState> {
  236. StreamSubscription<(String, int, int)>? _progressStream;
  237. final Map<String, (int, int)> _progressMap = {};
  238. @override
  239. void initState() {
  240. _progressStream =
  241. RemoteAssetsService.instance.progressStream.listen((event) {
  242. final String url = event.$1;
  243. String title = "";
  244. if (url.contains("clip-image")) {
  245. title = "Image Model";
  246. } else if (url.contains("clip-text")) {
  247. title = "Text Model";
  248. }
  249. if (title.isNotEmpty) {
  250. _progressMap[title] = (event.$2, event.$3);
  251. setState(() {});
  252. }
  253. });
  254. super.initState();
  255. }
  256. @override
  257. void dispose() {
  258. super.dispose();
  259. _progressStream?.cancel();
  260. }
  261. @override
  262. Widget build(BuildContext context) {
  263. return Column(
  264. children: [
  265. MenuSectionTitle(title: S.of(context).status),
  266. MenuItemWidget(
  267. captionedTextWidget: CaptionedTextWidget(
  268. title: _getTitle(context),
  269. ),
  270. trailingWidget: EnteLoadingWidget(
  271. size: 12,
  272. color: getEnteColorScheme(context).fillMuted,
  273. ),
  274. singleBorderRadius: 8,
  275. alignCaptionedTextToLeft: true,
  276. isGestureDetectorDisabled: true,
  277. ),
  278. // show the progress map if in debug mode
  279. if (flagService.internalUser)
  280. ..._progressMap.entries.map((entry) {
  281. return MenuItemWidget(
  282. key: ValueKey(entry.value),
  283. captionedTextWidget: CaptionedTextWidget(
  284. title: entry.key,
  285. ),
  286. trailingWidget: Text(
  287. entry.value.$1 == entry.value.$2
  288. ? "Done"
  289. : "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}",
  290. style: Theme.of(context).textTheme.bodySmall,
  291. ),
  292. singleBorderRadius: 8,
  293. alignCaptionedTextToLeft: true,
  294. isGestureDetectorDisabled: true,
  295. );
  296. }).toList(),
  297. ],
  298. );
  299. }
  300. String _getTitle(BuildContext context) {
  301. switch (widget.state) {
  302. case InitializationState.waitingForNetwork:
  303. return S.of(context).waitingForWifi;
  304. default:
  305. return S.of(context).loadingModel;
  306. }
  307. }
  308. }
  309. class MagicSearchIndexStatsWidget extends StatefulWidget {
  310. const MagicSearchIndexStatsWidget({
  311. super.key,
  312. });
  313. @override
  314. State<MagicSearchIndexStatsWidget> createState() =>
  315. _MagicSearchIndexStatsWidgetState();
  316. }
  317. class _MagicSearchIndexStatsWidgetState
  318. extends State<MagicSearchIndexStatsWidget> {
  319. IndexStatus? _status;
  320. late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription;
  321. @override
  322. void initState() {
  323. super.initState();
  324. _eventSubscription =
  325. Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
  326. _fetchIndexStatus();
  327. });
  328. _fetchIndexStatus();
  329. }
  330. void _fetchIndexStatus() {
  331. SemanticSearchService.instance.getIndexStatus().then((status) {
  332. _status = status;
  333. setState(() {});
  334. });
  335. }
  336. @override
  337. void dispose() {
  338. super.dispose();
  339. _eventSubscription.cancel();
  340. }
  341. @override
  342. Widget build(BuildContext context) {
  343. if (_status == null) {
  344. return const EnteLoadingWidget();
  345. }
  346. return Column(
  347. children: [
  348. Row(
  349. children: [
  350. MenuSectionTitle(title: S.of(context).status),
  351. Expanded(child: Container()),
  352. _status!.pendingItems > 0
  353. ? EnteLoadingWidget(
  354. color: getEnteColorScheme(context).fillMuted,
  355. )
  356. : const SizedBox.shrink(),
  357. ],
  358. ),
  359. MenuItemWidget(
  360. captionedTextWidget: CaptionedTextWidget(
  361. title: S.of(context).indexedItems,
  362. ),
  363. trailingWidget: Text(
  364. NumberFormat().format(_status!.indexedItems),
  365. style: Theme.of(context).textTheme.bodySmall,
  366. ),
  367. singleBorderRadius: 8,
  368. alignCaptionedTextToLeft: true,
  369. isGestureDetectorDisabled: true,
  370. // Setting a key here to ensure trailingWidget is refreshed
  371. key: ValueKey("indexed_items_" + _status!.indexedItems.toString()),
  372. ),
  373. MenuItemWidget(
  374. captionedTextWidget: CaptionedTextWidget(
  375. title: S.of(context).pendingItems,
  376. ),
  377. trailingWidget: Text(
  378. NumberFormat().format(_status!.pendingItems),
  379. style: Theme.of(context).textTheme.bodySmall,
  380. ),
  381. singleBorderRadius: 8,
  382. alignCaptionedTextToLeft: true,
  383. isGestureDetectorDisabled: true,
  384. // Setting a key here to ensure trailingWidget is refreshed
  385. key: ValueKey("pending_items_" + _status!.pendingItems.toString()),
  386. ),
  387. ],
  388. );
  389. }
  390. }
  391. class FaceRecognitionStatusWidget extends StatefulWidget {
  392. const FaceRecognitionStatusWidget({
  393. super.key,
  394. });
  395. @override
  396. State<FaceRecognitionStatusWidget> createState() =>
  397. FaceRecognitionStatusWidgetState();
  398. }
  399. class FaceRecognitionStatusWidgetState
  400. extends State<FaceRecognitionStatusWidget> {
  401. Timer? _timer;
  402. @override
  403. void initState() {
  404. super.initState();
  405. _timer = Timer.periodic(const Duration(seconds: 10), (timer) {
  406. setState(() {
  407. // Your state update logic here
  408. });
  409. });
  410. }
  411. Future<(int, int, int, double)> getIndexStatus() async {
  412. try {
  413. final indexedFiles = await FaceMLDataDB.instance
  414. .getIndexedFileCount(minimumMlVersion: faceMlVersion);
  415. final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
  416. final showIndexedFiles = min(indexedFiles, indexableFiles);
  417. final pendingFiles = max(indexableFiles - indexedFiles, 0);
  418. final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
  419. final clusteredFaces =
  420. await FaceMLDataDB.instance.getClusteredFaceCount();
  421. final clusteringDoneRatio = clusteredFaces / foundFaces;
  422. return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
  423. } catch (e, s) {
  424. _logger.severe('Error getting face recognition status', e, s);
  425. rethrow;
  426. }
  427. }
  428. @override
  429. void dispose() {
  430. _timer?.cancel();
  431. super.dispose();
  432. }
  433. @override
  434. Widget build(BuildContext context) {
  435. return Column(
  436. children: [
  437. Row(
  438. children: [
  439. MenuSectionTitle(title: S.of(context).status),
  440. Expanded(child: Container()),
  441. ],
  442. ),
  443. FutureBuilder(
  444. future: getIndexStatus(),
  445. builder: (context, snapshot) {
  446. if (snapshot.hasData) {
  447. final int indexedFiles = snapshot.data!.$1;
  448. final int pendingFiles = snapshot.data!.$2;
  449. final int foundFaces = snapshot.data!.$3;
  450. final double clusteringDoneRatio = snapshot.data!.$4;
  451. final double clusteringPercentage =
  452. (clusteringDoneRatio * 100).clamp(0, 100);
  453. return Column(
  454. children: [
  455. MenuItemWidget(
  456. captionedTextWidget: CaptionedTextWidget(
  457. title: S.of(context).indexedItems,
  458. ),
  459. trailingWidget: Text(
  460. NumberFormat().format(indexedFiles),
  461. style: Theme.of(context).textTheme.bodySmall,
  462. ),
  463. singleBorderRadius: 8,
  464. alignCaptionedTextToLeft: true,
  465. isGestureDetectorDisabled: true,
  466. key: ValueKey("indexed_items_" + indexedFiles.toString()),
  467. ),
  468. MenuItemWidget(
  469. captionedTextWidget: CaptionedTextWidget(
  470. title: S.of(context).pendingItems,
  471. ),
  472. trailingWidget: Text(
  473. NumberFormat().format(pendingFiles),
  474. style: Theme.of(context).textTheme.bodySmall,
  475. ),
  476. singleBorderRadius: 8,
  477. alignCaptionedTextToLeft: true,
  478. isGestureDetectorDisabled: true,
  479. key: ValueKey("pending_items_" + pendingFiles.toString()),
  480. ),
  481. MenuItemWidget(
  482. captionedTextWidget: CaptionedTextWidget(
  483. title: S.of(context).foundFaces,
  484. ),
  485. trailingWidget: Text(
  486. NumberFormat().format(foundFaces),
  487. style: Theme.of(context).textTheme.bodySmall,
  488. ),
  489. singleBorderRadius: 8,
  490. alignCaptionedTextToLeft: true,
  491. isGestureDetectorDisabled: true,
  492. key: ValueKey("found_faces_" + foundFaces.toString()),
  493. ),
  494. MenuItemWidget(
  495. captionedTextWidget: CaptionedTextWidget(
  496. title: S.of(context).clusteringProgress,
  497. ),
  498. trailingWidget: Text(
  499. "${clusteringPercentage.toStringAsFixed(0)}%",
  500. style: Theme.of(context).textTheme.bodySmall,
  501. ),
  502. singleBorderRadius: 8,
  503. alignCaptionedTextToLeft: true,
  504. isGestureDetectorDisabled: true,
  505. key: ValueKey(
  506. "clustering_progress_" +
  507. clusteringPercentage.toStringAsFixed(0),
  508. ),
  509. ),
  510. ],
  511. );
  512. }
  513. return const EnteLoadingWidget();
  514. },
  515. ),
  516. ],
  517. );
  518. }
  519. }