client/web: move auth session creation out of /api/auth

Splits auth session creation into two new endpoints:

/api/auth/session/new - to request a new auth session

/api/auth/session/wait - to block until user has completed auth url

Updates tailscale/corp#14335

Signed-off-by: Sonia Appasamy <sonia@tailscale.com>
This commit is contained in:
Sonia Appasamy 2023-11-03 13:38:01 -04:00 committed by Sonia Appasamy
parent 658971d7c0
commit e5dcf7bdde
5 changed files with 176 additions and 97 deletions

View File

@ -3,19 +3,19 @@ import React from "react"
import LegacyClientView from "src/components/views/legacy-client-view"
import LoginClientView from "src/components/views/login-client-view"
import ReadonlyClientView from "src/components/views/readonly-client-view"
import useAuth, { AuthResponse } from "src/hooks/auth"
import useAuth, { AuthResponse, SessionsCallbacks } from "src/hooks/auth"
import useNodeData from "src/hooks/node-data"
import ManagementClientView from "./views/management-client-view"
export default function App() {
const { data: auth, loading: loadingAuth, waitOnAuth } = useAuth()
const { data: auth, loading: loadingAuth, sessions } = useAuth()
return (
<div className="flex flex-col items-center min-w-sm max-w-lg mx-auto py-14">
{loadingAuth ? (
<div className="text-center py-14">Loading...</div> // TODO(sonia): add a loading view
) : (
<WebClient auth={auth} waitOnAuth={waitOnAuth} />
<WebClient auth={auth} sessions={sessions} />
)}
</div>
)
@ -23,10 +23,10 @@ export default function App() {
function WebClient({
auth,
waitOnAuth,
sessions,
}: {
auth?: AuthResponse
waitOnAuth: () => Promise<void>
sessions: SessionsCallbacks
}) {
const { data, refreshData, updateNode } = useNodeData()
@ -45,7 +45,7 @@ function WebClient({
<ManagementClientView {...data} />
) : data.DebugMode === "login" || data.DebugMode === "full" ? (
// Render new client interface in readonly mode.
<ReadonlyClientView data={data} auth={auth} waitOnAuth={waitOnAuth} />
<ReadonlyClientView data={data} auth={auth} sessions={sessions} />
) : (
// Render legacy client interface.
<LegacyClientView

View File

@ -1,5 +1,5 @@
import React from "react"
import { AuthResponse } from "src/hooks/auth"
import { AuthResponse, AuthType, SessionsCallbacks } from "src/hooks/auth"
import { NodeData } from "src/hooks/node-data"
import { ReactComponent as ConnectedDeviceIcon } from "src/icons/connected-device.svg"
import { ReactComponent as TailscaleLogo } from "src/icons/tailscale-logo.svg"
@ -17,11 +17,11 @@ import ProfilePic from "src/ui/profile-pic"
export default function ReadonlyClientView({
data,
auth,
waitOnAuth,
sessions,
}: {
data: NodeData
auth?: AuthResponse
waitOnAuth: () => Promise<void>
sessions: SessionsCallbacks
}) {
return (
<>
@ -51,12 +51,14 @@ export default function ReadonlyClientView({
<div className="text-sm leading-tight">{data.IP}</div>
</div>
</div>
{data.DebugMode === "full" && (
{auth?.authNeeded == AuthType.tailscale && (
<button
className="button button-blue ml-6"
onClick={() => {
window.open(auth?.authUrl, "_blank")
waitOnAuth()
sessions
.new()
.then((url) => window.open(url, "_blank"))
.then(() => sessions.wait())
}}
>
Access

View File

@ -8,20 +8,23 @@ export enum AuthType {
export type AuthResponse = {
ok: boolean
authUrl?: string
authNeeded?: AuthType
}
export type SessionsCallbacks = {
new: () => Promise<string> // creates new auth session and returns authURL
wait: () => Promise<void> // blocks until auth is completed
}
// useAuth reports and refreshes Tailscale auth status
// for the web client.
export default function useAuth() {
const [data, setData] = useState<AuthResponse>()
const [loading, setLoading] = useState<boolean>(true)
const loadAuth = useCallback((wait?: boolean) => {
const url = wait ? "/auth?wait=true" : "/auth"
const loadAuth = useCallback(() => {
setLoading(true)
return apiFetch(url, "GET")
return apiFetch("/auth", "GET")
.then((r) => r.json())
.then((d) => {
setData(d)
@ -44,11 +47,33 @@ export default function useAuth() {
})
}, [])
const newSession = useCallback(() => {
return apiFetch("/auth/session/new", "GET")
.then((r) => r.json())
.then((d) => d.authUrl)
.catch((error) => {
console.error(error)
})
}, [])
const waitForSessionCompletion = useCallback(() => {
return apiFetch("/auth/session/wait", "GET")
.then(() => loadAuth()) // refresh auth data
.catch((error) => {
console.error(error)
})
}, [])
useEffect(() => {
loadAuth()
}, [])
const waitOnAuth = useCallback(() => loadAuth(true), [])
return { data, loading, waitOnAuth }
return {
data,
loading,
sessions: {
new: newSession,
wait: waitForSessionCompletion,
},
}
}

View File

@ -203,8 +203,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/") {
if r.Method == httpm.GET && r.URL.Path == "/api/auth" {
s.serveAPIAuth(w, r)
switch {
case r.URL.Path == "/api/auth" && r.Method == httpm.GET:
s.serveAPIAuth(w, r) // serve auth status
return
case r.URL.Path == "/api/auth/session/new" && r.Method == httpm.GET:
s.serveAPIAuthSessionNew(w, r) // create new session
return
case r.URL.Path == "/api/auth/session/wait" && r.Method == httpm.GET:
s.serveAPIAuthSessionWait(w, r) // wait for session to be authorized
return
}
if ok := s.authorizeRequest(w, r); !ok {
@ -295,20 +302,15 @@ func (s *Server) serveLoginAPI(w http.ResponseWriter, r *http.Request) {
type authResponse struct {
OK bool `json:"ok"` // true when user has valid auth session
AuthURL string `json:"authUrl,omitempty"` // filled when user has control auth action to take
AuthNeeded authType `json:"authNeeded,omitempty"` // filled when user needs to complete a specific type of auth
}
// serverAPIAuth handles requests to the /api/auth endpoint
// and returns an authResponse indicating the current auth state and any steps the user needs to take.
func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
if r.Method != httpm.GET {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var resp authResponse
session, whois, err := s.getSession(r)
session, _, err := s.getSession(r)
switch {
case err != nil && errors.Is(err, errNotUsingTailscale):
// not using tailscale, so perform platform auth
@ -328,12 +330,43 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
default:
resp.OK = true // no additional auth for this distro
}
case err != nil && (errors.Is(err, errNotOwner) ||
errors.Is(err, errNotUsingTailscale) ||
errors.Is(err, errTaggedLocalSource) ||
errors.Is(err, errTaggedRemoteSource)):
// These cases are all restricted to the readonly view.
// No auth action to take.
resp = authResponse{OK: false}
case err != nil && !errors.Is(err, errNoSession):
// Any other error.
http.Error(w, err.Error(), http.StatusInternalServerError)
return
case session.isAuthorized(s.timeNow()):
resp = authResponse{OK: true}
default:
resp = authResponse{OK: false, AuthNeeded: tailscaleAuth}
}
writeJSON(w, resp)
}
type newSessionAuthResponse struct {
AuthURL string `json:"authUrl,omitempty"`
}
// serveAPIAuthSessionNew handles requests to the /api/auth/session/new endpoint.
func (s *Server) serveAPIAuthSessionNew(w http.ResponseWriter, r *http.Request) {
session, whois, err := s.getSession(r)
if err != nil && !errors.Is(err, errNoSession) {
// Source associated with request not allowed to create
// a session for this web client.
http.Error(w, err.Error(), http.StatusUnauthorized)
return
case session == nil:
}
if session == nil {
// Create a new session.
session, err := s.newSession(r.Context(), whois)
// If one already existed, we return that authURL rather than creating a new one.
session, err = s.newSession(r.Context(), whois)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -346,30 +379,26 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) {
Path: "/",
Expires: session.expires(),
})
resp = authResponse{OK: false, AuthURL: session.AuthURL}
case !session.isAuthorized(s.timeNow()):
if r.URL.Query().Get("wait") == "true" {
// Client requested we block until user completes auth.
}
writeJSON(w, newSessionAuthResponse{AuthURL: session.AuthURL})
}
// serveAPIAuthSessionWait handles requests to the /api/auth/session/wait endpoint.
func (s *Server) serveAPIAuthSessionWait(w http.ResponseWriter, r *http.Request) {
session, _, err := s.getSession(r)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if session.isAuthorized(s.timeNow()) {
return // already authorized
}
if err := s.awaitUserAuth(r.Context(), session); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}
if session.isAuthorized(s.timeNow()) {
resp = authResponse{OK: true}
} else {
resp = authResponse{OK: false, AuthURL: session.AuthURL}
}
default:
resp = authResponse{OK: true}
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
}
// serveAPI serves requests for the web client api.
// It should only be called by Server.ServeHTTP, via Server.apiHandler,
@ -458,11 +487,7 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) {
if len(st.TailscaleIPs) != 0 {
data.IP = st.TailscaleIPs[0].String()
}
if err := json.NewEncoder(w).Encode(*data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
writeJSON(w, *data)
}
type nodeUpdate struct {
@ -720,3 +745,12 @@ func enforcePrefix(prefix string, h http.HandlerFunc) http.HandlerFunc {
http.StripPrefix(prefix, h).ServeHTTP(w, r)
}
}
func writeJSON(w http.ResponseWriter, data any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
w.Header().Set("Content-Type", "text/plain")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

View File

@ -408,7 +408,7 @@ func() *ipnstate.PeerStatus { return self },
}
}
func TestServeTailscaleAuth(t *testing.T) {
func TestServeAuth(t *testing.T) {
user := &tailcfg.UserProfile{ID: tailcfg.UserID(1)}
self := &ipnstate.PeerStatus{ID: "self", UserID: user.ID}
remoteNode := &apitype.WhoIsResponse{Node: &tailcfg.Node{ID: 1}, UserProfile: user}
@ -463,17 +463,28 @@ func() *ipnstate.PeerStatus { return self },
tests := []struct {
name string
cookie string
query string
cookie string // cookie attached to request
wantNewCookie bool // want new cookie generated during request
wantSession *browserSession // session associated w/ cookie after request
path string
wantStatus int
wantResp *authResponse
wantNewCookie bool // new cookie generated
wantSession *browserSession // session associated w/ cookie at end of request
wantResp any
}{
{
name: "new-session-created",
name: "no-session",
path: "/api/auth",
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
wantResp: &authResponse{OK: false, AuthNeeded: tailscaleAuth},
wantNewCookie: false,
wantSession: nil,
},
{
name: "new-session",
path: "/api/auth/session/new",
wantStatus: http.StatusOK,
wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true,
wantSession: &browserSession{
ID: "GENERATED_ID", // gets swapped for newly created ID by test
@ -487,9 +498,26 @@ func() *ipnstate.PeerStatus { return self },
},
{
name: "query-existing-incomplete-session",
path: "/api/auth",
cookie: successCookie,
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPathSuccess},
wantResp: &authResponse{OK: false, AuthNeeded: tailscaleAuth},
wantSession: &browserSession{
ID: successCookie,
SrcNode: remoteNode.Node.ID,
SrcUser: user.ID,
Created: oneHourAgo,
AuthID: testAuthPathSuccess,
AuthURL: testControlURL + testAuthPathSuccess,
Authenticated: false,
},
},
{
name: "existing-session-used",
path: "/api/auth/session/new", // should not create new session
cookie: successCookie,
wantStatus: http.StatusOK,
wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPathSuccess},
wantSession: &browserSession{
ID: successCookie,
SrcNode: remoteNode.Node.ID,
@ -502,12 +530,10 @@ func() *ipnstate.PeerStatus { return self },
},
{
name: "transition-to-successful-session",
path: "/api/auth/session/wait",
cookie: successCookie,
// query "wait" indicates the FE wants to make
// local api call to wait until session completed.
query: "wait=true",
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: true},
wantResp: nil,
wantSession: &browserSession{
ID: successCookie,
SrcNode: remoteNode.Node.ID,
@ -520,6 +546,7 @@ func() *ipnstate.PeerStatus { return self },
},
{
name: "query-existing-complete-session",
path: "/api/auth",
cookie: successCookie,
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: true},
@ -535,17 +562,18 @@ func() *ipnstate.PeerStatus { return self },
},
{
name: "transition-to-failed-session",
path: "/api/auth/session/wait",
cookie: failureCookie,
query: "wait=true",
wantStatus: http.StatusUnauthorized,
wantResp: nil,
wantSession: nil, // session deleted
},
{
name: "failed-session-cleaned-up",
path: "/api/auth/session/new",
cookie: failureCookie,
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true,
wantSession: &browserSession{
ID: "GENERATED_ID",
@ -559,9 +587,10 @@ func() *ipnstate.PeerStatus { return self },
},
{
name: "expired-cookie-gets-new-session",
path: "/api/auth/session/new",
cookie: expiredCookie,
wantStatus: http.StatusOK,
wantResp: &authResponse{OK: false, AuthURL: testControlURL + testAuthPath},
wantResp: &newSessionAuthResponse{AuthURL: testControlURL + testAuthPath},
wantNewCookie: true,
wantSession: &browserSession{
ID: "GENERATED_ID",
@ -576,12 +605,11 @@ func() *ipnstate.PeerStatus { return self },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "/api/auth", nil)
r.URL.RawQuery = tt.query
r := httptest.NewRequest("GET", tt.path, nil)
r.RemoteAddr = remoteIP
r.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tt.cookie})
w := httptest.NewRecorder()
s.serveAPIAuth(w, r)
s.serve(w, r)
res := w.Result()
defer res.Body.Close()
@ -589,17 +617,20 @@ func() *ipnstate.PeerStatus { return self },
if gotStatus := res.StatusCode; tt.wantStatus != gotStatus {
t.Errorf("wrong status; want=%v, got=%v", tt.wantStatus, gotStatus)
}
var gotResp *authResponse
var gotResp string
if res.StatusCode == http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if err := json.Unmarshal(body, &gotResp); err != nil {
t.Fatal(err)
gotResp = strings.Trim(string(body), "\n")
}
var wantResp string
if tt.wantResp != nil {
b, _ := json.Marshal(tt.wantResp)
wantResp = string(b)
}
if diff := cmp.Diff(gotResp, tt.wantResp); diff != "" {
if diff := cmp.Diff(gotResp, string(wantResp)); diff != "" {
t.Errorf("wrong response; (-got+want):%v", diff)
}
// Validate cookie creation.
@ -654,22 +685,13 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu
t.Fatalf("/whois call missing \"addr\" query")
}
if node := whoIs[addr]; node != nil {
if err := json.NewEncoder(w).Encode(&node); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
writeJSON(w, &node)
return
}
http.Error(w, "not a node", http.StatusUnauthorized)
return
case "/localapi/v0/status":
status := ipnstate.Status{Self: self()}
if err := json.NewEncoder(w).Encode(status); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
writeJSON(w, ipnstate.Status{Self: self()})
return
case "/localapi/v0/debug-web-client": // used by TestServeTailscaleAuth
type reqData struct {
@ -694,11 +716,7 @@ type reqData struct {
http.Error(w, "authenticated as wrong user", http.StatusUnauthorized)
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
writeJSON(w, resp)
return
default:
t.Fatalf("unhandled localapi test endpoint %q, add to localapi handler func in test", r.URL.Path)