cmd/tsidp: fix tsnet listener

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-11-14 23:53:20 -08:00
parent 93664ac8dc
commit 3fa5c76cba

View File

@ -56,6 +56,8 @@ func main() {
lc *tailscale.LocalClient
st *ipnstate.Status
err error
lns []net.Listener
)
if *flagUseLocalTailscaled {
lc = &tailscale.LocalClient{}
@ -63,6 +65,23 @@ func main() {
if err != nil {
log.Fatalf("getting status: %v", err)
}
portStr := fmt.Sprint(*flagPort)
anySuccess := false
for _, ip := range st.TailscaleIPs {
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr))
if err != nil {
log.Printf("failed to listen on %v: %v", ip, err)
continue
}
anySuccess = true
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate,
})
lns = append(lns, ln)
}
if !anySuccess {
log.Fatalf("failed to listen on any of %v", st.TailscaleIPs)
}
} else {
ts := &tsnet.Server{
Hostname: "idp",
@ -78,34 +97,38 @@ func main() {
if err != nil {
log.Fatalf("getting local client: %v", err)
}
ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort))
if err != nil {
log.Fatal(err)
}
lns = append(lns, ln)
}
srv := &idpServer{
lc: lc,
serverURL: fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort),
lc: lc,
}
if *flagPort != 443 {
srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort)
} else {
srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, "."))
}
log.Printf("Running tsidp at %s ...", srv.serverURL)
if *flagLocalPort != -1 {
log.Printf("Also running tsidp at %s ...", srv.loopbackURL)
srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort)
go func() {
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort))
if err != nil {
log.Fatal(err)
}
log.Printf("Also running tsidp at %s ...", srv.loopbackURL)
http.Serve(ln, srv)
}()
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort))
if err != nil {
log.Fatal(err)
}
lns = append(lns, ln)
}
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", st.TailscaleIPs[0], *flagPort))
if err != nil {
log.Fatal(err)
for _, ln := range lns {
go http.Serve(ln, srv)
}
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate,
})
log.Fatal(http.Serve(ln, srv))
select {}
}
type idpServer struct {