api_utils.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package httpd
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "os"
  11. "path"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/drakkan/sftpgo/dataprovider"
  16. "github.com/drakkan/sftpgo/sftpd"
  17. "github.com/drakkan/sftpgo/utils"
  18. "github.com/go-chi/render"
  19. )
  20. var (
  21. httpBaseURL = "http://127.0.0.1:8080"
  22. )
  23. // SetBaseURL sets the base url to use for HTTP requests, default is "http://127.0.0.1:8080"
  24. func SetBaseURL(url string) {
  25. httpBaseURL = url
  26. }
  27. // gets an HTTP Client with a timeout
  28. func getHTTPClient() *http.Client {
  29. return &http.Client{
  30. Timeout: 15 * time.Second,
  31. }
  32. }
  33. func buildURLRelativeToBase(paths ...string) string {
  34. // we need to use path.Join and not filepath.Join
  35. // since filepath.Join will use backslash separator on Windows
  36. p := path.Join(paths...)
  37. return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/"))
  38. }
  39. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  40. var errorString string
  41. if err != nil {
  42. errorString = err.Error()
  43. }
  44. resp := apiResponse{
  45. Error: errorString,
  46. Message: message,
  47. HTTPStatus: code,
  48. }
  49. if code != http.StatusOK {
  50. w.Header().Set("Content-Type", "application/json; charset=utf-8")
  51. w.WriteHeader(code)
  52. }
  53. render.JSON(w, r, resp)
  54. }
  55. func getRespStatus(err error) int {
  56. if _, ok := err.(*dataprovider.ValidationError); ok {
  57. return http.StatusBadRequest
  58. }
  59. if _, ok := err.(*dataprovider.MethodDisabledError); ok {
  60. return http.StatusForbidden
  61. }
  62. if os.IsNotExist(err) {
  63. return http.StatusBadRequest
  64. }
  65. return http.StatusInternalServerError
  66. }
  67. // AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode.
  68. func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) {
  69. var newUser dataprovider.User
  70. var body []byte
  71. userAsJSON, err := json.Marshal(user)
  72. if err != nil {
  73. return newUser, body, err
  74. }
  75. resp, err := getHTTPClient().Post(buildURLRelativeToBase(userPath), "application/json", bytes.NewBuffer(userAsJSON))
  76. if err != nil {
  77. return newUser, body, err
  78. }
  79. defer resp.Body.Close()
  80. err = checkResponse(resp.StatusCode, expectedStatusCode)
  81. if expectedStatusCode != http.StatusOK {
  82. body, _ = getResponseBody(resp)
  83. return newUser, body, err
  84. }
  85. if err == nil {
  86. err = render.DecodeJSON(resp.Body, &newUser)
  87. } else {
  88. body, _ = getResponseBody(resp)
  89. }
  90. if err == nil {
  91. err = checkUser(&user, &newUser)
  92. }
  93. return newUser, body, err
  94. }
  95. // UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode.
  96. func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) {
  97. var newUser dataprovider.User
  98. var body []byte
  99. userAsJSON, err := json.Marshal(user)
  100. if err != nil {
  101. return user, body, err
  102. }
  103. req, err := http.NewRequest(http.MethodPut, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)),
  104. bytes.NewBuffer(userAsJSON))
  105. if err != nil {
  106. return user, body, err
  107. }
  108. resp, err := getHTTPClient().Do(req)
  109. if err != nil {
  110. return user, body, err
  111. }
  112. defer resp.Body.Close()
  113. body, _ = getResponseBody(resp)
  114. err = checkResponse(resp.StatusCode, expectedStatusCode)
  115. if expectedStatusCode != http.StatusOK {
  116. return newUser, body, err
  117. }
  118. if err == nil {
  119. newUser, body, err = GetUserByID(user.ID, expectedStatusCode)
  120. }
  121. if err == nil {
  122. err = checkUser(&user, &newUser)
  123. }
  124. return newUser, body, err
  125. }
  126. // RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode.
  127. func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  128. var body []byte
  129. req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), nil)
  130. if err != nil {
  131. return body, err
  132. }
  133. resp, err := getHTTPClient().Do(req)
  134. if err != nil {
  135. return body, err
  136. }
  137. defer resp.Body.Close()
  138. body, _ = getResponseBody(resp)
  139. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  140. }
  141. // GetUserByID gets an user by database id and checks the received HTTP Status code against expectedStatusCode.
  142. func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, []byte, error) {
  143. var user dataprovider.User
  144. var body []byte
  145. resp, err := getHTTPClient().Get(buildURLRelativeToBase(userPath, strconv.FormatInt(userID, 10)))
  146. if err != nil {
  147. return user, body, err
  148. }
  149. defer resp.Body.Close()
  150. err = checkResponse(resp.StatusCode, expectedStatusCode)
  151. if err == nil && expectedStatusCode == http.StatusOK {
  152. err = render.DecodeJSON(resp.Body, &user)
  153. } else {
  154. body, _ = getResponseBody(resp)
  155. }
  156. return user, body, err
  157. }
  158. // GetUsers allows to get a list of users and checks the received HTTP Status code against expectedStatusCode.
  159. // The number of results can be limited specifying a limit.
  160. // Some results can be skipped specifying an offset.
  161. // The results can be filtered specifying an username, the username filter is an exact match
  162. func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, []byte, error) {
  163. var users []dataprovider.User
  164. var body []byte
  165. url, err := url.Parse(buildURLRelativeToBase(userPath))
  166. if err != nil {
  167. return users, body, err
  168. }
  169. q := url.Query()
  170. if limit > 0 {
  171. q.Add("limit", strconv.FormatInt(limit, 10))
  172. }
  173. if offset > 0 {
  174. q.Add("offset", strconv.FormatInt(offset, 10))
  175. }
  176. if len(username) > 0 {
  177. q.Add("username", username)
  178. }
  179. url.RawQuery = q.Encode()
  180. resp, err := getHTTPClient().Get(url.String())
  181. if err != nil {
  182. return users, body, err
  183. }
  184. defer resp.Body.Close()
  185. err = checkResponse(resp.StatusCode, expectedStatusCode)
  186. if err == nil && expectedStatusCode == http.StatusOK {
  187. err = render.DecodeJSON(resp.Body, &users)
  188. } else {
  189. body, _ = getResponseBody(resp)
  190. }
  191. return users, body, err
  192. }
  193. // GetQuotaScans gets active quota scans and checks the received HTTP Status code against expectedStatusCode.
  194. func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, []byte, error) {
  195. var quotaScans []sftpd.ActiveQuotaScan
  196. var body []byte
  197. resp, err := getHTTPClient().Get(buildURLRelativeToBase(quotaScanPath))
  198. if err != nil {
  199. return quotaScans, body, err
  200. }
  201. defer resp.Body.Close()
  202. err = checkResponse(resp.StatusCode, expectedStatusCode)
  203. if err == nil && expectedStatusCode == http.StatusOK {
  204. err = render.DecodeJSON(resp.Body, &quotaScans)
  205. } else {
  206. body, _ = getResponseBody(resp)
  207. }
  208. return quotaScans, body, err
  209. }
  210. // StartQuotaScan start a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode.
  211. func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  212. var body []byte
  213. userAsJSON, err := json.Marshal(user)
  214. if err != nil {
  215. return body, err
  216. }
  217. resp, err := getHTTPClient().Post(buildURLRelativeToBase(quotaScanPath), "application/json", bytes.NewBuffer(userAsJSON))
  218. if err != nil {
  219. return body, err
  220. }
  221. defer resp.Body.Close()
  222. body, _ = getResponseBody(resp)
  223. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  224. }
  225. // GetConnections returns status and stats for active SFTP/SCP connections
  226. func GetConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
  227. var connections []sftpd.ConnectionStatus
  228. var body []byte
  229. resp, err := getHTTPClient().Get(buildURLRelativeToBase(activeConnectionsPath))
  230. if err != nil {
  231. return connections, body, err
  232. }
  233. defer resp.Body.Close()
  234. err = checkResponse(resp.StatusCode, expectedStatusCode)
  235. if err == nil && expectedStatusCode == http.StatusOK {
  236. err = render.DecodeJSON(resp.Body, &connections)
  237. } else {
  238. body, _ = getResponseBody(resp)
  239. }
  240. return connections, body, err
  241. }
  242. // CloseConnection closes an active connection identified by connectionID
  243. func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
  244. var body []byte
  245. req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil)
  246. if err != nil {
  247. return body, err
  248. }
  249. resp, err := getHTTPClient().Do(req)
  250. if err != nil {
  251. return body, err
  252. }
  253. defer resp.Body.Close()
  254. err = checkResponse(resp.StatusCode, expectedStatusCode)
  255. body, _ = getResponseBody(resp)
  256. return body, err
  257. }
  258. // GetVersion returns version details
  259. func GetVersion(expectedStatusCode int) (utils.VersionInfo, []byte, error) {
  260. var version utils.VersionInfo
  261. var body []byte
  262. resp, err := getHTTPClient().Get(buildURLRelativeToBase(versionPath))
  263. if err != nil {
  264. return version, body, err
  265. }
  266. defer resp.Body.Close()
  267. err = checkResponse(resp.StatusCode, expectedStatusCode)
  268. if err == nil && expectedStatusCode == http.StatusOK {
  269. err = render.DecodeJSON(resp.Body, &version)
  270. } else {
  271. body, _ = getResponseBody(resp)
  272. }
  273. return version, body, err
  274. }
  275. // GetProviderStatus returns provider status
  276. func GetProviderStatus(expectedStatusCode int) (map[string]interface{}, []byte, error) {
  277. var response map[string]interface{}
  278. var body []byte
  279. resp, err := getHTTPClient().Get(buildURLRelativeToBase(providerStatusPath))
  280. if err != nil {
  281. return response, body, err
  282. }
  283. defer resp.Body.Close()
  284. err = checkResponse(resp.StatusCode, expectedStatusCode)
  285. if err == nil && (expectedStatusCode == http.StatusOK || expectedStatusCode == http.StatusInternalServerError) {
  286. err = render.DecodeJSON(resp.Body, &response)
  287. } else {
  288. body, _ = getResponseBody(resp)
  289. }
  290. return response, body, err
  291. }
  292. // Dumpdata requests a backup to outputFile.
  293. // outputFile is relative to the configured backups_path
  294. func Dumpdata(outputFile string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  295. var response map[string]interface{}
  296. var body []byte
  297. url, err := url.Parse(buildURLRelativeToBase(dumpDataPath))
  298. if err != nil {
  299. return response, body, err
  300. }
  301. q := url.Query()
  302. q.Add("output_file", outputFile)
  303. url.RawQuery = q.Encode()
  304. resp, err := getHTTPClient().Get(url.String())
  305. if err != nil {
  306. return response, body, err
  307. }
  308. defer resp.Body.Close()
  309. err = checkResponse(resp.StatusCode, expectedStatusCode)
  310. if err == nil && expectedStatusCode == http.StatusOK {
  311. err = render.DecodeJSON(resp.Body, &response)
  312. } else {
  313. body, _ = getResponseBody(resp)
  314. }
  315. return response, body, err
  316. }
  317. // Loaddata restores a backup.
  318. // New users are added, existing users are updated. Users will be restored one by one and the restore is stopped if a
  319. // user cannot be added/updated, so it could happen a partial restore
  320. func Loaddata(inputFile, scanQuota string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  321. var response map[string]interface{}
  322. var body []byte
  323. url, err := url.Parse(buildURLRelativeToBase(loadDataPath))
  324. if err != nil {
  325. return response, body, err
  326. }
  327. q := url.Query()
  328. q.Add("input_file", inputFile)
  329. if len(scanQuota) > 0 {
  330. q.Add("scan_quota", scanQuota)
  331. }
  332. url.RawQuery = q.Encode()
  333. resp, err := getHTTPClient().Get(url.String())
  334. if err != nil {
  335. return response, body, err
  336. }
  337. defer resp.Body.Close()
  338. err = checkResponse(resp.StatusCode, expectedStatusCode)
  339. if err == nil && expectedStatusCode == http.StatusOK {
  340. err = render.DecodeJSON(resp.Body, &response)
  341. } else {
  342. body, _ = getResponseBody(resp)
  343. }
  344. return response, body, err
  345. }
  346. func checkResponse(actual int, expected int) error {
  347. if expected != actual {
  348. return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
  349. }
  350. return nil
  351. }
  352. func getResponseBody(resp *http.Response) ([]byte, error) {
  353. return ioutil.ReadAll(resp.Body)
  354. }
  355. func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
  356. if len(actual.Password) > 0 {
  357. return errors.New("User password must not be visible")
  358. }
  359. if expected.ID <= 0 {
  360. if actual.ID <= 0 {
  361. return errors.New("actual user ID must be > 0")
  362. }
  363. } else {
  364. if actual.ID != expected.ID {
  365. return errors.New("user ID mismatch")
  366. }
  367. }
  368. if len(expected.Permissions) != len(actual.Permissions) {
  369. return errors.New("Permissions mismatch")
  370. }
  371. for dir, perms := range expected.Permissions {
  372. if actualPerms, ok := actual.Permissions[dir]; ok {
  373. for _, v := range actualPerms {
  374. if !utils.IsStringInSlice(v, perms) {
  375. return errors.New("Permissions contents mismatch")
  376. }
  377. }
  378. } else {
  379. return errors.New("Permissions directories mismatch")
  380. }
  381. }
  382. if err := compareUserFilters(expected, actual); err != nil {
  383. return err
  384. }
  385. if err := compareUserFsConfig(expected, actual); err != nil {
  386. return err
  387. }
  388. return compareEqualsUserFields(expected, actual)
  389. }
  390. func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) error {
  391. if expected.FsConfig.Provider != actual.FsConfig.Provider {
  392. return errors.New("Fs provider mismatch")
  393. }
  394. if expected.FsConfig.S3Config.Bucket != actual.FsConfig.S3Config.Bucket {
  395. return errors.New("S3 bucket mismatch")
  396. }
  397. if expected.FsConfig.S3Config.Region != actual.FsConfig.S3Config.Region {
  398. return errors.New("S3 region mismatch")
  399. }
  400. if expected.FsConfig.S3Config.AccessKey != actual.FsConfig.S3Config.AccessKey {
  401. return errors.New("S3 access key mismatch")
  402. }
  403. if err := checkS3AccessSecret(expected.FsConfig.S3Config.AccessSecret, actual.FsConfig.S3Config.AccessSecret); err != nil {
  404. return err
  405. }
  406. if expected.FsConfig.S3Config.Endpoint != actual.FsConfig.S3Config.Endpoint {
  407. return errors.New("S3 endpoint mismatch")
  408. }
  409. if expected.FsConfig.S3Config.StorageClass != actual.FsConfig.S3Config.StorageClass {
  410. return errors.New("S3 storage class mismatch")
  411. }
  412. if expected.FsConfig.S3Config.KeyPrefix != actual.FsConfig.S3Config.KeyPrefix &&
  413. expected.FsConfig.S3Config.KeyPrefix+"/" != actual.FsConfig.S3Config.KeyPrefix {
  414. return errors.New("S3 key prefix mismatch")
  415. }
  416. return nil
  417. }
  418. func checkS3AccessSecret(expectedAccessSecret, actualAccessSecret string) error {
  419. if len(expectedAccessSecret) > 0 {
  420. vals := strings.Split(expectedAccessSecret, "$")
  421. if strings.HasPrefix(expectedAccessSecret, "$aes$") && len(vals) == 4 {
  422. expectedAccessSecret = utils.RemoveDecryptionKey(expectedAccessSecret)
  423. if expectedAccessSecret != actualAccessSecret {
  424. return fmt.Errorf("S3 access secret mismatch, expected: %v", expectedAccessSecret)
  425. }
  426. } else {
  427. // here we check that actualAccessSecret is aes encrypted without the nonce
  428. parts := strings.Split(actualAccessSecret, "$")
  429. if !strings.HasPrefix(actualAccessSecret, "$aes$") || len(parts) != 3 {
  430. return errors.New("Invalid S3 access secret")
  431. }
  432. if len(parts) == len(vals) {
  433. if expectedAccessSecret != actualAccessSecret {
  434. return errors.New("S3 encrypted access secret mismatch")
  435. }
  436. }
  437. }
  438. } else {
  439. if expectedAccessSecret != actualAccessSecret {
  440. return errors.New("S3 access secret mismatch")
  441. }
  442. }
  443. return nil
  444. }
  445. func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error {
  446. if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) {
  447. return errors.New("AllowedIP mismatch")
  448. }
  449. if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) {
  450. return errors.New("DeniedIP mismatch")
  451. }
  452. for _, IPMask := range expected.Filters.AllowedIP {
  453. if !utils.IsStringInSlice(IPMask, actual.Filters.AllowedIP) {
  454. return errors.New("AllowedIP contents mismatch")
  455. }
  456. }
  457. for _, IPMask := range expected.Filters.DeniedIP {
  458. if !utils.IsStringInSlice(IPMask, actual.Filters.DeniedIP) {
  459. return errors.New("DeniedIP contents mismatch")
  460. }
  461. }
  462. return nil
  463. }
  464. func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error {
  465. if expected.Username != actual.Username {
  466. return errors.New("Username mismatch")
  467. }
  468. if expected.HomeDir != actual.HomeDir {
  469. return errors.New("HomeDir mismatch")
  470. }
  471. if expected.UID != actual.UID {
  472. return errors.New("UID mismatch")
  473. }
  474. if expected.GID != actual.GID {
  475. return errors.New("GID mismatch")
  476. }
  477. if expected.MaxSessions != actual.MaxSessions {
  478. return errors.New("MaxSessions mismatch")
  479. }
  480. if expected.QuotaSize != actual.QuotaSize {
  481. return errors.New("QuotaSize mismatch")
  482. }
  483. if expected.QuotaFiles != actual.QuotaFiles {
  484. return errors.New("QuotaFiles mismatch")
  485. }
  486. if len(expected.Permissions) != len(actual.Permissions) {
  487. return errors.New("Permissions mismatch")
  488. }
  489. if expected.UploadBandwidth != actual.UploadBandwidth {
  490. return errors.New("UploadBandwidth mismatch")
  491. }
  492. if expected.DownloadBandwidth != actual.DownloadBandwidth {
  493. return errors.New("DownloadBandwidth mismatch")
  494. }
  495. if expected.Status != actual.Status {
  496. return errors.New("Status mismatch")
  497. }
  498. if expected.ExpirationDate != actual.ExpirationDate {
  499. return errors.New("ExpirationDate mismatch")
  500. }
  501. return nil
  502. }