sftpgo-mirror/utils/utils.go
2020-06-19 17:08:51 +02:00

314 lines
7.9 KiB
Go

// Package utils provides some common utility methods
package utils
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"html/template"
"io"
"io/ioutil"
"net"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"time"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/logger"
)
const logSender = "utils"
// IsStringInSlice searches a string in a slice and returns true if the string is found
func IsStringInSlice(obj string, list []string) bool {
for _, v := range list {
if v == obj {
return true
}
}
return false
}
// IsStringPrefixInSlice searches a string prefix in a slice and returns true
// if a matching prefix is found
func IsStringPrefixInSlice(obj string, list []string) bool {
for _, v := range list {
if strings.HasPrefix(obj, v) {
return true
}
}
return false
}
// GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct
func GetTimeAsMsSinceEpoch(t time.Time) int64 {
return t.UnixNano() / 1000000
}
// GetTimeFromMsecSinceEpoch return a time struct from a unix timestamp with millisecond precision
func GetTimeFromMsecSinceEpoch(msec int64) time.Time {
return time.Unix(0, msec*1000000)
}
// GetDurationAsString returns a string representation for a time.Duration
func GetDurationAsString(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d -= h * time.Hour
m := d / time.Minute
d -= m * time.Minute
s := d / time.Second
if h > 0 {
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
return fmt.Sprintf("%02d:%02d", m, s)
}
// ByteCountSI returns humanized size in SI (decimal) format
func ByteCountSI(b int64) string {
return byteCount(b, 1000)
}
// ByteCountIEC returns humanized size in IEC (binary) format
func ByteCountIEC(b int64) string {
return byteCount(b, 1024)
}
func byteCount(b int64, unit int64) string {
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := unit, 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
if unit == 1000 {
return fmt.Sprintf("%.1f %cB",
float64(b)/float64(div), "KMGTPE"[exp])
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
// GetIPFromRemoteAddress returns the IP from the remote address.
// If the given remote address cannot be parsed it will be returned unchanged
func GetIPFromRemoteAddress(remoteAddress string) string {
ip, _, err := net.SplitHostPort(remoteAddress)
if err == nil {
return ip
}
return remoteAddress
}
// NilIfEmpty returns nil if the input string is empty
func NilIfEmpty(s string) *string {
if len(s) == 0 {
return nil
}
return &s
}
// EncryptData encrypts data using the given key
func EncryptData(data string) (string, error) {
var result string
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return result, err
}
keyHex := hex.EncodeToString(key)
block, err := aes.NewCipher([]byte(keyHex))
if err != nil {
return result, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return result, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return result, err
}
ciphertext := gcm.Seal(nonce, nonce, []byte(data), nil)
result = fmt.Sprintf("$aes$%s$%x", keyHex, ciphertext)
return result, err
}
// RemoveDecryptionKey returns encrypted data without the decryption key
func RemoveDecryptionKey(encryptData string) string {
vals := strings.Split(encryptData, "$")
if len(vals) == 4 {
return fmt.Sprintf("$%v$%v", vals[1], vals[3])
}
return encryptData
}
// DecryptData decrypts data encrypted using EncryptData
func DecryptData(data string) (string, error) {
var result string
vals := strings.Split(data, "$")
if len(vals) != 4 {
return "", errors.New("data to decrypt is not in the correct format")
}
key := vals[2]
encrypted, err := hex.DecodeString(vals[3])
if err != nil {
return result, err
}
block, err := aes.NewCipher([]byte(key))
if err != nil {
return result, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return result, err
}
nonceSize := gcm.NonceSize()
nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return result, err
}
return string(plaintext), nil
}
// GenerateRSAKeys generate rsa private and public keys and write the
// private key to specified file and the public key to the specified
// file adding the .pub suffix
func GenerateRSAKeys(file string) error {
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return err
}
o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer o.Close()
priv := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
if err := pem.Encode(o, priv); err != nil {
return err
}
pub, err := ssh.NewPublicKey(&key.PublicKey)
if err != nil {
return err
}
return ioutil.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
}
// GenerateECDSAKeys generate ecdsa private and public keys and write the
// private key to specified file and the public key to the specified
// file adding the .pub suffix
func GenerateECDSAKeys(file string) error {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer o.Close()
keyBytes, err := x509.MarshalECPrivateKey(key)
if err != nil {
return err
}
priv := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
}
if err := pem.Encode(o, priv); err != nil {
return err
}
pub, err := ssh.NewPublicKey(&key.PublicKey)
if err != nil {
return err
}
return ioutil.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
}
// GetDirsForSFTPPath returns all the directory for the given path in reverse order
// for example if the path is: /1/2/3/4 it returns:
// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ]
func GetDirsForSFTPPath(p string) []string {
sftpPath := CleanSFTPPath(p)
dirsForPath := []string{sftpPath}
for {
if sftpPath == "/" {
break
}
sftpPath = path.Dir(sftpPath)
dirsForPath = append(dirsForPath, sftpPath)
}
return dirsForPath
}
// CleanSFTPPath returns a clean sftp path
func CleanSFTPPath(p string) string {
sftpPath := filepath.ToSlash(p)
if !path.IsAbs(p) {
sftpPath = "/" + sftpPath
}
return path.Clean(sftpPath)
}
// LoadTemplate wraps a call to a function returning (*Template, error)
// it is just like template.Must but it writes a log before exiting
func LoadTemplate(t *template.Template, err error) *template.Template {
if err != nil {
logger.ErrorToConsole("error loading required template: %v", err)
logger.Error(logSender, "", "error loading required template: %v", err)
panic(err)
}
return t
}
// IsFileInputValid returns true this is a valid file name.
// This method must be used before joining a file name, generally provided as
// user input, with a directory
func IsFileInputValid(fileInput string) bool {
cleanInput := filepath.Clean(fileInput)
if cleanInput == "." || cleanInput == ".." {
return false
}
return true
}
// CleanDirInput sanitizes user input for directories.
// On Windows it removes any trailing `"`.
// We try to help windows users that set an invalid path such as "C:\ProgramData\SFTPGO\".
// This will only help if the invalid path is the last argument, for example in this command:
// sftpgo.exe serve -c "C:\ProgramData\SFTPGO\" -l "sftpgo.log"
// the -l flag will be ignored and the -c flag will get the value `C:\ProgramData\SFTPGO" -l sftpgo.log`
// since the backslash after SFTPGO escape the double quote. This is definitely a bad user input
func CleanDirInput(dirInput string) string {
if runtime.GOOS == "windows" {
for strings.HasSuffix(dirInput, "\"") {
dirInput = strings.TrimSuffix(dirInput, "\"")
}
}
return filepath.Clean(dirInput)
}