api_utils.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. package httpd
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "path"
  14. "path/filepath"
  15. "strconv"
  16. "strings"
  17. "github.com/go-chi/render"
  18. "github.com/drakkan/sftpgo/common"
  19. "github.com/drakkan/sftpgo/dataprovider"
  20. "github.com/drakkan/sftpgo/httpclient"
  21. "github.com/drakkan/sftpgo/kms"
  22. "github.com/drakkan/sftpgo/utils"
  23. "github.com/drakkan/sftpgo/version"
  24. "github.com/drakkan/sftpgo/vfs"
  25. )
  26. var (
  27. httpBaseURL = "http://127.0.0.1:8080"
  28. authUsername = ""
  29. authPassword = ""
  30. )
  31. // SetBaseURLAndCredentials sets the base url and the optional credentials to use for HTTP requests.
  32. // Default URL is "http://127.0.0.1:8080" with empty credentials
  33. func SetBaseURLAndCredentials(url, username, password string) {
  34. httpBaseURL = url
  35. authUsername = username
  36. authPassword = password
  37. }
  38. func sendHTTPRequest(method, url string, body io.Reader, contentType string) (*http.Response, error) {
  39. req, err := http.NewRequest(method, url, body)
  40. if err != nil {
  41. return nil, err
  42. }
  43. if len(contentType) > 0 {
  44. req.Header.Set("Content-Type", "application/json")
  45. }
  46. if len(authUsername) > 0 || len(authPassword) > 0 {
  47. req.SetBasicAuth(authUsername, authPassword)
  48. }
  49. return httpclient.GetHTTPClient().Do(req)
  50. }
  51. func buildURLRelativeToBase(paths ...string) string {
  52. // we need to use path.Join and not filepath.Join
  53. // since filepath.Join will use backslash separator on Windows
  54. p := path.Join(paths...)
  55. return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/"))
  56. }
  57. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  58. var errorString string
  59. if err != nil {
  60. errorString = err.Error()
  61. }
  62. resp := apiResponse{
  63. Error: errorString,
  64. Message: message,
  65. }
  66. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  67. render.JSON(w, r.WithContext(ctx), resp)
  68. }
  69. func getRespStatus(err error) int {
  70. if _, ok := err.(*dataprovider.ValidationError); ok {
  71. return http.StatusBadRequest
  72. }
  73. if _, ok := err.(*dataprovider.MethodDisabledError); ok {
  74. return http.StatusForbidden
  75. }
  76. if _, ok := err.(*dataprovider.RecordNotFoundError); ok {
  77. return http.StatusNotFound
  78. }
  79. if os.IsNotExist(err) {
  80. return http.StatusBadRequest
  81. }
  82. return http.StatusInternalServerError
  83. }
  84. // AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode.
  85. func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) {
  86. var newUser dataprovider.User
  87. var body []byte
  88. userAsJSON, _ := json.Marshal(user)
  89. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(userPath), bytes.NewBuffer(userAsJSON),
  90. "application/json")
  91. if err != nil {
  92. return newUser, body, err
  93. }
  94. defer resp.Body.Close()
  95. err = checkResponse(resp.StatusCode, expectedStatusCode)
  96. if expectedStatusCode != http.StatusOK {
  97. body, _ = getResponseBody(resp)
  98. return newUser, body, err
  99. }
  100. if err == nil {
  101. err = render.DecodeJSON(resp.Body, &newUser)
  102. } else {
  103. body, _ = getResponseBody(resp)
  104. }
  105. if err == nil {
  106. err = checkUser(&user, &newUser)
  107. }
  108. return newUser, body, err
  109. }
  110. // UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode.
  111. func UpdateUser(user dataprovider.User, expectedStatusCode int, disconnect string) (dataprovider.User, []byte, error) {
  112. var newUser dataprovider.User
  113. var body []byte
  114. url, err := addDisconnectQueryParam(buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), disconnect)
  115. if err != nil {
  116. return user, body, err
  117. }
  118. userAsJSON, _ := json.Marshal(user)
  119. resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json")
  120. if err != nil {
  121. return user, body, err
  122. }
  123. defer resp.Body.Close()
  124. body, _ = getResponseBody(resp)
  125. err = checkResponse(resp.StatusCode, expectedStatusCode)
  126. if expectedStatusCode != http.StatusOK {
  127. return newUser, body, err
  128. }
  129. if err == nil {
  130. newUser, body, err = GetUserByID(user.ID, expectedStatusCode)
  131. }
  132. if err == nil {
  133. err = checkUser(&user, &newUser)
  134. }
  135. return newUser, body, err
  136. }
  137. // RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode.
  138. func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  139. var body []byte
  140. resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), nil, "")
  141. if err != nil {
  142. return body, err
  143. }
  144. defer resp.Body.Close()
  145. body, _ = getResponseBody(resp)
  146. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  147. }
  148. // GetUserByID gets a user by database id and checks the received HTTP Status code against expectedStatusCode.
  149. func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, []byte, error) {
  150. var user dataprovider.User
  151. var body []byte
  152. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(userPath, strconv.FormatInt(userID, 10)), nil, "")
  153. if err != nil {
  154. return user, body, err
  155. }
  156. defer resp.Body.Close()
  157. err = checkResponse(resp.StatusCode, expectedStatusCode)
  158. if err == nil && expectedStatusCode == http.StatusOK {
  159. err = render.DecodeJSON(resp.Body, &user)
  160. } else {
  161. body, _ = getResponseBody(resp)
  162. }
  163. return user, body, err
  164. }
  165. // GetUsers returns a list of users and checks the received HTTP Status code against expectedStatusCode.
  166. // The number of results can be limited specifying a limit.
  167. // Some results can be skipped specifying an offset.
  168. // The results can be filtered specifying a username, the username filter is an exact match
  169. func GetUsers(limit, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, []byte, error) {
  170. var users []dataprovider.User
  171. var body []byte
  172. url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(userPath), limit, offset)
  173. if err != nil {
  174. return users, body, err
  175. }
  176. if len(username) > 0 {
  177. q := url.Query()
  178. q.Add("username", username)
  179. url.RawQuery = q.Encode()
  180. }
  181. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  182. if err != nil {
  183. return users, body, err
  184. }
  185. defer resp.Body.Close()
  186. err = checkResponse(resp.StatusCode, expectedStatusCode)
  187. if err == nil && expectedStatusCode == http.StatusOK {
  188. err = render.DecodeJSON(resp.Body, &users)
  189. } else {
  190. body, _ = getResponseBody(resp)
  191. }
  192. return users, body, err
  193. }
  194. // GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode.
  195. func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) {
  196. var quotaScans []common.ActiveQuotaScan
  197. var body []byte
  198. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "")
  199. if err != nil {
  200. return quotaScans, body, err
  201. }
  202. defer resp.Body.Close()
  203. err = checkResponse(resp.StatusCode, expectedStatusCode)
  204. if err == nil && expectedStatusCode == http.StatusOK {
  205. err = render.DecodeJSON(resp.Body, &quotaScans)
  206. } else {
  207. body, _ = getResponseBody(resp)
  208. }
  209. return quotaScans, body, err
  210. }
  211. // StartQuotaScan starts a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode.
  212. func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) {
  213. var body []byte
  214. userAsJSON, _ := json.Marshal(user)
  215. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotaScanPath), bytes.NewBuffer(userAsJSON), "")
  216. if err != nil {
  217. return body, err
  218. }
  219. defer resp.Body.Close()
  220. body, _ = getResponseBody(resp)
  221. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  222. }
  223. // UpdateQuotaUsage updates the user used quota limits and checks the received HTTP Status code against expectedStatusCode.
  224. func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) {
  225. var body []byte
  226. userAsJSON, _ := json.Marshal(user)
  227. url, err := addModeQueryParam(buildURLRelativeToBase(updateUsedQuotaPath), mode)
  228. if err != nil {
  229. return body, err
  230. }
  231. resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "")
  232. if err != nil {
  233. return body, err
  234. }
  235. defer resp.Body.Close()
  236. body, _ = getResponseBody(resp)
  237. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  238. }
  239. // GetConnections returns status and stats for active SFTP/SCP connections
  240. func GetConnections(expectedStatusCode int) ([]common.ConnectionStatus, []byte, error) {
  241. var connections []common.ConnectionStatus
  242. var body []byte
  243. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(activeConnectionsPath), nil, "")
  244. if err != nil {
  245. return connections, body, err
  246. }
  247. defer resp.Body.Close()
  248. err = checkResponse(resp.StatusCode, expectedStatusCode)
  249. if err == nil && expectedStatusCode == http.StatusOK {
  250. err = render.DecodeJSON(resp.Body, &connections)
  251. } else {
  252. body, _ = getResponseBody(resp)
  253. }
  254. return connections, body, err
  255. }
  256. // CloseConnection closes an active connection identified by connectionID
  257. func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
  258. var body []byte
  259. resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil, "")
  260. if err != nil {
  261. return body, err
  262. }
  263. defer resp.Body.Close()
  264. err = checkResponse(resp.StatusCode, expectedStatusCode)
  265. body, _ = getResponseBody(resp)
  266. return body, err
  267. }
  268. // AddFolder adds a new folder and checks the received HTTP Status code against expectedStatusCode
  269. func AddFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) {
  270. var newFolder vfs.BaseVirtualFolder
  271. var body []byte
  272. folderAsJSON, _ := json.Marshal(folder)
  273. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(folderPath), bytes.NewBuffer(folderAsJSON),
  274. "application/json")
  275. if err != nil {
  276. return newFolder, body, err
  277. }
  278. defer resp.Body.Close()
  279. err = checkResponse(resp.StatusCode, expectedStatusCode)
  280. if expectedStatusCode != http.StatusOK {
  281. body, _ = getResponseBody(resp)
  282. return newFolder, body, err
  283. }
  284. if err == nil {
  285. err = render.DecodeJSON(resp.Body, &newFolder)
  286. } else {
  287. body, _ = getResponseBody(resp)
  288. }
  289. if err == nil {
  290. err = checkFolder(&folder, &newFolder)
  291. }
  292. return newFolder, body, err
  293. }
  294. // RemoveFolder removes an existing user and checks the received HTTP Status code against expectedStatusCode.
  295. func RemoveFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) {
  296. var body []byte
  297. baseURL := buildURLRelativeToBase(folderPath)
  298. url, err := url.Parse(baseURL)
  299. if err != nil {
  300. return body, err
  301. }
  302. q := url.Query()
  303. q.Add("folder_path", folder.MappedPath)
  304. url.RawQuery = q.Encode()
  305. resp, err := sendHTTPRequest(http.MethodDelete, url.String(), nil, "")
  306. if err != nil {
  307. return body, err
  308. }
  309. defer resp.Body.Close()
  310. body, _ = getResponseBody(resp)
  311. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  312. }
  313. // GetFolders returns a list of folders and checks the received HTTP Status code against expectedStatusCode.
  314. // The number of results can be limited specifying a limit.
  315. // Some results can be skipped specifying an offset.
  316. // The results can be filtered specifying a folder path, the folder path filter is an exact match
  317. func GetFolders(limit int64, offset int64, mappedPath string, expectedStatusCode int) ([]vfs.BaseVirtualFolder, []byte, error) {
  318. var folders []vfs.BaseVirtualFolder
  319. var body []byte
  320. url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(folderPath), limit, offset)
  321. if err != nil {
  322. return folders, body, err
  323. }
  324. if len(mappedPath) > 0 {
  325. q := url.Query()
  326. q.Add("folder_path", mappedPath)
  327. url.RawQuery = q.Encode()
  328. }
  329. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  330. if err != nil {
  331. return folders, body, err
  332. }
  333. defer resp.Body.Close()
  334. err = checkResponse(resp.StatusCode, expectedStatusCode)
  335. if err == nil && expectedStatusCode == http.StatusOK {
  336. err = render.DecodeJSON(resp.Body, &folders)
  337. } else {
  338. body, _ = getResponseBody(resp)
  339. }
  340. return folders, body, err
  341. }
  342. // GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode.
  343. func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) {
  344. var quotaScans []common.ActiveVirtualFolderQuotaScan
  345. var body []byte
  346. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "")
  347. if err != nil {
  348. return quotaScans, body, err
  349. }
  350. defer resp.Body.Close()
  351. err = checkResponse(resp.StatusCode, expectedStatusCode)
  352. if err == nil && expectedStatusCode == http.StatusOK {
  353. err = render.DecodeJSON(resp.Body, &quotaScans)
  354. } else {
  355. body, _ = getResponseBody(resp)
  356. }
  357. return quotaScans, body, err
  358. }
  359. // StartFolderQuotaScan start a new quota scan for the given folder and checks the received HTTP Status code against expectedStatusCode.
  360. func StartFolderQuotaScan(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) {
  361. var body []byte
  362. folderAsJSON, _ := json.Marshal(folder)
  363. resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotaScanVFolderPath), bytes.NewBuffer(folderAsJSON), "")
  364. if err != nil {
  365. return body, err
  366. }
  367. defer resp.Body.Close()
  368. body, _ = getResponseBody(resp)
  369. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  370. }
  371. // UpdateFolderQuotaUsage updates the folder used quota limits and checks the received HTTP Status code against expectedStatusCode.
  372. func UpdateFolderQuotaUsage(folder vfs.BaseVirtualFolder, mode string, expectedStatusCode int) ([]byte, error) {
  373. var body []byte
  374. folderAsJSON, _ := json.Marshal(folder)
  375. url, err := addModeQueryParam(buildURLRelativeToBase(updateFolderUsedQuotaPath), mode)
  376. if err != nil {
  377. return body, err
  378. }
  379. resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(folderAsJSON), "")
  380. if err != nil {
  381. return body, err
  382. }
  383. defer resp.Body.Close()
  384. body, _ = getResponseBody(resp)
  385. return body, checkResponse(resp.StatusCode, expectedStatusCode)
  386. }
  387. // GetVersion returns version details
  388. func GetVersion(expectedStatusCode int) (version.Info, []byte, error) {
  389. var appVersion version.Info
  390. var body []byte
  391. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(versionPath), nil, "")
  392. if err != nil {
  393. return appVersion, body, err
  394. }
  395. defer resp.Body.Close()
  396. err = checkResponse(resp.StatusCode, expectedStatusCode)
  397. if err == nil && expectedStatusCode == http.StatusOK {
  398. err = render.DecodeJSON(resp.Body, &appVersion)
  399. } else {
  400. body, _ = getResponseBody(resp)
  401. }
  402. return appVersion, body, err
  403. }
  404. // GetProviderStatus returns provider status
  405. func GetProviderStatus(expectedStatusCode int) (map[string]interface{}, []byte, error) {
  406. var response map[string]interface{}
  407. var body []byte
  408. resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(providerStatusPath), nil, "")
  409. if err != nil {
  410. return response, body, err
  411. }
  412. defer resp.Body.Close()
  413. err = checkResponse(resp.StatusCode, expectedStatusCode)
  414. if err == nil && (expectedStatusCode == http.StatusOK || expectedStatusCode == http.StatusInternalServerError) {
  415. err = render.DecodeJSON(resp.Body, &response)
  416. } else {
  417. body, _ = getResponseBody(resp)
  418. }
  419. return response, body, err
  420. }
  421. // Dumpdata requests a backup to outputFile.
  422. // outputFile is relative to the configured backups_path
  423. func Dumpdata(outputFile, indent string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  424. var response map[string]interface{}
  425. var body []byte
  426. url, err := url.Parse(buildURLRelativeToBase(dumpDataPath))
  427. if err != nil {
  428. return response, body, err
  429. }
  430. q := url.Query()
  431. q.Add("output_file", outputFile)
  432. if len(indent) > 0 {
  433. q.Add("indent", indent)
  434. }
  435. url.RawQuery = q.Encode()
  436. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  437. if err != nil {
  438. return response, body, err
  439. }
  440. defer resp.Body.Close()
  441. err = checkResponse(resp.StatusCode, expectedStatusCode)
  442. if err == nil && expectedStatusCode == http.StatusOK {
  443. err = render.DecodeJSON(resp.Body, &response)
  444. } else {
  445. body, _ = getResponseBody(resp)
  446. }
  447. return response, body, err
  448. }
  449. // Loaddata restores a backup.
  450. // New users are added, existing users are updated. Users will be restored one by one and the restore is stopped if a
  451. // user cannot be added/updated, so it could happen a partial restore
  452. func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
  453. var response map[string]interface{}
  454. var body []byte
  455. url, err := url.Parse(buildURLRelativeToBase(loadDataPath))
  456. if err != nil {
  457. return response, body, err
  458. }
  459. q := url.Query()
  460. q.Add("input_file", inputFile)
  461. if len(scanQuota) > 0 {
  462. q.Add("scan_quota", scanQuota)
  463. }
  464. if len(mode) > 0 {
  465. q.Add("mode", mode)
  466. }
  467. url.RawQuery = q.Encode()
  468. resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "")
  469. if err != nil {
  470. return response, body, err
  471. }
  472. defer resp.Body.Close()
  473. err = checkResponse(resp.StatusCode, expectedStatusCode)
  474. if err == nil && expectedStatusCode == http.StatusOK {
  475. err = render.DecodeJSON(resp.Body, &response)
  476. } else {
  477. body, _ = getResponseBody(resp)
  478. }
  479. return response, body, err
  480. }
  481. func checkResponse(actual int, expected int) error {
  482. if expected != actual {
  483. return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
  484. }
  485. return nil
  486. }
  487. func getResponseBody(resp *http.Response) ([]byte, error) {
  488. return ioutil.ReadAll(resp.Body)
  489. }
  490. func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder) error {
  491. if expected.ID <= 0 {
  492. if actual.ID <= 0 {
  493. return errors.New("actual folder ID must be > 0")
  494. }
  495. } else {
  496. if actual.ID != expected.ID {
  497. return errors.New("folder ID mismatch")
  498. }
  499. }
  500. if expected.MappedPath != actual.MappedPath {
  501. return errors.New("mapped path mismatch")
  502. }
  503. if expected.LastQuotaUpdate != actual.LastQuotaUpdate {
  504. return errors.New("last quota update mismatch")
  505. }
  506. if expected.UsedQuotaSize != actual.UsedQuotaSize {
  507. return errors.New("used quota size mismatch")
  508. }
  509. if expected.UsedQuotaFiles != actual.UsedQuotaFiles {
  510. return errors.New("used quota files mismatch")
  511. }
  512. if len(expected.Users) != len(actual.Users) {
  513. return errors.New("folder users mismatch")
  514. }
  515. for _, u := range actual.Users {
  516. if !utils.IsStringInSlice(u, expected.Users) {
  517. return errors.New("folder users mismatch")
  518. }
  519. }
  520. return nil
  521. }
  522. func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
  523. if len(actual.Password) > 0 {
  524. return errors.New("User password must not be visible")
  525. }
  526. if expected.ID <= 0 {
  527. if actual.ID <= 0 {
  528. return errors.New("actual user ID must be > 0")
  529. }
  530. } else {
  531. if actual.ID != expected.ID {
  532. return errors.New("user ID mismatch")
  533. }
  534. }
  535. if len(expected.Permissions) != len(actual.Permissions) {
  536. return errors.New("Permissions mismatch")
  537. }
  538. for dir, perms := range expected.Permissions {
  539. if actualPerms, ok := actual.Permissions[dir]; ok {
  540. for _, v := range actualPerms {
  541. if !utils.IsStringInSlice(v, perms) {
  542. return errors.New("Permissions contents mismatch")
  543. }
  544. }
  545. } else {
  546. return errors.New("Permissions directories mismatch")
  547. }
  548. }
  549. if err := compareUserFilters(expected, actual); err != nil {
  550. return err
  551. }
  552. if err := compareUserFsConfig(expected, actual); err != nil {
  553. return err
  554. }
  555. if err := compareUserVirtualFolders(expected, actual); err != nil {
  556. return err
  557. }
  558. return compareEqualsUserFields(expected, actual)
  559. }
  560. func compareUserVirtualFolders(expected *dataprovider.User, actual *dataprovider.User) error {
  561. if len(actual.VirtualFolders) != len(expected.VirtualFolders) {
  562. return errors.New("Virtual folders mismatch")
  563. }
  564. for _, v := range actual.VirtualFolders {
  565. found := false
  566. for _, v1 := range expected.VirtualFolders {
  567. if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) &&
  568. filepath.Clean(v.MappedPath) == filepath.Clean(v1.MappedPath) {
  569. found = true
  570. break
  571. }
  572. }
  573. if !found {
  574. return errors.New("Virtual folders mismatch")
  575. }
  576. }
  577. return nil
  578. }
  579. func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) error {
  580. if expected.FsConfig.Provider != actual.FsConfig.Provider {
  581. return errors.New("Fs provider mismatch")
  582. }
  583. if err := compareS3Config(expected, actual); err != nil {
  584. return err
  585. }
  586. if err := compareGCSConfig(expected, actual); err != nil {
  587. return err
  588. }
  589. if err := compareAzBlobConfig(expected, actual); err != nil {
  590. return err
  591. }
  592. if err := checkEncryptedSecret(expected.FsConfig.CryptConfig.Passphrase, actual.FsConfig.CryptConfig.Passphrase); err != nil {
  593. return err
  594. }
  595. return nil
  596. }
  597. func compareS3Config(expected *dataprovider.User, actual *dataprovider.User) error {
  598. if expected.FsConfig.S3Config.Bucket != actual.FsConfig.S3Config.Bucket {
  599. return errors.New("S3 bucket mismatch")
  600. }
  601. if expected.FsConfig.S3Config.Region != actual.FsConfig.S3Config.Region {
  602. return errors.New("S3 region mismatch")
  603. }
  604. if expected.FsConfig.S3Config.AccessKey != actual.FsConfig.S3Config.AccessKey {
  605. return errors.New("S3 access key mismatch")
  606. }
  607. if err := checkEncryptedSecret(expected.FsConfig.S3Config.AccessSecret, actual.FsConfig.S3Config.AccessSecret); err != nil {
  608. return fmt.Errorf("S3 access secret mismatch: %v", err)
  609. }
  610. if expected.FsConfig.S3Config.Endpoint != actual.FsConfig.S3Config.Endpoint {
  611. return errors.New("S3 endpoint mismatch")
  612. }
  613. if expected.FsConfig.S3Config.StorageClass != actual.FsConfig.S3Config.StorageClass {
  614. return errors.New("S3 storage class mismatch")
  615. }
  616. if expected.FsConfig.S3Config.UploadPartSize != actual.FsConfig.S3Config.UploadPartSize {
  617. return errors.New("S3 upload part size mismatch")
  618. }
  619. if expected.FsConfig.S3Config.UploadConcurrency != actual.FsConfig.S3Config.UploadConcurrency {
  620. return errors.New("S3 upload concurrency mismatch")
  621. }
  622. if expected.FsConfig.S3Config.KeyPrefix != actual.FsConfig.S3Config.KeyPrefix &&
  623. expected.FsConfig.S3Config.KeyPrefix+"/" != actual.FsConfig.S3Config.KeyPrefix {
  624. return errors.New("S3 key prefix mismatch")
  625. }
  626. return nil
  627. }
  628. func compareGCSConfig(expected *dataprovider.User, actual *dataprovider.User) error {
  629. if expected.FsConfig.GCSConfig.Bucket != actual.FsConfig.GCSConfig.Bucket {
  630. return errors.New("GCS bucket mismatch")
  631. }
  632. if expected.FsConfig.GCSConfig.StorageClass != actual.FsConfig.GCSConfig.StorageClass {
  633. return errors.New("GCS storage class mismatch")
  634. }
  635. if expected.FsConfig.GCSConfig.KeyPrefix != actual.FsConfig.GCSConfig.KeyPrefix &&
  636. expected.FsConfig.GCSConfig.KeyPrefix+"/" != actual.FsConfig.GCSConfig.KeyPrefix {
  637. return errors.New("GCS key prefix mismatch")
  638. }
  639. if expected.FsConfig.GCSConfig.AutomaticCredentials != actual.FsConfig.GCSConfig.AutomaticCredentials {
  640. return errors.New("GCS automatic credentials mismatch")
  641. }
  642. return nil
  643. }
  644. func compareAzBlobConfig(expected *dataprovider.User, actual *dataprovider.User) error {
  645. if expected.FsConfig.AzBlobConfig.Container != actual.FsConfig.AzBlobConfig.Container {
  646. return errors.New("Azure Blob container mismatch")
  647. }
  648. if expected.FsConfig.AzBlobConfig.AccountName != actual.FsConfig.AzBlobConfig.AccountName {
  649. return errors.New("Azure Blob account name mismatch")
  650. }
  651. if err := checkEncryptedSecret(expected.FsConfig.AzBlobConfig.AccountKey, actual.FsConfig.AzBlobConfig.AccountKey); err != nil {
  652. return fmt.Errorf("Azure Blob account key mismatch: %v", err)
  653. }
  654. if expected.FsConfig.AzBlobConfig.Endpoint != actual.FsConfig.AzBlobConfig.Endpoint {
  655. return errors.New("Azure Blob endpoint mismatch")
  656. }
  657. if expected.FsConfig.AzBlobConfig.SASURL != actual.FsConfig.AzBlobConfig.SASURL {
  658. return errors.New("Azure Blob SASL URL mismatch")
  659. }
  660. if expected.FsConfig.AzBlobConfig.UploadPartSize != actual.FsConfig.AzBlobConfig.UploadPartSize {
  661. return errors.New("Azure Blob upload part size mismatch")
  662. }
  663. if expected.FsConfig.AzBlobConfig.UploadConcurrency != actual.FsConfig.AzBlobConfig.UploadConcurrency {
  664. return errors.New("Azure Blob upload concurrency mismatch")
  665. }
  666. if expected.FsConfig.AzBlobConfig.KeyPrefix != actual.FsConfig.AzBlobConfig.KeyPrefix &&
  667. expected.FsConfig.AzBlobConfig.KeyPrefix+"/" != actual.FsConfig.AzBlobConfig.KeyPrefix {
  668. return errors.New("Azure Blob key prefix mismatch")
  669. }
  670. if expected.FsConfig.AzBlobConfig.UseEmulator != actual.FsConfig.AzBlobConfig.UseEmulator {
  671. return errors.New("Azure Blob use emulator mismatch")
  672. }
  673. if expected.FsConfig.AzBlobConfig.AccessTier != actual.FsConfig.AzBlobConfig.AccessTier {
  674. return errors.New("Azure Blob access tier mismatch")
  675. }
  676. return nil
  677. }
  678. func areSecretEquals(expected, actual *kms.Secret) bool {
  679. if expected == nil && actual == nil {
  680. return true
  681. }
  682. if expected != nil && expected.IsEmpty() && actual == nil {
  683. return true
  684. }
  685. if actual != nil && actual.IsEmpty() && expected == nil {
  686. return true
  687. }
  688. return false
  689. }
  690. func checkEncryptedSecret(expected, actual *kms.Secret) error {
  691. if areSecretEquals(expected, actual) {
  692. return nil
  693. }
  694. if expected == nil && actual != nil && !actual.IsEmpty() {
  695. return errors.New("secret mismatch")
  696. }
  697. if actual == nil && expected != nil && !expected.IsEmpty() {
  698. return errors.New("secret mismatch")
  699. }
  700. if expected.IsPlain() && actual.IsEncrypted() {
  701. if actual.GetPayload() == "" {
  702. return errors.New("invalid secret payload")
  703. }
  704. if actual.GetAdditionalData() != "" {
  705. return errors.New("invalid secret additional data")
  706. }
  707. if actual.GetKey() != "" {
  708. return errors.New("invalid secret key")
  709. }
  710. } else {
  711. if expected.GetStatus() != actual.GetStatus() || expected.GetPayload() != actual.GetPayload() {
  712. return errors.New("secret mismatch")
  713. }
  714. }
  715. return nil
  716. }
  717. func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error {
  718. if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) {
  719. return errors.New("AllowedIP mismatch")
  720. }
  721. if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) {
  722. return errors.New("DeniedIP mismatch")
  723. }
  724. if len(expected.Filters.DeniedLoginMethods) != len(actual.Filters.DeniedLoginMethods) {
  725. return errors.New("Denied login methods mismatch")
  726. }
  727. if len(expected.Filters.DeniedProtocols) != len(actual.Filters.DeniedProtocols) {
  728. return errors.New("Denied protocols mismatch")
  729. }
  730. if expected.Filters.MaxUploadFileSize != actual.Filters.MaxUploadFileSize {
  731. return errors.New("Max upload file size mismatch")
  732. }
  733. for _, IPMask := range expected.Filters.AllowedIP {
  734. if !utils.IsStringInSlice(IPMask, actual.Filters.AllowedIP) {
  735. return errors.New("AllowedIP contents mismatch")
  736. }
  737. }
  738. for _, IPMask := range expected.Filters.DeniedIP {
  739. if !utils.IsStringInSlice(IPMask, actual.Filters.DeniedIP) {
  740. return errors.New("DeniedIP contents mismatch")
  741. }
  742. }
  743. for _, method := range expected.Filters.DeniedLoginMethods {
  744. if !utils.IsStringInSlice(method, actual.Filters.DeniedLoginMethods) {
  745. return errors.New("Denied login methods contents mismatch")
  746. }
  747. }
  748. for _, protocol := range expected.Filters.DeniedProtocols {
  749. if !utils.IsStringInSlice(protocol, actual.Filters.DeniedProtocols) {
  750. return errors.New("Denied protocols contents mismatch")
  751. }
  752. }
  753. if err := compareUserFileExtensionsFilters(expected, actual); err != nil {
  754. return err
  755. }
  756. return compareUserFilePatternsFilters(expected, actual)
  757. }
  758. func checkFilterMatch(expected []string, actual []string) bool {
  759. if len(expected) != len(actual) {
  760. return false
  761. }
  762. for _, e := range expected {
  763. if !utils.IsStringInSlice(strings.ToLower(e), actual) {
  764. return false
  765. }
  766. }
  767. return true
  768. }
  769. func compareUserFilePatternsFilters(expected *dataprovider.User, actual *dataprovider.User) error {
  770. if len(expected.Filters.FilePatterns) != len(actual.Filters.FilePatterns) {
  771. return errors.New("file patterns mismatch")
  772. }
  773. for _, f := range expected.Filters.FilePatterns {
  774. found := false
  775. for _, f1 := range actual.Filters.FilePatterns {
  776. if path.Clean(f.Path) == path.Clean(f1.Path) {
  777. if !checkFilterMatch(f.AllowedPatterns, f1.AllowedPatterns) ||
  778. !checkFilterMatch(f.DeniedPatterns, f1.DeniedPatterns) {
  779. return errors.New("file patterns contents mismatch")
  780. }
  781. found = true
  782. }
  783. }
  784. if !found {
  785. return errors.New("file patterns contents mismatch")
  786. }
  787. }
  788. return nil
  789. }
  790. func compareUserFileExtensionsFilters(expected *dataprovider.User, actual *dataprovider.User) error {
  791. if len(expected.Filters.FileExtensions) != len(actual.Filters.FileExtensions) {
  792. return errors.New("file extensions mismatch")
  793. }
  794. for _, f := range expected.Filters.FileExtensions {
  795. found := false
  796. for _, f1 := range actual.Filters.FileExtensions {
  797. if path.Clean(f.Path) == path.Clean(f1.Path) {
  798. if !checkFilterMatch(f.AllowedExtensions, f1.AllowedExtensions) ||
  799. !checkFilterMatch(f.DeniedExtensions, f1.DeniedExtensions) {
  800. return errors.New("file extensions contents mismatch")
  801. }
  802. found = true
  803. }
  804. }
  805. if !found {
  806. return errors.New("file extensions contents mismatch")
  807. }
  808. }
  809. return nil
  810. }
  811. func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error {
  812. if expected.Username != actual.Username {
  813. return errors.New("Username mismatch")
  814. }
  815. if expected.HomeDir != actual.HomeDir {
  816. return errors.New("HomeDir mismatch")
  817. }
  818. if expected.UID != actual.UID {
  819. return errors.New("UID mismatch")
  820. }
  821. if expected.GID != actual.GID {
  822. return errors.New("GID mismatch")
  823. }
  824. if expected.MaxSessions != actual.MaxSessions {
  825. return errors.New("MaxSessions mismatch")
  826. }
  827. if expected.QuotaSize != actual.QuotaSize {
  828. return errors.New("QuotaSize mismatch")
  829. }
  830. if expected.QuotaFiles != actual.QuotaFiles {
  831. return errors.New("QuotaFiles mismatch")
  832. }
  833. if len(expected.Permissions) != len(actual.Permissions) {
  834. return errors.New("Permissions mismatch")
  835. }
  836. if expected.UploadBandwidth != actual.UploadBandwidth {
  837. return errors.New("UploadBandwidth mismatch")
  838. }
  839. if expected.DownloadBandwidth != actual.DownloadBandwidth {
  840. return errors.New("DownloadBandwidth mismatch")
  841. }
  842. if expected.Status != actual.Status {
  843. return errors.New("Status mismatch")
  844. }
  845. if expected.ExpirationDate != actual.ExpirationDate {
  846. return errors.New("ExpirationDate mismatch")
  847. }
  848. if expected.AdditionalInfo != actual.AdditionalInfo {
  849. return errors.New("AdditionalInfo mismatch")
  850. }
  851. return nil
  852. }
  853. func addLimitAndOffsetQueryParams(rawurl string, limit, offset int64) (*url.URL, error) {
  854. url, err := url.Parse(rawurl)
  855. if err != nil {
  856. return nil, err
  857. }
  858. q := url.Query()
  859. if limit > 0 {
  860. q.Add("limit", strconv.FormatInt(limit, 10))
  861. }
  862. if offset > 0 {
  863. q.Add("offset", strconv.FormatInt(offset, 10))
  864. }
  865. url.RawQuery = q.Encode()
  866. return url, err
  867. }
  868. func addModeQueryParam(rawurl, mode string) (*url.URL, error) {
  869. url, err := url.Parse(rawurl)
  870. if err != nil {
  871. return nil, err
  872. }
  873. q := url.Query()
  874. if len(mode) > 0 {
  875. q.Add("mode", mode)
  876. }
  877. url.RawQuery = q.Encode()
  878. return url, err
  879. }
  880. func addDisconnectQueryParam(rawurl, disconnect string) (*url.URL, error) {
  881. url, err := url.Parse(rawurl)
  882. if err != nil {
  883. return nil, err
  884. }
  885. q := url.Query()
  886. if len(disconnect) > 0 {
  887. q.Add("disconnect", disconnect)
  888. }
  889. url.RawQuery = q.Encode()
  890. return url, err
  891. }