machine_learning_settings_page.dart 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import "dart:async";
  2. import "dart:math" show max, min;
  3. import "package:flutter/material.dart";
  4. import "package:intl/intl.dart";
  5. import "package:logging/logging.dart";
  6. import "package:photos/core/event_bus.dart";
  7. import 'package:photos/events/embedding_updated_event.dart';
  8. import "package:photos/face/db.dart";
  9. import "package:photos/generated/l10n.dart";
  10. import "package:photos/models/ml/ml_versions.dart";
  11. import "package:photos/service_locator.dart";
  12. import "package:photos/services/machine_learning/face_ml/face_ml_service.dart";
  13. import "package:photos/services/machine_learning/machine_learning_controller.dart";
  14. import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
  15. import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
  16. import "package:photos/services/remote_assets_service.dart";
  17. import "package:photos/theme/ente_theme.dart";
  18. import "package:photos/ui/common/loading_widget.dart";
  19. import "package:photos/ui/components/buttons/icon_button_widget.dart";
  20. import "package:photos/ui/components/captioned_text_widget.dart";
  21. import "package:photos/ui/components/menu_item_widget/menu_item_widget.dart";
  22. import "package:photos/ui/components/menu_section_description_widget.dart";
  23. import "package:photos/ui/components/menu_section_title.dart";
  24. import "package:photos/ui/components/title_bar_title_widget.dart";
  25. import "package:photos/ui/components/title_bar_widget.dart";
  26. import "package:photos/ui/components/toggle_switch_widget.dart";
  27. import "package:photos/utils/data_util.dart";
  28. import "package:photos/utils/local_settings.dart";
  29. import "package:photos/utils/ml_util.dart";
  30. import "package:photos/utils/wakelock_util.dart";
  31. final _logger = Logger("MachineLearningSettingsPage");
  32. class MachineLearningSettingsPage extends StatefulWidget {
  33. const MachineLearningSettingsPage({super.key});
  34. @override
  35. State<MachineLearningSettingsPage> createState() =>
  36. _MachineLearningSettingsPageState();
  37. }
  38. class _MachineLearningSettingsPageState
  39. extends State<MachineLearningSettingsPage> {
  40. late InitializationState _state;
  41. final EnteWakeLock _wakeLock = EnteWakeLock();
  42. late StreamSubscription<MLFrameworkInitializationUpdateEvent>
  43. _eventSubscription;
  44. @override
  45. void initState() {
  46. super.initState();
  47. _eventSubscription =
  48. Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
  49. _fetchState();
  50. setState(() {});
  51. });
  52. _fetchState();
  53. _wakeLock.enable();
  54. }
  55. void _fetchState() {
  56. _state = SemanticSearchService.instance.getFrameworkInitializationState();
  57. }
  58. @override
  59. void dispose() {
  60. super.dispose();
  61. _eventSubscription.cancel();
  62. _wakeLock.disable();
  63. }
  64. @override
  65. Widget build(BuildContext context) {
  66. final bool facesFlag = flagService.faceSearchEnabled;
  67. _logger.info("On page open, facesFlag: $facesFlag");
  68. return Scaffold(
  69. body: CustomScrollView(
  70. primary: false,
  71. slivers: <Widget>[
  72. TitleBarWidget(
  73. flexibleSpaceTitle: TitleBarTitleWidget(
  74. title: S.of(context).machineLearning,
  75. ),
  76. actionIcons: [
  77. IconButtonWidget(
  78. icon: Icons.close_outlined,
  79. iconButtonType: IconButtonType.secondary,
  80. onTap: () {
  81. Navigator.pop(context);
  82. Navigator.pop(context);
  83. Navigator.pop(context);
  84. },
  85. ),
  86. ],
  87. ),
  88. SliverList(
  89. delegate: SliverChildBuilderDelegate(
  90. (delegateBuildContext, index) {
  91. return Padding(
  92. padding: const EdgeInsets.symmetric(horizontal: 16),
  93. child: Padding(
  94. padding: const EdgeInsets.symmetric(vertical: 20),
  95. child: Column(
  96. mainAxisSize: MainAxisSize.min,
  97. children: [
  98. _getMagicSearchSettings(context),
  99. const SizedBox(height: 12),
  100. facesFlag
  101. ? _getFacesSearchSettings(context)
  102. : const SizedBox.shrink(),
  103. ],
  104. ),
  105. ),
  106. );
  107. },
  108. childCount: 1,
  109. ),
  110. ),
  111. ],
  112. ),
  113. );
  114. }
  115. Widget _getMagicSearchSettings(BuildContext context) {
  116. final colorScheme = getEnteColorScheme(context);
  117. final hasEnabled = LocalSettings.instance.hasEnabledMagicSearch();
  118. return Column(
  119. children: [
  120. MenuItemWidget(
  121. captionedTextWidget: CaptionedTextWidget(
  122. title: S.of(context).magicSearch,
  123. ),
  124. menuItemColor: colorScheme.fillFaint,
  125. trailingWidget: ToggleSwitchWidget(
  126. value: () => LocalSettings.instance.hasEnabledMagicSearch(),
  127. onChanged: () async {
  128. await LocalSettings.instance.setShouldEnableMagicSearch(
  129. !LocalSettings.instance.hasEnabledMagicSearch(),
  130. );
  131. if (LocalSettings.instance.hasEnabledMagicSearch()) {
  132. unawaited(
  133. SemanticSearchService.instance
  134. .init(shouldSyncImmediately: true),
  135. );
  136. } else {
  137. await SemanticSearchService.instance.clearQueue();
  138. }
  139. setState(() {});
  140. },
  141. ),
  142. singleBorderRadius: 8,
  143. alignCaptionedTextToLeft: true,
  144. isGestureDetectorDisabled: true,
  145. ),
  146. const SizedBox(
  147. height: 4,
  148. ),
  149. MenuSectionDescriptionWidget(
  150. content: S.of(context).magicSearchDescription,
  151. ),
  152. const SizedBox(
  153. height: 12,
  154. ),
  155. hasEnabled
  156. ? Column(
  157. children: [
  158. _state == InitializationState.initialized
  159. ? const MagicSearchIndexStatsWidget()
  160. : ModelLoadingState(_state),
  161. const SizedBox(
  162. height: 12,
  163. ),
  164. flagService.internalUser
  165. ? MenuItemWidget(
  166. leadingIcon: Icons.delete_sweep_outlined,
  167. captionedTextWidget: CaptionedTextWidget(
  168. title: S.of(context).clearIndexes,
  169. ),
  170. menuItemColor: getEnteColorScheme(context).fillFaint,
  171. singleBorderRadius: 8,
  172. alwaysShowSuccessState: true,
  173. onTap: () async {
  174. await SemanticSearchService.instance.clearIndexes();
  175. if (mounted) {
  176. setState(() => {});
  177. }
  178. },
  179. )
  180. : const SizedBox.shrink(),
  181. ],
  182. )
  183. : const SizedBox.shrink(),
  184. ],
  185. );
  186. }
  187. Widget _getFacesSearchSettings(BuildContext context) {
  188. final colorScheme = getEnteColorScheme(context);
  189. final hasEnabled = LocalSettings.instance.isFaceIndexingEnabled;
  190. return Column(
  191. children: [
  192. MenuItemWidget(
  193. captionedTextWidget: CaptionedTextWidget(
  194. title: S.of(context).faceRecognition,
  195. ),
  196. menuItemColor: colorScheme.fillFaint,
  197. trailingWidget: ToggleSwitchWidget(
  198. value: () => LocalSettings.instance.isFaceIndexingEnabled,
  199. onChanged: () async {
  200. final isEnabled =
  201. await LocalSettings.instance.toggleFaceIndexing();
  202. if (isEnabled) {
  203. unawaited(FaceMlService.instance.ensureInitialized());
  204. } else {
  205. FaceMlService.instance.pauseIndexingAndClustering();
  206. }
  207. if (mounted) {
  208. setState(() {});
  209. }
  210. },
  211. ),
  212. singleBorderRadius: 8,
  213. alignCaptionedTextToLeft: true,
  214. isGestureDetectorDisabled: true,
  215. ),
  216. const SizedBox(
  217. height: 4,
  218. ),
  219. MenuSectionDescriptionWidget(
  220. content: S.of(context).faceRecognitionIndexingDescription,
  221. ),
  222. const SizedBox(
  223. height: 12,
  224. ),
  225. hasEnabled
  226. ? const FaceRecognitionStatusWidget()
  227. : const SizedBox.shrink(),
  228. ],
  229. );
  230. }
  231. }
  232. class ModelLoadingState extends StatefulWidget {
  233. final InitializationState state;
  234. const ModelLoadingState(
  235. this.state, {
  236. Key? key,
  237. }) : super(key: key);
  238. @override
  239. State<ModelLoadingState> createState() => _ModelLoadingStateState();
  240. }
  241. class _ModelLoadingStateState extends State<ModelLoadingState> {
  242. StreamSubscription<(String, int, int)>? _progressStream;
  243. final Map<String, (int, int)> _progressMap = {};
  244. @override
  245. void initState() {
  246. _progressStream =
  247. RemoteAssetsService.instance.progressStream.listen((event) {
  248. final String url = event.$1;
  249. String title = "";
  250. if (url.contains("clip-image")) {
  251. title = "Image Model";
  252. } else if (url.contains("clip-text")) {
  253. title = "Text Model";
  254. }
  255. if (title.isNotEmpty) {
  256. _progressMap[title] = (event.$2, event.$3);
  257. setState(() {});
  258. }
  259. });
  260. super.initState();
  261. }
  262. @override
  263. void dispose() {
  264. super.dispose();
  265. _progressStream?.cancel();
  266. }
  267. @override
  268. Widget build(BuildContext context) {
  269. return Column(
  270. children: [
  271. MenuSectionTitle(title: S.of(context).status),
  272. MenuItemWidget(
  273. captionedTextWidget: CaptionedTextWidget(
  274. title: _getTitle(context),
  275. ),
  276. trailingWidget: EnteLoadingWidget(
  277. size: 12,
  278. color: getEnteColorScheme(context).fillMuted,
  279. ),
  280. singleBorderRadius: 8,
  281. alignCaptionedTextToLeft: true,
  282. isGestureDetectorDisabled: true,
  283. ),
  284. // show the progress map if in debug mode
  285. if (flagService.internalUser)
  286. ..._progressMap.entries.map((entry) {
  287. return MenuItemWidget(
  288. key: ValueKey(entry.value),
  289. captionedTextWidget: CaptionedTextWidget(
  290. title: entry.key,
  291. ),
  292. trailingWidget: Text(
  293. entry.value.$1 == entry.value.$2
  294. ? "Done"
  295. : "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}",
  296. style: Theme.of(context).textTheme.bodySmall,
  297. ),
  298. singleBorderRadius: 8,
  299. alignCaptionedTextToLeft: true,
  300. isGestureDetectorDisabled: true,
  301. );
  302. }).toList(),
  303. ],
  304. );
  305. }
  306. String _getTitle(BuildContext context) {
  307. switch (widget.state) {
  308. case InitializationState.waitingForNetwork:
  309. return S.of(context).waitingForWifi;
  310. default:
  311. return S.of(context).loadingModel;
  312. }
  313. }
  314. }
  315. class MagicSearchIndexStatsWidget extends StatefulWidget {
  316. const MagicSearchIndexStatsWidget({
  317. super.key,
  318. });
  319. @override
  320. State<MagicSearchIndexStatsWidget> createState() =>
  321. _MagicSearchIndexStatsWidgetState();
  322. }
  323. class _MagicSearchIndexStatsWidgetState
  324. extends State<MagicSearchIndexStatsWidget> {
  325. IndexStatus? _status;
  326. late StreamSubscription<EmbeddingCacheUpdatedEvent> _eventSubscription;
  327. @override
  328. void initState() {
  329. super.initState();
  330. _eventSubscription =
  331. Bus.instance.on<EmbeddingCacheUpdatedEvent>().listen((event) {
  332. _fetchIndexStatus();
  333. });
  334. _fetchIndexStatus();
  335. }
  336. void _fetchIndexStatus() {
  337. SemanticSearchService.instance.getIndexStatus().then((status) {
  338. _status = status;
  339. setState(() {});
  340. });
  341. }
  342. @override
  343. void dispose() {
  344. super.dispose();
  345. _eventSubscription.cancel();
  346. }
  347. @override
  348. Widget build(BuildContext context) {
  349. if (_status == null) {
  350. return const EnteLoadingWidget();
  351. }
  352. return Column(
  353. children: [
  354. Row(
  355. children: [
  356. MenuSectionTitle(title: S.of(context).status),
  357. Expanded(child: Container()),
  358. _status!.pendingItems > 0
  359. ? EnteLoadingWidget(
  360. color: getEnteColorScheme(context).fillMuted,
  361. )
  362. : const SizedBox.shrink(),
  363. ],
  364. ),
  365. MenuItemWidget(
  366. captionedTextWidget: CaptionedTextWidget(
  367. title: S.of(context).indexedItems,
  368. ),
  369. trailingWidget: Text(
  370. NumberFormat().format(_status!.indexedItems),
  371. style: Theme.of(context).textTheme.bodySmall,
  372. ),
  373. singleBorderRadius: 8,
  374. alignCaptionedTextToLeft: true,
  375. isGestureDetectorDisabled: true,
  376. // Setting a key here to ensure trailingWidget is refreshed
  377. key: ValueKey("indexed_items_" + _status!.indexedItems.toString()),
  378. ),
  379. MenuItemWidget(
  380. captionedTextWidget: CaptionedTextWidget(
  381. title: S.of(context).pendingItems,
  382. ),
  383. trailingWidget: Text(
  384. NumberFormat().format(_status!.pendingItems),
  385. style: Theme.of(context).textTheme.bodySmall,
  386. ),
  387. singleBorderRadius: 8,
  388. alignCaptionedTextToLeft: true,
  389. isGestureDetectorDisabled: true,
  390. // Setting a key here to ensure trailingWidget is refreshed
  391. key: ValueKey("pending_items_" + _status!.pendingItems.toString()),
  392. ),
  393. ],
  394. );
  395. }
  396. }
  397. class FaceRecognitionStatusWidget extends StatefulWidget {
  398. const FaceRecognitionStatusWidget({
  399. super.key,
  400. });
  401. @override
  402. State<FaceRecognitionStatusWidget> createState() =>
  403. FaceRecognitionStatusWidgetState();
  404. }
  405. class FaceRecognitionStatusWidgetState
  406. extends State<FaceRecognitionStatusWidget> {
  407. Timer? _timer;
  408. @override
  409. void initState() {
  410. super.initState();
  411. _timer = Timer.periodic(const Duration(seconds: 10), (timer) {
  412. setState(() {
  413. // Your state update logic here
  414. });
  415. });
  416. }
  417. Future<(int, int, double, bool)> getIndexStatus() async {
  418. try {
  419. final indexedFiles = await FaceMLDataDB.instance
  420. .getIndexedFileCount(minimumMlVersion: faceMlVersion);
  421. final indexableFiles = (await getIndexableFileIDs()).length;
  422. final showIndexedFiles = min(indexedFiles, indexableFiles);
  423. final pendingFiles = max(indexableFiles - indexedFiles, 0);
  424. final clusteringDoneRatio =
  425. await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
  426. final bool deviceIsHealthy =
  427. MachineLearningController.instance.isDeviceHealthy;
  428. return (
  429. showIndexedFiles,
  430. pendingFiles,
  431. clusteringDoneRatio,
  432. deviceIsHealthy
  433. );
  434. } catch (e, s) {
  435. _logger.severe('Error getting face recognition status', e, s);
  436. rethrow;
  437. }
  438. }
  439. @override
  440. void dispose() {
  441. _timer?.cancel();
  442. super.dispose();
  443. }
  444. @override
  445. Widget build(BuildContext context) {
  446. return Column(
  447. children: [
  448. Row(
  449. children: [
  450. MenuSectionTitle(title: S.of(context).status),
  451. Expanded(child: Container()),
  452. ],
  453. ),
  454. FutureBuilder(
  455. future: getIndexStatus(),
  456. builder: (context, snapshot) {
  457. if (snapshot.hasData) {
  458. final int indexedFiles = snapshot.data!.$1;
  459. final int pendingFiles = snapshot.data!.$2;
  460. final double clusteringDoneRatio = snapshot.data!.$3;
  461. final double clusteringPercentage =
  462. (clusteringDoneRatio * 100).clamp(0, 100);
  463. final bool isDeviceHealthy = snapshot.data!.$4;
  464. if (!isDeviceHealthy &&
  465. (pendingFiles > 0 || clusteringPercentage < 99)) {
  466. return MenuSectionDescriptionWidget(
  467. content: S.of(context).indexingIsPaused,
  468. );
  469. }
  470. return Column(
  471. children: [
  472. MenuItemWidget(
  473. captionedTextWidget: CaptionedTextWidget(
  474. title: S.of(context).indexedItems,
  475. ),
  476. trailingWidget: Text(
  477. NumberFormat().format(indexedFiles),
  478. style: Theme.of(context).textTheme.bodySmall,
  479. ),
  480. singleBorderRadius: 8,
  481. alignCaptionedTextToLeft: true,
  482. isGestureDetectorDisabled: true,
  483. key: ValueKey("indexed_items_" + indexedFiles.toString()),
  484. ),
  485. MenuItemWidget(
  486. captionedTextWidget: CaptionedTextWidget(
  487. title: S.of(context).pendingItems,
  488. ),
  489. trailingWidget: Text(
  490. NumberFormat().format(pendingFiles),
  491. style: Theme.of(context).textTheme.bodySmall,
  492. ),
  493. singleBorderRadius: 8,
  494. alignCaptionedTextToLeft: true,
  495. isGestureDetectorDisabled: true,
  496. key: ValueKey("pending_items_" + pendingFiles.toString()),
  497. ),
  498. MenuItemWidget(
  499. captionedTextWidget: CaptionedTextWidget(
  500. title: S.of(context).clusteringProgress,
  501. ),
  502. trailingWidget: Text(
  503. "${clusteringPercentage.toStringAsFixed(0)}%",
  504. style: Theme.of(context).textTheme.bodySmall,
  505. ),
  506. singleBorderRadius: 8,
  507. alignCaptionedTextToLeft: true,
  508. isGestureDetectorDisabled: true,
  509. key: ValueKey(
  510. "clustering_progress_" +
  511. clusteringPercentage.toStringAsFixed(0),
  512. ),
  513. ),
  514. ],
  515. );
  516. }
  517. return const EnteLoadingWidget();
  518. },
  519. ),
  520. ],
  521. );
  522. }
  523. }