Moves files

This commit is contained in:
Alexander Neumann
2017-07-23 14:19:13 +02:00
parent d1bd160b0a
commit 83d1a46526
284 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,72 @@
package sftp
import (
"net/url"
"path"
"strings"
"restic/errors"
"restic/options"
)
// Config collects all information required to connect to an sftp server.
type Config struct {
User, Host, Path string
Layout string `option:"layout" help:"use this backend directory layout (default: auto-detect)"`
Command string `option:"command" help:"specify command to create sftp connection"`
}
func init() {
options.Register("sftp", Config{})
}
// ParseConfig parses the string s and extracts the sftp config. The
// supported configuration formats are sftp://user@host/directory
// and sftp:user@host:directory. The directory will be path Cleaned and can
// be an absolute path if it starts with a '/' (e.g.
// sftp://user@host//absolute and sftp:user@host:/absolute).
func ParseConfig(s string) (interface{}, error) {
var user, host, dir string
switch {
case strings.HasPrefix(s, "sftp://"):
// parse the "sftp://user@host/path" url format
url, err := url.Parse(s)
if err != nil {
return nil, errors.Wrap(err, "url.Parse")
}
if url.User != nil {
user = url.User.Username()
}
host = url.Host
dir = url.Path
if dir == "" {
return nil, errors.Errorf("invalid backend %q, no directory specified", s)
}
dir = dir[1:]
case strings.HasPrefix(s, "sftp:"):
// parse the sftp:user@host:path format, which means we'll get
// "user@host:path" in s
s = s[5:]
// split user@host and path at the colon
data := strings.SplitN(s, ":", 2)
if len(data) < 2 {
return nil, errors.New("sftp: invalid format, hostname or path not found")
}
host = data[0]
dir = data[1]
// split user and host at the "@"
data = strings.SplitN(host, "@", 2)
if len(data) == 2 {
user = data[0]
host = data[1]
}
default:
return nil, errors.New(`invalid format, does not start with "sftp:"`)
}
return Config{
User: user,
Host: host,
Path: path.Clean(dir),
}, nil
}

View File

@@ -0,0 +1,90 @@
package sftp
import "testing"
var configTests = []struct {
in string
cfg Config
}{
// first form, user specified sftp://user@host/dir
{
"sftp://user@host/dir/subdir",
Config{User: "user", Host: "host", Path: "dir/subdir"},
},
{
"sftp://host/dir/subdir",
Config{Host: "host", Path: "dir/subdir"},
},
{
"sftp://host//dir/subdir",
Config{Host: "host", Path: "/dir/subdir"},
},
{
"sftp://host:10022//dir/subdir",
Config{Host: "host:10022", Path: "/dir/subdir"},
},
{
"sftp://user@host:10022//dir/subdir",
Config{User: "user", Host: "host:10022", Path: "/dir/subdir"},
},
{
"sftp://user@host/dir/subdir/../other",
Config{User: "user", Host: "host", Path: "dir/other"},
},
{
"sftp://user@host/dir///subdir",
Config{User: "user", Host: "host", Path: "dir/subdir"},
},
// second form, user specified sftp:user@host:/dir
{
"sftp:user@host:/dir/subdir",
Config{User: "user", Host: "host", Path: "/dir/subdir"},
},
{
"sftp:host:../dir/subdir",
Config{Host: "host", Path: "../dir/subdir"},
},
{
"sftp:user@host:dir/subdir:suffix",
Config{User: "user", Host: "host", Path: "dir/subdir:suffix"},
},
{
"sftp:user@host:dir/subdir/../other",
Config{User: "user", Host: "host", Path: "dir/other"},
},
{
"sftp:user@host:dir///subdir",
Config{User: "user", Host: "host", Path: "dir/subdir"},
},
}
func TestParseConfig(t *testing.T) {
for i, test := range configTests {
cfg, err := ParseConfig(test.in)
if err != nil {
t.Errorf("test %d:%s failed: %v", i, test.in, err)
continue
}
if cfg != test.cfg {
t.Errorf("test %d:\ninput:\n %s\n wrong config, want:\n %v\ngot:\n %v",
i, test.in, test.cfg, cfg)
continue
}
}
}
var configTestsInvalid = []string{
"sftp://host:dir",
}
func TestParseConfigInvalid(t *testing.T) {
for i, test := range configTestsInvalid {
_, err := ParseConfig(test)
if err == nil {
t.Errorf("test %d: invalid config %s did not return an error", i, test)
continue
}
}
}

