fix(login): host utility to provide correct host behind proxies (#10770)

<!--
Please inform yourself about the contribution guidelines on submitting a
PR here:
https://github.com/zitadel/zitadel/blob/main/CONTRIBUTING.md#submit-a-pull-request-pr.
Take note of how PR/commit titles should be written and replace the
template texts in the sections below. Don't remove any of the sections.
It is important that the commit history clearly shows what is changed
and why.
Important: By submitting a contribution you agree to the terms from our
Licensing Policy as described here:
https://github.com/zitadel/zitadel/blob/main/LICENSING.md#community-contributions.
-->

# Which Problems Are Solved

When deploying the login application behind proxies or using Vercel
rewrites (e.g., `zitadel.com/login` → `login-zitadel-qa.vercel.app`),
the application was using the internal rewritten host instead of the
original user-facing host. This caused several issues:

1. **Broken Password Reset Emails**: Email links contained internal
hosts like `login-zitadel-qa.vercel.app` instead of `zitadel.com`
2. **Inconsistent User Experience**: Users would see different domains
in various parts of the flow
3. **Security Concerns**: Internal infrastructure details were exposed
to end users
4. **Scattered Logic**: Host detection logic was duplicated across
multiple files with inconsistent error handling

# How the Problems Are Solved

Created comprehensive host detection utilities in `/lib/server/host.ts`
and `/lib/client/host.ts`:

**Server-side utilities:**
- `getOriginalHost()` - Returns the original user-facing host
- `getOriginalHostWithProtocol()` - Returns host with proper protocol
(http/https)
This commit is contained in:
Max Peintner
2025-09-23 18:21:01 +02:00
committed by GitHub
parent 16906d2c2c
commit 09d09ab337
24 changed files with 1324 additions and 840 deletions

View File

@@ -51,8 +51,7 @@ async function resolveOrganizationForUser({
serviceUrl,
domain: suffix,
});
const orgToCheckForDiscovery =
orgs.result && orgs.result.length === 1 ? orgs.result[0].id : undefined;
const orgToCheckForDiscovery = orgs.result && orgs.result.length === 1 ? orgs.result[0].id : undefined;
if (orgToCheckForDiscovery) {
const orgLoginSettings = await getLoginSettings({
@@ -141,12 +140,7 @@ export default async function Page(props: {
}
}
return loginSuccess(
userId,
{ idpIntentId: id, idpIntentToken: token },
requestId,
branding,
);
return loginSuccess(userId, { idpIntentId: id, idpIntentToken: token }, requestId, branding);
}
if (link) {
@@ -174,12 +168,7 @@ export default async function Page(props: {
if (!idpLink) {
return linkingFailed(branding);
} else {
return linkingSuccess(
userId,
{ idpIntentId: id, idpIntentToken: token },
requestId,
branding,
);
return linkingSuccess(userId, { idpIntentId: id, idpIntentToken: token }, requestId, branding);
}
}
@@ -230,12 +219,7 @@ export default async function Page(props: {
if (!idpLink) {
return linkingFailed(branding);
} else {
return linkingSuccess(
foundUser.userId,
{ idpIntentId: id, idpIntentToken: token },
requestId,
branding,
);
return linkingSuccess(foundUser.userId, { idpIntentId: id, idpIntentToken: token }, requestId, branding);
}
}
}
@@ -260,10 +244,7 @@ export default async function Page(props: {
organization: organizationSchema,
});
} else {
addHumanUserWithOrganization = create(
AddHumanUserRequestSchema,
addHumanUser,
);
addHumanUserWithOrganization = create(AddHumanUserRequestSchema, addHumanUser);
}
try {
@@ -272,16 +253,10 @@ export default async function Page(props: {
request: addHumanUserWithOrganization,
});
} catch (error: unknown) {
console.error(
"An error occurred while creating the user:",
error,
addHumanUser,
);
console.error("An error occurred while creating the user:", error, addHumanUser);
return loginFailed(
branding,
(error as ConnectError).message
? (error as ConnectError).message
: "Could not create user",
(error as ConnectError).message ? (error as ConnectError).message : "Could not create user",
);
}
} else if (options?.isCreationAllowed) {
@@ -325,11 +300,7 @@ export default async function Page(props: {
<p className="ztdl-p">
<Translated i18nKey="registerSuccess.description" namespace="idp" />
</p>
<IdpSignin
userId={newUser.userId}
idpIntent={{ idpIntentId: id, idpIntentToken: token }}
requestId={requestId}
/>
<IdpSignin userId={newUser.userId} idpIntent={{ idpIntentId: id, idpIntentToken: token }} requestId={requestId} />
</div>
</DynamicTheme>
);

View File

@@ -22,7 +22,7 @@ import { headers } from "next/headers";
export async function generateMetadata(): Promise<Metadata> {
const t = await getTranslations("mfa");
return { title: t('set.title')};
return { title: t("set.title") };
}
function isSessionValid(session: Partial<Session>): {
@@ -31,23 +31,19 @@ function isSessionValid(session: Partial<Session>): {
} {
const validPassword = session?.factors?.password?.verifiedAt;
const validPasskey = session?.factors?.webAuthN?.verifiedAt;
const stillValid = session.expirationDate
? timestampDate(session.expirationDate) > new Date()
: true;
const validIDP = session?.factors?.intent?.verifiedAt;
const stillValid = session.expirationDate ? timestampDate(session.expirationDate) > new Date() : true;
const verifiedAt = validPassword || validPasskey;
const valid = !!((validPassword || validPasskey) && stillValid);
const verifiedAt = validPassword || validPasskey || validIDP;
const valid = !!((validPassword || validPasskey || validIDP) && stillValid);
return { valid, verifiedAt };
}
export default async function Page(props: {
searchParams: Promise<Record<string | number | symbol, string | undefined>>;
}) {
export default async function Page(props: { searchParams: Promise<Record<string | number | symbol, string | undefined>> }) {
const searchParams = await props.searchParams;
const { loginName, checkAfter, force, requestId, organization, sessionId } =
searchParams;
const { loginName, checkAfter, force, requestId, organization, sessionId } = searchParams;
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
@@ -68,8 +64,7 @@ export default async function Page(props: {
userId,
}).then((methods) => {
return getUserByID({ serviceUrl, userId }).then((user) => {
const humanUser =
user.user?.type.case === "human" ? user.user?.type.value : undefined;
const humanUser = user.user?.type.case === "human" ? user.user?.type.value : undefined;
return {
id: session.id,
@@ -83,10 +78,7 @@ export default async function Page(props: {
});
}
async function loadSessionByLoginname(
loginName?: string,
organization?: string,
) {
async function loadSessionByLoginname(loginName?: string, organization?: string) {
return loadMostRecentSession({
serviceUrl,
sessionParams: {
@@ -152,24 +144,21 @@ export default async function Page(props: {
</Alert>
)}
{isSessionValid(sessionWithData).valid &&
loginSettings &&
sessionWithData &&
sessionWithData.factors?.user?.id && (
<ChooseSecondFactorToSetup
userId={sessionWithData.factors?.user?.id}
loginName={loginName}
sessionId={sessionWithData.id}
requestId={requestId}
organization={organization}
loginSettings={loginSettings}
userMethods={sessionWithData.authMethods ?? []}
phoneVerified={sessionWithData.phoneVerified ?? false}
emailVerified={sessionWithData.emailVerified ?? false}
checkAfter={checkAfter === "true"}
force={force === "true"}
></ChooseSecondFactorToSetup>
)}
{valid && loginSettings && sessionWithData && sessionWithData.factors?.user?.id && (
<ChooseSecondFactorToSetup
userId={sessionWithData.factors?.user?.id}
loginName={loginName}
sessionId={sessionWithData.id}
requestId={requestId}
organization={organization}
loginSettings={loginSettings}
userMethods={sessionWithData.authMethods ?? []}
phoneVerified={sessionWithData.phoneVerified ?? false}
emailVerified={sessionWithData.emailVerified ?? false}
checkAfter={checkAfter === "true"}
force={force === "true"}
></ChooseSecondFactorToSetup>
)}
<div className="mt-8 flex w-full flex-row items-center">
<BackButton />

View File

@@ -4,20 +4,17 @@ import { LoginOTP } from "@/components/login-otp";
import { Translated } from "@/components/translated";
import { UserAvatar } from "@/components/user-avatar";
import { getSessionCookieById } from "@/lib/cookies";
import { getOriginalHost } from "@/lib/server/host";
import { getServiceUrlFromHeaders } from "@/lib/service-url";
import { loadMostRecentSession } from "@/lib/session";
import {
getBrandingSettings,
getLoginSettings,
getSession,
} from "@/lib/zitadel";
import { getBrandingSettings, getLoginSettings, getSession } from "@/lib/zitadel";
import { Metadata } from "next";
import { getTranslations } from "next-intl/server";
import { headers } from "next/headers";
export async function generateMetadata(): Promise<Metadata> {
const t = await getTranslations("otp");
return { title: t('verify.title')};
return { title: t("verify.title") };
}
export default async function Page(props: {
@@ -29,11 +26,7 @@ export default async function Page(props: {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const host = await getOriginalHost();
const {
loginName, // send from password page
@@ -120,9 +113,7 @@ export default async function Page(props: {
loginName={loginName ?? session.factors?.user?.loginName}
sessionId={sessionId}
requestId={requestId}
organization={
organization ?? session?.factors?.user?.organizationId
}
organization={organization ?? session?.factors?.user?.organizationId}
method={method}
loginSettings={loginSettings}
host={host}

View File

@@ -13,23 +13,16 @@ import { headers } from "next/headers";
export async function generateMetadata(): Promise<Metadata> {
const t = await getTranslations("u2f");
return { title: t('verify.title')};
return { title: t("verify.title") };
}
export default async function Page(props: {
searchParams: Promise<Record<string | number | symbol, string | undefined>>;
}) {
export default async function Page(props: { searchParams: Promise<Record<string | number | symbol, string | undefined>> }) {
const searchParams = await props.searchParams;
const { loginName, requestId, sessionId, organization } = searchParams;
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const branding = await getBrandingSettings({
serviceUrl,
@@ -37,17 +30,13 @@ export default async function Page(props: {
});
const sessionFactors = sessionId
? await loadSessionById(serviceUrl, sessionId, organization)
? await loadSessionById(sessionId, organization)
: await loadMostRecentSession({
serviceUrl,
sessionParams: { loginName, organization },
});
async function loadSessionById(
host: string,
sessionId: string,
organization?: string,
) {
async function loadSessionById(sessionId: string, organization?: string) {
const recent = await getSessionCookieById({ sessionId, organization });
return getSession({
serviceUrl,

View File

@@ -4,6 +4,7 @@ import { Translated } from "@/components/translated";
import { UserAvatar } from "@/components/user-avatar";
import { VerifyForm } from "@/components/verify-form";
import { sendEmailCode, sendInviteEmailCode } from "@/lib/server/verify";
import { getOriginalHostWithProtocol } from "@/lib/server/host";
import { getServiceUrlFromHeaders } from "@/lib/service-url";
import { loadMostRecentSession } from "@/lib/session";
import { getBrandingSettings, getUserByID } from "@/lib/zitadel";
@@ -14,14 +15,13 @@ import { headers } from "next/headers";
export async function generateMetadata(): Promise<Metadata> {
const t = await getTranslations("verify");
return { title: t('verify.title')};
return { title: t("verify.title") };
}
export default async function Page(props: { searchParams: Promise<any> }) {
const searchParams = await props.searchParams;
const { userId, loginName, code, organization, requestId, invite, send } =
searchParams;
const { userId, loginName, code, organization, requestId, invite, send } = searchParams;
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
@@ -41,17 +41,13 @@ export default async function Page(props: { searchParams: Promise<any> }) {
const basePath = process.env.NEXT_PUBLIC_BASE_PATH ?? "";
async function sendEmail(userId: string) {
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const hostWithProtocol = await getOriginalHostWithProtocol();
if (invite === "true") {
await sendInviteEmailCode({
userId,
urlTemplate:
`${host.includes("localhost") ? "http://" : "https://"}${host}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}&invite=true` +
`${hostWithProtocol}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}&invite=true` +
(requestId ? `&requestId=${requestId}` : ""),
}).catch((error) => {
console.error("Could not send invitation email", error);
@@ -61,7 +57,7 @@ export default async function Page(props: { searchParams: Promise<any> }) {
await sendEmailCode({
userId,
urlTemplate:
`${host.includes("localhost") ? "http://" : "https://"}${host}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
`${hostWithProtocol}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
(requestId ? `&requestId=${requestId}` : ""),
}).catch((error) => {
console.error("Could not send verification email", error);
@@ -157,11 +153,7 @@ export default async function Page(props: { searchParams: Promise<any> }) {
></UserAvatar>
) : (
user && (
<UserAvatar
loginName={user.preferredLoginName}
displayName={human?.profile?.displayName}
showDropdown={false}
/>
<UserAvatar loginName={user.preferredLoginName} displayName={human?.profile?.displayName} showDropdown={false} />
)
)}

View File

@@ -1,14 +1,7 @@
import { getAllSessions } from "@/lib/cookies";
import { getServiceUrlFromHeaders } from "@/lib/service-url";
import {
validateAuthRequest,
isRSCRequest
} from "@/lib/auth-utils";
import {
handleOIDCFlowInitiation,
handleSAMLFlowInitiation,
FlowInitiationParams
} from "@/lib/server/flow-initiation";
import { validateAuthRequest, isRSCRequest } from "@/lib/auth-utils";
import { handleOIDCFlowInitiation, handleSAMLFlowInitiation, FlowInitiationParams } from "@/lib/server/flow-initiation";
import { listSessions } from "@/lib/zitadel";
import { Session } from "@zitadel/proto/zitadel/session/v2/session_pb";
import { headers } from "next/headers";
@@ -17,6 +10,8 @@ import { NextRequest, NextResponse } from "next/server";
export const dynamic = "force-dynamic";
export const revalidate = false;
export const fetchCache = "default-no-store";
// Add this to prevent RSC requests
export const runtime = "nodejs";
async function loadSessions({ serviceUrl, ids }: { serviceUrl: string; ids: string[] }): Promise<Session[]> {
const response = await listSessions({
@@ -41,10 +36,7 @@ export async function GET(request: NextRequest) {
// Early validation: if no valid request parameters, return error immediately
const requestId = validateAuthRequest(searchParams);
if (!requestId) {
return NextResponse.json(
{ error: "No valid authentication request found" },
{ status: 400 },
);
return NextResponse.json({ error: "No valid authentication request found" }, { status: 400 });
}
const sessionCookies = await getAllSessions();
@@ -69,14 +61,8 @@ export async function GET(request: NextRequest) {
return handleSAMLFlowInitiation(flowParams);
} else if (requestId.startsWith("device_")) {
// Device Authorization does not need to start here as it is handled on the /device endpoint
return NextResponse.json(
{ error: "Device authorization should use /device endpoint" },
{ status: 400 }
);
return NextResponse.json({ error: "Device authorization should use /device endpoint" }, { status: 400 });
} else {
return NextResponse.json(
{ error: "Invalid request ID format" },
{ status: 400 }
);
return NextResponse.json({ error: "Invalid request ID format" }, { status: 400 });
}
}

View File

@@ -1,8 +1,8 @@
"use client";
import { createNewSessionFromIdpIntent } from "@/lib/server/idp";
import { CreateNewSessionCommand, createNewSessionFromIdpIntent } from "@/lib/server/idp";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { Alert } from "./alert";
import { Spinner } from "./spinner";
@@ -16,25 +16,33 @@ type Props = {
requestId?: string;
};
export function IdpSignin({
userId,
idpIntent: { idpIntentId, idpIntentToken },
requestId,
}: Props) {
export function IdpSignin({ userId, idpIntent: { idpIntentId, idpIntentToken }, requestId }: Props) {
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const executedRef = useRef(false);
const router = useRouter();
useEffect(() => {
createNewSessionFromIdpIntent({
// Prevent double execution in React Strict Mode
if (executedRef.current) {
return;
}
executedRef.current = true;
let request: CreateNewSessionCommand = {
userId,
idpIntent: {
idpIntentId,
idpIntentToken,
},
requestId,
})
};
if (requestId) {
request = { ...request, requestId: requestId };
}
createNewSessionFromIdpIntent(request)
.then((response) => {
if (response && "error" in response && response?.error) {
setError(response?.error);

View File

@@ -1,7 +1,7 @@
"use client";
import { sendLoginname } from "@/lib/server/loginname";
import { clearSession, continueWithSession } from "@/lib/server/session";
import { clearSession, continueWithSession, ContinueWithSessionCommand } from "@/lib/server/session";
import { XCircleIcon } from "@heroicons/react/24/outline";
import * as Tooltip from "@radix-ui/react-tooltip";
import { Timestamp, timestampDate } from "@zitadel/client";
@@ -21,9 +21,7 @@ export function isSessionValid(session: Partial<Session>): {
const validPasskey = session?.factors?.webAuthN?.verifiedAt;
const validIDP = session?.factors?.intent?.verifiedAt;
const stillValid = session.expirationDate
? timestampDate(session.expirationDate) > new Date()
: true;
const stillValid = session.expirationDate ? timestampDate(session.expirationDate) > new Date() : true;
const verifiedAt = validPassword || validPasskey || validIDP;
const valid = !!((validPassword || validPasskey || validIDP) && stillValid);
@@ -31,15 +29,7 @@ export function isSessionValid(session: Partial<Session>): {
return { valid, verifiedAt };
}
export function SessionItem({
session,
reload,
requestId,
}: {
session: Session;
reload: () => void;
requestId?: string;
}) {
export function SessionItem({ session, reload, requestId }: { session: Session; reload: () => void; requestId?: string }) {
const currentLocale = useLocale();
moment.locale(currentLocale === "zh" ? "zh-cn" : currentLocale);
@@ -73,10 +63,21 @@ export function SessionItem({
<button
onClick={async () => {
if (valid && session?.factors?.user) {
await continueWithSession({
...session,
requestId: requestId,
});
const sessionPayload: ContinueWithSessionCommand = session;
if (requestId) {
sessionPayload.requestId = requestId;
}
const callbackResponse = await continueWithSession(sessionPayload);
if (callbackResponse && "error" in callbackResponse) {
setError(callbackResponse.error);
return;
}
if (callbackResponse && "redirect" in callbackResponse) {
return router.push(callbackResponse.redirect);
}
} else if (session.factors?.user) {
setLoading(true);
const res = await sendLoginname({
@@ -114,9 +115,7 @@ export function SessionItem({
<div className="flex flex-col items-start overflow-hidden">
<span className="">{session.factors?.user?.displayName}</span>
<span className="text-ellipsis text-xs opacity-80">
{session.factors?.user?.loginName}
</span>
<span className="text-ellipsis text-xs opacity-80">{session.factors?.user?.loginName}</span>
{valid ? (
<span className="text-ellipsis text-xs opacity-80">
<Translated i18nKey="verified" namespace="accounts" />{" "}
@@ -126,8 +125,7 @@ export function SessionItem({
verifiedAt && (
<span className="text-ellipsis text-xs opacity-80">
<Translated i18nKey="expired" namespace="accounts" />{" "}
{session.expirationDate &&
moment(timestampDate(session.expirationDate)).fromNow()}
{session.expirationDate && moment(timestampDate(session.expirationDate)).fromNow()}
</span>
)
)}

View File

@@ -35,6 +35,7 @@ export async function loginWithOIDCAndSession({
console.log("Session is valid:", isValid);
if (!isValid && selectedSession.factors?.user) {
console.log("Session is not valid, need to re-authenticate user");
// if the session is not valid anymore, we need to redirect the user to re-authenticate /
// TODO: handle IDP intent direcly if available
const command: SendLoginnameCommand = {

View File

@@ -0,0 +1,297 @@
import { describe, expect, test, vi, beforeEach, afterEach } from "vitest";
import { getOriginalHost, getOriginalHostWithProtocol } from "./host";
// Mock the Next.js headers function
vi.mock("next/headers", () => ({
headers: vi.fn(),
}));
describe("Host utility functions", () => {
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.restoreAllMocks();
});
describe("getOriginalHost", () => {
test("should return x-forwarded-host when available", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "x-forwarded-host") return "zitadel.com";
if (key === "x-original-host") return "backup.com";
if (key === "host") return "internal.vercel.app";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHost();
expect(result).toBe("zitadel.com");
expect(mockHeaders.get).toHaveBeenCalledWith("x-forwarded-host");
});
test("should fall back to x-original-host when x-forwarded-host is not available", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "x-forwarded-host") return null;
if (key === "x-original-host") return "original.com";
if (key === "host") return "internal.vercel.app";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHost();
expect(result).toBe("original.com");
expect(mockHeaders.get).toHaveBeenCalledWith("x-forwarded-host");
expect(mockHeaders.get).toHaveBeenCalledWith("x-original-host");
});
test("should fall back to host when forwarded headers are not available", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "x-forwarded-host") return null;
if (key === "x-original-host") return null;
if (key === "host") return "fallback.com";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHost();
expect(result).toBe("fallback.com");
expect(mockHeaders.get).toHaveBeenCalledWith("x-forwarded-host");
expect(mockHeaders.get).toHaveBeenCalledWith("x-original-host");
expect(mockHeaders.get).toHaveBeenCalledWith("host");
});
test("should throw error when no host is found", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => null),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
await expect(getOriginalHost()).rejects.toThrow("No host found in headers");
});
test("should throw error when host is empty string", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => ""),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
await expect(getOriginalHost()).rejects.toThrow("No host found in headers");
});
test("should throw error when host is not a string", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => 123),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
await expect(getOriginalHost()).rejects.toThrow("No host found in headers");
});
});
describe("getOriginalHostWithProtocol", () => {
test("should return https for production domain", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "zitadel.com"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://zitadel.com");
});
test("should return http for localhost", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "localhost:3000"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("http://localhost:3000");
});
test("should return http for localhost without port", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "localhost"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("http://localhost");
});
test("should return https for custom domain", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "auth.company.com"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://auth.company.com");
});
});
describe("Real-world scenarios", () => {
test("should handle Vercel rewrite scenario", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
// Simulate Vercel rewrite: zitadel.com/login -> login-zitadel-qa.vercel.app
if (key === "x-forwarded-host") return "zitadel.com";
if (key === "host") return "login-zitadel-qa.vercel.app";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://zitadel.com");
});
test("should handle CloudFlare proxy scenario", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "x-forwarded-host") return "auth.company.com";
if (key === "x-original-host") return null;
if (key === "host") return "cloudflare-worker.workers.dev";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHost();
expect(result).toBe("auth.company.com");
});
test("should handle development environment", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "host") return "localhost:3000";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("http://localhost:3000");
});
test("should handle staging environment with subdomain", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
if (key === "x-forwarded-host") return "staging-auth.company.com";
if (key === "host") return "staging-internal.vercel.app";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://staging-auth.company.com");
});
});
describe("Edge cases", () => {
test("should handle IPv4 addresses", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "192.168.1.100:3000"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://192.168.1.100:3000");
});
test("should handle IPv6 addresses", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "[::1]:3000"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://[::1]:3000");
});
test("should handle hosts with ports", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "zitadel.com:8080"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("https://zitadel.com:8080");
});
test("should handle localhost with different ports", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn(() => "localhost:8080"),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHostWithProtocol();
expect(result).toBe("http://localhost:8080");
});
test("should handle priority order correctly", async () => {
const { headers } = await import("next/headers");
const mockHeaders = {
get: vi.fn((key: string) => {
// All headers are present, should return x-forwarded-host (highest priority)
if (key === "x-forwarded-host") return "priority1.com";
if (key === "x-original-host") return "priority2.com";
if (key === "host") return "priority3.com";
return null;
}),
};
vi.mocked(headers).mockResolvedValue(mockHeaders as any);
const result = await getOriginalHost();
expect(result).toBe("priority1.com");
// Should only call x-forwarded-host since it's available
expect(mockHeaders.get).toHaveBeenCalledWith("x-forwarded-host");
expect(mockHeaders.get).toHaveBeenCalledTimes(1);
});
});
});

View File

@@ -0,0 +1,48 @@
import { headers } from "next/headers";
/**
* Gets the original host that the user sees in their browser URL.
* When using rewrites this function prioritizes forwarded headers that preserve the original host.
*
* ⚠️ SERVER-SIDE ONLY: This function can only be used in:
* - Server Actions (functions with "use server")
* - Server Components (React components that run on the server)
* - Route Handlers (API routes)
* - Middleware
*
* @returns The host string (e.g., "zitadel.com")
* @throws Error if no host is found
*/
export async function getOriginalHost(): Promise<string> {
const _headers = await headers();
// Priority order:
// 1. x-forwarded-host - Set by proxies/CDNs with the original host
// 2. x-original-host - Alternative header sometimes used
// 3. host - Fallback to the current host header
const host = _headers.get("x-forwarded-host") || _headers.get("x-original-host") || _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found in headers");
}
return host;
}
/**
* Gets the original host with protocol prefix.
* Automatically detects if localhost should use http:// or https://
*
* ⚠️ SERVER-SIDE ONLY: This function can only be used in:
* - Server Actions (functions with "use server")
* - Server Components (React components that run on the server)
* - Route Handlers (API routes)
* - Middleware
*
* @returns The full URL prefix (e.g., "https://zitadel.com")
*/
export async function getOriginalHostWithProtocol(): Promise<string> {
const host = await getOriginalHost();
const protocol = host.includes("localhost") ? "http://" : "https://";
return `${protocol}${host}`;
}

View File

@@ -13,16 +13,14 @@ import { completeFlowOrGetUrl } from "../client";
import { getServiceUrlFromHeaders } from "../service-url";
import { checkEmailVerification, checkMFAFactors } from "../verify-helper";
import { createSessionForIdpAndUpdateCookie } from "./cookie";
import { getOriginalHost } from "./host";
export type RedirectToIdpState = { error?: string | null } | undefined;
export async function redirectToIdp(prevState: RedirectToIdpState, formData: FormData): Promise<RedirectToIdpState> {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "Could not get host" };
}
const host = await getOriginalHost();
const params = new URLSearchParams();
@@ -88,7 +86,7 @@ async function startIDPFlow(command: StartIDPFlowCommand) {
return { redirect: url };
}
type CreateNewSessionCommand = {
export type CreateNewSessionCommand = {
userId: string;
idpIntent: {
idpIntentId: string;
@@ -104,11 +102,6 @@ export async function createNewSessionFromIdpIntent(command: CreateNewSessionCom
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "Could not get domain" };
}
if (!command.userId || !command.idpIntent) {
throw new Error("No userId or loginName provided");
@@ -160,19 +153,17 @@ export async function createNewSessionFromIdpIntent(command: CreateNewSessionCom
}
}
if (authMethods) {
const mfaFactorCheck = await checkMFAFactors(
serviceUrl,
session,
loginSettings,
authMethods,
command.organization,
command.requestId,
);
const mfaFactorCheck = await checkMFAFactors(
serviceUrl,
session,
loginSettings,
authMethods || [], // Pass empty array if no auth methods
command.organization,
command.requestId,
);
if (mfaFactorCheck?.redirect) {
return mfaFactorCheck;
}
if (mfaFactorCheck?.redirect) {
return mfaFactorCheck;
}
return completeFlowOrGetUrl(
@@ -201,11 +192,6 @@ export async function createNewSessionForLDAP(command: createNewSessionForLDAPCo
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "Could not get domain" };
}
if (!command.username || !command.password) {
return { error: "No username or password provided" };

View File

@@ -21,6 +21,7 @@ import {
startIdentityProviderFlow,
} from "../zitadel";
import { createSessionAndUpdateCookie } from "./cookie";
import { getOriginalHost } from "./host";
export type SendLoginnameCommand = {
loginName: string;
@@ -34,11 +35,6 @@ const ORG_SUFFIX_REGEX = /(?<=@)(.+)/;
export async function sendLoginname(command: SendLoginnameCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
throw new Error("Could not get domain");
}
const loginSettingsByContext = await getLoginSettings({
serviceUrl,
@@ -80,11 +76,7 @@ export async function sendLoginname(command: SendLoginnameCommand) {
if (identityProviders.length === 1) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "Could not get host" };
}
const host = await getOriginalHost();
const identityProviderType = identityProviders[0].type;
@@ -134,11 +126,7 @@ export async function sendLoginname(command: SendLoginnameCommand) {
if (identityProviders.length === 1) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "Could not get host" };
}
const host = await getOriginalHost();
const identityProviderId = identityProviders[0].idpId;

View File

@@ -23,6 +23,7 @@ import { getMostRecentSessionCookie, getSessionCookieById, getSessionCookieByLog
import { getServiceUrlFromHeaders } from "../service-url";
import { checkEmailVerification, checkUserVerification } from "../verify-helper";
import { setSessionAndUpdateCookie } from "./cookie";
import { getOriginalHost } from "./host";
type VerifyPasskeyCommand = {
passkeyId: string;
@@ -56,11 +57,7 @@ export async function registerPasskeyLink(
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
throw new Error("Could not get domain");
}
const host = await getOriginalHost();
const sessionCookie = await getSessionCookieById({ sessionId });
const session = await getSession({

View File

@@ -23,6 +23,7 @@ import { headers } from "next/headers";
import { completeFlowOrGetUrl } from "../client";
import { getSessionCookieById, getSessionCookieByLoginName } from "../cookies";
import { getServiceUrlFromHeaders } from "../service-url";
import { getOriginalHostWithProtocol } from "./host";
import {
checkEmailVerification,
checkMFAFactors,
@@ -40,11 +41,9 @@ type ResetPasswordCommand = {
export async function resetPassword(command: ResetPasswordCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
// Get the original host that the user sees with protocol
const hostWithProtocol = await getOriginalHostWithProtocol();
const users = await listUsers({
serviceUrl,
@@ -63,7 +62,7 @@ export async function resetPassword(command: ResetPasswordCommand) {
serviceUrl,
userId,
urlTemplate:
`${host.includes("localhost") ? "http://" : "https://"}${host}${basePath}/password/set?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
`${hostWithProtocol}${basePath}/password/set?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
(command.requestId ? `&requestId=${command.requestId}` : ""),
});
}

View File

@@ -1,14 +1,14 @@
"use server";
import { createSessionAndUpdateCookie, createSessionForIdpAndUpdateCookie } from "@/lib/server/cookie";
import { addHumanUser, addIDPLink, getLoginSettings, getUserByID } from "@/lib/zitadel";
import { addHumanUser, addIDPLink, getLoginSettings, getUserByID, listAuthenticationMethodTypes } from "@/lib/zitadel";
import { create } from "@zitadel/client";
import { Factors } from "@zitadel/proto/zitadel/session/v2/session_pb";
import { ChecksJson, ChecksSchema } from "@zitadel/proto/zitadel/session/v2/session_service_pb";
import { headers } from "next/headers";
import { completeFlowOrGetUrl } from "../client";
import { getServiceUrlFromHeaders } from "../service-url";
import { checkEmailVerification } from "../verify-helper";
import { checkEmailVerification, checkMFAFactors } from "../verify-helper";
type RegisterUserCommand = {
email: string;
@@ -27,11 +27,6 @@ export type RegisterUserResponse = {
export async function registerUser(command: RegisterUserCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const addResponse = await addHumanUser({
serviceUrl,
@@ -147,13 +142,8 @@ export type registerUserAndLinkToIDPResponse = {
export async function registerUserAndLinkToIDP(command: RegisterUserAndLinkToIDPommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const addResponse = await addHumanUser({
const addUserResponse = await addHumanUser({
serviceUrl,
email: command.email,
firstName: command.firstName,
@@ -161,7 +151,7 @@ export async function registerUserAndLinkToIDP(command: RegisterUserAndLinkToIDP
organization: command.organization,
});
if (!addResponse) {
if (!addUserResponse) {
return { error: "Could not create user" };
}
@@ -177,7 +167,7 @@ export async function registerUserAndLinkToIDP(command: RegisterUserAndLinkToIDP
userId: command.idpUserId,
userName: command.idpUserName,
},
userId: addResponse.userId,
userId: addUserResponse.userId,
});
if (!idpLink) {
@@ -186,7 +176,7 @@ export async function registerUserAndLinkToIDP(command: RegisterUserAndLinkToIDP
const session = await createSessionForIdpAndUpdateCookie({
requestId: command.requestId,
userId: addResponse.userId, // the user we just created
userId: addUserResponse.userId, // the user we just created
idpIntent: command.idpIntent,
lifetime: loginSettings?.externalLoginCheckLifetime,
});
@@ -195,6 +185,51 @@ export async function registerUserAndLinkToIDP(command: RegisterUserAndLinkToIDP
return { error: "Could not create session" };
}
// const userResponse = await getUserByID({
// serviceUrl,
// userId: session?.factors?.user?.id,
// });
// if (!userResponse.user) {
// return { error: "User not found in the system" };
// }
// const humanUser = userResponse.user.type.case === "human" ? userResponse.user.type.value : undefined;
// check to see if user was verified
// const emailVerificationCheck = checkEmailVerification(session, humanUser, command.organization, command.requestId);
// if (emailVerificationCheck?.redirect) {
// return emailVerificationCheck;
// }
// check if user has MFA methods
let authMethods;
if (session.factors?.user?.id) {
const response = await listAuthenticationMethodTypes({
serviceUrl,
userId: session.factors.user.id,
});
if (response.authMethodTypes && response.authMethodTypes.length) {
authMethods = response.authMethodTypes;
}
}
// Always check MFA factors, even if no auth methods are configured
// This ensures that force MFA settings are respected
const mfaFactorCheck = await checkMFAFactors(
serviceUrl,
session,
loginSettings,
authMethods || [], // Pass empty array if no auth methods
command.organization,
command.requestId,
);
if (mfaFactorCheck?.redirect) {
return mfaFactorCheck;
}
return completeFlowOrGetUrl(
command.requestId && session.id
? {

View File

@@ -21,6 +21,7 @@ import {
removeSessionFromCookie,
} from "../cookies";
import { getServiceUrlFromHeaders } from "../service-url";
import { getOriginalHost } from "./host";
export async function skipMFAAndContinueWithNextUrl({
userId,
@@ -67,7 +68,9 @@ export async function skipMFAAndContinueWithNextUrl({
return { error: "Could not skip MFA and continue" };
}
export async function continueWithSession({ requestId, ...session }: Session & { requestId?: string }) {
export type ContinueWithSessionCommand = Session & { requestId?: string };
export async function continueWithSession({ requestId, ...session }: ContinueWithSessionCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
@@ -122,7 +125,7 @@ export async function updateSession(options: UpdateSessionCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
const host = await getOriginalHost();
if (!host) {
return { error: "Could not get host" };

View File

@@ -7,6 +7,7 @@ import { headers } from "next/headers";
import { userAgent } from "next/server";
import { getSessionCookieById } from "../cookies";
import { getServiceUrlFromHeaders } from "../service-url";
import { getOriginalHost } from "./host";
type RegisterU2FCommand = {
sessionId: string;
@@ -22,11 +23,7 @@ type VerifyU2FCommand = {
export async function addU2F(command: RegisterU2FCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
const host = await getOriginalHost();
const sessionCookie = await getSessionCookieById({
sessionId: command.sessionId,
@@ -60,12 +57,6 @@ export async function addU2F(command: RegisterU2FCommand) {
export async function verifyU2F(command: VerifyU2FCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host || typeof host !== "string") {
throw new Error("No host found");
}
let passkeyName = command.passkeyName;
if (!passkeyName) {
const headersList = await headers();

View File

@@ -24,6 +24,7 @@ import { getServiceUrlFromHeaders } from "../service-url";
import { loadMostRecentSession } from "../session";
import { checkMFAFactors } from "../verify-helper";
import { createSessionAndUpdateCookie } from "./cookie";
import { getOriginalHostWithProtocol } from "./host";
export async function verifyTOTP(code: string, loginName?: string, organization?: string) {
const _headers = await headers();
@@ -250,11 +251,7 @@ type resendVerifyEmailCommand = {
export async function resendVerification(command: resendVerifyEmailCommand) {
const _headers = await headers();
const { serviceUrl } = getServiceUrlFromHeaders(_headers);
const host = _headers.get("host");
if (!host) {
return { error: "No host found" };
}
const hostWithProtocol = await getOriginalHostWithProtocol();
const basePath = process.env.NEXT_PUBLIC_BASE_PATH ?? "";
@@ -263,7 +260,7 @@ export async function resendVerification(command: resendVerifyEmailCommand) {
serviceUrl,
userId: command.userId,
urlTemplate:
`${host.includes("localhost") ? "http://" : "https://"}${host}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}&invite=true` +
`${hostWithProtocol}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}&invite=true` +
(command.requestId ? `&requestId=${command.requestId}` : ""),
}).catch((error) => {
if (error.code === 9) {
@@ -275,7 +272,7 @@ export async function resendVerification(command: resendVerifyEmailCommand) {
userId: command.userId,
serviceUrl,
urlTemplate:
`${host.includes("localhost") ? "http://" : "https://"}${host}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
`${hostWithProtocol}${basePath}/verify?code={{.Code}}&userId={{.UserID}}&organization={{.OrgID}}` +
(command.requestId ? `&requestId=${command.requestId}` : ""),
});
}

View File

@@ -5,6 +5,8 @@
* - Session expiration checks
* - User presence validation
* - Authentication factor verification (password, passkey, IDP)
* - MFA validation using the shared shouldEnforceMFA function from verify-helper
* - Passkey authentication inherently satisfies MFA requirements
* - MFA validation with configured authentication methods (TOTP, OTP Email/SMS, U2F)
* - MFA validation with login settings (forceMfa, forceMfaLocalOnly)
* - Email verification when EMAIL_VERIFICATION environment variable is enabled
@@ -15,6 +17,7 @@ import { timestampDate } from "@zitadel/client";
import { AuthenticationMethodType } from "@zitadel/proto/zitadel/user/v2/user_service_pb";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { isSessionValid } from "./session";
import * as verifyHelperModule from "./verify-helper";
import * as zitadelModule from "./zitadel";
// Mock the zitadel client timestampDate function
@@ -29,6 +32,11 @@ vi.mock("./zitadel", () => ({
getUserByID: vi.fn(),
}));
// Mock the verify-helper module
vi.mock("./verify-helper", () => ({
shouldEnforceMFA: vi.fn(),
}));
// Mock environment variables
const originalEnv = process.env;
@@ -221,15 +229,21 @@ describe("isSessionValid", () => {
},
});
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [AuthenticationMethodType.PASSWORD, AuthenticationMethodType.TOTP],
} as any);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: true,
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(true);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(false);
expect(consoleSpy).toHaveBeenCalledWith("Session has no valid multifactor", expect.any(Object));
expect(consoleSpy).toHaveBeenCalledWith("Session has no valid MFA factor. Configured methods:", expect.any(Array), "Session factors:", expect.any(Object));
consoleSpy.mockRestore();
});
@@ -349,6 +363,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
@@ -384,6 +400,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
@@ -420,6 +438,8 @@ describe("isSessionValid", () => {
authMethodTypes: [AuthenticationMethodType.TOTP],
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(true);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(false);
@@ -470,6 +490,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
// This should be true - if it's false, the original bug still exists
@@ -508,6 +530,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
// With our fix, this should be true (session is valid)
@@ -543,6 +567,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
@@ -576,6 +602,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(true);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(false);
@@ -650,6 +678,50 @@ describe("isSessionValid", () => {
expect(result).toBe(true);
});
test("should return false when forceMfaLocalOnly is enabled for password authentication but MFA not satisfied", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
user: {
id: mockUserId,
organizationId: mockOrganizationId,
loginName: "test@example.com",
displayName: "Test User",
verifiedAt: verifiedTimestamp,
},
password: {
verifiedAt: verifiedTimestamp,
},
// No MFA factors verified
},
});
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [AuthenticationMethodType.TOTP],
} as any);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: false,
forceMfaLocalOnly: true,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(true);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(false);
expect(zitadelModule.getLoginSettings).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
organization: mockOrganizationId,
});
expect(zitadelModule.listAuthenticationMethodTypes).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
userId: mockUserId,
});
consoleSpy.mockRestore();
});
});
describe("email verification", () => {
@@ -682,6 +754,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
vi.mocked(zitadelModule.getUserByID).mockResolvedValue({
user: {
type: {
@@ -734,6 +808,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
vi.mocked(zitadelModule.getUserByID).mockResolvedValue({
user: {
type: {
@@ -781,6 +857,8 @@ describe("isSessionValid", () => {
forceMfaLocalOnly: false,
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
@@ -824,7 +902,7 @@ describe("isSessionValid", () => {
});
describe("IDP authentication", () => {
test("should return true when authenticated with IDP intent", async () => {
test("should return true when authenticated with IDP intent and no MFA required", async () => {
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
@@ -842,21 +920,28 @@ describe("isSessionValid", () => {
},
});
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [],
} as any);
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: false,
forceMfaLocalOnly: false,
} as any);
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [],
} as any);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
expect(verifyHelperModule.shouldEnforceMFA).toHaveBeenCalledWith(session, expect.any(Object));
expect(zitadelModule.getLoginSettings).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
organization: mockOrganizationId,
});
});
test("should return true when authenticated with IDP intent even with forced MFA", async () => {
test("should return false when authenticated with IDP intent but MFA required and not satisfied", async () => {
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
@@ -870,11 +955,13 @@ describe("isSessionValid", () => {
intent: {
verifiedAt: verifiedTimestamp,
},
// No password factor, no MFA factors
// No password factor, no MFA factors verified
},
});
// Organization enforces MFA
// shouldEnforceMFA returns true (MFA is required for this session)
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(true);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: true,
forceMfaLocalOnly: false,
@@ -885,12 +972,141 @@ describe("isSessionValid", () => {
authMethodTypes: [AuthenticationMethodType.TOTP, AuthenticationMethodType.OTP_EMAIL],
} as any);
// Should still return true because IDP bypasses MFA requirements
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(false);
expect(verifyHelperModule.shouldEnforceMFA).toHaveBeenCalledWith(session, expect.any(Object));
expect(zitadelModule.getLoginSettings).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
organization: mockOrganizationId,
});
expect(zitadelModule.listAuthenticationMethodTypes).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
userId: mockUserId,
});
});
test("should return true when authenticated with IDP intent and forceMfaLocalOnly (IDP bypasses local-only MFA)", async () => {
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
user: {
id: mockUserId,
organizationId: mockOrganizationId,
loginName: "test@example.com",
displayName: "Test User",
verifiedAt: verifiedTimestamp,
},
intent: {
verifiedAt: verifiedTimestamp,
},
// No password factor, no MFA factors verified
},
});
// shouldEnforceMFA returns false (IDP bypasses forceMfaLocalOnly)
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: false,
forceMfaLocalOnly: true,
} as any);
// User has MFA methods configured but none verified
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [AuthenticationMethodType.TOTP, AuthenticationMethodType.OTP_EMAIL],
} as any);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
// Verify that getLoginSettings was not called since IDP should bypass MFA check entirely
expect(zitadelModule.getLoginSettings).not.toHaveBeenCalled();
expect(verifyHelperModule.shouldEnforceMFA).toHaveBeenCalledWith(session, expect.any(Object));
expect(zitadelModule.getLoginSettings).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
organization: mockOrganizationId,
});
// Should not call listAuthenticationMethodTypes since shouldEnforceMFA returned false
expect(zitadelModule.listAuthenticationMethodTypes).not.toHaveBeenCalled();
});
test("should return true when authenticated with IDP intent and MFA required and satisfied", async () => {
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
user: {
id: mockUserId,
organizationId: mockOrganizationId,
loginName: "test@example.com",
displayName: "Test User",
verifiedAt: verifiedTimestamp,
},
intent: {
verifiedAt: verifiedTimestamp,
},
totp: {
verifiedAt: verifiedTimestamp,
},
},
});
// Organization enforces MFA
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: true,
forceMfaLocalOnly: false,
} as any);
// User has TOTP configured and verified
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [AuthenticationMethodType.TOTP],
} as any);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
});
});
describe("passkey authentication", () => {
test("should return true when authenticated with passkey and MFA required (passkey satisfies MFA)", async () => {
const verifiedTimestamp = createMockTimestamp();
const session = createMockSession({
factors: {
user: {
id: mockUserId,
organizationId: mockOrganizationId,
loginName: "test@example.com",
displayName: "Test User",
verifiedAt: verifiedTimestamp,
},
webAuthN: {
verifiedAt: verifiedTimestamp,
},
// No password factor, no additional MFA factors
},
});
// shouldEnforceMFA returns false (passkey satisfies MFA requirements)
vi.mocked(verifyHelperModule.shouldEnforceMFA).mockReturnValue(false);
vi.mocked(zitadelModule.getLoginSettings).mockResolvedValue({
forceMfa: true,
forceMfaLocalOnly: false,
} as any);
// User has MFA methods configured but none verified (passkey should satisfy MFA)
vi.mocked(zitadelModule.listAuthenticationMethodTypes).mockResolvedValue({
authMethodTypes: [AuthenticationMethodType.TOTP],
} as any);
const result = await isSessionValid({ serviceUrl: mockServiceUrl, session });
expect(result).toBe(true);
expect(verifyHelperModule.shouldEnforceMFA).toHaveBeenCalledWith(session, expect.any(Object));
expect(zitadelModule.getLoginSettings).toHaveBeenCalledWith({
serviceUrl: mockServiceUrl,
organization: mockOrganizationId,
});
// Should not call listAuthenticationMethodTypes since shouldEnforceMFA returned false
expect(zitadelModule.listAuthenticationMethodTypes).not.toHaveBeenCalled();
});
});

View File

@@ -5,6 +5,7 @@ import { Session } from "@zitadel/proto/zitadel/session/v2/session_pb";
import { GetSessionResponse } from "@zitadel/proto/zitadel/session/v2/session_service_pb";
import { AuthenticationMethodType } from "@zitadel/proto/zitadel/user/v2/user_service_pb";
import { getMostRecentCookieWithLoginname } from "./cookies";
import { shouldEnforceMFA } from "./verify-helper";
import { getLoginSettings, getSession, getUserByID, listAuthenticationMethodTypes } from "./zitadel";
type LoadMostRecentSessionParams = {
@@ -44,76 +45,72 @@ export async function isSessionValid({ serviceUrl, session }: { serviceUrl: stri
let mfaValid = true;
// Check if user authenticated via IDP - if so, skip MFA validation entirely
// Check if user authenticated via different methods
const validIDP = session?.factors?.intent?.verifiedAt;
if (validIDP) {
// IDP authentication bypasses MFA requirements
mfaValid = true;
} else {
// Get login settings to determine if MFA is actually required by policy
const loginSettings = await getLoginSettings({
const validPassword = session?.factors?.password?.verifiedAt;
const validPasskey = session?.factors?.webAuthN?.verifiedAt;
// Get login settings to determine if MFA is actually required by policy
const loginSettings = await getLoginSettings({
serviceUrl,
organization: session.factors?.user?.organizationId,
});
// Use the existing shouldEnforceMFA function to determine if MFA is required
const isMfaRequired = shouldEnforceMFA(session, loginSettings);
// Only enforce MFA validation if MFA is required by policy
if (isMfaRequired) {
const authMethodTypes = await listAuthenticationMethodTypes({
serviceUrl,
organization: session.factors?.user?.organizationId,
userId: session.factors.user.id,
});
const isMfaRequired = loginSettings?.forceMfa || loginSettings?.forceMfaLocalOnly;
const authMethods = authMethodTypes.authMethodTypes;
// Filter to only MFA methods (exclude PASSWORD and PASSKEY)
const mfaMethods = authMethods?.filter(
(method) =>
method === AuthenticationMethodType.TOTP ||
method === AuthenticationMethodType.OTP_EMAIL ||
method === AuthenticationMethodType.OTP_SMS ||
method === AuthenticationMethodType.U2F,
);
// Only enforce MFA validation if MFA is required by policy
if (isMfaRequired) {
const authMethodTypes = await listAuthenticationMethodTypes({
serviceUrl,
userId: session.factors.user.id,
});
if (mfaMethods && mfaMethods.length > 0) {
// Check if any of the configured MFA methods have been verified
const totpValid = mfaMethods.includes(AuthenticationMethodType.TOTP) && !!session.factors.totp?.verifiedAt;
const otpEmailValid =
mfaMethods.includes(AuthenticationMethodType.OTP_EMAIL) && !!session.factors.otpEmail?.verifiedAt;
const otpSmsValid = mfaMethods.includes(AuthenticationMethodType.OTP_SMS) && !!session.factors.otpSms?.verifiedAt;
const u2fValid = mfaMethods.includes(AuthenticationMethodType.U2F) && !!session.factors.webAuthN?.verifiedAt;
const authMethods = authMethodTypes.authMethodTypes;
// Filter to only MFA methods (exclude PASSWORD and PASSKEY)
const mfaMethods = authMethods?.filter(
(method) =>
method === AuthenticationMethodType.TOTP ||
method === AuthenticationMethodType.OTP_EMAIL ||
method === AuthenticationMethodType.OTP_SMS ||
method === AuthenticationMethodType.U2F,
);
mfaValid = totpValid || otpEmailValid || otpSmsValid || u2fValid;
if (mfaMethods && mfaMethods.length > 0) {
// Check if any of the configured MFA methods have been verified
const totpValid = mfaMethods.includes(AuthenticationMethodType.TOTP) && !!session.factors.totp?.verifiedAt;
const otpEmailValid =
mfaMethods.includes(AuthenticationMethodType.OTP_EMAIL) && !!session.factors.otpEmail?.verifiedAt;
const otpSmsValid = mfaMethods.includes(AuthenticationMethodType.OTP_SMS) && !!session.factors.otpSms?.verifiedAt;
const u2fValid = mfaMethods.includes(AuthenticationMethodType.U2F) && !!session.factors.webAuthN?.verifiedAt;
if (!mfaValid) {
console.warn("Session has no valid MFA factor. Configured methods:", mfaMethods, "Session factors:", {
totp: session.factors.totp?.verifiedAt,
otpEmail: session.factors.otpEmail?.verifiedAt,
otpSms: session.factors.otpSms?.verifiedAt,
webAuthN: session.factors.webAuthN?.verifiedAt,
});
}
} else {
// No specific MFA methods configured, but MFA is forced - check for any verified MFA factors
// (excluding IDP which should be handled separately)
const otpEmail = session.factors.otpEmail?.verifiedAt;
const otpSms = session.factors.otpSms?.verifiedAt;
const totp = session.factors.totp?.verifiedAt;
const webAuthN = session.factors.webAuthN?.verifiedAt;
// Note: Removed IDP (session.factors.intent?.verifiedAt) as requested
mfaValid = totpValid || otpEmailValid || otpSmsValid || u2fValid;
if (!mfaValid) {
console.warn("Session has no valid MFA factor. Configured methods:", mfaMethods, "Session factors:", {
totp: session.factors.totp?.verifiedAt,
otpEmail: session.factors.otpEmail?.verifiedAt,
otpSms: session.factors.otpSms?.verifiedAt,
webAuthN: session.factors.webAuthN?.verifiedAt,
});
}
} else {
// No specific MFA methods configured, but MFA is forced - check for any verified MFA factors
// (excluding IDP which should be handled separately)
const otpEmail = session.factors.otpEmail?.verifiedAt;
const otpSms = session.factors.otpSms?.verifiedAt;
const totp = session.factors.totp?.verifiedAt;
const webAuthN = session.factors.webAuthN?.verifiedAt;
// Note: Removed IDP (session.factors.intent?.verifiedAt) as requested
mfaValid = !!(otpEmail || otpSms || totp || webAuthN);
if (!mfaValid) {
console.warn("Session has no valid multifactor", session.factors);
}
mfaValid = !!(otpEmail || otpSms || totp || webAuthN);
if (!mfaValid) {
console.warn("Session has no valid multifactor", session.factors);
}
}
}
// If MFA is not required by policy, mfaValid remains true
const validPassword = session?.factors?.password?.verifiedAt;
const validPasskey = session?.factors?.webAuthN?.verifiedAt;
// validIDP already declared above for IDP bypass logic
// If MFA is not required by policy, mfaValid remains true
const stillValid = session.expirationDate ? timestampDate(session.expirationDate).getTime() > new Date().getTime() : true;

View File

@@ -0,0 +1,322 @@
import { describe, it, expect, beforeEach } from "vitest";
import { shouldEnforceMFA } from "./verify-helper";
// Mock function to create timestamps - following the same pattern as session.test.ts
function createMockTimestamp(offsetMs = 3600000): any {
return {
seconds: BigInt(Math.floor((Date.now() + offsetMs) / 1000)),
nanos: 0,
};
}
// Mock function to create a basic session - following the same pattern as session.test.ts
function createMockSession(overrides: any = {}): any {
const futureTimestamp = createMockTimestamp();
const defaultSession = {
id: "test-session-id",
factors: {
user: {
id: "test-user-id",
loginName: "test@example.com",
displayName: "Test User",
organizationId: "test-org-id",
verifiedAt: futureTimestamp,
},
},
...overrides,
};
return defaultSession;
}
// Mock function to create login settings
function createMockLoginSettings(overrides: any = {}): any {
return {
forceMfa: false,
forceMfaLocalOnly: false,
...overrides,
};
}
describe("shouldEnforceMFA", () => {
let mockSession: any;
let mockLoginSettings: any;
beforeEach(() => {
mockSession = createMockSession();
mockLoginSettings = createMockLoginSettings();
});
describe("when loginSettings is undefined", () => {
it("should return false", () => {
const result = shouldEnforceMFA(mockSession, undefined);
expect(result).toBe(false);
});
});
describe("passkey authentication", () => {
beforeEach(() => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
webAuthN: {
verifiedAt: createMockTimestamp(),
userVerified: true,
},
},
});
});
it("should return false when user authenticated with passkey, even with forceMfa enabled", () => {
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
it("should return false when user authenticated with passkey, even with forceMfaLocalOnly enabled", () => {
mockLoginSettings = createMockLoginSettings({ forceMfaLocalOnly: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
it("should return false when user authenticated with passkey and both force settings enabled", () => {
mockLoginSettings = createMockLoginSettings({
forceMfa: true,
forceMfaLocalOnly: true,
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
it("should return true when passkey is not user verified", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
webAuthN: {
verifiedAt: createMockTimestamp(),
userVerified: false, // Not user verified
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
// Should return true because passkey is not user verified, so it doesn't count as passkey auth
expect(result).toBe(true);
});
});
describe("forceMfa setting", () => {
beforeEach(() => {
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
});
it("should return true when forceMfa is enabled and user authenticated with password", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
password: {
verifiedAt: createMockTimestamp(),
},
},
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
it("should return true when forceMfa is enabled and user authenticated with IDP", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
intent: {
verifiedAt: createMockTimestamp(),
},
},
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
it("should return true when forceMfa is enabled with no specific authentication method", () => {
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
});
describe("forceMfaLocalOnly setting", () => {
beforeEach(() => {
mockLoginSettings = createMockLoginSettings({ forceMfaLocalOnly: true });
});
it("should return true when forceMfaLocalOnly is enabled and user authenticated with password", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
password: {
verifiedAt: createMockTimestamp(),
},
},
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
it("should return false when forceMfaLocalOnly is enabled and user authenticated with IDP", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
intent: {
verifiedAt: createMockTimestamp(),
},
},
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
it("should return false when forceMfaLocalOnly is enabled with no specific authentication method", () => {
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
});
describe("mixed authentication scenarios", () => {
it("should prioritize passkey over password when both are present", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
password: {
verifiedAt: createMockTimestamp(),
},
webAuthN: {
verifiedAt: createMockTimestamp(),
userVerified: true,
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false); // Passkey should override password
});
it("should prioritize passkey over IDP when both are present", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
intent: {
verifiedAt: createMockTimestamp(),
},
webAuthN: {
verifiedAt: createMockTimestamp(),
userVerified: true,
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfaLocalOnly: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false); // Passkey should override IDP
});
it("should handle password + IDP scenario with forceMfaLocalOnly", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
password: {
verifiedAt: createMockTimestamp(),
},
intent: {
verifiedAt: createMockTimestamp(),
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfaLocalOnly: true });
// With both password and IDP, the current logic should return false for IDP
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
});
describe("no MFA enforcement", () => {
it("should return false when neither forceMfa nor forceMfaLocalOnly is enabled", () => {
mockLoginSettings = createMockLoginSettings({
forceMfa: false,
forceMfaLocalOnly: false,
});
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(false);
});
});
describe("edge cases", () => {
it("should handle session with no factors", () => {
mockSession = createMockSession({
factors: undefined,
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
it("should handle session with empty factors", () => {
mockSession = createMockSession({
factors: {
user: {
id: "test-user-id",
loginName: "test@example.com",
displayName: "Test User",
organizationId: "test-org-id",
verifiedAt: createMockTimestamp(),
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true);
});
it("should handle webAuthN factor without userVerified", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
webAuthN: {
verifiedAt: createMockTimestamp(),
userVerified: false,
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true); // Should require MFA since it's not a proper passkey
});
it("should handle webAuthN factor without verifiedAt", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
webAuthN: {
userVerified: true,
// verifiedAt is undefined
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true); // Should require MFA since webAuthN wasn't actually verified
});
it("should handle webAuthN factor with verifiedAt but no userVerified property", () => {
mockSession = createMockSession({
factors: {
...mockSession.factors,
webAuthN: {
verifiedAt: createMockTimestamp(),
// userVerified is undefined (should be falsy)
},
},
});
mockLoginSettings = createMockLoginSettings({ forceMfa: true });
const result = shouldEnforceMFA(mockSession, mockLoginSettings);
expect(result).toBe(true); // Should require MFA since userVerified is falsy
});
});
});

View File

@@ -20,9 +20,7 @@ export function checkPasswordChangeRequired(
let isOutdated = false;
if (expirySettings?.maxAgeDays && humanUser?.passwordChanged) {
const maxAgeDays = Number(expirySettings.maxAgeDays); // Convert bigint to number
const passwordChangedDate = moment(
timestampDate(humanUser.passwordChanged),
);
const passwordChangedDate = moment(timestampDate(humanUser.passwordChanged));
const outdatedPassword = passwordChangedDate.add(maxAgeDays, "days");
isOutdated = moment().isAfter(outdatedPassword);
}
@@ -33,10 +31,7 @@ export function checkPasswordChangeRequired(
});
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
session.factors?.user?.organizationId as string,
);
params.append("organization", session.factors?.user?.organizationId as string);
}
if (requestId) {
@@ -47,12 +42,7 @@ export function checkPasswordChangeRequired(
}
}
export function checkEmailVerified(
session: Session,
humanUser?: HumanUser,
organization?: string,
requestId?: string,
) {
export function checkEmailVerified(session: Session, humanUser?: HumanUser, organization?: string, requestId?: string) {
if (!humanUser?.email?.isVerified) {
const paramsVerify = new URLSearchParams({
loginName: session.factors?.user?.loginName as string,
@@ -61,10 +51,7 @@ export function checkEmailVerified(
});
if (organization || session.factors?.user?.organizationId) {
paramsVerify.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
paramsVerify.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
if (requestId) {
@@ -75,16 +62,8 @@ export function checkEmailVerified(
}
}
export function checkEmailVerification(
session: Session,
humanUser?: HumanUser,
organization?: string,
requestId?: string,
) {
if (
!humanUser?.email?.isVerified &&
process.env.EMAIL_VERIFICATION === "true"
) {
export function checkEmailVerification(session: Session, humanUser?: HumanUser, organization?: string, requestId?: string) {
if (!humanUser?.email?.isVerified && process.env.EMAIL_VERIFICATION === "true") {
const params = new URLSearchParams({
loginName: session.factors?.user?.loginName as string,
send: "true", // set this to true as we dont expect old email codes to be valid anymore
@@ -95,10 +74,7 @@ export function checkEmailVerification(
}
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
params.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
return { redirect: `/verify?` + params };
@@ -113,15 +89,23 @@ export async function checkMFAFactors(
organization?: string,
requestId?: string,
) {
console.log("checkMFAFactors called with session:", {
sessionId: session.id,
userId: session.factors?.user?.id,
loginName: session.factors?.user?.loginName,
hasIntentFactor: !!session.factors?.intent?.verifiedAt,
hasPasswordFactor: !!session.factors?.password?.verifiedAt,
hasWebAuthNFactor: !!session.factors?.webAuthN?.verifiedAt,
});
const availableMultiFactors = authMethods?.filter(
(m: AuthenticationMethodType) =>
m !== AuthenticationMethodType.PASSWORD &&
m !== AuthenticationMethodType.PASSKEY,
m === AuthenticationMethodType.TOTP ||
m === AuthenticationMethodType.OTP_SMS ||
m === AuthenticationMethodType.OTP_EMAIL ||
m === AuthenticationMethodType.U2F,
);
const hasAuthenticatedWithPasskey =
session.factors?.webAuthN?.verifiedAt &&
session.factors?.webAuthN?.userVerified;
const hasAuthenticatedWithPasskey = session.factors?.webAuthN?.verifiedAt && session.factors?.webAuthN?.userVerified;
// escape further checks if user has authenticated with passkey
if (hasAuthenticatedWithPasskey) {
@@ -139,10 +123,7 @@ export async function checkMFAFactors(
}
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
params.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
const factor = availableMultiFactors[0];
@@ -166,59 +147,50 @@ export async function checkMFAFactors(
}
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
params.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
return { redirect: `/mfa?` + params };
} else if (
(loginSettings?.forceMfa || loginSettings?.forceMfaLocalOnly) &&
!availableMultiFactors.length
) {
} else if (shouldEnforceMFA(session, loginSettings) && !availableMultiFactors.length) {
const params = new URLSearchParams({
loginName: session.factors?.user?.loginName as string,
force: "true", // this defines if the mfa is forced in the settings
checkAfter: "true", // this defines if the check is directly made after the setup
});
if (session.id) {
params.append("sessionId", session.id);
}
if (requestId) {
params.append("requestId", requestId);
}
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
params.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
// TODO: provide a way to setup passkeys on mfa page?
return { redirect: `/mfa/set?` + params };
} else if (
loginSettings?.mfaInitSkipLifetime &&
(loginSettings.mfaInitSkipLifetime.nanos > 0 ||
loginSettings.mfaInitSkipLifetime.seconds > 0) &&
(loginSettings.mfaInitSkipLifetime.nanos > 0 || loginSettings.mfaInitSkipLifetime.seconds > 0) &&
!availableMultiFactors.length &&
session?.factors?.user?.id
session?.factors?.user?.id &&
shouldEnforceMFA(session, loginSettings)
) {
const userResponse = await getUserByID({
serviceUrl,
userId: session.factors?.user?.id,
});
const humanUser =
userResponse?.user?.type.case === "human"
? userResponse?.user.type.value
: undefined;
const humanUser = userResponse?.user?.type.case === "human" ? userResponse?.user.type.value : undefined;
if (humanUser?.mfaInitSkipped) {
const mfaInitSkippedTimestamp = timestampDate(humanUser.mfaInitSkipped);
const mfaInitSkipLifetimeMillis =
Number(loginSettings.mfaInitSkipLifetime.seconds) * 1000 +
loginSettings.mfaInitSkipLifetime.nanos / 1000000;
Number(loginSettings.mfaInitSkipLifetime.seconds) * 1000 + loginSettings.mfaInitSkipLifetime.nanos / 1000000;
const currentTime = Date.now();
const mfaInitSkippedTime = mfaInitSkippedTimestamp.getTime();
const timeDifference = currentTime - mfaInitSkippedTime;
@@ -237,15 +209,16 @@ export async function checkMFAFactors(
checkAfter: "true", // this defines if the check is directly made after the setup
});
if (session.id) {
params.append("sessionId", session.id);
}
if (requestId) {
params.append("requestId", requestId);
}
if (organization || session.factors?.user?.organizationId) {
params.append(
"organization",
organization ?? (session.factors?.user?.organizationId as string),
);
params.append("organization", organization ?? (session.factors?.user?.organizationId as string));
}
// TODO: provide a way to setup passkeys on mfa page?
@@ -253,6 +226,52 @@ export async function checkMFAFactors(
}
}
/**
* Determines if MFA should be enforced based on the authentication method used and login settings
* @param session - The current session
* @param loginSettings - The login settings containing MFA enforcement rules
* @returns true if MFA should be enforced, false otherwise
*/
export function shouldEnforceMFA(session: Session, loginSettings: LoginSettings | undefined): boolean {
if (!loginSettings) {
return false;
}
// Check if user authenticated with passkey (passkeys are inherently multi-factor)
const authenticatedWithPasskey = session.factors?.webAuthN?.verifiedAt && session.factors?.webAuthN?.userVerified;
// If user authenticated with passkey, MFA is not required regardless of settings
if (authenticatedWithPasskey) {
return false;
}
// If forceMfa is enabled, MFA is required for ALL authentication methods (except passkeys)
if (loginSettings.forceMfa) {
return true;
}
// If forceMfaLocalOnly is enabled, MFA is only required for local/password authentication
if (loginSettings.forceMfaLocalOnly) {
// Check if user authenticated with password (local authentication)
const authenticatedWithPassword = !!session.factors?.password?.verifiedAt;
// Check if user authenticated with IDP (external authentication)
const authenticatedWithIDP = !!session.factors?.intent?.verifiedAt;
// If user authenticated with IDP, MFA is not required for forceMfaLocalOnly
if (authenticatedWithIDP) {
return false;
}
// If user authenticated with password, MFA is required for forceMfaLocalOnly
if (authenticatedWithPassword) {
return true;
}
}
return false;
}
export async function checkUserVerification(userId: string): Promise<boolean> {
// check if a verification was done earlier
const cookiesList = await cookies();
@@ -264,24 +283,17 @@ export async function checkUserVerification(userId: string): Promise<boolean> {
return false;
}
const verificationCheck = crypto
.createHash("sha256")
.update(`${userId}:${fingerPrintCookie.value}`)
.digest("hex");
const verificationCheck = crypto.createHash("sha256").update(`${userId}:${fingerPrintCookie.value}`).digest("hex");
const cookieValue = await cookiesList.get("verificationCheck")?.value;
if (!cookieValue) {
console.warn(
"User verification check cookie not found. User verification check failed.",
);
console.warn("User verification check cookie not found. User verification check failed.");
return false;
}
if (cookieValue !== verificationCheck) {
console.warn(
`User verification check failed. Expected ${verificationCheck} but got ${cookieValue}`,
);
console.warn(`User verification check failed. Expected ${verificationCheck} but got ${cookieValue}`);
return false;
}

View File

@@ -2,40 +2,19 @@ import { Client, create, Duration } from "@zitadel/client";
import { createServerTransport as libCreateServerTransport } from "@zitadel/client/node";
import { makeReqCtx } from "@zitadel/client/v2";
import { IdentityProviderService } from "@zitadel/proto/zitadel/idp/v2/idp_service_pb";
import {
OrganizationSchema,
TextQueryMethod,
} from "@zitadel/proto/zitadel/object/v2/object_pb";
import {
CreateCallbackRequest,
OIDCService,
} from "@zitadel/proto/zitadel/oidc/v2/oidc_service_pb";
import { OrganizationSchema, TextQueryMethod } from "@zitadel/proto/zitadel/object/v2/object_pb";
import { CreateCallbackRequest, OIDCService } from "@zitadel/proto/zitadel/oidc/v2/oidc_service_pb";
import { Organization } from "@zitadel/proto/zitadel/org/v2/org_pb";
import { OrganizationService } from "@zitadel/proto/zitadel/org/v2/org_service_pb";
import {
CreateResponseRequest,
SAMLService,
} from "@zitadel/proto/zitadel/saml/v2/saml_service_pb";
import { CreateResponseRequest, SAMLService } from "@zitadel/proto/zitadel/saml/v2/saml_service_pb";
import { RequestChallenges } from "@zitadel/proto/zitadel/session/v2/challenge_pb";
import {
Checks,
SessionService,
} from "@zitadel/proto/zitadel/session/v2/session_service_pb";
import { Checks, SessionService } from "@zitadel/proto/zitadel/session/v2/session_service_pb";
import { LoginSettings } from "@zitadel/proto/zitadel/settings/v2/login_settings_pb";
import { SettingsService } from "@zitadel/proto/zitadel/settings/v2/settings_service_pb";
import { SendEmailVerificationCodeSchema } from "@zitadel/proto/zitadel/user/v2/email_pb";
import type {
FormData,
RedirectURLsJson,
} from "@zitadel/proto/zitadel/user/v2/idp_pb";
import {
NotificationType,
SendPasswordResetLinkSchema,
} from "@zitadel/proto/zitadel/user/v2/password_pb";
import {
SearchQuery,
SearchQuerySchema,
} from "@zitadel/proto/zitadel/user/v2/query_pb";
import type { FormData, RedirectURLsJson } from "@zitadel/proto/zitadel/user/v2/idp_pb";
import { NotificationType, SendPasswordResetLinkSchema } from "@zitadel/proto/zitadel/user/v2/password_pb";
import { SearchQuery, SearchQuerySchema } from "@zitadel/proto/zitadel/user/v2/query_pb";
import { SendInviteCodeSchema } from "@zitadel/proto/zitadel/user/v2/user_pb";
import {
AddHumanUserRequest,
@@ -73,8 +52,7 @@ export async function getHostedLoginTranslation({
organization?: string;
locale?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getHostedLoginTranslation(
@@ -99,15 +77,8 @@ export async function getHostedLoginTranslation({
return useCache ? cacheWrapper(callback) : callback;
}
export async function getBrandingSettings({
serviceUrl,
organization,
}: {
serviceUrl: string;
organization?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getBrandingSettings({ serviceUrl, organization }: { serviceUrl: string; organization?: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getBrandingSettings({ ctx: makeReqCtx(organization) }, {})
@@ -116,15 +87,8 @@ export async function getBrandingSettings({
return useCache ? cacheWrapper(callback) : callback;
}
export async function getLoginSettings({
serviceUrl,
organization,
}: {
serviceUrl: string;
organization?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getLoginSettings({ serviceUrl, organization }: { serviceUrl: string; organization?: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getLoginSettings({ ctx: makeReqCtx(organization) }, {})
@@ -133,30 +97,16 @@ export async function getLoginSettings({
return useCache ? cacheWrapper(callback) : callback;
}
export async function getSecuritySettings({
serviceUrl,
}: {
serviceUrl: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getSecuritySettings({ serviceUrl }: { serviceUrl: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getSecuritySettings({})
.then((resp) => (resp.settings ? resp.settings : undefined));
const callback = settingsService.getSecuritySettings({}).then((resp) => (resp.settings ? resp.settings : undefined));
return useCache ? cacheWrapper(callback) : callback;
}
export async function getLockoutSettings({
serviceUrl,
orgId,
}: {
serviceUrl: string;
orgId?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getLockoutSettings({ serviceUrl, orgId }: { serviceUrl: string; orgId?: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getLockoutSettings({ ctx: makeReqCtx(orgId) }, {})
@@ -165,15 +115,8 @@ export async function getLockoutSettings({
return useCache ? cacheWrapper(callback) : callback;
}
export async function getPasswordExpirySettings({
serviceUrl,
orgId,
}: {
serviceUrl: string;
orgId?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getPasswordExpirySettings({ serviceUrl, orgId }: { serviceUrl: string; orgId?: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getPasswordExpirySettings({ ctx: makeReqCtx(orgId) }, {})
@@ -182,77 +125,34 @@ export async function getPasswordExpirySettings({
return useCache ? cacheWrapper(callback) : callback;
}
export async function listIDPLinks({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function listIDPLinks({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.listIDPLinks({ userId }, {});
}
export async function addOTPEmail({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function addOTPEmail({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.addOTPEmail({ userId }, {});
}
export async function addOTPSMS({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function addOTPSMS({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.addOTPSMS({ userId }, {});
}
export async function registerTOTP({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function registerTOTP({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.registerTOTP({ userId }, {});
}
export async function getGeneralSettings({
serviceUrl,
}: {
serviceUrl: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
export async function getGeneralSettings({ serviceUrl }: { serviceUrl: string }) {
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getGeneralSettings({}, {})
.then((resp) => resp.supportedLanguages);
const callback = settingsService.getGeneralSettings({}, {}).then((resp) => resp.supportedLanguages);
return useCache ? cacheWrapper(callback) : callback;
}
@@ -264,8 +164,7 @@ export async function getLegalAndSupportSettings({
serviceUrl: string;
organization?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getLegalAndSupportSettings({ ctx: makeReqCtx(organization) }, {})
@@ -281,8 +180,7 @@ export async function getPasswordComplexitySettings({
serviceUrl: string;
organization?: string;
}) {
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
const callback = settingsService
.getPasswordComplexitySettings({ ctx: makeReqCtx(organization) })
@@ -300,8 +198,7 @@ export async function createSessionFromChecks({
checks: Checks;
lifetime: Duration;
}) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
const userAgent = await getUserAgent();
@@ -322,8 +219,8 @@ export async function createSessionForUserIdAndIdpIntent({
};
lifetime: Duration;
}) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
console.log("Creating session for userId and IDP intent", { userId, idpIntent, lifetime });
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
const userAgent = await getUserAgent();
@@ -357,8 +254,7 @@ export async function setSession({
checks?: Checks;
lifetime: Duration;
}) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
return sessionService.setSession(
{
@@ -382,8 +278,7 @@ export async function getSession({
sessionId: string;
sessionToken: string;
}) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
return sessionService.getSession({ sessionId, sessionToken }, {});
}
@@ -397,8 +292,7 @@ export async function deleteSession({
sessionId: string;
sessionToken: string;
}) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
return sessionService.deleteSession({ sessionId, sessionToken }, {});
}
@@ -409,8 +303,7 @@ type ListSessionsCommand = {
};
export async function listSessions({ serviceUrl, ids }: ListSessionsCommand) {
const sessionService: Client<typeof SessionService> =
await createServiceForHost(SessionService, serviceUrl);
const sessionService: Client<typeof SessionService> = await createServiceForHost(SessionService, serviceUrl);
return sessionService.listSessions(
{
@@ -436,36 +329,21 @@ export type AddHumanUserData = {
organization: string;
};
export async function addHumanUser({
serviceUrl,
email,
firstName,
lastName,
password,
organization,
}: AddHumanUserData) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function addHumanUser({ serviceUrl, email, firstName, lastName, password, organization }: AddHumanUserData) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
let addHumanUserRequest: AddHumanUserRequest = create(
AddHumanUserRequestSchema,
{
email: {
email,
verification: {
case: "isVerified",
value: false,
},
let addHumanUserRequest: AddHumanUserRequest = create(AddHumanUserRequestSchema, {
email: {
email,
verification: {
case: "isVerified",
value: false,
},
username: email,
profile: { givenName: firstName, familyName: lastName },
passwordType: password
? { case: "password", value: { password } }
: undefined,
},
);
username: email,
profile: { givenName: firstName, familyName: lastName },
passwordType: password ? { case: "password", value: { password } } : undefined,
});
if (organization) {
const organizationSchema = create(OrganizationSchema, {
@@ -481,32 +359,14 @@ export async function addHumanUser({
return userService.addHumanUser(addHumanUserRequest);
}
export async function addHuman({
serviceUrl,
request,
}: {
serviceUrl: string;
request: AddHumanUserRequest;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function addHuman({ serviceUrl, request }: { serviceUrl: string; request: AddHumanUserRequest }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.addHumanUser(request);
}
export async function updateHuman({
serviceUrl,
request,
}: {
serviceUrl: string;
request: UpdateHumanUserRequest;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function updateHuman({ serviceUrl, request }: { serviceUrl: string; request: UpdateHumanUserRequest }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.updateHumanUser(request);
}
@@ -520,40 +380,19 @@ export async function verifyTOTPRegistration({
code: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.verifyTOTPRegistration({ code, userId }, {});
}
export async function getUserByID({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function getUserByID({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.getUserByID({ userId }, {});
}
export async function humanMFAInitSkipped({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function humanMFAInitSkipped({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.humanMFAInitSkipped({ userId }, {});
}
@@ -567,10 +406,7 @@ export async function verifyInviteCode({
userId: string;
verificationCode: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.verifyInviteCode({ userId, verificationCode }, {});
}
@@ -596,10 +432,7 @@ export async function sendEmailCode({
},
});
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.sendEmailCode(medium, {});
}
@@ -622,10 +455,7 @@ export async function createInviteCode({
urlTemplate,
};
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.createInviteCode(
{
@@ -648,14 +478,7 @@ export type ListUsersCommand = {
organizationId?: string;
};
export async function listUsers({
serviceUrl,
loginName,
userName,
phone,
email,
organizationId,
}: ListUsersCommand) {
export async function listUsers({ serviceUrl, loginName, userName, phone, email, organizationId }: ListUsersCommand) {
const queries: SearchQuery[] = [];
// either use loginName or userName, email, phone
@@ -738,10 +561,7 @@ export async function listUsers({
);
}
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.listUsers({ queries });
}
@@ -791,13 +611,7 @@ const EmailQuery = (searchValue: string) =>
* this is a dedicated search function to search for users from the loginname page
* it searches users based on the loginName or userName and org suffix combination, and falls back to email and phone if no users are found
* */
export async function searchUsers({
serviceUrl,
searchValue,
loginSettings,
organizationId,
suffix,
}: SearchUsersCommand) {
export async function searchUsers({ serviceUrl, searchValue, loginSettings, organizationId, suffix }: SearchUsersCommand) {
const queries: SearchQuery[] = [];
// if a suffix is provided, we search for the userName concatenated with the suffix
@@ -823,10 +637,7 @@ export async function searchUsers({
);
}
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
const loginNameResult = await userService.listUsers({ queries });
@@ -843,10 +654,7 @@ export async function searchUsers({
}
const emailAndPhoneQueries: SearchQuery[] = [];
if (
loginSettings.disableLoginWithEmail &&
loginSettings.disableLoginWithPhone
) {
if (loginSettings.disableLoginWithEmail && loginSettings.disableLoginWithPhone) {
return { error: "User not found in the system" };
} else if (loginSettings.disableLoginWithEmail && searchValue.length <= 20) {
const phoneQuery = PhoneQuery(searchValue);
@@ -910,13 +718,8 @@ export async function searchUsers({
return { error: "User not found in the system" };
}
export async function getDefaultOrg({
serviceUrl,
}: {
serviceUrl: string;
}): Promise<Organization | null> {
const orgService: Client<typeof OrganizationService> =
await createServiceForHost(OrganizationService, serviceUrl);
export async function getDefaultOrg({ serviceUrl }: { serviceUrl: string }): Promise<Organization | null> {
const orgService: Client<typeof OrganizationService> = await createServiceForHost(OrganizationService, serviceUrl);
return orgService
.listOrganizations(
@@ -935,15 +738,8 @@ export async function getDefaultOrg({
.then((resp) => (resp?.result && resp.result[0] ? resp.result[0] : null));
}
export async function getOrgsByDomain({
serviceUrl,
domain,
}: {
serviceUrl: string;
domain: string;
}) {
const orgService: Client<typeof OrganizationService> =
await createServiceForHost(OrganizationService, serviceUrl);
export async function getOrgsByDomain({ serviceUrl, domain }: { serviceUrl: string; domain: string }) {
const orgService: Client<typeof OrganizationService> = await createServiceForHost(OrganizationService, serviceUrl);
return orgService.listOrganizations(
{
@@ -969,10 +765,7 @@ export async function startIdentityProviderFlow({
idpId: string;
urls: RedirectURLsJson;
}): Promise<string | null> {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService
.startIdentityProviderIntent({
@@ -999,10 +792,7 @@ export async function startIdentityProviderFlow({
});
const stringifiedFields = JSON.stringify(formData.fields);
console.log(
"Successfully stringified formData.fields, length:",
stringifiedFields.length,
);
console.log("Successfully stringified formData.fields, length:", stringifiedFields.length);
// Check cookie size limits (typical limit is 4KB)
if (stringifiedFields.length > 4000) {
@@ -1038,10 +828,7 @@ export async function startLDAPIdentityProviderFlow({
username: string;
password: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.startIdentityProviderIntent({
idpId,
@@ -1055,13 +842,7 @@ export async function startLDAPIdentityProviderFlow({
});
}
export async function getAuthRequest({
serviceUrl,
authRequestId,
}: {
serviceUrl: string;
authRequestId: string;
}) {
export async function getAuthRequest({ serviceUrl, authRequestId }: { serviceUrl: string; authRequestId: string }) {
const oidcService = await createServiceForHost(OIDCService, serviceUrl);
return oidcService.getAuthRequest({
@@ -1069,13 +850,7 @@ export async function getAuthRequest({
});
}
export async function getDeviceAuthorizationRequest({
serviceUrl,
userCode,
}: {
serviceUrl: string;
userCode: string;
}) {
export async function getDeviceAuthorizationRequest({ serviceUrl, userCode }: { serviceUrl: string; userCode: string }) {
const oidcService = await createServiceForHost(OIDCService, serviceUrl);
return oidcService.getDeviceAuthorizationRequest({
@@ -1108,25 +883,13 @@ export async function authorizeOrDenyDeviceAuthorization({
});
}
export async function createCallback({
serviceUrl,
req,
}: {
serviceUrl: string;
req: CreateCallbackRequest;
}) {
export async function createCallback({ serviceUrl, req }: { serviceUrl: string; req: CreateCallbackRequest }) {
const oidcService = await createServiceForHost(OIDCService, serviceUrl);
return oidcService.createCallback(req);
}
export async function getSAMLRequest({
serviceUrl,
samlRequestId,
}: {
serviceUrl: string;
samlRequestId: string;
}) {
export async function getSAMLRequest({ serviceUrl, samlRequestId }: { serviceUrl: string; samlRequestId: string }) {
const samlService = await createServiceForHost(SAMLService, serviceUrl);
return samlService.getSAMLRequest({
@@ -1134,13 +897,7 @@ export async function getSAMLRequest({
});
}
export async function createResponse({
serviceUrl,
req,
}: {
serviceUrl: string;
req: CreateResponseRequest;
}) {
export async function createResponse({ serviceUrl, req }: { serviceUrl: string; req: CreateResponseRequest }) {
const samlService = await createServiceForHost(SAMLService, serviceUrl);
return samlService.createResponse(req);
@@ -1155,10 +912,7 @@ export async function verifyEmail({
userId: string;
verificationCode: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.verifyEmail(
{
@@ -1188,43 +942,19 @@ export async function resendEmailCode({
request = { ...request, verification: { case: "sendCode", value: medium } };
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.resendEmailCode(request, {});
}
export async function retrieveIDPIntent({
serviceUrl,
id,
token,
}: {
serviceUrl: string;
id: string;
token: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function retrieveIDPIntent({ serviceUrl, id, token }: { serviceUrl: string; id: string; token: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.retrieveIdentityProviderIntent(
{ idpIntentId: id, idpIntentToken: token },
{},
);
return userService.retrieveIdentityProviderIntent({ idpIntentId: id, idpIntentToken: token }, {});
}
export async function getIDPByID({
serviceUrl,
id,
}: {
serviceUrl: string;
id: string;
}) {
const idpService: Client<typeof IdentityProviderService> =
await createServiceForHost(IdentityProviderService, serviceUrl);
export async function getIDPByID({ serviceUrl, id }: { serviceUrl: string; id: string }) {
const idpService: Client<typeof IdentityProviderService> = await createServiceForHost(IdentityProviderService, serviceUrl);
return idpService.getIDPByID({ id }, {}).then((resp) => resp.idp);
}
@@ -1238,10 +968,7 @@ export async function addIDPLink({
idp: { id: string; userId: string; userName: string };
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.addIDPLink(
{
@@ -1274,10 +1001,7 @@ export async function passwordReset({
urlTemplate,
};
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.passwordReset(
{
@@ -1319,10 +1043,7 @@ export async function setUserPassword({
};
}
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.setPassword(payload, {}).catch((error) => {
// throw error if failed precondition (ex. User is not yet initialized)
@@ -1334,17 +1055,8 @@ export async function setUserPassword({
});
}
export async function setPassword({
serviceUrl,
payload,
}: {
serviceUrl: string;
payload: SetPasswordRequest;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function setPassword({ serviceUrl, payload }: { serviceUrl: string; payload: SetPasswordRequest }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.setPassword(payload, {});
}
@@ -1355,17 +1067,8 @@ export async function setPassword({
* @param userId the id of the user where the email should be set
* @returns the newly set email
*/
export async function createPasskeyRegistrationLink({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function createPasskeyRegistrationLink({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.createPasskeyRegistrationLink({
userId,
@@ -1383,19 +1086,8 @@ export async function createPasskeyRegistrationLink({
* @param domain the domain on which the factor is registered
* @returns the newly set email
*/
export async function registerU2F({
serviceUrl,
userId,
domain,
}: {
serviceUrl: string;
userId: string;
domain: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function registerU2F({ serviceUrl, userId, domain }: { serviceUrl: string; userId: string; domain: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.registerU2F({
userId,
@@ -1416,10 +1108,7 @@ export async function verifyU2FRegistration({
serviceUrl: string;
request: VerifyU2FRegistrationRequest;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.verifyU2FRegistration(request, {});
}
@@ -1444,8 +1133,7 @@ export async function getActiveIdentityProviders({
if (linking_allowed) {
props.linkingAllowed = linking_allowed;
}
const settingsService: Client<typeof SettingsService> =
await createServiceForHost(SettingsService, serviceUrl);
const settingsService: Client<typeof SettingsService> = await createServiceForHost(SettingsService, serviceUrl);
return settingsService.getActiveIdentityProviders(props, {});
}
@@ -1463,10 +1151,7 @@ export async function verifyPasskeyRegistration({
serviceUrl: string;
request: VerifyPasskeyRegistrationRequest;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.verifyPasskeyRegistration(request, {});
}
@@ -1490,10 +1175,7 @@ export async function registerPasskey({
code: { id: string; code: string };
domain: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.registerPasskey({
userId,
@@ -1508,17 +1190,8 @@ export async function registerPasskey({
* @param userId the id of the user where the email should be set
* @returns the list of authentication method types
*/
export async function listAuthenticationMethodTypes({
serviceUrl,
userId,
}: {
serviceUrl: string;
userId: string;
}) {
const userService: Client<typeof UserService> = await createServiceForHost(
UserService,
serviceUrl,
);
export async function listAuthenticationMethodTypes({ serviceUrl, userId }: { serviceUrl: string; userId: string }) {
const userService: Client<typeof UserService> = await createServiceForHost(UserService, serviceUrl);
return userService.listAuthenticationMethodTypes({
userId,
@@ -1533,16 +1206,14 @@ export function createServerTransport(token: string, baseUrl: string) {
: [
(next) => {
return (req) => {
process.env
.CUSTOM_REQUEST_HEADERS!.split(",")
.forEach((header) => {
const kv = header.split(":");
if (kv.length === 2) {
req.header.set(kv[0].trim(), kv[1].trim());
} else {
console.warn(`Skipping malformed header: ${header}`);
}
});
process.env.CUSTOM_REQUEST_HEADERS!.split(",").forEach((header) => {
const kv = header.split(":");
if (kv.length === 2) {
req.header.set(kv[0].trim(), kv[1].trim());
} else {
console.warn(`Skipping malformed header: ${header}`);
}
});
return next(req);
};
},