sftpgo-mirror/internal/httpd/api_utils.go
Nicola Murino 3dd412f6e3
WebAdmin and REST API: remove too granular permissions
Our permissions system for admin users is too granular and some
permissions overlap. For example, you can define an administrator
with the "manage_system" permission and not with the "manage_admins"
or "manage_user" permission, but the "manage_system" permission
allows you to restore a backup and then create users and
administrators. The following permissions will be removed:
"manage_admins", "manage_apikeys", "manage_system", "retention_checks",
"manage_event_rules", "manage_roles", "manage_ip_lists". Now you
need to add the "*" permission to replace the removed granular
permissions because the removed permissions allow actions that
should only be allowed to super administrators.
There is no point in having separate, overlapping permissions.

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-10 10:46:28 +01:00

944 lines
29 KiB
Go

// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package httpd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"mime"
"net/http"
"net/url"
"os"
"path"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/render"
"github.com/klauspost/compress/zip"
"github.com/rs/xid"
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
type pwdChange struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type pwdReset struct {
Code string `json:"code"`
Password string `json:"password"`
}
type baseProfile struct {
Email string `json:"email,omitempty"`
Description string `json:"description,omitempty"`
AllowAPIKeyAuth bool `json:"allow_api_key_auth"`
}
type adminProfile struct {
baseProfile
}
type userProfile struct {
baseProfile
AdditionalEmails []string `json:"additional_emails,omitempty"`
PublicKeys []string `json:"public_keys,omitempty"`
TLSCerts []string `json:"tls_certs,omitempty"`
}
func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
var errorString string
if errors.Is(err, util.ErrNotFound) {
errorString = http.StatusText(http.StatusNotFound)
} else if err != nil {
errorString = err.Error()
}
resp := apiResponse{
Error: errorString,
Message: message,
}
ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
render.JSON(w, r.WithContext(ctx), resp)
}
func getRespStatus(err error) int {
if errors.Is(err, util.ErrValidation) {
return http.StatusBadRequest
}
if errors.Is(err, util.ErrMethodDisabled) {
return http.StatusForbidden
}
if errors.Is(err, util.ErrNotFound) {
return http.StatusNotFound
}
if errors.Is(err, fs.ErrNotExist) {
return http.StatusBadRequest
}
if errors.Is(err, fs.ErrPermission) || errors.Is(err, dataprovider.ErrLoginNotAllowedFromIP) {
return http.StatusForbidden
}
if errors.Is(err, plugin.ErrNoSearcher) || errors.Is(err, dataprovider.ErrNotImplemented) {
return http.StatusNotImplemented
}
if errors.Is(err, dataprovider.ErrDuplicatedKey) || errors.Is(err, dataprovider.ErrForeignKeyViolated) {
return http.StatusConflict
}
return http.StatusInternalServerError
}
// mappig between fs errors for HTTP protocol and HTTP response status codes
func getMappedStatusCode(err error) int {
var statusCode int
switch {
case errors.Is(err, fs.ErrPermission):
statusCode = http.StatusForbidden
case errors.Is(err, common.ErrReadQuotaExceeded):
statusCode = http.StatusForbidden
case errors.Is(err, fs.ErrNotExist):
statusCode = http.StatusNotFound
case errors.Is(err, common.ErrQuotaExceeded):
statusCode = http.StatusRequestEntityTooLarge
case errors.Is(err, common.ErrOpUnsupported):
statusCode = http.StatusBadRequest
default:
if _, ok := err.(*http.MaxBytesError); ok {
statusCode = http.StatusRequestEntityTooLarge
} else {
statusCode = http.StatusInternalServerError
}
}
return statusCode
}
func getURLParam(r *http.Request, key string) string {
v := chi.URLParam(r, key)
unescaped, err := url.PathUnescape(v)
if err != nil {
return v
}
return unescaped
}
func getURLPath(r *http.Request) string {
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
return rctx.RoutePath
}
return r.URL.Path
}
func getCommaSeparatedQueryParam(r *http.Request, key string) []string {
var result []string
for _, val := range strings.Split(r.URL.Query().Get(key), ",") {
val = strings.TrimSpace(val)
if val != "" {
result = append(result, val)
}
}
return util.RemoveDuplicates(result, false)
}
func getBoolQueryParam(r *http.Request, param string) bool {
return r.URL.Query().Get(param) == "true"
}
func getActiveConnections(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
stats := common.Connections.GetStats(claims.Role)
if claims.NodeID == "" {
stats = append(stats, getNodesConnections(claims.Username, claims.Role)...)
}
render.JSON(w, r, stats)
}
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
connectionID := getURLParam(r, "connectionID")
if connectionID == "" {
sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
return
}
node := r.URL.Query().Get("node")
if node == "" || node == dataprovider.GetNodeName() {
if common.Connections.Close(connectionID, claims.Role) {
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
} else {
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
}
return
}
n, err := dataprovider.GetNodeByName(node)
if err != nil {
logger.Warn(logSender, "", "unable to get node with name %q: %v", node, err)
status := getRespStatus(err)
sendAPIResponse(w, r, nil, http.StatusText(status), status)
return
}
if err := n.SendDeleteRequest(claims.Username, claims.Role, fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID)); err != nil {
logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err)
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
return
}
sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
}
// getNodesConnections returns the active connections from other nodes.
// Errors are silently ignored
func getNodesConnections(admin, role string) []common.ConnectionStatus {
nodes, err := dataprovider.GetNodes()
if err != nil || len(nodes) == 0 {
return nil
}
var results []common.ConnectionStatus
var mu sync.Mutex
var wg sync.WaitGroup
for _, n := range nodes {
wg.Add(1)
go func(node dataprovider.Node) {
defer wg.Done()
var stats []common.ConnectionStatus
if err := node.SendGetRequest(admin, role, activeConnectionsPath, &stats); err != nil {
logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err)
return
}
mu.Lock()
results = append(results, stats...)
mu.Unlock()
}(n)
}
wg.Wait()
return results
}
func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) {
var err error
limit := 100
offset := 0
order := dataprovider.OrderASC
if _, ok := r.URL.Query()["limit"]; ok {
limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
if err != nil {
err = errors.New("invalid limit")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}
if limit > 500 {
limit = 500
}
}
if _, ok := r.URL.Query()["offset"]; ok {
offset, err = strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil {
err = errors.New("invalid offset")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}
}
if _, ok := r.URL.Query()["order"]; ok {
order = r.URL.Query().Get("order")
if order != dataprovider.OrderASC && order != dataprovider.OrderDESC {
err = errors.New("invalid order")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}
}
return limit, offset, order, err
}
func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) {
defer lister.Close()
dataGetter := func(limit, _ int) ([]byte, int, error) {
contents, err := lister.Next(limit)
if errors.Is(err, io.EOF) {
err = nil
}
if err != nil {
return nil, 0, err
}
results := make([]map[string]any, 0, len(contents))
for _, info := range contents {
if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() {
continue
}
res := make(map[string]any)
res["name"] = info.Name()
if info.Mode().IsRegular() {
res["size"] = info.Size()
}
res["mode"] = info.Mode()
res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339)
results = append(results, res)
}
data, err := json.Marshal(results)
count := limit
if len(results) == 0 {
count = 0
}
return data, count, err
}
streamJSONArray(w, defaultQueryLimit, dataGetter)
}
func streamData(w io.Writer, data []byte) {
b := bytes.NewBuffer(data)
_, err := io.CopyN(w, b, int64(len(data)))
if err != nil {
panic(http.ErrAbortHandler)
}
}
func streamJSONArray(w http.ResponseWriter, chunkSize int, dataGetter func(limit, offset int) ([]byte, int, error)) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Accept-Ranges", "none")
w.WriteHeader(http.StatusOK)
streamData(w, []byte("["))
offset := 0
for {
data, count, err := dataGetter(chunkSize, offset)
if err != nil {
panic(http.ErrAbortHandler)
}
if count == 0 {
break
}
if offset > 0 {
streamData(w, []byte(","))
}
streamData(w, data[1:len(data)-1])
if count < chunkSize {
break
}
offset += count
}
streamData(w, []byte("]"))
}
func renderPNGImage(w http.ResponseWriter, r *http.Request, b []byte) {
if len(b) == 0 {
ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusNotFound)
render.PlainText(w, r.WithContext(ctx), http.StatusText(http.StatusNotFound))
return
}
w.Header().Set("Content-Type", "image/png")
streamData(w, b)
}
func getCompressedFileName(username string, files []string) string {
if len(files) == 1 {
name := path.Base(files[0])
return fmt.Sprintf("%s-%s.zip", username, strings.TrimSuffix(name, path.Ext(name)))
}
return fmt.Sprintf("%s-download.zip", username)
}
func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir string, files []string,
share *dataprovider.Share,
) {
conn.User.CheckFsRoot(conn.ID) //nolint:errcheck
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Accept-Ranges", "none")
w.Header().Set("Content-Transfer-Encoding", "binary")
w.WriteHeader(http.StatusOK)
wr := zip.NewWriter(w)
for _, file := range files {
fullPath := util.CleanPath(path.Join(baseDir, file))
if err := addZipEntry(wr, conn, fullPath, baseDir, 0); err != nil {
if share != nil {
dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
}
panic(http.ErrAbortHandler)
}
}
if err := wr.Close(); err != nil {
conn.Log(logger.LevelError, "unable to close zip file: %v", err)
if share != nil {
dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
}
panic(http.ErrAbortHandler)
}
}
func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, recursion int) error {
if recursion >= util.MaxRecursion {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion)
return util.ErrRecursionTooDeep
}
recursion++
info, err := conn.Stat(entryPath, 1)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err)
return err
}
entryName, err := getZipEntryName(entryPath, baseDir)
if err != nil {
conn.Log(logger.LevelError, "unable to get zip entry name: %v", err)
return err
}
if info.IsDir() {
_, err = wr.CreateHeader(&zip.FileHeader{
Name: entryName + "/",
Method: zip.Deflate,
Modified: info.ModTime(),
})
if err != nil {
conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return err
}
lister, err := conn.ReadDir(entryPath)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err)
return err
}
defer lister.Close()
for {
contents, err := lister.Next(vfs.ListerBatchSize)
finished := errors.Is(err, io.EOF)
if err != nil && !finished {
return err
}
for _, info := range contents {
fullPath := util.CleanPath(path.Join(entryPath, info.Name()))
if err := addZipEntry(wr, conn, fullPath, baseDir, recursion); err != nil {
return err
}
}
if finished {
return nil
}
}
}
if !info.Mode().IsRegular() {
// we only allow regular files
conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath)
return nil
}
return addFileToZipEntry(wr, conn, entryPath, entryName, info)
}
func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error {
reader, err := conn.getFileReader(entryPath, 0, http.MethodGet)
if err != nil {
conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err)
return err
}
defer reader.Close()
f, err := wr.CreateHeader(&zip.FileHeader{
Name: entryName,
Method: zip.Deflate,
Modified: info.ModTime(),
})
if err != nil {
conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err)
return err
}
_, err = io.Copy(f, reader)
return err
}
func getZipEntryName(entryPath, baseDir string) (string, error) {
if !strings.HasPrefix(entryPath, baseDir) {
return "", fmt.Errorf("entry path %q is outside base dir %q", entryPath, baseDir)
}
entryPath = strings.TrimPrefix(entryPath, baseDir)
return strings.TrimPrefix(entryPath, "/"), nil
}
func checkDownloadFileFromShare(share *dataprovider.Share, info os.FileInfo) error {
if share != nil && !info.Mode().IsRegular() {
return util.NewValidationError("non regular files are not supported for shares")
}
return nil
}
func downloadFile(w http.ResponseWriter, r *http.Request, connection *Connection, name string,
info os.FileInfo, inline bool, share *dataprovider.Share,
) (int, error) {
connection.User.CheckFsRoot(connection.ID) //nolint:errcheck
err := checkDownloadFileFromShare(share, info)
if err != nil {
return http.StatusBadRequest, err
}
rangeHeader := r.Header.Get("Range")
if rangeHeader != "" && checkIfRange(r, info.ModTime()) == condFalse {
rangeHeader = ""
}
offset := int64(0)
size := info.Size()
responseStatus := http.StatusOK
if strings.HasPrefix(rangeHeader, "bytes=") {
if strings.Contains(rangeHeader, ",") {
return http.StatusRequestedRangeNotSatisfiable, fmt.Errorf("unsupported range %q", rangeHeader)
}
offset, size, err = parseRangeRequest(rangeHeader[6:], size)
if err != nil {
return http.StatusRequestedRangeNotSatisfiable, err
}
responseStatus = http.StatusPartialContent
}
reader, err := connection.getFileReader(name, offset, r.Method)
if err != nil {
return getMappedStatusCode(err), fmt.Errorf("unable to read file %q: %v", name, err)
}
defer reader.Close()
w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
if checkPreconditions(w, r, info.ModTime()) {
return 0, fmt.Errorf("%v", http.StatusText(http.StatusPreconditionFailed))
}
ctype := mime.TypeByExtension(path.Ext(name))
if ctype == "" {
ctype = "application/octet-stream"
}
if responseStatus == http.StatusPartialContent {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, info.Size()))
}
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.Header().Set("Content-Type", ctype)
if !inline {
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", path.Base(name)))
}
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(responseStatus)
if r.Method != http.MethodHead {
_, err = io.CopyN(w, reader, size)
if err != nil {
if share != nil {
dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
}
connection.Log(logger.LevelDebug, "error reading file to download: %v", err)
panic(http.ErrAbortHandler)
}
}
return http.StatusOK, nil
}
func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) bool {
if checkIfUnmodifiedSince(r, modtime) == condFalse {
w.WriteHeader(http.StatusPreconditionFailed)
return true
}
if checkIfModifiedSince(r, modtime) == condFalse {
w.WriteHeader(http.StatusNotModified)
return true
}
return false
}
func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult {
ius := r.Header.Get("If-Unmodified-Since")
if ius == "" || isZeroTime(modtime) {
return condNone
}
t, err := http.ParseTime(ius)
if err != nil {
return condNone
}
// The Last-Modified header truncates sub-second precision so
// the modtime needs to be truncated too.
modtime = modtime.Truncate(time.Second)
if modtime.Before(t) || modtime.Equal(t) {
return condTrue
}
return condFalse
}
func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return condNone
}
ims := r.Header.Get("If-Modified-Since")
if ims == "" || isZeroTime(modtime) {
return condNone
}
t, err := http.ParseTime(ims)
if err != nil {
return condNone
}
// The Last-Modified header truncates sub-second precision so
// the modtime needs to be truncated too.
modtime = modtime.Truncate(time.Second)
if modtime.Before(t) || modtime.Equal(t) {
return condFalse
}
return condTrue
}
func checkIfRange(r *http.Request, modtime time.Time) condResult {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return condNone
}
ir := r.Header.Get("If-Range")
if ir == "" {
return condNone
}
if modtime.IsZero() {
return condFalse
}
t, err := http.ParseTime(ir)
if err != nil {
return condFalse
}
if modtime.Unix() == t.Unix() {
return condTrue
}
return condFalse
}
func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) {
var start, end int64
var err error
values := strings.Split(bytesRange, "-")
if values[0] == "" {
start = -1
} else {
start, err = strconv.ParseInt(values[0], 10, 64)
if err != nil {
return start, size, err
}
}
if len(values) >= 2 {
if values[1] != "" {
end, err = strconv.ParseInt(values[1], 10, 64)
if err != nil {
return start, size, err
}
if end >= size {
end = size - 1
}
}
}
if start == -1 && end == 0 {
return 0, 0, fmt.Errorf("unsupported range %q", bytesRange)
}
if end > 0 {
if start == -1 {
// we have something like -500
start = size - end
size = end
// start cannot be < 0 here, we did end = size -1 above
} else {
// we have something like 500-600
size = end - start + 1
if size < 0 {
return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange)
}
}
return start, size, nil
}
// we have something like 500-
size -= start
if size < 0 {
return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange)
}
return start, size, err
}
func handleDefenderEventLoginFailed(ipAddr string, err error) error {
event := common.HostEventLoginFailed
if errors.Is(err, util.ErrNotFound) {
event = common.HostEventUserNotFound
err = dataprovider.ErrInvalidCredentials
}
common.AddDefenderEvent(ipAddr, common.ProtocolHTTP, event)
common.DelayLogin(err)
return err
}
func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error) {
metric.AddLoginAttempt(loginMethod)
var protocol string
switch loginMethod {
case dataprovider.LoginMethodIDP:
protocol = common.ProtocolOIDC
default:
protocol = common.ProtocolHTTP
}
if err == nil {
plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, protocol, user.Username, ip, "", nil)
common.DelayLogin(nil)
} else if err != common.ErrInternalFailure && err != common.ErrNoCredentials {
logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error())
err = handleDefenderEventLoginFailed(ip, err)
logEv := notifier.LogEventTypeLoginFailed
if errors.Is(err, util.ErrNotFound) {
logEv = notifier.LogEventTypeLoginNoUser
}
plugin.Handler.NotifyLogEvent(logEv, protocol, user.Username, ip, "", err)
}
metric.AddLoginResult(loginMethod, err)
dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
}
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username)
return util.NewI18nError(
fmt.Errorf("protocol HTTP is not allowed for user %q", user.Username),
util.I18nErrorProtocolForbidden,
)
}
if !isLoggedInWithOIDC(r) && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {
logger.Info(logSender, connectionID, "cannot login user %q, password login method is not allowed", user.Username)
return util.NewI18nError(
fmt.Errorf("login method password is not allowed for user %q", user.Username),
util.I18nErrorPwdLoginForbidden,
)
}
if checkSessions && user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
if activeSessions >= user.MaxSessions {
logger.Info(logSender, connectionID, "authentication refused for user: %q, too many open sessions: %v/%v", user.Username,
activeSessions, user.MaxSessions)
return util.NewI18nError(fmt.Errorf("too many open sessions: %v", activeSessions), util.I18nError429Message)
}
}
if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, r.RemoteAddr)
return util.NewI18nError(
fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, r.RemoteAddr),
util.I18nErrorIPForbidden,
)
}
return nil
}
func getActiveAdmin(username, ipAddr string) (dataprovider.Admin, error) {
admin, err := dataprovider.AdminExists(username)
if err != nil {
return admin, err
}
if err := admin.CanLogin(ipAddr); err != nil {
return admin, util.NewRecordNotFoundError(fmt.Sprintf("admin %q cannot login: %v", username, err))
}
return admin, nil
}
func getActiveUser(username string, r *http.Request) (dataprovider.User, error) {
user, err := dataprovider.GetUserWithGroupSettings(username, "")
if err != nil {
return user, err
}
if err := user.CheckLoginConditions(); err != nil {
return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err))
}
if err := checkHTTPClientUser(&user, r, xid.New().String(), false); err != nil {
return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err))
}
return user, nil
}
func handleForgotPassword(r *http.Request, username string, isAdmin bool) error {
var emails []string
var subject string
var err error
var admin dataprovider.Admin
var user dataprovider.User
if username == "" {
return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired)
}
if isAdmin {
admin, err = getActiveAdmin(username, util.GetIPFromRemoteAddress(r.RemoteAddr))
if admin.Email != "" {
emails = []string{admin.Email}
}
subject = fmt.Sprintf("Email Verification Code for admin %q", username)
} else {
user, err = getActiveUser(username, r)
emails = user.GetEmailAddresses()
subject = fmt.Sprintf("Email Verification Code for user %q", username)
if err == nil {
if !isUserAllowedToResetPassword(r, &user) {
return util.NewI18nError(
util.NewValidationError("you are not allowed to reset your password"),
util.I18nErrorPwdResetForbidded,
)
}
}
}
if err != nil {
if errors.Is(err, util.ErrNotFound) {
handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), err) //nolint:errcheck
logger.Debug(logSender, middleware.GetReqID(r.Context()),
"username %q does not exists or cannot login, reset password request silently ignored, is admin? %t, err: %v",
username, isAdmin, err)
return nil
}
return util.NewI18nError(util.NewGenericError("Error retrieving your account, please try again later"), util.I18nErrorGetUser)
}
if len(emails) == 0 {
return util.NewI18nError(
util.NewValidationError("Your account does not have an email address, it is not possible to reset your password by sending an email verification code"),
util.I18nErrorPwdResetNoEmail,
)
}
c := newResetCode(username, isAdmin)
body := new(bytes.Buffer)
data := make(map[string]string)
data["Code"] = c.Code
if err := smtp.RenderPasswordResetTemplate(body, data); err != nil {
logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to render password reset template: %v", err)
return util.NewGenericError("Unable to render password reset template")
}
startTime := time.Now()
if err := smtp.SendEmail(emails, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to send password reset code via email: %v, elapsed: %v",
err, time.Since(startTime))
return util.NewI18nError(
util.NewGenericError(fmt.Sprintf("Error sending confirmation code via email: %v", err)),
util.I18nErrorPwdResetSendEmail,
)
}
logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %q, emails: %+v, is admin? %v, elapsed: %v",
username, emails, isAdmin, time.Since(startTime))
return resetCodesMgr.Add(c)
}
func handleResetPassword(r *http.Request, code, newPassword, confirmPassword string, isAdmin bool) (
*dataprovider.Admin, *dataprovider.User, error,
) {
var admin dataprovider.Admin
var user dataprovider.User
var err error
if newPassword == "" {
return &admin, &user, util.NewValidationError("please set a password")
}
if code == "" {
return &admin, &user, util.NewValidationError("please set a confirmation code")
}
if newPassword != confirmPassword {
return &admin, &user, util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch)
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
resetCode, err := resetCodesMgr.Get(code)
if err != nil {
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
return &admin, &user, util.NewValidationError("confirmation code not found")
}
if resetCode.IsAdmin != isAdmin {
return &admin, &user, util.NewValidationError("invalid confirmation code")
}
if isAdmin {
admin, err = getActiveAdmin(resetCode.Username, ipAddr)
if err != nil {
return &admin, &user, util.NewValidationError("unable to associate the confirmation code with an existing admin")
}
admin.Password = newPassword
admin.Filters.RequirePasswordChange = false
err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role)
if err != nil {
return &admin, &user, util.NewGenericError(fmt.Sprintf("unable to set the new password: %v", err))
}
err = resetCodesMgr.Delete(code)
return &admin, &user, err
}
user, err = getActiveUser(resetCode.Username, r)
if err != nil {
return &admin, &user, util.NewValidationError("Unable to associate the confirmation code with an existing user")
}
if !isUserAllowedToResetPassword(r, &user) {
return &admin, &user, util.NewI18nError(
util.NewValidationError("you are not allowed to reset your password"),
util.I18nErrorPwdResetForbidded,
)
}
err = dataprovider.UpdateUserPassword(user.Username, newPassword, dataprovider.ActionExecutorSelf,
util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role)
if err == nil {
err = resetCodesMgr.Delete(code)
}
user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now())
user.Filters.RequirePasswordChange = false
return &admin, &user, err
}
func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool {
if !user.CanResetPassword() {
return false
}
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
return false
}
if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {
return false
}
if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
return false
}
return true
}
func getProtocolFromRequest(r *http.Request) string {
if isLoggedInWithOIDC(r) {
return common.ProtocolOIDC
}
return common.ProtocolHTTP
}
func hideConfidentialData(claims *jwtTokenClaims, r *http.Request) bool {
if !claims.hasPerm(dataprovider.PermAdminAny) {
return true
}
return r.URL.Query().Get("confidential_data") != "1"
}