View File

@@ -0,0 +1,3 @@
// Package sftp implements repository storage in a directory on a remote server
// via the sftp protocol.
package sftp

View File

@@ -0,0 +1,87 @@
package sftp_test
import (
"context"
"fmt"
"path/filepath"
"restic"
"restic/backend/sftp"
. "restic/test"
"testing"
)
func TestLayout(t *testing.T) {
if sftpServer == "" {
t.Skip("sftp server binary not available")
}
path, cleanup := TempDir(t)
defer cleanup()
var tests = []struct {
filename string
layout string
failureExpected bool
datafiles map[string]bool
}{
{"repo-layout-default.tar.gz", "", false, map[string]bool{
"aa464e9fd598fe4202492ee317ffa728e82fa83a1de1a61996e5bd2d6651646c": false,
"fc919a3b421850f6fa66ad22ebcf91e433e79ffef25becf8aef7c7b1eca91683": false,
"c089d62788da14f8b7cbf77188305c0874906f0b73d3fce5a8869050e8d0c0e1": false,
}},
{"repo-layout-s3legacy.tar.gz", "", false, map[string]bool{
"fc919a3b421850f6fa66ad22ebcf91e433e79ffef25becf8aef7c7b1eca91683": false,
"c089d62788da14f8b7cbf77188305c0874906f0b73d3fce5a8869050e8d0c0e1": false,
"aa464e9fd598fe4202492ee317ffa728e82fa83a1de1a61996e5bd2d6651646c": false,
}},
}
for _, test := range tests {
t.Run(test.filename, func(t *testing.T) {
SetupTarTestFixture(t, path, filepath.Join("..", "testdata", test.filename))
repo := filepath.Join(path, "repo")
be, err := sftp.Open(sftp.Config{
Command: fmt.Sprintf("%q -e", sftpServer),
Path: repo,
Layout: test.layout,
})
if err != nil {
t.Fatal(err)
}
if be == nil {
t.Fatalf("Open() returned nil but no error")
}
datafiles := make(map[string]bool)
for id := range be.List(context.TODO(), restic.DataFile) {
datafiles[id] = false
}
if len(datafiles) == 0 {
t.Errorf("List() returned zero data files")
}
for id := range test.datafiles {
if _, ok := datafiles[id]; !ok {
t.Errorf("datafile with id %v not found", id)
}
datafiles[id] = true
}
for id, v := range datafiles {
if !v {
t.Errorf("unexpected id %v found", id)
}
}
if err = be.Close(); err != nil {
t.Errorf("Close() returned error %v", err)
}
RemoveAll(t, filepath.Join(path, "repo"))
})
}
}

View File

