api_utils.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. package httpd
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "path"
  14. "strconv"
  15. "strings"
  16. "time"
  17. "github.com/drakkan/sftpgo/dataprovider"
  18. "github.com/drakkan/sftpgo/sftpd"
  19. "github.com/drakkan/sftpgo/utils"
  20. "github.com/go-chi/render"
  21. )
  22. var (
  23. httpBaseURL = "http://127.0.0.1:8080"
  24. authUsername = ""
  25. authPassword = ""
  26. )
  27. // SetBaseURLAndCredentials sets the base url and the optional credentials to use for HTTP requests.
  28. // Default URL is "http://127.0.0.1:8080" with empty credentials
  29. func SetBaseURLAndCredentials(url, username, password string) {
  30. httpBaseURL = url
  31. authUsername = username
  32. authPassword = password
  33. }
  34. // gets an HTTP Client with a timeout
  35. func getHTTPClient() *http.Client {
  36. return &http.Client{
  37. Timeout: 15 * time.Second,
  38. }
  39. }
  40. func sendHTTPRequest(method, url string, body io.Reader, contentType string) (*http.Response, error) {
  41. req, err := http.NewRequest(method, url, body)
  42. if err != nil {
  43. return nil, err
  44. }
  45. if len(contentType) > 0 {
  46. req.Header.Set("Content-Type", "application/json")
  47. }
  48. if len(authUsername) > 0 || len(authPassword) > 0 {
  49. req.SetBasicAuth(authUsername, authPassword)
  50. }
  51. return getHTTPClient().Do(req)
  52. }
  53. func buildURLRelativeToBase(paths ...string) string {
  54. // we need to use path.Join and not filepath.Join
  55. // since filepath.Join will use backslash separator on Windows
  56. p := path.Join(paths...)
  57. return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/"))
  58. }
  59. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  60. var errorString string
  61. if err != nil {
  62. errorString = err.Error()
  63. }
  64. resp := apiResponse{
  65. Error: errorString,
  66. Message: message,
  67. HTTPStatus: code,
  68. }
  69. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  70. render.JSON(w, r.WithContext(ctx), resp)
  71. }
  72. func getRespStatus(err error) int {
  73. if _, ok := err.(*dataprovider.ValidationError); ok {
  74. return http.StatusBadRequest
  75. }
  76. if _, ok := err.(*dataprovider.MethodDisabledError); ok {
  77. return http.StatusForbidden
  78. }
  79. if os.IsNotExist(err) {
  80. return http.StatusBadRequest
  81. }
  82. return http.StatusInternalServerError
  83. }
  84. // AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode.
  85. func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) {
  86. var newUser dataprovider.User
  87. var body []byte
  88. userAsJSON, err := json.Marshal(user)
  89. if err != nil {
  90. return newUser, body, err
  91. }
  92. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(userPath), bytes.NewBuffer(userAsJSON),
  93. "application/json")
  94. if err != nil {
  95. return newUser, body, err
  96. }
  97. defer resp.Body.Close()
  98. err = checkResponse(resp.StatusCode, expectedStatusCode)
  99. if expectedStatusCode != http.StatusOK {
  100. body, _ = getResponseBody(resp)
  101. return newUser, body, err
  102. }
  103. if err == nil {
  104. err = render.DecodeJSON(resp.Body, &newUser)
  105. } else {
  106. body, _ = getResponseBody(resp)
  107. }
  108. if err == nil {
  109. err = checkUser(&user, &newUser)
  110. }
  111. return newUser, body, err
  112. }
  113. // UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode.
  114. func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) {
  115. var newUser dataprovider.User
  116. var body []byte
  117. userAsJSON, err := json.Marshal(user)
  118. if err != nil {
  119. return user, body, err
  120. }
  121. resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)),
  122. bytes.NewBuffer(userAsJSON), "application/json")
  123. if err != nil {
  124. return user, body, err
  125. }
  126. defer resp.Body.Close()
  127. body, _ = getResponseBody(resp)
  128. err = checkResponse(resp.StatusCode, expectedStatusCode)
  129. if expectedStatusCode != http.StatusOK {
  130. return newUser, body, err
  131. }
  132. if err == nil {
  133. newUser, body, err = GetUserByID(user.ID, expectedStatusCode)
  134. }
  135. if err == nil {
  136. err = checkUser(&user, &newUser)
  137. }
  138. return newUser, body, err
  139. }
  140. // RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode.
  141. func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  142. var body []byte
  143. resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), nil, "")
  144. if err != nil {
  145. return body, err
  146. }
  147. defer resp.Body.Close()
  148. body, _ = getResponseBody(resp)
  149. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  150. }
  151. // GetUserByID gets an user by database id and checks the received HTTP Status code against expectedStatusCode.
  152. func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, []byte, error) {
  153. var user dataprovider.User
  154. var body []byte
  155. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(userPath, strconv.FormatInt(userID, 10)), nil, "")
  156. if err != nil {
  157. return user, body, err
  158. }
  159. defer resp.Body.Close()
  160. err = checkResponse(resp.StatusCode, expectedStatusCode)
  161. if err == nil && expectedStatusCode == http.StatusOK {
  162. err = render.DecodeJSON(resp.Body, &user)
  163. } else {
  164. body, _ = getResponseBody(resp)
  165. }
  166. return user, body, err
  167. }
  168. // GetUsers allows to get a list of users and checks the received HTTP Status code against expectedStatusCode.
  169. // The number of results can be limited specifying a limit.
  170. // Some results can be skipped specifying an offset.
  171. // The results can be filtered specifying an username, the username filter is an exact match
  172. func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, []byte, error) {
  173. var users []dataprovider.User
  174. var body []byte
  175. url, err := url.Parse(buildURLRelativeToBase(userPath))
  176. if err != nil {
  177. return users, body, err
  178. }
  179. q := url.Query()
  180. if limit > 0 {
  181. q.Add("limit", strconv.FormatInt(limit, 10))
  182. }
  183. if offset > 0 {
  184. q.Add("offset", strconv.FormatInt(offset, 10))
  185. }
  186. if len(username) > 0 {
  187. q.Add("username", username)
  188. }
  189. url.RawQuery = q.Encode()
  190. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  191. if err != nil {
  192. return users, body, err
  193. }
  194. defer resp.Body.Close()
  195. err = checkResponse(resp.StatusCode, expectedStatusCode)
  196. if err == nil && expectedStatusCode == http.StatusOK {
  197. err = render.DecodeJSON(resp.Body, &users)
  198. } else {
  199. body, _ = getResponseBody(resp)
  200. }
  201. return users, body, err
  202. }
  203. // GetQuotaScans gets active quota scans and checks the received HTTP Status code against expectedStatusCode.
  204. func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, []byte, error) {
  205. var quotaScans []sftpd.ActiveQuotaScan
  206. var body []byte
  207. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "")
  208. if err != nil {
  209. return quotaScans, body, err
  210. }
  211. defer resp.Body.Close()
  212. err = checkResponse(resp.StatusCode, expectedStatusCode)
  213. if err == nil && expectedStatusCode == http.StatusOK {
  214. err = render.DecodeJSON(resp.Body, &quotaScans)
  215. } else {
  216. body, _ = getResponseBody(resp)
  217. }
  218. return quotaScans, body, err
  219. }
  220. // StartQuotaScan start a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode.
  221. func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  222. var body []byte
  223. userAsJSON, err := json.Marshal(user)
  224. if err != nil {
  225. return body, err
  226. }
  227. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotaScanPath), bytes.NewBuffer(userAsJSON), "")
  228. if err != nil {
  229. return body, err
  230. }
  231. defer resp.Body.Close()
  232. body, _ = getResponseBody(resp)
  233. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  234. }
  235. // GetConnections returns status and stats for active SFTP/SCP connections
  236. func GetConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
  237. var connections []sftpd.ConnectionStatus
  238. var body []byte
  239. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(activeConnectionsPath), nil, "")
  240. if err != nil {
  241. return connections, body, err
  242. }
  243. defer resp.Body.Close()
  244. err = checkResponse(resp.StatusCode, expectedStatusCode)
  245. if err == nil && expectedStatusCode == http.StatusOK {
  246. err = render.DecodeJSON(resp.Body, &connections)
  247. } else {
  248. body, _ = getResponseBody(resp)
  249. }
  250. return connections, body, err
  251. }
  252. // CloseConnection closes an active connection identified by connectionID
  253. func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
  254. var body []byte
  255. resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil, "")
  256. if err != nil {
  257. return body, err
  258. }
  259. defer resp.Body.Close()
  260. err = checkResponse(resp.StatusCode, expectedStatusCode)
  261. body, _ = getResponseBody(resp)
  262. return body, err
  263. }
  264. // GetVersion returns version details
  265. func GetVersion(expectedStatusCode int) (utils.VersionInfo, []byte, error) {
  266. var version utils.VersionInfo
  267. var body []byte
  268. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(versionPath), nil, "")
  269. if err != nil {
  270. return version, body, err
  271. }
  272. defer resp.Body.Close()
  273. err = checkResponse(resp.StatusCode, expectedStatusCode)
  274. if err == nil && expectedStatusCode == http.StatusOK {
  275. err = render.DecodeJSON(resp.Body, &version)
  276. } else {
  277. body, _ = getResponseBody(resp)
  278. }
  279. return version, body, err
  280. }
  281. // GetProviderStatus returns provider status
  282. func GetProviderStatus(expectedStatusCode int) (map[string]interface{}, []byte, error) {
  283. var response map[string]interface{}
  284. var body []byte
  285. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(providerStatusPath), nil, "")
  286. if err != nil {
  287. return response, body, err
  288. }
  289. defer resp.Body.Close()
  290. err = checkResponse(resp.StatusCode, expectedStatusCode)
  291. if err == nil && (expectedStatusCode == http.StatusOK || expectedStatusCode == http.StatusInternalServerError) {
  292. err = render.DecodeJSON(resp.Body, &response)
  293. } else {
  294. body, _ = getResponseBody(resp)
  295. }
  296. return response, body, err
  297. }
  298. // Dumpdata requests a backup to outputFile.
  299. // outputFile is relative to the configured backups_path
  300. func Dumpdata(outputFile, indent string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  301. var response map[string]interface{}
  302. var body []byte
  303. url, err := url.Parse(buildURLRelativeToBase(dumpDataPath))
  304. if err != nil {
  305. return response, body, err
  306. }
  307. q := url.Query()
  308. q.Add("output_file", outputFile)
  309. if len(indent) > 0 {
  310. q.Add("indent", indent)
  311. }
  312. url.RawQuery = q.Encode()
  313. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  314. if err != nil {
  315. return response, body, err
  316. }
  317. defer resp.Body.Close()
  318. err = checkResponse(resp.StatusCode, expectedStatusCode)
  319. if err == nil && expectedStatusCode == http.StatusOK {
  320. err = render.DecodeJSON(resp.Body, &response)
  321. } else {
  322. body, _ = getResponseBody(resp)
  323. }
  324. return response, body, err
  325. }
  326. // Loaddata restores a backup.
  327. // New users are added, existing users are updated. Users will be restored one by one and the restore is stopped if a
  328. // user cannot be added/updated, so it could happen a partial restore
  329. func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  330. var response map[string]interface{}
  331. var body []byte
  332. url, err := url.Parse(buildURLRelativeToBase(loadDataPath))
  333. if err != nil {
  334. return response, body, err
  335. }
  336. q := url.Query()
  337. q.Add("input_file", inputFile)
  338. if len(scanQuota) > 0 {
  339. q.Add("scan_quota", scanQuota)
  340. }
  341. if len(mode) > 0 {
  342. q.Add("mode", mode)
  343. }
  344. url.RawQuery = q.Encode()
  345. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  346. if err != nil {
  347. return response, body, err
  348. }
  349. defer resp.Body.Close()
  350. err = checkResponse(resp.StatusCode, expectedStatusCode)
  351. if err == nil && expectedStatusCode == http.StatusOK {
  352. err = render.DecodeJSON(resp.Body, &response)
  353. } else {
  354. body, _ = getResponseBody(resp)
  355. }
  356. return response, body, err
  357. }
  358. func checkResponse(actual int, expected int) error {
  359. if expected != actual {
  360. return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
  361. }
  362. return nil
  363. }
  364. func getResponseBody(resp *http.Response) ([]byte, error) {
  365. return ioutil.ReadAll(resp.Body)
  366. }
  367. func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
  368. if len(actual.Password) > 0 {
  369. return errors.New("User password must not be visible")
  370. }
  371. if expected.ID <= 0 {
  372. if actual.ID <= 0 {
  373. return errors.New("actual user ID must be > 0")
  374. }
  375. } else {
  376. if actual.ID != expected.ID {
  377. return errors.New("user ID mismatch")
  378. }
  379. }
  380. if len(expected.Permissions) != len(actual.Permissions) {
  381. return errors.New("Permissions mismatch")
  382. }
  383. for dir, perms := range expected.Permissions {
  384. if actualPerms, ok := actual.Permissions[dir]; ok {
  385. for _, v := range actualPerms {
  386. if !utils.IsStringInSlice(v, perms) {
  387. return errors.New("Permissions contents mismatch")
  388. }
  389. }
  390. } else {
  391. return errors.New("Permissions directories mismatch")
  392. }
  393. }
  394. if err := compareUserFilters(expected, actual); err != nil {
  395. return err
  396. }
  397. if err := compareUserFsConfig(expected, actual); err != nil {
  398. return err
  399. }
  400. return compareEqualsUserFields(expected, actual)
  401. }
  402. func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) error {
  403. if expected.FsConfig.Provider != actual.FsConfig.Provider {
  404. return errors.New("Fs provider mismatch")
  405. }
  406. if expected.FsConfig.S3Config.Bucket != actual.FsConfig.S3Config.Bucket {
  407. return errors.New("S3 bucket mismatch")
  408. }
  409. if expected.FsConfig.S3Config.Region != actual.FsConfig.S3Config.Region {
  410. return errors.New("S3 region mismatch")
  411. }
  412. if expected.FsConfig.S3Config.AccessKey != actual.FsConfig.S3Config.AccessKey {
  413. return errors.New("S3 access key mismatch")
  414. }
  415. if err := checkS3AccessSecret(expected.FsConfig.S3Config.AccessSecret, actual.FsConfig.S3Config.AccessSecret); err != nil {
  416. return err
  417. }
  418. if expected.FsConfig.S3Config.Endpoint != actual.FsConfig.S3Config.Endpoint {
  419. return errors.New("S3 endpoint mismatch")
  420. }
  421. if expected.FsConfig.S3Config.StorageClass != actual.FsConfig.S3Config.StorageClass {
  422. return errors.New("S3 storage class mismatch")
  423. }
  424. if expected.FsConfig.S3Config.KeyPrefix != actual.FsConfig.S3Config.KeyPrefix &&
  425. expected.FsConfig.S3Config.KeyPrefix+"/" != actual.FsConfig.S3Config.KeyPrefix {
  426. return errors.New("S3 key prefix mismatch")
  427. }
  428. if expected.FsConfig.GCSConfig.Bucket != actual.FsConfig.GCSConfig.Bucket {
  429. return errors.New("GCS bucket mismatch")
  430. }
  431. if expected.FsConfig.GCSConfig.StorageClass != actual.FsConfig.GCSConfig.StorageClass {
  432. return errors.New("GCS storage class mismatch")
  433. }
  434. if expected.FsConfig.GCSConfig.KeyPrefix != actual.FsConfig.GCSConfig.KeyPrefix &&
  435. expected.FsConfig.GCSConfig.KeyPrefix+"/" != actual.FsConfig.GCSConfig.KeyPrefix {
  436. return errors.New("GCS key prefix mismatch")
  437. }
  438. if expected.FsConfig.GCSConfig.AutomaticCredentials != actual.FsConfig.GCSConfig.AutomaticCredentials {
  439. return errors.New("GCS automatic credentials mismatch")
  440. }
  441. return nil
  442. }
  443. func checkS3AccessSecret(expectedAccessSecret, actualAccessSecret string) error {
  444. if len(expectedAccessSecret) > 0 {
  445. vals := strings.Split(expectedAccessSecret, "$")
  446. if strings.HasPrefix(expectedAccessSecret, "$aes$") && len(vals) == 4 {
  447. expectedAccessSecret = utils.RemoveDecryptionKey(expectedAccessSecret)
  448. if expectedAccessSecret != actualAccessSecret {
  449. return fmt.Errorf("S3 access secret mismatch, expected: %v", expectedAccessSecret)
  450. }
  451. } else {
  452. // here we check that actualAccessSecret is aes encrypted without the nonce
  453. parts := strings.Split(actualAccessSecret, "$")
  454. if !strings.HasPrefix(actualAccessSecret, "$aes$") || len(parts) != 3 {
  455. return errors.New("Invalid S3 access secret")
  456. }
  457. if len(parts) == len(vals) {
  458. if expectedAccessSecret != actualAccessSecret {
  459. return errors.New("S3 encrypted access secret mismatch")
  460. }
  461. }
  462. }
  463. } else {
  464. if expectedAccessSecret != actualAccessSecret {
  465. return errors.New("S3 access secret mismatch")
  466. }
  467. }
  468. return nil
  469. }
  470. func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error {
  471. if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) {
  472. return errors.New("AllowedIP mismatch")
  473. }
  474. if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) {
  475. return errors.New("DeniedIP mismatch")
  476. }
  477. for _, IPMask := range expected.Filters.AllowedIP {
  478. if !utils.IsStringInSlice(IPMask, actual.Filters.AllowedIP) {
  479. return errors.New("AllowedIP contents mismatch")
  480. }
  481. }
  482. for _, IPMask := range expected.Filters.DeniedIP {
  483. if !utils.IsStringInSlice(IPMask, actual.Filters.DeniedIP) {
  484. return errors.New("DeniedIP contents mismatch")
  485. }
  486. }
  487. return nil
  488. }
  489. func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error {
  490. if expected.Username != actual.Username {
  491. return errors.New("Username mismatch")
  492. }
  493. if expected.HomeDir != actual.HomeDir {
  494. return errors.New("HomeDir mismatch")
  495. }
  496. if expected.UID != actual.UID {
  497. return errors.New("UID mismatch")
  498. }
  499. if expected.GID != actual.GID {
  500. return errors.New("GID mismatch")
  501. }
  502. if expected.MaxSessions != actual.MaxSessions {
  503. return errors.New("MaxSessions mismatch")
  504. }
  505. if expected.QuotaSize != actual.QuotaSize {
  506. return errors.New("QuotaSize mismatch")
  507. }
  508. if expected.QuotaFiles != actual.QuotaFiles {
  509. return errors.New("QuotaFiles mismatch")
  510. }
  511. if len(expected.Permissions) != len(actual.Permissions) {
  512. return errors.New("Permissions mismatch")
  513. }
  514. if expected.UploadBandwidth != actual.UploadBandwidth {
  515. return errors.New("UploadBandwidth mismatch")
  516. }
  517. if expected.DownloadBandwidth != actual.DownloadBandwidth {
  518. return errors.New("DownloadBandwidth mismatch")
  519. }
  520. if expected.Status != actual.Status {
  521. return errors.New("Status mismatch")
  522. }
  523. if expected.ExpirationDate != actual.ExpirationDate {
  524. return errors.New("ExpirationDate mismatch")
  525. }
  526. return nil
  527. }