api_utils.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package api
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "strconv"
  10. "github.com/drakkan/sftpgo/dataprovider"
  11. "github.com/drakkan/sftpgo/sftpd"
  12. "github.com/drakkan/sftpgo/utils"
  13. "github.com/go-chi/render"
  14. )
  15. var (
  16. defaultPerms = []string{dataprovider.PermAny}
  17. httpBaseURL = "http://127.0.0.1:8080"
  18. )
  19. // SetBaseURL sets the url to use for HTTP request, default is "http://127.0.0.1:8080"
  20. func SetBaseURL(url string) {
  21. httpBaseURL = url
  22. }
  23. // AddUser add a new user, useful for tests
  24. func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
  25. var newUser dataprovider.User
  26. userAsJSON, err := json.Marshal(user)
  27. if err != nil {
  28. return newUser, err
  29. }
  30. resp, err := http.Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON))
  31. if err != nil {
  32. return newUser, err
  33. }
  34. defer resp.Body.Close()
  35. err = checkResponse(resp.StatusCode, expectedStatusCode, resp)
  36. if expectedStatusCode != http.StatusOK {
  37. return newUser, err
  38. }
  39. if err == nil {
  40. err = render.DecodeJSON(resp.Body, &newUser)
  41. }
  42. if err == nil {
  43. err = checkUser(user, newUser)
  44. }
  45. return newUser, err
  46. }
  47. // UpdateUser update an user, useful for tests
  48. func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
  49. var newUser dataprovider.User
  50. userAsJSON, err := json.Marshal(user)
  51. if err != nil {
  52. return user, err
  53. }
  54. req, err := http.NewRequest(http.MethodPut, httpBaseURL+userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON))
  55. if err != nil {
  56. return user, err
  57. }
  58. req.Header.Set("Content-Type", "application/json")
  59. resp, err := http.DefaultClient.Do(req)
  60. if err != nil {
  61. return user, err
  62. }
  63. defer resp.Body.Close()
  64. err = checkResponse(resp.StatusCode, expectedStatusCode, resp)
  65. if expectedStatusCode != http.StatusOK {
  66. return newUser, err
  67. }
  68. if err == nil {
  69. newUser, err = GetUserByID(user.ID, expectedStatusCode)
  70. }
  71. if err == nil {
  72. err = checkUser(user, newUser)
  73. }
  74. return newUser, err
  75. }
  76. // RemoveUser remove user, useful for tests
  77. func RemoveUser(user dataprovider.User, expectedStatusCode int) error {
  78. req, err := http.NewRequest(http.MethodDelete, httpBaseURL+userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
  79. if err != nil {
  80. return err
  81. }
  82. resp, err := http.DefaultClient.Do(req)
  83. if err != nil {
  84. return err
  85. }
  86. defer resp.Body.Close()
  87. return checkResponse(resp.StatusCode, expectedStatusCode, resp)
  88. }
  89. // GetUserByID get user by id, useful for tests
  90. func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) {
  91. var user dataprovider.User
  92. resp, err := http.Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10))
  93. if err != nil {
  94. return user, err
  95. }
  96. defer resp.Body.Close()
  97. err = checkResponse(resp.StatusCode, expectedStatusCode, resp)
  98. if err == nil && expectedStatusCode == http.StatusOK {
  99. err = render.DecodeJSON(resp.Body, &user)
  100. }
  101. return user, err
  102. }
  103. // GetUsers useful for tests
  104. func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) {
  105. var users []dataprovider.User
  106. req, err := http.NewRequest(http.MethodGet, httpBaseURL+userPath, nil)
  107. if err != nil {
  108. return users, err
  109. }
  110. q := req.URL.Query()
  111. if limit > 0 {
  112. q.Add("limit", strconv.FormatInt(limit, 10))
  113. }
  114. if offset > 0 {
  115. q.Add("offset", strconv.FormatInt(offset, 10))
  116. }
  117. if len(username) > 0 {
  118. q.Add("username", username)
  119. }
  120. req.URL.RawQuery = q.Encode()
  121. resp, err := http.DefaultClient.Do(req)
  122. if err != nil {
  123. return users, err
  124. }
  125. defer resp.Body.Close()
  126. err = checkResponse(resp.StatusCode, expectedStatusCode, resp)
  127. if err == nil && expectedStatusCode == http.StatusOK {
  128. err = render.DecodeJSON(resp.Body, &users)
  129. }
  130. return users, err
  131. }
  132. // GetQuotaScans get active quota scans, useful for tests
  133. func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) {
  134. var quotaScans []sftpd.ActiveQuotaScan
  135. resp, err := http.Get(httpBaseURL + quotaScanPath)
  136. if err != nil {
  137. return quotaScans, err
  138. }
  139. defer resp.Body.Close()
  140. err = checkResponse(resp.StatusCode, expectedStatusCode, resp)
  141. if err == nil && expectedStatusCode == http.StatusOK {
  142. err = render.DecodeJSON(resp.Body, &quotaScans)
  143. }
  144. return quotaScans, err
  145. }
  146. // StartQuotaScan start a new quota scan
  147. func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error {
  148. userAsJSON, err := json.Marshal(user)
  149. if err != nil {
  150. return err
  151. }
  152. resp, err := http.Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON))
  153. if err != nil {
  154. return err
  155. }
  156. defer resp.Body.Close()
  157. return checkResponse(resp.StatusCode, expectedStatusCode, resp)
  158. }
  159. func checkResponse(actual int, expected int, resp *http.Response) error {
  160. if expected != actual {
  161. return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
  162. }
  163. if expected != http.StatusOK && resp != nil {
  164. b, err := ioutil.ReadAll(resp.Body)
  165. if err == nil {
  166. fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b))
  167. }
  168. }
  169. return nil
  170. }
  171. func checkUser(expected dataprovider.User, actual dataprovider.User) error {
  172. if len(actual.Password) > 0 {
  173. return errors.New("User password must not be visible")
  174. }
  175. if len(actual.PublicKey) > 0 {
  176. return errors.New("User public key must not be visible")
  177. }
  178. if expected.ID <= 0 {
  179. if actual.ID <= 0 {
  180. return errors.New("actual user ID must be > 0")
  181. }
  182. } else {
  183. if actual.ID != expected.ID {
  184. return errors.New("user ID mismatch")
  185. }
  186. }
  187. for _, v := range expected.Permissions {
  188. if !utils.IsStringInSlice(v, actual.Permissions) {
  189. return errors.New("Permissions contents mismatch")
  190. }
  191. }
  192. return compareEqualsUserFields(expected, actual)
  193. }
  194. func compareEqualsUserFields(expected dataprovider.User, actual dataprovider.User) error {
  195. if expected.Username != actual.Username {
  196. return errors.New("Username mismatch")
  197. }
  198. if expected.HomeDir != actual.HomeDir {
  199. return errors.New("HomeDir mismatch")
  200. }
  201. if expected.UID != actual.UID {
  202. return errors.New("UID mismatch")
  203. }
  204. if expected.GID != actual.GID {
  205. return errors.New("GID mismatch")
  206. }
  207. if expected.MaxSessions != actual.MaxSessions {
  208. return errors.New("MaxSessions mismatch")
  209. }
  210. if expected.QuotaSize != actual.QuotaSize {
  211. return errors.New("QuotaSize mismatch")
  212. }
  213. if expected.QuotaFiles != actual.QuotaFiles {
  214. return errors.New("QuotaFiles mismatch")
  215. }
  216. if len(expected.Permissions) != len(actual.Permissions) {
  217. return errors.New("Permissions mismatch")
  218. }
  219. if expected.UploadBandwidth != actual.UploadBandwidth {
  220. return errors.New("UploadBandwidth mismatch")
  221. }
  222. if expected.DownloadBandwidth != actual.DownloadBandwidth {
  223. return errors.New("DownloadBandwidth mismatch")
  224. }
  225. return nil
  226. }