@@ -0,0 +1,497 @@
package sftp
import (
"bufio"
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"restic"
"strings"
"time"
"restic/errors"
"restic/backend"
"restic/debug"
"github.com/pkg/sftp"
)
// SFTP is a backend in a directory accessed via SFTP.
type SFTP struct {
c *sftp.Client
p string
cmd *exec.Cmd
result <-chan error
backend.Layout
Config
}
var _ restic.Backend = &SFTP{}
const defaultLayout = "default"
func startClient(program string, args ...string) (*SFTP, error) {
debug.Log("start client %v %v", program, args)
// Connect to a remote host and request the sftp subsystem via the 'ssh'
// command. This assumes that passwordless login is correctly configured.
cmd := exec.Command(program, args...)
// prefix the errors with the program name
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, errors.Wrap(err, "cmd.StderrPipe")
}
go func() {
sc := bufio.NewScanner(stderr)
for sc.Scan() {
fmt.Fprintf(os.Stderr, "subprocess %v: %v\n", program, sc.Text())
}
}()
// ignore signals sent to the parent (e.g. SIGINT)
cmd.SysProcAttr = ignoreSigIntProcAttr()
// get stdin and stdout
wr, err := cmd.StdinPipe()
if err != nil {
return nil, errors.Wrap(err, "cmd.StdinPipe")
}
rd, err := cmd.StdoutPipe()
if err != nil {
return nil, errors.Wrap(err, "cmd.StdoutPipe")
}
// start the process
if err := cmd.Start(); err != nil {
return nil, errors.Wrap(err, "cmd.Start")
}
// wait in a different goroutine
ch := make(chan error, 1)
go func() {
err := cmd.Wait()
debug.Log("ssh command exited, err %v", err)
ch <- errors.Wrap(err, "cmd.Wait")
}()
// open the SFTP session
client, err := sftp.NewClientPipe(rd, wr)
if err != nil {
return nil, errors.Errorf("unable to start the sftp session, error: %v", err)
}
return &SFTP{c: client, cmd: cmd, result: ch}, nil
}
// clientError returns an error if the client has exited. Otherwise, nil is
// returned immediately.
func (r *SFTP) clientError() error {
select {
case err := <-r.result:
debug.Log("client has exited with err %v", err)
return err
default:
}
return nil
}
// Open opens an sftp backend as described by the config by running
// "ssh" with the appropriate arguments (or cfg.Command, if set).
func Open(cfg Config) (*SFTP, error) {
debug.Log("open backend with config %#v", cfg)
cmd, args, err := buildSSHCommand(cfg)
if err != nil {
return nil, err
}
sftp, err := startClient(cmd, args...)
if err != nil {
debug.Log("unable to start program: %v", err)
return nil, err
}
sftp.Layout, err = backend.ParseLayout(sftp, cfg.Layout, defaultLayout, cfg.Path)
if err != nil {
return nil, err
}
debug.Log("layout: %v\n", sftp.Layout)
if err := sftp.checkDataSubdirs(); err != nil {
debug.Log("checkDataSubdirs returned %v", err)
return nil, err
}
sftp.Config = cfg
sftp.p = cfg.Path
return sftp, nil
}
func (r *SFTP) checkDataSubdirs() error {
datadir := r.Dirname(restic.Handle{Type: restic.DataFile})
// check if all paths for data/ exist
entries, err := r.c.ReadDir(datadir)
if err != nil {
return err
}
subdirs := make(map[string]struct{}, len(entries))
for _, entry := range entries {
subdirs[entry.Name()] = struct{}{}
}
for i := 0; i < 256; i++ {
subdir := fmt.Sprintf("%02x", i)
if _, ok := subdirs[subdir]; !ok {
debug.Log("subdir %v is missing, creating", subdir)
err := r.mkdirAll(path.Join(datadir, subdir), backend.Modes.Dir)
if err != nil {
return err
}
}
}
return nil
}
func (r *SFTP) mkdirAllDataSubdirs() error {
for _, d := range r.Paths() {
err := r.mkdirAll(d, backend.Modes.Dir)
debug.Log("mkdirAll %v -> %v", d, err)
if err != nil {
return err
}
}
return nil
}
// Join combines path components with slashes (according to the sftp spec).
func (r *SFTP) Join(p ...string) string {
return path.Join(p...)
}
// ReadDir returns the entries for a directory.
func (r *SFTP) ReadDir(dir string) ([]os.FileInfo, error) {
return r.c.ReadDir(dir)
}
// IsNotExist returns true if the error is caused by a not existing file.
func (r *SFTP) IsNotExist(err error) bool {
if os.IsNotExist(err) {
return true
}
statusError, ok := err.(*sftp.StatusError)
if !ok {
return false
}
return statusError.Error() == `sftp: "No such file" (SSH_FX_NO_SUCH_FILE)`
}
func buildSSHCommand(cfg Config) (cmd string, args []string, err error) {
if cfg.Command != "" {
return SplitShellArgs(cfg.Command)
}
cmd = "ssh"
hostport := strings.Split(cfg.Host, ":")
args = []string{hostport[0]}
if len(hostport) > 1 {
args = append(args, "-p", hostport[1])
}
if cfg.User != "" {
args = append(args, "-l")
args = append(args, cfg.User)
}
args = append(args, "-s")
args = append(args, "sftp")
return cmd, args, nil
}
// Create creates an sftp backend as described by the config by running
// "ssh" with the appropriate arguments (or cfg.Command, if set).
func Create(cfg Config) (*SFTP, error) {
cmd, args, err := buildSSHCommand(cfg)
if err != nil {
return nil, err
}
sftp, err := startClient(cmd, args...)
if err != nil {
debug.Log("unable to start program: %v", err)
return nil, err
}
sftp.Layout, err = backend.ParseLayout(sftp, cfg.Layout, defaultLayout, cfg.Path)
if err != nil {
return nil, err
}
// test if config file already exists
_, err = sftp.c.Lstat(Join(cfg.Path, backend.Paths.Config))
if err == nil {
return nil, errors.New("config file already exists")
}
// create paths for data and refs
if err = sftp.mkdirAllDataSubdirs(); err != nil {
return nil, err
}
err = sftp.Close()
if err != nil {
return nil, errors.Wrap(err, "Close")
}
// open backend
return Open(cfg)
}
// Location returns this backend's location (the directory name).
func (r *SFTP) Location() string {
return r.p
}
func (r *SFTP) mkdirAll(dir string, mode os.FileMode) error {
// check if directory already exists
fi, err := r.c.Lstat(dir)
if err == nil {
if fi.IsDir() {
return nil
}
return errors.Errorf("mkdirAll(%s): entry exists but is not a directory", dir)
}
// create parent directories
errMkdirAll := r.mkdirAll(path.Dir(dir), backend.Modes.Dir)
// create directory
errMkdir := r.c.Mkdir(dir)
// test if directory was created successfully
fi, err = r.c.Lstat(dir)
if err != nil {
// return previous errors
return errors.Errorf("mkdirAll(%s): unable to create directories: %v, %v", dir, errMkdirAll, errMkdir)
}
if !fi.IsDir() {
return errors.Errorf("mkdirAll(%s): entry exists but is not a directory", dir)
}
// set mode
return r.c.Chmod(dir, mode)
}
// Join joins the given paths and cleans them afterwards. This always uses
// forward slashes, which is required by sftp.
func Join(parts ...string) string {
return path.Clean(path.Join(parts...))
}
// Save stores data in the backend at the handle.
func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
debug.Log("Save %v", h)
if err := r.clientError(); err != nil {
return err
}
if err := h.Valid(); err != nil {
return err
}
filename := r.Filename(h)
// create new file
f, err := r.c.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY)
if r.IsNotExist(errors.Cause(err)) {
// create the locks dir, then try again
err = r.mkdirAll(r.Dirname(h), backend.Modes.Dir)
if err != nil {
return errors.Wrap(err, "MkdirAll")
}
return r.Save(ctx, h, rd)
}
if err != nil {
return errors.Wrap(err, "OpenFile")
}
// save data
_, err = io.Copy(f, rd)
if err != nil {
_ = f.Close()
return errors.Wrap(err, "Write")
}
err = f.Close()
if err != nil {
return errors.Wrap(err, "Close")
}
// set mode to read-only
fi, err := r.c.Lstat(filename)
if err != nil {
return errors.Wrap(err, "Lstat")
}
err = r.c.Chmod(filename, fi.Mode()&os.FileMode(^uint32(0222)))
return errors.Wrap(err, "Chmod")
}
// Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use.
func (r *SFTP) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v", h, length, offset)
if err := h.Valid(); err != nil {
return nil, err
}
if offset < 0 {
return nil, errors.New("offset is negative")
}
f, err := r.c.Open(r.Filename(h))
if err != nil {
return nil, err
}
if offset > 0 {
_, err = f.Seek(offset, 0)
if err != nil {
_ = f.Close()
return nil, err
}
}
if length > 0 {
return backend.LimitReadCloser(f, int64(length)), nil
}
return f, nil
}
// Stat returns information about a blob.
func (r *SFTP) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
debug.Log("Stat(%v)", h)
if err := r.clientError(); err != nil {
return restic.FileInfo{}, err
}
if err := h.Valid(); err != nil {
return restic.FileInfo{}, err
}
fi, err := r.c.Lstat(r.Filename(h))
if err != nil {
return restic.FileInfo{}, errors.Wrap(err, "Lstat")
}
return restic.FileInfo{Size: fi.Size()}, nil
}
// Test returns true if a blob of the given type and name exists in the backend.
func (r *SFTP) Test(ctx context.Context, h restic.Handle) (bool, error) {
debug.Log("Test(%v)", h)
if err := r.clientError(); err != nil {
return false, err
}
_, err := r.c.Lstat(r.Filename(h))
if os.IsNotExist(errors.Cause(err)) {
return false, nil
}
if err != nil {
return false, errors.Wrap(err, "Lstat")
}
return true, nil
}
// Remove removes the content stored at name.
func (r *SFTP) Remove(ctx context.Context, h restic.Handle) error {
debug.Log("Remove(%v)", h)
if err := r.clientError(); err != nil {
return err
}
return r.c.Remove(r.Filename(h))
}
// List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending
// stops.
func (r *SFTP) List(ctx context.Context, t restic.FileType) <-chan string {
debug.Log("List %v", t)
ch := make(chan string)
go func() {
defer close(ch)
walker := r.c.Walk(r.Basedir(t))
for walker.Step() {
if walker.Err() != nil {
continue
}
if !walker.Stat().Mode().IsRegular() {
continue
}
select {
case ch <- path.Base(walker.Path()):
case <-ctx.Done():
return
}
}
}()
return ch
}
var closeTimeout = 2 * time.Second
// Close closes the sftp connection and terminates the underlying command.
func (r *SFTP) Close() error {
debug.Log("")
if r == nil {
return nil
}
err := r.c.Close()
debug.Log("Close returned error %v", err)
// wait for closeTimeout before killing the process
select {
case err := <-r.result:
return err
case <-time.After(closeTimeout):
}
if err := r.cmd.Process.Kill(); err != nil {
return err
}
// get the error, but ignore it
<-r.result
return nil
}

