api_utils.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package httpd
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net/http"
  7. "os"
  8. "path"
  9. "strconv"
  10. "strings"
  11. "github.com/go-chi/render"
  12. "github.com/klauspost/compress/zip"
  13. "github.com/drakkan/sftpgo/common"
  14. "github.com/drakkan/sftpgo/dataprovider"
  15. "github.com/drakkan/sftpgo/logger"
  16. )
  17. func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
  18. var errorString string
  19. if err != nil {
  20. errorString = err.Error()
  21. }
  22. resp := apiResponse{
  23. Error: errorString,
  24. Message: message,
  25. }
  26. ctx := context.WithValue(r.Context(), render.StatusCtxKey, code)
  27. render.JSON(w, r.WithContext(ctx), resp)
  28. }
  29. func getRespStatus(err error) int {
  30. if _, ok := err.(*dataprovider.ValidationError); ok {
  31. return http.StatusBadRequest
  32. }
  33. if _, ok := err.(*dataprovider.MethodDisabledError); ok {
  34. return http.StatusForbidden
  35. }
  36. if _, ok := err.(*dataprovider.RecordNotFoundError); ok {
  37. return http.StatusNotFound
  38. }
  39. if os.IsNotExist(err) {
  40. return http.StatusBadRequest
  41. }
  42. return http.StatusInternalServerError
  43. }
  44. func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
  45. connectionID := getURLParam(r, "connectionID")
  46. if connectionID == "" {
  47. sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
  48. return
  49. }
  50. if common.Connections.Close(connectionID) {
  51. sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK)
  52. } else {
  53. sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
  54. }
  55. }
  56. func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) {
  57. var err error
  58. limit := 100
  59. offset := 0
  60. order := dataprovider.OrderASC
  61. if _, ok := r.URL.Query()["limit"]; ok {
  62. limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
  63. if err != nil {
  64. err = errors.New("invalid limit")
  65. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  66. return limit, offset, order, err
  67. }
  68. if limit > 500 {
  69. limit = 500
  70. }
  71. }
  72. if _, ok := r.URL.Query()["offset"]; ok {
  73. offset, err = strconv.Atoi(r.URL.Query().Get("offset"))
  74. if err != nil {
  75. err = errors.New("invalid offset")
  76. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  77. return limit, offset, order, err
  78. }
  79. }
  80. if _, ok := r.URL.Query()["order"]; ok {
  81. order = r.URL.Query().Get("order")
  82. if order != dataprovider.OrderASC && order != dataprovider.OrderDESC {
  83. err = errors.New("invalid order")
  84. sendAPIResponse(w, r, err, "", http.StatusBadRequest)
  85. return limit, offset, order, err
  86. }
  87. }
  88. return limit, offset, order, err
  89. }
  90. func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir string, files []string) {
  91. w.Header().Set("Content-Type", "application/zip")
  92. w.Header().Set("Accept-Ranges", "none")
  93. w.Header().Set("Content-Transfer-Encoding", "binary")
  94. w.WriteHeader(http.StatusOK)
  95. wr := zip.NewWriter(w)
  96. for _, file := range files {
  97. fullPath := path.Join(baseDir, file)
  98. if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
  99. panic(http.ErrAbortHandler)
  100. }
  101. }
  102. if err := wr.Close(); err != nil {
  103. conn.Log(logger.LevelWarn, "unable to close zip file: %v", err)
  104. panic(http.ErrAbortHandler)
  105. }
  106. }
  107. func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string) error {
  108. info, err := conn.Stat(entryPath, 1)
  109. if err != nil {
  110. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, stat error: %v", entryPath, err)
  111. return err
  112. }
  113. if info.IsDir() {
  114. _, err := wr.Create(getZipEntryName(entryPath, baseDir) + "/")
  115. if err != nil {
  116. conn.Log(logger.LevelDebug, "unable to create zip entry %#v: %v", entryPath, err)
  117. return err
  118. }
  119. contents, err := conn.ReadDir(entryPath)
  120. if err != nil {
  121. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, read dir error: %v", entryPath, err)
  122. return err
  123. }
  124. for _, info := range contents {
  125. fullPath := path.Join(entryPath, info.Name())
  126. if err := addZipEntry(wr, conn, fullPath, baseDir); err != nil {
  127. return err
  128. }
  129. }
  130. return nil
  131. }
  132. if !info.Mode().IsRegular() {
  133. // we only allow regular files
  134. conn.Log(logger.LevelDebug, "skipping zip entry for non regular file %#v", entryPath)
  135. return nil
  136. }
  137. reader, err := conn.getFileReader(entryPath, 0, http.MethodGet)
  138. if err != nil {
  139. conn.Log(logger.LevelDebug, "unable to add zip entry %#v, cannot open file: %v", entryPath, err)
  140. return err
  141. }
  142. defer reader.Close()
  143. f, err := wr.Create(getZipEntryName(entryPath, baseDir))
  144. if err != nil {
  145. conn.Log(logger.LevelDebug, "unable to create zip entry %#v: %v", entryPath, err)
  146. return err
  147. }
  148. _, err = io.Copy(f, reader)
  149. return err
  150. }
  151. func getZipEntryName(entryPath, baseDir string) string {
  152. entryPath = strings.TrimPrefix(entryPath, baseDir)
  153. return strings.TrimPrefix(entryPath, "/")
  154. }