api_utils.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. package httpd
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime"
  9. "net/http"
  10. "net/url"
  11. "os"
  12. "path"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "github.com/go-chi/chi/v5"
  17. "github.com/go-chi/chi/v5/middleware"
  18. "github.com/go-chi/render"
  19. "github.com/klauspost/compress/zip"
  20. "github.com/drakkan/sftpgo/v2/common"
  21. "github.com/drakkan/sftpgo/v2/dataprovider"
  22. "github.com/drakkan/sftpgo/v2/logger"
  23. "github.com/drakkan/sftpgo/v2/metric"
  24. "github.com/drakkan/sftpgo/v2/plugin"
  25. "github.com/drakkan/sftpgo/v2/smtp"
  26. "github.com/drakkan/sftpgo/v2/util"
  27. )
  28. type pwdChange struct {
  29. CurrentPassword string `json:"current_password"`
  30. NewPassword string `json:"new_password"`
  31. }
  32. type pwdReset struct {
  33. Code string `json:"code"`
  34. Password string `json:"password"`
  35. }
  36. type baseProfile struct {
  37. Email string `json:"email,omitempty"`
  38. Description string `json:"description,omitempty"`
  39. AllowAPIKeyAuth bool `json:"allow_api_key_auth"`
  40. }
  41. type adminProfile struct {
  42. baseProfile
  43. }
  44. type userProfile struct {
  45. baseProfile
  46. PublicKeys []string `json:"public_keys,omitempty"`
  47. }
  48. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  49. var errorString string
  50. if _, ok := err.(*util.RecordNotFoundError); ok {
  51. errorString = http.StatusText(http.StatusNotFound)
  52. } else if err != nil {
  53. errorString = err.Error()
  54. }
  55. resp := apiResponse{
  56. Error: errorString,
  57. Message: message,
  58. }
  59. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  60. render.JSON(w, r.WithContext(ctx), resp)
  61. }
  62. func getRespStatus(err error) int {
  63. if _, ok := err.(*util.ValidationError); ok {
  64. return http.StatusBadRequest
  65. }
  66. if _, ok := err.(*util.MethodDisabledError); ok {
  67. return http.StatusForbidden
  68. }
  69. if _, ok := err.(*util.RecordNotFoundError); ok {
  70. return http.StatusNotFound
  71. }
  72. if os.IsNotExist(err) {
  73. return http.StatusBadRequest
  74. }
  75. if os.IsPermission(err) || errors.Is(err, dataprovider.ErrLoginNotAllowedFromIP) {
  76. return http.StatusForbidden
  77. }
  78. if errors.Is(err, plugin.ErrNoSearcher) || errors.Is(err, dataprovider.ErrNotImplemented) {
  79. return http.StatusNotImplemented
  80. }
  81. return http.StatusInternalServerError
  82. }
  83. // mappig between fs errors for HTTP protocol and HTTP response status codes
  84. func getMappedStatusCode(err error) int {
  85. var statusCode int
  86. switch {
  87. case errors.Is(err, os.ErrPermission):
  88. statusCode = http.StatusForbidden
  89. case errors.Is(err, common.ErrReadQuotaExceeded):
  90. statusCode = http.StatusForbidden
  91. case errors.Is(err, os.ErrNotExist):
  92. statusCode = http.StatusNotFound
  93. case errors.Is(err, common.ErrQuotaExceeded):
  94. statusCode = http.StatusRequestEntityTooLarge
  95. case errors.Is(err, common.ErrOpUnsupported):
  96. statusCode = http.StatusBadRequest
  97. default:
  98. statusCode = http.StatusInternalServerError
  99. }
  100. return statusCode
  101. }
  102. func getURLParam(r *http.Request, key string) string {
  103. v := chi.URLParam(r, key)
  104. unescaped, err := url.PathUnescape(v)
  105. if err != nil {
  106. return v
  107. }
  108. return unescaped
  109. }
  110. func getCommaSeparatedQueryParam(r *http.Request, key string) []string {
  111. var result []string
  112. for _, val := range strings.Split(r.URL.Query().Get(key), ",") {
  113. val = strings.TrimSpace(val)
  114. if val != "" {
  115. result = append(result, val)
  116. }
  117. }
  118. return util.RemoveDuplicates(result)
  119. }
  120. func getBoolQueryParam(r *http.Request, param string) bool {
  121. return r.URL.Query().Get(param) == "true"
  122. }
  123. func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
  124. r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
  125. connectionID := getURLParam(r, "connectionID")
  126. if connectionID == "" {
  127. sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
  128. return
  129. }
  130. if common.Connections.Close(connectionID) {
  131. sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
  132. } else {
  133. sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
  134. }
  135. }
  136. func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) {
  137. var err error
  138. limit := 100
  139. offset := 0
  140. order := dataprovider.OrderASC
  141. if _, ok := r.URL.Query()["limit"]; ok {
  142. limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
  143. if err != nil {
  144. err = errors.New("invalid limit")
  145. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  146. return limit, offset, order, err
  147. }
  148. if limit > 500 {
  149. limit = 500
  150. }
  151. }
  152. if _, ok := r.URL.Query()["offset"]; ok {
  153. offset, err = strconv.Atoi(r.URL.Query().Get("offset"))
  154. if err != nil {
  155. err = errors.New("invalid offset")
  156. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  157. return limit, offset, order, err
  158. }
  159. }
  160. if _, ok := r.URL.Query()["order"]; ok {
  161. order = r.URL.Query().Get("order")
  162. if order != dataprovider.OrderASC && order != dataprovider.OrderDESC {
  163. err = errors.New("invalid order")
  164. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  165. return limit, offset, order, err
  166. }
  167. }
  168. return limit, offset, order, err
  169. }
  170. func renderAPIDirContents(w http.ResponseWriter, r *http.Request, contents []os.FileInfo, omitNonRegularFiles bool) {
  171. results := make([]map[string]interface{}, 0, len(contents))
  172. for _, info := range contents {
  173. if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() {
  174. continue
  175. }
  176. res := make(map[string]interface{})
  177. res["name"] = info.Name()
  178. if info.Mode().IsRegular() {
  179. res["size"] = info.Size()
  180. }
  181. res["mode"] = info.Mode()
  182. res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339)
  183. results = append(results, res)
  184. }
  185. render.JSON(w, r, results)
  186. }
  187. func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir string, files []string,
  188. share *dataprovider.Share,
  189. ) {
  190. w.Header().Set("Content-Type", "application/zip")
  191. w.Header().Set("Accept-Ranges", "none")
  192. w.Header().Set("Content-Transfer-Encoding", "binary")
  193. w.WriteHeader(http.StatusOK)
  194. wr := zip.NewWriter(w)
  195. for _, file := range files {
  196. fullPath := path.Join(baseDir, file)
  197. if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
  198. if share != nil {
  199. dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
  200. }
  201. panic(http.ErrAbortHandler)
  202. }
  203. }
  204. if err := wr.Close(); err != nil {
  205. conn.Log(logger.LevelError, "unable to close zip file: %v", err)
  206. if share != nil {
  207. dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
  208. }
  209. panic(http.ErrAbortHandler)
  210. }
  211. }
  212. func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) error {
  213. info, err := conn.Stat(entryPath, 1)
  214. if err != nil {
  215. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, stat error: %v", entryPath, err)
  216. return err
  217. }
  218. if info.IsDir() {
  219. _, err := wr.Create(getZipEntryName(entryPath, baseDir) + "/")
  220. if err != nil {
  221. conn.Log(logger.LevelDebug, "unable to create zip entry %#v: %v", entryPath, err)
  222. return err
  223. }
  224. contents, err := conn.ReadDir(entryPath)
  225. if err != nil {
  226. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, read dir error: %v", entryPath, err)
  227. return err
  228. }
  229. for _, info := range contents {
  230. fullPath := path.Join(entryPath, info.Name())
  231. if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
  232. return err
  233. }
  234. }
  235. return nil
  236. }
  237. if !info.Mode().IsRegular() {
  238. // we only allow regular files
  239. conn.Log(logger.LevelDebug, "skipping zip entry for non regular file %#v", entryPath)
  240. return nil
  241. }
  242. reader, err := conn.getFileReader(entryPath, 0, http.MethodGet)
  243. if err != nil {
  244. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, cannot open file: %v", entryPath, err)
  245. return err
  246. }
  247. defer reader.Close()
  248. f, err := wr.Create(getZipEntryName(entryPath, baseDir))
  249. if err != nil {
  250. conn.Log(logger.LevelDebug, "unable to create zip entry %#v: %v", entryPath, err)
  251. return err
  252. }
  253. _, err = io.Copy(f, reader)
  254. return err
  255. }
  256. func getZipEntryName(entryPath, baseDir string) string {
  257. entryPath = strings.TrimPrefix(entryPath, baseDir)
  258. return strings.TrimPrefix(entryPath, "/")
  259. }
  260. func checkDownloadFileFromShare(share *dataprovider.Share, info os.FileInfo) error {
  261. if share != nil && !info.Mode().IsRegular() {
  262. return util.NewValidationError("non regular files are not supported for shares")
  263. }
  264. return nil
  265. }
  266. func downloadFile(w http.ResponseWriter, r *http.Request, connection *Connection, name string,
  267. info os.FileInfo, inline bool, share *dataprovider.Share,
  268. ) (int, error) {
  269. err := checkDownloadFileFromShare(share, info)
  270. if err != nil {
  271. return http.StatusBadRequest, err
  272. }
  273. rangeHeader := r.Header.Get("Range")
  274. if rangeHeader != "" && checkIfRange(r, info.ModTime()) == condFalse {
  275. rangeHeader = ""
  276. }
  277. offset := int64(0)
  278. size := info.Size()
  279. responseStatus := http.StatusOK
  280. if strings.HasPrefix(rangeHeader, "bytes=") {
  281. if strings.Contains(rangeHeader, ",") {
  282. return http.StatusRequestedRangeNotSatisfiable, fmt.Errorf("unsupported range %#v", rangeHeader)
  283. }
  284. offset, size, err = parseRangeRequest(rangeHeader[6:], size)
  285. if err != nil {
  286. return http.StatusRequestedRangeNotSatisfiable, err
  287. }
  288. responseStatus = http.StatusPartialContent
  289. }
  290. reader, err := connection.getFileReader(name, offset, r.Method)
  291. if err != nil {
  292. return getMappedStatusCode(err), fmt.Errorf("unable to read file %#v: %v", name, err)
  293. }
  294. defer reader.Close()
  295. w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
  296. if checkPreconditions(w, r, info.ModTime()) {
  297. return 0, fmt.Errorf("%v", http.StatusText(http.StatusPreconditionFailed))
  298. }
  299. ctype := mime.TypeByExtension(path.Ext(name))
  300. if ctype == "" {
  301. ctype = "application/octet-stream"
  302. }
  303. if responseStatus == http.StatusPartialContent {
  304. w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, info.Size()))
  305. }
  306. w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
  307. w.Header().Set("Content-Type", ctype)
  308. if !inline {
  309. w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%#v", path.Base(name)))
  310. }
  311. w.Header().Set("Accept-Ranges", "bytes")
  312. w.WriteHeader(responseStatus)
  313. if r.Method != http.MethodHead {
  314. _, err = io.CopyN(w, reader, size)
  315. if err != nil {
  316. if share != nil {
  317. dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck
  318. }
  319. connection.Log(logger.LevelDebug, "error reading file to download: %v", err)
  320. panic(http.ErrAbortHandler)
  321. }
  322. }
  323. return http.StatusOK, nil
  324. }
  325. func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) bool {
  326. if checkIfUnmodifiedSince(r, modtime) == condFalse {
  327. w.WriteHeader(http.StatusPreconditionFailed)
  328. return true
  329. }
  330. if checkIfModifiedSince(r, modtime) == condFalse {
  331. w.WriteHeader(http.StatusNotModified)
  332. return true
  333. }
  334. return false
  335. }
  336. func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult {
  337. ius := r.Header.Get("If-Unmodified-Since")
  338. if ius == "" || isZeroTime(modtime) {
  339. return condNone
  340. }
  341. t, err := http.ParseTime(ius)
  342. if err != nil {
  343. return condNone
  344. }
  345. // The Last-Modified header truncates sub-second precision so
  346. // the modtime needs to be truncated too.
  347. modtime = modtime.Truncate(time.Second)
  348. if modtime.Before(t) || modtime.Equal(t) {
  349. return condTrue
  350. }
  351. return condFalse
  352. }
  353. func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult {
  354. if r.Method != http.MethodGet && r.Method != http.MethodHead {
  355. return condNone
  356. }
  357. ims := r.Header.Get("If-Modified-Since")
  358. if ims == "" || isZeroTime(modtime) {
  359. return condNone
  360. }
  361. t, err := http.ParseTime(ims)
  362. if err != nil {
  363. return condNone
  364. }
  365. // The Last-Modified header truncates sub-second precision so
  366. // the modtime needs to be truncated too.
  367. modtime = modtime.Truncate(time.Second)
  368. if modtime.Before(t) || modtime.Equal(t) {
  369. return condFalse
  370. }
  371. return condTrue
  372. }
  373. func checkIfRange(r *http.Request, modtime time.Time) condResult {
  374. if r.Method != http.MethodGet && r.Method != http.MethodHead {
  375. return condNone
  376. }
  377. ir := r.Header.Get("If-Range")
  378. if ir == "" {
  379. return condNone
  380. }
  381. if modtime.IsZero() {
  382. return condFalse
  383. }
  384. t, err := http.ParseTime(ir)
  385. if err != nil {
  386. return condFalse
  387. }
  388. if modtime.Add(60 * time.Second).Before(t) {
  389. return condTrue
  390. }
  391. return condFalse
  392. }
  393. func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) {
  394. var start, end int64
  395. var err error
  396. values := strings.Split(bytesRange, "-")
  397. if values[0] == "" {
  398. start = -1
  399. } else {
  400. start, err = strconv.ParseInt(values[0], 10, 64)
  401. if err != nil {
  402. return start, size, err
  403. }
  404. }
  405. if len(values) >= 2 {
  406. if values[1] != "" {
  407. end, err = strconv.ParseInt(values[1], 10, 64)
  408. if err != nil {
  409. return start, size, err
  410. }
  411. if end >= size {
  412. end = size - 1
  413. }
  414. }
  415. }
  416. if start == -1 && end == 0 {
  417. return 0, 0, fmt.Errorf("unsupported range %#v", bytesRange)
  418. }
  419. if end > 0 {
  420. if start == -1 {
  421. // we have something like -500
  422. start = size - end
  423. size = end
  424. // start cannit be < 0 here, we did end = size -1 above
  425. } else {
  426. // we have something like 500-600
  427. size = end - start + 1
  428. if size < 0 {
  429. return 0, 0, fmt.Errorf("unacceptable range %#v", bytesRange)
  430. }
  431. }
  432. return start, size, nil
  433. }
  434. // we have something like 500-
  435. size -= start
  436. if size < 0 {
  437. return 0, 0, fmt.Errorf("unacceptable range %#v", bytesRange)
  438. }
  439. return start, size, err
  440. }
  441. func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error) {
  442. metric.AddLoginAttempt(loginMethod)
  443. var protocol string
  444. switch loginMethod {
  445. case dataprovider.LoginMethodIDP:
  446. protocol = common.ProtocolOIDC
  447. default:
  448. protocol = common.ProtocolHTTP
  449. }
  450. if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
  451. logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error())
  452. event := common.HostEventLoginFailed
  453. if _, ok := err.(*util.RecordNotFoundError); ok {
  454. event = common.HostEventUserNotFound
  455. }
  456. common.AddDefenderEvent(ip, event)
  457. }
  458. metric.AddLoginResult(loginMethod, err)
  459. dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
  460. }
  461. func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
  462. if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) {
  463. logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
  464. return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
  465. }
  466. if !isLoggedInWithOIDC(r) && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, nil) {
  467. logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
  468. return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
  469. }
  470. if user.MaxSessions > 0 {
  471. activeSessions := common.Connections.GetActiveSessions(user.Username)
  472. if activeSessions >= user.MaxSessions {
  473. logger.Info(logSender, connectionID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,
  474. activeSessions, user.MaxSessions)
  475. return fmt.Errorf("too many open sessions: %v", activeSessions)
  476. }
  477. }
  478. if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
  479. logger.Info(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
  480. return fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
  481. }
  482. return nil
  483. }
  484. func handleForgotPassword(r *http.Request, username string, isAdmin bool) error {
  485. var email, subject string
  486. var err error
  487. var admin dataprovider.Admin
  488. var user dataprovider.User
  489. if username == "" {
  490. return util.NewValidationError("Username is mandatory")
  491. }
  492. if isAdmin {
  493. admin, err = dataprovider.AdminExists(username)
  494. email = admin.Email
  495. subject = fmt.Sprintf("Email Verification Code for admin %#v", username)
  496. } else {
  497. user, err = dataprovider.UserExists(username)
  498. email = user.Email
  499. subject = fmt.Sprintf("Email Verification Code for user %#v", username)
  500. if err == nil {
  501. if !isUserAllowedToResetPassword(r, &user) {
  502. return util.NewValidationError("You are not allowed to reset your password")
  503. }
  504. }
  505. }
  506. if err != nil {
  507. if _, ok := err.(*util.RecordNotFoundError); ok {
  508. logger.Debug(logSender, middleware.GetReqID(r.Context()), "username %#v does not exists, reset password request silently ignored, is admin? %v",
  509. username, isAdmin)
  510. return nil
  511. }
  512. return util.NewGenericError("Error retrieving your account, please try again later")
  513. }
  514. if email == "" {
  515. return util.NewValidationError("Your account does not have an email address, it is not possible to reset your password by sending an email verification code")
  516. }
  517. c := newResetCode(username, isAdmin)
  518. body := new(bytes.Buffer)
  519. data := make(map[string]string)
  520. data["Code"] = c.Code
  521. if err := smtp.RenderPasswordResetTemplate(body, data); err != nil {
  522. logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to render password reset template: %v", err)
  523. return util.NewGenericError("Unable to render password reset template")
  524. }
  525. startTime := time.Now()
  526. if err := smtp.SendEmail(email, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
  527. logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to send password reset code via email: %v, elapsed: %v",
  528. err, time.Since(startTime))
  529. return util.NewGenericError(fmt.Sprintf("Unable to send confirmation code via email: %v", err))
  530. }
  531. logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %#v, email: %#v, is admin? %v, elapsed: %v",
  532. username, email, isAdmin, time.Since(startTime))
  533. resetCodes.Store(c.Code, c)
  534. return nil
  535. }
  536. func handleResetPassword(r *http.Request, code, newPassword string, isAdmin bool) (
  537. *dataprovider.Admin, *dataprovider.User, error,
  538. ) {
  539. var admin dataprovider.Admin
  540. var user dataprovider.User
  541. var err error
  542. if newPassword == "" {
  543. return &admin, &user, util.NewValidationError("Please set a password")
  544. }
  545. if code == "" {
  546. return &admin, &user, util.NewValidationError("Please set a confirmation code")
  547. }
  548. c, ok := resetCodes.Load(code)
  549. if !ok {
  550. return &admin, &user, util.NewValidationError("Confirmation code not found")
  551. }
  552. resetCode := c.(*resetCode)
  553. if resetCode.IsAdmin != isAdmin {
  554. return &admin, &user, util.NewValidationError("Invalid confirmation code")
  555. }
  556. if isAdmin {
  557. admin, err = dataprovider.AdminExists(resetCode.Username)
  558. if err != nil {
  559. return &admin, &user, util.NewValidationError("Unable to associate the confirmation code with an existing admin")
  560. }
  561. admin.Password = newPassword
  562. err = dataprovider.UpdateAdmin(&admin, admin.Username, util.GetIPFromRemoteAddress(r.RemoteAddr))
  563. if err != nil {
  564. return &admin, &user, util.NewGenericError(fmt.Sprintf("Unable to set the new password: %v", err))
  565. }
  566. } else {
  567. user, err = dataprovider.UserExists(resetCode.Username)
  568. if err != nil {
  569. return &admin, &user, util.NewValidationError("Unable to associate the confirmation code with an existing user")
  570. }
  571. if err == nil {
  572. if !isUserAllowedToResetPassword(r, &user) {
  573. return &admin, &user, util.NewValidationError("You are not allowed to reset your password")
  574. }
  575. }
  576. user.Password = newPassword
  577. err = dataprovider.UpdateUser(&user, user.Username, util.GetIPFromRemoteAddress(r.RemoteAddr))
  578. if err != nil {
  579. return &admin, &user, util.NewGenericError(fmt.Sprintf("Unable to set the new password: %v", err))
  580. }
  581. }
  582. resetCodes.Delete(code)
  583. return &admin, &user, nil
  584. }
  585. func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool {
  586. if !user.CanResetPassword() {
  587. return false
  588. }
  589. if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) {
  590. return false
  591. }
  592. if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, nil) {
  593. return false
  594. }
  595. if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
  596. return false
  597. }
  598. return true
  599. }
  600. func getProtocolFromRequest(r *http.Request) string {
  601. if isLoggedInWithOIDC(r) {
  602. return common.ProtocolOIDC
  603. }
  604. return common.ProtocolHTTP
  605. }