login route host context

This commit is contained in:
Max Peintner
2025-01-16 10:01:51 +01:00
parent 52548e35c5
commit b6a2f20dee

View File

@@ -22,16 +22,18 @@ import {
} from "@zitadel/proto/zitadel/oidc/v2/oidc_service_pb"; } from "@zitadel/proto/zitadel/oidc/v2/oidc_service_pb";
import { Session } from "@zitadel/proto/zitadel/session/v2/session_pb"; import { Session } from "@zitadel/proto/zitadel/session/v2/session_pb";
import { AuthenticationMethodType } from "@zitadel/proto/zitadel/user/v2/user_service_pb"; import { AuthenticationMethodType } from "@zitadel/proto/zitadel/user/v2/user_service_pb";
import { headers } from "next/headers";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
export const revalidate = false; export const revalidate = false;
export const fetchCache = "default-no-store"; export const fetchCache = "default-no-store";
async function loadSessions(ids: string[]): Promise<Session[]> { async function loadSessions(host: string, ids: string[]): Promise<Session[]> {
const response = await listSessions( const response = await listSessions({
ids.filter((id: string | undefined) => !!id), host,
); ids: ids.filter((id: string | undefined) => !!id),
});
return response?.sessions ?? []; return response?.sessions ?? [];
} }
@@ -44,7 +46,10 @@ const IDP_SCOPE_REGEX = /urn:zitadel:iam:org:idp:id:(.+)/;
* mfa is required, session is not valid anymore (e.g. session expired, user logged out, etc.) * mfa is required, session is not valid anymore (e.g. session expired, user logged out, etc.)
* to check for mfa for automatically selected session -> const response = await listAuthenticationMethodTypes(userId); * to check for mfa for automatically selected session -> const response = await listAuthenticationMethodTypes(userId);
**/ **/
async function isSessionValid(session: Session): Promise<boolean> { async function isSessionValid(
host: string,
session: Session,
): Promise<boolean> {
// session can't be checked without user // session can't be checked without user
if (!session.factors?.user) { if (!session.factors?.user) {
console.warn("Session has no user"); console.warn("Session has no user");
@@ -53,9 +58,10 @@ async function isSessionValid(session: Session): Promise<boolean> {
let mfaValid = true; let mfaValid = true;
const authMethodTypes = await listAuthenticationMethodTypes( const authMethodTypes = await listAuthenticationMethodTypes({
session.factors.user.id, host,
); userId: session.factors.user.id,
});
const authMethods = authMethodTypes.authMethodTypes; const authMethods = authMethodTypes.authMethodTypes;
if (authMethods && authMethods.includes(AuthenticationMethodType.TOTP)) { if (authMethods && authMethods.includes(AuthenticationMethodType.TOTP)) {
@@ -101,9 +107,10 @@ async function isSessionValid(session: Session): Promise<boolean> {
} }
} else { } else {
// only check settings if no auth methods are available, as this would require a setup // only check settings if no auth methods are available, as this would require a setup
const loginSettings = await getLoginSettings( const loginSettings = await getLoginSettings({
session.factors?.user?.organizationId, host,
); organization: session.factors?.user?.organizationId,
});
if (loginSettings?.forceMfa || loginSettings?.forceMfaLocalOnly) { if (loginSettings?.forceMfa || loginSettings?.forceMfaLocalOnly) {
const otpEmail = session.factors.otpEmail?.verifiedAt; const otpEmail = session.factors.otpEmail?.verifiedAt;
const otpSms = session.factors.otpSms?.verifiedAt; const otpSms = session.factors.otpSms?.verifiedAt;
@@ -144,6 +151,7 @@ async function isSessionValid(session: Session): Promise<boolean> {
} }
async function findValidSession( async function findValidSession(
host: string,
sessions: Session[], sessions: Session[],
authRequest: AuthRequest, authRequest: AuthRequest,
): Promise<Session | undefined> { ): Promise<Session | undefined> {
@@ -170,7 +178,7 @@ async function findValidSession(
// return the first valid session according to settings // return the first valid session according to settings
for (const session of sessionsWithHint) { for (const session of sessionsWithHint) {
if (await isSessionValid(session)) { if (await isSessionValid(host, session)) {
return session; return session;
} }
} }
@@ -183,6 +191,12 @@ export async function GET(request: NextRequest) {
const authRequestId = searchParams.get("authRequest"); const authRequestId = searchParams.get("authRequest");
const sessionId = searchParams.get("sessionId"); const sessionId = searchParams.get("sessionId");
const host = (await headers()).get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
// TODO: find a better way to handle _rsc (react server components) requests and block them to avoid conflicts when creating oidc callback // TODO: find a better way to handle _rsc (react server components) requests and block them to avoid conflicts when creating oidc callback
const _rsc = searchParams.get("_rsc"); const _rsc = searchParams.get("_rsc");
if (_rsc) { if (_rsc) {
@@ -193,7 +207,7 @@ export async function GET(request: NextRequest) {
const ids = sessionCookies.map((s) => s.id); const ids = sessionCookies.map((s) => s.id);
let sessions: Session[] = []; let sessions: Session[] = [];
if (ids && ids.length) { if (ids && ids.length) {
sessions = await loadSessions(ids); sessions = await loadSessions(host, ids);
} }
if (authRequestId && sessionId) { if (authRequestId && sessionId) {
@@ -206,7 +220,7 @@ export async function GET(request: NextRequest) {
if (selectedSession && selectedSession.id) { if (selectedSession && selectedSession.id) {
console.log(`Found session ${selectedSession.id}`); console.log(`Found session ${selectedSession.id}`);
const isValid = await isSessionValid(selectedSession); const isValid = await isSessionValid(host, selectedSession);
console.log("Session is valid:", isValid); console.log("Session is valid:", isValid);
@@ -239,15 +253,16 @@ export async function GET(request: NextRequest) {
// works not with _rsc request // works not with _rsc request
try { try {
const { callbackUrl } = await createCallback( const { callbackUrl } = await createCallback({
create(CreateCallbackRequestSchema, { host,
req: create(CreateCallbackRequestSchema, {
authRequestId, authRequestId,
callbackKind: { callbackKind: {
case: "session", case: "session",
value: create(SessionSchema, session), value: create(SessionSchema, session),
}, },
}), }),
); });
if (callbackUrl) { if (callbackUrl) {
return NextResponse.redirect(callbackUrl); return NextResponse.redirect(callbackUrl);
} else { } else {
@@ -265,9 +280,10 @@ export async function GET(request: NextRequest) {
"code" in error && "code" in error &&
error?.code === 9 error?.code === 9
) { ) {
const loginSettings = await getLoginSettings( const loginSettings = await getLoginSettings({
selectedSession.factors?.user?.organizationId, host,
); organization: selectedSession.factors?.user?.organizationId,
});
if (loginSettings?.defaultRedirectUri) { if (loginSettings?.defaultRedirectUri) {
return NextResponse.redirect(loginSettings.defaultRedirectUri); return NextResponse.redirect(loginSettings.defaultRedirectUri);
@@ -297,7 +313,7 @@ export async function GET(request: NextRequest) {
} }
if (authRequestId) { if (authRequestId) {
const { authRequest } = await getAuthRequest({ authRequestId }); const { authRequest } = await getAuthRequest({ host, authRequestId });
let organization = ""; let organization = "";
let suffix = ""; let suffix = "";
@@ -324,7 +340,7 @@ export async function GET(request: NextRequest) {
const matched = ORG_DOMAIN_SCOPE_REGEX.exec(orgDomainScope); const matched = ORG_DOMAIN_SCOPE_REGEX.exec(orgDomainScope);
const orgDomain = matched?.[1] ?? ""; const orgDomain = matched?.[1] ?? "";
if (orgDomain) { if (orgDomain) {
const orgs = await getOrgsByDomain(orgDomain); const orgs = await getOrgsByDomain({ host, domain: orgDomain });
if (orgs.result && orgs.result.length === 1) { if (orgs.result && orgs.result.length === 1) {
organization = orgs.result[0].id ?? ""; organization = orgs.result[0].id ?? "";
suffix = orgDomain; suffix = orgDomain;
@@ -337,9 +353,10 @@ export async function GET(request: NextRequest) {
const matched = IDP_SCOPE_REGEX.exec(idpScope); const matched = IDP_SCOPE_REGEX.exec(idpScope);
idpId = matched?.[1] ?? ""; idpId = matched?.[1] ?? "";
const identityProviders = await getActiveIdentityProviders( const identityProviders = await getActiveIdentityProviders({
organization ? organization : undefined, host,
).then((resp) => { orgId: organization ? organization : undefined,
}).then((resp) => {
return resp.identityProviders; return resp.identityProviders;
}); });
@@ -362,6 +379,7 @@ export async function GET(request: NextRequest) {
} }
return startIdentityProviderFlow({ return startIdentityProviderFlow({
host,
idpId, idpId,
urls: { urls: {
successUrl: successUrl:
@@ -460,7 +478,11 @@ export async function GET(request: NextRequest) {
* This means that the user should not be prompted to enter their password again. * This means that the user should not be prompted to enter their password again.
* Instead, the server attempts to silently authenticate the user using an existing session or other authentication mechanisms that do not require user interaction * Instead, the server attempts to silently authenticate the user using an existing session or other authentication mechanisms that do not require user interaction
**/ **/
const selectedSession = await findValidSession(sessions, authRequest); const selectedSession = await findValidSession(
host,
sessions,
authRequest,
);
if (!selectedSession || !selectedSession.id) { if (!selectedSession || !selectedSession.id) {
return NextResponse.json( return NextResponse.json(
@@ -485,19 +507,24 @@ export async function GET(request: NextRequest) {
sessionToken: cookie.token, sessionToken: cookie.token,
}; };
const { callbackUrl } = await createCallback( const { callbackUrl } = await createCallback({
create(CreateCallbackRequestSchema, { host,
req: create(CreateCallbackRequestSchema, {
authRequestId, authRequestId,
callbackKind: { callbackKind: {
case: "session", case: "session",
value: create(SessionSchema, session), value: create(SessionSchema, session),
}, },
}), }),
); });
return NextResponse.redirect(callbackUrl); return NextResponse.redirect(callbackUrl);
} else { } else {
// check for loginHint, userId hint and valid sessions // check for loginHint, userId hint and valid sessions
let selectedSession = await findValidSession(sessions, authRequest); let selectedSession = await findValidSession(
host,
sessions,
authRequest,
);
if (!selectedSession || !selectedSession.id) { if (!selectedSession || !selectedSession.id) {
return gotoAccounts(); return gotoAccounts();
@@ -517,15 +544,16 @@ export async function GET(request: NextRequest) {
}; };
try { try {
const { callbackUrl } = await createCallback( const { callbackUrl } = await createCallback({
create(CreateCallbackRequestSchema, { host,
req: create(CreateCallbackRequestSchema, {
authRequestId, authRequestId,
callbackKind: { callbackKind: {
case: "session", case: "session",
value: create(SessionSchema, session), value: create(SessionSchema, session),
}, },
}), }),
); });
if (callbackUrl) { if (callbackUrl) {
return NextResponse.redirect(callbackUrl); return NextResponse.redirect(callbackUrl);
} else { } else {