Browse Source

Merge branch 'main' into migrate_files_db_to_sqlite_async

ashilkn 1 năm trước cách đây
mục cha
commit
a470ed4dfa

+ 39 - 17
mobile/lib/face/db.dart

@@ -13,6 +13,8 @@ import "package:photos/face/model/face.dart";
 import "package:photos/models/file/file.dart";
 import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
 import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
+import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
+import "package:photos/utils/ml_util.dart";
 import 'package:sqlite_async/sqlite_async.dart';
 
 /// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`.
@@ -249,7 +251,7 @@ class FaceMLDataDB {
       final List<int> fileId = [recentFileID];
       int? avatarFileId;
       if (avatarFaceId != null) {
-        avatarFileId = int.tryParse(avatarFaceId.split('_')[0]);
+        avatarFileId = tryGetFileIdFromFaceId(avatarFaceId);
         if (avatarFileId != null) {
           fileId.add(avatarFileId);
         }
@@ -401,8 +403,10 @@ class FaceMLDataDB {
       final personID = map[personIdColumn] as String;
       final clusterID = map[fcClusterID] as int;
       final faceID = map[fcFaceId] as String;
-      result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {})
-        .add(faceID);
+      result
+          .putIfAbsent(personID, () => {})
+          .putIfAbsent(clusterID, () => {})
+          .add(faceID);
     }
     return result;
   }
