importer.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. // Package subimporter implements a bulk ZIP/CSV importer of subscribers.
  2. // It implements a simple queue for buffering imports and committing records
  3. // to DB along with ZIP and CSV handling utilities. It is meant to be used as
  4. // a singleton as each Importer instance is stateful, where it keeps track of
  5. // an import in progress. Only one import should happen on a single importer
  6. // instance at a time.
  7. package subimporter
  8. import (
  9. "archive/zip"
  10. "bytes"
  11. "database/sql"
  12. "encoding/csv"
  13. "encoding/json"
  14. "errors"
  15. "fmt"
  16. "io"
  17. "io/ioutil"
  18. "log"
  19. "net/mail"
  20. "os"
  21. "regexp"
  22. "strings"
  23. "sync"
  24. "github.com/gofrs/uuid"
  25. "github.com/knadh/listmonk/internal/i18n"
  26. "github.com/knadh/listmonk/models"
  27. "github.com/lib/pq"
  28. )
  29. const (
  30. // stdInputMaxLen is the maximum allowed length for a standard input field.
  31. stdInputMaxLen = 200
  32. // commitBatchSize is the number of inserts to commit in a single SQL transaction.
  33. commitBatchSize = 10000
  34. )
  35. // Various import statuses.
  36. const (
  37. StatusNone = "none"
  38. StatusImporting = "importing"
  39. StatusStopping = "stopping"
  40. StatusFinished = "finished"
  41. StatusFailed = "failed"
  42. ModeSubscribe = "subscribe"
  43. ModeBlocklist = "blocklist"
  44. )
  45. // Importer represents the bulk CSV subscriber import system.
  46. type Importer struct {
  47. opt Options
  48. db *sql.DB
  49. i18n *i18n.I18n
  50. stop chan bool
  51. status Status
  52. sync.RWMutex
  53. }
  54. // Options represents inport options.
  55. type Options struct {
  56. UpsertStmt *sql.Stmt
  57. BlocklistStmt *sql.Stmt
  58. UpdateListDateStmt *sql.Stmt
  59. NotifCB models.AdminNotifCallback
  60. // Lookup table for blocklisted domains.
  61. DomainBlocklist map[string]bool
  62. }
  63. // Session represents a single import session.
  64. type Session struct {
  65. im *Importer
  66. subQueue chan SubReq
  67. log *log.Logger
  68. opt SessionOpt
  69. }
  70. // SessionOpt represents the options for an importer session.
  71. type SessionOpt struct {
  72. Filename string `json:"filename"`
  73. Mode string `json:"mode"`
  74. SubStatus string `json:"subscription_status"`
  75. Overwrite bool `json:"overwrite"`
  76. Delim string `json:"delim"`
  77. ListIDs []int `json:"lists"`
  78. }
  79. // Status reporesents statistics from an ongoing import session.
  80. type Status struct {
  81. Name string `json:"name"`
  82. Total int `json:"total"`
  83. Imported int `json:"imported"`
  84. Status string `json:"status"`
  85. logBuf *bytes.Buffer
  86. }
  87. // SubReq is a wrapper over the Subscriber model.
  88. type SubReq struct {
  89. models.Subscriber
  90. Lists pq.Int64Array `json:"lists"`
  91. ListUUIDs pq.StringArray `json:"list_uuids"`
  92. PreconfirmSubs bool `json:"preconfirm_subscriptions"`
  93. }
  94. type importStatusTpl struct {
  95. Name string
  96. Status string
  97. Imported int
  98. Total int
  99. }
  100. var (
  101. // ErrIsImporting is thrown when an import request is made while an
  102. // import is already running.
  103. ErrIsImporting = errors.New("import is already running")
  104. csvHeaders = map[string]bool{
  105. "email": true,
  106. "name": true,
  107. "attributes": true}
  108. regexCleanStr = regexp.MustCompile("[[:^ascii:]]")
  109. )
  110. // New returns a new instance of Importer.
  111. func New(opt Options, db *sql.DB, i *i18n.I18n) *Importer {
  112. im := Importer{
  113. opt: opt,
  114. db: db,
  115. i18n: i,
  116. status: Status{Status: StatusNone, logBuf: bytes.NewBuffer(nil)},
  117. stop: make(chan bool, 1),
  118. }
  119. return &im
  120. }
  121. // NewSession returns an new instance of Session. It takes the name
  122. // of the uploaded file, but doesn't do anything with it but retains it for stats.
  123. func (im *Importer) NewSession(opt SessionOpt) (*Session, error) {
  124. if im.getStatus() != StatusNone {
  125. return nil, errors.New("an import is already running")
  126. }
  127. im.Lock()
  128. im.status = Status{Status: StatusImporting,
  129. Name: opt.Filename,
  130. logBuf: bytes.NewBuffer(nil)}
  131. im.Unlock()
  132. s := &Session{
  133. im: im,
  134. log: log.New(im.status.logBuf, "", log.Ldate|log.Ltime|log.Lshortfile),
  135. subQueue: make(chan SubReq, commitBatchSize),
  136. opt: opt,
  137. }
  138. s.log.Printf("processing '%s'", opt.Filename)
  139. return s, nil
  140. }
  141. // GetStats returns the global Stats of the importer.
  142. func (im *Importer) GetStats() Status {
  143. im.RLock()
  144. defer im.RUnlock()
  145. return Status{
  146. Name: im.status.Name,
  147. Status: im.status.Status,
  148. Total: im.status.Total,
  149. Imported: im.status.Imported,
  150. }
  151. }
  152. // GetLogs returns the log entries of the last import session.
  153. func (im *Importer) GetLogs() []byte {
  154. im.RLock()
  155. defer im.RUnlock()
  156. if im.status.logBuf == nil {
  157. return []byte{}
  158. }
  159. return im.status.logBuf.Bytes()
  160. }
  161. // setStatus sets the Importer's status.
  162. func (im *Importer) setStatus(status string) {
  163. im.Lock()
  164. im.status.Status = status
  165. im.Unlock()
  166. }
  167. // getStatus get's the Importer's status.
  168. func (im *Importer) getStatus() string {
  169. im.RLock()
  170. status := im.status.Status
  171. im.RUnlock()
  172. return status
  173. }
  174. // isDone returns true if the importer is working (importing|stopping).
  175. func (im *Importer) isDone() bool {
  176. s := true
  177. im.RLock()
  178. if im.getStatus() == StatusImporting || im.getStatus() == StatusStopping {
  179. s = false
  180. }
  181. im.RUnlock()
  182. return s
  183. }
  184. // incrementImportCount sets the Importer's "imported" counter.
  185. func (im *Importer) incrementImportCount(n int) {
  186. im.Lock()
  187. im.status.Imported += n
  188. im.Unlock()
  189. }
  190. // sendNotif sends admin notifications for import completions.
  191. func (im *Importer) sendNotif(status string) error {
  192. var (
  193. s = im.GetStats()
  194. out = importStatusTpl{
  195. Name: s.Name,
  196. Status: status,
  197. Imported: s.Imported,
  198. Total: s.Total,
  199. }
  200. subject = fmt.Sprintf("%s: %s import",
  201. strings.Title(status),
  202. s.Name)
  203. )
  204. return im.opt.NotifCB(subject, out)
  205. }
  206. // Start is a blocking function that selects on a channel queue until all
  207. // subscriber entries in the import session are imported. It should be
  208. // invoked as a goroutine.
  209. func (s *Session) Start() {
  210. var (
  211. tx *sql.Tx
  212. stmt *sql.Stmt
  213. err error
  214. total = 0
  215. cur = 0
  216. listIDs = make(pq.Int64Array, len(s.opt.ListIDs))
  217. )
  218. for i, v := range s.opt.ListIDs {
  219. listIDs[i] = int64(v)
  220. }
  221. for sub := range s.subQueue {
  222. if cur == 0 {
  223. // New transaction batch.
  224. tx, err = s.im.db.Begin()
  225. if err != nil {
  226. s.log.Printf("error creating DB transaction: %v", err)
  227. continue
  228. }
  229. if s.opt.Mode == ModeSubscribe {
  230. stmt = tx.Stmt(s.im.opt.UpsertStmt)
  231. } else {
  232. stmt = tx.Stmt(s.im.opt.BlocklistStmt)
  233. }
  234. }
  235. uu, err := uuid.NewV4()
  236. if err != nil {
  237. s.log.Printf("error generating UUID: %v", err)
  238. tx.Rollback()
  239. break
  240. }
  241. if s.opt.Mode == ModeSubscribe {
  242. _, err = stmt.Exec(uu, sub.Email, sub.Name, sub.Attribs, listIDs, s.opt.SubStatus, s.opt.Overwrite)
  243. } else if s.opt.Mode == ModeBlocklist {
  244. _, err = stmt.Exec(uu, sub.Email, sub.Name, sub.Attribs)
  245. }
  246. if err != nil {
  247. s.log.Printf("error executing insert: %v", err)
  248. tx.Rollback()
  249. break
  250. }
  251. cur++
  252. total++
  253. // Batch size is met. Commit.
  254. if cur%commitBatchSize == 0 {
  255. if err := tx.Commit(); err != nil {
  256. tx.Rollback()
  257. s.log.Printf("error committing to DB: %v", err)
  258. } else {
  259. s.im.incrementImportCount(cur)
  260. s.log.Printf("imported %d", total)
  261. }
  262. cur = 0
  263. }
  264. }
  265. // Queue's closed and there's nothing left to commit.
  266. if cur == 0 {
  267. s.im.setStatus(StatusFinished)
  268. s.log.Printf("imported finished")
  269. if _, err := s.im.opt.UpdateListDateStmt.Exec(listIDs); err != nil {
  270. s.log.Printf("error updating lists date: %v", err)
  271. }
  272. s.im.sendNotif(StatusFinished)
  273. return
  274. }
  275. // Queue's closed and there are records left to commit.
  276. if err := tx.Commit(); err != nil {
  277. tx.Rollback()
  278. s.im.setStatus(StatusFailed)
  279. s.log.Printf("error committing to DB: %v", err)
  280. s.im.sendNotif(StatusFailed)
  281. return
  282. }
  283. s.im.incrementImportCount(cur)
  284. s.im.setStatus(StatusFinished)
  285. s.log.Printf("imported finished")
  286. if _, err := s.im.opt.UpdateListDateStmt.Exec(listIDs); err != nil {
  287. s.log.Printf("error updating lists date: %v", err)
  288. }
  289. s.im.sendNotif(StatusFinished)
  290. }
  291. // Stop stops an active import session.
  292. func (s *Session) Stop() {
  293. close(s.subQueue)
  294. }
  295. // ExtractZIP takes a ZIP file's path and extracts all .csv files in it to
  296. // a temporary directory, and returns the name of the temp directory and the
  297. // list of extracted .csv files.
  298. func (s *Session) ExtractZIP(srcPath string, maxCSVs int) (string, []string, error) {
  299. if s.im.isDone() {
  300. return "", nil, ErrIsImporting
  301. }
  302. failed := true
  303. defer func() {
  304. if failed {
  305. s.im.setStatus(StatusFailed)
  306. }
  307. }()
  308. z, err := zip.OpenReader(srcPath)
  309. if err != nil {
  310. return "", nil, err
  311. }
  312. defer z.Close()
  313. // Create a temporary directory to extract the files.
  314. dir, err := ioutil.TempDir("", "listmonk")
  315. if err != nil {
  316. s.log.Printf("error creating temporary directory for extracting ZIP: %v", err)
  317. return "", nil, err
  318. }
  319. files := make([]string, 0, len(z.File))
  320. for _, f := range z.File {
  321. fName := f.FileInfo().Name()
  322. // Skip directories.
  323. if f.FileInfo().IsDir() {
  324. s.log.Printf("skipping directory '%s'", fName)
  325. continue
  326. }
  327. // Skip files without the .csv extension.
  328. if !strings.HasSuffix(strings.ToLower(fName), ".csv") {
  329. s.log.Printf("skipping non .csv file '%s'", fName)
  330. continue
  331. }
  332. s.log.Printf("extracting '%s'", fName)
  333. src, err := f.Open()
  334. if err != nil {
  335. s.log.Printf("error opening '%s' from ZIP: '%v'", fName, err)
  336. return "", nil, err
  337. }
  338. defer src.Close()
  339. out, err := os.OpenFile(dir+"/"+fName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
  340. if err != nil {
  341. s.log.Printf("error creating '%s/%s': '%v'", dir, fName, err)
  342. return "", nil, err
  343. }
  344. defer out.Close()
  345. if _, err := io.Copy(out, src); err != nil {
  346. s.log.Printf("error extracting to '%s/%s': '%v'", dir, fName, err)
  347. return "", nil, err
  348. }
  349. s.log.Printf("extracted '%s'", fName)
  350. files = append(files, fName)
  351. if len(files) > maxCSVs {
  352. s.log.Printf("won't extract any more files. Maximum is %d", maxCSVs)
  353. break
  354. }
  355. }
  356. if len(files) == 0 {
  357. s.log.Println("no CSV files found in the ZIP")
  358. return "", nil, errors.New("no CSV files found in the ZIP")
  359. }
  360. failed = false
  361. return dir, files, nil
  362. }
  363. // LoadCSV loads a CSV file and validates and imports the subscriber entries in it.
  364. func (s *Session) LoadCSV(srcPath string, delim rune) error {
  365. if s.im.isDone() {
  366. return ErrIsImporting
  367. }
  368. // Default status is "failed" in case the function
  369. // returns at one of the many possible errors.
  370. failed := true
  371. defer func() {
  372. if failed {
  373. s.im.setStatus(StatusFailed)
  374. }
  375. }()
  376. f, err := os.Open(srcPath)
  377. if err != nil {
  378. return err
  379. }
  380. // Count the total number of lines in the file. This doesn't distinguish
  381. // between "blank" and non "blank" lines, and is only used to derive
  382. // the progress percentage for the frontend.
  383. numLines, err := countLines(f)
  384. if err != nil {
  385. s.log.Printf("error counting lines in '%s': '%v'", srcPath, err)
  386. return err
  387. }
  388. if numLines == 0 {
  389. return errors.New("empty file")
  390. }
  391. s.im.Lock()
  392. // Exclude the header from count.
  393. s.im.status.Total = numLines - 1
  394. s.im.Unlock()
  395. // Rewind, now that we've done a linecount on the same handler.
  396. _, _ = f.Seek(0, 0)
  397. rd := csv.NewReader(f)
  398. rd.Comma = delim
  399. // Read the header.
  400. csvHdr, err := rd.Read()
  401. if err != nil {
  402. s.log.Printf("error reading header from '%s': '%v'", srcPath, err)
  403. return err
  404. }
  405. hdrKeys := s.mapCSVHeaders(csvHdr, csvHeaders)
  406. // email, and name are required headers.
  407. if _, ok := hdrKeys["email"]; !ok {
  408. s.log.Printf("'email' column not found in '%s'", srcPath)
  409. return errors.New("'email' column not found")
  410. }
  411. if _, ok := hdrKeys["name"]; !ok {
  412. s.log.Printf("'name' column not found in '%s'", srcPath)
  413. return errors.New("'name' column not found")
  414. }
  415. var (
  416. lnHdr = len(hdrKeys)
  417. i = 0
  418. )
  419. for {
  420. i++
  421. // Check for the stop signal.
  422. select {
  423. case <-s.im.stop:
  424. failed = false
  425. close(s.subQueue)
  426. s.log.Println("stop request received")
  427. return nil
  428. default:
  429. }
  430. cols, err := rd.Read()
  431. if err == io.EOF {
  432. break
  433. } else if err != nil {
  434. if err, ok := err.(*csv.ParseError); ok && err.Err == csv.ErrFieldCount {
  435. s.log.Printf("skipping line %d. %v", i, err)
  436. continue
  437. } else {
  438. s.log.Printf("error reading CSV '%s'", err)
  439. return err
  440. }
  441. }
  442. lnCols := len(cols)
  443. if lnCols < lnHdr {
  444. s.log.Printf("skipping line %d. column count (%d) does not match minimum header count (%d)", i, lnCols, lnHdr)
  445. continue
  446. }
  447. // Iterate the key map and based on the indices mapped earlier,
  448. // form a map of key: csv_value, eg: email: user@user.com.
  449. row := make(map[string]string, lnCols)
  450. for key := range hdrKeys {
  451. row[key] = cols[hdrKeys[key]]
  452. }
  453. sub := SubReq{}
  454. sub.Email = row["email"]
  455. sub.Name = row["name"]
  456. sub, err = s.im.ValidateFields(sub)
  457. if err != nil {
  458. s.log.Printf("skipping line %d: %s: %v", i, sub.Email, err)
  459. continue
  460. }
  461. // JSON attributes.
  462. if len(row["attributes"]) > 0 {
  463. var (
  464. attribs models.SubscriberAttribs
  465. b = []byte(row["attributes"])
  466. )
  467. if err := json.Unmarshal(b, &attribs); err != nil {
  468. s.log.Printf("skipping invalid attributes JSON on line %d for '%s': %v", i, sub.Email, err)
  469. } else {
  470. sub.Attribs = attribs
  471. }
  472. }
  473. // Send the subscriber to the queue.
  474. s.subQueue <- sub
  475. }
  476. close(s.subQueue)
  477. failed = false
  478. return nil
  479. }
  480. // Stop sends a signal to stop the existing import.
  481. func (im *Importer) Stop() {
  482. if im.getStatus() != StatusImporting {
  483. im.Lock()
  484. im.status = Status{Status: StatusNone}
  485. im.Unlock()
  486. return
  487. }
  488. select {
  489. case im.stop <- true:
  490. im.setStatus(StatusStopping)
  491. default:
  492. }
  493. }
  494. // ValidateFields validates incoming subscriber field values and returns sanitized fields.
  495. func (im *Importer) ValidateFields(s SubReq) (SubReq, error) {
  496. if len(s.Email) > 1000 {
  497. return s, errors.New(im.i18n.T("subscribers.invalidEmail"))
  498. }
  499. s.Name = strings.TrimSpace(s.Name)
  500. if len(s.Name) == 0 || len(s.Name) > stdInputMaxLen {
  501. return s, errors.New(im.i18n.T("subscribers.invalidName"))
  502. }
  503. em, err := im.SanitizeEmail(s.Email)
  504. if err != nil {
  505. return s, err
  506. }
  507. s.Email = em
  508. return s, nil
  509. }
  510. // SanitizeEmail validates and sanitizes an e-mail string and returns the lowercased,
  511. // e-mail component of an e-mail string.
  512. func (im *Importer) SanitizeEmail(email string) (string, error) {
  513. email = strings.ToLower(strings.TrimSpace(email))
  514. // Since `mail.ParseAddress` parses an email address which can also contain optional name component
  515. // here we check if incoming email string is same as the parsed email.Address. So this eliminates
  516. // any valid email address with name and also valid address with empty name like `<abc@example.com>`.
  517. em, err := mail.ParseAddress(email)
  518. if err != nil || em.Address != email {
  519. return "", errors.New(im.i18n.T("subscribers.invalidEmail"))
  520. }
  521. // Check if the e-mail's domain is blocklisted.
  522. d := strings.Split(em.Address, "@")
  523. if len(d) == 2 {
  524. _, ok := im.opt.DomainBlocklist[d[1]]
  525. if ok {
  526. return "", errors.New(im.i18n.T("subscribers.domainBlocklisted"))
  527. }
  528. }
  529. return em.Address, nil
  530. }
  531. // mapCSVHeaders takes a list of headers obtained from a CSV file, a map of known headers,
  532. // and returns a new map with each of the headers in the known map mapped by the position (0-n)
  533. // in the given CSV list.
  534. func (s *Session) mapCSVHeaders(csvHdrs []string, knownHdrs map[string]bool) map[string]int {
  535. // Map 0-n column index to the header keys, name: 0, email: 1 etc.
  536. // This is to allow dynamic ordering of columns in th CSV.
  537. hdrKeys := make(map[string]int)
  538. for i, h := range csvHdrs {
  539. // Clean the string of non-ASCII characters (BOM etc.).
  540. h := regexCleanStr.ReplaceAllString(h, "")
  541. if _, ok := knownHdrs[h]; !ok {
  542. s.log.Printf("ignoring unknown header '%s'", h)
  543. continue
  544. }
  545. hdrKeys[h] = i
  546. }
  547. return hdrKeys
  548. }
  549. // countLines counts the number of line breaks in a file. This does not
  550. // distinguish between "blank" and non "blank" lines.
  551. // Credit: https://stackoverflow.com/a/24563853
  552. func countLines(r io.Reader) (int, error) {
  553. var (
  554. buf = make([]byte, 32*1024)
  555. count = 0
  556. lineSep = []byte{'\n'}
  557. )
  558. for {
  559. c, err := r.Read(buf)
  560. count += bytes.Count(buf[:c], lineSep)
  561. switch {
  562. case err == io.EOF:
  563. return count, nil
  564. case err != nil:
  565. return count, err
  566. }
  567. }
  568. }