switch from go-simple-mail to go-mail

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2023-01-15 15:28:31 +01:00
parent 6afbd77fd5
commit f2618e7de6
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
8 changed files with 716 additions and 278 deletions

21
go.mod
View file

@ -15,7 +15,7 @@ require (
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.47 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.47
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.14.0 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.14.0
github.com/aws/aws-sdk-go-v2/service/s3 v1.30.0 github.com/aws/aws-sdk-go-v2/service/s3 v1.30.0
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.18.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.18.1
github.com/aws/aws-sdk-go-v2/service/sts v1.18.0 github.com/aws/aws-sdk-go-v2/service/sts v1.18.0
github.com/bmatcuk/doublestar/v4 v4.6.0 github.com/bmatcuk/doublestar/v4 v4.6.0
github.com/cockroachdb/cockroach-go/v2 v2.2.20 github.com/cockroachdb/cockroach-go/v2 v2.2.20
@ -59,27 +59,27 @@ require (
github.com/spf13/viper v1.14.0 github.com/spf13/viper v1.14.0
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
github.com/studio-b12/gowebdav v0.0.0-20221109171924-60ec5ad56012 github.com/studio-b12/gowebdav v0.0.0-20221109171924-60ec5ad56012
github.com/subosito/gotenv v1.4.1 github.com/subosito/gotenv v1.4.2
github.com/unrolled/secure v1.13.0 github.com/unrolled/secure v1.13.0
github.com/wagslane/go-password-validator v0.3.0 github.com/wagslane/go-password-validator v0.3.0
github.com/xhit/go-simple-mail/v2 v2.13.0 github.com/wneessen/go-mail v0.3.8
github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a
go.etcd.io/bbolt v1.3.6 go.etcd.io/bbolt v1.3.6
go.uber.org/automaxprocs v1.5.1 go.uber.org/automaxprocs v1.5.1
gocloud.dev v0.27.0 gocloud.dev v0.28.0
golang.org/x/crypto v0.5.0 golang.org/x/crypto v0.5.0
golang.org/x/net v0.5.0 golang.org/x/net v0.5.0
golang.org/x/oauth2 v0.4.0 golang.org/x/oauth2 v0.4.0
golang.org/x/sys v0.4.0 golang.org/x/sys v0.4.0
golang.org/x/term v0.4.0 golang.org/x/term v0.4.0
golang.org/x/time v0.3.0 golang.org/x/time v0.3.0
google.golang.org/api v0.106.0 google.golang.org/api v0.107.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
require ( require (
cloud.google.com/go v0.108.0 // indirect cloud.google.com/go v0.108.0 // indirect
cloud.google.com/go/compute v1.15.0 // indirect cloud.google.com/go/compute v1.15.1 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v0.10.0 // indirect cloud.google.com/go/iam v0.10.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect
@ -108,7 +108,6 @@ require (
github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.0 // indirect github.com/go-jose/go-jose/v3 v3.0.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-test/deep v1.1.0 // indirect
github.com/goccy/go-json v0.10.0 // indirect github.com/goccy/go-json v0.10.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
@ -129,8 +128,7 @@ require (
github.com/lestrrat-go/httprc v1.0.4 // indirect github.com/lestrrat-go/httprc v1.0.4 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.7 // indirect github.com/lufia/plan9stats v0.0.0-20230110061619-bbe2e5e100de // indirect
github.com/lufia/plan9stats v0.0.0-20220913051719-115f729f3c8c // indirect
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect github.com/mattn/go-isatty v0.0.17 // indirect
@ -153,7 +151,6 @@ require (
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/tklauser/go-sysconf v0.3.11 // indirect github.com/tklauser/go-sysconf v0.3.11 // indirect
github.com/tklauser/numcpus v0.6.0 // indirect github.com/tklauser/numcpus v0.6.0 // indirect
github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
golang.org/x/mod v0.7.0 // indirect golang.org/x/mod v0.7.0 // indirect
@ -161,8 +158,8 @@ require (
golang.org/x/tools v0.5.0 // indirect golang.org/x/tools v0.5.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230106154932-a12b697841d9 // indirect google.golang.org/genproto v0.0.0-20230113154510-dbe35b8444a5 // indirect
google.golang.org/grpc v1.51.0 // indirect google.golang.org/grpc v1.52.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect

667
go.sum

File diff suppressed because it is too large Load diff

View file

@ -29,7 +29,7 @@ import (
"sync" "sync"
"time" "time"
mail "github.com/xhit/go-simple-mail/v2" "github.com/wneessen/go-mail"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/dataprovider"
@ -372,7 +372,7 @@ func (c *RetentionCheck) sendEmailNotification(errCheck error) error {
Results: c.results, Results: c.results,
}) })
} }
var files []mail.File var files []*mail.File
f, err := params.getRetentionReportsAsMailAttachment() f, err := params.getRetentionReportsAsMailAttachment()
if err != nil { if err != nil {
c.conn.Log(logger.LevelError, "unable to get retention report as mail attachment: %v", err) c.conn.Log(logger.LevelError, "unable to get retention report as mail attachment: %v", err)
@ -391,11 +391,11 @@ func (c *RetentionCheck) sendEmailNotification(errCheck error) error {
body := "Further details attached." body := "Further details attached."
err = smtp.SendEmail([]string{c.Email}, subject, body, smtp.EmailContentTypeTextPlain, files...) err = smtp.SendEmail([]string{c.Email}, subject, body, smtp.EmailContentTypeTextPlain, files...)
if err != nil { if err != nil {
c.conn.Log(logger.LevelError, "unable to notify retention check result via email: %v, elapsed: %v", err, c.conn.Log(logger.LevelError, "unable to notify retention check result via email: %v, elapsed: %s", err,
time.Since(startTime)) time.Since(startTime))
return err return err
} }
c.conn.Log(logger.LevelInfo, "retention check result successfully notified via email, elapsed: %v", time.Since(startTime)) c.conn.Log(logger.LevelInfo, "retention check result successfully notified via email, elapsed: %s", time.Since(startTime))
return nil return nil
} }

View file

@ -84,6 +84,7 @@ func TestRetentionValidation(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "mail.example.com", Host: "mail.example.com",
Port: 25, Port: 25,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize(configDir)
@ -116,6 +117,7 @@ func TestRetentionEmailNotifications(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 2525, Port: 2525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize(configDir)

View file

@ -41,7 +41,7 @@ import (
"github.com/robfig/cron/v3" "github.com/robfig/cron/v3"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
mail "github.com/xhit/go-simple-mail/v2" "github.com/wneessen/go-mail"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/logger"
@ -563,16 +563,30 @@ func (p *EventParams) getCompressedDataRetentionReport() ([]byte, error) {
return nil, errors.New("no data retention report available") return nil, errors.New("no data retention report available")
} }
var b bytes.Buffer var b bytes.Buffer
wr := zip.NewWriter(&b) if _, err := p.writeCompressedDataRetentionReports(&b); err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (p *EventParams) writeCompressedDataRetentionReports(w io.Writer) (int64, error) {
var n int64
wr := zip.NewWriter(w)
for _, check := range p.retentionChecks { for _, check := range p.retentionChecks {
if size := int64(len(b.Bytes())); size > maxAttachmentsSize {
eventManagerLog(logger.LevelError, "unable to get retention report, size too large: %s", util.ByteCountIEC(size))
return nil, fmt.Errorf("unable to get retention report, size too large: %s", util.ByteCountIEC(size))
}
data, err := getCSVRetentionReport(check.Results) data, err := getCSVRetentionReport(check.Results)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to get CSV report: %w", err) return n, fmt.Errorf("unable to get CSV report: %w", err)
} }
dataSize := int64(len(data))
n += dataSize
// we suppose a 3:1 compression ratio
if n > (maxAttachmentsSize * 3) {
eventManagerLog(logger.LevelError, "unable to get retention report, size too large: %s",
util.ByteCountIEC(n))
return n, fmt.Errorf("unable to get retention report, size too large: %s", util.ByteCountIEC(n))
}
fh := &zip.FileHeader{ fh := &zip.FileHeader{
Name: fmt.Sprintf("%s-%s.csv", check.ActionName, check.Username), Name: fmt.Sprintf("%s-%s.csv", check.ActionName, check.Username),
Method: zip.Deflate, Method: zip.Deflate,
@ -580,28 +594,28 @@ func (p *EventParams) getCompressedDataRetentionReport() ([]byte, error) {
} }
f, err := wr.CreateHeader(fh) f, err := wr.CreateHeader(fh)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create zip header for file %q: %w", fh.Name, err) return n, fmt.Errorf("unable to create zip header for file %q: %w", fh.Name, err)
} }
_, err = io.Copy(f, bytes.NewBuffer(data)) _, err = io.CopyN(f, bytes.NewBuffer(data), dataSize)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to write content to zip file %q: %w", fh.Name, err) return n, fmt.Errorf("unable to write content to zip file %q: %w", fh.Name, err)
} }
} }
if err := wr.Close(); err != nil { if err := wr.Close(); err != nil {
return nil, fmt.Errorf("unable to close zip writer: %w", err) return n, fmt.Errorf("unable to close zip writer: %w", err)
} }
return b.Bytes(), nil return n, nil
} }
func (p *EventParams) getRetentionReportsAsMailAttachment() (mail.File, error) { func (p *EventParams) getRetentionReportsAsMailAttachment() (*mail.File, error) {
var result mail.File if len(p.retentionChecks) == 0 {
data, err := p.getCompressedDataRetentionReport() return nil, errors.New("no data retention report available")
if err != nil {
return result, err
} }
result.Name = "retention-reports.zip" return &mail.File{
result.Data = data Name: "retention-reports.zip",
return result, nil Header: make(map[string][]string),
Writer: p.writeCompressedDataRetentionReports,
}, nil
} }
func (p *EventParams) getStringReplacements(addObjectData bool) []string { func (p *EventParams) getStringReplacements(addObjectData bool) []string {
@ -905,34 +919,24 @@ func writeFileContent(conn *BaseConnection, virtualPath string, w io.Writer) err
return err return err
} }
func getFileContent(conn *BaseConnection, virtualPath string, expectedSize int) ([]byte, error) { func getFileContentFn(conn *BaseConnection, virtualPath string, size int64) func(w io.Writer) (int64, error) {
reader, cancelFn, err := getFileReader(conn, virtualPath) return func(w io.Writer) (int64, error) {
if err != nil { reader, cancelFn, err := getFileReader(conn, virtualPath)
return nil, err if err != nil {
return 0, err
}
defer cancelFn()
defer reader.Close()
return io.CopyN(w, reader, size)
} }
defer cancelFn()
defer reader.Close()
data := make([]byte, expectedSize)
_, err = io.ReadFull(reader, data)
return data, err
} }
func getMailAttachments(user dataprovider.User, attachments []string, replacer *strings.Replacer) ([]mail.File, error) { func getMailAttachments(conn *BaseConnection, attachments []string, replacer *strings.Replacer) ([]*mail.File, error) {
var files []mail.File var files []*mail.File
user, err := getUserForEventAction(user)
if err != nil {
return nil, err
}
connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
err = user.CheckFsRoot(connectionID)
defer user.CloseFs() //nolint:errcheck
if err != nil {
return nil, fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err)
}
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
totalSize := int64(0) totalSize := int64(0)
for _, virtualPath := range replacePathsPlaceholders(attachments, replacer) { for _, virtualPath := range replacePathsPlaceholders(attachments, replacer) {
info, err := conn.DoStat(virtualPath, 0, false) info, err := conn.DoStat(virtualPath, 0, false)
if err != nil { if err != nil {
@ -945,13 +949,10 @@ func getMailAttachments(user dataprovider.User, attachments []string, replacer *
if totalSize > maxAttachmentsSize { if totalSize > maxAttachmentsSize {
return nil, fmt.Errorf("unable to send files as attachment, size too large: %s", util.ByteCountIEC(totalSize)) return nil, fmt.Errorf("unable to send files as attachment, size too large: %s", util.ByteCountIEC(totalSize))
} }
data, err := getFileContent(conn, virtualPath, int(info.Size())) files = append(files, &mail.File{
if err != nil { Name: path.Base(virtualPath),
return nil, fmt.Errorf("unable to get content for file %q, user %q: %w", virtualPath, conn.User.Username, err) Header: make(map[string][]string),
} Writer: getFileContentFn(conn, virtualPath, info.Size()),
files = append(files, mail.File{
Name: path.Base(virtualPath),
Data: data,
}) })
} }
return files, nil return files, nil
@ -1265,7 +1266,7 @@ func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *Event
body := replaceWithReplacer(c.Body, replacer) body := replaceWithReplacer(c.Body, replacer)
subject := replaceWithReplacer(c.Subject, replacer) subject := replaceWithReplacer(c.Subject, replacer)
startTime := time.Now() startTime := time.Now()
var files []mail.File var files []*mail.File
fileAttachments := make([]string, 0, len(c.Attachments)) fileAttachments := make([]string, 0, len(c.Attachments))
for _, attachment := range c.Attachments { for _, attachment := range c.Attachments {
if attachment == dataprovider.RetentionReportPlaceHolder { if attachment == dataprovider.RetentionReportPlaceHolder {
@ -1283,7 +1284,18 @@ func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *Event
if err != nil { if err != nil {
return err return err
} }
res, err := getMailAttachments(user, fileAttachments, replacer) user, err = getUserForEventAction(user)
if err != nil {
return err
}
connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String())
err = user.CheckFsRoot(connectionID)
defer user.CloseFs() //nolint:errcheck
if err != nil {
return fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err)
}
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
res, err := getMailAttachments(conn, fileAttachments, replacer)
if err != nil { if err != nil {
return err return err
} }

View file

@ -30,6 +30,7 @@ import (
"time" "time"
"github.com/klauspost/compress/zip" "github.com/klauspost/compress/zip"
"github.com/rs/xid"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms" sdkkms "github.com/sftpgo/sdk/kms"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -530,14 +531,6 @@ func TestEventManagerErrors(t *testing.T) {
}, },
}) })
assert.Error(t, err) assert.Error(t, err)
_, err = getMailAttachments(dataprovider.User{
Groups: []sdk.GroupMapping{
{
Name: groupName,
Type: sdk.GroupTypePrimary,
},
}}, []string{"/a", "/b"}, nil)
assert.Error(t, err)
err = executePwdExpirationCheckForUser(&dataprovider.User{ err = executePwdExpirationCheckForUser(&dataprovider.User{
Groups: []sdk.GroupMapping{ Groups: []sdk.GroupMapping{
{ {
@ -1253,17 +1246,21 @@ func TestGetFileContent(t *testing.T) {
fileContent := []byte("test file content") fileContent := []byte("test file content")
err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file.txt"), fileContent, 0666) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file.txt"), fileContent, 0666)
assert.NoError(t, err) assert.NoError(t, err)
conn := NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user)
replacer := strings.NewReplacer("old", "new") replacer := strings.NewReplacer("old", "new")
files, err := getMailAttachments(user, []string{"/file.txt"}, replacer) files, err := getMailAttachments(conn, []string{"/file.txt"}, replacer)
assert.NoError(t, err) assert.NoError(t, err)
if assert.Len(t, files, 1) { if assert.Len(t, files, 1) {
assert.Equal(t, fileContent, files[0].Data) var b bytes.Buffer
_, err = files[0].Writer(&b)
assert.NoError(t, err)
assert.Equal(t, fileContent, b.Bytes())
} }
// missing file // missing file
_, err = getMailAttachments(user, []string{"/file1.txt"}, replacer) _, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer)
assert.Error(t, err) assert.Error(t, err)
// directory // directory
_, err = getMailAttachments(user, []string{"/"}, replacer) _, err = getMailAttachments(conn, []string{"/"}, replacer)
assert.Error(t, err) assert.Error(t, err)
// files too large // files too large
content := make([]byte, maxAttachmentsSize/2+1) content := make([]byte, maxAttachmentsSize/2+1)
@ -1273,12 +1270,15 @@ func TestGetFileContent(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file2.txt"), content, 0666) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file2.txt"), content, 0666)
assert.NoError(t, err) assert.NoError(t, err)
files, err = getMailAttachments(user, []string{"/file1.txt"}, replacer) files, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer)
assert.NoError(t, err) assert.NoError(t, err)
if assert.Len(t, files, 1) { if assert.Len(t, files, 1) {
assert.Equal(t, content, files[0].Data) var b bytes.Buffer
_, err = files[0].Writer(&b)
assert.NoError(t, err)
assert.Equal(t, content, b.Bytes())
} }
_, err = getMailAttachments(user, []string{"/file1.txt", "/file2.txt"}, replacer) _, err = getMailAttachments(conn, []string{"/file1.txt", "/file2.txt"}, replacer)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "size too large") assert.Contains(t, err.Error(), "size too large")
} }
@ -1287,9 +1287,15 @@ func TestGetFileContent(t *testing.T) {
user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("pwd") user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("pwd")
err = dataprovider.UpdateUser(&user, "", "", "") err = dataprovider.UpdateUser(&user, "", "", "")
assert.NoError(t, err) assert.NoError(t, err)
conn = NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user)
// the file is not encrypted so reading the encryption header will fail // the file is not encrypted so reading the encryption header will fail
_, err = getMailAttachments(user, []string{"/file.txt"}, replacer) files, err = getMailAttachments(conn, []string{"/file.txt"}, replacer)
assert.Error(t, err) assert.NoError(t, err)
if assert.Len(t, files, 1) {
var b bytes.Buffer
_, err = files[0].Writer(&b)
assert.Error(t, err)
}
err = dataprovider.DeleteUser(username, "", "", "") err = dataprovider.DeleteUser(username, "", "", "")
assert.NoError(t, err) assert.NoError(t, err)
@ -1361,7 +1367,9 @@ func TestFilesystemActionErrors(t *testing.T) {
sender: username, sender: username,
}) })
assert.Error(t, err) assert.Error(t, err)
_, err = getFileContent(NewBaseConnection("", protocolEventAction, "", "", user), "/f.txt", 1234) fn := getFileContentFn(NewBaseConnection("", protocolEventAction, "", "", user), "/f.txt", 1234)
var b bytes.Buffer
_, err = fn(&b)
assert.Error(t, err) assert.Error(t, err)
err = executeHTTPRuleAction(dataprovider.EventActionHTTPConfig{ err = executeHTTPRuleAction(dataprovider.EventActionHTTPConfig{
Endpoint: "http://127.0.0.1:9999/", Endpoint: "http://127.0.0.1:9999/",

View file

@ -5973,6 +5973,7 @@ func TestNamingRules(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize(configDir)
@ -11655,6 +11656,7 @@ func TestMaxSessions(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize(configDir)
@ -11732,6 +11734,7 @@ func TestSFTPLoopError(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize(configDir)
@ -21672,6 +21675,7 @@ func TestAdminForgotPassword(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize(configDir)
@ -21777,6 +21781,7 @@ func TestAdminForgotPassword(t *testing.T) {
smtpCfg = smtp.Config{ smtpCfg = smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3526, Port: 3526,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize(configDir)
@ -21825,6 +21830,7 @@ func TestUserForgotPassword(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize(configDir)
@ -21975,6 +21981,7 @@ func TestAPIForgotPassword(t *testing.T) {
smtpCfg := smtp.Config{ smtpCfg := smtp.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 3525, Port: 3525,
From: "notification@example.com",
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize(configDir)

View file

@ -17,13 +17,14 @@ package smtp
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
"path/filepath" "path/filepath"
"time" "time"
mail "github.com/xhit/go-simple-mail/v2" "github.com/wneessen/go-mail"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/util"
@ -49,14 +50,13 @@ const (
) )
var ( var (
smtpServer *mail.SMTPServer config *Config
from string
emailTemplates = make(map[string]*template.Template) emailTemplates = make(map[string]*template.Template)
) )
// IsEnabled returns true if an SMTP server is configured // IsEnabled returns true if an SMTP server is configured
func IsEnabled() bool { func IsEnabled() bool {
return smtpServer != nil return config != nil
} }
// Config defines the SMTP configuration to use to send emails // Config defines the SMTP configuration to use to send emails
@ -91,7 +91,7 @@ type Config struct {
// Initialize initialized and validates the SMTP configuration // Initialize initialized and validates the SMTP configuration
func (c *Config) Initialize(configDir string) error { func (c *Config) Initialize(configDir string) error {
smtpServer = nil config = nil
if c.Host == "" { if c.Host == "" {
logger.Debug(logSender, "", "configuration disabled, email capabilities will not be available") logger.Debug(logSender, "", "configuration disabled, email capabilities will not be available")
return nil return nil
@ -105,53 +105,51 @@ func (c *Config) Initialize(configDir string) error {
if c.Encryption < 0 || c.Encryption > 2 { if c.Encryption < 0 || c.Encryption > 2 {
return fmt.Errorf("smtp: invalid encryption %v", c.Encryption) return fmt.Errorf("smtp: invalid encryption %v", c.Encryption)
} }
if c.From == "" && c.User == "" {
return fmt.Errorf(`smtp: from address and user cannot both be empty`)
}
templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir)
if templatesPath == "" { if templatesPath == "" {
return fmt.Errorf("smtp: invalid templates path %#v", templatesPath) return fmt.Errorf("smtp: invalid templates path %#v", templatesPath)
} }
loadTemplates(filepath.Join(templatesPath, templateEmailDir)) loadTemplates(filepath.Join(templatesPath, templateEmailDir))
from = c.From config = c
smtpServer = mail.NewSMTPClient() logger.Debug(logSender, "", "configuration successfully initialized, host: %q, port: %d, username: %q, auth: %d, encryption: %d, helo: %q",
smtpServer.Host = c.Host config.Host, config.Port, config.User, config.AuthType, config.Encryption, config.Domain)
smtpServer.Port = c.Port
smtpServer.Username = c.User
smtpServer.Password = c.Password
smtpServer.Authentication = c.getAuthType()
smtpServer.Encryption = c.getEncryption()
smtpServer.KeepAlive = false
smtpServer.ConnectTimeout = 10 * time.Second
smtpServer.SendTimeout = 120 * time.Second
if c.Domain != "" {
smtpServer.Helo = c.Domain
}
logger.Debug(logSender, "", "configuration successfully initialized, host: %#v, port: %v, username: %#v, auth: %v, encryption: %v, helo: %#v",
smtpServer.Host, smtpServer.Port, smtpServer.Username, smtpServer.Authentication, smtpServer.Encryption, smtpServer.Helo)
return nil return nil
} }
func (c *Config) getEncryption() mail.Encryption { func (c *Config) getMailClientOptions() []mail.Option {
options := []mail.Option{mail.WithPort(c.Port)}
switch c.Encryption { switch c.Encryption {
case 1: case 1:
return mail.EncryptionSSLTLS options = append(options, mail.WithSSL())
case 2: case 2:
return mail.EncryptionSTARTTLS options = append(options, mail.WithTLSPolicy(mail.TLSMandatory))
default: default:
return mail.EncryptionNone options = append(options, mail.WithTLSPolicy(mail.NoTLS))
} }
} if config.User != "" {
options = append(options, mail.WithUsername(config.User))
func (c *Config) getAuthType() mail.AuthType {
if c.User == "" && c.Password == "" {
return mail.AuthNone
} }
switch c.AuthType { if config.Password != "" {
case 1: options = append(options, mail.WithPassword(config.Password))
return mail.AuthLogin
case 2:
return mail.AuthCRAMMD5
default:
return mail.AuthPlain
} }
if config.User != "" || config.Password != "" {
switch config.AuthType {
case 1:
options = append(options, mail.WithSMTPAuth(mail.SMTPAuthLogin))
case 2:
options = append(options, mail.WithSMTPAuth(mail.SMTPAuthCramMD5))
default:
options = append(options, mail.WithSMTPAuth(mail.SMTPAuthPlain))
}
}
if config.Domain != "" {
options = append(options, mail.WithHELO(config.Domain))
}
return options
} }
func loadTemplates(templatesPath string) { func loadTemplates(templatesPath string) {
@ -168,7 +166,7 @@ func loadTemplates(templatesPath string) {
// RenderPasswordResetTemplate executes the password reset template // RenderPasswordResetTemplate executes the password reset template
func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error { func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error {
if smtpServer == nil { if !IsEnabled() {
return errors.New("smtp: not configured") return errors.New("smtp: not configured")
} }
return emailTemplates[templatePasswordReset].Execute(buf, data) return emailTemplates[templatePasswordReset].Execute(buf, data)
@ -176,46 +174,51 @@ func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error {
// RenderPasswordExpirationTemplate executes the password expiration template // RenderPasswordExpirationTemplate executes the password expiration template
func RenderPasswordExpirationTemplate(buf *bytes.Buffer, data any) error { func RenderPasswordExpirationTemplate(buf *bytes.Buffer, data any) error {
if smtpServer == nil { if !IsEnabled() {
return errors.New("smtp: not configured") return errors.New("smtp: not configured")
} }
return emailTemplates[templatePasswordExpiration].Execute(buf, data) return emailTemplates[templatePasswordExpiration].Execute(buf, data)
} }
// SendEmail tries to send an email using the specified parameters. // SendEmail tries to send an email using the specified parameters.
func SendEmail(to []string, subject, body string, contentType EmailContentType, attachments ...mail.File) error { func SendEmail(to []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error {
if smtpServer == nil { if !IsEnabled() {
return errors.New("smtp: not configured") return errors.New("smtp: not configured")
} }
if len(to) == 0 { m := mail.NewMsg()
return errors.New("smtp: cannot send an email without recipients")
}
smtpClient, err := smtpServer.Connect()
if err != nil {
return fmt.Errorf("smtp: unable to connect: %w", err)
}
email := mail.NewMSG() var from string
email.AllowDuplicateAddress = true if config.From != "" {
if from != "" { from = config.From
email.SetFrom(from)
} else { } else {
email.SetFrom(smtpServer.Username) from = config.User
} }
email.AddTo(to...).SetSubject(subject) if err := m.From(from); err != nil {
return fmt.Errorf("invalid from address: %w", err)
}
if err := m.To(to...); err != nil {
return err
}
m.Subject(subject)
m.SetDate()
m.SetMessageID()
m.SetAttachements(attachments)
switch contentType { switch contentType {
case EmailContentTypeTextPlain: case EmailContentTypeTextPlain:
email.SetBody(mail.TextPlain, body) m.SetBodyString(mail.TypeTextPlain, body)
case EmailContentTypeTextHTML: case EmailContentTypeTextHTML:
email.SetBody(mail.TextHTML, body) m.SetBodyString(mail.TypeTextHTML, body)
default: default:
return fmt.Errorf("smtp: unsupported body content type %v", contentType) return fmt.Errorf("smtp: unsupported body content type %v", contentType)
} }
for _, attachment := range attachments {
email.Attach(&attachment) c, err := mail.NewClient(config.Host, config.getMailClientOptions()...)
if err != nil {
return fmt.Errorf("unable to create mail client: %w", err)
} }
if email.Error != nil { ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
return fmt.Errorf("smtp: email error: %w", email.Error) defer cancelFn()
}
return email.Send(smtpClient) return c.DialAndSendWithContext(ctx, m)
} }