View File

@@ -0,0 +1,95 @@
package sftp_test
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"restic"
"restic/backend/sftp"
"restic/backend/test"
"restic/errors"
"strings"
"testing"
. "restic/test"
)
func findSFTPServerBinary() string {
for _, dir := range strings.Split(TestSFTPPath, ":") {
testpath := filepath.Join(dir, "sftp-server")
_, err := os.Stat(testpath)
if !os.IsNotExist(errors.Cause(err)) {
return testpath
}
}
return ""
}
var sftpServer = findSFTPServerBinary()
func newTestSuite(t testing.TB) *test.Suite {
return &test.Suite{
// NewConfig returns a config for a new temporary backend that will be used in tests.
NewConfig: func() (interface{}, error) {
dir, err := ioutil.TempDir(TestTempDir, "restic-test-sftp-")
if err != nil {
t.Fatal(err)
}
t.Logf("create new backend at %v", dir)
cfg := sftp.Config{
Path: dir,
Command: fmt.Sprintf("%q -e", sftpServer),
}
return cfg, nil
},
// CreateFn is a function that creates a temporary repository for the tests.
Create: func(config interface{}) (restic.Backend, error) {
cfg := config.(sftp.Config)
return sftp.Create(cfg)
},
// OpenFn is a function that opens a previously created temporary repository.
Open: func(config interface{}) (restic.Backend, error) {
cfg := config.(sftp.Config)
return sftp.Open(cfg)
},
// CleanupFn removes data created during the tests.
Cleanup: func(config interface{}) error {
cfg := config.(sftp.Config)
if !TestCleanupTempDirs {
t.Logf("leaving test backend dir at %v", cfg.Path)
}
RemoveAll(t, cfg.Path)
return nil
},
}
}
func TestBackendSFTP(t *testing.T) {
defer func() {
if t.Skipped() {
SkipDisallowed(t, "restic/backend/sftp.TestBackendSFTP")
}
}()
if sftpServer == "" {
t.Skip("sftp server binary not found")
}
newTestSuite(t).RunTests(t)
}
func BenchmarkBackendSFTP(t *testing.B) {
if sftpServer == "" {
t.Skip("sftp server binary not found")
}
newTestSuite(t).RunBenchmarks(t)
}

