machine_learning_settings_page.dart 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. import "dart:async";
  2. import "dart:math" show max, min;
  3. import "package:flutter/material.dart";
  4. import "package:intl/intl.dart";
  5. import "package:photos/core/event_bus.dart";
  6. import 'package:photos/events/embedding_updated_event.dart';
  7. import "package:photos/face/db.dart";
  8. import "package:photos/generated/l10n.dart";
  9. import "package:photos/models/ml/ml_versions.dart";
  10. import "package:photos/service_locator.dart";
  11. import "package:photos/services/machine_learning/face_ml/face_ml_service.dart";
  12. import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
  13. import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
  14. import "package:photos/services/remote_assets_service.dart";
  15. import "package:photos/theme/ente_theme.dart";
  16. import "package:photos/ui/common/loading_widget.dart";
  17. import "package:photos/ui/components/buttons/icon_button_widget.dart";
  18. import "package:photos/ui/components/captioned_text_widget.dart";
  19. import "package:photos/ui/components/menu_item_widget/menu_item_widget.dart";
  20. import "package:photos/ui/components/menu_section_description_widget.dart";
  21. import "package:photos/ui/components/menu_section_title.dart";
  22. import "package:photos/ui/components/title_bar_title_widget.dart";
  23. import "package:photos/ui/components/title_bar_widget.dart";
  24. import "package:photos/ui/components/toggle_switch_widget.dart";
  25. import "package:photos/utils/data_util.dart";
  26. import "package:photos/utils/local_settings.dart";
  27. class MachineLearningSettingsPage extends StatefulWidget {
  28. const MachineLearningSettingsPage({super.key});
  29. @override
  30. State<MachineLearningSettingsPage> createState() =>
  31. _MachineLearningSettingsPageState();
  32. }
  33. class _MachineLearningSettingsPageState
  34. extends State<MachineLearningSettingsPage> {
  35. late InitializationState _state;
  36. late StreamSubscription<MLFrameworkInitializationUpdateEvent>
  37. _eventSubscription;
  38. @override
  39. void initState() {
  40. super.initState();
  41. _eventSubscription =
  42. Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
  43. _fetchState();
  44. setState(() {});
  45. });
  46. _fetchState();
  47. }
  48. void _fetchState() {
  49. _state = SemanticSearchService.instance.getFrameworkInitializationState();
  50. }
  51. @override
  52. void dispose() {
  53. super.dispose();
  54. _eventSubscription.cancel();
  55. }
  56. @override
  57. Widget build(BuildContext context) {
  58. final bool facesFlag = flagService.faceSearchEnabled;
  59. return Scaffold(
  60. body: CustomScrollView(
  61. primary: false,
  62. slivers: <Widget>[
  63. TitleBarWidget(
  64. flexibleSpaceTitle: TitleBarTitleWidget(
  65. title: S.of(context).machineLearning,
  66. ),
  67. actionIcons: [
  68. IconButtonWidget(
  69. icon: Icons.close_outlined,
  70. iconButtonType: IconButtonType.secondary,
  71. onTap: () {
  72. Navigator.pop(context);
  73. Navigator.pop(context);
  74. Navigator.pop(context);
  75. },
  76. ),
  77. ],
  78. ),
  79. SliverList(
  80. delegate: SliverChildBuilderDelegate(
  81. (delegateBuildContext, index) {
  82. return Padding(
  83. padding: const EdgeInsets.symmetric(horizontal: 16),
  84. child: Padding(
  85. padding: const EdgeInsets.symmetric(vertical: 20),
  86. child: Column(
  87. mainAxisSize: MainAxisSize.min,
  88. children: [
  89. _getMagicSearchSettings(context),
  90. const SizedBox(height: 12),
  91. facesFlag
  92. ? _getFacesSearchSettings(context)
  93. : const SizedBox.shrink(),
  94. ],
  95. ),
  96. ),
  97. );
  98. },
  99. childCount: 1,
  100. ),
  101. ),
  102. ],
  103. ),
  104. );
  105. }
  106. Widget _getMagicSearchSettings(BuildContext context) {
  107. final colorScheme = getEnteColorScheme(context);
  108. final hasEnabled = LocalSettings.instance.hasEnabledMagicSearch();
  109. return Column(
  110. children: [
  111. MenuItemWidget(
  112. captionedTextWidget: CaptionedTextWidget(
  113. title: S.of(context).magicSearch,
  114. ),
  115. menuItemColor: colorScheme.fillFaint,
  116. trailingWidget: ToggleSwitchWidget(
  117. value: () => LocalSettings.instance.hasEnabledMagicSearch(),
  118. onChanged: () async {
  119. await LocalSettings.instance.setShouldEnableMagicSearch(
  120. !LocalSettings.instance.hasEnabledMagicSearch(),
  121. );
  122. if (LocalSettings.instance.hasEnabledMagicSearch()) {
  123. unawaited(
  124. SemanticSearchService.instance
  125. .init(shouldSyncImmediately: true),
  126. );
  127. } else {
  128. await SemanticSearchService.instance.clearQueue();
  129. }
  130. setState(() {});
  131. },
  132. ),
  133. singleBorderRadius: 8,
  134. alignCaptionedTextToLeft: true,
  135. isGestureDetectorDisabled: true,
  136. ),
  137. const SizedBox(
  138. height: 4,
  139. ),
  140. MenuSectionDescriptionWidget(
  141. content: S.of(context).magicSearchDescription,
  142. ),
  143. const SizedBox(
  144. height: 12,
  145. ),
  146. hasEnabled
  147. ? Column(
  148. children: [
  149. _state == InitializationState.initialized
  150. ? const MagicSearchIndexStatsWidget()
  151. : ModelLoadingState(_state),
  152. const SizedBox(
  153. height: 12,
  154. ),
  155. flagService.internalUser
  156. ? MenuItemWidget(
  157. leadingIcon: Icons.delete_sweep_outlined,
  158. captionedTextWidget: CaptionedTextWidget(
  159. title: S.of(context).clearIndexes,
  160. ),
  161. menuItemColor: getEnteColorScheme(context).fillFaint,
  162. singleBorderRadius: 8,
  163. alwaysShowSuccessState: true,
  164. onTap: () async {
  165. await SemanticSearchService.instance.clearIndexes();
  166. if (mounted) {
  167. setState(() => {});
  168. }
  169. },
  170. )
  171. : const SizedBox.shrink(),
  172. ],
  173. )
  174. : const SizedBox.shrink(),
  175. ],
  176. );
  177. }
  178. Widget _getFacesSearchSettings(BuildContext context) {
  179. final colorScheme = getEnteColorScheme(context);
  180. final hasEnabled = LocalSettings.instance.isFaceIndexingEnabled;
  181. return Column(
  182. children: [
  183. MenuItemWidget(
  184. captionedTextWidget: CaptionedTextWidget(
  185. title: S.of(context).faceRecognition,
  186. ),
  187. menuItemColor: colorScheme.fillFaint,
  188. trailingWidget: ToggleSwitchWidget(
  189. value: () => LocalSettings.instance.isFaceIndexingEnabled,
  190. onChanged: () async {
  191. final isEnabled =
  192. await LocalSettings.instance.toggleFaceIndexing();
  193. if (isEnabled) {
  194. unawaited(FaceMlService.instance.ensureInitialized());
  195. } else {
  196. FaceMlService.instance.pauseIndexing();
  197. }
  198. if (mounted) {
  199. setState(() {});
  200. }
  201. },
  202. ),
  203. singleBorderRadius: 8,
  204. alignCaptionedTextToLeft: true,
  205. isGestureDetectorDisabled: true,
  206. ),
  207. const SizedBox(
  208. height: 4,
  209. ),
  210. MenuSectionDescriptionWidget(
  211. content: S.of(context).faceRecognitionIndexingDescription,
  212. ),
  213. const SizedBox(
  214. height: 12,
  215. ),
  216. hasEnabled
  217. ? const FaceRecognitionStatusWidget()
  218. : const SizedBox.shrink(),
  219. ],
  220. );
  221. }
  222. }
  223. class ModelLoadingState extends StatefulWidget {
  224. final InitializationState state;
  225. const ModelLoadingState(
  226. this.state, {
  227. Key? key,
  228. }) : super(key: key);
  229. @override
  230. State<ModelLoadingState> createState() => _ModelLoadingStateState();
  231. }
  232. class _ModelLoadingStateState extends State<ModelLoadingState> {
  233. StreamSubscription<(String, int, int)>? _progressStream;
  234. final Map<String, (int, int)> _progressMap = {};
  235. @override
  236. void initState() {
  237. _progressStream =
  238. RemoteAssetsService.instance.progressStream.listen((event) {
  239. final String url = event.$1;
  240. String title = "";
  241. if (url.contains("clip-image")) {
  242. title = "Image Model";
  243. } else if (url.contains("clip-text")) {
  244. title = "Text Model";
  245. }
  246. if (title.isNotEmpty) {
  247. _progressMap[title] = (event.$2, event.$3);
  248. setState(() {});
  249. }
  250. });
  251. super.initState();
  252. }
  253. @override
  254. void dispose() {
  255. super.dispose();
  256. _progressStream?.cancel();
  257. }
  258. @override
  259. Widget build(BuildContext context) {
  260. return Column(
  261. children: [
  262. MenuSectionTitle(title: S.of(context).status),
  263. MenuItemWidget(
  264. captionedTextWidget: CaptionedTextWidget(
  265. title: _getTitle(context),
  266. ),
  267. trailingWidget: EnteLoadingWidget(
  268. size: 12,
  269. color: getEnteColorScheme(context).fillMuted,
  270. ),
  271. singleBorderRadius: 8,
  272. alignCaptionedTextToLeft: true,
  273. isGestureDetectorDisabled: true,
  274. ),
  275. // show the progress map if in debug mode
  276. if (flagService.internalUser)
  277. ..._progressMap.entries.map((entry) {
  278. return MenuItemWidget(
  279. key: ValueKey(entry.value),
  280. captionedTextWidget: CaptionedTextWidget(
  281. title: entry.key,
  282. ),
  283. trailingWidget: Text(
  284. entry.value.$1 == entry.value.$2
  285. ? "Done"
  286. : "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}",
  287. style: Theme.of(context).textTheme.bodySmall,
  288. ),
  289. singleBorderRadius: 8,
  290. alignCaptionedTextToLeft: true,
  291. isGestureDetectorDisabled: true,
  292. );
  293. }).toList(),
  294. ],
  295. );
  296. }
  297. String _getTitle(BuildContext context) {
  298. switch (widget.state) {
  299. case InitializationState.waitingForNetwork:
  300. return S.of(context).waitingForWifi;
  301. default:
  302. return S.of(context).loadingModel;
  303. }
  304. }
  305. }
  306. class MagicSearchIndexStatsWidget extends StatefulWidget {
  307. const MagicSearchIndexStatsWidget({
  308. super.key,
  309. });
  310. @override
  311. State<MagicSearchIndexStatsWidget> createState() =>
  312. _MagicSearchIndexStatsWidgetState();
  313. }
  314. class _MagicSearchIndexStatsWidgetState
  315. extends State<MagicSearchIndexStatsWidget> {
  316. IndexStatus? _status;
  317. late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription;
  318. @override
  319. void initState() {
  320. super.initState();
  321. _eventSubscription =
  322. Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
  323. _fetchIndexStatus();
  324. });
  325. _fetchIndexStatus();
  326. }
  327. void _fetchIndexStatus() {
  328. SemanticSearchService.instance.getIndexStatus().then((status) {
  329. _status = status;
  330. setState(() {});
  331. });
  332. }
  333. @override
  334. void dispose() {
  335. super.dispose();
  336. _eventSubscription.cancel();
  337. }
  338. @override
  339. Widget build(BuildContext context) {
  340. if (_status == null) {
  341. return const EnteLoadingWidget();
  342. }
  343. return Column(
  344. children: [
  345. Row(
  346. children: [
  347. MenuSectionTitle(title: S.of(context).status),
  348. Expanded(child: Container()),
  349. _status!.pendingItems > 0
  350. ? EnteLoadingWidget(
  351. color: getEnteColorScheme(context).fillMuted,
  352. )
  353. : const SizedBox.shrink(),
  354. ],
  355. ),
  356. MenuItemWidget(
  357. captionedTextWidget: CaptionedTextWidget(
  358. title: S.of(context).indexedItems,
  359. ),
  360. trailingWidget: Text(
  361. NumberFormat().format(_status!.indexedItems),
  362. style: Theme.of(context).textTheme.bodySmall,
  363. ),
  364. singleBorderRadius: 8,
  365. alignCaptionedTextToLeft: true,
  366. isGestureDetectorDisabled: true,
  367. // Setting a key here to ensure trailingWidget is refreshed
  368. key: ValueKey("indexed_items_" + _status!.indexedItems.toString()),
  369. ),
  370. MenuItemWidget(
  371. captionedTextWidget: CaptionedTextWidget(
  372. title: S.of(context).pendingItems,
  373. ),
  374. trailingWidget: Text(
  375. NumberFormat().format(_status!.pendingItems),
  376. style: Theme.of(context).textTheme.bodySmall,
  377. ),
  378. singleBorderRadius: 8,
  379. alignCaptionedTextToLeft: true,
  380. isGestureDetectorDisabled: true,
  381. // Setting a key here to ensure trailingWidget is refreshed
  382. key: ValueKey("pending_items_" + _status!.pendingItems.toString()),
  383. ),
  384. ],
  385. );
  386. }
  387. }
  388. class FaceRecognitionStatusWidget extends StatefulWidget {
  389. const FaceRecognitionStatusWidget({
  390. super.key,
  391. });
  392. @override
  393. State<FaceRecognitionStatusWidget> createState() =>
  394. FaceRecognitionStatusWidgetState();
  395. }
  396. class FaceRecognitionStatusWidgetState
  397. extends State<FaceRecognitionStatusWidget> {
  398. Timer? _timer;
  399. @override
  400. void initState() {
  401. super.initState();
  402. _timer = Timer.periodic(const Duration(seconds: 10), (timer) {
  403. setState(() {
  404. // Your state update logic here
  405. });
  406. });
  407. }
  408. Future<(int, int, int, double)> getIndexStatus() async {
  409. final indexedFiles = await FaceMLDataDB.instance
  410. .getIndexedFileCount(minimumMlVersion: faceMlVersion);
  411. final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
  412. final showIndexedFiles = min(indexedFiles, indexableFiles);
  413. final pendingFiles = max(indexableFiles - indexedFiles, 0);
  414. final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
  415. final clusteredFaces = await FaceMLDataDB.instance.getClusteredFaceCount();
  416. final clusteringDoneRatio = clusteredFaces / foundFaces;
  417. return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
  418. }
  419. @override
  420. void dispose() {
  421. _timer?.cancel();
  422. super.dispose();
  423. }
  424. @override
  425. Widget build(BuildContext context) {
  426. return Column(
  427. children: [
  428. Row(
  429. children: [
  430. MenuSectionTitle(title: S.of(context).status),
  431. Expanded(child: Container()),
  432. ],
  433. ),
  434. FutureBuilder(
  435. future: getIndexStatus(),
  436. builder: (context, snapshot) {
  437. if (snapshot.hasData) {
  438. final int indexedFiles = snapshot.data!.$1;
  439. final int pendingFiles = snapshot.data!.$2;
  440. final int foundFaces = snapshot.data!.$3;
  441. final double clusteringDoneRatio = snapshot.data!.$4;
  442. final double clusteringPercentage =
  443. (clusteringDoneRatio * 100).clamp(0, 100);
  444. return Column(
  445. children: [
  446. MenuItemWidget(
  447. captionedTextWidget: CaptionedTextWidget(
  448. title: S.of(context).indexedItems,
  449. ),
  450. trailingWidget: Text(
  451. NumberFormat().format(indexedFiles),
  452. style: Theme.of(context).textTheme.bodySmall,
  453. ),
  454. singleBorderRadius: 8,
  455. alignCaptionedTextToLeft: true,
  456. isGestureDetectorDisabled: true,
  457. key: ValueKey("indexed_items_" + indexedFiles.toString()),
  458. ),
  459. MenuItemWidget(
  460. captionedTextWidget: CaptionedTextWidget(
  461. title: S.of(context).pendingItems,
  462. ),
  463. trailingWidget: Text(
  464. NumberFormat().format(pendingFiles),
  465. style: Theme.of(context).textTheme.bodySmall,
  466. ),
  467. singleBorderRadius: 8,
  468. alignCaptionedTextToLeft: true,
  469. isGestureDetectorDisabled: true,
  470. key: ValueKey("pending_items_" + pendingFiles.toString()),
  471. ),
  472. MenuItemWidget(
  473. captionedTextWidget: CaptionedTextWidget(
  474. title: S.of(context).foundFaces,
  475. ),
  476. trailingWidget: Text(
  477. NumberFormat().format(foundFaces),
  478. style: Theme.of(context).textTheme.bodySmall,
  479. ),
  480. singleBorderRadius: 8,
  481. alignCaptionedTextToLeft: true,
  482. isGestureDetectorDisabled: true,
  483. key: ValueKey("found_faces_" + foundFaces.toString()),
  484. ),
  485. MenuItemWidget(
  486. captionedTextWidget: CaptionedTextWidget(
  487. title: S.of(context).clusteringProgress,
  488. ),
  489. trailingWidget: Text(
  490. "${clusteringPercentage.toStringAsFixed(0)}%",
  491. style: Theme.of(context).textTheme.bodySmall,
  492. ),
  493. singleBorderRadius: 8,
  494. alignCaptionedTextToLeft: true,
  495. isGestureDetectorDisabled: true,
  496. key: ValueKey(
  497. "clustering_progress_" +
  498. clusteringPercentage.toStringAsFixed(0),
  499. ),
  500. ),
  501. ],
  502. );
  503. }
  504. return const EnteLoadingWidget();
  505. },
  506. ),
  507. ],
  508. );
  509. }
  510. }