浏览代码

sql provider: enhanced folder mapping query using an upsert

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 年之前
父节点
当前提交
dd9c5b2149
共有 2 个文件被更改,包括 35 次插入45 次删除
  1. 18 39
      dataprovider/sqlcommon.go
  2. 17 6
      dataprovider/sqlqueries.go

+ 18 - 39
dataprovider/sqlcommon.go

@@ -2022,19 +2022,6 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
 	return user, nil
 	return user, nil
 }
 }
 
 
-func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
-	var folderName string
-	q := checkFolderNameQuery()
-	stmt, err := dbHandle.PrepareContext(ctx, q)
-	if err != nil {
-		providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
-		return err
-	}
-	defer stmt.Close()
-	row := stmt.QueryRowContext(ctx, name)
-	return row.Scan(&folderName)
-}
-
 func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
 func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
 	var folder vfs.BaseVirtualFolder
 	var folder vfs.BaseVirtualFolder
 	q := getFolderByNameQuery()
 	q := getFolderByNameQuery()
@@ -2093,29 +2080,23 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer
 }
 }
 
 
 func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
 func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
-	usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
-	var folder vfs.BaseVirtualFolder
-	// FIXME: we could use an UPSERT here, this SELECT could be racy
-	err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
-	switch err {
-	case nil:
-		err = sqlCommonUpdateFolder(baseFolder, dbHandle)
-		if err != nil {
-			return folder, err
-		}
-	case sql.ErrNoRows:
-		baseFolder.UsedQuotaFiles = usedQuotaFiles
-		baseFolder.UsedQuotaSize = usedQuotaSize
-		baseFolder.LastQuotaUpdate = lastQuotaUpdate
-		err = sqlCommonAddFolder(baseFolder, dbHandle)
-		if err != nil {
-			return folder, err
-		}
-	default:
-		return folder, err
+	usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier,
+) error {
+	fsConfig, err := json.Marshal(baseFolder.FsConfig)
+	if err != nil {
+		return err
+	}
+	q := getUpsertFolderQuery()
+	stmt, err := dbHandle.PrepareContext(ctx, q)
+	if err != nil {
+		providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
+		return err
 	}
 	}
+	defer stmt.Close()
 
 
-	return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle)
+	_, err = stmt.ExecContext(ctx, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
+		lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
+	return err
 }
 }
 
 
 func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
 func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
@@ -2327,7 +2308,7 @@ func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.
 		return err
 		return err
 	}
 	}
 	defer stmt.Close()
 	defer stmt.Close()
-	_, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.ID, user.Username)
+	_, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username)
 	return err
 	return err
 }
 }
 
 
@@ -2362,11 +2343,10 @@ func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHan
 	}
 	}
 	for idx := range group.VirtualFolders {
 	for idx := range group.VirtualFolders {
 		vfolder := &group.VirtualFolders[idx]
 		vfolder := &group.VirtualFolders[idx]
-		f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
+		err = sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		vfolder.BaseVirtualFolder = f
 		err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
 		err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -2382,11 +2362,10 @@ func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle
 	}
 	}
 	for idx := range user.VirtualFolders {
 	for idx := range user.VirtualFolders {
 		vfolder := &user.VirtualFolders[idx]
 		vfolder := &user.VirtualFolders[idx]
-		f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
+		err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		vfolder.BaseVirtualFolder = f
 		err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
 		err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
 		if err != nil {
 		if err != nil {
 			return err
 			return err

+ 17 - 6
dataprovider/sqlqueries.go

@@ -450,10 +450,6 @@ func getFolderByNameQuery() string {
 	return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0])
 	return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0])
 }
 }
 
 
-func checkFolderNameQuery() string {
-	return fmt.Sprintf(`SELECT name FROM %v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0])
-}
-
 func getAddFolderQuery() string {
 func getAddFolderQuery() string {
 	return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
 	return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
 		VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
 		VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
@@ -469,6 +465,20 @@ func getDeleteFolderQuery() string {
 	return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableFolders, sqlPlaceholders[0])
 	return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableFolders, sqlPlaceholders[0])
 }
 }
 
 
+func getUpsertFolderQuery() string {
+	if config.Driver == MySQLDataProviderName {
+		return fmt.Sprintf("INSERT INTO %v (`path`,`used_quota_size`,`used_quota_files`,`last_quota_update`,`name`,"+
+			"`description`,`filesystem`) VALUES (%v,%v,%v,%v,%v,%v,%v) ON DUPLICATE KEY UPDATE "+
+			"`path`=VALUES(`path`),`description`=VALUES(`description`),`filesystem`=VALUES(`filesystem`)",
+			sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
+			sqlPlaceholders[5], sqlPlaceholders[6])
+	}
+	return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
+		VALUES (%v,%v,%v,%v,%v,%v,%v) ON CONFLICT (name) DO UPDATE SET path = EXCLUDED.path,description=EXCLUDED.description,
+		filesystem=EXCLUDED.filesystem`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
+		sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
+}
+
 func getClearUserGroupMappingQuery() string {
 func getClearUserGroupMappingQuery() string {
 	return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersGroupsMapping,
 	return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersGroupsMapping,
 		sqlTableUsers, sqlPlaceholders[0])
 		sqlTableUsers, sqlPlaceholders[0])
@@ -499,8 +509,9 @@ func getClearUserFolderMappingQuery() string {
 
 
 func getAddUserFolderMappingQuery() string {
 func getAddUserFolderMappingQuery() string {
 	return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,user_id)
 	return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,user_id)
-		VALUES (%v,%v,%v,%v,(SELECT id FROM %v WHERE username = %v))`, sqlTableUsersFoldersMapping, sqlPlaceholders[0],
-		sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4])
+		VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE username = %v))`,
+		sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders,
+		sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4])
 }
 }
 
 
 func getFoldersQuery(order string, minimal bool) string {
 func getFoldersQuery(order string, minimal bool) string {