mirror of
https://github.com/restic/restic.git
synced 2025-08-13 19:56:42 +00:00
Read TLS client cert and key from the same file
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func newAzureTestSuite(t testing.TB) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
|
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func newB2TestSuite(t testing.TB) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
|
@@ -3,19 +3,66 @@ package backend
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
"github.com/restic/restic/internal/errors"
|
||||
)
|
||||
|
||||
// TransportOptions collects various options which can be set for an HTTP based
|
||||
// transport.
|
||||
type TransportOptions struct {
|
||||
// contains filenames of PEM encoded root certificates to trust
|
||||
RootCertFilenames []string
|
||||
|
||||
// contains the name of a file containing the TLS client certificate and private key in PEM format
|
||||
TLSClientCertKeyFilename string
|
||||
}
|
||||
|
||||
// readPEMCertKey reads a file and returns the PEM encoded certificate and key
|
||||
// blocks.
|
||||
func readPEMCertKey(filename string) (certs []byte, key []byte, err error) {
|
||||
data, err := ioutil.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "ReadFile")
|
||||
}
|
||||
|
||||
var block *pem.Block
|
||||
for {
|
||||
if len(data) == 0 {
|
||||
break
|
||||
}
|
||||
block, data = pem.Decode(data)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(block.Type, "CERTIFICATE"):
|
||||
certs = append(certs, pem.EncodeToMemory(block)...)
|
||||
case strings.HasSuffix(block.Type, "PRIVATE KEY"):
|
||||
if key != nil {
|
||||
return nil, nil, errors.Errorf("error loading TLS cert and key from %v: more than one private key found", filename)
|
||||
}
|
||||
key = pem.EncodeToMemory(block)
|
||||
default:
|
||||
return nil, nil, errors.Errorf("error loading TLS cert and key from %v: unknown block type %v found", filename, block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return certs, key, nil
|
||||
}
|
||||
|
||||
// Transport returns a new http.RoundTripper with default settings applied. If
|
||||
// a custom rootCertFilename is non-empty, it must point to a valid PEM file,
|
||||
// otherwise the function will return an error.
|
||||
func Transport(rootCertFilenames []string, tlsClientCert string, tlsClientKey string) (http.RoundTripper, error) {
|
||||
func Transport(opts TransportOptions) (http.RoundTripper, error) {
|
||||
// copied from net/http
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -32,34 +79,36 @@ func Transport(rootCertFilenames []string, tlsClientCert string, tlsClientKey st
|
||||
TLSClientConfig: &tls.Config{},
|
||||
}
|
||||
|
||||
if tlsClientCert != "" && tlsClientKey != "" {
|
||||
c, err := tls.LoadX509KeyPair(tlsClientCert, tlsClientKey)
|
||||
if opts.TLSClientCertKeyFilename != "" {
|
||||
certs, key, err := readPEMCertKey(opts.TLSClientCertKeyFilename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read client certificate/key pair: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
tr.TLSClientConfig.Certificates = []tls.Certificate{c}
|
||||
}
|
||||
|
||||
if rootCertFilenames == nil {
|
||||
return debug.RoundTripper(tr), nil
|
||||
}
|
||||
|
||||
p := x509.NewCertPool()
|
||||
for _, filename := range rootCertFilenames {
|
||||
if filename == "" {
|
||||
return nil, fmt.Errorf("empty filename for root certificate supplied")
|
||||
}
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
crt, err := tls.X509KeyPair(certs, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read root certificate: %v", err)
|
||||
}
|
||||
if ok := p.AppendCertsFromPEM(b); !ok {
|
||||
return nil, fmt.Errorf("cannot parse root certificate from %q", filename)
|
||||
return nil, errors.Errorf("parse TLS client cert or key: %v", err)
|
||||
}
|
||||
tr.TLSClientConfig.Certificates = []tls.Certificate{crt}
|
||||
}
|
||||
|
||||
tr.TLSClientConfig.RootCAs = p
|
||||
if opts.RootCertFilenames != nil {
|
||||
pool := x509.NewCertPool()
|
||||
for _, filename := range opts.RootCertFilenames {
|
||||
if filename == "" {
|
||||
return nil, errors.Errorf("empty filename for root certificate supplied")
|
||||
}
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("unable to read root certificate: %v", err)
|
||||
}
|
||||
if ok := pool.AppendCertsFromPEM(b); !ok {
|
||||
return nil, errors.Errorf("cannot parse root certificate from %q", filename)
|
||||
}
|
||||
}
|
||||
tr.TLSClientConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
// wrap in the debug round tripper
|
||||
// wrap in the debug round tripper (if active)
|
||||
return debug.RoundTripper(tr), nil
|
||||
}
|
||||
|
@@ -68,7 +68,7 @@ func runRESTServer(ctx context.Context, t testing.TB, dir string) (*url.URL, fun
|
||||
}
|
||||
|
||||
func newTestSuite(ctx context.Context, t testing.TB, url *url.URL, minimalData bool) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
|
@@ -121,7 +121,7 @@ func createS3(t testing.TB, cfg MinioTestConfig, tr http.RoundTripper) (be resti
|
||||
}
|
||||
|
||||
func newMinioTestSuite(ctx context.Context, t testing.TB) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func BenchmarkBackendMinio(t *testing.B) {
|
||||
}
|
||||
|
||||
func newS3TestSuite(t testing.TB) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
|
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func newSwiftTestSuite(t testing.TB) *test.Suite {
|
||||
tr, err := backend.Transport(nil, "", "")
|
||||
tr, err := backend.Transport(backend.TransportOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create transport for tests: %v", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user