@@ -476,8 +480,7 @@ class FaceMLDataDB {
     for (final map in maps) {
       final clusterID = map[fcClusterID] as int;
       final faceID = map[fcFaceId] as String;
-      final x = faceID.split('_').first;
-      final fileID = int.parse(x);
+      final fileID = getFileIdFromFaceId(faceID);
       result[fileID] = (result[fileID] ?? {})..add(clusterID);
     }
     return result;
@@ -665,19 +668,38 @@ class FaceMLDataDB {
     return maps.first['count'] as int;
   }
 
-  Future<int> getClusteredFaceCount() async {
+  Future<int> getClusteredFileCount() async {
     final db = await instance.asyncDB;
     final List<Map<String, dynamic>> maps = await db.getAll(
-      'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable',
+      'SELECT $fcFaceId FROM $faceClustersTable',
     );
-    return maps.first['count'] as int;
+    final Set<int> fileIDs = {};
+    for (final map in maps) {
+      final int fileID = getFileIdFromFaceId(map[fcFaceId] as String);
+      fileIDs.add(fileID);
+    }
+    return fileIDs.length;
   }
 
-  Future<double> getClusteredToTotalFacesRatio() async {
-    final int totalFaces = await getTotalFaceCount();
-    final int clusteredFaces = await getClusteredFaceCount();
+  Future<double> getClusteredToIndexableFilesRatio() async {
+    final int indexableFiles = (await getIndexableFileIDs()).length;
+    final int clusteredFiles = await getClusteredFileCount();
+
+    return clusteredFiles / indexableFiles;
+  }
 
-    return clusteredFaces / totalFaces;
+  Future<int> getUnclusteredFaceCount() async {
+    final db = await instance.asyncDB;
+    const String query = '''
+      SELECT f.$faceIDColumn
+      FROM $facesTable f
+      LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$fcFaceId
+      WHERE f.$faceScore > $kMinimumQualityFaceScore
+      AND f.$faceBlur > $kLaplacianHardThreshold
+      AND fc.$fcFaceId IS NULL
+    ''';
+    final List<Map<String, dynamic>> maps = await db.getAll(query);
+    return maps.length;
   }
 
   Future<int> getBlurryFaceCount([
@@ -795,7 +817,7 @@ class FaceMLDataDB {
       for (final map in maps) {
         final clusterID = map[clusterIDColumn] as int;
         final String faceID = map[fcFaceId] as String;
-        final fileID = int.parse(faceID.split('_').first);
+        final fileID = getFileIdFromFaceId(faceID);
         result[fileID] = (result[fileID] ?? {})..add(clusterID);
       }
       return result;
@@ -814,8 +836,8 @@ class FaceMLDataDB {
       final Map<int, Set<int>> result = {};
       for (final map in maps) {
         final clusterID = map[fcClusterID] as int;
-        final faceId = map[fcFaceId] as String;
-        final fileID = int.parse(faceId.split("_").first);
+        final faceID = map[fcFaceId] as String;
+        final fileID = getFileIdFromFaceId(faceID);
         result[fileID] = (result[fileID] ?? {})..add(clusterID);
       }
       return result;
@@ -964,7 +986,7 @@ class FaceMLDataDB {
     final Map<String, int> faceIDToClusterID = {};
     for (final row in faceIdsResult) {
       final faceID = row[fcFaceId] as String;
-      if (fileIds.contains(faceID.split('_').first)) {
+      if (fileIds.contains(getFileIdFromFaceId(faceID))) {
         maxClusterID += 1;
         faceIDToClusterID[faceID] = maxClusterID;
       }
@@ -990,7 +1012,7 @@ class FaceMLDataDB {
     final Map<String, int> faceIDToClusterID = {};
     for (final row in faceIdsResult) {
       final faceID = row[fcFaceId] as String;
-      if (fileIds.contains(faceID.split('_').first)) {
+      if (fileIds.contains(getFileIdFromFaceId(faceID))) {
         maxClusterID += 1;
         faceIDToClusterID[faceID] = maxClusterID;
       }

+ 20 - 36
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -498,19 +498,8 @@ class FaceClusteringService {
       }
     }
 
-    // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
     if (fileIDToCreationTime != null) {
-      faceInfos.sort((a, b) {
-        if (a.fileCreationTime == null && b.fileCreationTime == null) {
-          return 0;
-        } else if (a.fileCreationTime == null) {
-          return 1;
-        } else if (b.fileCreationTime == null) {
-          return -1;
-        } else {
-          return a.fileCreationTime!.compareTo(b.fileCreationTime!);
-        }
-      });
+      _sortFaceInfosOnCreationTime(faceInfos);
     }
 
     // Sort the faceInfos such that the ones with null clusterId are at the end
@@ -796,19 +785,8 @@ class FaceClusteringService {
       );
     }
 
-    // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
     if (fileIDToCreationTime != null) {
-      faceInfos.sort((a, b) {
-        if (a.fileCreationTime == null && b.fileCreationTime == null) {
-          return 0;
-        } else if (a.fileCreationTime == null) {
-          return 1;
-        } else if (b.fileCreationTime == null) {
-          return -1;
-        } else {
-          return a.fileCreationTime!.compareTo(b.fileCreationTime!);
-        }
-      });
+      _sortFaceInfosOnCreationTime(faceInfos);
     }
 
     if (faceInfos.isEmpty) {
@@ -996,19 +974,8 @@ class FaceClusteringService {
       );
     }
 
-    // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
     if (fileIDToCreationTime != null) {
-      faceInfos.sort((a, b) {
-        if (a.fileCreationTime == null && b.fileCreationTime == null) {
-          return 0;
-        } else if (a.fileCreationTime == null) {
-          return 1;
-        } else if (b.fileCreationTime == null) {
-          return -1;
-        } else {
-          return a.fileCreationTime!.compareTo(b.fileCreationTime!);
-        }
-      });
+      _sortFaceInfosOnCreationTime(faceInfos);
     }
 
     // Get the embeddings
@@ -1027,3 +994,20 @@ class FaceClusteringService {
     return clusteredFaceIDs;
   }
 }
+
+/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first
+void _sortFaceInfosOnCreationTime(
+  List<FaceInfo> faceInfos,
+) {
+  faceInfos.sort((b, a) {
+    if (a.fileCreationTime == null && b.fileCreationTime == null) {
+      return 0;
+    } else if (a.fileCreationTime == null) {
+      return 1;
+    } else if (b.fileCreationTime == null) {
+      return -1;
+    } else {
+      return a.fileCreationTime!.compareTo(b.fileCreationTime!);
+    }
+  });
+}

+ 5 - 1
mobile/lib/services/machine_learning/face_ml/face_ml_result.dart

@@ -310,5 +310,9 @@ class FaceResultBuilder {
 }
 
 int getFileIdFromFaceId(String faceId) {
-  return int.parse(faceId.split("_")[0]);
+  return int.parse(faceId.split("_").first);
 }
+
+int? tryGetFileIdFromFaceId(String faceId) {
+  return int.tryParse(faceId.split("_").first);
+}

+ 11 - 28
mobile/lib/services/machine_learning/face_ml/face_ml_service.dart

@@ -12,7 +12,6 @@ import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
 import "package:logging/logging.dart";
 import "package:onnxruntime/onnxruntime.dart";
 import "package:package_info_plus/package_info_plus.dart";
-import "package:photos/core/configuration.dart";
 import "package:photos/core/event_bus.dart";
 import "package:photos/db/files_db.dart";
 import "package:photos/events/diff_sync_complete_event.dart";
@@ -99,6 +98,7 @@ class FaceMlService {
 
   final int _fileDownloadLimit = 5;
   final int _embeddingFetchLimit = 200;
+  final int _kForceClusteringFaceCount = 4000;
 
   Future<void> init({bool initializeImageMlIsolate = false}) async {
     if (LocalSettings.instance.isFaceIndexingEnabled == false) {
@@ -358,16 +358,17 @@ class FaceMlService {
     if (_cannotRunMLFunction()) return;
 
     await sync(forceSync: _shouldSyncPeople);
-    await indexAllImages();
-    final indexingCompleteRatio = await _getIndexedDoneRatio();
-    if (indexingCompleteRatio < 0.95) {
+
+    final int unclusteredFacesCount =
+        await FaceMLDataDB.instance.getUnclusteredFaceCount();
+    if (unclusteredFacesCount > _kForceClusteringFaceCount) {
       _logger.info(
-        "Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio",
+        "There are $unclusteredFacesCount unclustered faces, doing clustering first",
       );
-      return;
-    } else {
       await clusterAllImages();
     }
+    await indexAllImages();
+    await clusterAllImages();
   }
 
   void pauseIndexingAndClustering() {
@@ -445,7 +446,7 @@ class FaceMlService {
 
         if (LocalSettings.instance.remoteFetchEnabled) {
           try {
-            final List<int> fileIds = [];
+            final Set<int> fileIds = {}; // if there are duplicates here server returns 400
             // Try to find embeddings on the remote server
             for (final f in chunk) {
               fileIds.add(f.uploadedFileID!);
@@ -590,8 +591,8 @@ class FaceMlService {
           allFaceInfoForClustering.add(faceInfo);
         }
       }
-      // sort the embeddings based on file creation time, oldest first
-      allFaceInfoForClustering.sort((a, b) {
+      // sort the embeddings based on file creation time, newest first
+      allFaceInfoForClustering.sort((b, a) {
         return fileIDToCreationTime[a.fileID]!
             .compareTo(fileIDToCreationTime[b.fileID]!);
       });
@@ -1171,24 +1172,6 @@ class FaceMlService {
     }
   }
 
-  Future<double> _getIndexedDoneRatio() async {
-    final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start();
-
-    final int alreadyIndexedCount = await FaceMLDataDB.instance
-        .getIndexedFileCount(minimumMlVersion: faceMlVersion);
-    final int totalIndexableCount = (await getIndexableFileIDs()).length;
-    final ratio = alreadyIndexedCount / totalIndexableCount;
-
-    w?.log('getIndexedDoneRatio');
-
-    return ratio;
-  }
-
-  static Future<List<int>> getIndexableFileIDs() async {
-    return FilesDB.instance
-        .getOwnedFileIDs(Configuration.instance.getUserID()!);
-  }
-
   bool _skipAnalysisEnteFile(EnteFile enteFile, Map<int, int> indexedFileIds) {
     if (_isIndexingOrClusteringRunning == false ||
         _mlControllerStatus == false) {

+ 12 - 8
mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart

@@ -1,6 +1,7 @@
 import "dart:async";
 import "dart:convert";
 
+import "package:computer/computer.dart";
 import "package:logging/logging.dart";
 import "package:photos/core/network/network.dart";
 import "package:photos/db/files_db.dart";
@@ -16,6 +17,8 @@ import "package:shared_preferences/shared_preferences.dart";
 class RemoteFileMLService {
   RemoteFileMLService._privateConstructor();
 
+  static final Computer _computer = Computer.shared();
+
   static final RemoteFileMLService instance =
       RemoteFileMLService._privateConstructor();
 
@@ -52,13 +55,13 @@ class RemoteFileMLService {
   }
 
   Future<FilesMLDataResponse> getFilessEmbedding(
-    List<int> fileIds,
+    Set<int> fileIds,
   ) async {
     try {
       final res = await _dio.post(
         "/embeddings/files",
         data: {
-          "fileIDs": fileIds,
+          "fileIDs": fileIds.toList(),
           "model": 'file-ml-clip-face',
         },
       );
@@ -107,15 +110,17 @@ class RemoteFileMLService {
       final input = EmbeddingsDecoderInput(embedding, fileKey);
       inputs.add(input);
     }
-    // todo: use compute or isolate
-    return decryptFileMLComputer(
-      {
+    return _computer.compute<Map<String, dynamic>, Map<int, FileMl>>(
+      _decryptFileMLComputer,
+      param: {
         "inputs": inputs,
       },
     );
   }
 
-  Future<Map<int, FileMl>> decryptFileMLComputer(
+}
+
+Future<Map<int, FileMl>> _decryptFileMLComputer(
     Map<String, dynamic> args,
   ) async {
     final result = <int, FileMl>{};
@@ -134,5 +139,4 @@ class RemoteFileMLService {
       result[input.embedding.fileID] = decodedEmbedding;
     }
     return result;
-  }
-}
+  }

+ 3 - 4
mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart

@@ -23,6 +23,7 @@ import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx
 import "package:photos/utils/debouncer.dart";
 import "package:photos/utils/device_info.dart";
 import "package:photos/utils/local_settings.dart";
+import "package:photos/utils/ml_util.dart";
 import "package:photos/utils/thumbnail_util.dart";
 
 class SemanticSearchService {
@@ -160,8 +161,7 @@ class SemanticSearchService {
   }
 
   Future<IndexStatus> getIndexStatus() async {
-    final indexableFileIDs = await FilesDB.instance
-        .getOwnedFileIDs(Configuration.instance.getUserID()!);
+    final indexableFileIDs = await getIndexableFileIDs();
     return IndexStatus(
       min(_cachedEmbeddings.length, indexableFileIDs.length),
       (await _getFileIDsToBeIndexed()).length,
@@ -222,8 +222,7 @@ class SemanticSearchService {
   }
 
   Future<List<int>> _getFileIDsToBeIndexed() async {
-    final uploadedFileIDs = await FilesDB.instance
-        .getOwnedFileIDs(Configuration.instance.getUserID()!);
+    final uploadedFileIDs = await getIndexableFileIDs();
     final embeddedFileIDs =
         await EmbeddingsDB.instance.getFileIDs(_currentModel);
 

+ 0 - 9
mobile/lib/services/search_service.dart

@@ -754,15 +754,6 @@ class SearchService {
 
   Future<List<GenericSearchResult>> getAllFace(int? limit) async {
     try {
-      // Don't return anything if clustering is not nearly complete yet
-      final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
-      final clusteredFaces =
-          await FaceMLDataDB.instance.getClusteredFaceCount();
-      final clusteringDoneRatio = clusteredFaces / foundFaces;
-      if (clusteringDoneRatio < 0.9) {
-        return [];
-      }
-
       debugPrint("getting faces");
       final Map<int, Set<int>> fileIdToClusterID =
           await FaceMLDataDB.instance.getFileIdToClusterIds();

+ 1 - 1
mobile/lib/ui/settings/debug/face_debug_section_widget.dart

@@ -177,7 +177,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
         sectionOptionSpacing,
         MenuItemWidget(
           captionedTextWidget: FutureBuilder<double>(
-            future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(),
+            future: FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(),
             builder: (context, snapshot) {
               if (snapshot.hasData) {
                 return CaptionedTextWidget(

+ 6 - 22
mobile/lib/ui/settings/machine_learning_settings_page.dart

@@ -26,6 +26,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
 import "package:photos/ui/components/toggle_switch_widget.dart";
 import "package:photos/utils/data_util.dart";
 import "package:photos/utils/local_settings.dart";
+import "package:photos/utils/ml_util.dart";
 
 final _logger = Logger("MachineLearningSettingsPage");
 
@@ -438,19 +439,16 @@ class FaceRecognitionStatusWidgetState
     });
   }
 
-  Future<(int, int, int, double)> getIndexStatus() async {
+  Future<(int, int, double)> getIndexStatus() async {
     try {
       final indexedFiles = await FaceMLDataDB.instance
           .getIndexedFileCount(minimumMlVersion: faceMlVersion);
-      final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
+      final indexableFiles = (await getIndexableFileIDs()).length;
       final showIndexedFiles = min(indexedFiles, indexableFiles);
       final pendingFiles = max(indexableFiles - indexedFiles, 0);
-      final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
-      final clusteredFaces =
-          await FaceMLDataDB.instance.getClusteredFaceCount();
-      final clusteringDoneRatio = clusteredFaces / foundFaces;
+      final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
 
-      return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
+      return (showIndexedFiles, pendingFiles, clusteringDoneRatio);
     } catch (e, s) {
       _logger.severe('Error getting face recognition status', e, s);
       rethrow;
@@ -479,8 +477,7 @@ class FaceRecognitionStatusWidgetState
             if (snapshot.hasData) {
               final int indexedFiles = snapshot.data!.$1;
               final int pendingFiles = snapshot.data!.$2;
-              final int foundFaces = snapshot.data!.$3;
-              final double clusteringDoneRatio = snapshot.data!.$4;
+              final double clusteringDoneRatio = snapshot.data!.$3;
               final double clusteringPercentage =
                   (clusteringDoneRatio * 100).clamp(0, 100);
 
@@ -512,19 +509,6 @@ class FaceRecognitionStatusWidgetState
                     isGestureDetectorDisabled: true,
                     key: ValueKey("pending_items_" + pendingFiles.toString()),
                   ),
-                  MenuItemWidget(
-                    captionedTextWidget: CaptionedTextWidget(
-                      title: S.of(context).foundFaces,
-                    ),
-                    trailingWidget: Text(
-                      NumberFormat().format(foundFaces),
-                      style: Theme.of(context).textTheme.bodySmall,
-                    ),
-                    singleBorderRadius: 8,
-                    alignCaptionedTextToLeft: true,
-                    isGestureDetectorDisabled: true,
-                    key: ValueKey("found_faces_" + foundFaces.toString()),
-                  ),
                   MenuItemWidget(
                     captionedTextWidget: CaptionedTextWidget(
                       title: S.of(context).clusteringProgress,

+ 7 - 0
mobile/lib/utils/ml_util.dart

@@ -0,0 +1,7 @@
+import "package:photos/core/configuration.dart";
+import "package:photos/db/files_db.dart";
+
+Future<List<int>> getIndexableFileIDs() async {
+    return FilesDB.instance
+        .getOwnedFileIDs(Configuration.instance.getUserID()!);
+  }

+ 1 - 1
mobile/pubspec.yaml

@@ -12,7 +12,7 @@ description: ente photos application
 # Read more about iOS versioning at
 # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
 
-version: 0.8.109+633
+version: 0.8.110+634
 publish_to: none
 
 environment:

+ 1 - 0
web/apps/auth/package.json

@@ -7,6 +7,7 @@
         "@ente/accounts": "*",
         "@ente/eslint-config": "*",
         "@ente/shared": "*",
+        "jssha": "~3.3.1",
         "otpauth": "^9"
     }
 }

+ 14 - 24
web/apps/auth/src/pages/auth.tsx

@@ -46,14 +46,11 @@ const AuthenticatorCodesPage = () => {
         appContext.showNavBar(false);
     }, []);
 
+    const lcSearch = searchTerm.toLowerCase();
     const filteredCodes = codes.filter(
-        (secret) =>
-            (secret.issuer ?? "")
-                .toLowerCase()
-                .includes(searchTerm.toLowerCase()) ||
-            (secret.account ?? "")
-                .toLowerCase()
-                .includes(searchTerm.toLowerCase()),
+        (code) =>
+            code.issuer?.toLowerCase().includes(lcSearch) ||
+            code.account?.toLowerCase().includes(lcSearch),
     );
 
     if (!hasFetched) {
@@ -190,28 +187,21 @@ const CodeDisplay: React.FC<CodeDisplay> = ({ code }) => {
     useEffect(() => {
         // Generate to set the initial otp and nextOTP on component mount.
         regen();
-        const codeType = code.type;
-        const codePeriodInMs = code.period * 1000;
-        const timeToNextCode =
-            codePeriodInMs - (new Date().getTime() % codePeriodInMs);
-        const interval = null;
+
+        const periodMs = code.period * 1000;
+        const timeToNextCode = periodMs - (Date.now() % periodMs);
+
+        let interval: ReturnType<typeof setInterval> | undefined;
         // Wait until we are at the start of the next code period, and then
         // start the interval loop.
         setTimeout(() => {
             // We need to call regen() once before the interval loop to set the
             // initial otp and nextOTP.
             regen();
-            codeType.toLowerCase() === "totp" ||
-            codeType.toLowerCase() === "hotp"
-                ? setInterval(() => {
-                      regen();
-                  }, codePeriodInMs)
-                : null;
+            interval = setInterval(regen, periodMs);
         }, timeToNextCode);
 
-        return () => {
-            if (interval) clearInterval(interval);
-        };
+        return () => interval && clearInterval(interval);
     }, [code]);
 
     return (
@@ -270,7 +260,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
                             textAlign: "left",
                         }}
                     >
-                        {code.issuer}
+                        {code.issuer ?? ""}
                     </p>
                     <p
                         style={{
@@ -283,7 +273,7 @@ const OTPDisplay: React.FC<OTPDisplayProps> = ({ code, otp, nextOTP }) => {
                             color: "grey",
                         }}
                     >
-                        {code.account}
+                        {code.account ?? ""}
                     </p>
                     <p
                         style={{
@@ -349,7 +339,7 @@ const TimerProgress: React.FC<TimerProgressProps> = ({ period }) => {
 
     useEffect(() => {
         const advance = () => {
-            const timeRemaining = us - ((new Date().getTime() * 1000) % us);
+            const timeRemaining = us - ((Date.now() * 1000) % us);
             setProgress(timeRemaining / us);
         };
 

+ 111 - 84
web/apps/auth/src/services/code.ts

@@ -1,5 +1,6 @@
+import { ensure } from "@/utils/ensure";
 import { HOTP, TOTP } from "otpauth";
-import { URI } from "vscode-uri";
+import { Steam } from "./steam";
 
 /**
  * A parsed representation of an *OTP code URI.
@@ -10,13 +11,19 @@ export interface Code {
     /** A unique id for the corresponding "auth entity" in our system. */
     id?: String;
     /** The type of the code. */
-    type: "totp" | "hotp";
+    type: "totp" | "hotp" | "steam";
     /** The user's account or email for which this code is used. */
-    account: string;
+    account?: string;
     /** The name of the entity that issued this code. */
     issuer: string;
-    /** Number of digits in the generated OTP. */
-    digits: number;
+    /**
+     * Length of the generated OTP.
+     *
+     * This is vernacularly called "digits", which is an accurate description
+     * for the OG TOTP/HOTP codes. However, steam codes are not just digits, so
+     * we name this as a content-neutral "length".
+     */
+    length: number;
     /**
      * The time period (in seconds) for which a single OTP generated from this
      * code remains valid.
@@ -32,7 +39,7 @@ export interface Code {
     /** The (HMAC) algorithm used by the OTP generator. */
     algorithm: "sha1" | "sha256" | "sha512";
     /** The original string from which this code was generated. */
-    uriString?: string;
+    uriString: string;
 }
 
 /**
@@ -45,100 +52,109 @@ export interface Code {
  *
  * - (TOTP)
  *   otpauth://totp/ACME:user@example.org?algorithm=SHA1&digits=6&issuer=acme&period=30&secret=ALPHANUM
+ *
+ * See also `auth/test/models/code_test.dart`.
  */
 export const codeFromURIString = (id: string, uriString: string): Code => {
-    const santizedRawData = uriString
-        .replaceAll("+", "%2B")
-        .replaceAll(":", "%3A")
-        .replaceAll("\r", "")
-        // trim quotes
-        .replace(/^"|"$/g, "");
-
-    const uriParams = {};
-    const searchParamsString =
-        decodeURIComponent(santizedRawData).split("?")[1];
-    searchParamsString.split("&").forEach((pair) => {
-        const [key, value] = pair.split("=");
-        uriParams[key] = value;
-    });
-
-    const uri = URI.parse(santizedRawData);
-    let uriPath = decodeURIComponent(uri.path);
-    if (uriPath.startsWith("/otpauth://") || uriPath.startsWith("otpauth://")) {
-        uriPath = uriPath.split("otpauth://")[1];
-    } else if (uriPath.startsWith("otpauth%3A//")) {
-        uriPath = uriPath.split("otpauth%3A//")[1];
+    try {
+        return _codeFromURIString(id, uriString);
+    } catch (e) {
+        // We might have legacy encodings of account names that contain a "#",
+        // which causes the rest of the URL to be treated as a fragment, and
+        // ignored. See if this was potentially such a case, otherwise rethrow.
+        if (uriString.includes("#"))
+            return _codeFromURIString(id, uriString.replaceAll("#", "%23"));
+        throw e;
     }
+};
+
+const _codeFromURIString = (id: string, uriString: string): Code => {
+    const url = new URL(uriString);
+
+    // A URL like
+    //
+    // new URL("otpauth://hotp/Test?secret=AAABBBCCCDDDEEEFFF&issuer=Test&counter=0")
+    //
+    // is parsed differently by the browser and Node depending on the scheme.
+    // When the scheme is http(s), then both of them consider "hotp" as the
+    // `host`. However, when the scheme is "otpauth", as is our case, the
+    // browser considers the entire thing as part of the pathname. so we get.
+    //
+    //     host: ""
+    //     pathname: "//hotp/Test"
+    //
+    // Since this code run on browsers only, we parse as per that behaviour.
+
+    const [type, path] = parsePathname(url);
 
     return {
         id,
-        type: _getType(uriPath),
-        account: _getAccount(uriPath),
-        issuer: _getIssuer(uriPath, uriParams),
-        digits: parseDigits(uriParams),
-        period: parsePeriod(uriParams),
-        secret: parseSecret(uriParams),
-        algorithm: parseAlgorithm(uriParams),
+        type,
+        account: parseAccount(path),
+        issuer: parseIssuer(url, path),
+        length: parseLength(url, type),
+        period: parsePeriod(url),
+        secret: parseSecret(url),
+        algorithm: parseAlgorithm(url),
         uriString,
     };
 };
 
-const _getType = (uriPath: string): Code["type"] => {
-    const oauthType = uriPath.split("/")[0].substring(0);
-    if (oauthType.toLowerCase() === "totp") {
-        return "totp";
-    } else if (oauthType.toLowerCase() === "hotp") {
-        return "hotp";
-    }
-    throw new Error(`Unsupported format with host ${oauthType}`);
+const parsePathname = (url: URL): [type: Code["type"], path: string] => {
+    const p = url.pathname.toLowerCase();
+    if (p.startsWith("//totp")) return ["totp", url.pathname.slice(6)];
+    if (p.startsWith("//hotp")) return ["hotp", url.pathname.slice(6)];
+    if (p.startsWith("//steam")) return ["steam", url.pathname.slice(7)];
+    throw new Error(`Unsupported code or unparseable path "${url.pathname}"`);
 };
 
-const _getAccount = (uriPath: string): string => {
-    try {
-        const path = decodeURIComponent(uriPath);
-        if (path.includes(":")) {
-            return path.split(":")[1];
-        } else if (path.includes("/")) {
-            return path.split("/")[1];
-        }
-    } catch (e) {
-        return "";
-    }
+const parseAccount = (path: string): string | undefined => {
+    // "/ACME:user@example.org" => "user@example.org"
+    let p = decodeURIComponent(path);
+    if (p.startsWith("/")) p = p.slice(1);
+    if (p.includes(":")) p = p.split(":").slice(1).join(":");
+    return p;
 };
 
-const _getIssuer = (uriPath: string, uriParams: { get?: any }): string => {
-    try {
-        if (uriParams["issuer"] !== undefined) {
-            let issuer = uriParams["issuer"];
-            // This is to handle bug in the ente auth app
-            if (issuer.endsWith("period")) {
-                issuer = issuer.substring(0, issuer.length - 6);
-            }
-            return issuer;
+const parseIssuer = (url: URL, path: string): string => {
+    // If there is a "issuer" search param, use that.
+    let issuer = url.searchParams.get("issuer");
+    if (issuer) {
+        // This is to handle bug in old versions of Ente Auth app.
+        if (issuer.endsWith("period")) {
+            issuer = issuer.substring(0, issuer.length - 6);
         }
-        let path = decodeURIComponent(uriPath);
-        if (path.startsWith("totp/") || path.startsWith("hotp/")) {
-            path = path.substring(5);
-        }
-        if (path.includes(":")) {
-            return path.split(":")[0];
-        } else if (path.includes("-")) {
-            return path.split("-")[0];
-        }
-        return path;
-    } catch (e) {
-        return "";
+        return issuer;
     }
+
+    // Otherwise use the `prefix:` from the account as the issuer.
+    // "/ACME:user@example.org" => "ACME"
+    let p = decodeURIComponent(path);
+    if (p.startsWith("/")) p = p.slice(1);
+
+    if (p.includes(":")) p = p.split(":")[0];
+    else if (p.includes("-")) p = p.split("-")[0];
+
+    return p;
 };
 
-const parseDigits = (uriParams): number =>
-    parseInt(uriParams["digits"] ?? "", 10) || 6;
+/**
+ * Parse the length of the generated code.
+ *
+ * The URI query param is called digits since originally TOTP/HOTP codes used
+ * this for generating numeric codes. Now we also support steam, which instead
+ * shows non-numeric codes, and also with a different default length of 5.
+ */
+const parseLength = (url: URL, type: Code["type"]): number => {
+    const defaultLength = type == "steam" ? 5 : 6;
+    return parseInt(url.searchParams.get("digits") ?? "", 10) || defaultLength;
+};
 
-const parsePeriod = (uriParams): number =>
-    parseInt(uriParams["period"] ?? "", 10) || 30;
+const parsePeriod = (url: URL): number =>
+    parseInt(url.searchParams.get("period") ?? "", 10) || 30;
 
-const parseAlgorithm = (uriParams): Code["algorithm"] => {
-    switch (uriParams["algorithm"]?.toLowerCase()) {
+const parseAlgorithm = (url: URL): Code["algorithm"] => {
+    switch (url.searchParams.get("algorithm")?.toLowerCase()) {
         case "sha256":
             return "sha256";
         case "sha512":
@@ -148,8 +164,8 @@ const parseAlgorithm = (uriParams): Code["algorithm"] => {
     }
 };
 
-const parseSecret = (uriParams): string =>
-    uriParams["secret"].replaceAll(" ", "").toUpperCase();
+const parseSecret = (url: URL): string =>
+    ensure(url.searchParams.get("secret")).replaceAll(" ", "").toUpperCase();
 
 /**
  * Generate a pair of OTPs (one time passwords) from the given {@link code}.
@@ -168,11 +184,11 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
                 secret: code.secret,
                 algorithm: code.algorithm,
                 period: code.period,
-                digits: code.digits,
+                digits: code.length,
             });
             otp = totp.generate();
             nextOTP = totp.generate({
-                timestamp: new Date().getTime() + code.period * 1000,
+                timestamp: Date.now() + code.period * 1000,
             });
             break;
         }
@@ -187,6 +203,17 @@ export const generateOTPs = (code: Code): [otp: string, nextOTP: string] => {
             nextOTP = hotp.generate({ counter: 1 });
             break;
         }
+
+        case "steam": {
+            const steam = new Steam({
+                secret: code.secret,
+            });
+            otp = steam.generate();
+            nextOTP = steam.generate({
+                timestamp: Date.now() + code.period * 1000,
+            });
+            break;
+        }
     }
     return [otp, nextOTP];
 };

+ 1 - 1
web/apps/auth/src/services/remote.ts

@@ -35,7 +35,7 @@ export const getAuthCodes = async (): Promise<Code[]> => {
                             );
                         return codeFromURIString(entity.id, decryptedCode);
                     } catch (e) {
-                        log.error(`failed to parse codeId = ${entity.id}`);
+                        log.error(`Failed to parse codeID ${entity.id}`, e);
                         return null;
                     }
                 }),

+ 74 - 0
web/apps/auth/src/services/steam.ts

@@ -0,0 +1,74 @@
+import jsSHA from "jssha";
+import { Secret } from "otpauth";
+
+/**
+ * Steam OTPs.
+ *
+ * Steam's algorithm is a custom variant of TOTP that uses a 26-character
+ * alphabet instead of digits.
+ *
+ * A Dart implementation of the algorithm can be found in
+ * https://github.com/elliotwutingfeng/steam_totp/blob/main/lib/src/steam_totp_base.dart
+ * (MIT license), and we use that as a reference. Our implementation is written
+ * in the style of the other TOTP/HOTP classes that are provided by the otpauth
+ * JS library that we use for the normal TOTP/HOTP generation
+ * https://github.com/hectorm/otpauth/blob/master/src/hotp.js (MIT license).
+ */
+export class Steam {
+    secret: Secret;
+    period: number;
+
+    constructor({ secret }: { secret: string }) {
+        this.secret = Secret.fromBase32(secret);
+        this.period = 30;
+    }
+
+    generate({ timestamp }: { timestamp: number } = { timestamp: Date.now() }) {
+        // Same as regular TOTP.
+        const counter = Math.floor(timestamp / 1000 / this.period);
+
+        // Same as regular HOTP, but algorithm is fixed to SHA-1.
+        const digest = sha1HMACDigest(this.secret.buffer, uintToArray(counter));
+
+        // Same calculation as regular HOTP.
+        const offset = digest[digest.length - 1] & 15;
+        let otp =
+            ((digest[offset] & 127) << 24) |
+            ((digest[offset + 1] & 255) << 16) |
+            ((digest[offset + 2] & 255) << 8) |
+            (digest[offset + 3] & 255);
+
+        // However, instead of using this as the OTP, use it to index into
+        // the steam OTP alphabet.
+        const alphabet = "23456789BCDFGHJKMNPQRTVWXY";
+        const N = alphabet.length;
+        const steamOTP = [];
+        for (let i = 0; i < 5; i++) {
+            steamOTP.push(alphabet[otp % N]);
+            otp = Math.trunc(otp / N);
+        }
+        return steamOTP.join("");
+    }
+}
+
+// Equivalent to
+// https://github.com/hectorm/otpauth/blob/master/src/utils/encoding/uint.js
+const uintToArray = (n: number): Uint8Array => {
+    const result = new Uint8Array(8);
+    for (let i = 7; i >= 0; i--) {
+        result[i] = n & 255;
+        n >>= 8;
+    }
+    return result;
+};
+
+// We don't necessarily need a dependency on `jssha`, we could use SubtleCrypto
+// here too. However, SubtleCrypto has an async interface, and we already have a
+// transitive dependency on `jssha` via `otpauth`, so just using it here doesn't
+// increase our bundle size any further.
+const sha1HMACDigest = (key: ArrayBuffer, message: Uint8Array) => {
+    const hmac = new jsSHA("SHA-1", "UINT8ARRAY");
+    hmac.setHMACKey(key, "ARRAYBUFFER");
+    hmac.update(message);
+    return hmac.getHMAC("UINT8ARRAY");
+};

+ 0 - 1
web/apps/photos/package.json

@@ -43,7 +43,6 @@
         "similarity-transformation": "^0.0.1",
         "transformation-matrix": "^2.16",
         "uuid": "^9.0.1",
-        "vscode-uri": "^3.0.7",
         "xml-js": "^1.6.11",
         "zxcvbn": "^4.4.2"
     },

+ 4 - 0
web/docs/dependencies.md

@@ -198,3 +198,7 @@ some cases.
 
 -   [otpauth](https://github.com/hectorm/otpauth) is used for the generation of
     the actual OTP from the user's TOTP/HOTP secret.
+
+-   However, otpauth doesn't support steam OTPs. For these, we need to compute
+    the SHA-1, and we use the same library, `jssha` that `otpauth` uses (since
+    it is already part of our bundle).

+ 0 - 5
web/yarn.lock

@@ -4804,11 +4804,6 @@ void-elements@3.1.0:
   resolved "https://registry.yarnpkg.com/void-elements/-/void-elements-3.1.0.tgz#614f7fbf8d801f0bb5f0661f5b2f5785750e4f09"
   integrity sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==
 
-vscode-uri@^3.0.7:
-  version "3.0.8"
-  resolved "https://registry.yarnpkg.com/vscode-uri/-/vscode-uri-3.0.8.tgz#1770938d3e72588659a172d0fd4642780083ff9f"
-  integrity sha512-AyFQ0EVmsOZOlAnxoFOGOq1SQDWAB7C6aqMGS23svWAllfOaxbuFvcT8D1i8z3Gyn8fraVeZNNmN6e9bxxXkKw==
-
 webidl-conversions@^3.0.0:
   version "3.0.1"
   resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"