tailscale/net/connectproxy/connectproxy.go

94 lines
1.9 KiB
Go
Raw Normal View History

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package connectproxy contains some CONNECT proxy code.
package connectproxy
import (
"context"
"io"
"log"
"net"
"net/http"
"time"
"tailscale.com/net/netx"
"tailscale.com/types/logger"
)
// Handler is an HTTP CONNECT proxy handler.
type Handler struct {
// Dial, if non-nil, is an alternate dialer to use
// instead of the default dialer.
Dial netx.DialFunc
// Logf, if non-nil, is an alterate logger to
// use instead of log.Printf.
Logf logger.Logf
// Check, if non-nil, validates the CONNECT target.
Check func(hostPort string) error
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method != "CONNECT" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dial := h.Dial
if dial == nil {
var d net.Dialer
dial = d.DialContext
}
logf := h.Logf
if logf == nil {
logf = log.Printf
}
hostPort := r.RequestURI
if h.Check != nil {
if err := h.Check(hostPort); err != nil {
logf("CONNECT target %q not allowed: %v", hostPort, err)
http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
return
}
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
back, err := dial(ctx, "tcp", hostPort)
if err != nil {
logf("error CONNECT dialing %v: %v", hostPort, err)
http.Error(w, "Connect failure", http.StatusBadGateway)
return
}
defer back.Close()
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
return
}
c, br, err := hj.Hijack()
if err != nil {
logf("CONNECT hijack: %v", err)
return
}
defer c.Close()
io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
errc := make(chan error, 2)
go func() {
_, err := io.Copy(c, back)
errc <- err
}()
go func() {
_, err := io.Copy(back, br)
errc <- err
}()
<-errc
}