diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index d3e712a6..8c9522fb 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -2022,19 +2022,6 @@ func getUserFromDbRow(row sqlScanner) (User, error) { return user, nil } -func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error { - var folderName string - q := checkFolderNameQuery() - stmt, err := dbHandle.PrepareContext(ctx, q) - if err != nil { - providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) - return err - } - defer stmt.Close() - row := stmt.QueryRowContext(ctx, name) - return row.Scan(&folderName) -} - func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { var folder vfs.BaseVirtualFolder q := getFolderByNameQuery() @@ -2093,29 +2080,23 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer } func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64, - usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { - var folder vfs.BaseVirtualFolder - // FIXME: we could use an UPSERT here, this SELECT could be racy - err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle) - switch err { - case nil: - err = sqlCommonUpdateFolder(baseFolder, dbHandle) - if err != nil { - return folder, err - } - case sql.ErrNoRows: - baseFolder.UsedQuotaFiles = usedQuotaFiles - baseFolder.UsedQuotaSize = usedQuotaSize - baseFolder.LastQuotaUpdate = lastQuotaUpdate - err = sqlCommonAddFolder(baseFolder, dbHandle) - if err != nil { - return folder, err - } - default: - return folder, err + usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier, +) error { + fsConfig, err := json.Marshal(baseFolder.FsConfig) + if err != nil { + return err } + q := getUpsertFolderQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() - return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle) + _, err = stmt.ExecContext(ctx, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles, + lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig)) + return err } func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { @@ -2327,7 +2308,7 @@ func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs. return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.ID, user.Username) + _, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username) return err } @@ -2362,11 +2343,10 @@ func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHan } for idx := range group.VirtualFolders { vfolder := &group.VirtualFolders[idx] - f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) + err = sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) if err != nil { return err } - vfolder.BaseVirtualFolder = f err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle) if err != nil { return err @@ -2382,11 +2362,10 @@ func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle } for idx := range user.VirtualFolders { vfolder := &user.VirtualFolders[idx] - f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) + err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) if err != nil { return err } - vfolder.BaseVirtualFolder = f err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle) if err != nil { return err diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index aff9a067..007a3509 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -450,10 +450,6 @@ func getFolderByNameQuery() string { return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0]) } -func checkFolderNameQuery() string { - return fmt.Sprintf(`SELECT name FROM %v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0]) -} - func getAddFolderQuery() string { return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem) VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], @@ -469,6 +465,20 @@ func getDeleteFolderQuery() string { return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableFolders, sqlPlaceholders[0]) } +func getUpsertFolderQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("INSERT INTO %v (`path`,`used_quota_size`,`used_quota_files`,`last_quota_update`,`name`,"+ + "`description`,`filesystem`) VALUES (%v,%v,%v,%v,%v,%v,%v) ON DUPLICATE KEY UPDATE "+ + "`path`=VALUES(`path`),`description`=VALUES(`description`),`filesystem`=VALUES(`filesystem`)", + sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6]) + } + return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem) + VALUES (%v,%v,%v,%v,%v,%v,%v) ON CONFLICT (name) DO UPDATE SET path = EXCLUDED.path,description=EXCLUDED.description, + filesystem=EXCLUDED.filesystem`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) +} + func getClearUserGroupMappingQuery() string { return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0]) @@ -499,8 +509,9 @@ func getClearUserFolderMappingQuery() string { func getAddUserFolderMappingQuery() string { return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,user_id) - VALUES (%v,%v,%v,%v,(SELECT id FROM %v WHERE username = %v))`, sqlTableUsersFoldersMapping, sqlPlaceholders[0], - sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4]) + VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE username = %v))`, + sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, + sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4]) } func getFoldersQuery(order string, minimal bool) string {