machine_learning_settings_page.dart 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import "dart:async";
  2. import "package:flutter/material.dart";
  3. import "package:intl/intl.dart";
  4. import "package:photos/core/event_bus.dart";
  5. import 'package:photos/events/embedding_updated_event.dart';
  6. import "package:photos/generated/l10n.dart";
  7. import "package:photos/services/feature_flag_service.dart";
  8. import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
  9. import "package:photos/services/semantic_search/semantic_search_service.dart";
  10. import "package:photos/theme/ente_theme.dart";
  11. import "package:photos/ui/common/loading_widget.dart";
  12. import "package:photos/ui/components/buttons/icon_button_widget.dart";
  13. import "package:photos/ui/components/captioned_text_widget.dart";
  14. import "package:photos/ui/components/menu_item_widget/menu_item_widget.dart";
  15. import "package:photos/ui/components/menu_section_description_widget.dart";
  16. import "package:photos/ui/components/menu_section_title.dart";
  17. import "package:photos/ui/components/title_bar_title_widget.dart";
  18. import "package:photos/ui/components/title_bar_widget.dart";
  19. import "package:photos/ui/components/toggle_switch_widget.dart";
  20. import "package:photos/utils/local_settings.dart";
  21. class MachineLearningSettingsPage extends StatefulWidget {
  22. const MachineLearningSettingsPage({super.key});
  23. @override
  24. State<MachineLearningSettingsPage> createState() =>
  25. _MachineLearningSettingsPageState();
  26. }
  27. class _MachineLearningSettingsPageState
  28. extends State<MachineLearningSettingsPage> {
  29. late InitializationState _state;
  30. late StreamSubscription<MLFrameworkInitializationUpdateEvent>
  31. _eventSubscription;
  32. @override
  33. void initState() {
  34. super.initState();
  35. _eventSubscription =
  36. Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
  37. _fetchState();
  38. setState(() {});
  39. });
  40. _fetchState();
  41. }
  42. void _fetchState() {
  43. _state = SemanticSearchService.instance.getFrameworkInitializationState();
  44. }
  45. @override
  46. void dispose() {
  47. super.dispose();
  48. _eventSubscription.cancel();
  49. }
  50. @override
  51. Widget build(BuildContext context) {
  52. return Scaffold(
  53. body: CustomScrollView(
  54. primary: false,
  55. slivers: <Widget>[
  56. TitleBarWidget(
  57. flexibleSpaceTitle: TitleBarTitleWidget(
  58. title: S.of(context).machineLearning,
  59. ),
  60. actionIcons: [
  61. IconButtonWidget(
  62. icon: Icons.close_outlined,
  63. iconButtonType: IconButtonType.secondary,
  64. onTap: () {
  65. Navigator.pop(context);
  66. Navigator.pop(context);
  67. Navigator.pop(context);
  68. },
  69. ),
  70. ],
  71. ),
  72. SliverList(
  73. delegate: SliverChildBuilderDelegate(
  74. (delegateBuildContext, index) {
  75. return Padding(
  76. padding: const EdgeInsets.symmetric(horizontal: 16),
  77. child: Padding(
  78. padding: const EdgeInsets.symmetric(vertical: 20),
  79. child: Column(
  80. mainAxisSize: MainAxisSize.min,
  81. children: [
  82. _getMagicSearchSettings(context),
  83. ],
  84. ),
  85. ),
  86. );
  87. },
  88. childCount: 1,
  89. ),
  90. ),
  91. ],
  92. ),
  93. );
  94. }
  95. Widget _getMagicSearchSettings(BuildContext context) {
  96. final colorScheme = getEnteColorScheme(context);
  97. final hasEnabled = LocalSettings.instance.hasEnabledMagicSearch();
  98. return Column(
  99. children: [
  100. MenuItemWidget(
  101. captionedTextWidget: CaptionedTextWidget(
  102. title: S.of(context).magicSearch,
  103. ),
  104. menuItemColor: colorScheme.fillFaint,
  105. trailingWidget: ToggleSwitchWidget(
  106. value: () => LocalSettings.instance.hasEnabledMagicSearch(),
  107. onChanged: () async {
  108. await LocalSettings.instance.setShouldEnableMagicSearch(
  109. !LocalSettings.instance.hasEnabledMagicSearch(),
  110. );
  111. if (LocalSettings.instance.hasEnabledMagicSearch()) {
  112. unawaited(
  113. SemanticSearchService.instance
  114. .init(shouldSyncImmediately: true),
  115. );
  116. } else {
  117. await SemanticSearchService.instance.clearQueue();
  118. }
  119. setState(() {});
  120. },
  121. ),
  122. singleBorderRadius: 8,
  123. alignCaptionedTextToLeft: true,
  124. isGestureDetectorDisabled: true,
  125. ),
  126. const SizedBox(
  127. height: 4,
  128. ),
  129. MenuSectionDescriptionWidget(
  130. content: S.of(context).magicSearchDescription,
  131. ),
  132. const SizedBox(
  133. height: 12,
  134. ),
  135. hasEnabled
  136. ? Column(
  137. children: [
  138. _state == InitializationState.initialized
  139. ? const MagicSearchIndexStatsWidget()
  140. : ModelLoadingState(_state),
  141. const SizedBox(
  142. height: 12,
  143. ),
  144. FeatureFlagService.instance.isInternalUserOrDebugBuild()
  145. ? MenuItemWidget(
  146. leadingIcon: Icons.delete_sweep_outlined,
  147. captionedTextWidget: CaptionedTextWidget(
  148. title: S.of(context).clearIndexes,
  149. ),
  150. menuItemColor: getEnteColorScheme(context).fillFaint,
  151. singleBorderRadius: 8,
  152. alwaysShowSuccessState: true,
  153. onTap: () async {
  154. await SemanticSearchService.instance.clearIndexes();
  155. if (mounted) {
  156. setState(() => {});
  157. }
  158. },
  159. )
  160. : const SizedBox.shrink(),
  161. ],
  162. )
  163. : const SizedBox.shrink(),
  164. ],
  165. );
  166. }
  167. }
  168. class ModelLoadingState extends StatelessWidget {
  169. final InitializationState state;
  170. const ModelLoadingState(
  171. this.state, {
  172. Key? key,
  173. }) : super(key: key);
  174. @override
  175. Widget build(BuildContext context) {
  176. return Column(
  177. children: [
  178. MenuSectionTitle(title: S.of(context).status),
  179. MenuItemWidget(
  180. captionedTextWidget: CaptionedTextWidget(
  181. title: _getTitle(context),
  182. ),
  183. trailingWidget: EnteLoadingWidget(
  184. size: 12,
  185. color: getEnteColorScheme(context).fillMuted,
  186. ),
  187. singleBorderRadius: 8,
  188. alignCaptionedTextToLeft: true,
  189. isGestureDetectorDisabled: true,
  190. ),
  191. ],
  192. );
  193. }
  194. String _getTitle(BuildContext context) {
  195. switch (state) {
  196. case InitializationState.waitingForNetwork:
  197. return S.of(context).waitingForWifi;
  198. default:
  199. return S.of(context).loadingModel;
  200. }
  201. }
  202. }
  203. class MagicSearchIndexStatsWidget extends StatefulWidget {
  204. const MagicSearchIndexStatsWidget({
  205. super.key,
  206. });
  207. @override
  208. State<MagicSearchIndexStatsWidget> createState() =>
  209. _MagicSearchIndexStatsWidgetState();
  210. }
  211. class _MagicSearchIndexStatsWidgetState
  212. extends State<MagicSearchIndexStatsWidget> {
  213. IndexStatus? _status;
  214. late StreamSubscription<EmbeddingUpdatedEvent> _eventSubscription;
  215. @override
  216. void initState() {
  217. super.initState();
  218. _eventSubscription =
  219. Bus.instance.on<EmbeddingUpdatedEvent>().listen((event) {
  220. _fetchIndexStatus();
  221. });
  222. _fetchIndexStatus();
  223. }
  224. void _fetchIndexStatus() {
  225. SemanticSearchService.instance.getIndexStatus().then((status) {
  226. _status = status;
  227. setState(() {});
  228. });
  229. }
  230. @override
  231. void dispose() {
  232. super.dispose();
  233. _eventSubscription.cancel();
  234. }
  235. @override
  236. Widget build(BuildContext context) {
  237. if (_status == null) {
  238. return const EnteLoadingWidget();
  239. }
  240. return Column(
  241. children: [
  242. Row(
  243. children: [
  244. MenuSectionTitle(title: S.of(context).status),
  245. Expanded(child: Container()),
  246. _status!.pendingItems > 0
  247. ? EnteLoadingWidget(
  248. color: getEnteColorScheme(context).fillMuted,
  249. )
  250. : const SizedBox.shrink(),
  251. ],
  252. ),
  253. MenuItemWidget(
  254. captionedTextWidget: CaptionedTextWidget(
  255. title: S.of(context).indexedItems,
  256. ),
  257. trailingWidget: Text(
  258. NumberFormat().format(_status!.indexedItems),
  259. style: Theme.of(context).textTheme.bodySmall,
  260. ),
  261. singleBorderRadius: 8,
  262. alignCaptionedTextToLeft: true,
  263. isGestureDetectorDisabled: true,
  264. // Setting a key here to ensure trailingWidget is refreshed
  265. key: ValueKey("indexed_items_" + _status!.indexedItems.toString()),
  266. ),
  267. MenuItemWidget(
  268. captionedTextWidget: CaptionedTextWidget(
  269. title: S.of(context).pendingItems,
  270. ),
  271. trailingWidget: Text(
  272. NumberFormat().format(_status!.pendingItems),
  273. style: Theme.of(context).textTheme.bodySmall,
  274. ),
  275. singleBorderRadius: 8,
  276. alignCaptionedTextToLeft: true,
  277. isGestureDetectorDisabled: true,
  278. // Setting a key here to ensure trailingWidget is refreshed
  279. key: ValueKey("pending_items_" + _status!.pendingItems.toString()),
  280. ),
  281. ],
  282. );
  283. }
  284. }