View File

@@ -0,0 +1,13 @@
// +build !windows
package sftp
import (
"syscall"
)
// ignoreSigIntProcAttr returns a syscall.SysProcAttr that
// disables SIGINT on parent.
func ignoreSigIntProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{Setsid: true}
}

View File

@@ -0,0 +1,11 @@
package sftp
import (
"syscall"
)
// ignoreSigIntProcAttr returns a default syscall.SysProcAttr
// on Windows.
func ignoreSigIntProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}

View File

@@ -0,0 +1,77 @@
package sftp
import (
"restic/errors"
"unicode"
)
// shellSplitter splits a command string into separater arguments. It supports
// single and double quoted strings.
type shellSplitter struct {
quote rune
lastChar rune
}
func (s *shellSplitter) isSplitChar(c rune) bool {
// only test for quotes if the last char was not a backslash
if s.lastChar != '\\' {
// quote ended
if s.quote != 0 && c == s.quote {
s.quote = 0
return true
}
// quote starts
if s.quote == 0 && (c == '"' || c == '\'') {
s.quote = c
return true
}
}
s.lastChar = c
// within quote
if s.quote != 0 {
return false
}
// outside quote
return c == '\\' || unicode.IsSpace(c)
}
// SplitShellArgs returns the list of arguments from a shell command string.
func SplitShellArgs(data string) (cmd string, args []string, err error) {
s := &shellSplitter{}
// derived from strings.SplitFunc
fieldStart := -1 // Set to -1 when looking for start of field.
for i, rune := range data {
if s.isSplitChar(rune) {
if fieldStart >= 0 {
args = append(args, data[fieldStart:i])
fieldStart = -1
}
} else if fieldStart == -1 {
fieldStart = i
}
}
if fieldStart >= 0 { // Last field might end at EOF.
args = append(args, data[fieldStart:])
}
switch s.quote {
case '\'':
return "", nil, errors.New("single-quoted string not terminated")
case '"':
return "", nil, errors.New("double-quoted string not terminated")
}
if len(args) == 0 {
return "", nil, errors.New("command string is empty")
}
cmd, args = args[0], args[1:]
return cmd, args, nil
}

View File

@@ -0,0 +1,115 @@
package sftp
import (
"reflect"
"testing"
)
func TestShellSplitter(t *testing.T) {
var tests = []struct {
data string
cmd string
args []string
}{
{
`foo`,
"foo", []string{},
},
{
`'foo'`,
"foo", []string{},
},
{
`foo bar baz`,
"foo", []string{"bar", "baz"},
},
{
`foo 'bar' baz`,
"foo", []string{"bar", "baz"},
},
{
`'bar box' baz`,
"bar box", []string{"baz"},
},
{
`"bar 'box'" baz`,
"bar 'box'", []string{"baz"},
},
{
`'bar "box"' baz`,
`bar "box"`, []string{"baz"},
},
{
`\"bar box baz`,
`"bar`, []string{"box", "baz"},
},
{
`"bar/foo/x" "box baz"`,
"bar/foo/x", []string{"box baz"},
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
cmd, args, err := SplitShellArgs(test.data)
if err != nil {
t.Fatal(err)
}
if cmd != test.cmd {
t.Fatalf("wrong cmd returned, want:\n %#v\ngot:\n %#v",
test.cmd, cmd)
}
if !reflect.DeepEqual(args, test.args) {
t.Fatalf("wrong args returned, want:\n %#v\ngot:\n %#v",
test.args, args)
}
})
}
}
func TestShellSplitterInvalid(t *testing.T) {
var tests = []struct {
data string
err string
}{
{
"foo'",
"single-quoted string not terminated",
},
{
`foo"`,
"double-quoted string not terminated",
},
{
"foo 'bar",
"single-quoted string not terminated",
},
{
`foo "bar`,
"double-quoted string not terminated",
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
cmd, args, err := SplitShellArgs(test.data)
if err == nil {
t.Fatalf("expected error not found: %v", test.err)
}
if err.Error() != test.err {
t.Fatalf("expected error not found, want:\n %q\ngot:\n %q", test.err, err.Error())
}
if cmd != "" {
t.Fatalf("splitter returned cmd from invalid data: %v", cmd)
}
if len(args) > 0 {
t.Fatalf("splitter returned fields from invalid data: %v", args)
}
})
}
}

View File

@@ -0,0 +1,52 @@
package sftp
import (
"reflect"
"testing"
)
var sshcmdTests = []struct {
cfg Config
cmd string
args []string
}{
{
Config{User: "user", Host: "host", Path: "dir/subdir"},
"ssh",
[]string{"host", "-l", "user", "-s", "sftp"},
},
{
Config{Host: "host", Path: "dir/subdir"},
"ssh",
[]string{"host", "-s", "sftp"},
},
{
Config{Host: "host:10022", Path: "/dir/subdir"},
"ssh",
[]string{"host", "-p", "10022", "-s", "sftp"},
},
{
Config{User: "user", Host: "host:10022", Path: "/dir/subdir"},
"ssh",
[]string{"host", "-p", "10022", "-l", "user", "-s", "sftp"},
},
}
func TestBuildSSHCommand(t *testing.T) {
for _, test := range sshcmdTests {
t.Run("", func(t *testing.T) {
cmd, args, err := buildSSHCommand(test.cfg)
if err != nil {
t.Fatal(err)
}
if cmd != test.cmd {
t.Fatalf("cmd: want %v, got %v", test.cmd, cmd)
}
if !reflect.DeepEqual(test.args, args) {
t.Fatalf("wrong args, want:\n %v\ngot:\n %v", test.args, args)
}
})
}
}