diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 13e7c0dee7..c864c650a7 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -25,6 +25,7 @@ env: internal/api/assets/router.go openapi/v2 pkg/grpc/**/*.pb.* + pkg/grpc/**/*.connect.go jobs: build: diff --git a/.gitignore b/.gitignore index 23469d4209..0aa6cc1976 100644 --- a/.gitignore +++ b/.gitignore @@ -52,7 +52,8 @@ console/src/app/proto/generated/ !pkg/grpc/protoc/v2/options.pb.go **.proto.mock.go **.pb.*.go -**.gen.go +pkg/**/**.connect.go +**.gen.go openapi/**/*.json /internal/api/assets/authz.go /internal/api/assets/router.go diff --git a/Makefile b/Makefile index 10f52b7c4c..3bad5aa1c6 100644 --- a/Makefile +++ b/Makefile @@ -78,12 +78,13 @@ core_grpc_dependencies: go install github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2@v2.22.0 # https://pkg.go.dev/github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2?tab=versions go install github.com/envoyproxy/protoc-gen-validate@v1.1.0 # https://pkg.go.dev/github.com/envoyproxy/protoc-gen-validate?tab=versions go install github.com/bufbuild/buf/cmd/buf@v1.45.0 # https://pkg.go.dev/github.com/bufbuild/buf/cmd/buf?tab=versions + go install connectrpc.com/connect/cmd/protoc-gen-connect-go@v1.18.1 # https://pkg.go.dev/connectrpc.com/connect/cmd/protoc-gen-connect-go?tab=versions .PHONY: core_api core_api: core_api_generator core_grpc_dependencies buf generate mkdir -p pkg/grpc - cp -r .artifacts/grpc/github.com/zitadel/zitadel/pkg/grpc/* pkg/grpc/ + cp -r .artifacts/grpc/github.com/zitadel/zitadel/pkg/grpc/** pkg/grpc/ mkdir -p openapi/v2/zitadel cp -r .artifacts/grpc/zitadel/ openapi/v2/zitadel diff --git a/buf.gen.yaml b/buf.gen.yaml index 858a1e6404..5a29ba9cd3 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -19,3 +19,5 @@ plugins: out: .artifacts/grpc - plugin: zitadel out: .artifacts/grpc + - plugin: connect-go + out: .artifacts/grpc diff --git a/cmd/start/start.go b/cmd/start/start.go index 06f3554a58..50bb9fbdb3 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -59,7 +59,8 @@ import ( "github.com/zitadel/zitadel/internal/api/grpc/system" user_v2 "github.com/zitadel/zitadel/internal/api/grpc/user/v2" user_v2beta "github.com/zitadel/zitadel/internal/api/grpc/user/v2beta" - webkey "github.com/zitadel/zitadel/internal/api/grpc/webkey/v2beta" + webkey_v2 "github.com/zitadel/zitadel/internal/api/grpc/webkey/v2" + webkey_v2beta "github.com/zitadel/zitadel/internal/api/grpc/webkey/v2beta" http_util "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/api/idp" @@ -515,7 +516,10 @@ func startAPIs( if err := apis.RegisterService(ctx, user_v3_alpha.CreateServer(commands)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, webkey.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, webkey_v2beta.CreateServer(commands, queries)); err != nil { + return nil, err + } + if err := apis.RegisterService(ctx, webkey_v2.CreateServer(commands, queries)); err != nil { return nil, err } if err := apis.RegisterService(ctx, debug_events.CreateServer(commands, queries)); err != nil { diff --git a/docs/.gitignore b/docs/.gitignore index bd99d98c6f..e894d20ec6 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -27,3 +27,4 @@ npm-debug.log* yarn-debug.log* yarn-error.log* .vercel +/protoc-gen-connect-openapi* diff --git a/docs/base.yaml b/docs/base.yaml new file mode 100644 index 0000000000..dc5b9aa0f9 --- /dev/null +++ b/docs/base.yaml @@ -0,0 +1,3 @@ +openapi: 3.1.0 +info: + version: v2 \ No newline at end of file diff --git a/docs/buf.gen.yaml b/docs/buf.gen.yaml index a628f6e748..b507a2fb9c 100644 --- a/docs/buf.gen.yaml +++ b/docs/buf.gen.yaml @@ -1,11 +1,18 @@ # buf.gen.yaml -version: v1 +version: v2 managed: enabled: true plugins: - - plugin: buf.build/grpc-ecosystem/openapiv2 + - remote: buf.build/grpc-ecosystem/openapiv2 out: .artifacts/openapi opt: - allow_delete_body - remove_internal_comments=true - preserve_rpc_order=true + - local: ./protoc-gen-connect-openapi + out: .artifacts/openapi3 + strategy: all + opt: + - short-service-tags + - ignore-googleapi-http + - base=base.yaml diff --git a/docs/docusaurus.config.js b/docs/docusaurus.config.js index abf5c742a5..ffca8b21de 100644 --- a/docs/docusaurus.config.js +++ b/docs/docusaurus.config.js @@ -337,7 +337,7 @@ module.exports = { }, webkey_v2: { specPath: - ".artifacts/openapi/zitadel/webkey/v2beta/webkey_service.swagger.json", + ".artifacts/openapi3/zitadel/webkey/v2/webkey_service.openapi.yaml", outputDir: "docs/apis/resources/webkey_service_v2", sidebarOptions: { groupPathsBy: "tag", @@ -373,7 +373,7 @@ module.exports = { }, org_v2beta: { specPath: - ".artifacts/openapi/zitadel/org/v2beta/org_service.swagger.json", + ".artifacts/openapi3/zitadel/org/v2beta/org_service.openapi.yaml", outputDir: "docs/apis/resources/org_service_v2beta", sidebarOptions: { groupPathsBy: "tag", @@ -382,16 +382,24 @@ module.exports = { }, project_v2beta: { specPath: - ".artifacts/openapi/zitadel/project/v2beta/project_service.swagger.json", + ".artifacts/openapi3/zitadel/project/v2beta/project_service.openapi.yaml", outputDir: "docs/apis/resources/project_service_v2", sidebarOptions: { groupPathsBy: "tag", categoryLinkSource: "auto", }, }, + application_v2: { + specPath: ".artifacts/openapi3/zitadel/app/v2beta/app_service.openapi.yaml", + outputDir: "docs/apis/resources/application_service_v2", + sidebarOptions: { + groupPathsBy: "tag", + categoryLinkSource: "auto", + }, + }, instance_v2: { specPath: - ".artifacts/openapi/zitadel/instance/v2beta/instance_service.swagger.json", + ".artifacts/openapi3/zitadel/instance/v2beta/instance_service.openapi.yaml", outputDir: "docs/apis/resources/instance_service_v2", sidebarOptions: { groupPathsBy: "tag", diff --git a/docs/package.json b/docs/package.json index 2e1214f378..f799a5e76f 100644 --- a/docs/package.json +++ b/docs/package.json @@ -18,7 +18,8 @@ "generate:apidocs": "docusaurus gen-api-docs all", "generate:configdocs": "cp -r ../cmd/defaults.yaml ./docs/self-hosting/manage/configure/ && cp -r ../cmd/setup/steps.yaml ./docs/self-hosting/manage/configure/", "generate:re-gen": "yarn generate:clean-all && yarn generate", - "generate:clean-all": "docusaurus clean-api-docs all" + "generate:clean-all": "docusaurus clean-api-docs all", + "postinstall": "sh ./plugin-download.sh" }, "dependencies": { "@bufbuild/buf": "^1.14.0", diff --git a/docs/plugin-download.sh b/docs/plugin-download.sh new file mode 100644 index 0000000000..c6de8d702f --- /dev/null +++ b/docs/plugin-download.sh @@ -0,0 +1,21 @@ +echo $(uname -m) + +if [ "$(uname)" = "Darwin" ]; then + curl -L -o protoc-gen-connect-openapi.tar.gz https://github.com/sudorandom/protoc-gen-connect-openapi/releases/download/v0.18.0/protoc-gen-connect-openapi_0.18.0_darwin_all.tar.gz +else + ARCH=$(uname -m) + case $ARCH in + x86_64) + ARCH="amd64" + ;; + aarch64|arm64) + ARCH="arm64" + ;; + *) + echo "Unsupported architecture: $ARCH" + exit 1 + ;; + esac + curl -L -o protoc-gen-connect-openapi.tar.gz https://github.com/sudorandom/protoc-gen-connect-openapi/releases/download/v0.18.0/protoc-gen-connect-openapi_0.18.0_linux_${ARCH}.tar.gz +fi +tar -xvf protoc-gen-connect-openapi.tar.gz \ No newline at end of file diff --git a/docs/sidebars.js b/docs/sidebars.js index fe77ea0af2..a0de30271d 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -16,6 +16,7 @@ const sidebar_api_actions_v2 = require("./docs/apis/resources/action_service_v2/ const sidebar_api_project_service_v2 = require("./docs/apis/resources/project_service_v2/sidebar.ts").default const sidebar_api_webkey_service_v2 = require("./docs/apis/resources/webkey_service_v2/sidebar.ts").default const sidebar_api_instance_service_v2 = require("./docs/apis/resources/instance_service_v2/sidebar.ts").default +const sidebar_api_app_v2 = require("./docs/apis/resources/application_service_v2/sidebar.ts").default module.exports = { guides: [ @@ -806,6 +807,18 @@ module.exports = { }, items: sidebar_api_org_service_v2, }, + { + type: "category", + label: "Organization (Beta)", + link: { + type: "generated-index", + title: "Organization Service beta API", + slug: "/apis/resources/org_service/v2beta", + description: + "This API is intended to manage organizations for ZITADEL. \n", + }, + items: sidebar_api_org_service_v2beta, + }, { type: "category", label: "Identity Provider", @@ -820,19 +833,15 @@ module.exports = { }, { type: "category", - label: "Web key (Beta)", + label: "Web Key", link: { type: "generated-index", - title: "Web Key Service API (Beta)", + title: "Web Key Service API", slug: "/apis/resources/webkey_service_v2", description: "This API is intended to manage web keys for a ZITADEL instance, used to sign and validate OIDC tokens.\n" + - "\n" + - "This service is in beta state. It can AND will continue breaking until a stable version is released.\n"+ "\n"+ - "The public key endpoint (outside of this service) is used to retrieve the public keys of the active and inactive keys.\n"+ - "\n"+ - "Please make sure to enable the `web_key` feature flag on your instance to use this service and that you're running ZITADEL V3.", + "The public key endpoint (outside of this service) is used to retrieve the public keys of the active and inactive keys.\n", }, items: sidebar_api_webkey_service_v2 }, @@ -857,6 +866,54 @@ module.exports = { }, items: sidebar_api_actions_v2, }, + { + type: "category", + label: "Project (Beta)", + link: { + type: "generated-index", + title: "Project Service API (Beta)", + slug: "/apis/resources/project_service_v2", + description: + "This API is intended to manage projects and subresources for ZITADEL. \n" + + "\n" + + "This service is in beta state. It can AND will continue breaking until a stable version is released.", + }, + items: sidebar_api_project_service_v2, + }, + { + type: "category", + label: "Instance (Beta)", + link: { + type: "generated-index", + title: "Instance Service API (Beta)", + slug: "/apis/resources/instance_service_v2", + description: + "This API is intended to manage instances, custom domains and trusted domains in ZITADEL.\n" + + "\n" + + "This service is in beta state. It can AND will continue breaking until a stable version is released.\n"+ + "\n" + + "This v2 of the API provides the same functionalities as the v1, but organised on a per resource basis.\n" + + "The whole functionality related to domains (custom and trusted) has been moved under this instance API." + , + }, + items: sidebar_api_instance_service_v2, + }, + { + type: "category", + label: "App (Beta)", + link: { + type: "generated-index", + title: "Application Service API (Beta)", + slug: "/apis/resources/application_service_v2", + description: + "This API lets you manage Zitadel applications (API, SAML, OIDC).\n"+ + "\n"+ + "The API offers generic endpoints that work for all app types (API, SAML, OIDC), "+ + "\n"+ + "This API is in beta state. It can AND will continue breaking until a stable version is released.\n" + }, + items: sidebar_api_app_v2, + }, ], }, { diff --git a/docs/yarn.lock b/docs/yarn.lock index c48c5b8bd6..307577b44e 100644 --- a/docs/yarn.lock +++ b/docs/yarn.lock @@ -6121,6 +6121,11 @@ caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001718: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001724.tgz#312e163553dd70d2c0fb603d74810c85d8ed94a0" integrity sha512-WqJo7p0TbHDOythNTqYujmaJTvtYRZrjpP8TCvH6Vb9CYJerJNKamKzIWOM4BkQatWj9H2lYulpdAQNBe7QhNA== +caniuse-lite@^1.0.30001716: + version "1.0.30001726" + resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001726.tgz#a15bd87d5a4bf01f6b6f70ae7c97fdfd28b5ae47" + integrity sha512-VQAUIUzBiZ/UnlM28fSp2CRF3ivUn1BWEvxMcVTNwpw91Py1pGbPIyIKtd+tzct9C3ouceCVdGAXxZOpZAsgdw== + ccount@^2.0.0: version "2.0.1" resolved "https://registry.yarnpkg.com/ccount/-/ccount-2.0.1.tgz#17a3bf82302e0870d6da43a01311a8bc02a3ecf5" @@ -7503,6 +7508,11 @@ electron-to-chromium@^1.4.796: resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.4.803.tgz#cf55808a5ee12e2a2778bbe8cdc941ef87c2093b" integrity sha512-61H9mLzGOCLLVsnLiRzCbc63uldP0AniRYPV3hbGVtONA1pI7qSGILdbofR7A8TMbOypDocEAjH/e+9k1QIe3g== +electron-to-chromium@^1.5.149: + version "1.5.178" + resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.5.178.tgz#6fc4d69eb5275bb13068931448fd822458901fbb" + integrity sha512-wObbz/ar3Bc6e4X5vf0iO8xTN8YAjN/tgiAOJLr7yjYFtP9wAjq8Mb5h0yn6kResir+VYx2DXBj9NNobs0ETSA== + electron-to-chromium@^1.5.160: version "1.5.172" resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.5.172.tgz#fe1d99028d8d6321668d0f1fed61d99ac896259c" diff --git a/go.mod b/go.mod index ee7fb0a33a..22980acfaf 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ toolchain go1.24.1 require ( cloud.google.com/go/profiler v0.4.2 cloud.google.com/go/storage v1.54.0 + connectrpc.com/connect v1.18.1 + connectrpc.com/grpcreflect v1.3.0 dario.cat/mergo v1.0.2 github.com/BurntSushi/toml v1.5.0 github.com/DATA-DOG/go-sqlmock v1.5.2 diff --git a/go.sum b/go.sum index 01acf27c5d..7221111a2b 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,10 @@ cloud.google.com/go/storage v1.54.0 h1:Du3XEyliAiftfyW0bwfdppm2MMLdpVAfiIg4T2nAI cloud.google.com/go/storage v1.54.0/go.mod h1:hIi9Boe8cHxTyaeqh7KMMwKg088VblFK46C2x/BWaZE= cloud.google.com/go/trace v1.11.3 h1:c+I4YFjxRQjvAhRmSsmjpASUKq88chOX854ied0K/pE= cloud.google.com/go/trace v1.11.3/go.mod h1:pt7zCYiDSQjC9Y2oqCsh9jF4GStB/hmjrYLsxRR27q8= +connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= +connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= +connectrpc.com/grpcreflect v1.3.0 h1:Y4V+ACf8/vOb1XOc251Qun7jMB75gCUNw6llvB9csXc= +connectrpc.com/grpcreflect v1.3.0/go.mod h1:nfloOtCS8VUQOQ1+GTdFzVg2CJo4ZGaat8JIovCtDYs= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= diff --git a/internal/api/api.go b/internal/api/api.go index 62d3e14b35..349e9186bc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -7,16 +7,18 @@ import ( "sort" "strings" + "connectrpc.com/grpcreflect" "github.com/gorilla/mux" "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/zitadel/logging" "google.golang.org/grpc" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" "github.com/zitadel/zitadel/internal/api/authz" + grpc_api "github.com/zitadel/zitadel/internal/api/grpc" "github.com/zitadel/zitadel/internal/api/grpc/server" + "github.com/zitadel/zitadel/internal/api/grpc/server/connect_middleware" http_util "github.com/zitadel/zitadel/internal/api/http" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/api/ui/login" @@ -24,10 +26,16 @@ import ( "github.com/zitadel/zitadel/internal/telemetry/metrics" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" + system_pb "github.com/zitadel/zitadel/pkg/grpc/system" +) + +var ( + metricTypes = []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode} ) type API struct { port uint16 + externalDomain string grpcServer *grpc.Server verifier authz.APITokenVerifier health healthCheck @@ -37,16 +45,23 @@ type API struct { healthServer *health.Server accessInterceptor *http_mw.AccessInterceptor queries *query.Queries + authConfig authz.Config + systemAuthZ authz.Config + connectServices map[string][]string } func (a *API) ListGrpcServices() []string { serviceInfo := a.grpcServer.GetServiceInfo() - services := make([]string, len(serviceInfo)) + services := make([]string, len(serviceInfo)+len(a.connectServices)) i := 0 for servicename := range serviceInfo { services[i] = servicename i++ } + for prefix := range a.connectServices { + services[i] = strings.Trim(prefix, "/") + i++ + } sort.Strings(services) return services } @@ -59,6 +74,11 @@ func (a *API) ListGrpcMethods() []string { methods = append(methods, "/"+servicename+"/"+method.Name) } } + for service, methodList := range a.connectServices { + for _, method := range methodList { + methods = append(methods, service+method) + } + } sort.Strings(methods) return methods } @@ -82,12 +102,16 @@ func New( ) (_ *API, err error) { api := &API{ port: port, + externalDomain: externalDomain, verifier: verifier, health: queries, router: router, queries: queries, accessInterceptor: accessInterceptor, hostHeaders: hostHeaders, + authConfig: authZ, + systemAuthZ: systemAuthz, + connectServices: make(map[string][]string), } api.grpcServer = server.CreateServer(api.verifier, systemAuthz, authZ, queries, externalDomain, tlsConfig, accessInterceptor.AccessService()) @@ -100,10 +124,15 @@ func New( api.RegisterHandlerOnPrefix("/debug", api.healthHandler()) api.router.Handle("/", http.RedirectHandler(login.HandlerPrefix, http.StatusFound)) - reflection.Register(api.grpcServer) return api, nil } +func (a *API) serverReflection() { + reflector := grpcreflect.NewStaticReflector(a.ListGrpcServices()...) + a.RegisterHandlerOnPrefix(grpcreflect.NewHandlerV1(reflector)) + a.RegisterHandlerOnPrefix(grpcreflect.NewHandlerV1Alpha(reflector)) +} + // RegisterServer registers a grpc service on the grpc server, // creates a new grpc gateway and registers it as a separate http handler // @@ -131,17 +160,50 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP // and its gateway on the gateway handler // // used for >= v2 api (e.g. user, session, ...) -func (a *API) RegisterService(ctx context.Context, grpcServer server.Server) error { - grpcServer.RegisterServer(a.grpcServer) - err := server.RegisterGateway(ctx, a.grpcGateway, grpcServer) - if err != nil { - return err +func (a *API) RegisterService(ctx context.Context, srv server.Server) error { + switch service := srv.(type) { + case server.GrpcServer: + service.RegisterServer(a.grpcServer) + case server.ConnectServer: + a.registerConnectServer(service) } - a.verifier.RegisterServer(grpcServer.AppName(), grpcServer.MethodPrefix(), grpcServer.AuthMethods()) - a.healthServer.SetServingStatus(grpcServer.MethodPrefix(), healthpb.HealthCheckResponse_SERVING) + if withGateway, ok := srv.(server.WithGateway); ok { + err := server.RegisterGateway(ctx, a.grpcGateway, withGateway) + if err != nil { + return err + } + } + a.verifier.RegisterServer(srv.AppName(), srv.MethodPrefix(), srv.AuthMethods()) + a.healthServer.SetServingStatus(srv.MethodPrefix(), healthpb.HealthCheckResponse_SERVING) return nil } +func (a *API) registerConnectServer(service server.ConnectServer) { + prefix, handler := service.RegisterConnectServer( + connect_middleware.CallDurationHandler(), + connect_middleware.MetricsHandler(metricTypes, grpc_api.Probes...), + connect_middleware.NoCacheInterceptor(), + connect_middleware.InstanceInterceptor(a.queries, a.externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName), + connect_middleware.AccessStorageInterceptor(a.accessInterceptor.AccessService()), + connect_middleware.ErrorHandler(), + connect_middleware.LimitsInterceptor(system_pb.SystemService_ServiceDesc.ServiceName), + connect_middleware.AuthorizationInterceptor(a.verifier, a.systemAuthZ, a.authConfig), + connect_middleware.TranslationHandler(), + connect_middleware.QuotaExhaustedInterceptor(a.accessInterceptor.AccessService(), system_pb.SystemService_ServiceDesc.ServiceName), + connect_middleware.ExecutionHandler(a.queries), + connect_middleware.ValidationHandler(), + connect_middleware.ServiceHandler(), + connect_middleware.ActivityInterceptor(), + ) + methods := service.FileDescriptor().Services().Get(0).Methods() + methodNames := make([]string, methods.Len()) + for i := 0; i < methods.Len(); i++ { + methodNames[i] = string(methods.Get(i).Name()) + } + a.connectServices[prefix] = methodNames + a.RegisterHandlerPrefixes(handler, prefix) +} + // HandleFunc allows registering a [http.HandlerFunc] on an exact // path, instead of prefix like RegisterHandlerOnPrefix. func (a *API) HandleFunc(path string, f http.HandlerFunc) { @@ -173,6 +235,9 @@ func (a *API) registerHealthServer() { } func (a *API) RouteGRPC() { + // since all services are now registered, we can build the grpc server reflection and register the handler + a.serverReflection() + http2Route := a.router. MatcherFunc(func(r *http.Request, _ *mux.RouteMatch) bool { return r.ProtoMajor == 2 diff --git a/internal/api/grpc/action/v2beta/execution.go b/internal/api/grpc/action/v2beta/execution.go index 5477a8128e..3b49ebb364 100644 --- a/internal/api/grpc/action/v2beta/execution.go +++ b/internal/api/grpc/action/v2beta/execution.go @@ -3,6 +3,7 @@ package action import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/authz" @@ -13,8 +14,8 @@ import ( action "github.com/zitadel/zitadel/pkg/grpc/action/v2beta" ) -func (s *Server) SetExecution(ctx context.Context, req *action.SetExecutionRequest) (*action.SetExecutionResponse, error) { - reqTargets := req.GetTargets() +func (s *Server) SetExecution(ctx context.Context, req *connect.Request[action.SetExecutionRequest]) (*connect.Response[action.SetExecutionResponse], error) { + reqTargets := req.Msg.GetTargets() targets := make([]*execution.Target, len(reqTargets)) for i, target := range reqTargets { targets[i] = &execution.Target{Type: domain.ExecutionTargetTypeTarget, Target: target} @@ -25,7 +26,7 @@ func (s *Server) SetExecution(ctx context.Context, req *action.SetExecutionReque var err error var details *domain.ObjectDetails instanceID := authz.GetInstance(ctx).InstanceID() - switch t := req.GetCondition().GetConditionType().(type) { + switch t := req.Msg.GetCondition().GetConditionType().(type) { case *action.Condition_Request: cond := executionConditionFromRequest(t.Request) details, err = s.command.SetExecutionRequest(ctx, cond, set, instanceID) @@ -43,27 +44,27 @@ func (s *Server) SetExecution(ctx context.Context, req *action.SetExecutionReque if err != nil { return nil, err } - return &action.SetExecutionResponse{ + return connect.NewResponse(&action.SetExecutionResponse{ SetDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) ListExecutionFunctions(ctx context.Context, _ *action.ListExecutionFunctionsRequest) (*action.ListExecutionFunctionsResponse, error) { - return &action.ListExecutionFunctionsResponse{ +func (s *Server) ListExecutionFunctions(ctx context.Context, _ *connect.Request[action.ListExecutionFunctionsRequest]) (*connect.Response[action.ListExecutionFunctionsResponse], error) { + return connect.NewResponse(&action.ListExecutionFunctionsResponse{ Functions: s.ListActionFunctions(), - }, nil + }), nil } -func (s *Server) ListExecutionMethods(ctx context.Context, _ *action.ListExecutionMethodsRequest) (*action.ListExecutionMethodsResponse, error) { - return &action.ListExecutionMethodsResponse{ +func (s *Server) ListExecutionMethods(ctx context.Context, _ *connect.Request[action.ListExecutionMethodsRequest]) (*connect.Response[action.ListExecutionMethodsResponse], error) { + return connect.NewResponse(&action.ListExecutionMethodsResponse{ Methods: s.ListGRPCMethods(), - }, nil + }), nil } -func (s *Server) ListExecutionServices(ctx context.Context, _ *action.ListExecutionServicesRequest) (*action.ListExecutionServicesResponse, error) { - return &action.ListExecutionServicesResponse{ +func (s *Server) ListExecutionServices(ctx context.Context, _ *connect.Request[action.ListExecutionServicesRequest]) (*connect.Response[action.ListExecutionServicesResponse], error) { + return connect.NewResponse(&action.ListExecutionServicesResponse{ Services: s.ListGRPCServices(), - }, nil + }), nil } func executionConditionFromRequest(request *action.RequestExecution) *command.ExecutionAPICondition { diff --git a/internal/api/grpc/action/v2beta/query.go b/internal/api/grpc/action/v2beta/query.go index 1dbe80a8f7..9428b6ab7b 100644 --- a/internal/api/grpc/action/v2beta/query.go +++ b/internal/api/grpc/action/v2beta/query.go @@ -4,6 +4,7 @@ import ( "context" "strings" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -22,14 +23,14 @@ const ( conditionIDEventGroupSegmentCount = 1 ) -func (s *Server) GetTarget(ctx context.Context, req *action.GetTargetRequest) (*action.GetTargetResponse, error) { - resp, err := s.query.GetTargetByID(ctx, req.GetId()) +func (s *Server) GetTarget(ctx context.Context, req *connect.Request[action.GetTargetRequest]) (*connect.Response[action.GetTargetResponse], error) { + resp, err := s.query.GetTargetByID(ctx, req.Msg.GetId()) if err != nil { return nil, err } - return &action.GetTargetResponse{ + return connect.NewResponse(&action.GetTargetResponse{ Target: targetToPb(resp), - }, nil + }), nil } type InstanceContext interface { @@ -41,8 +42,8 @@ type Context interface { GetOwner() InstanceContext } -func (s *Server) ListTargets(ctx context.Context, req *action.ListTargetsRequest) (*action.ListTargetsResponse, error) { - queries, err := s.ListTargetsRequestToModel(req) +func (s *Server) ListTargets(ctx context.Context, req *connect.Request[action.ListTargetsRequest]) (*connect.Response[action.ListTargetsResponse], error) { + queries, err := s.ListTargetsRequestToModel(req.Msg) if err != nil { return nil, err } @@ -50,14 +51,14 @@ func (s *Server) ListTargets(ctx context.Context, req *action.ListTargetsRequest if err != nil { return nil, err } - return &action.ListTargetsResponse{ + return connect.NewResponse(&action.ListTargetsResponse{ Result: targetsToPb(resp.Targets), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, resp.SearchResponse), - }, nil + }), nil } -func (s *Server) ListExecutions(ctx context.Context, req *action.ListExecutionsRequest) (*action.ListExecutionsResponse, error) { - queries, err := s.ListExecutionsRequestToModel(req) +func (s *Server) ListExecutions(ctx context.Context, req *connect.Request[action.ListExecutionsRequest]) (*connect.Response[action.ListExecutionsResponse], error) { + queries, err := s.ListExecutionsRequestToModel(req.Msg) if err != nil { return nil, err } @@ -65,10 +66,10 @@ func (s *Server) ListExecutions(ctx context.Context, req *action.ListExecutionsR if err != nil { return nil, err } - return &action.ListExecutionsResponse{ + return connect.NewResponse(&action.ListExecutionsResponse{ Result: executionsToPb(resp.Executions), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, resp.SearchResponse), - }, nil + }), nil } func targetsToPb(targets []*query.Target) []*action.Target { diff --git a/internal/api/grpc/action/v2beta/server.go b/internal/api/grpc/action/v2beta/server.go index ef0d8eb2ba..440bf842ca 100644 --- a/internal/api/grpc/action/v2beta/server.go +++ b/internal/api/grpc/action/v2beta/server.go @@ -1,7 +1,10 @@ package action import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/query" action "github.com/zitadel/zitadel/pkg/grpc/action/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/action/v2beta/actionconnect" ) -var _ action.ActionServiceServer = (*Server)(nil) +var _ actionconnect.ActionServiceHandler = (*Server)(nil) type Server struct { - action.UnimplementedActionServiceServer systemDefaults systemdefaults.SystemDefaults command *command.Commands query *query.Queries @@ -43,8 +46,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - action.RegisterActionServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return actionconnect.NewActionServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return action.File_zitadel_action_v2beta_action_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/action/v2beta/target.go b/internal/api/grpc/action/v2beta/target.go index 26c88b9683..b13f3461f0 100644 --- a/internal/api/grpc/action/v2beta/target.go +++ b/internal/api/grpc/action/v2beta/target.go @@ -3,6 +3,7 @@ package action import ( "context" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/timestamppb" @@ -13,8 +14,8 @@ import ( action "github.com/zitadel/zitadel/pkg/grpc/action/v2beta" ) -func (s *Server) CreateTarget(ctx context.Context, req *action.CreateTargetRequest) (*action.CreateTargetResponse, error) { - add := createTargetToCommand(req) +func (s *Server) CreateTarget(ctx context.Context, req *connect.Request[action.CreateTargetRequest]) (*connect.Response[action.CreateTargetResponse], error) { + add := createTargetToCommand(req.Msg) instanceID := authz.GetInstance(ctx).InstanceID() createdAt, err := s.command.AddTarget(ctx, add, instanceID) if err != nil { @@ -24,16 +25,16 @@ func (s *Server) CreateTarget(ctx context.Context, req *action.CreateTargetReque if !createdAt.IsZero() { creationDate = timestamppb.New(createdAt) } - return &action.CreateTargetResponse{ + return connect.NewResponse(&action.CreateTargetResponse{ Id: add.AggregateID, CreationDate: creationDate, SigningKey: add.SigningKey, - }, nil + }), nil } -func (s *Server) UpdateTarget(ctx context.Context, req *action.UpdateTargetRequest) (*action.UpdateTargetResponse, error) { +func (s *Server) UpdateTarget(ctx context.Context, req *connect.Request[action.UpdateTargetRequest]) (*connect.Response[action.UpdateTargetResponse], error) { instanceID := authz.GetInstance(ctx).InstanceID() - update := updateTargetToCommand(req) + update := updateTargetToCommand(req.Msg) changedAt, err := s.command.ChangeTarget(ctx, update, instanceID) if err != nil { return nil, err @@ -42,15 +43,15 @@ func (s *Server) UpdateTarget(ctx context.Context, req *action.UpdateTargetReque if !changedAt.IsZero() { changeDate = timestamppb.New(changedAt) } - return &action.UpdateTargetResponse{ + return connect.NewResponse(&action.UpdateTargetResponse{ ChangeDate: changeDate, SigningKey: update.SigningKey, - }, nil + }), nil } -func (s *Server) DeleteTarget(ctx context.Context, req *action.DeleteTargetRequest) (*action.DeleteTargetResponse, error) { +func (s *Server) DeleteTarget(ctx context.Context, req *connect.Request[action.DeleteTargetRequest]) (*connect.Response[action.DeleteTargetResponse], error) { instanceID := authz.GetInstance(ctx).InstanceID() - deletedAt, err := s.command.DeleteTarget(ctx, req.GetId(), instanceID) + deletedAt, err := s.command.DeleteTarget(ctx, req.Msg.GetId(), instanceID) if err != nil { return nil, err } @@ -58,9 +59,9 @@ func (s *Server) DeleteTarget(ctx context.Context, req *action.DeleteTargetReque if !deletedAt.IsZero() { deletionDate = timestamppb.New(deletedAt) } - return &action.DeleteTargetResponse{ + return connect.NewResponse(&action.DeleteTargetResponse{ DeletionDate: deletionDate, - }, nil + }), nil } func createTargetToCommand(req *action.CreateTargetRequest) *command.AddTarget { diff --git a/internal/api/grpc/app/v2beta/app.go b/internal/api/grpc/app/v2beta/app.go index 48c602f454..e751bf503f 100644 --- a/internal/api/grpc/app/v2beta/app.go +++ b/internal/api/grpc/app/v2beta/app.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/grpc/app/v2beta/convert" @@ -13,15 +14,15 @@ import ( app "github.com/zitadel/zitadel/pkg/grpc/app/v2beta" ) -func (s *Server) CreateApplication(ctx context.Context, req *app.CreateApplicationRequest) (*app.CreateApplicationResponse, error) { - switch t := req.GetCreationRequestType().(type) { +func (s *Server) CreateApplication(ctx context.Context, req *connect.Request[app.CreateApplicationRequest]) (*connect.Response[app.CreateApplicationResponse], error) { + switch t := req.Msg.GetCreationRequestType().(type) { case *app.CreateApplicationRequest_ApiRequest: - apiApp, err := s.command.AddAPIApplication(ctx, convert.CreateAPIApplicationRequestToDomain(req.GetName(), req.GetProjectId(), req.GetId(), t.ApiRequest), "") + apiApp, err := s.command.AddAPIApplication(ctx, convert.CreateAPIApplicationRequestToDomain(req.Msg.GetName(), req.Msg.GetProjectId(), req.Msg.GetId(), t.ApiRequest), "") if err != nil { return nil, err } - return &app.CreateApplicationResponse{ + return connect.NewResponse(&app.CreateApplicationResponse{ AppId: apiApp.AppID, CreationDate: timestamppb.New(apiApp.ChangeDate), CreationResponseType: &app.CreateApplicationResponse_ApiResponse{ @@ -30,10 +31,10 @@ func (s *Server) CreateApplication(ctx context.Context, req *app.CreateApplicati ClientSecret: apiApp.ClientSecretString, }, }, - }, nil + }), nil case *app.CreateApplicationRequest_OidcRequest: - oidcAppRequest, err := convert.CreateOIDCAppRequestToDomain(req.GetName(), req.GetProjectId(), req.GetOidcRequest()) + oidcAppRequest, err := convert.CreateOIDCAppRequestToDomain(req.Msg.GetName(), req.Msg.GetProjectId(), req.Msg.GetOidcRequest()) if err != nil { return nil, err } @@ -43,7 +44,7 @@ func (s *Server) CreateApplication(ctx context.Context, req *app.CreateApplicati return nil, err } - return &app.CreateApplicationResponse{ + return connect.NewResponse(&app.CreateApplicationResponse{ AppId: oidcApp.AppID, CreationDate: timestamppb.New(oidcApp.ChangeDate), CreationResponseType: &app.CreateApplicationResponse_OidcResponse{ @@ -54,10 +55,10 @@ func (s *Server) CreateApplication(ctx context.Context, req *app.CreateApplicati ComplianceProblems: convert.ComplianceProblemsToLocalizedMessages(oidcApp.Compliance.Problems), }, }, - }, nil + }), nil case *app.CreateApplicationRequest_SamlRequest: - samlAppRequest, err := convert.CreateSAMLAppRequestToDomain(req.GetName(), req.GetProjectId(), req.GetSamlRequest()) + samlAppRequest, err := convert.CreateSAMLAppRequestToDomain(req.Msg.GetName(), req.Msg.GetProjectId(), req.Msg.GetSamlRequest()) if err != nil { return nil, err } @@ -67,27 +68,27 @@ func (s *Server) CreateApplication(ctx context.Context, req *app.CreateApplicati return nil, err } - return &app.CreateApplicationResponse{ + return connect.NewResponse(&app.CreateApplicationResponse{ AppId: samlApp.AppID, CreationDate: timestamppb.New(samlApp.ChangeDate), CreationResponseType: &app.CreateApplicationResponse_SamlResponse{ SamlResponse: &app.CreateSAMLApplicationResponse{}, }, - }, nil + }), nil default: return nil, zerrors.ThrowInvalidArgument(nil, "APP-0iiN46", "unknown app type") } } -func (s *Server) UpdateApplication(ctx context.Context, req *app.UpdateApplicationRequest) (*app.UpdateApplicationResponse, error) { +func (s *Server) UpdateApplication(ctx context.Context, req *connect.Request[app.UpdateApplicationRequest]) (*connect.Response[app.UpdateApplicationResponse], error) { var changedTime time.Time - if name := strings.TrimSpace(req.GetName()); name != "" { + if name := strings.TrimSpace(req.Msg.GetName()); name != "" { updatedDetails, err := s.command.UpdateApplicationName( ctx, - req.GetProjectId(), + req.Msg.GetProjectId(), &domain.ChangeApp{ - AppID: req.GetId(), + AppID: req.Msg.GetId(), AppName: name, }, "", @@ -99,9 +100,9 @@ func (s *Server) UpdateApplication(ctx context.Context, req *app.UpdateApplicati changedTime = updatedDetails.EventDate } - switch t := req.GetUpdateRequestType().(type) { + switch t := req.Msg.GetUpdateRequestType().(type) { case *app.UpdateApplicationRequest_ApiConfigurationRequest: - updatedAPIApp, err := s.command.UpdateAPIApplication(ctx, convert.UpdateAPIApplicationConfigurationRequestToDomain(req.GetId(), req.GetProjectId(), t.ApiConfigurationRequest), "") + updatedAPIApp, err := s.command.UpdateAPIApplication(ctx, convert.UpdateAPIApplicationConfigurationRequestToDomain(req.Msg.GetId(), req.Msg.GetProjectId(), t.ApiConfigurationRequest), "") if err != nil { return nil, err } @@ -109,7 +110,7 @@ func (s *Server) UpdateApplication(ctx context.Context, req *app.UpdateApplicati changedTime = updatedAPIApp.ChangeDate case *app.UpdateApplicationRequest_OidcConfigurationRequest: - oidcApp, err := convert.UpdateOIDCAppConfigRequestToDomain(req.GetId(), req.GetProjectId(), t.OidcConfigurationRequest) + oidcApp, err := convert.UpdateOIDCAppConfigRequestToDomain(req.Msg.GetId(), req.Msg.GetProjectId(), t.OidcConfigurationRequest) if err != nil { return nil, err } @@ -122,7 +123,7 @@ func (s *Server) UpdateApplication(ctx context.Context, req *app.UpdateApplicati changedTime = updatedOIDCApp.ChangeDate case *app.UpdateApplicationRequest_SamlConfigurationRequest: - samlApp, err := convert.UpdateSAMLAppConfigRequestToDomain(req.GetId(), req.GetProjectId(), t.SamlConfigurationRequest) + samlApp, err := convert.UpdateSAMLAppConfigRequestToDomain(req.Msg.GetId(), req.Msg.GetProjectId(), t.SamlConfigurationRequest) if err != nil { return nil, err } @@ -135,53 +136,53 @@ func (s *Server) UpdateApplication(ctx context.Context, req *app.UpdateApplicati changedTime = updatedSAMLApp.ChangeDate } - return &app.UpdateApplicationResponse{ + return connect.NewResponse(&app.UpdateApplicationResponse{ ChangeDate: timestamppb.New(changedTime), - }, nil + }), nil } -func (s *Server) DeleteApplication(ctx context.Context, req *app.DeleteApplicationRequest) (*app.DeleteApplicationResponse, error) { - details, err := s.command.RemoveApplication(ctx, req.GetProjectId(), req.GetId(), "") +func (s *Server) DeleteApplication(ctx context.Context, req *connect.Request[app.DeleteApplicationRequest]) (*connect.Response[app.DeleteApplicationResponse], error) { + details, err := s.command.RemoveApplication(ctx, req.Msg.GetProjectId(), req.Msg.GetId(), "") if err != nil { return nil, err } - return &app.DeleteApplicationResponse{ + return connect.NewResponse(&app.DeleteApplicationResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) DeactivateApplication(ctx context.Context, req *app.DeactivateApplicationRequest) (*app.DeactivateApplicationResponse, error) { - details, err := s.command.DeactivateApplication(ctx, req.GetProjectId(), req.GetId(), "") +func (s *Server) DeactivateApplication(ctx context.Context, req *connect.Request[app.DeactivateApplicationRequest]) (*connect.Response[app.DeactivateApplicationResponse], error) { + details, err := s.command.DeactivateApplication(ctx, req.Msg.GetProjectId(), req.Msg.GetId(), "") if err != nil { return nil, err } - return &app.DeactivateApplicationResponse{ + return connect.NewResponse(&app.DeactivateApplicationResponse{ DeactivationDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) ReactivateApplication(ctx context.Context, req *app.ReactivateApplicationRequest) (*app.ReactivateApplicationResponse, error) { - details, err := s.command.ReactivateApplication(ctx, req.GetProjectId(), req.GetId(), "") +func (s *Server) ReactivateApplication(ctx context.Context, req *connect.Request[app.ReactivateApplicationRequest]) (*connect.Response[app.ReactivateApplicationResponse], error) { + details, err := s.command.ReactivateApplication(ctx, req.Msg.GetProjectId(), req.Msg.GetId(), "") if err != nil { return nil, err } - return &app.ReactivateApplicationResponse{ + return connect.NewResponse(&app.ReactivateApplicationResponse{ ReactivationDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) RegenerateClientSecret(ctx context.Context, req *app.RegenerateClientSecretRequest) (*app.RegenerateClientSecretResponse, error) { +func (s *Server) RegenerateClientSecret(ctx context.Context, req *connect.Request[app.RegenerateClientSecretRequest]) (*connect.Response[app.RegenerateClientSecretResponse], error) { var secret string var changeDate time.Time - switch req.GetAppType().(type) { + switch req.Msg.GetAppType().(type) { case *app.RegenerateClientSecretRequest_IsApi: - config, err := s.command.ChangeAPIApplicationSecret(ctx, req.GetProjectId(), req.GetApplicationId(), "") + config, err := s.command.ChangeAPIApplicationSecret(ctx, req.Msg.GetProjectId(), req.Msg.GetApplicationId(), "") if err != nil { return nil, err } @@ -189,7 +190,7 @@ func (s *Server) RegenerateClientSecret(ctx context.Context, req *app.Regenerate changeDate = config.ChangeDate case *app.RegenerateClientSecretRequest_IsOidc: - config, err := s.command.ChangeOIDCApplicationSecret(ctx, req.GetProjectId(), req.GetApplicationId(), "") + config, err := s.command.ChangeOIDCApplicationSecret(ctx, req.Msg.GetProjectId(), req.Msg.GetApplicationId(), "") if err != nil { return nil, err } @@ -201,8 +202,8 @@ func (s *Server) RegenerateClientSecret(ctx context.Context, req *app.Regenerate return nil, zerrors.ThrowInvalidArgument(nil, "APP-aLWIzw", "unknown app type") } - return &app.RegenerateClientSecretResponse{ + return connect.NewResponse(&app.RegenerateClientSecretResponse{ ClientSecret: secret, CreationDate: timestamppb.New(changeDate), - }, nil + }), nil } diff --git a/internal/api/grpc/app/v2beta/app_key.go b/internal/api/grpc/app/v2beta/app_key.go index 8c0c1989b2..087ff90916 100644 --- a/internal/api/grpc/app/v2beta/app_key.go +++ b/internal/api/grpc/app/v2beta/app_key.go @@ -4,14 +4,15 @@ import ( "context" "strings" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/grpc/app/v2beta/convert" app "github.com/zitadel/zitadel/pkg/grpc/app/v2beta" ) -func (s *Server) CreateApplicationKey(ctx context.Context, req *app.CreateApplicationKeyRequest) (*app.CreateApplicationKeyResponse, error) { - domainReq := convert.CreateAPIClientKeyRequestToDomain(req) +func (s *Server) CreateApplicationKey(ctx context.Context, req *connect.Request[app.CreateApplicationKeyRequest]) (*connect.Response[app.CreateApplicationKeyResponse], error) { + domainReq := convert.CreateAPIClientKeyRequestToDomain(req.Msg) appKey, err := s.command.AddApplicationKey(ctx, domainReq, "") if err != nil { @@ -23,25 +24,25 @@ func (s *Server) CreateApplicationKey(ctx context.Context, req *app.CreateApplic return nil, err } - return &app.CreateApplicationKeyResponse{ + return connect.NewResponse(&app.CreateApplicationKeyResponse{ Id: appKey.KeyID, CreationDate: timestamppb.New(appKey.ChangeDate), KeyDetails: keyDetails, - }, nil + }), nil } -func (s *Server) DeleteApplicationKey(ctx context.Context, req *app.DeleteApplicationKeyRequest) (*app.DeleteApplicationKeyResponse, error) { +func (s *Server) DeleteApplicationKey(ctx context.Context, req *connect.Request[app.DeleteApplicationKeyRequest]) (*connect.Response[app.DeleteApplicationKeyResponse], error) { deletionDetails, err := s.command.RemoveApplicationKey(ctx, - strings.TrimSpace(req.GetProjectId()), - strings.TrimSpace(req.GetApplicationId()), - strings.TrimSpace(req.GetId()), - strings.TrimSpace(req.GetOrganizationId()), + strings.TrimSpace(req.Msg.GetProjectId()), + strings.TrimSpace(req.Msg.GetApplicationId()), + strings.TrimSpace(req.Msg.GetId()), + strings.TrimSpace(req.Msg.GetOrganizationId()), ) if err != nil { return nil, err } - return &app.DeleteApplicationKeyResponse{ + return connect.NewResponse(&app.DeleteApplicationKeyResponse{ DeletionDate: timestamppb.New(deletionDetails.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/app/v2beta/query.go b/internal/api/grpc/app/v2beta/query.go index 2926884520..ab2a98d14a 100644 --- a/internal/api/grpc/app/v2beta/query.go +++ b/internal/api/grpc/app/v2beta/query.go @@ -4,6 +4,7 @@ import ( "context" "strings" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/grpc/app/v2beta/convert" @@ -12,19 +13,19 @@ import ( app "github.com/zitadel/zitadel/pkg/grpc/app/v2beta" ) -func (s *Server) GetApplication(ctx context.Context, req *app.GetApplicationRequest) (*app.GetApplicationResponse, error) { - res, err := s.query.AppByIDWithPermission(ctx, req.GetId(), false, s.checkPermission) +func (s *Server) GetApplication(ctx context.Context, req *connect.Request[app.GetApplicationRequest]) (*connect.Response[app.GetApplicationResponse], error) { + res, err := s.query.AppByIDWithPermission(ctx, req.Msg.GetId(), false, s.checkPermission) if err != nil { return nil, err } - return &app.GetApplicationResponse{ + return connect.NewResponse(&app.GetApplicationResponse{ App: convert.AppToPb(res), - }, nil + }), nil } -func (s *Server) ListApplications(ctx context.Context, req *app.ListApplicationsRequest) (*app.ListApplicationsResponse, error) { - queries, err := convert.ListApplicationsRequestToModel(s.systemDefaults, req) +func (s *Server) ListApplications(ctx context.Context, req *connect.Request[app.ListApplicationsRequest]) (*connect.Response[app.ListApplicationsResponse], error) { + queries, err := convert.ListApplicationsRequestToModel(s.systemDefaults, req.Msg) if err != nil { return nil, err } @@ -34,32 +35,32 @@ func (s *Server) ListApplications(ctx context.Context, req *app.ListApplications return nil, err } - return &app.ListApplicationsResponse{ + return connect.NewResponse(&app.ListApplicationsResponse{ Applications: convert.AppsToPb(res.Apps), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, res.SearchResponse), - }, nil + }), nil } -func (s *Server) GetApplicationKey(ctx context.Context, req *app.GetApplicationKeyRequest) (*app.GetApplicationKeyResponse, error) { - queries, err := convert.GetApplicationKeyQueriesRequestToDomain(req.GetOrganizationId(), req.GetProjectId(), req.GetApplicationId()) +func (s *Server) GetApplicationKey(ctx context.Context, req *connect.Request[app.GetApplicationKeyRequest]) (*connect.Response[app.GetApplicationKeyResponse], error) { + queries, err := convert.GetApplicationKeyQueriesRequestToDomain(req.Msg.GetOrganizationId(), req.Msg.GetProjectId(), req.Msg.GetApplicationId()) if err != nil { return nil, err } - key, err := s.query.GetAuthNKeyByIDWithPermission(ctx, true, strings.TrimSpace(req.GetId()), s.checkPermission, queries...) + key, err := s.query.GetAuthNKeyByIDWithPermission(ctx, true, strings.TrimSpace(req.Msg.GetId()), s.checkPermission, queries...) if err != nil { return nil, err } - return &app.GetApplicationKeyResponse{ + return connect.NewResponse(&app.GetApplicationKeyResponse{ Id: key.ID, CreationDate: timestamppb.New(key.CreationDate), ExpirationDate: timestamppb.New(key.Expiration), - }, nil + }), nil } -func (s *Server) ListApplicationKeys(ctx context.Context, req *app.ListApplicationKeysRequest) (*app.ListApplicationKeysResponse, error) { - queries, err := convert.ListApplicationKeysRequestToDomain(s.systemDefaults, req) +func (s *Server) ListApplicationKeys(ctx context.Context, req *connect.Request[app.ListApplicationKeysRequest]) (*connect.Response[app.ListApplicationKeysResponse], error) { + queries, err := convert.ListApplicationKeysRequestToDomain(s.systemDefaults, req.Msg) if err != nil { return nil, err } @@ -69,8 +70,8 @@ func (s *Server) ListApplicationKeys(ctx context.Context, req *app.ListApplicati return nil, err } - return &app.ListApplicationKeysResponse{ + return connect.NewResponse(&app.ListApplicationKeysResponse{ Keys: convert.ApplicationKeysToPb(res.AuthNKeys), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, res.SearchResponse), - }, nil + }), nil } diff --git a/internal/api/grpc/app/v2beta/server.go b/internal/api/grpc/app/v2beta/server.go index 8343cbe404..54842070cb 100644 --- a/internal/api/grpc/app/v2beta/server.go +++ b/internal/api/grpc/app/v2beta/server.go @@ -1,21 +1,23 @@ package app import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" app "github.com/zitadel/zitadel/pkg/grpc/app/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/app/v2beta/appconnect" ) -var _ app.AppServiceServer = (*Server)(nil) +var _ appconnect.AppServiceHandler = (*Server)(nil) type Server struct { - app.UnimplementedAppServiceServer command *command.Commands query *query.Queries systemDefaults systemdefaults.SystemDefaults @@ -36,8 +38,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - app.RegisterAppServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return appconnect.NewAppServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return app.File_zitadel_app_v2beta_app_service_proto } func (s *Server) AppName() string { @@ -51,7 +57,3 @@ func (s *Server) MethodPrefix() string { func (s *Server) AuthMethods() authz.MethodMapping { return app.AppService_AuthMethods } - -func (s *Server) RegisterGateway() server.RegisterGatewayFunc { - return app.RegisterAppServiceHandler -} diff --git a/internal/api/grpc/feature/v2/feature.go b/internal/api/grpc/feature/v2/feature.go index f4527689fc..f450f734e4 100644 --- a/internal/api/grpc/feature/v2/feature.go +++ b/internal/api/grpc/feature/v2/feature.go @@ -3,6 +3,7 @@ package feature import ( "context" + "connectrpc.com/connect" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -10,8 +11,8 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/feature/v2" ) -func (s *Server) SetSystemFeatures(ctx context.Context, req *feature.SetSystemFeaturesRequest) (_ *feature.SetSystemFeaturesResponse, err error) { - features, err := systemFeaturesToCommand(req) +func (s *Server) SetSystemFeatures(ctx context.Context, req *connect.Request[feature.SetSystemFeaturesRequest]) (_ *connect.Response[feature.SetSystemFeaturesResponse], err error) { + features, err := systemFeaturesToCommand(req.Msg) if err != nil { return nil, err } @@ -19,31 +20,31 @@ func (s *Server) SetSystemFeatures(ctx context.Context, req *feature.SetSystemFe if err != nil { return nil, err } - return &feature.SetSystemFeaturesResponse{ + return connect.NewResponse(&feature.SetSystemFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ResetSystemFeatures(ctx context.Context, req *feature.ResetSystemFeaturesRequest) (_ *feature.ResetSystemFeaturesResponse, err error) { +func (s *Server) ResetSystemFeatures(ctx context.Context, req *connect.Request[feature.ResetSystemFeaturesRequest]) (_ *connect.Response[feature.ResetSystemFeaturesResponse], err error) { details, err := s.command.ResetSystemFeatures(ctx) if err != nil { return nil, err } - return &feature.ResetSystemFeaturesResponse{ + return connect.NewResponse(&feature.ResetSystemFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) GetSystemFeatures(ctx context.Context, req *feature.GetSystemFeaturesRequest) (_ *feature.GetSystemFeaturesResponse, err error) { +func (s *Server) GetSystemFeatures(ctx context.Context, req *connect.Request[feature.GetSystemFeaturesRequest]) (_ *connect.Response[feature.GetSystemFeaturesResponse], err error) { f, err := s.query.GetSystemFeatures(ctx) if err != nil { return nil, err } - return systemFeaturesToPb(f), nil + return connect.NewResponse(systemFeaturesToPb(f)), nil } -func (s *Server) SetInstanceFeatures(ctx context.Context, req *feature.SetInstanceFeaturesRequest) (_ *feature.SetInstanceFeaturesResponse, err error) { - features, err := instanceFeaturesToCommand(req) +func (s *Server) SetInstanceFeatures(ctx context.Context, req *connect.Request[feature.SetInstanceFeaturesRequest]) (_ *connect.Response[feature.SetInstanceFeaturesResponse], err error) { + features, err := instanceFeaturesToCommand(req.Msg) if err != nil { return nil, err } @@ -51,44 +52,44 @@ func (s *Server) SetInstanceFeatures(ctx context.Context, req *feature.SetInstan if err != nil { return nil, err } - return &feature.SetInstanceFeaturesResponse{ + return connect.NewResponse(&feature.SetInstanceFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ResetInstanceFeatures(ctx context.Context, req *feature.ResetInstanceFeaturesRequest) (_ *feature.ResetInstanceFeaturesResponse, err error) { +func (s *Server) ResetInstanceFeatures(ctx context.Context, req *connect.Request[feature.ResetInstanceFeaturesRequest]) (_ *connect.Response[feature.ResetInstanceFeaturesResponse], err error) { details, err := s.command.ResetInstanceFeatures(ctx) if err != nil { return nil, err } - return &feature.ResetInstanceFeaturesResponse{ + return connect.NewResponse(&feature.ResetInstanceFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) GetInstanceFeatures(ctx context.Context, req *feature.GetInstanceFeaturesRequest) (_ *feature.GetInstanceFeaturesResponse, err error) { - f, err := s.query.GetInstanceFeatures(ctx, req.GetInheritance()) +func (s *Server) GetInstanceFeatures(ctx context.Context, req *connect.Request[feature.GetInstanceFeaturesRequest]) (_ *connect.Response[feature.GetInstanceFeaturesResponse], err error) { + f, err := s.query.GetInstanceFeatures(ctx, req.Msg.GetInheritance()) if err != nil { return nil, err } - return instanceFeaturesToPb(f), nil + return connect.NewResponse(instanceFeaturesToPb(f)), nil } -func (s *Server) SetOrganizationFeatures(ctx context.Context, req *feature.SetOrganizationFeaturesRequest) (_ *feature.SetOrganizationFeaturesResponse, err error) { +func (s *Server) SetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.SetOrganizationFeaturesRequest]) (_ *connect.Response[feature.SetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method SetOrganizationFeatures not implemented") } -func (s *Server) ResetOrganizationFeatures(ctx context.Context, req *feature.ResetOrganizationFeaturesRequest) (_ *feature.ResetOrganizationFeaturesResponse, err error) { +func (s *Server) ResetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.ResetOrganizationFeaturesRequest]) (_ *connect.Response[feature.ResetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method ResetOrganizationFeatures not implemented") } -func (s *Server) GetOrganizationFeatures(ctx context.Context, req *feature.GetOrganizationFeaturesRequest) (_ *feature.GetOrganizationFeaturesResponse, err error) { +func (s *Server) GetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.GetOrganizationFeaturesRequest]) (_ *connect.Response[feature.GetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method GetOrganizationFeatures not implemented") } -func (s *Server) SetUserFeatures(ctx context.Context, req *feature.SetUserFeatureRequest) (_ *feature.SetUserFeaturesResponse, err error) { +func (s *Server) SetUserFeatures(ctx context.Context, req *connect.Request[feature.SetUserFeatureRequest]) (_ *connect.Response[feature.SetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method SetUserFeatures not implemented") } -func (s *Server) ResetUserFeatures(ctx context.Context, req *feature.ResetUserFeaturesRequest) (_ *feature.ResetUserFeaturesResponse, err error) { +func (s *Server) ResetUserFeatures(ctx context.Context, req *connect.Request[feature.ResetUserFeaturesRequest]) (_ *connect.Response[feature.ResetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method ResetUserFeatures not implemented") } -func (s *Server) GetUserFeatures(ctx context.Context, req *feature.GetUserFeaturesRequest) (_ *feature.GetUserFeaturesResponse, err error) { +func (s *Server) GetUserFeatures(ctx context.Context, req *connect.Request[feature.GetUserFeaturesRequest]) (_ *connect.Response[feature.GetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method GetUserFeatures not implemented") } diff --git a/internal/api/grpc/feature/v2/server.go b/internal/api/grpc/feature/v2/server.go index ab92df5822..3eb4cc6813 100644 --- a/internal/api/grpc/feature/v2/server.go +++ b/internal/api/grpc/feature/v2/server.go @@ -1,17 +1,22 @@ package feature import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/feature/v2" + "github.com/zitadel/zitadel/pkg/grpc/feature/v2/featureconnect" ) +var _ featureconnect.FeatureServiceHandler = (*Server)(nil) + type Server struct { - feature.UnimplementedFeatureServiceServer command *command.Commands query *query.Queries } @@ -26,8 +31,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - feature.RegisterFeatureServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return featureconnect.NewFeatureServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return feature.File_zitadel_feature_v2_feature_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/feature/v2beta/feature.go b/internal/api/grpc/feature/v2beta/feature.go index b94f8e7de2..4ff51af883 100644 --- a/internal/api/grpc/feature/v2beta/feature.go +++ b/internal/api/grpc/feature/v2beta/feature.go @@ -3,6 +3,7 @@ package feature import ( "context" + "connectrpc.com/connect" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -10,77 +11,77 @@ import ( feature "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta" ) -func (s *Server) SetSystemFeatures(ctx context.Context, req *feature.SetSystemFeaturesRequest) (_ *feature.SetSystemFeaturesResponse, err error) { - details, err := s.command.SetSystemFeatures(ctx, systemFeaturesToCommand(req)) +func (s *Server) SetSystemFeatures(ctx context.Context, req *connect.Request[feature.SetSystemFeaturesRequest]) (_ *connect.Response[feature.SetSystemFeaturesResponse], err error) { + details, err := s.command.SetSystemFeatures(ctx, systemFeaturesToCommand(req.Msg)) if err != nil { return nil, err } - return &feature.SetSystemFeaturesResponse{ + return connect.NewResponse(&feature.SetSystemFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ResetSystemFeatures(ctx context.Context, req *feature.ResetSystemFeaturesRequest) (_ *feature.ResetSystemFeaturesResponse, err error) { +func (s *Server) ResetSystemFeatures(ctx context.Context, req *connect.Request[feature.ResetSystemFeaturesRequest]) (_ *connect.Response[feature.ResetSystemFeaturesResponse], err error) { details, err := s.command.ResetSystemFeatures(ctx) if err != nil { return nil, err } - return &feature.ResetSystemFeaturesResponse{ + return connect.NewResponse(&feature.ResetSystemFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) GetSystemFeatures(ctx context.Context, req *feature.GetSystemFeaturesRequest) (_ *feature.GetSystemFeaturesResponse, err error) { +func (s *Server) GetSystemFeatures(ctx context.Context, req *connect.Request[feature.GetSystemFeaturesRequest]) (_ *connect.Response[feature.GetSystemFeaturesResponse], err error) { f, err := s.query.GetSystemFeatures(ctx) if err != nil { return nil, err } - return systemFeaturesToPb(f), nil + return connect.NewResponse(systemFeaturesToPb(f)), nil } -func (s *Server) SetInstanceFeatures(ctx context.Context, req *feature.SetInstanceFeaturesRequest) (_ *feature.SetInstanceFeaturesResponse, err error) { - details, err := s.command.SetInstanceFeatures(ctx, instanceFeaturesToCommand(req)) +func (s *Server) SetInstanceFeatures(ctx context.Context, req *connect.Request[feature.SetInstanceFeaturesRequest]) (_ *connect.Response[feature.SetInstanceFeaturesResponse], err error) { + details, err := s.command.SetInstanceFeatures(ctx, instanceFeaturesToCommand(req.Msg)) if err != nil { return nil, err } - return &feature.SetInstanceFeaturesResponse{ + return connect.NewResponse(&feature.SetInstanceFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ResetInstanceFeatures(ctx context.Context, req *feature.ResetInstanceFeaturesRequest) (_ *feature.ResetInstanceFeaturesResponse, err error) { +func (s *Server) ResetInstanceFeatures(ctx context.Context, req *connect.Request[feature.ResetInstanceFeaturesRequest]) (_ *connect.Response[feature.ResetInstanceFeaturesResponse], err error) { details, err := s.command.ResetInstanceFeatures(ctx) if err != nil { return nil, err } - return &feature.ResetInstanceFeaturesResponse{ + return connect.NewResponse(&feature.ResetInstanceFeaturesResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) GetInstanceFeatures(ctx context.Context, req *feature.GetInstanceFeaturesRequest) (_ *feature.GetInstanceFeaturesResponse, err error) { - f, err := s.query.GetInstanceFeatures(ctx, req.GetInheritance()) +func (s *Server) GetInstanceFeatures(ctx context.Context, req *connect.Request[feature.GetInstanceFeaturesRequest]) (_ *connect.Response[feature.GetInstanceFeaturesResponse], err error) { + f, err := s.query.GetInstanceFeatures(ctx, req.Msg.GetInheritance()) if err != nil { return nil, err } - return instanceFeaturesToPb(f), nil + return connect.NewResponse(instanceFeaturesToPb(f)), nil } -func (s *Server) SetOrganizationFeatures(ctx context.Context, req *feature.SetOrganizationFeaturesRequest) (_ *feature.SetOrganizationFeaturesResponse, err error) { +func (s *Server) SetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.SetOrganizationFeaturesRequest]) (_ *connect.Response[feature.SetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method SetOrganizationFeatures not implemented") } -func (s *Server) ResetOrganizationFeatures(ctx context.Context, req *feature.ResetOrganizationFeaturesRequest) (_ *feature.ResetOrganizationFeaturesResponse, err error) { +func (s *Server) ResetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.ResetOrganizationFeaturesRequest]) (_ *connect.Response[feature.ResetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method ResetOrganizationFeatures not implemented") } -func (s *Server) GetOrganizationFeatures(ctx context.Context, req *feature.GetOrganizationFeaturesRequest) (_ *feature.GetOrganizationFeaturesResponse, err error) { +func (s *Server) GetOrganizationFeatures(ctx context.Context, req *connect.Request[feature.GetOrganizationFeaturesRequest]) (_ *connect.Response[feature.GetOrganizationFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method GetOrganizationFeatures not implemented") } -func (s *Server) SetUserFeatures(ctx context.Context, req *feature.SetUserFeatureRequest) (_ *feature.SetUserFeaturesResponse, err error) { +func (s *Server) SetUserFeatures(ctx context.Context, req *connect.Request[feature.SetUserFeatureRequest]) (_ *connect.Response[feature.SetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method SetUserFeatures not implemented") } -func (s *Server) ResetUserFeatures(ctx context.Context, req *feature.ResetUserFeaturesRequest) (_ *feature.ResetUserFeaturesResponse, err error) { +func (s *Server) ResetUserFeatures(ctx context.Context, req *connect.Request[feature.ResetUserFeaturesRequest]) (_ *connect.Response[feature.ResetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method ResetUserFeatures not implemented") } -func (s *Server) GetUserFeatures(ctx context.Context, req *feature.GetUserFeaturesRequest) (_ *feature.GetUserFeaturesResponse, err error) { +func (s *Server) GetUserFeatures(ctx context.Context, req *connect.Request[feature.GetUserFeaturesRequest]) (_ *connect.Response[feature.GetUserFeaturesResponse], err error) { return nil, status.Errorf(codes.Unimplemented, "method GetUserFeatures not implemented") } diff --git a/internal/api/grpc/feature/v2beta/server.go b/internal/api/grpc/feature/v2beta/server.go index 4208c4acfc..29877f77f9 100644 --- a/internal/api/grpc/feature/v2beta/server.go +++ b/internal/api/grpc/feature/v2beta/server.go @@ -1,17 +1,22 @@ package feature import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" feature "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/feature/v2beta/featureconnect" ) +var _ featureconnect.FeatureServiceHandler = (*Server)(nil) + type Server struct { - feature.UnimplementedFeatureServiceServer command *command.Commands query *query.Queries } @@ -26,8 +31,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - feature.RegisterFeatureServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return featureconnect.NewFeatureServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return feature.File_zitadel_feature_v2beta_feature_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/gerrors/zitadel_errors.go b/internal/api/grpc/gerrors/zitadel_errors.go index d679054da6..b5d2893062 100644 --- a/internal/api/grpc/gerrors/zitadel_errors.go +++ b/internal/api/grpc/gerrors/zitadel_errors.go @@ -3,10 +3,12 @@ package gerrors import ( "errors" + "connectrpc.com/connect" "github.com/jackc/pgx/v5/pgconn" "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" commandErrors "github.com/zitadel/zitadel/internal/command/errors" @@ -36,6 +38,30 @@ func ZITADELToGRPCError(err error) error { return s.Err() } +func ZITADELToConnectError(err error) error { + if err == nil { + return nil + } + connectError := new(connect.Error) + if errors.As(err, &connectError) { + return err + } + code, key, id, ok := ExtractZITADELError(err) + if !ok { + return status.Convert(err).Err() + } + msg := key + msg += " (" + id + ")" + + errorInfo := getErrorInfo(id, key, err) + + cErr := connect.NewError(connect.Code(code), errors.New(msg)) + if detail, detailErr := connect.NewErrorDetail(errorInfo.(proto.Message)); detailErr == nil { + cErr.AddDetail(detail) + } + return cErr +} + func ExtractZITADELError(err error) (c codes.Code, msg, id string, ok bool) { if err == nil { return codes.OK, "", "", false diff --git a/internal/api/grpc/idp/v2/query.go b/internal/api/grpc/idp/v2/query.go index 082a94d18f..587b1687b9 100644 --- a/internal/api/grpc/idp/v2/query.go +++ b/internal/api/grpc/idp/v2/query.go @@ -3,6 +3,7 @@ package idp import ( "context" + "connectrpc.com/connect" "github.com/crewjam/saml" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/durationpb" @@ -15,12 +16,12 @@ import ( idp_pb "github.com/zitadel/zitadel/pkg/grpc/idp/v2" ) -func (s *Server) GetIDPByID(ctx context.Context, req *idp_pb.GetIDPByIDRequest) (*idp_pb.GetIDPByIDResponse, error) { - idp, err := s.query.IDPTemplateByID(ctx, true, req.Id, false, s.checkPermission) +func (s *Server) GetIDPByID(ctx context.Context, req *connect.Request[idp_pb.GetIDPByIDRequest]) (*connect.Response[idp_pb.GetIDPByIDResponse], error) { + idp, err := s.query.IDPTemplateByID(ctx, true, req.Msg.GetId(), false, s.checkPermission) if err != nil { return nil, err } - return &idp_pb.GetIDPByIDResponse{Idp: idpToPb(idp)}, nil + return connect.NewResponse(&idp_pb.GetIDPByIDResponse{Idp: idpToPb(idp)}), nil } func idpToPb(idp *query.IDPTemplate) *idp_pb.IDP { diff --git a/internal/api/grpc/idp/v2/server.go b/internal/api/grpc/idp/v2/server.go index 246e980434..666c39294d 100644 --- a/internal/api/grpc/idp/v2/server.go +++ b/internal/api/grpc/idp/v2/server.go @@ -1,7 +1,10 @@ package idp import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/idp/v2" + "github.com/zitadel/zitadel/pkg/grpc/idp/v2/idpconnect" ) -var _ idp.IdentityProviderServiceServer = (*Server)(nil) +var _ idpconnect.IdentityProviderServiceHandler = (*Server)(nil) type Server struct { - idp.UnimplementedIdentityProviderServiceServer command *command.Commands query *query.Queries @@ -35,8 +38,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - idp.RegisterIdentityProviderServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return idpconnect.NewIdentityProviderServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return idp.File_zitadel_idp_v2_idp_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/instance/v2beta/domain.go b/internal/api/grpc/instance/v2beta/domain.go index 439c6e5d8d..380ebff5a7 100644 --- a/internal/api/grpc/instance/v2beta/domain.go +++ b/internal/api/grpc/instance/v2beta/domain.go @@ -3,48 +3,49 @@ package instance import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" instance "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" ) -func (s *Server) AddCustomDomain(ctx context.Context, req *instance.AddCustomDomainRequest) (*instance.AddCustomDomainResponse, error) { - details, err := s.command.AddInstanceDomain(ctx, req.GetDomain()) +func (s *Server) AddCustomDomain(ctx context.Context, req *connect.Request[instance.AddCustomDomainRequest]) (*connect.Response[instance.AddCustomDomainResponse], error) { + details, err := s.command.AddInstanceDomain(ctx, req.Msg.GetDomain()) if err != nil { return nil, err } - return &instance.AddCustomDomainResponse{ + return connect.NewResponse(&instance.AddCustomDomainResponse{ CreationDate: timestamppb.New(details.CreationDate), - }, nil + }), nil } -func (s *Server) RemoveCustomDomain(ctx context.Context, req *instance.RemoveCustomDomainRequest) (*instance.RemoveCustomDomainResponse, error) { - details, err := s.command.RemoveInstanceDomain(ctx, req.GetDomain()) +func (s *Server) RemoveCustomDomain(ctx context.Context, req *connect.Request[instance.RemoveCustomDomainRequest]) (*connect.Response[instance.RemoveCustomDomainResponse], error) { + details, err := s.command.RemoveInstanceDomain(ctx, req.Msg.GetDomain()) if err != nil { return nil, err } - return &instance.RemoveCustomDomainResponse{ + return connect.NewResponse(&instance.RemoveCustomDomainResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) AddTrustedDomain(ctx context.Context, req *instance.AddTrustedDomainRequest) (*instance.AddTrustedDomainResponse, error) { - details, err := s.command.AddTrustedDomain(ctx, req.GetDomain()) +func (s *Server) AddTrustedDomain(ctx context.Context, req *connect.Request[instance.AddTrustedDomainRequest]) (*connect.Response[instance.AddTrustedDomainResponse], error) { + details, err := s.command.AddTrustedDomain(ctx, req.Msg.GetDomain()) if err != nil { return nil, err } - return &instance.AddTrustedDomainResponse{ + return connect.NewResponse(&instance.AddTrustedDomainResponse{ CreationDate: timestamppb.New(details.CreationDate), - }, nil + }), nil } -func (s *Server) RemoveTrustedDomain(ctx context.Context, req *instance.RemoveTrustedDomainRequest) (*instance.RemoveTrustedDomainResponse, error) { - details, err := s.command.RemoveTrustedDomain(ctx, req.GetDomain()) +func (s *Server) RemoveTrustedDomain(ctx context.Context, req *connect.Request[instance.RemoveTrustedDomainRequest]) (*connect.Response[instance.RemoveTrustedDomainResponse], error) { + details, err := s.command.RemoveTrustedDomain(ctx, req.Msg.GetDomain()) if err != nil { return nil, err } - return &instance.RemoveTrustedDomainResponse{ + return connect.NewResponse(&instance.RemoveTrustedDomainResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/instance/v2beta/instance.go b/internal/api/grpc/instance/v2beta/instance.go index b1c36e74bb..b3f2d6e478 100644 --- a/internal/api/grpc/instance/v2beta/instance.go +++ b/internal/api/grpc/instance/v2beta/instance.go @@ -3,30 +3,31 @@ package instance import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" instance "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" ) -func (s *Server) DeleteInstance(ctx context.Context, request *instance.DeleteInstanceRequest) (*instance.DeleteInstanceResponse, error) { - obj, err := s.command.RemoveInstance(ctx, request.GetInstanceId()) +func (s *Server) DeleteInstance(ctx context.Context, request *connect.Request[instance.DeleteInstanceRequest]) (*connect.Response[instance.DeleteInstanceResponse], error) { + obj, err := s.command.RemoveInstance(ctx, request.Msg.GetInstanceId()) if err != nil { return nil, err } - return &instance.DeleteInstanceResponse{ + return connect.NewResponse(&instance.DeleteInstanceResponse{ DeletionDate: timestamppb.New(obj.EventDate), - }, nil + }), nil } -func (s *Server) UpdateInstance(ctx context.Context, request *instance.UpdateInstanceRequest) (*instance.UpdateInstanceResponse, error) { - obj, err := s.command.UpdateInstance(ctx, request.GetInstanceName()) +func (s *Server) UpdateInstance(ctx context.Context, request *connect.Request[instance.UpdateInstanceRequest]) (*connect.Response[instance.UpdateInstanceResponse], error) { + obj, err := s.command.UpdateInstance(ctx, request.Msg.GetInstanceName()) if err != nil { return nil, err } - return &instance.UpdateInstanceResponse{ + return connect.NewResponse(&instance.UpdateInstanceResponse{ ChangeDate: timestamppb.New(obj.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/instance/v2beta/query.go b/internal/api/grpc/instance/v2beta/query.go index 74f79313ea..10716ffda0 100644 --- a/internal/api/grpc/instance/v2beta/query.go +++ b/internal/api/grpc/instance/v2beta/query.go @@ -3,23 +3,25 @@ package instance import ( "context" + "connectrpc.com/connect" + filter "github.com/zitadel/zitadel/internal/api/grpc/filter/v2beta" instance "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" ) -func (s *Server) GetInstance(ctx context.Context, _ *instance.GetInstanceRequest) (*instance.GetInstanceResponse, error) { +func (s *Server) GetInstance(ctx context.Context, _ *connect.Request[instance.GetInstanceRequest]) (*connect.Response[instance.GetInstanceResponse], error) { inst, err := s.query.Instance(ctx, true) if err != nil { return nil, err } - return &instance.GetInstanceResponse{ + return connect.NewResponse(&instance.GetInstanceResponse{ Instance: ToProtoObject(inst), - }, nil + }), nil } -func (s *Server) ListInstances(ctx context.Context, req *instance.ListInstancesRequest) (*instance.ListInstancesResponse, error) { - queries, err := ListInstancesRequestToModel(req, s.systemDefaults) +func (s *Server) ListInstances(ctx context.Context, req *connect.Request[instance.ListInstancesRequest]) (*connect.Response[instance.ListInstancesResponse], error) { + queries, err := ListInstancesRequestToModel(req.Msg, s.systemDefaults) if err != nil { return nil, err } @@ -29,14 +31,14 @@ func (s *Server) ListInstances(ctx context.Context, req *instance.ListInstancesR return nil, err } - return &instance.ListInstancesResponse{ + return connect.NewResponse(&instance.ListInstancesResponse{ Instances: InstancesToPb(instances.Instances), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, instances.SearchResponse), - }, nil + }), nil } -func (s *Server) ListCustomDomains(ctx context.Context, req *instance.ListCustomDomainsRequest) (*instance.ListCustomDomainsResponse, error) { - queries, err := ListCustomDomainsRequestToModel(req, s.systemDefaults) +func (s *Server) ListCustomDomains(ctx context.Context, req *connect.Request[instance.ListCustomDomainsRequest]) (*connect.Response[instance.ListCustomDomainsResponse], error) { + queries, err := ListCustomDomainsRequestToModel(req.Msg, s.systemDefaults) if err != nil { return nil, err } @@ -46,14 +48,14 @@ func (s *Server) ListCustomDomains(ctx context.Context, req *instance.ListCustom return nil, err } - return &instance.ListCustomDomainsResponse{ + return connect.NewResponse(&instance.ListCustomDomainsResponse{ Domains: DomainsToPb(domains.Domains), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, domains.SearchResponse), - }, nil + }), nil } -func (s *Server) ListTrustedDomains(ctx context.Context, req *instance.ListTrustedDomainsRequest) (*instance.ListTrustedDomainsResponse, error) { - queries, err := ListTrustedDomainsRequestToModel(req, s.systemDefaults) +func (s *Server) ListTrustedDomains(ctx context.Context, req *connect.Request[instance.ListTrustedDomainsRequest]) (*connect.Response[instance.ListTrustedDomainsResponse], error) { + queries, err := ListTrustedDomainsRequestToModel(req.Msg, s.systemDefaults) if err != nil { return nil, err } @@ -63,8 +65,8 @@ func (s *Server) ListTrustedDomains(ctx context.Context, req *instance.ListTrust return nil, err } - return &instance.ListTrustedDomainsResponse{ + return connect.NewResponse(&instance.ListTrustedDomainsResponse{ TrustedDomain: trustedDomainsToPb(domains.Domains), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, domains.SearchResponse), - }, nil + }), nil } diff --git a/internal/api/grpc/instance/v2beta/server.go b/internal/api/grpc/instance/v2beta/server.go index aaeaa4cc8f..1fb3513dd6 100644 --- a/internal/api/grpc/instance/v2beta/server.go +++ b/internal/api/grpc/instance/v2beta/server.go @@ -1,7 +1,10 @@ package instance import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/query" instance "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta/instanceconnect" ) -var _ instance.InstanceServiceServer = (*Server)(nil) +var _ instanceconnect.InstanceServiceHandler = (*Server)(nil) type Server struct { - instance.UnimplementedInstanceServiceServer command *command.Commands query *query.Queries systemDefaults systemdefaults.SystemDefaults @@ -39,8 +42,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - instance.RegisterInstanceServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return instanceconnect.NewInstanceServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return instance.File_zitadel_instance_v2beta_instance_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/oidc/v2/oidc.go b/internal/api/grpc/oidc/v2/oidc.go index 8612d11558..d56d6da056 100644 --- a/internal/api/grpc/oidc/v2/oidc.go +++ b/internal/api/grpc/oidc/v2/oidc.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" + "connectrpc.com/connect" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/op" "google.golang.org/protobuf/types/known/durationpb" @@ -18,30 +19,30 @@ import ( oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" ) -func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) { - authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true) +func (s *Server) GetAuthRequest(ctx context.Context, req *connect.Request[oidc_pb.GetAuthRequestRequest]) (*connect.Response[oidc_pb.GetAuthRequestResponse], error) { + authRequest, err := s.query.AuthRequestByID(ctx, true, req.Msg.GetAuthRequestId(), true) if err != nil { logging.WithError(err).Error("query authRequest by ID") return nil, err } - return &oidc_pb.GetAuthRequestResponse{ + return connect.NewResponse(&oidc_pb.GetAuthRequestResponse{ AuthRequest: authRequestToPb(authRequest), - }, nil + }), nil } -func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) { - switch v := req.GetCallbackKind().(type) { +func (s *Server) CreateCallback(ctx context.Context, req *connect.Request[oidc_pb.CreateCallbackRequest]) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { + switch v := req.Msg.GetCallbackKind().(type) { case *oidc_pb.CreateCallbackRequest_Error: - return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error) + return s.failAuthRequest(ctx, req.Msg.GetAuthRequestId(), v.Error) case *oidc_pb.CreateCallbackRequest_Session: - return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session) + return s.linkSessionToAuthRequest(ctx, req.Msg.GetAuthRequestId(), v.Session) default: return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) } } -func (s *Server) GetDeviceAuthorizationRequest(ctx context.Context, req *oidc_pb.GetDeviceAuthorizationRequestRequest) (*oidc_pb.GetDeviceAuthorizationRequestResponse, error) { - deviceRequest, err := s.query.DeviceAuthRequestByUserCode(ctx, req.GetUserCode()) +func (s *Server) GetDeviceAuthorizationRequest(ctx context.Context, req *connect.Request[oidc_pb.GetDeviceAuthorizationRequestRequest]) (*connect.Response[oidc_pb.GetDeviceAuthorizationRequestResponse], error) { + deviceRequest, err := s.query.DeviceAuthRequestByUserCode(ctx, req.Msg.GetUserCode()) if err != nil { return nil, err } @@ -49,7 +50,7 @@ func (s *Server) GetDeviceAuthorizationRequest(ctx context.Context, req *oidc_pb if err != nil { return nil, err } - return &oidc_pb.GetDeviceAuthorizationRequestResponse{ + return connect.NewResponse(&oidc_pb.GetDeviceAuthorizationRequestResponse{ DeviceAuthorizationRequest: &oidc_pb.DeviceAuthorizationRequest{ Id: base64.RawURLEncoding.EncodeToString(encrypted), ClientId: deviceRequest.ClientID, @@ -57,24 +58,24 @@ func (s *Server) GetDeviceAuthorizationRequest(ctx context.Context, req *oidc_pb AppName: deviceRequest.AppName, ProjectName: deviceRequest.ProjectName, }, - }, nil + }), nil } -func (s *Server) AuthorizeOrDenyDeviceAuthorization(ctx context.Context, req *oidc_pb.AuthorizeOrDenyDeviceAuthorizationRequest) (*oidc_pb.AuthorizeOrDenyDeviceAuthorizationResponse, error) { - deviceCode, err := s.deviceCodeFromID(req.GetDeviceAuthorizationId()) +func (s *Server) AuthorizeOrDenyDeviceAuthorization(ctx context.Context, req *connect.Request[oidc_pb.AuthorizeOrDenyDeviceAuthorizationRequest]) (*connect.Response[oidc_pb.AuthorizeOrDenyDeviceAuthorizationResponse], error) { + deviceCode, err := s.deviceCodeFromID(req.Msg.GetDeviceAuthorizationId()) if err != nil { return nil, err } - switch req.GetDecision().(type) { + switch req.Msg.GetDecision().(type) { case *oidc_pb.AuthorizeOrDenyDeviceAuthorizationRequest_Session: - _, err = s.command.ApproveDeviceAuthWithSession(ctx, deviceCode, req.GetSession().GetSessionId(), req.GetSession().GetSessionToken()) + _, err = s.command.ApproveDeviceAuthWithSession(ctx, deviceCode, req.Msg.GetSession().GetSessionId(), req.Msg.GetSession().GetSessionToken()) case *oidc_pb.AuthorizeOrDenyDeviceAuthorizationRequest_Deny: _, err = s.command.CancelDeviceAuth(ctx, deviceCode, domain.DeviceAuthCanceledDenied) } if err != nil { return nil, err } - return &oidc_pb.AuthorizeOrDenyDeviceAuthorizationResponse{}, nil + return connect.NewResponse(&oidc_pb.AuthorizeOrDenyDeviceAuthorizationResponse{}), nil } func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest { @@ -136,7 +137,7 @@ func (s *Server) checkPermission(ctx context.Context, clientID string, userID st return nil } -func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) { +func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError())) if err != nil { return nil, err @@ -146,13 +147,13 @@ func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae * if err != nil { return nil, err } - return &oidc_pb.CreateCallbackResponse{ + return connect.NewResponse(&oidc_pb.CreateCallbackResponse{ Details: object.DomainToDetailsPb(details), CallbackUrl: callback, - }, nil + }), nil } -func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) { +func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true, s.checkPermission) if err != nil { return nil, err @@ -172,10 +173,10 @@ func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID str if err != nil { return nil, err } - return &oidc_pb.CreateCallbackResponse{ + return connect.NewResponse(&oidc_pb.CreateCallbackResponse{ Details: object.DomainToDetailsPb(details), CallbackUrl: callback, - }, nil + }), nil } func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason { diff --git a/internal/api/grpc/oidc/v2/server.go b/internal/api/grpc/oidc/v2/server.go index 99234ee3d7..3d8f78a8ad 100644 --- a/internal/api/grpc/oidc/v2/server.go +++ b/internal/api/grpc/oidc/v2/server.go @@ -1,7 +1,10 @@ package oidc import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -10,12 +13,12 @@ import ( "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/query" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" + "github.com/zitadel/zitadel/pkg/grpc/oidc/v2/oidcconnect" ) -var _ oidc_pb.OIDCServiceServer = (*Server)(nil) +var _ oidcconnect.OIDCServiceHandler = (*Server)(nil) type Server struct { - oidc_pb.UnimplementedOIDCServiceServer command *command.Commands query *query.Queries @@ -42,8 +45,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - oidc_pb.RegisterOIDCServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return oidcconnect.NewOIDCServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return oidc_pb.File_zitadel_oidc_v2_oidc_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/oidc/v2beta/oidc.go b/internal/api/grpc/oidc/v2beta/oidc.go index 66c4bee828..432e6f833f 100644 --- a/internal/api/grpc/oidc/v2beta/oidc.go +++ b/internal/api/grpc/oidc/v2beta/oidc.go @@ -3,6 +3,7 @@ package oidc import ( "context" + "connectrpc.com/connect" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/op" "google.golang.org/protobuf/types/known/durationpb" @@ -17,15 +18,15 @@ import ( oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" ) -func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) { - authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true) +func (s *Server) GetAuthRequest(ctx context.Context, req *connect.Request[oidc_pb.GetAuthRequestRequest]) (*connect.Response[oidc_pb.GetAuthRequestResponse], error) { + authRequest, err := s.query.AuthRequestByID(ctx, true, req.Msg.GetAuthRequestId(), true) if err != nil { logging.WithError(err).Error("query authRequest by ID") return nil, err } - return &oidc_pb.GetAuthRequestResponse{ + return connect.NewResponse(&oidc_pb.GetAuthRequestResponse{ AuthRequest: authRequestToPb(authRequest), - }, nil + }), nil } func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest { @@ -73,18 +74,18 @@ func promptToPb(p domain.Prompt) oidc_pb.Prompt { } } -func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) { - switch v := req.GetCallbackKind().(type) { +func (s *Server) CreateCallback(ctx context.Context, req *connect.Request[oidc_pb.CreateCallbackRequest]) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { + switch v := req.Msg.GetCallbackKind().(type) { case *oidc_pb.CreateCallbackRequest_Error: - return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error) + return s.failAuthRequest(ctx, req.Msg.GetAuthRequestId(), v.Error) case *oidc_pb.CreateCallbackRequest_Session: - return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session) + return s.linkSessionToAuthRequest(ctx, req.Msg.GetAuthRequestId(), v.Session) default: return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) } } -func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) { +func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError())) if err != nil { return nil, err @@ -94,10 +95,10 @@ func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae * if err != nil { return nil, err } - return &oidc_pb.CreateCallbackResponse{ + return connect.NewResponse(&oidc_pb.CreateCallbackResponse{ Details: object.DomainToDetailsPb(details), CallbackUrl: callback, - }, nil + }), nil } func (s *Server) checkPermission(ctx context.Context, clientID string, userID string) error { @@ -114,7 +115,7 @@ func (s *Server) checkPermission(ctx context.Context, clientID string, userID st return nil } -func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) { +func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*connect.Response[oidc_pb.CreateCallbackResponse], error) { details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true, s.checkPermission) if err != nil { return nil, err @@ -130,10 +131,10 @@ func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID str if err != nil { return nil, err } - return &oidc_pb.CreateCallbackResponse{ + return connect.NewResponse(&oidc_pb.CreateCallbackResponse{ Details: object.DomainToDetailsPb(details), CallbackUrl: callback, - }, nil + }), nil } func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason { diff --git a/internal/api/grpc/oidc/v2beta/server.go b/internal/api/grpc/oidc/v2beta/server.go index 7595ae927e..5309a5093e 100644 --- a/internal/api/grpc/oidc/v2beta/server.go +++ b/internal/api/grpc/oidc/v2beta/server.go @@ -1,7 +1,10 @@ package oidc import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta/oidcconnect" ) -var _ oidc_pb.OIDCServiceServer = (*Server)(nil) +var _ oidcconnect.OIDCServiceHandler = (*Server)(nil) type Server struct { - oidc_pb.UnimplementedOIDCServiceServer command *command.Commands query *query.Queries @@ -38,8 +41,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - oidc_pb.RegisterOIDCServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return oidcconnect.NewOIDCServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return oidc_pb.File_zitadel_oidc_v2beta_oidc_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/org/v2/org.go b/internal/api/grpc/org/v2/org.go index b876826365..42832d147f 100644 --- a/internal/api/grpc/org/v2/org.go +++ b/internal/api/grpc/org/v2/org.go @@ -3,6 +3,8 @@ package org import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/api/grpc/user/v2" "github.com/zitadel/zitadel/internal/command" @@ -10,8 +12,8 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/org/v2" ) -func (s *Server) AddOrganization(ctx context.Context, request *org.AddOrganizationRequest) (*org.AddOrganizationResponse, error) { - orgSetup, err := addOrganizationRequestToCommand(request) +func (s *Server) AddOrganization(ctx context.Context, request *connect.Request[org.AddOrganizationRequest]) (*connect.Response[org.AddOrganizationResponse], error) { + orgSetup, err := addOrganizationRequestToCommand(request.Msg) if err != nil { return nil, err } @@ -68,7 +70,7 @@ func addOrganizationRequestAdminToCommand(admin *org.AddOrganizationRequest_Admi } } -func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *org.AddOrganizationResponse, err error) { +func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *connect.Response[org.AddOrganizationResponse], err error) { admins := make([]*org.AddOrganizationResponse_CreatedAdmin, 0, len(createdOrg.OrgAdmins)) for _, admin := range createdOrg.OrgAdmins { admin, ok := admin.(*command.CreatedOrgAdmin) @@ -80,9 +82,9 @@ func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *org.AddOrganiza }) } } - return &org.AddOrganizationResponse{ + return connect.NewResponse(&org.AddOrganizationResponse{ Details: object.DomainToDetailsPb(createdOrg.ObjectDetails), OrganizationId: createdOrg.ObjectDetails.ResourceOwner, CreatedAdmins: admins, - }, nil + }), nil } diff --git a/internal/api/grpc/org/v2/org_test.go b/internal/api/grpc/org/v2/org_test.go index 37a3dca41a..564c5597ee 100644 --- a/internal/api/grpc/org/v2/org_test.go +++ b/internal/api/grpc/org/v2/org_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/timestamppb" @@ -138,7 +139,7 @@ func Test_createdOrganizationToPb(t *testing.T) { tests := []struct { name string args args - want *org.AddOrganizationResponse + want *connect.Response[org.AddOrganizationResponse] wantErr error }{ { @@ -159,7 +160,7 @@ func Test_createdOrganizationToPb(t *testing.T) { }, }, }, - want: &org.AddOrganizationResponse{ + want: connect.NewResponse(&org.AddOrganizationResponse{ Details: &object.Details{ Sequence: 1, ChangeDate: timestamppb.New(now), @@ -173,7 +174,7 @@ func Test_createdOrganizationToPb(t *testing.T) { PhoneCode: gu.Ptr("phoneCode"), }, }, - }, + }), }, } for _, tt := range tests { diff --git a/internal/api/grpc/org/v2/query.go b/internal/api/grpc/org/v2/query.go index 27f279d40e..09e2534e8d 100644 --- a/internal/api/grpc/org/v2/query.go +++ b/internal/api/grpc/org/v2/query.go @@ -3,6 +3,8 @@ package org import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/domain" @@ -11,36 +13,36 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/org/v2" ) -func (s *Server) ListOrganizations(ctx context.Context, req *org.ListOrganizationsRequest) (*org.ListOrganizationsResponse, error) { +func (s *Server) ListOrganizations(ctx context.Context, req *connect.Request[org.ListOrganizationsRequest]) (*connect.Response[org.ListOrganizationsResponse], error) { queries, err := listOrgRequestToModel(ctx, req) if err != nil { return nil, err } - orgs, err := s.query.SearchOrgs(ctx, queries, s.checkPermission) + orgs, err := s.query.SearchOrgs(ctx, queries.Msg, s.checkPermission) if err != nil { return nil, err } - return &org.ListOrganizationsResponse{ + return connect.NewResponse(&org.ListOrganizationsResponse{ Result: organizationsToPb(orgs.Orgs), Details: object.ToListDetails(orgs.SearchResponse), - }, nil + }), nil } -func listOrgRequestToModel(ctx context.Context, req *org.ListOrganizationsRequest) (*query.OrgSearchQueries, error) { - offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := orgQueriesToQuery(ctx, req.Queries) +func listOrgRequestToModel(ctx context.Context, req *connect.Request[org.ListOrganizationsRequest]) (*connect.Response[query.OrgSearchQueries], error) { + offset, limit, asc := object.ListQueryToQuery(req.Msg.Query) + queries, err := orgQueriesToQuery(ctx, req.Msg.Queries) if err != nil { return nil, err } - return &query.OrgSearchQueries{ + return connect.NewResponse(&query.OrgSearchQueries{ SearchRequest: query.SearchRequest{ Offset: offset, Limit: limit, - SortingColumn: fieldNameToOrganizationColumn(req.SortingColumn), + SortingColumn: fieldNameToOrganizationColumn(req.Msg.SortingColumn), Asc: asc, }, Queries: queries, - }, nil + }), nil } func orgQueriesToQuery(ctx context.Context, queries []*org.SearchQuery) (_ []query.SearchQuery, err error) { diff --git a/internal/api/grpc/org/v2/server.go b/internal/api/grpc/org/v2/server.go index 36588f3eb7..6fd318d114 100644 --- a/internal/api/grpc/org/v2/server.go +++ b/internal/api/grpc/org/v2/server.go @@ -1,7 +1,10 @@ package org import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/org/v2" + "github.com/zitadel/zitadel/pkg/grpc/org/v2/orgconnect" ) -var _ org.OrganizationServiceServer = (*Server)(nil) +var _ orgconnect.OrganizationServiceHandler = (*Server)(nil) type Server struct { - org.UnimplementedOrganizationServiceServer command *command.Commands query *query.Queries checkPermission domain.PermissionCheck @@ -34,8 +37,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - org.RegisterOrganizationServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return orgconnect.NewOrganizationServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return org.File_zitadel_org_v2_org_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/org/v2beta/helper.go b/internal/api/grpc/org/v2beta/helper.go index 6f47819bb4..77c3130488 100644 --- a/internal/api/grpc/org/v2beta/helper.go +++ b/internal/api/grpc/org/v2beta/helper.go @@ -3,6 +3,7 @@ package org import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" // TODO fix below @@ -71,7 +72,7 @@ func OrgStateToPb(state domain.OrgState) v2beta_org.OrgState { } } -func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *org.CreateOrganizationResponse, err error) { +func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *connect.Response[org.CreateOrganizationResponse], err error) { admins := make([]*org.OrganizationAdmin, len(createdOrg.OrgAdmins)) for i, admin := range createdOrg.OrgAdmins { switch admin := admin.(type) { @@ -95,11 +96,11 @@ func createdOrganizationToPb(createdOrg *command.CreatedOrg) (_ *org.CreateOrgan } } } - return &org.CreateOrganizationResponse{ + return connect.NewResponse(&org.CreateOrganizationResponse{ CreationDate: timestamppb.New(createdOrg.ObjectDetails.EventDate), Id: createdOrg.ObjectDetails.ResourceOwner, OrganizationAdmins: admins, - }, nil + }), nil } func OrgViewsToPb(orgs []*query.Org) []*v2beta_org.Organization { diff --git a/internal/api/grpc/org/v2beta/org.go b/internal/api/grpc/org/v2beta/org.go index 66198757cb..35e1d72d3c 100644 --- a/internal/api/grpc/org/v2beta/org.go +++ b/internal/api/grpc/org/v2beta/org.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" metadata "github.com/zitadel/zitadel/internal/api/grpc/metadata/v2beta" @@ -17,8 +18,8 @@ import ( v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" ) -func (s *Server) CreateOrganization(ctx context.Context, request *v2beta_org.CreateOrganizationRequest) (*v2beta_org.CreateOrganizationResponse, error) { - orgSetup, err := createOrganizationRequestToCommand(request) +func (s *Server) CreateOrganization(ctx context.Context, request *connect.Request[v2beta_org.CreateOrganizationRequest]) (*connect.Response[v2beta_org.CreateOrganizationResponse], error) { + orgSetup, err := createOrganizationRequestToCommand(request.Msg) if err != nil { return nil, err } @@ -29,19 +30,19 @@ func (s *Server) CreateOrganization(ctx context.Context, request *v2beta_org.Cre return createdOrganizationToPb(createdOrg) } -func (s *Server) UpdateOrganization(ctx context.Context, request *v2beta_org.UpdateOrganizationRequest) (*v2beta_org.UpdateOrganizationResponse, error) { - org, err := s.command.ChangeOrg(ctx, request.Id, request.Name) +func (s *Server) UpdateOrganization(ctx context.Context, request *connect.Request[v2beta_org.UpdateOrganizationRequest]) (*connect.Response[v2beta_org.UpdateOrganizationResponse], error) { + org, err := s.command.ChangeOrg(ctx, request.Msg.GetId(), request.Msg.GetName()) if err != nil { return nil, err } - return &v2beta_org.UpdateOrganizationResponse{ + return connect.NewResponse(&v2beta_org.UpdateOrganizationResponse{ ChangeDate: timestamppb.New(org.EventDate), - }, nil + }), nil } -func (s *Server) ListOrganizations(ctx context.Context, request *v2beta_org.ListOrganizationsRequest) (*v2beta_org.ListOrganizationsResponse, error) { - queries, err := listOrgRequestToModel(s.systemDefaults, request) +func (s *Server) ListOrganizations(ctx context.Context, request *connect.Request[v2beta_org.ListOrganizationsRequest]) (*connect.Response[v2beta_org.ListOrganizationsResponse], error) { + queries, err := listOrgRequestToModel(s.systemDefaults, request.Msg) if err != nil { return nil, err } @@ -49,107 +50,107 @@ func (s *Server) ListOrganizations(ctx context.Context, request *v2beta_org.List if err != nil { return nil, err } - return &v2beta_org.ListOrganizationsResponse{ + return connect.NewResponse(&v2beta_org.ListOrganizationsResponse{ Organizations: OrgViewsToPb(orgs.Orgs), Pagination: &filter.PaginationResponse{ TotalResult: orgs.Count, - AppliedLimit: uint64(request.GetPagination().GetLimit()), + AppliedLimit: uint64(request.Msg.GetPagination().GetLimit()), }, - }, nil + }), nil } -func (s *Server) DeleteOrganization(ctx context.Context, request *v2beta_org.DeleteOrganizationRequest) (*v2beta_org.DeleteOrganizationResponse, error) { - details, err := s.command.RemoveOrg(ctx, request.Id) +func (s *Server) DeleteOrganization(ctx context.Context, request *connect.Request[v2beta_org.DeleteOrganizationRequest]) (*connect.Response[v2beta_org.DeleteOrganizationResponse], error) { + details, err := s.command.RemoveOrg(ctx, request.Msg.GetId()) if err != nil { var notFoundError *zerrors.NotFoundError if errors.As(err, ¬FoundError) { - return &v2beta_org.DeleteOrganizationResponse{}, nil + return connect.NewResponse(&v2beta_org.DeleteOrganizationResponse{}), nil } return nil, err } - return &v2beta_org.DeleteOrganizationResponse{ + return connect.NewResponse(&v2beta_org.DeleteOrganizationResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) SetOrganizationMetadata(ctx context.Context, request *v2beta_org.SetOrganizationMetadataRequest) (*v2beta_org.SetOrganizationMetadataResponse, error) { - result, err := s.command.BulkSetOrgMetadata(ctx, request.OrganizationId, BulkSetOrgMetadataToDomain(request)...) +func (s *Server) SetOrganizationMetadata(ctx context.Context, request *connect.Request[v2beta_org.SetOrganizationMetadataRequest]) (*connect.Response[v2beta_org.SetOrganizationMetadataResponse], error) { + result, err := s.command.BulkSetOrgMetadata(ctx, request.Msg.GetOrganizationId(), BulkSetOrgMetadataToDomain(request.Msg)...) if err != nil { return nil, err } - return &org.SetOrganizationMetadataResponse{ + return connect.NewResponse(&org.SetOrganizationMetadataResponse{ SetDate: timestamppb.New(result.EventDate), - }, nil + }), nil } -func (s *Server) ListOrganizationMetadata(ctx context.Context, request *v2beta_org.ListOrganizationMetadataRequest) (*v2beta_org.ListOrganizationMetadataResponse, error) { - metadataQueries, err := ListOrgMetadataToDomain(s.systemDefaults, request) +func (s *Server) ListOrganizationMetadata(ctx context.Context, request *connect.Request[v2beta_org.ListOrganizationMetadataRequest]) (*connect.Response[v2beta_org.ListOrganizationMetadataResponse], error) { + metadataQueries, err := ListOrgMetadataToDomain(s.systemDefaults, request.Msg) if err != nil { return nil, err } - res, err := s.query.SearchOrgMetadata(ctx, true, request.OrganizationId, metadataQueries, false) + res, err := s.query.SearchOrgMetadata(ctx, true, request.Msg.GetOrganizationId(), metadataQueries, false) if err != nil { return nil, err } - return &v2beta_org.ListOrganizationMetadataResponse{ + return connect.NewResponse(&v2beta_org.ListOrganizationMetadataResponse{ Metadata: metadata.OrgMetadataListToPb(res.Metadata), Pagination: &filter.PaginationResponse{ TotalResult: res.Count, - AppliedLimit: uint64(request.GetPagination().GetLimit()), + AppliedLimit: uint64(request.Msg.GetPagination().GetLimit()), }, - }, nil + }), nil } -func (s *Server) DeleteOrganizationMetadata(ctx context.Context, request *v2beta_org.DeleteOrganizationMetadataRequest) (*v2beta_org.DeleteOrganizationMetadataResponse, error) { - result, err := s.command.BulkRemoveOrgMetadata(ctx, request.OrganizationId, request.Keys...) +func (s *Server) DeleteOrganizationMetadata(ctx context.Context, request *connect.Request[v2beta_org.DeleteOrganizationMetadataRequest]) (*connect.Response[v2beta_org.DeleteOrganizationMetadataResponse], error) { + result, err := s.command.BulkRemoveOrgMetadata(ctx, request.Msg.GetOrganizationId(), request.Msg.Keys...) if err != nil { return nil, err } - return &v2beta_org.DeleteOrganizationMetadataResponse{ + return connect.NewResponse(&v2beta_org.DeleteOrganizationMetadataResponse{ DeletionDate: timestamppb.New(result.EventDate), - }, nil + }), nil } -func (s *Server) DeactivateOrganization(ctx context.Context, request *org.DeactivateOrganizationRequest) (*org.DeactivateOrganizationResponse, error) { - objectDetails, err := s.command.DeactivateOrg(ctx, request.Id) +func (s *Server) DeactivateOrganization(ctx context.Context, request *connect.Request[org.DeactivateOrganizationRequest]) (*connect.Response[org.DeactivateOrganizationResponse], error) { + objectDetails, err := s.command.DeactivateOrg(ctx, request.Msg.GetId()) if err != nil { return nil, err } - return &org.DeactivateOrganizationResponse{ + return connect.NewResponse(&org.DeactivateOrganizationResponse{ ChangeDate: timestamppb.New(objectDetails.EventDate), - }, nil + }), nil } -func (s *Server) ActivateOrganization(ctx context.Context, request *org.ActivateOrganizationRequest) (*org.ActivateOrganizationResponse, error) { - objectDetails, err := s.command.ReactivateOrg(ctx, request.Id) +func (s *Server) ActivateOrganization(ctx context.Context, request *connect.Request[org.ActivateOrganizationRequest]) (*connect.Response[org.ActivateOrganizationResponse], error) { + objectDetails, err := s.command.ReactivateOrg(ctx, request.Msg.GetId()) if err != nil { return nil, err } - return &org.ActivateOrganizationResponse{ + return connect.NewResponse(&org.ActivateOrganizationResponse{ ChangeDate: timestamppb.New(objectDetails.EventDate), - }, err + }), err } -func (s *Server) AddOrganizationDomain(ctx context.Context, request *org.AddOrganizationDomainRequest) (*org.AddOrganizationDomainResponse, error) { - userIDs, err := s.getClaimedUserIDsOfOrgDomain(ctx, request.Domain, request.OrganizationId) +func (s *Server) AddOrganizationDomain(ctx context.Context, request *connect.Request[org.AddOrganizationDomainRequest]) (*connect.Response[org.AddOrganizationDomainResponse], error) { + userIDs, err := s.getClaimedUserIDsOfOrgDomain(ctx, request.Msg.GetDomain(), request.Msg.GetOrganizationId()) if err != nil { return nil, err } - details, err := s.command.AddOrgDomain(ctx, request.OrganizationId, request.Domain, userIDs) + details, err := s.command.AddOrgDomain(ctx, request.Msg.GetOrganizationId(), request.Msg.GetDomain(), userIDs) if err != nil { return nil, err } - return &org.AddOrganizationDomainResponse{ + return connect.NewResponse(&org.AddOrganizationDomainResponse{ CreationDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) ListOrganizationDomains(ctx context.Context, req *org.ListOrganizationDomainsRequest) (*org.ListOrganizationDomainsResponse, error) { - queries, err := ListOrgDomainsRequestToModel(s.systemDefaults, req) +func (s *Server) ListOrganizationDomains(ctx context.Context, req *connect.Request[org.ListOrganizationDomainsRequest]) (*connect.Response[org.ListOrganizationDomainsResponse], error) { + queries, err := ListOrgDomainsRequestToModel(s.systemDefaults, req.Msg) if err != nil { return nil, err } - orgIDQuery, err := query.NewOrgDomainOrgIDSearchQuery(req.OrganizationId) + orgIDQuery, err := query.NewOrgDomainOrgIDSearchQuery(req.Msg.GetOrganizationId()) if err != nil { return nil, err } @@ -159,48 +160,48 @@ func (s *Server) ListOrganizationDomains(ctx context.Context, req *org.ListOrgan if err != nil { return nil, err } - return &org.ListOrganizationDomainsResponse{ + return connect.NewResponse(&org.ListOrganizationDomainsResponse{ Domains: object.DomainsToPb(domains.Domains), Pagination: &filter.PaginationResponse{ TotalResult: domains.Count, - AppliedLimit: uint64(req.GetPagination().GetLimit()), + AppliedLimit: uint64(req.Msg.GetPagination().GetLimit()), }, - }, nil + }), nil } -func (s *Server) DeleteOrganizationDomain(ctx context.Context, req *org.DeleteOrganizationDomainRequest) (*org.DeleteOrganizationDomainResponse, error) { - details, err := s.command.RemoveOrgDomain(ctx, RemoveOrgDomainRequestToDomain(ctx, req)) +func (s *Server) DeleteOrganizationDomain(ctx context.Context, req *connect.Request[org.DeleteOrganizationDomainRequest]) (*connect.Response[org.DeleteOrganizationDomainResponse], error) { + details, err := s.command.RemoveOrgDomain(ctx, RemoveOrgDomainRequestToDomain(ctx, req.Msg)) if err != nil { return nil, err } - return &org.DeleteOrganizationDomainResponse{ + return connect.NewResponse(&org.DeleteOrganizationDomainResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, err + }), err } -func (s *Server) GenerateOrganizationDomainValidation(ctx context.Context, req *org.GenerateOrganizationDomainValidationRequest) (*org.GenerateOrganizationDomainValidationResponse, error) { - token, url, err := s.command.GenerateOrgDomainValidation(ctx, GenerateOrgDomainValidationRequestToDomain(ctx, req)) +func (s *Server) GenerateOrganizationDomainValidation(ctx context.Context, req *connect.Request[org.GenerateOrganizationDomainValidationRequest]) (*connect.Response[org.GenerateOrganizationDomainValidationResponse], error) { + token, url, err := s.command.GenerateOrgDomainValidation(ctx, GenerateOrgDomainValidationRequestToDomain(ctx, req.Msg)) if err != nil { return nil, err } - return &org.GenerateOrganizationDomainValidationResponse{ + return connect.NewResponse(&org.GenerateOrganizationDomainValidationResponse{ Token: token, Url: url, - }, nil + }), nil } -func (s *Server) VerifyOrganizationDomain(ctx context.Context, request *org.VerifyOrganizationDomainRequest) (*org.VerifyOrganizationDomainResponse, error) { - userIDs, err := s.getClaimedUserIDsOfOrgDomain(ctx, request.Domain, request.OrganizationId) +func (s *Server) VerifyOrganizationDomain(ctx context.Context, request *connect.Request[org.VerifyOrganizationDomainRequest]) (*connect.Response[org.VerifyOrganizationDomainResponse], error) { + userIDs, err := s.getClaimedUserIDsOfOrgDomain(ctx, request.Msg.GetDomain(), request.Msg.GetOrganizationId()) if err != nil { return nil, err } - details, err := s.command.ValidateOrgDomain(ctx, ValidateOrgDomainRequestToDomain(ctx, request), userIDs) + details, err := s.command.ValidateOrgDomain(ctx, ValidateOrgDomainRequestToDomain(ctx, request.Msg), userIDs) if err != nil { return nil, err } - return &org.VerifyOrganizationDomainResponse{ + return connect.NewResponse(&org.VerifyOrganizationDomainResponse{ ChangeDate: timestamppb.New(details.EventDate), - }, nil + }), nil } func createOrganizationRequestToCommand(request *v2beta_org.CreateOrganizationRequest) (*command.OrgSetup, error) { diff --git a/internal/api/grpc/org/v2beta/org_test.go b/internal/api/grpc/org/v2beta/org_test.go index 346d6b88c1..85dec79be4 100644 --- a/internal/api/grpc/org/v2beta/org_test.go +++ b/internal/api/grpc/org/v2beta/org_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -138,7 +139,7 @@ func Test_createdOrganizationToPb(t *testing.T) { tests := []struct { name string args args - want *org.CreateOrganizationResponse + want *connect.Response[org.CreateOrganizationResponse] wantErr error }{ { @@ -159,7 +160,7 @@ func Test_createdOrganizationToPb(t *testing.T) { }, }, }, - want: &org.CreateOrganizationResponse{ + want: connect.NewResponse(&org.CreateOrganizationResponse{ CreationDate: timestamppb.New(now), Id: "orgID", OrganizationAdmins: []*org.OrganizationAdmin{ @@ -173,7 +174,7 @@ func Test_createdOrganizationToPb(t *testing.T) { }, }, }, - }, + }), }, } for _, tt := range tests { diff --git a/internal/api/grpc/org/v2beta/server.go b/internal/api/grpc/org/v2beta/server.go index b7e8d4994f..8f9091c7c3 100644 --- a/internal/api/grpc/org/v2beta/server.go +++ b/internal/api/grpc/org/v2beta/server.go @@ -1,7 +1,10 @@ package org import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -10,12 +13,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/org/v2beta/orgconnect" ) -var _ org.OrganizationServiceServer = (*Server)(nil) +var _ orgconnect.OrganizationServiceHandler = (*Server)(nil) type Server struct { - org.UnimplementedOrganizationServiceServer systemDefaults systemdefaults.SystemDefaults command *command.Commands query *query.Queries @@ -38,8 +41,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - org.RegisterOrganizationServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return orgconnect.NewOrganizationServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return org.File_zitadel_org_v2beta_org_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/project/v2beta/project.go b/internal/api/grpc/project/v2beta/project.go index 01b478f5be..b3294f1ea6 100644 --- a/internal/api/grpc/project/v2beta/project.go +++ b/internal/api/grpc/project/v2beta/project.go @@ -3,6 +3,7 @@ package project import ( "context" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/timestamppb" @@ -13,8 +14,8 @@ import ( project_pb "github.com/zitadel/zitadel/pkg/grpc/project/v2beta" ) -func (s *Server) CreateProject(ctx context.Context, req *project_pb.CreateProjectRequest) (*project_pb.CreateProjectResponse, error) { - add := projectCreateToCommand(req) +func (s *Server) CreateProject(ctx context.Context, req *connect.Request[project_pb.CreateProjectRequest]) (*connect.Response[project_pb.CreateProjectResponse], error) { + add := projectCreateToCommand(req.Msg) project, err := s.command.AddProject(ctx, add) if err != nil { return nil, err @@ -23,10 +24,10 @@ func (s *Server) CreateProject(ctx context.Context, req *project_pb.CreateProjec if !project.EventDate.IsZero() { creationDate = timestamppb.New(project.EventDate) } - return &project_pb.CreateProjectResponse{ + return connect.NewResponse(&project_pb.CreateProjectResponse{ Id: add.AggregateID, CreationDate: creationDate, - }, nil + }), nil } func projectCreateToCommand(req *project_pb.CreateProjectRequest) *command.AddProject { @@ -60,8 +61,8 @@ func privateLabelingSettingToDomain(setting project_pb.PrivateLabelingSetting) d } } -func (s *Server) UpdateProject(ctx context.Context, req *project_pb.UpdateProjectRequest) (*project_pb.UpdateProjectResponse, error) { - project, err := s.command.ChangeProject(ctx, projectUpdateToCommand(req)) +func (s *Server) UpdateProject(ctx context.Context, req *connect.Request[project_pb.UpdateProjectRequest]) (*connect.Response[project_pb.UpdateProjectResponse], error) { + project, err := s.command.ChangeProject(ctx, projectUpdateToCommand(req.Msg)) if err != nil { return nil, err } @@ -69,9 +70,9 @@ func (s *Server) UpdateProject(ctx context.Context, req *project_pb.UpdateProjec if !project.EventDate.IsZero() { changeDate = timestamppb.New(project.EventDate) } - return &project_pb.UpdateProjectResponse{ + return connect.NewResponse(&project_pb.UpdateProjectResponse{ ChangeDate: changeDate, - }, nil + }), nil } func projectUpdateToCommand(req *project_pb.UpdateProjectRequest) *command.ChangeProject { @@ -91,13 +92,13 @@ func projectUpdateToCommand(req *project_pb.UpdateProjectRequest) *command.Chang } } -func (s *Server) DeleteProject(ctx context.Context, req *project_pb.DeleteProjectRequest) (*project_pb.DeleteProjectResponse, error) { - userGrantIDs, err := s.userGrantsFromProject(ctx, req.Id) +func (s *Server) DeleteProject(ctx context.Context, req *connect.Request[project_pb.DeleteProjectRequest]) (*connect.Response[project_pb.DeleteProjectResponse], error) { + userGrantIDs, err := s.userGrantsFromProject(ctx, req.Msg.GetId()) if err != nil { return nil, err } - deletedAt, err := s.command.DeleteProject(ctx, req.Id, "", userGrantIDs...) + deletedAt, err := s.command.DeleteProject(ctx, req.Msg.GetId(), "", userGrantIDs...) if err != nil { return nil, err } @@ -105,9 +106,9 @@ func (s *Server) DeleteProject(ctx context.Context, req *project_pb.DeleteProjec if !deletedAt.IsZero() { deletionDate = timestamppb.New(deletedAt) } - return &project_pb.DeleteProjectResponse{ + return connect.NewResponse(&project_pb.DeleteProjectResponse{ DeletionDate: deletionDate, - }, nil + }), nil } func (s *Server) userGrantsFromProject(ctx context.Context, projectID string) ([]string, error) { @@ -124,8 +125,8 @@ func (s *Server) userGrantsFromProject(ctx context.Context, projectID string) ([ return userGrantsToIDs(userGrants.UserGrants), nil } -func (s *Server) DeactivateProject(ctx context.Context, req *project_pb.DeactivateProjectRequest) (*project_pb.DeactivateProjectResponse, error) { - details, err := s.command.DeactivateProject(ctx, req.Id, "") +func (s *Server) DeactivateProject(ctx context.Context, req *connect.Request[project_pb.DeactivateProjectRequest]) (*connect.Response[project_pb.DeactivateProjectResponse], error) { + details, err := s.command.DeactivateProject(ctx, req.Msg.GetId(), "") if err != nil { return nil, err } @@ -133,13 +134,13 @@ func (s *Server) DeactivateProject(ctx context.Context, req *project_pb.Deactiva if !details.EventDate.IsZero() { changeDate = timestamppb.New(details.EventDate) } - return &project_pb.DeactivateProjectResponse{ + return connect.NewResponse(&project_pb.DeactivateProjectResponse{ ChangeDate: changeDate, - }, nil + }), nil } -func (s *Server) ActivateProject(ctx context.Context, req *project_pb.ActivateProjectRequest) (*project_pb.ActivateProjectResponse, error) { - details, err := s.command.ReactivateProject(ctx, req.Id, "") +func (s *Server) ActivateProject(ctx context.Context, req *connect.Request[project_pb.ActivateProjectRequest]) (*connect.Response[project_pb.ActivateProjectResponse], error) { + details, err := s.command.ReactivateProject(ctx, req.Msg.GetId(), "") if err != nil { return nil, err } @@ -147,9 +148,9 @@ func (s *Server) ActivateProject(ctx context.Context, req *project_pb.ActivatePr if !details.EventDate.IsZero() { changeDate = timestamppb.New(details.EventDate) } - return &project_pb.ActivateProjectResponse{ + return connect.NewResponse(&project_pb.ActivateProjectResponse{ ChangeDate: changeDate, - }, nil + }), nil } func userGrantsToIDs(userGrants []*query.UserGrant) []string { diff --git a/internal/api/grpc/project/v2beta/project_grant.go b/internal/api/grpc/project/v2beta/project_grant.go index 6c3b195c66..555d4bfd27 100644 --- a/internal/api/grpc/project/v2beta/project_grant.go +++ b/internal/api/grpc/project/v2beta/project_grant.go @@ -3,6 +3,7 @@ package project import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/command" @@ -11,8 +12,8 @@ import ( project_pb "github.com/zitadel/zitadel/pkg/grpc/project/v2beta" ) -func (s *Server) CreateProjectGrant(ctx context.Context, req *project_pb.CreateProjectGrantRequest) (*project_pb.CreateProjectGrantResponse, error) { - add := projectGrantCreateToCommand(req) +func (s *Server) CreateProjectGrant(ctx context.Context, req *connect.Request[project_pb.CreateProjectGrantRequest]) (*connect.Response[project_pb.CreateProjectGrantResponse], error) { + add := projectGrantCreateToCommand(req.Msg) project, err := s.command.AddProjectGrant(ctx, add) if err != nil { return nil, err @@ -21,9 +22,9 @@ func (s *Server) CreateProjectGrant(ctx context.Context, req *project_pb.CreateP if !project.EventDate.IsZero() { creationDate = timestamppb.New(project.EventDate) } - return &project_pb.CreateProjectGrantResponse{ + return connect.NewResponse(&project_pb.CreateProjectGrantResponse{ CreationDate: creationDate, - }, nil + }), nil } func projectGrantCreateToCommand(req *project_pb.CreateProjectGrantRequest) *command.AddProjectGrant { @@ -37,8 +38,8 @@ func projectGrantCreateToCommand(req *project_pb.CreateProjectGrantRequest) *com } } -func (s *Server) UpdateProjectGrant(ctx context.Context, req *project_pb.UpdateProjectGrantRequest) (*project_pb.UpdateProjectGrantResponse, error) { - project, err := s.command.ChangeProjectGrant(ctx, projectGrantUpdateToCommand(req)) +func (s *Server) UpdateProjectGrant(ctx context.Context, req *connect.Request[project_pb.UpdateProjectGrantRequest]) (*connect.Response[project_pb.UpdateProjectGrantResponse], error) { + project, err := s.command.ChangeProjectGrant(ctx, projectGrantUpdateToCommand(req.Msg)) if err != nil { return nil, err } @@ -46,9 +47,9 @@ func (s *Server) UpdateProjectGrant(ctx context.Context, req *project_pb.UpdateP if !project.EventDate.IsZero() { changeDate = timestamppb.New(project.EventDate) } - return &project_pb.UpdateProjectGrantResponse{ + return connect.NewResponse(&project_pb.UpdateProjectGrantResponse{ ChangeDate: changeDate, - }, nil + }), nil } func projectGrantUpdateToCommand(req *project_pb.UpdateProjectGrantRequest) *command.ChangeProjectGrant { @@ -61,8 +62,8 @@ func projectGrantUpdateToCommand(req *project_pb.UpdateProjectGrantRequest) *com } } -func (s *Server) DeactivateProjectGrant(ctx context.Context, req *project_pb.DeactivateProjectGrantRequest) (*project_pb.DeactivateProjectGrantResponse, error) { - details, err := s.command.DeactivateProjectGrant(ctx, req.ProjectId, "", req.GrantedOrganizationId, "") +func (s *Server) DeactivateProjectGrant(ctx context.Context, req *connect.Request[project_pb.DeactivateProjectGrantRequest]) (*connect.Response[project_pb.DeactivateProjectGrantResponse], error) { + details, err := s.command.DeactivateProjectGrant(ctx, req.Msg.GetProjectId(), "", req.Msg.GetGrantedOrganizationId(), "") if err != nil { return nil, err } @@ -70,13 +71,13 @@ func (s *Server) DeactivateProjectGrant(ctx context.Context, req *project_pb.Dea if !details.EventDate.IsZero() { changeDate = timestamppb.New(details.EventDate) } - return &project_pb.DeactivateProjectGrantResponse{ + return connect.NewResponse(&project_pb.DeactivateProjectGrantResponse{ ChangeDate: changeDate, - }, nil + }), nil } -func (s *Server) ActivateProjectGrant(ctx context.Context, req *project_pb.ActivateProjectGrantRequest) (*project_pb.ActivateProjectGrantResponse, error) { - details, err := s.command.ReactivateProjectGrant(ctx, req.ProjectId, "", req.GrantedOrganizationId, "") +func (s *Server) ActivateProjectGrant(ctx context.Context, req *connect.Request[project_pb.ActivateProjectGrantRequest]) (*connect.Response[project_pb.ActivateProjectGrantResponse], error) { + details, err := s.command.ReactivateProjectGrant(ctx, req.Msg.GetProjectId(), "", req.Msg.GetGrantedOrganizationId(), "") if err != nil { return nil, err } @@ -84,17 +85,17 @@ func (s *Server) ActivateProjectGrant(ctx context.Context, req *project_pb.Activ if !details.EventDate.IsZero() { changeDate = timestamppb.New(details.EventDate) } - return &project_pb.ActivateProjectGrantResponse{ + return connect.NewResponse(&project_pb.ActivateProjectGrantResponse{ ChangeDate: changeDate, - }, nil + }), nil } -func (s *Server) DeleteProjectGrant(ctx context.Context, req *project_pb.DeleteProjectGrantRequest) (*project_pb.DeleteProjectGrantResponse, error) { - userGrantIDs, err := s.userGrantsFromProjectGrant(ctx, req.ProjectId, req.GrantedOrganizationId) +func (s *Server) DeleteProjectGrant(ctx context.Context, req *connect.Request[project_pb.DeleteProjectGrantRequest]) (*connect.Response[project_pb.DeleteProjectGrantResponse], error) { + userGrantIDs, err := s.userGrantsFromProjectGrant(ctx, req.Msg.GetProjectId(), req.Msg.GetGrantedOrganizationId()) if err != nil { return nil, err } - details, err := s.command.DeleteProjectGrant(ctx, req.ProjectId, "", req.GrantedOrganizationId, "", userGrantIDs...) + details, err := s.command.DeleteProjectGrant(ctx, req.Msg.GetProjectId(), "", req.Msg.GetGrantedOrganizationId(), "", userGrantIDs...) if err != nil { return nil, err } @@ -102,9 +103,9 @@ func (s *Server) DeleteProjectGrant(ctx context.Context, req *project_pb.DeleteP if !details.EventDate.IsZero() { deletionDate = timestamppb.New(details.EventDate) } - return &project_pb.DeleteProjectGrantResponse{ + return connect.NewResponse(&project_pb.DeleteProjectGrantResponse{ DeletionDate: deletionDate, - }, nil + }), nil } func (s *Server) userGrantsFromProjectGrant(ctx context.Context, projectID, grantedOrganizationID string) ([]string, error) { diff --git a/internal/api/grpc/project/v2beta/project_role.go b/internal/api/grpc/project/v2beta/project_role.go index 07fc4e9eac..2316ef4028 100644 --- a/internal/api/grpc/project/v2beta/project_role.go +++ b/internal/api/grpc/project/v2beta/project_role.go @@ -3,6 +3,7 @@ package project import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/command" @@ -11,8 +12,8 @@ import ( project_pb "github.com/zitadel/zitadel/pkg/grpc/project/v2beta" ) -func (s *Server) AddProjectRole(ctx context.Context, req *project_pb.AddProjectRoleRequest) (*project_pb.AddProjectRoleResponse, error) { - role, err := s.command.AddProjectRole(ctx, addProjectRoleRequestToCommand(req)) +func (s *Server) AddProjectRole(ctx context.Context, req *connect.Request[project_pb.AddProjectRoleRequest]) (*connect.Response[project_pb.AddProjectRoleResponse], error) { + role, err := s.command.AddProjectRole(ctx, addProjectRoleRequestToCommand(req.Msg)) if err != nil { return nil, err } @@ -20,9 +21,9 @@ func (s *Server) AddProjectRole(ctx context.Context, req *project_pb.AddProjectR if !role.EventDate.IsZero() { creationDate = timestamppb.New(role.EventDate) } - return &project_pb.AddProjectRoleResponse{ + return connect.NewResponse(&project_pb.AddProjectRoleResponse{ CreationDate: creationDate, - }, nil + }), nil } func addProjectRoleRequestToCommand(req *project_pb.AddProjectRoleRequest) *command.AddProjectRole { @@ -41,8 +42,8 @@ func addProjectRoleRequestToCommand(req *project_pb.AddProjectRoleRequest) *comm } } -func (s *Server) UpdateProjectRole(ctx context.Context, req *project_pb.UpdateProjectRoleRequest) (*project_pb.UpdateProjectRoleResponse, error) { - role, err := s.command.ChangeProjectRole(ctx, updateProjectRoleRequestToCommand(req)) +func (s *Server) UpdateProjectRole(ctx context.Context, req *connect.Request[project_pb.UpdateProjectRoleRequest]) (*connect.Response[project_pb.UpdateProjectRoleResponse], error) { + role, err := s.command.ChangeProjectRole(ctx, updateProjectRoleRequestToCommand(req.Msg)) if err != nil { return nil, err } @@ -50,9 +51,9 @@ func (s *Server) UpdateProjectRole(ctx context.Context, req *project_pb.UpdatePr if !role.EventDate.IsZero() { changeDate = timestamppb.New(role.EventDate) } - return &project_pb.UpdateProjectRoleResponse{ + return connect.NewResponse(&project_pb.UpdateProjectRoleResponse{ ChangeDate: changeDate, - }, nil + }), nil } func updateProjectRoleRequestToCommand(req *project_pb.UpdateProjectRoleRequest) *command.ChangeProjectRole { @@ -75,16 +76,16 @@ func updateProjectRoleRequestToCommand(req *project_pb.UpdateProjectRoleRequest) } } -func (s *Server) RemoveProjectRole(ctx context.Context, req *project_pb.RemoveProjectRoleRequest) (*project_pb.RemoveProjectRoleResponse, error) { - userGrantIDs, err := s.userGrantsFromProjectAndRole(ctx, req.ProjectId, req.RoleKey) +func (s *Server) RemoveProjectRole(ctx context.Context, req *connect.Request[project_pb.RemoveProjectRoleRequest]) (*connect.Response[project_pb.RemoveProjectRoleResponse], error) { + userGrantIDs, err := s.userGrantsFromProjectAndRole(ctx, req.Msg.GetProjectId(), req.Msg.GetRoleKey()) if err != nil { return nil, err } - projectGrantIDs, err := s.projectGrantsFromProjectAndRole(ctx, req.ProjectId, req.RoleKey) + projectGrantIDs, err := s.projectGrantsFromProjectAndRole(ctx, req.Msg.GetProjectId(), req.Msg.GetRoleKey()) if err != nil { return nil, err } - details, err := s.command.RemoveProjectRole(ctx, req.ProjectId, req.RoleKey, "", projectGrantIDs, userGrantIDs...) + details, err := s.command.RemoveProjectRole(ctx, req.Msg.GetProjectId(), req.Msg.GetRoleKey(), "", projectGrantIDs, userGrantIDs...) if err != nil { return nil, err } @@ -92,9 +93,9 @@ func (s *Server) RemoveProjectRole(ctx context.Context, req *project_pb.RemovePr if !details.EventDate.IsZero() { deletionDate = timestamppb.New(details.EventDate) } - return &project_pb.RemoveProjectRoleResponse{ + return connect.NewResponse(&project_pb.RemoveProjectRoleResponse{ RemovalDate: deletionDate, - }, nil + }), nil } func (s *Server) userGrantsFromProjectAndRole(ctx context.Context, projectID, roleKey string) ([]string, error) { diff --git a/internal/api/grpc/project/v2beta/query.go b/internal/api/grpc/project/v2beta/query.go index 42b69a480e..c736c5a086 100644 --- a/internal/api/grpc/project/v2beta/query.go +++ b/internal/api/grpc/project/v2beta/query.go @@ -3,6 +3,7 @@ package project import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" filter "github.com/zitadel/zitadel/internal/api/grpc/filter/v2beta" @@ -13,18 +14,18 @@ import ( project_pb "github.com/zitadel/zitadel/pkg/grpc/project/v2beta" ) -func (s *Server) GetProject(ctx context.Context, req *project_pb.GetProjectRequest) (*project_pb.GetProjectResponse, error) { - project, err := s.query.GetProjectByIDWithPermission(ctx, true, req.Id, s.checkPermission) +func (s *Server) GetProject(ctx context.Context, req *connect.Request[project_pb.GetProjectRequest]) (*connect.Response[project_pb.GetProjectResponse], error) { + project, err := s.query.GetProjectByIDWithPermission(ctx, true, req.Msg.GetId(), s.checkPermission) if err != nil { return nil, err } - return &project_pb.GetProjectResponse{ + return connect.NewResponse(&project_pb.GetProjectResponse{ Project: projectToPb(project), - }, nil + }), nil } -func (s *Server) ListProjects(ctx context.Context, req *project_pb.ListProjectsRequest) (*project_pb.ListProjectsResponse, error) { - queries, err := s.listProjectRequestToModel(req) +func (s *Server) ListProjects(ctx context.Context, req *connect.Request[project_pb.ListProjectsRequest]) (*connect.Response[project_pb.ListProjectsResponse], error) { + queries, err := s.listProjectRequestToModel(req.Msg) if err != nil { return nil, err } @@ -32,10 +33,10 @@ func (s *Server) ListProjects(ctx context.Context, req *project_pb.ListProjectsR if err != nil { return nil, err } - return &project_pb.ListProjectsResponse{ + return connect.NewResponse(&project_pb.ListProjectsResponse{ Projects: grantedProjectsToPb(resp.GrantedProjects), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, resp.SearchResponse), - }, nil + }), nil } func (s *Server) listProjectRequestToModel(req *project_pb.ListProjectsRequest) (*query.ProjectAndGrantedProjectSearchQueries, error) { @@ -213,8 +214,8 @@ func privateLabelingSettingToPb(setting domain.PrivateLabelingSetting) project_p } } -func (s *Server) ListProjectGrants(ctx context.Context, req *project_pb.ListProjectGrantsRequest) (*project_pb.ListProjectGrantsResponse, error) { - queries, err := s.listProjectGrantsRequestToModel(req) +func (s *Server) ListProjectGrants(ctx context.Context, req *connect.Request[project_pb.ListProjectGrantsRequest]) (*connect.Response[project_pb.ListProjectGrantsResponse], error) { + queries, err := s.listProjectGrantsRequestToModel(req.Msg) if err != nil { return nil, err } @@ -222,10 +223,10 @@ func (s *Server) ListProjectGrants(ctx context.Context, req *project_pb.ListProj if err != nil { return nil, err } - return &project_pb.ListProjectGrantsResponse{ + return connect.NewResponse(&project_pb.ListProjectGrantsResponse{ ProjectGrants: projectGrantsToPb(resp.ProjectGrants), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, resp.SearchResponse), - }, nil + }), nil } func (s *Server) listProjectGrantsRequestToModel(req *project_pb.ListProjectGrantsRequest) (*query.ProjectGrantSearchQueries, error) { @@ -329,12 +330,12 @@ func projectGrantStateToPb(state domain.ProjectGrantState) project_pb.ProjectGra } } -func (s *Server) ListProjectRoles(ctx context.Context, req *project_pb.ListProjectRolesRequest) (*project_pb.ListProjectRolesResponse, error) { - queries, err := s.listProjectRolesRequestToModel(req) +func (s *Server) ListProjectRoles(ctx context.Context, req *connect.Request[project_pb.ListProjectRolesRequest]) (*connect.Response[project_pb.ListProjectRolesResponse], error) { + queries, err := s.listProjectRolesRequestToModel(req.Msg) if err != nil { return nil, err } - err = queries.AppendProjectIDQuery(req.ProjectId) + err = queries.AppendProjectIDQuery(req.Msg.GetProjectId()) if err != nil { return nil, err } @@ -342,10 +343,10 @@ func (s *Server) ListProjectRoles(ctx context.Context, req *project_pb.ListProje if err != nil { return nil, err } - return &project_pb.ListProjectRolesResponse{ + return connect.NewResponse(&project_pb.ListProjectRolesResponse{ ProjectRoles: roleViewsToPb(roles.ProjectRoles), Pagination: filter.QueryToPaginationPb(queries.SearchRequest, roles.SearchResponse), - }, nil + }), nil } func (s *Server) listProjectRolesRequestToModel(req *project_pb.ListProjectRolesRequest) (*query.ProjectRoleSearchQueries, error) { diff --git a/internal/api/grpc/project/v2beta/server.go b/internal/api/grpc/project/v2beta/server.go index fe197f9688..12c18ae4c6 100644 --- a/internal/api/grpc/project/v2beta/server.go +++ b/internal/api/grpc/project/v2beta/server.go @@ -1,21 +1,23 @@ package project import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" project "github.com/zitadel/zitadel/pkg/grpc/project/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/project/v2beta/projectconnect" ) -var _ project.ProjectServiceServer = (*Server)(nil) +var _ projectconnect.ProjectServiceHandler = (*Server)(nil) type Server struct { - project.UnimplementedProjectServiceServer systemDefaults systemdefaults.SystemDefaults command *command.Commands query *query.Queries @@ -39,8 +41,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - project.RegisterProjectServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return projectconnect.NewProjectServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return project.File_zitadel_project_v2beta_project_service_proto } func (s *Server) AppName() string { @@ -54,7 +60,3 @@ func (s *Server) MethodPrefix() string { func (s *Server) AuthMethods() authz.MethodMapping { return project.ProjectService_AuthMethods } - -func (s *Server) RegisterGateway() server.RegisterGatewayFunc { - return project.RegisterProjectServiceHandler -} diff --git a/internal/api/grpc/saml/v2/saml.go b/internal/api/grpc/saml/v2/saml.go index 43eae5feb1..5491a5e04b 100644 --- a/internal/api/grpc/saml/v2/saml.go +++ b/internal/api/grpc/saml/v2/saml.go @@ -3,6 +3,7 @@ package saml import ( "context" + "connectrpc.com/connect" "github.com/zitadel/logging" "github.com/zitadel/saml/pkg/provider" "google.golang.org/protobuf/types/known/timestamppb" @@ -16,15 +17,15 @@ import ( saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" ) -func (s *Server) GetSAMLRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) { - authRequest, err := s.query.SamlRequestByID(ctx, true, req.GetSamlRequestId(), true) +func (s *Server) GetSAMLRequest(ctx context.Context, req *connect.Request[saml_pb.GetSAMLRequestRequest]) (*connect.Response[saml_pb.GetSAMLRequestResponse], error) { + authRequest, err := s.query.SamlRequestByID(ctx, true, req.Msg.GetSamlRequestId(), true) if err != nil { logging.WithError(err).Error("query samlRequest by ID") return nil, err } - return &saml_pb.GetSAMLRequestResponse{ + return connect.NewResponse(&saml_pb.GetSAMLRequestResponse{ SamlRequest: samlRequestToPb(authRequest), - }, nil + }), nil } func samlRequestToPb(a *query.SamlRequest) *saml_pb.SAMLRequest { @@ -34,18 +35,18 @@ func samlRequestToPb(a *query.SamlRequest) *saml_pb.SAMLRequest { } } -func (s *Server) CreateResponse(ctx context.Context, req *saml_pb.CreateResponseRequest) (*saml_pb.CreateResponseResponse, error) { - switch v := req.GetResponseKind().(type) { +func (s *Server) CreateResponse(ctx context.Context, req *connect.Request[saml_pb.CreateResponseRequest]) (*connect.Response[saml_pb.CreateResponseResponse], error) { + switch v := req.Msg.GetResponseKind().(type) { case *saml_pb.CreateResponseRequest_Error: - return s.failSAMLRequest(ctx, req.GetSamlRequestId(), v.Error) + return s.failSAMLRequest(ctx, req.Msg.GetSamlRequestId(), v.Error) case *saml_pb.CreateResponseRequest_Session: - return s.linkSessionToSAMLRequest(ctx, req.GetSamlRequestId(), v.Session) + return s.linkSessionToSAMLRequest(ctx, req.Msg.GetSamlRequestId(), v.Session) default: return nil, zerrors.ThrowUnimplementedf(nil, "SAMLv2-0Tfak3fBS0", "verification oneOf %T in method CreateResponse not implemented", v) } } -func (s *Server) failSAMLRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateResponseResponse, error) { +func (s *Server) failSAMLRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*connect.Response[saml_pb.CreateResponseResponse], error) { details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError())) if err != nil { return nil, err @@ -55,7 +56,7 @@ func (s *Server) failSAMLRequest(ctx context.Context, samlRequestID string, ae * if err != nil { return nil, err } - return createCallbackResponseFromBinding(details, url, body, authReq.RelayState), nil + return connect.NewResponse(createCallbackResponseFromBinding(details, url, body, authReq.RelayState)), nil } func (s *Server) checkPermission(ctx context.Context, issuer string, userID string) error { @@ -72,7 +73,7 @@ func (s *Server) checkPermission(ctx context.Context, issuer string, userID stri return nil } -func (s *Server) linkSessionToSAMLRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateResponseResponse, error) { +func (s *Server) linkSessionToSAMLRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*connect.Response[saml_pb.CreateResponseResponse], error) { details, aar, err := s.command.LinkSessionToSAMLRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true, s.checkPermission) if err != nil { return nil, err @@ -87,7 +88,7 @@ func (s *Server) linkSessionToSAMLRequest(ctx context.Context, samlRequestID str if err != nil { return nil, err } - return createCallbackResponseFromBinding(details, url, body, authReq.RelayState), nil + return connect.NewResponse(createCallbackResponseFromBinding(details, url, body, authReq.RelayState)), nil } func createCallbackResponseFromBinding(details *domain.ObjectDetails, url string, body string, relayState string) *saml_pb.CreateResponseResponse { diff --git a/internal/api/grpc/saml/v2/server.go b/internal/api/grpc/saml/v2/server.go index 62299d88c5..312a7c356a 100644 --- a/internal/api/grpc/saml/v2/server.go +++ b/internal/api/grpc/saml/v2/server.go @@ -1,7 +1,10 @@ package saml import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,9 +12,10 @@ import ( "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" + "github.com/zitadel/zitadel/pkg/grpc/saml/v2/samlconnect" ) -var _ saml_pb.SAMLServiceServer = (*Server)(nil) +var _ samlconnect.SAMLServiceHandler = (*Server)(nil) type Server struct { saml_pb.UnimplementedSAMLServiceServer @@ -38,8 +42,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - saml_pb.RegisterSAMLServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return samlconnect.NewSAMLServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return saml_pb.File_zitadel_saml_v2_saml_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/server/connect_middleware/access_interceptor.go b/internal/api/grpc/server/connect_middleware/access_interceptor.go new file mode 100644 index 0000000000..a08df59860 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/access_interceptor.go @@ -0,0 +1,57 @@ +package connect_middleware + +import ( + "context" + "net/http" + "time" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" + http_util "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) { + if !svc.Enabled() { + return handler(ctx, req) + } + resp, handlerErr := handler(ctx, req) + + interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx) + defer func() { span.EndWithError(err) }() + + var respStatus uint32 + if code := connect.CodeOf(handlerErr); code != connect.CodeUnknown { + respStatus = uint32(code) + } + + respHeader := http.Header{} + if resp != nil { + respHeader = resp.Header() + } + instance := authz.GetInstance(ctx) + domainCtx := http_util.DomainContext(ctx) + + r := &record.AccessLog{ + LogDate: time.Now(), + Protocol: record.GRPC, + RequestURL: req.Spec().Procedure, + ResponseStatus: respStatus, + RequestHeaders: req.Header(), + ResponseHeaders: respHeader, + InstanceID: instance.InstanceID(), + ProjectID: instance.ProjectID(), + RequestedDomain: domainCtx.RequestedDomain(), + RequestedHost: domainCtx.RequestedHost(), + } + + svc.Handle(interceptorCtx, r) + return resp, handlerErr + } + } +} diff --git a/internal/api/grpc/server/connect_middleware/activity_interceptor.go b/internal/api/grpc/server/connect_middleware/activity_interceptor.go new file mode 100644 index 0000000000..4ba6044645 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/activity_interceptor.go @@ -0,0 +1,52 @@ +package connect_middleware + +import ( + "context" + "net/http" + "slices" + "strings" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/activity" + "github.com/zitadel/zitadel/internal/api/grpc/gerrors" + ainfo "github.com/zitadel/zitadel/internal/api/info" +) + +func ActivityInterceptor() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + ctx = activityInfoFromGateway(ctx, req.Header()).SetMethod(req.Spec().Procedure).IntoContext(ctx) + resp, err := handler(ctx, req) + if isResourceAPI(req.Spec().Procedure) { + code, _, _, _ := gerrors.ExtractZITADELError(err) + ctx = ainfo.ActivityInfoFromContext(ctx).SetGRPCStatus(code).IntoContext(ctx) + activity.TriggerGRPCWithContext(ctx, activity.ResourceAPI) + } + return resp, err + } + } +} + +var resourcePrefixes = []string{ + "/zitadel.management.v1.ManagementService/", + "/zitadel.admin.v1.AdminService/", + "/zitadel.user.v2.UserService/", + "/zitadel.settings.v2.SettingsService/", + "/zitadel.user.v2beta.UserService/", + "/zitadel.settings.v2beta.SettingsService/", + "/zitadel.auth.v1.AuthService/", +} + +func isResourceAPI(method string) bool { + return slices.ContainsFunc(resourcePrefixes, func(prefix string) bool { + return strings.HasPrefix(method, prefix) + }) +} + +func activityInfoFromGateway(ctx context.Context, headers http.Header) *ainfo.ActivityInfo { + info := ainfo.ActivityInfoFromContext(ctx) + path := headers.Get(activity.PathKey) + requestMethod := headers.Get(activity.RequestMethodKey) + return info.SetPath(path).SetRequestMethod(requestMethod) +} diff --git a/internal/api/grpc/server/connect_middleware/auth_interceptor.go b/internal/api/grpc/server/connect_middleware/auth_interceptor.go new file mode 100644 index 0000000000..9e500601d0 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/auth_interceptor.go @@ -0,0 +1,65 @@ +package connect_middleware + +import ( + "context" + "errors" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +func AuthorizationInterceptor(verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return authorize(ctx, req, handler, verifier, systemUserPermissions, authConfig) + } + } +} + +func authorize(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) (_ connect.AnyResponse, err error) { + authOpt, needsToken := verifier.CheckAuthMethod(req.Spec().Procedure) + if !needsToken { + return handler(ctx, req) + } + + authCtx, span := tracing.NewServerInterceptorSpan(ctx) + defer func() { span.EndWithError(err) }() + + authToken := req.Header().Get(http.Authorization) + if authToken == "" { + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("auth header missing")) + } + + orgID, orgDomain := orgIDAndDomainFromRequest(req) + ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, orgDomain, verifier, systemUserPermissions.RolePermissionMappings, authConfig.RolePermissionMappings, authOpt, req.Spec().Procedure) + if err != nil { + return nil, err + } + span.End() + return handler(ctxSetter(ctx), req) +} + +func orgIDAndDomainFromRequest(req connect.AnyRequest) (id, domain string) { + orgID := req.Header().Get(http.ZitadelOrgID) + oz, ok := req.Any().(OrganizationFromRequest) + if ok { + id = oz.OrganizationFromRequestConnect().ID + domain = oz.OrganizationFromRequestConnect().Domain + if id != "" || domain != "" { + return id, domain + } + } + return orgID, domain +} + +type Organization struct { + ID string + Domain string +} + +type OrganizationFromRequest interface { + OrganizationFromRequestConnect() *Organization +} diff --git a/internal/api/grpc/server/connect_middleware/auth_interceptor_test.go b/internal/api/grpc/server/connect_middleware/auth_interceptor_test.go new file mode 100644 index 0000000000..06e716c140 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/auth_interceptor_test.go @@ -0,0 +1,318 @@ +package connect_middleware + +import ( + "context" + "errors" + "net/http" + "reflect" + "testing" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const anAPIRole = "AN_API_ROLE" + +type authzRepoMock struct{} + +func (v *authzRepoMock) VerifyAccessToken(ctx context.Context, token, clientID, projectID string) (string, string, string, string, string, error) { + return "", "", "", "", "", nil +} + +func (v *authzRepoMock) SearchMyMemberships(ctx context.Context, orgID string, _ bool) ([]*authz.Membership, error) { + return authz.Memberships{{ + MemberType: authz.MemberTypeOrganization, + AggregateID: orgID, + Roles: []string{anAPIRole}, + }}, nil +} + +func (v *authzRepoMock) ProjectIDAndOriginsByClientID(ctx context.Context, clientID string) (string, []string, error) { + return "", nil, nil +} + +func (v *authzRepoMock) ExistsOrg(ctx context.Context, orgID, domain string) (string, error) { + return orgID, nil +} + +func (v *authzRepoMock) VerifierClientID(ctx context.Context, appName string) (string, string, error) { + return "", "", nil +} + +var ( + accessTokenOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) { + return "user1", "", "", "", "org1", nil + }) + accessTokenNOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) { + return "", "", "", "", "", zerrors.ThrowUnauthenticated(nil, "TEST-fQHDI", "unauthenticaded") + }) + systemTokenNOK = authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) { + return nil, "", errors.New("system token error") + }) +) + +type mockOrgFromRequest struct { + id string +} + +func (m *mockOrgFromRequest) OrganizationFromRequestConnect() *Organization { + return &Organization{ + ID: m.id, + Domain: "", + } +} + +func Test_authorize(t *testing.T) { + type args struct { + ctx context.Context + req connect.AnyRequest + handler func(t *testing.T) connect.UnaryFunc + verifier func() authz.APITokenVerifier + authConfig authz.Config + } + type res struct { + want interface{} + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + "no token needed ok", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/no/token/needed"}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{}) + return verifier + }, + }, + res{ + &connect.Response[struct{}]{}, + false, + }, + }, + { + "auth header missing error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication"}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}}) + return verifier + }, + authConfig: authz.Config{}, + }, + res{ + nil, + true, + }, + }, + { + "unauthorized error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"wrong"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}}) + return verifier + }, + authConfig: authz.Config{}, + }, + res{ + nil, + true, + }, + }, + { + "authorized ok", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{ + UserID: "user1", + OrgID: "org1", + ResourceOwner: "org1", + }), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}}) + return verifier + }, + authConfig: authz.Config{}, + }, + res{ + &connect.Response[struct{}]{}, + false, + }, + }, + { + "authorized ok, org by request", + args{ + ctx: context.Background(), + req: &mockReq[mockOrgFromRequest]{ + Request: connect.Request[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}}, + procedure: "/need/authentication", + header: http.Header{"Authorization": []string{"Bearer token"}}, + }, + handler: emptyMockHandler(&connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}}, authz.CtxData{ + UserID: "user1", + OrgID: "id", + ResourceOwner: "org1", + }), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}}) + return verifier + }, + authConfig: authz.Config{}, + }, + res{ + &connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}}, + false, + }, + }, + { + "permission denied error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{ + UserID: "user1", + OrgID: "org1", + ResourceOwner: "org1", + }), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}}) + return verifier + }, + authConfig: authz.Config{ + RolePermissionMappings: []authz.RoleMapping{{ + Role: anAPIRole, + Permissions: []string{"to.do.something.else"}, + }}, + }, + }, + res{ + nil, + true, + }, + }, + { + "permission ok", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{ + UserID: "user1", + OrgID: "org1", + ResourceOwner: "org1", + }), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}}) + return verifier + }, + authConfig: authz.Config{ + RolePermissionMappings: []authz.RoleMapping{{ + Role: anAPIRole, + Permissions: []string{"to.do.something"}, + }}, + }, + }, + res{ + &connect.Response[struct{}]{}, + false, + }, + }, + { + "system token permission denied error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) { + return authz.Memberships{{ + MemberType: authz.MemberTypeSystem, + Roles: []string{"A_SYSTEM_ROLE"}, + }}, "systemuser", nil + })) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}}) + return verifier + }, + authConfig: authz.Config{ + RolePermissionMappings: []authz.RoleMapping{{ + Role: "A_SYSTEM_ROLE", + Permissions: []string{"to.do.something.else"}, + }}, + }, + }, + res{ + nil, + true, + }, + }, + { + "system token permission denied error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{ + UserID: "systemuser", + SystemMemberships: authz.Memberships{{ + MemberType: authz.MemberTypeSystem, + Roles: []string{"A_SYSTEM_ROLE"}, + }}, + SystemUserPermissions: []authz.SystemUserPermissions{{ + MemberType: authz.MemberTypeSystem, + Permissions: []string{"to.do.something"}, + }}, + }), + verifier: func() authz.APITokenVerifier { + verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) { + return authz.Memberships{{ + MemberType: authz.MemberTypeSystem, + Roles: []string{"A_SYSTEM_ROLE"}, + }}, "systemuser", nil + })) + verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}}) + return verifier + }, + authConfig: authz.Config{ + RolePermissionMappings: []authz.RoleMapping{{ + Role: "A_SYSTEM_ROLE", + Permissions: []string{"to.do.something"}, + }}, + }, + }, + res{ + &connect.Response[struct{}]{}, + false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := authorize(tt.args.ctx, tt.args.req, tt.args.handler(t), tt.args.verifier(), tt.args.authConfig, tt.args.authConfig) + if (err != nil) != tt.res.wantErr { + t.Errorf("authorize() error = %v, wantErr %v", err, tt.res.wantErr) + return + } + if !reflect.DeepEqual(got, tt.res.want) { + t.Errorf("authorize() got = %v, want %v", got, tt.res.want) + } + }) + } +} diff --git a/internal/api/grpc/server/connect_middleware/cache_interceptor.go b/internal/api/grpc/server/connect_middleware/cache_interceptor.go new file mode 100644 index 0000000000..60ba0032f1 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/cache_interceptor.go @@ -0,0 +1,31 @@ +package connect_middleware + +import ( + "context" + "net/http" + "time" + + "connectrpc.com/connect" + + _ "github.com/zitadel/zitadel/internal/statik" +) + +func NoCacheInterceptor() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + headers := map[string]string{ + "cache-control": "no-store", + "expires": time.Now().UTC().Format(http.TimeFormat), + "pragma": "no-cache", + } + resp, err := handler(ctx, req) + if err != nil { + return nil, err + } + for key, value := range headers { + resp.Header().Set(key, value) + } + return resp, err + } + } +} diff --git a/internal/api/grpc/server/connect_middleware/call_interceptor.go b/internal/api/grpc/server/connect_middleware/call_interceptor.go new file mode 100644 index 0000000000..cc74e10f85 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/call_interceptor.go @@ -0,0 +1,18 @@ +package connect_middleware + +import ( + "context" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/call" +) + +func CallDurationHandler() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + ctx = call.WithTimestamp(ctx) + return handler(ctx, req) + } + } +} diff --git a/internal/api/grpc/server/connect_middleware/error_interceptor.go b/internal/api/grpc/server/connect_middleware/error_interceptor.go new file mode 100644 index 0000000000..9aef95bc6d --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/error_interceptor.go @@ -0,0 +1,23 @@ +package connect_middleware + +import ( + "context" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/grpc/gerrors" + _ "github.com/zitadel/zitadel/internal/statik" +) + +func ErrorHandler() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return toConnectError(ctx, req, handler) + } + } +} + +func toConnectError(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) { + resp, err := handler(ctx, req) + return resp, gerrors.ZITADELToConnectError(err) // TODO ! +} diff --git a/internal/api/grpc/server/connect_middleware/error_interceptor_test.go b/internal/api/grpc/server/connect_middleware/error_interceptor_test.go new file mode 100644 index 0000000000..954f2fd58f --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/error_interceptor_test.go @@ -0,0 +1,65 @@ +package connect_middleware + +import ( + "context" + "reflect" + "testing" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" +) + +func Test_toGRPCError(t *testing.T) { + type args struct { + ctx context.Context + req connect.AnyRequest + handler func(t *testing.T) connect.UnaryFunc + } + type res struct { + want interface{} + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + "no error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{}, + handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}), + }, + res{ + &connect.Response[struct{}]{}, + false, + }, + }, + { + "error", + args{ + ctx: context.Background(), + req: &mockReq[struct{}]{}, + handler: errorMockHandler(), + }, + res{ + nil, + true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toConnectError(tt.args.ctx, tt.args.req, tt.args.handler(t)) + if (err != nil) != tt.res.wantErr { + t.Errorf("toGRPCError() error = %v, wantErr %v", err, tt.res.wantErr) + return + } + if !reflect.DeepEqual(got, tt.res.want) { + t.Errorf("toGRPCError() got = %v, want %v", got, tt.res.want) + } + }) + } +} diff --git a/internal/api/grpc/server/connect_middleware/execution_interceptor.go b/internal/api/grpc/server/connect_middleware/execution_interceptor.go new file mode 100644 index 0000000000..879496a33f --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/execution_interceptor.go @@ -0,0 +1,160 @@ +package connect_middleware + +import ( + "context" + "encoding/json" + + "connectrpc.com/connect" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/execution" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +func ExecutionHandler(queries *query.Queries) connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) { + requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, req.Spec().Procedure) + + // call targets otherwise return req + handledReq, err := executeTargetsForRequest(ctx, requestTargets, req.Spec().Procedure, req) + if err != nil { + return nil, err + } + + response, err := handler(ctx, handledReq) + if err != nil { + return nil, err + } + + return executeTargetsForResponse(ctx, responseTargets, req.Spec().Procedure, handledReq, response) + } + } +} + +func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest) (_ connect.AnyRequest, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + // if no targets are found, return without any calls + if len(targets) == 0 { + return req, nil + } + + ctxData := authz.GetCtxData(ctx) + info := &ContextInfoRequest{ + FullMethod: fullMethod, + InstanceID: authz.GetInstance(ctx).InstanceID(), + ProjectID: ctxData.ProjectID, + OrgID: ctxData.OrgID, + UserID: ctxData.UserID, + Request: Message{req.Any().(proto.Message)}, + } + + _, err = execution.CallTargets(ctx, targets, info) + if err != nil { + return nil, err + } + return req, nil +} + +func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest, resp connect.AnyResponse) (_ connect.AnyResponse, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + // if no targets are found, return without any calls + if len(targets) == 0 { + return resp, nil + } + + ctxData := authz.GetCtxData(ctx) + info := &ContextInfoResponse{ + FullMethod: fullMethod, + InstanceID: authz.GetInstance(ctx).InstanceID(), + ProjectID: ctxData.ProjectID, + OrgID: ctxData.OrgID, + UserID: ctxData.UserID, + Request: Message{req.Any().(proto.Message)}, + Response: Message{resp.Any().(proto.Message)}, + } + + _, err = execution.CallTargets(ctx, targets, info) + if err != nil { + return nil, err + } + return resp, nil +} + +var _ execution.ContextInfo = &ContextInfoRequest{} + +type ContextInfoRequest struct { + FullMethod string `json:"fullMethod,omitempty"` + InstanceID string `json:"instanceID,omitempty"` + OrgID string `json:"orgID,omitempty"` + ProjectID string `json:"projectID,omitempty"` + UserID string `json:"userID,omitempty"` + Request Message `json:"request,omitempty"` +} + +type Message struct { + proto.Message +} + +func (r *Message) MarshalJSON() ([]byte, error) { + data, err := protojson.Marshal(r.Message) + if err != nil { + return nil, err + } + return data, nil +} + +func (r *Message) UnmarshalJSON(data []byte) error { + return protojson.Unmarshal(data, r.Message) +} + +func (c *ContextInfoRequest) GetHTTPRequestBody() []byte { + data, err := json.Marshal(c) + if err != nil { + return nil + } + return data +} + +func (c *ContextInfoRequest) SetHTTPResponseBody(resp []byte) error { + return json.Unmarshal(resp, &c.Request) +} + +func (c *ContextInfoRequest) GetContent() interface{} { + return c.Request.Message +} + +var _ execution.ContextInfo = &ContextInfoResponse{} + +type ContextInfoResponse struct { + FullMethod string `json:"fullMethod,omitempty"` + InstanceID string `json:"instanceID,omitempty"` + OrgID string `json:"orgID,omitempty"` + ProjectID string `json:"projectID,omitempty"` + UserID string `json:"userID,omitempty"` + Request Message `json:"request,omitempty"` + Response Message `json:"response,omitempty"` +} + +func (c *ContextInfoResponse) GetHTTPRequestBody() []byte { + data, err := json.Marshal(c) + if err != nil { + return nil + } + return data +} + +func (c *ContextInfoResponse) SetHTTPResponseBody(resp []byte) error { + return json.Unmarshal(resp, &c.Response) +} + +func (c *ContextInfoResponse) GetContent() interface{} { + return c.Response.Message +} diff --git a/internal/api/grpc/server/connect_middleware/execution_interceptor_test.go b/internal/api/grpc/server/connect_middleware/execution_interceptor_test.go new file mode 100644 index 0000000000..d910824f21 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/execution_interceptor_test.go @@ -0,0 +1,815 @@ +package connect_middleware + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/execution" +) + +var _ execution.Target = &mockExecutionTarget{} + +type mockExecutionTarget struct { + InstanceID string + ExecutionID string + TargetID string + TargetType domain.TargetType + Endpoint string + Timeout time.Duration + InterruptOnError bool + SigningKey string +} + +func (e *mockExecutionTarget) SetEndpoint(endpoint string) { + e.Endpoint = endpoint +} +func (e *mockExecutionTarget) IsInterruptOnError() bool { + return e.InterruptOnError +} +func (e *mockExecutionTarget) GetEndpoint() string { + return e.Endpoint +} +func (e *mockExecutionTarget) GetTargetType() domain.TargetType { + return e.TargetType +} +func (e *mockExecutionTarget) GetTimeout() time.Duration { + return e.Timeout +} +func (e *mockExecutionTarget) GetTargetID() string { + return e.TargetID +} +func (e *mockExecutionTarget) GetExecutionID() string { + return e.ExecutionID +} +func (e *mockExecutionTarget) GetSigningKey() string { + return e.SigningKey +} + +func newMockContentRequest(content string) *connect.Request[structpb.Struct] { + return connect.NewRequest(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "content": { + Kind: &structpb.Value_StringValue{StringValue: content}, + }, + }, + }) +} + +func newMockContentResponse(content string) *connect.Response[structpb.Struct] { + return connect.NewResponse(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "content": { + Kind: &structpb.Value_StringValue{StringValue: content}, + }, + }, + }) +} + +func newMockContextInfoRequest(fullMethod, request string) *ContextInfoRequest { + return &ContextInfoRequest{ + FullMethod: fullMethod, + Request: Message{Message: newMockContentRequest(request).Msg}, + } +} + +func newMockContextInfoResponse(fullMethod, request, response string) *ContextInfoResponse { + return &ContextInfoResponse{ + FullMethod: fullMethod, + Request: Message{Message: newMockContentRequest(request).Msg}, + Response: Message{Message: newMockContentResponse(response).Msg}, + } +} + +func Test_executeTargetsForGRPCFullMethod_request(t *testing.T) { + type target struct { + reqBody execution.ContextInfo + sleep time.Duration + statusCode int + respBody connect.AnyResponse + } + type args struct { + ctx context.Context + + executionTargets []execution.Target + targets []target + fullMethod string + req connect.AnyRequest + } + type res struct { + want interface{} + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + "target, executionTargets nil", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: nil, + req: newMockContentRequest("request"), + }, + res{ + want: newMockContentRequest("request"), + }, + }, + { + "target, executionTargets empty", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{}, + req: newMockContentRequest("request"), + }, + res{ + want: newMockContentRequest("request"), + }, + }, + { + "target, not reachable", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + }, + }, + targets: []target{}, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "target, error without interrupt", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusBadRequest, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + want: newMockContentRequest("content"), + }, + }, + { + "target, interruptOnError", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusBadRequest, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "target, timeout", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Second, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 5 * time.Second, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "target, wrong request", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Second, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + {reqBody: newMockContextInfoRequest("/service/method", "wrong")}, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "target, ok", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + want: newMockContentRequest("content1"), + }, + }, + { + "target async, timeout", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeAsync, + Timeout: time.Second, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 5 * time.Second, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + want: newMockContentRequest("content"), + }, + }, + { + "target async, ok", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeAsync, + Timeout: time.Minute, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + want: newMockContentRequest("content"), + }, + }, + { + "webhook, error", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeWebhook, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + sleep: 0, + statusCode: http.StatusInternalServerError, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "webhook, timeout", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeWebhook, + Timeout: time.Second, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 5 * time.Second, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "webhook, ok", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeWebhook, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + want: newMockContentRequest("content"), + }, + }, + { + "with includes, interruptOnError", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target1", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target2", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target3", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusOK, + }, + { + reqBody: newMockContextInfoRequest("/service/method", "content1"), + respBody: newMockContentResponse("content2"), + sleep: 0, + statusCode: http.StatusBadRequest, + }, + { + reqBody: newMockContextInfoRequest("/service/method", "content2"), + respBody: newMockContentResponse("content3"), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + { + "with includes, timeout", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target1", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target2", + TargetType: domain.TargetTypeCall, + Timeout: time.Second, + InterruptOnError: true, + SigningKey: "signingkey", + }, + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target3", + TargetType: domain.TargetTypeCall, + Timeout: time.Second, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse("content1"), + sleep: 0, + statusCode: http.StatusOK, + }, + { + reqBody: newMockContextInfoRequest("/service/method", "content1"), + respBody: newMockContentResponse("content2"), + sleep: 5 * time.Second, + statusCode: http.StatusBadRequest, + }, + { + reqBody: newMockContextInfoRequest("/service/method", "content2"), + respBody: newMockContentResponse("content3"), + sleep: 5 * time.Second, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("content"), + }, + res{ + wantErr: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + closeFuncs := make([]func(), len(tt.args.targets)) + for i, target := range tt.args.targets { + url, closeF := testServerCall( + target.reqBody, + target.sleep, + target.statusCode, + target.respBody, + ) + + et := tt.args.executionTargets[i].(*mockExecutionTarget) + et.SetEndpoint(url) + closeFuncs[i] = closeF + } + + resp, err := executeTargetsForRequest( + tt.args.ctx, + tt.args.executionTargets, + tt.args.fullMethod, + tt.args.req, + ) + + if tt.res.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.EqualExportedValues(t, tt.res.want, resp) + + for _, closeF := range closeFuncs { + closeF() + } + }) + } +} + +func testServerCall( + reqBody interface{}, + sleep time.Duration, + statusCode int, + respBody connect.AnyResponse, +) (string, func()) { + handler := func(w http.ResponseWriter, r *http.Request) { + data, err := json.Marshal(reqBody) + if err != nil { + http.Error(w, "error", http.StatusInternalServerError) + return + } + + sentBody, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "error", http.StatusInternalServerError) + return + } + + if !reflect.DeepEqual(data, sentBody) { + http.Error(w, "error", http.StatusInternalServerError) + return + } + + if statusCode != http.StatusOK { + http.Error(w, "error", statusCode) + return + } + + time.Sleep(sleep) + + w.Header().Set("Content-Type", "application/json") + resp, err := protojson.Marshal(respBody.Any().(proto.Message)) + if err != nil { + http.Error(w, "error", http.StatusInternalServerError) + return + } + if _, err := w.Write(resp); err != nil { + http.Error(w, "error", http.StatusInternalServerError) + return + } + } + + server := httptest.NewServer(http.HandlerFunc(handler)) + + return server.URL, server.Close +} + +func Test_executeTargetsForGRPCFullMethod_response(t *testing.T) { + type target struct { + reqBody execution.ContextInfo + sleep time.Duration + statusCode int + respBody connect.AnyResponse + } + type args struct { + ctx context.Context + + executionTargets []execution.Target + targets []target + fullMethod string + req connect.AnyRequest + resp connect.AnyResponse + } + type res struct { + want interface{} + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + "target, executionTargets nil", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: nil, + req: newMockContentRequest("request"), + resp: newMockContentResponse("response"), + }, + res{ + want: newMockContentResponse("response"), + }, + }, + { + "target, executionTargets empty", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{}, + req: newMockContentRequest("request"), + resp: newMockContentResponse("response"), + }, + res{ + want: newMockContentResponse("response"), + }, + }, + { + "target, empty response", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "request./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoRequest("/service/method", "content"), + respBody: newMockContentResponse(""), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest(""), + resp: newMockContentResponse(""), + }, + res{ + wantErr: true, + }, + }, + { + "target, ok", + args{ + ctx: context.Background(), + fullMethod: "/service/method", + executionTargets: []execution.Target{ + &mockExecutionTarget{ + InstanceID: "instance", + ExecutionID: "response./zitadel.session.v2.SessionService/SetSession", + TargetID: "target", + TargetType: domain.TargetTypeCall, + Timeout: time.Minute, + InterruptOnError: true, + SigningKey: "signingkey", + }, + }, + targets: []target{ + { + reqBody: newMockContextInfoResponse("/service/method", "request", "response"), + respBody: newMockContentResponse("response1"), + sleep: 0, + statusCode: http.StatusOK, + }, + }, + req: newMockContentRequest("request"), + resp: newMockContentResponse("response"), + }, + res{ + want: newMockContentResponse("response1"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + closeFuncs := make([]func(), len(tt.args.targets)) + for i, target := range tt.args.targets { + url, closeF := testServerCall( + target.reqBody, + target.sleep, + target.statusCode, + target.respBody, + ) + + et := tt.args.executionTargets[i].(*mockExecutionTarget) + et.SetEndpoint(url) + closeFuncs[i] = closeF + } + + resp, err := executeTargetsForResponse( + tt.args.ctx, + tt.args.executionTargets, + tt.args.fullMethod, + tt.args.req, + tt.args.resp, + ) + + if tt.res.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.EqualExportedValues(t, tt.res.want, resp) + + for _, closeF := range closeFuncs { + closeF() + } + }) + } +} diff --git a/internal/api/grpc/server/connect_middleware/instance_interceptor.go b/internal/api/grpc/server/connect_middleware/instance_interceptor.go new file mode 100644 index 0000000000..27f59313f8 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/instance_interceptor.go @@ -0,0 +1,107 @@ +package connect_middleware + +import ( + "context" + "errors" + "fmt" + "strings" + + "connectrpc.com/connect" + "github.com/zitadel/logging" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/authz" + zitadel_http "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" + object_v3 "github.com/zitadel/zitadel/pkg/grpc/object/v3alpha" +) + +func InstanceInterceptor(verifier authz.InstanceVerifier, externalDomain string, explicitInstanceIdServices ...string) connect.UnaryInterceptorFunc { + translator, err := i18n.NewZitadelTranslator(language.English) + logging.OnError(err).Panic("unable to get translator") + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return setInstance(ctx, req, handler, verifier, externalDomain, translator, explicitInstanceIdServices...) + } + } +} + +func setInstance(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, externalDomain string, translator *i18n.Translator, idFromRequestsServices ...string) (_ connect.AnyResponse, err error) { + interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx) + defer func() { span.EndWithError(err) }() + + for _, service := range idFromRequestsServices { + if !strings.HasPrefix(service, "/") { + service = "/" + service + } + if strings.HasPrefix(req.Spec().Procedure, service) { + withInstanceIDProperty, ok := req.Any().(interface { + GetInstanceId() string + }) + if !ok { + return handler(ctx, req) + } + return addInstanceByID(interceptorCtx, req, handler, verifier, translator, withInstanceIDProperty.GetInstanceId()) + } + } + explicitInstanceRequest, ok := req.Any().(interface { + GetInstance() *object_v3.Instance + }) + if ok { + instance := explicitInstanceRequest.GetInstance() + if id := instance.GetId(); id != "" { + return addInstanceByID(interceptorCtx, req, handler, verifier, translator, id) + } + if domain := instance.GetDomain(); domain != "" { + return addInstanceByDomain(interceptorCtx, req, handler, verifier, translator, domain) + } + } + return addInstanceByRequestedHost(interceptorCtx, req, handler, verifier, translator, externalDomain) +} + +func addInstanceByID(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, id string) (connect.AnyResponse, error) { + instance, err := verifier.InstanceByID(ctx, id) + if err != nil { + notFoundErr := new(zerrors.ZitadelError) + if errors.As(err, ¬FoundErr) { + notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil) + } + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using id %s: %w", id, notFoundErr)) + } + return handler(authz.WithInstance(ctx, instance), req) +} + +func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, domain string) (connect.AnyResponse, error) { + instance, err := verifier.InstanceByHost(ctx, domain, "") + if err != nil { + notFoundErr := new(zerrors.NotFoundError) + if errors.As(err, ¬FoundErr) { + notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil) + } + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using domain %s: %w", domain, notFoundErr)) + } + return handler(authz.WithInstance(ctx, instance), req) +} + +func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) { + requestContext := zitadel_http.DomainContext(ctx) + if requestContext.InstanceHost == "" { + logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance") + return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified")) + } + instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost) + if err != nil { + origin := zitadel_http.DomainContext(ctx) + logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance") + zErr := new(zerrors.ZitadelError) + if errors.As(err, &zErr) { + zErr.SetMessage(translator.LocalizeFromCtx(ctx, zErr.GetMessage(), nil)) + zErr.Parent = err + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s): %s", origin, externalDomain, zErr.Error())) + } + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s)", origin, externalDomain)) + } + return handler(authz.WithInstance(ctx, instance), req) +} diff --git a/internal/api/grpc/server/connect_middleware/limits_interceptor.go b/internal/api/grpc/server/connect_middleware/limits_interceptor.go new file mode 100644 index 0000000000..abf7e5f0aa --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/limits_interceptor.go @@ -0,0 +1,34 @@ +package connect_middleware + +import ( + "context" + "strings" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func LimitsInterceptor(ignoreService ...string) connect.UnaryInterceptorFunc { + for idx, service := range ignoreService { + if !strings.HasPrefix(service, "/") { + ignoreService[idx] = "/" + service + } + } + + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) { + for _, service := range ignoreService { + if strings.HasPrefix(req.Spec().Procedure, service) { + return handler(ctx, req) + } + } + instance := authz.GetInstance(ctx) + if block := instance.Block(); block != nil && *block { + return nil, zerrors.ThrowResourceExhausted(nil, "LIMITS-molsj", "Errors.Limits.Instance.Blocked") + } + return handler(ctx, req) + } + } +} diff --git a/internal/api/grpc/server/connect_middleware/metrics_interceptor.go b/internal/api/grpc/server/connect_middleware/metrics_interceptor.go new file mode 100644 index 0000000000..552fa5658d --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/metrics_interceptor.go @@ -0,0 +1,96 @@ +package connect_middleware + +import ( + "context" + "strings" + + "connectrpc.com/connect" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/zitadel/logging" + "go.opentelemetry.io/otel/attribute" + "google.golang.org/grpc/codes" + + _ "github.com/zitadel/zitadel/internal/statik" + "github.com/zitadel/zitadel/internal/telemetry/metrics" +) + +const ( + GrpcMethod = "grpc_method" + ReturnCode = "return_code" + GrpcRequestCounter = "grpc.server.request_counter" + GrpcRequestCounterDescription = "Grpc request counter" + TotalGrpcRequestCounter = "grpc.server.total_request_counter" + TotalGrpcRequestCounterDescription = "Total grpc request counter" + GrpcStatusCodeCounter = "grpc.server.grpc_status_code" + GrpcStatusCodeCounterDescription = "Grpc status code counter" +) + +func MetricsHandler(metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return RegisterMetrics(ctx, req, handler, metricTypes, ignoredMethodSuffixes...) + } + } +} + +func RegisterMetrics(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) (_ connect.AnyResponse, err error) { + if len(metricTypes) == 0 { + return handler(ctx, req) + } + + for _, ignore := range ignoredMethodSuffixes { + if strings.HasSuffix(req.Spec().Procedure, ignore) { + return handler(ctx, req) + } + } + + resp, err := handler(ctx, req) + if containsMetricsMethod(metrics.MetricTypeRequestCount, metricTypes) { + RegisterGrpcRequestCounter(ctx, req.Spec().Procedure) + } + if containsMetricsMethod(metrics.MetricTypeTotalCount, metricTypes) { + RegisterGrpcTotalRequestCounter(ctx) + } + if containsMetricsMethod(metrics.MetricTypeStatusCode, metricTypes) { + RegisterGrpcRequestCodeCounter(ctx, req.Spec().Procedure, err) + } + return resp, err +} + +func RegisterGrpcRequestCounter(ctx context.Context, path string) { + var labels = map[string]attribute.Value{ + GrpcMethod: attribute.StringValue(path), + } + err := metrics.RegisterCounter(GrpcRequestCounter, GrpcRequestCounterDescription) + logging.OnError(err).Warn("failed to register grpc request counter") + err = metrics.AddCount(ctx, GrpcRequestCounter, 1, labels) + logging.OnError(err).Warn("failed to add grpc request count") +} + +func RegisterGrpcTotalRequestCounter(ctx context.Context) { + err := metrics.RegisterCounter(TotalGrpcRequestCounter, TotalGrpcRequestCounterDescription) + logging.OnError(err).Warn("failed to register total grpc request counter") + err = metrics.AddCount(ctx, TotalGrpcRequestCounter, 1, nil) + logging.OnError(err).Warn("failed to add total grpc request count") +} + +func RegisterGrpcRequestCodeCounter(ctx context.Context, path string, err error) { + statusCode := connect.CodeOf(err) + var labels = map[string]attribute.Value{ + GrpcMethod: attribute.StringValue(path), + ReturnCode: attribute.IntValue(runtime.HTTPStatusFromCode(codes.Code(statusCode))), + } + err = metrics.RegisterCounter(GrpcStatusCodeCounter, GrpcStatusCodeCounterDescription) + logging.OnError(err).Warn("failed to register grpc status code counter") + err = metrics.AddCount(ctx, GrpcStatusCodeCounter, 1, labels) + logging.OnError(err).Warn("failed to add grpc status code count") +} + +func containsMetricsMethod(metricType metrics.MetricType, metricTypes []metrics.MetricType) bool { + for _, m := range metricTypes { + if m == metricType { + return true + } + } + return false +} diff --git a/internal/api/grpc/server/connect_middleware/mock_test.go b/internal/api/grpc/server/connect_middleware/mock_test.go new file mode 100644 index 0000000000..abd996b01f --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/mock_test.go @@ -0,0 +1,50 @@ +package connect_middleware + +import ( + "context" + "net/http" + "testing" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func emptyMockHandler(resp connect.AnyResponse, expectedCtxData authz.CtxData) func(*testing.T) connect.UnaryFunc { + return func(t *testing.T) connect.UnaryFunc { + return func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) { + assert.Equal(t, expectedCtxData, authz.GetCtxData(ctx)) + return resp, nil + } + } +} + +func errorMockHandler() func(*testing.T) connect.UnaryFunc { + return func(t *testing.T) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return nil, zerrors.ThrowInternal(nil, "test", "error") + } + } +} + +type mockReq[t any] struct { + connect.Request[t] + + procedure string + header http.Header +} + +func (m *mockReq[T]) Spec() connect.Spec { + return connect.Spec{ + Procedure: m.procedure, + } +} + +func (m *mockReq[T]) Header() http.Header { + if m.header == nil { + m.header = make(http.Header) + } + return m.header +} diff --git a/internal/api/grpc/server/connect_middleware/quota_interceptor.go b/internal/api/grpc/server/connect_middleware/quota_interceptor.go new file mode 100644 index 0000000000..caa32511e4 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/quota_interceptor.go @@ -0,0 +1,53 @@ +package connect_middleware + +import ( + "context" + "strings" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) connect.UnaryInterceptorFunc { + for idx, service := range ignoreService { + if !strings.HasPrefix(service, "/") { + ignoreService[idx] = "/" + service + } + } + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) { + if !svc.Enabled() { + return handler(ctx, req) + } + interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx) + defer func() { span.EndWithError(err) }() + + // The auth interceptor will ensure that only authorized or public requests are allowed. + // So if there's no authorization context, we don't need to check for limitation + // Also, we don't limit calls with system user tokens + ctxData := authz.GetCtxData(ctx) + if ctxData.IsZero() || ctxData.SystemMemberships != nil { + return handler(ctx, req) + } + + for _, service := range ignoreService { + if strings.HasPrefix(req.Spec().Procedure, service) { + return handler(ctx, req) + } + } + + instance := authz.GetInstance(ctx) + remaining := svc.Limit(interceptorCtx, instance.InstanceID()) + if remaining != nil && *remaining == 0 { + return nil, zerrors.ThrowResourceExhausted(nil, "QUOTA-vjAy8", "Quota.Access.Exhausted") + } + span.End() + return handler(ctx, req) + } + } +} diff --git a/internal/api/grpc/server/connect_middleware/service_interceptor.go b/internal/api/grpc/server/connect_middleware/service_interceptor.go new file mode 100644 index 0000000000..c5cf798ce5 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/service_interceptor.go @@ -0,0 +1,45 @@ +package connect_middleware + +import ( + "context" + "strings" + + "connectrpc.com/connect" + + "github.com/zitadel/zitadel/internal/api/service" + _ "github.com/zitadel/zitadel/internal/statik" +) + +const ( + unknown = "UNKNOWN" +) + +func ServiceHandler() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + serviceName, _ := serviceAndMethod(req.Spec().Procedure) + if serviceName != unknown { + return handler(ctx, req) + } + ctx = service.WithService(ctx, serviceName) + return handler(ctx, req) + } + } +} + +// serviceAndMethod returns the service and method from a procedure. +func serviceAndMethod(procedure string) (string, string) { + procedure = strings.TrimPrefix(procedure, "/") + serviceName, method := unknown, unknown + if strings.Contains(procedure, "/") { + long := strings.Split(procedure, "/")[0] + if strings.Contains(long, ".") { + split := strings.Split(long, ".") + serviceName = split[len(split)-1] + } + } + if strings.Contains(procedure, "/") { + method = strings.Split(procedure, "/")[1] + } + return serviceName, method +} diff --git a/internal/api/grpc/server/connect_middleware/translation_interceptor.go b/internal/api/grpc/server/connect_middleware/translation_interceptor.go new file mode 100644 index 0000000000..f01b1c85ab --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/translation_interceptor.go @@ -0,0 +1,48 @@ +package connect_middleware + +import ( + "context" + + "connectrpc.com/connect" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/i18n" + _ "github.com/zitadel/zitadel/internal/statik" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +func TranslationHandler() connect.UnaryInterceptorFunc { + + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + resp, err := handler(ctx, req) + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if err != nil { + translator, translatorError := getTranslator(ctx) + if translatorError != nil { + return resp, err + } + return resp, translateError(ctx, err, translator) + } + if loc, ok := resp.Any().(localizers); ok { + translator, translatorError := getTranslator(ctx) + if translatorError != nil { + return resp, err + } + translateFields(ctx, loc, translator) + } + return resp, nil + } + } +} + +func getTranslator(ctx context.Context) (*i18n.Translator, error) { + translator, err := i18n.NewZitadelTranslator(authz.GetInstance(ctx).DefaultLanguage()) + if err != nil { + logging.New().WithError(err).Error("could not load translator") + } + return translator, err +} diff --git a/internal/api/grpc/server/connect_middleware/translator.go b/internal/api/grpc/server/connect_middleware/translator.go new file mode 100644 index 0000000000..6d61b1d772 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/translator.go @@ -0,0 +1,37 @@ +package connect_middleware + +import ( + "context" + "errors" + + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type localizers interface { + Localizers() []Localizer +} +type Localizer interface { + LocalizationKey() string + SetLocalizedMessage(string) +} + +func translateFields(ctx context.Context, object localizers, translator *i18n.Translator) { + if translator == nil || object == nil { + return + } + for _, field := range object.Localizers() { + field.SetLocalizedMessage(translator.LocalizeFromCtx(ctx, field.LocalizationKey(), nil)) + } +} + +func translateError(ctx context.Context, err error, translator *i18n.Translator) error { + if translator == nil || err == nil { + return err + } + caosErr := new(zerrors.ZitadelError) + if errors.As(err, &caosErr) { + caosErr.SetMessage(translator.LocalizeFromCtx(ctx, caosErr.GetMessage(), nil)) + } + return err +} diff --git a/internal/api/grpc/server/connect_middleware/validation_interceptor.go b/internal/api/grpc/server/connect_middleware/validation_interceptor.go new file mode 100644 index 0000000000..8441886114 --- /dev/null +++ b/internal/api/grpc/server/connect_middleware/validation_interceptor.go @@ -0,0 +1,36 @@ +package connect_middleware + +import ( + "context" + + "connectrpc.com/connect" + // import to make sure go.mod does not lose it + // because dependency is only needed for generated code + _ "github.com/envoyproxy/protoc-gen-validate/validate" +) + +func ValidationHandler() connect.UnaryInterceptorFunc { + return func(handler connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + return validate(ctx, req, handler) + } + } +} + +// validator interface needed for github.com/envoyproxy/protoc-gen-validate +// (it does not expose an interface itself) +type validator interface { + Validate() error +} + +func validate(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) { + validate, ok := req.Any().(validator) + if !ok { + return handler(ctx, req) + } + err := validate.Validate() + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + return handler(ctx, req) +} diff --git a/internal/api/grpc/server/gateway.go b/internal/api/grpc/server/gateway.go index ca7579ee89..b20819b850 100644 --- a/internal/api/grpc/server/gateway.go +++ b/internal/api/grpc/server/gateway.go @@ -171,7 +171,7 @@ func CreateGateway( }, nil } -func RegisterGateway(ctx context.Context, gateway *Gateway, server Server) error { +func RegisterGateway(ctx context.Context, gateway *Gateway, server WithGateway) error { err := server.RegisterGateway()(ctx, gateway.mux, gateway.connection) if err != nil { return fmt.Errorf("failed to register grpc gateway: %w", err) diff --git a/internal/api/grpc/server/server.go b/internal/api/grpc/server/server.go index b686d3add9..0c02087c89 100644 --- a/internal/api/grpc/server/server.go +++ b/internal/api/grpc/server/server.go @@ -2,11 +2,14 @@ package server import ( "crypto/tls" + "net/http" + "connectrpc.com/connect" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" "google.golang.org/grpc/credentials" healthpb "google.golang.org/grpc/health/grpc_health_v1" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" grpc_api "github.com/zitadel/zitadel/internal/api/grpc" @@ -19,21 +22,36 @@ import ( ) type Server interface { - RegisterServer(*grpc.Server) - RegisterGateway() RegisterGatewayFunc AppName() string MethodPrefix() string AuthMethods() authz.MethodMapping } +type GrpcServer interface { + Server + RegisterServer(*grpc.Server) +} + +type WithGateway interface { + Server + RegisterGateway() RegisterGatewayFunc +} + // WithGatewayPrefix extends the server interface with a prefix for the grpc gateway // // it's used for the System, Admin, Mgmt and Auth API type WithGatewayPrefix interface { - Server + GrpcServer + WithGateway GatewayPathPrefix() string } +type ConnectServer interface { + Server + RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) + FileDescriptor() protoreflect.FileDescriptor +} + func CreateServer( verifier authz.APITokenVerifier, systemAuthz authz.Config, diff --git a/internal/api/grpc/session/v2/query.go b/internal/api/grpc/session/v2/query.go index 73303dd9e8..78d8623ee7 100644 --- a/internal/api/grpc/session/v2/query.go +++ b/internal/api/grpc/session/v2/query.go @@ -4,6 +4,7 @@ import ( "context" "time" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/timestamppb" @@ -26,18 +27,18 @@ var ( } ) -func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) +func (s *Server) GetSession(ctx context.Context, req *connect.Request[session.GetSessionRequest]) (*connect.Response[session.GetSessionResponse], error) { + res, err := s.query.SessionByID(ctx, true, req.Msg.GetSessionId(), req.Msg.GetSessionToken(), s.checkPermission) if err != nil { return nil, err } - return &session.GetSessionResponse{ + return connect.NewResponse(&session.GetSessionResponse{ Session: sessionToPb(res), - }, nil + }), nil } -func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { - queries, err := listSessionsRequestToQuery(ctx, req) +func (s *Server) ListSessions(ctx context.Context, req *connect.Request[session.ListSessionsRequest]) (*connect.Response[session.ListSessionsResponse], error) { + queries, err := listSessionsRequestToQuery(ctx, req.Msg) if err != nil { return nil, err } @@ -45,10 +46,10 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ if err != nil { return nil, err } - return &session.ListSessionsResponse{ + return connect.NewResponse(&session.ListSessionsResponse{ Details: object.ToListDetails(sessions.SearchResponse), Sessions: sessionsToPb(sessions.Sessions), - }, nil + }), nil } func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { diff --git a/internal/api/grpc/session/v2/server.go b/internal/api/grpc/session/v2/server.go index ee534cb26c..8f06cb3fb0 100644 --- a/internal/api/grpc/session/v2/server.go +++ b/internal/api/grpc/session/v2/server.go @@ -1,7 +1,10 @@ package session import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/session/v2" + "github.com/zitadel/zitadel/pkg/grpc/session/v2/sessionconnect" ) -var _ session.SessionServiceServer = (*Server)(nil) +var _ sessionconnect.SessionServiceHandler = (*Server)(nil) type Server struct { - session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries @@ -35,8 +38,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - session.RegisterSessionServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return sessionconnect.NewSessionServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return session.File_zitadel_session_v2_session_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index 08f19368ef..94f686a72c 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "connectrpc.com/connect" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" @@ -17,12 +18,12 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) -func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { - checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) +func (s *Server) CreateSession(ctx context.Context, req *connect.Request[session.CreateSessionRequest]) (*connect.Response[session.CreateSessionResponse], error) { + checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req.Msg) if err != nil { return nil, err } - challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.Msg.GetChallenges(), checks) if err != nil { return nil, err } @@ -32,43 +33,43 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe return nil, err } - return &session.CreateSessionResponse{ + return connect.NewResponse(&session.CreateSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionId: set.ID, SessionToken: set.NewToken, Challenges: challengeResponse, - }, nil + }), nil } -func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest) (*session.SetSessionResponse, error) { - checks, err := s.setSessionRequestToCommand(ctx, req) +func (s *Server) SetSession(ctx context.Context, req *connect.Request[session.SetSessionRequest]) (*connect.Response[session.SetSessionResponse], error) { + checks, err := s.setSessionRequestToCommand(ctx, req.Msg) if err != nil { return nil, err } - challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.Msg.GetChallenges(), checks) if err != nil { return nil, err } - set, err := s.command.UpdateSession(ctx, req.GetSessionId(), cmds, req.GetMetadata(), req.GetLifetime().AsDuration()) + set, err := s.command.UpdateSession(ctx, req.Msg.GetSessionId(), cmds, req.Msg.GetMetadata(), req.Msg.GetLifetime().AsDuration()) if err != nil { return nil, err } - return &session.SetSessionResponse{ + return connect.NewResponse(&session.SetSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionToken: set.NewToken, Challenges: challengeResponse, - }, nil + }), nil } -func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRequest) (*session.DeleteSessionResponse, error) { - details, err := s.command.TerminateSession(ctx, req.GetSessionId(), req.GetSessionToken()) +func (s *Server) DeleteSession(ctx context.Context, req *connect.Request[session.DeleteSessionRequest]) (*connect.Response[session.DeleteSessionResponse], error) { + details, err := s.command.TerminateSession(ctx, req.Msg.GetSessionId(), req.Msg.GetSessionToken()) if err != nil { return nil, err } - return &session.DeleteSessionResponse{ + return connect.NewResponse(&session.DeleteSessionResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { diff --git a/internal/api/grpc/session/v2beta/server.go b/internal/api/grpc/session/v2beta/server.go index cf0d0c27f0..e659b406eb 100644 --- a/internal/api/grpc/session/v2beta/server.go +++ b/internal/api/grpc/session/v2beta/server.go @@ -1,7 +1,10 @@ package session import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -9,12 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/session/v2beta/sessionconnect" ) -var _ session.SessionServiceServer = (*Server)(nil) +var _ sessionconnect.SessionServiceHandler = (*Server)(nil) type Server struct { - session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries @@ -35,8 +38,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - session.RegisterSessionServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return sessionconnect.NewSessionServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return session.File_zitadel_session_v2beta_session_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/session/v2beta/session.go b/internal/api/grpc/session/v2beta/session.go index 3b36b8ba83..459cf77f05 100644 --- a/internal/api/grpc/session/v2beta/session.go +++ b/internal/api/grpc/session/v2beta/session.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" @@ -31,18 +32,18 @@ var ( } ) -func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) +func (s *Server) GetSession(ctx context.Context, req *connect.Request[session.GetSessionRequest]) (*connect.Response[session.GetSessionResponse], error) { + res, err := s.query.SessionByID(ctx, true, req.Msg.GetSessionId(), req.Msg.GetSessionToken(), s.checkPermission) if err != nil { return nil, err } - return &session.GetSessionResponse{ + return connect.NewResponse(&session.GetSessionResponse{ Session: sessionToPb(res), - }, nil + }), nil } -func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { - queries, err := listSessionsRequestToQuery(ctx, req) +func (s *Server) ListSessions(ctx context.Context, req *connect.Request[session.ListSessionsRequest]) (*connect.Response[session.ListSessionsResponse], error) { + queries, err := listSessionsRequestToQuery(ctx, req.Msg) if err != nil { return nil, err } @@ -50,18 +51,18 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ if err != nil { return nil, err } - return &session.ListSessionsResponse{ + return connect.NewResponse(&session.ListSessionsResponse{ Details: object.ToListDetails(sessions.SearchResponse), Sessions: sessionsToPb(sessions.Sessions), - }, nil + }), nil } -func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { - checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) +func (s *Server) CreateSession(ctx context.Context, req *connect.Request[session.CreateSessionRequest]) (*connect.Response[session.CreateSessionResponse], error) { + checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req.Msg) if err != nil { return nil, err } - challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.Msg.GetChallenges(), checks) if err != nil { return nil, err } @@ -71,43 +72,43 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe return nil, err } - return &session.CreateSessionResponse{ + return connect.NewResponse(&session.CreateSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionId: set.ID, SessionToken: set.NewToken, Challenges: challengeResponse, - }, nil + }), nil } -func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest) (*session.SetSessionResponse, error) { - checks, err := s.setSessionRequestToCommand(ctx, req) +func (s *Server) SetSession(ctx context.Context, req *connect.Request[session.SetSessionRequest]) (*connect.Response[session.SetSessionResponse], error) { + checks, err := s.setSessionRequestToCommand(ctx, req.Msg) if err != nil { return nil, err } - challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.Msg.GetChallenges(), checks) if err != nil { return nil, err } - set, err := s.command.UpdateSession(ctx, req.GetSessionId(), cmds, req.GetMetadata(), req.GetLifetime().AsDuration()) + set, err := s.command.UpdateSession(ctx, req.Msg.GetSessionId(), cmds, req.Msg.GetMetadata(), req.Msg.GetLifetime().AsDuration()) if err != nil { return nil, err } - return &session.SetSessionResponse{ + return connect.NewResponse(&session.SetSessionResponse{ Details: object.DomainToDetailsPb(set.ObjectDetails), SessionToken: set.NewToken, Challenges: challengeResponse, - }, nil + }), nil } -func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRequest) (*session.DeleteSessionResponse, error) { - details, err := s.command.TerminateSession(ctx, req.GetSessionId(), req.GetSessionToken()) +func (s *Server) DeleteSession(ctx context.Context, req *connect.Request[session.DeleteSessionRequest]) (*connect.Response[session.DeleteSessionResponse], error) { + details, err := s.command.TerminateSession(ctx, req.Msg.GetSessionId(), req.Msg.GetSessionToken()) if err != nil { return nil, err } - return &session.DeleteSessionResponse{ + return connect.NewResponse(&session.DeleteSessionResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func sessionsToPb(sessions []*query.Session) []*session.Session { diff --git a/internal/api/grpc/settings/v2/query.go b/internal/api/grpc/settings/v2/query.go index b8994ccb87..d522424040 100644 --- a/internal/api/grpc/settings/v2/query.go +++ b/internal/api/grpc/settings/v2/query.go @@ -3,6 +3,7 @@ package settings import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/authz" @@ -14,12 +15,12 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/settings/v2" ) -func (s *Server) GetLoginSettings(ctx context.Context, req *settings.GetLoginSettingsRequest) (*settings.GetLoginSettingsResponse, error) { - current, err := s.query.LoginPolicyByID(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetLoginSettings(ctx context.Context, req *connect.Request[settings.GetLoginSettingsRequest]) (*connect.Response[settings.GetLoginSettingsResponse], error) { + current, err := s.query.LoginPolicyByID(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetLoginSettingsResponse{ + return connect.NewResponse(&settings.GetLoginSettingsResponse{ Settings: loginSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -27,15 +28,15 @@ func (s *Server) GetLoginSettings(ctx context.Context, req *settings.GetLoginSet ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.OrgID, }, - }, nil + }), nil } -func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *settings.GetPasswordComplexitySettingsRequest) (*settings.GetPasswordComplexitySettingsResponse, error) { - current, err := s.query.PasswordComplexityPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *connect.Request[settings.GetPasswordComplexitySettingsRequest]) (*connect.Response[settings.GetPasswordComplexitySettingsResponse], error) { + current, err := s.query.PasswordComplexityPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetPasswordComplexitySettingsResponse{ + return connect.NewResponse(&settings.GetPasswordComplexitySettingsResponse{ Settings: passwordComplexitySettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -43,15 +44,15 @@ func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *setting ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *settings.GetPasswordExpirySettingsRequest) (*settings.GetPasswordExpirySettingsResponse, error) { - current, err := s.query.PasswordAgePolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *connect.Request[settings.GetPasswordExpirySettingsRequest]) (*connect.Response[settings.GetPasswordExpirySettingsResponse], error) { + current, err := s.query.PasswordAgePolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetPasswordExpirySettingsResponse{ + return connect.NewResponse(&settings.GetPasswordExpirySettingsResponse{ Settings: passwordExpirySettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -59,15 +60,15 @@ func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *settings.Ge ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetBrandingSettings(ctx context.Context, req *settings.GetBrandingSettingsRequest) (*settings.GetBrandingSettingsResponse, error) { - current, err := s.query.ActiveLabelPolicyByOrg(ctx, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetBrandingSettings(ctx context.Context, req *connect.Request[settings.GetBrandingSettingsRequest]) (*connect.Response[settings.GetBrandingSettingsResponse], error) { + current, err := s.query.ActiveLabelPolicyByOrg(ctx, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetBrandingSettingsResponse{ + return connect.NewResponse(&settings.GetBrandingSettingsResponse{ Settings: brandingSettingsToPb(current, s.assetsAPIDomain(ctx)), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -75,15 +76,15 @@ func (s *Server) GetBrandingSettings(ctx context.Context, req *settings.GetBrand ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetDomainSettings(ctx context.Context, req *settings.GetDomainSettingsRequest) (*settings.GetDomainSettingsResponse, error) { - current, err := s.query.DomainPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetDomainSettings(ctx context.Context, req *connect.Request[settings.GetDomainSettingsRequest]) (*connect.Response[settings.GetDomainSettingsResponse], error) { + current, err := s.query.DomainPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetDomainSettingsResponse{ + return connect.NewResponse(&settings.GetDomainSettingsResponse{ Settings: domainSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -91,15 +92,15 @@ func (s *Server) GetDomainSettings(ctx context.Context, req *settings.GetDomainS ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *settings.GetLegalAndSupportSettingsRequest) (*settings.GetLegalAndSupportSettingsResponse, error) { - current, err := s.query.PrivacyPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *connect.Request[settings.GetLegalAndSupportSettingsRequest]) (*connect.Response[settings.GetLegalAndSupportSettingsResponse], error) { + current, err := s.query.PrivacyPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetLegalAndSupportSettingsResponse{ + return connect.NewResponse(&settings.GetLegalAndSupportSettingsResponse{ Settings: legalAndSupportSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -107,15 +108,15 @@ func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *settings.G ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetLockoutSettings(ctx context.Context, req *settings.GetLockoutSettingsRequest) (*settings.GetLockoutSettingsResponse, error) { - current, err := s.query.LockoutPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx())) +func (s *Server) GetLockoutSettings(ctx context.Context, req *connect.Request[settings.GetLockoutSettingsRequest]) (*connect.Response[settings.GetLockoutSettingsResponse], error) { + current, err := s.query.LockoutPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx())) if err != nil { return nil, err } - return &settings.GetLockoutSettingsResponse{ + return connect.NewResponse(&settings.GetLockoutSettingsResponse{ Settings: lockoutSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -123,24 +124,24 @@ func (s *Server) GetLockoutSettings(ctx context.Context, req *settings.GetLockou ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetActiveIdentityProviders(ctx context.Context, req *settings.GetActiveIdentityProvidersRequest) (*settings.GetActiveIdentityProvidersResponse, error) { - queries, err := activeIdentityProvidersToQuery(req) +func (s *Server) GetActiveIdentityProviders(ctx context.Context, req *connect.Request[settings.GetActiveIdentityProvidersRequest]) (*connect.Response[settings.GetActiveIdentityProvidersResponse], error) { + queries, err := activeIdentityProvidersToQuery(req.Msg) if err != nil { return nil, err } - links, err := s.query.IDPLoginPolicyLinks(ctx, object.ResourceOwnerFromReq(ctx, req.GetCtx()), &query.IDPLoginPolicyLinksSearchQuery{Queries: queries}, false) + links, err := s.query.IDPLoginPolicyLinks(ctx, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), &query.IDPLoginPolicyLinksSearchQuery{Queries: queries}, false) if err != nil { return nil, err } - return &settings.GetActiveIdentityProvidersResponse{ + return connect.NewResponse(&settings.GetActiveIdentityProvidersResponse{ Details: object.ToListDetails(links.SearchResponse), IdentityProviders: identityProvidersToPb(links.Links), - }, nil + }), nil } func activeIdentityProvidersToQuery(req *settings.GetActiveIdentityProvidersRequest) (_ []query.SearchQuery, err error) { @@ -180,30 +181,30 @@ func activeIdentityProvidersToQuery(req *settings.GetActiveIdentityProvidersRequ return q, nil } -func (s *Server) GetGeneralSettings(ctx context.Context, _ *settings.GetGeneralSettingsRequest) (*settings.GetGeneralSettingsResponse, error) { +func (s *Server) GetGeneralSettings(ctx context.Context, _ *connect.Request[settings.GetGeneralSettingsRequest]) (*connect.Response[settings.GetGeneralSettingsResponse], error) { instance := authz.GetInstance(ctx) - return &settings.GetGeneralSettingsResponse{ + return connect.NewResponse(&settings.GetGeneralSettingsResponse{ SupportedLanguages: domain.LanguagesToStrings(i18n.SupportedLanguages()), DefaultOrgId: instance.DefaultOrganisationID(), DefaultLanguage: instance.DefaultLanguage().String(), - }, nil + }), nil } -func (s *Server) GetSecuritySettings(ctx context.Context, req *settings.GetSecuritySettingsRequest) (*settings.GetSecuritySettingsResponse, error) { +func (s *Server) GetSecuritySettings(ctx context.Context, req *connect.Request[settings.GetSecuritySettingsRequest]) (*connect.Response[settings.GetSecuritySettingsResponse], error) { policy, err := s.query.SecurityPolicy(ctx) if err != nil { return nil, err } - return &settings.GetSecuritySettingsResponse{ + return connect.NewResponse(&settings.GetSecuritySettingsResponse{ Settings: securityPolicyToSettingsPb(policy), - }, nil + }), nil } -func (s *Server) GetHostedLoginTranslation(ctx context.Context, req *settings.GetHostedLoginTranslationRequest) (*settings.GetHostedLoginTranslationResponse, error) { - translation, err := s.query.GetHostedLoginTranslation(ctx, req) +func (s *Server) GetHostedLoginTranslation(ctx context.Context, req *connect.Request[settings.GetHostedLoginTranslationRequest]) (*connect.Response[settings.GetHostedLoginTranslationResponse], error) { + translation, err := s.query.GetHostedLoginTranslation(ctx, req.Msg) if err != nil { return nil, err } - return translation, nil + return connect.NewResponse(translation), nil } diff --git a/internal/api/grpc/settings/v2/server.go b/internal/api/grpc/settings/v2/server.go index 9cae50824f..bfaec17fc2 100644 --- a/internal/api/grpc/settings/v2/server.go +++ b/internal/api/grpc/settings/v2/server.go @@ -2,8 +2,10 @@ package settings import ( "context" + "net/http" - "google.golang.org/grpc" + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/assets" "github.com/zitadel/zitadel/internal/api/authz" @@ -11,12 +13,12 @@ import ( "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/settings/v2" + "github.com/zitadel/zitadel/pkg/grpc/settings/v2/settingsconnect" ) -var _ settings.SettingsServiceServer = (*Server)(nil) +var _ settingsconnect.SettingsServiceHandler = (*Server)(nil) type Server struct { - settings.UnimplementedSettingsServiceServer command *command.Commands query *query.Queries assetsAPIDomain func(context.Context) string @@ -35,8 +37,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - settings.RegisterSettingsServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return settingsconnect.NewSettingsServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return settings.File_zitadel_settings_v2_settings_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/settings/v2/settings.go b/internal/api/grpc/settings/v2/settings.go index 09ee6b27c8..c7db200211 100644 --- a/internal/api/grpc/settings/v2/settings.go +++ b/internal/api/grpc/settings/v2/settings.go @@ -3,25 +3,27 @@ package settings import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/pkg/grpc/settings/v2" ) -func (s *Server) SetSecuritySettings(ctx context.Context, req *settings.SetSecuritySettingsRequest) (*settings.SetSecuritySettingsResponse, error) { - details, err := s.command.SetSecurityPolicy(ctx, securitySettingsToCommand(req)) +func (s *Server) SetSecuritySettings(ctx context.Context, req *connect.Request[settings.SetSecuritySettingsRequest]) (*connect.Response[settings.SetSecuritySettingsResponse], error) { + details, err := s.command.SetSecurityPolicy(ctx, securitySettingsToCommand(req.Msg)) if err != nil { return nil, err } - return &settings.SetSecuritySettingsResponse{ + return connect.NewResponse(&settings.SetSecuritySettingsResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) SetHostedLoginTranslation(ctx context.Context, req *settings.SetHostedLoginTranslationRequest) (*settings.SetHostedLoginTranslationResponse, error) { - res, err := s.command.SetHostedLoginTranslation(ctx, req) +func (s *Server) SetHostedLoginTranslation(ctx context.Context, req *connect.Request[settings.SetHostedLoginTranslationRequest]) (*connect.Response[settings.SetHostedLoginTranslationResponse], error) { + res, err := s.command.SetHostedLoginTranslation(ctx, req.Msg) if err != nil { return nil, err } - return res, nil + return connect.NewResponse(res), nil } diff --git a/internal/api/grpc/settings/v2beta/server.go b/internal/api/grpc/settings/v2beta/server.go index 24c8f7774a..a8200a7216 100644 --- a/internal/api/grpc/settings/v2beta/server.go +++ b/internal/api/grpc/settings/v2beta/server.go @@ -2,8 +2,10 @@ package settings import ( "context" + "net/http" - "google.golang.org/grpc" + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/assets" "github.com/zitadel/zitadel/internal/api/authz" @@ -11,12 +13,12 @@ import ( "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" settings "github.com/zitadel/zitadel/pkg/grpc/settings/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/settings/v2beta/settingsconnect" ) -var _ settings.SettingsServiceServer = (*Server)(nil) +var _ settingsconnect.SettingsServiceHandler = (*Server)(nil) type Server struct { - settings.UnimplementedSettingsServiceServer command *command.Commands query *query.Queries assetsAPIDomain func(context.Context) string @@ -35,8 +37,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - settings.RegisterSettingsServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return settingsconnect.NewSettingsServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return settings.File_zitadel_settings_v2beta_settings_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/settings/v2beta/settings.go b/internal/api/grpc/settings/v2beta/settings.go index 6193f129ba..53d2c37c32 100644 --- a/internal/api/grpc/settings/v2beta/settings.go +++ b/internal/api/grpc/settings/v2beta/settings.go @@ -3,6 +3,7 @@ package settings import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/authz" @@ -14,12 +15,12 @@ import ( settings "github.com/zitadel/zitadel/pkg/grpc/settings/v2beta" ) -func (s *Server) GetLoginSettings(ctx context.Context, req *settings.GetLoginSettingsRequest) (*settings.GetLoginSettingsResponse, error) { - current, err := s.query.LoginPolicyByID(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetLoginSettings(ctx context.Context, req *connect.Request[settings.GetLoginSettingsRequest]) (*connect.Response[settings.GetLoginSettingsResponse], error) { + current, err := s.query.LoginPolicyByID(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetLoginSettingsResponse{ + return connect.NewResponse(&settings.GetLoginSettingsResponse{ Settings: loginSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -27,15 +28,15 @@ func (s *Server) GetLoginSettings(ctx context.Context, req *settings.GetLoginSet ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.OrgID, }, - }, nil + }), nil } -func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *settings.GetPasswordComplexitySettingsRequest) (*settings.GetPasswordComplexitySettingsResponse, error) { - current, err := s.query.PasswordComplexityPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *connect.Request[settings.GetPasswordComplexitySettingsRequest]) (*connect.Response[settings.GetPasswordComplexitySettingsResponse], error) { + current, err := s.query.PasswordComplexityPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetPasswordComplexitySettingsResponse{ + return connect.NewResponse(&settings.GetPasswordComplexitySettingsResponse{ Settings: passwordComplexitySettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -43,15 +44,15 @@ func (s *Server) GetPasswordComplexitySettings(ctx context.Context, req *setting ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *settings.GetPasswordExpirySettingsRequest) (*settings.GetPasswordExpirySettingsResponse, error) { - current, err := s.query.PasswordAgePolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *connect.Request[settings.GetPasswordExpirySettingsRequest]) (*connect.Response[settings.GetPasswordExpirySettingsResponse], error) { + current, err := s.query.PasswordAgePolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetPasswordExpirySettingsResponse{ + return connect.NewResponse(&settings.GetPasswordExpirySettingsResponse{ Settings: passwordExpirySettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -59,15 +60,15 @@ func (s *Server) GetPasswordExpirySettings(ctx context.Context, req *settings.Ge ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetBrandingSettings(ctx context.Context, req *settings.GetBrandingSettingsRequest) (*settings.GetBrandingSettingsResponse, error) { - current, err := s.query.ActiveLabelPolicyByOrg(ctx, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetBrandingSettings(ctx context.Context, req *connect.Request[settings.GetBrandingSettingsRequest]) (*connect.Response[settings.GetBrandingSettingsResponse], error) { + current, err := s.query.ActiveLabelPolicyByOrg(ctx, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetBrandingSettingsResponse{ + return connect.NewResponse(&settings.GetBrandingSettingsResponse{ Settings: brandingSettingsToPb(current, s.assetsAPIDomain(ctx)), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -75,15 +76,15 @@ func (s *Server) GetBrandingSettings(ctx context.Context, req *settings.GetBrand ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetDomainSettings(ctx context.Context, req *settings.GetDomainSettingsRequest) (*settings.GetDomainSettingsResponse, error) { - current, err := s.query.DomainPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetDomainSettings(ctx context.Context, req *connect.Request[settings.GetDomainSettingsRequest]) (*connect.Response[settings.GetDomainSettingsResponse], error) { + current, err := s.query.DomainPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetDomainSettingsResponse{ + return connect.NewResponse(&settings.GetDomainSettingsResponse{ Settings: domainSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -91,15 +92,15 @@ func (s *Server) GetDomainSettings(ctx context.Context, req *settings.GetDomainS ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *settings.GetLegalAndSupportSettingsRequest) (*settings.GetLegalAndSupportSettingsResponse, error) { - current, err := s.query.PrivacyPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx()), false) +func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *connect.Request[settings.GetLegalAndSupportSettingsRequest]) (*connect.Response[settings.GetLegalAndSupportSettingsResponse], error) { + current, err := s.query.PrivacyPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), false) if err != nil { return nil, err } - return &settings.GetLegalAndSupportSettingsResponse{ + return connect.NewResponse(&settings.GetLegalAndSupportSettingsResponse{ Settings: legalAndSupportSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -107,15 +108,15 @@ func (s *Server) GetLegalAndSupportSettings(ctx context.Context, req *settings.G ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetLockoutSettings(ctx context.Context, req *settings.GetLockoutSettingsRequest) (*settings.GetLockoutSettingsResponse, error) { - current, err := s.query.LockoutPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.GetCtx())) +func (s *Server) GetLockoutSettings(ctx context.Context, req *connect.Request[settings.GetLockoutSettingsRequest]) (*connect.Response[settings.GetLockoutSettingsResponse], error) { + current, err := s.query.LockoutPolicyByOrg(ctx, true, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx())) if err != nil { return nil, err } - return &settings.GetLockoutSettingsResponse{ + return connect.NewResponse(&settings.GetLockoutSettingsResponse{ Settings: lockoutSettingsToPb(current), Details: &object_pb.Details{ Sequence: current.Sequence, @@ -123,46 +124,46 @@ func (s *Server) GetLockoutSettings(ctx context.Context, req *settings.GetLockou ChangeDate: timestamppb.New(current.ChangeDate), ResourceOwner: current.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) GetActiveIdentityProviders(ctx context.Context, req *settings.GetActiveIdentityProvidersRequest) (*settings.GetActiveIdentityProvidersResponse, error) { - links, err := s.query.IDPLoginPolicyLinks(ctx, object.ResourceOwnerFromReq(ctx, req.GetCtx()), &query.IDPLoginPolicyLinksSearchQuery{}, false) +func (s *Server) GetActiveIdentityProviders(ctx context.Context, req *connect.Request[settings.GetActiveIdentityProvidersRequest]) (*connect.Response[settings.GetActiveIdentityProvidersResponse], error) { + links, err := s.query.IDPLoginPolicyLinks(ctx, object.ResourceOwnerFromReq(ctx, req.Msg.GetCtx()), &query.IDPLoginPolicyLinksSearchQuery{}, false) if err != nil { return nil, err } - return &settings.GetActiveIdentityProvidersResponse{ + return connect.NewResponse(&settings.GetActiveIdentityProvidersResponse{ Details: object.ToListDetails(links.SearchResponse), IdentityProviders: identityProvidersToPb(links.Links), - }, nil + }), nil } -func (s *Server) GetGeneralSettings(ctx context.Context, _ *settings.GetGeneralSettingsRequest) (*settings.GetGeneralSettingsResponse, error) { +func (s *Server) GetGeneralSettings(ctx context.Context, _ *connect.Request[settings.GetGeneralSettingsRequest]) (*connect.Response[settings.GetGeneralSettingsResponse], error) { instance := authz.GetInstance(ctx) - return &settings.GetGeneralSettingsResponse{ + return connect.NewResponse(&settings.GetGeneralSettingsResponse{ SupportedLanguages: domain.LanguagesToStrings(i18n.SupportedLanguages()), DefaultOrgId: instance.DefaultOrganisationID(), DefaultLanguage: instance.DefaultLanguage().String(), - }, nil + }), nil } -func (s *Server) GetSecuritySettings(ctx context.Context, req *settings.GetSecuritySettingsRequest) (*settings.GetSecuritySettingsResponse, error) { +func (s *Server) GetSecuritySettings(ctx context.Context, req *connect.Request[settings.GetSecuritySettingsRequest]) (*connect.Response[settings.GetSecuritySettingsResponse], error) { policy, err := s.query.SecurityPolicy(ctx) if err != nil { return nil, err } - return &settings.GetSecuritySettingsResponse{ + return connect.NewResponse(&settings.GetSecuritySettingsResponse{ Settings: securityPolicyToSettingsPb(policy), - }, nil + }), nil } -func (s *Server) SetSecuritySettings(ctx context.Context, req *settings.SetSecuritySettingsRequest) (*settings.SetSecuritySettingsResponse, error) { - details, err := s.command.SetSecurityPolicy(ctx, securitySettingsToCommand(req)) +func (s *Server) SetSecuritySettings(ctx context.Context, req *connect.Request[settings.SetSecuritySettingsRequest]) (*connect.Response[settings.SetSecuritySettingsResponse], error) { + details, err := s.command.SetSecurityPolicy(ctx, securitySettingsToCommand(req.Msg)) if err != nil { return nil, err } - return &settings.SetSecuritySettingsResponse{ + return connect.NewResponse(&settings.SetSecuritySettingsResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/email.go b/internal/api/grpc/user/v2/email.go index 4b247ef10f..df68e58c7d 100644 --- a/internal/api/grpc/user/v2/email.go +++ b/internal/api/grpc/user/v2/email.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/domain" @@ -11,18 +12,18 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) { +func (s *Server) SetEmail(ctx context.Context, req *connect.Request[user.SetEmailRequest]) (resp *connect.Response[user.SetEmailResponse], err error) { var email *domain.Email - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetEmailRequest_SendCode: - email, err = s.command.ChangeUserEmailURLTemplate(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) + email, err = s.command.ChangeUserEmailURLTemplate(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) case *user.SetEmailRequest_ReturnCode: - email, err = s.command.ChangeUserEmailReturnCode(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg) + email, err = s.command.ChangeUserEmailReturnCode(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg) case *user.SetEmailRequest_IsVerified: - email, err = s.command.ChangeUserEmailVerified(ctx, req.GetUserId(), req.GetEmail()) + email, err = s.command.ChangeUserEmailVerified(ctx, req.Msg.GetUserId(), req.Msg.GetEmail()) case nil: - email, err = s.command.ChangeUserEmail(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg) + email, err = s.command.ChangeUserEmail(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-Ahng0", "verification oneOf %T in method SetEmail not implemented", v) } @@ -30,26 +31,26 @@ func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp return nil, err } - return &user.SetEmailResponse{ + return connect.NewResponse(&user.SetEmailResponse{ Details: &object.Details{ Sequence: email.Sequence, ChangeDate: timestamppb.New(email.ChangeDate), ResourceOwner: email.ResourceOwner, }, VerificationCode: email.PlainCode, - }, nil + }), nil } -func (s *Server) ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.ResendEmailCodeResponse, err error) { +func (s *Server) ResendEmailCode(ctx context.Context, req *connect.Request[user.ResendEmailCodeRequest]) (resp *connect.Response[user.ResendEmailCodeResponse], err error) { var email *domain.Email - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.ResendEmailCodeRequest_SendCode: - email, err = s.command.ResendUserEmailCodeURLTemplate(ctx, req.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) + email, err = s.command.ResendUserEmailCodeURLTemplate(ctx, req.Msg.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) case *user.ResendEmailCodeRequest_ReturnCode: - email, err = s.command.ResendUserEmailReturnCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.ResendUserEmailReturnCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case nil: - email, err = s.command.ResendUserEmailCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.ResendUserEmailCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-faj0l0nj5x", "verification oneOf %T in method ResendEmailCode not implemented", v) } @@ -57,26 +58,26 @@ func (s *Server) ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeR return nil, err } - return &user.ResendEmailCodeResponse{ + return connect.NewResponse(&user.ResendEmailCodeResponse{ Details: &object.Details{ Sequence: email.Sequence, ChangeDate: timestamppb.New(email.ChangeDate), ResourceOwner: email.ResourceOwner, }, VerificationCode: email.PlainCode, - }, nil + }), nil } -func (s *Server) SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) { +func (s *Server) SendEmailCode(ctx context.Context, req *connect.Request[user.SendEmailCodeRequest]) (resp *connect.Response[user.SendEmailCodeResponse], err error) { var email *domain.Email - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SendEmailCodeRequest_SendCode: - email, err = s.command.SendUserEmailCodeURLTemplate(ctx, req.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) + email, err = s.command.SendUserEmailCodeURLTemplate(ctx, req.Msg.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) case *user.SendEmailCodeRequest_ReturnCode: - email, err = s.command.SendUserEmailReturnCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.SendUserEmailReturnCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case nil: - email, err = s.command.SendUserEmailCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.SendUserEmailCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-faj0l0nj5x", "verification oneOf %T in method SendEmailCode not implemented", v) } @@ -84,30 +85,30 @@ func (s *Server) SendEmailCode(ctx context.Context, req *user.SendEmailCodeReque return nil, err } - return &user.SendEmailCodeResponse{ + return connect.NewResponse(&user.SendEmailCodeResponse{ Details: &object.Details{ Sequence: email.Sequence, ChangeDate: timestamppb.New(email.ChangeDate), ResourceOwner: email.ResourceOwner, }, VerificationCode: email.PlainCode, - }, nil + }), nil } -func (s *Server) VerifyEmail(ctx context.Context, req *user.VerifyEmailRequest) (*user.VerifyEmailResponse, error) { +func (s *Server) VerifyEmail(ctx context.Context, req *connect.Request[user.VerifyEmailRequest]) (*connect.Response[user.VerifyEmailResponse], error) { details, err := s.command.VerifyUserEmail(ctx, - req.GetUserId(), - req.GetVerificationCode(), + req.Msg.GetUserId(), + req.Msg.GetVerificationCode(), s.userCodeAlg, ) if err != nil { return nil, err } - return &user.VerifyEmailResponse{ + return connect.NewResponse(&user.VerifyEmailResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/human.go b/internal/api/grpc/user/v2/human.go index d8a0891396..06414d12cb 100644 --- a/internal/api/grpc/user/v2/human.go +++ b/internal/api/grpc/user/v2/human.go @@ -4,6 +4,7 @@ import ( "context" "io" + "connectrpc.com/connect" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/timestamppb" @@ -14,7 +15,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) createUserTypeHuman(ctx context.Context, humanPb *user.CreateUserRequest_Human, orgId string, userName, userId *string) (*user.CreateUserResponse, error) { +func (s *Server) createUserTypeHuman(ctx context.Context, humanPb *user.CreateUserRequest_Human, orgId string, userName, userId *string) (*connect.Response[user.CreateUserResponse], error) { addHumanPb := &user.AddHumanUserRequest{ Username: userName, UserId: userId, @@ -52,15 +53,15 @@ func (s *Server) createUserTypeHuman(ctx context.Context, humanPb *user.CreateUs ); err != nil { return nil, err } - return &user.CreateUserResponse{ + return connect.NewResponse(&user.CreateUserResponse{ Id: newHuman.ID, CreationDate: timestamppb.New(newHuman.Details.EventDate), EmailCode: newHuman.EmailCode, PhoneCode: newHuman.PhoneCode, - }, nil + }), nil } -func (s *Server) updateUserTypeHuman(ctx context.Context, humanPb *user.UpdateUserRequest_Human, userId string, userName *string) (*user.UpdateUserResponse, error) { +func (s *Server) updateUserTypeHuman(ctx context.Context, humanPb *user.UpdateUserRequest_Human, userId string, userName *string) (*connect.Response[user.UpdateUserResponse], error) { cmd, err := updateHumanUserToCommand(userId, userName, humanPb) if err != nil { return nil, err @@ -68,11 +69,11 @@ func (s *Server) updateUserTypeHuman(ctx context.Context, humanPb *user.UpdateUs if err = s.command.ChangeUserHuman(ctx, cmd, s.userCodeAlg); err != nil { return nil, err } - return &user.UpdateUserResponse{ + return connect.NewResponse(&user.UpdateUserResponse{ ChangeDate: timestamppb.New(cmd.Details.EventDate), EmailCode: cmd.EmailCode, PhoneCode: cmd.PhoneCode, - }, nil + }), nil } func updateHumanUserToCommand(userId string, userName *string, human *user.UpdateUserRequest_Human) (*command.ChangeHuman, error) { diff --git a/internal/api/grpc/user/v2/idp_link.go b/internal/api/grpc/user/v2/idp_link.go index bef40617cf..0b1e7ab998 100644 --- a/internal/api/grpc/user/v2/idp_link.go +++ b/internal/api/grpc/user/v2/idp_link.go @@ -3,6 +3,8 @@ package user import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" @@ -11,22 +13,22 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddIDPLink(ctx context.Context, req *user.AddIDPLinkRequest) (_ *user.AddIDPLinkResponse, err error) { - details, err := s.command.AddUserIDPLink(ctx, req.UserId, "", &command.AddLink{ - IDPID: req.GetIdpLink().GetIdpId(), - DisplayName: req.GetIdpLink().GetUserName(), - IDPExternalID: req.GetIdpLink().GetUserId(), +func (s *Server) AddIDPLink(ctx context.Context, req *connect.Request[user.AddIDPLinkRequest]) (_ *connect.Response[user.AddIDPLinkResponse], err error) { + details, err := s.command.AddUserIDPLink(ctx, req.Msg.GetUserId(), "", &command.AddLink{ + IDPID: req.Msg.GetIdpLink().GetIdpId(), + DisplayName: req.Msg.GetIdpLink().GetUserName(), + IDPExternalID: req.Msg.GetIdpLink().GetUserId(), }) if err != nil { return nil, err } - return &user.AddIDPLinkResponse{ + return connect.NewResponse(&user.AddIDPLinkResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ListIDPLinks(ctx context.Context, req *user.ListIDPLinksRequest) (_ *user.ListIDPLinksResponse, err error) { - queries, err := ListLinkedIDPsRequestToQuery(req) +func (s *Server) ListIDPLinks(ctx context.Context, req *connect.Request[user.ListIDPLinksRequest]) (_ *connect.Response[user.ListIDPLinksResponse], err error) { + queries, err := ListLinkedIDPsRequestToQuery(req.Msg) if err != nil { return nil, err } @@ -34,10 +36,10 @@ func (s *Server) ListIDPLinks(ctx context.Context, req *user.ListIDPLinksRequest if err != nil { return nil, err } - return &user.ListIDPLinksResponse{ + return connect.NewResponse(&user.ListIDPLinksResponse{ Result: IDPLinksToPb(res.Links), Details: object.ToListDetails(res.SearchResponse), - }, nil + }), nil } func ListLinkedIDPsRequestToQuery(req *user.ListIDPLinksRequest) (*query.IDPUserLinksSearchQuery, error) { @@ -72,14 +74,14 @@ func IDPLinkToPb(link *query.IDPUserLink) *user.IDPLink { } } -func (s *Server) RemoveIDPLink(ctx context.Context, req *user.RemoveIDPLinkRequest) (*user.RemoveIDPLinkResponse, error) { - objectDetails, err := s.command.RemoveUserIDPLink(ctx, RemoveIDPLinkRequestToDomain(ctx, req)) +func (s *Server) RemoveIDPLink(ctx context.Context, req *connect.Request[user.RemoveIDPLinkRequest]) (*connect.Response[user.RemoveIDPLinkResponse], error) { + objectDetails, err := s.command.RemoveUserIDPLink(ctx, RemoveIDPLinkRequestToDomain(ctx, req.Msg)) if err != nil { return nil, err } - return &user.RemoveIDPLinkResponse{ + return connect.NewResponse(&user.RemoveIDPLinkResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } func RemoveIDPLinkRequestToDomain(ctx context.Context, req *user.RemoveIDPLinkRequest) *domain.UserIDPLink { diff --git a/internal/api/grpc/user/v2/intent.go b/internal/api/grpc/user/v2/intent.go index fd65d61dfb..c26adba24d 100644 --- a/internal/api/grpc/user/v2/intent.go +++ b/internal/api/grpc/user/v2/intent.go @@ -6,6 +6,7 @@ import ( "errors" "time" + "connectrpc.com/connect" oidc_pkg "github.com/zitadel/oidc/v3/pkg/oidc" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -32,18 +33,18 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.StartIdentityProviderIntentRequest) (_ *user.StartIdentityProviderIntentResponse, err error) { - switch t := req.GetContent().(type) { +func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *connect.Request[user.StartIdentityProviderIntentRequest]) (_ *connect.Response[user.StartIdentityProviderIntentResponse], err error) { + switch t := req.Msg.GetContent().(type) { case *user.StartIdentityProviderIntentRequest_Urls: - return s.startIDPIntent(ctx, req.GetIdpId(), t.Urls) + return s.startIDPIntent(ctx, req.Msg.GetIdpId(), t.Urls) case *user.StartIdentityProviderIntentRequest_Ldap: - return s.startLDAPIntent(ctx, req.GetIdpId(), t.Ldap) + return s.startLDAPIntent(ctx, req.Msg.GetIdpId(), t.Ldap) default: return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-S2g21", "type oneOf %T in method StartIdentityProviderIntent not implemented", t) } } -func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { +func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*connect.Response[user.StartIdentityProviderIntentResponse], error) { state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) if err != nil { return nil, err @@ -58,12 +59,12 @@ func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.Re } switch a := auth.(type) { case *idp.RedirectAuth: - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: a.RedirectURL}, - }, nil + }), nil case *idp.FormAuth: - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_FormData{ FormData: &user.FormData{ @@ -71,12 +72,12 @@ func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.Re Fields: a.Fields, }, }, - }, nil + }), nil } return nil, zerrors.ThrowInvalidArgumentf(nil, "USERv2-3g2j3", "type oneOf %T in method StartIdentityProviderIntent not implemented", auth) } -func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { +func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*connect.Response[user.StartIdentityProviderIntentResponse], error) { intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil) if err != nil { return nil, err @@ -92,7 +93,7 @@ func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredenti if err != nil { return nil, err } - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_IdpIntent{ IdpIntent: &user.IDPIntent{ @@ -101,7 +102,7 @@ func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredenti UserId: userID, }, }, - }, nil + }), nil } func (s *Server) checkLinkedExternalUser(ctx context.Context, idpID, externalUserID string) (string, error) { @@ -150,12 +151,12 @@ func (s *Server) ldapLogin(ctx context.Context, idpID, username, password string return externalUser, userID, session, nil } -func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.RetrieveIdentityProviderIntentRequest) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { - intent, err := s.command.GetIntentWriteModel(ctx, req.GetIdpIntentId(), "") +func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *connect.Request[user.RetrieveIdentityProviderIntentRequest]) (_ *connect.Response[user.RetrieveIdentityProviderIntentResponse], err error) { + intent, err := s.command.GetIntentWriteModel(ctx, req.Msg.GetIdpIntentId(), "") if err != nil { return nil, err } - if err := s.checkIntentToken(req.GetIdpIntentToken(), intent.AggregateID); err != nil { + if err := s.checkIntentToken(req.Msg.GetIdpIntentToken(), intent.AggregateID); err != nil { return nil, err } if intent.State != domain.IDPIntentStateSucceeded { @@ -203,7 +204,7 @@ func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.R } idpIntent.AddHumanUser = idpUserToAddHumanUser(idpUser, idpIntent.IdpInformation.IdpId) } - return idpIntent, nil + return connect.NewResponse(idpIntent), nil } type rawUserMapper struct { diff --git a/internal/api/grpc/user/v2/key.go b/internal/api/grpc/user/v2/key.go index 59dab44248..021f4be388 100644 --- a/internal/api/grpc/user/v2/key.go +++ b/internal/api/grpc/user/v2/key.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/command" @@ -11,16 +12,16 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddKey(ctx context.Context, req *user.AddKeyRequest) (*user.AddKeyResponse, error) { +func (s *Server) AddKey(ctx context.Context, req *connect.Request[user.AddKeyRequest]) (*connect.Response[user.AddKeyResponse], error) { newMachineKey := &command.MachineKey{ ObjectRoot: models.ObjectRoot{ - AggregateID: req.UserId, + AggregateID: req.Msg.GetUserId(), }, - ExpirationDate: req.GetExpirationDate().AsTime(), + ExpirationDate: req.Msg.GetExpirationDate().AsTime(), Type: domain.AuthNKeyTypeJSON, PermissionCheck: s.command.NewPermissionCheckUserWrite(ctx), } - newMachineKey.PublicKey = req.PublicKey + newMachineKey.PublicKey = req.Msg.GetPublicKey() pubkeySupplied := len(newMachineKey.PublicKey) > 0 details, err := s.command.AddUserMachineKey(ctx, newMachineKey) @@ -37,26 +38,26 @@ func (s *Server) AddKey(ctx context.Context, req *user.AddKeyRequest) (*user.Add return nil, err } } - return &user.AddKeyResponse{ + return connect.NewResponse(&user.AddKeyResponse{ KeyId: newMachineKey.KeyID, KeyContent: keyDetails, CreationDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) RemoveKey(ctx context.Context, req *user.RemoveKeyRequest) (*user.RemoveKeyResponse, error) { +func (s *Server) RemoveKey(ctx context.Context, req *connect.Request[user.RemoveKeyRequest]) (*connect.Response[user.RemoveKeyResponse], error) { machineKey := &command.MachineKey{ ObjectRoot: models.ObjectRoot{ - AggregateID: req.UserId, + AggregateID: req.Msg.GetUserId(), }, PermissionCheck: s.command.NewPermissionCheckUserWrite(ctx), - KeyID: req.KeyId, + KeyID: req.Msg.GetKeyId(), } objectDetails, err := s.command.RemoveUserMachineKey(ctx, machineKey) if err != nil { return nil, err } - return &user.RemoveKeyResponse{ + return connect.NewResponse(&user.RemoveKeyResponse{ DeletionDate: timestamppb.New(objectDetails.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/key_query.go b/internal/api/grpc/user/v2/key_query.go index da4f47decf..e9466a791b 100644 --- a/internal/api/grpc/user/v2/key_query.go +++ b/internal/api/grpc/user/v2/key_query.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/grpc/filter/v2" @@ -12,13 +13,13 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) ListKeys(ctx context.Context, req *user.ListKeysRequest) (*user.ListKeysResponse, error) { - offset, limit, asc, err := filter.PaginationPbToQuery(s.systemDefaults, req.Pagination) +func (s *Server) ListKeys(ctx context.Context, req *connect.Request[user.ListKeysRequest]) (*connect.Response[user.ListKeysResponse], error) { + offset, limit, asc, err := filter.PaginationPbToQuery(s.systemDefaults, req.Msg.GetPagination()) if err != nil { return nil, err } - filters, err := keyFiltersToQueries(req.Filters) + filters, err := keyFiltersToQueries(req.Msg.GetFilters()) if err != nil { return nil, err } @@ -27,7 +28,7 @@ func (s *Server) ListKeys(ctx context.Context, req *user.ListKeysRequest) (*user Offset: offset, Limit: limit, Asc: asc, - SortingColumn: authnKeyFieldNameToSortingColumn(req.SortingColumn), + SortingColumn: authnKeyFieldNameToSortingColumn(req.Msg.SortingColumn), }, Queries: filters, } @@ -49,7 +50,7 @@ func (s *Server) ListKeys(ctx context.Context, req *user.ListKeysRequest) (*user ExpirationDate: timestamppb.New(key.Expiration), } } - return resp, nil + return connect.NewResponse(resp), nil } func keyFiltersToQueries(filters []*user.KeysSearchFilter) (_ []query.SearchQuery, err error) { diff --git a/internal/api/grpc/user/v2/machine.go b/internal/api/grpc/user/v2/machine.go index ad02b2289e..e5126b9019 100644 --- a/internal/api/grpc/user/v2/machine.go +++ b/internal/api/grpc/user/v2/machine.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/command" @@ -11,7 +12,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) createUserTypeMachine(ctx context.Context, machinePb *user.CreateUserRequest_Machine, orgId, userName, userId string) (*user.CreateUserResponse, error) { +func (s *Server) createUserTypeMachine(ctx context.Context, machinePb *user.CreateUserRequest_Machine, orgId, userName, userId string) (*connect.Response[user.CreateUserResponse], error) { cmd := &command.Machine{ Username: userName, Name: machinePb.Name, @@ -32,21 +33,21 @@ func (s *Server) createUserTypeMachine(ctx context.Context, machinePb *user.Crea if err != nil { return nil, err } - return &user.CreateUserResponse{ + return connect.NewResponse(&user.CreateUserResponse{ Id: cmd.AggregateID, CreationDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) updateUserTypeMachine(ctx context.Context, machinePb *user.UpdateUserRequest_Machine, userId string, userName *string) (*user.UpdateUserResponse, error) { +func (s *Server) updateUserTypeMachine(ctx context.Context, machinePb *user.UpdateUserRequest_Machine, userId string, userName *string) (*connect.Response[user.UpdateUserResponse], error) { cmd := updateMachineUserToCommand(userId, userName, machinePb) err := s.command.ChangeUserMachine(ctx, cmd) if err != nil { return nil, err } - return &user.UpdateUserResponse{ + return connect.NewResponse(&user.UpdateUserResponse{ ChangeDate: timestamppb.New(cmd.Details.EventDate), - }, nil + }), nil } func updateMachineUserToCommand(userId string, userName *string, machine *user.UpdateUserRequest_Machine) *command.ChangeMachine { diff --git a/internal/api/grpc/user/v2/otp.go b/internal/api/grpc/user/v2/otp.go index fd76cf2b93..2f04f438dd 100644 --- a/internal/api/grpc/user/v2/otp.go +++ b/internal/api/grpc/user/v2/otp.go @@ -3,39 +3,41 @@ package user import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddOTPSMS(ctx context.Context, req *user.AddOTPSMSRequest) (*user.AddOTPSMSResponse, error) { - details, err := s.command.AddHumanOTPSMS(ctx, req.GetUserId(), "") +func (s *Server) AddOTPSMS(ctx context.Context, req *connect.Request[user.AddOTPSMSRequest]) (*connect.Response[user.AddOTPSMSResponse], error) { + details, err := s.command.AddHumanOTPSMS(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.AddOTPSMSResponse{Details: object.DomainToDetailsPb(details)}, nil + return connect.NewResponse(&user.AddOTPSMSResponse{Details: object.DomainToDetailsPb(details)}), nil } -func (s *Server) RemoveOTPSMS(ctx context.Context, req *user.RemoveOTPSMSRequest) (*user.RemoveOTPSMSResponse, error) { - objectDetails, err := s.command.RemoveHumanOTPSMS(ctx, req.GetUserId(), "") +func (s *Server) RemoveOTPSMS(ctx context.Context, req *connect.Request[user.RemoveOTPSMSRequest]) (*connect.Response[user.RemoveOTPSMSResponse], error) { + objectDetails, err := s.command.RemoveHumanOTPSMS(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveOTPSMSResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveOTPSMSResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } -func (s *Server) AddOTPEmail(ctx context.Context, req *user.AddOTPEmailRequest) (*user.AddOTPEmailResponse, error) { - details, err := s.command.AddHumanOTPEmail(ctx, req.GetUserId(), "") +func (s *Server) AddOTPEmail(ctx context.Context, req *connect.Request[user.AddOTPEmailRequest]) (*connect.Response[user.AddOTPEmailResponse], error) { + details, err := s.command.AddHumanOTPEmail(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.AddOTPEmailResponse{Details: object.DomainToDetailsPb(details)}, nil + return connect.NewResponse(&user.AddOTPEmailResponse{Details: object.DomainToDetailsPb(details)}), nil } -func (s *Server) RemoveOTPEmail(ctx context.Context, req *user.RemoveOTPEmailRequest) (*user.RemoveOTPEmailResponse, error) { - objectDetails, err := s.command.RemoveHumanOTPEmail(ctx, req.GetUserId(), "") +func (s *Server) RemoveOTPEmail(ctx context.Context, req *connect.Request[user.RemoveOTPEmailRequest]) (*connect.Response[user.RemoveOTPEmailResponse], error) { + objectDetails, err := s.command.RemoveHumanOTPEmail(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveOTPEmailResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveOTPEmailResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } diff --git a/internal/api/grpc/user/v2/passkey.go b/internal/api/grpc/user/v2/passkey.go index 145c1e5716..90c6d72d13 100644 --- a/internal/api/grpc/user/v2/passkey.go +++ b/internal/api/grpc/user/v2/passkey.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/structpb" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" @@ -13,17 +14,17 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyRequest) (resp *user.RegisterPasskeyResponse, err error) { +func (s *Server) RegisterPasskey(ctx context.Context, req *connect.Request[user.RegisterPasskeyRequest]) (resp *connect.Response[user.RegisterPasskeyResponse], err error) { var ( - authenticator = passkeyAuthenticatorToDomain(req.GetAuthenticator()) + authenticator = passkeyAuthenticatorToDomain(req.Msg.GetAuthenticator()) ) - if code := req.GetCode(); code != nil { + if code := req.Msg.GetCode(); code != nil { return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), "", authenticator, code.Id, code.Code, req.GetDomain(), s.userCodeAlg), + s.command.RegisterUserPasskeyWithCode(ctx, req.Msg.GetUserId(), "", authenticator, code.Id, code.Code, req.Msg.GetDomain(), s.userCodeAlg), ) } return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskey(ctx, req.GetUserId(), "", req.GetDomain(), authenticator), + s.command.RegisterUserPasskey(ctx, req.Msg.GetUserId(), "", req.Msg.GetDomain(), authenticator), ) } @@ -51,86 +52,86 @@ func webAuthNRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails return object.DomainToDetailsPb(details.ObjectDetails), options, nil } -func passkeyRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*user.RegisterPasskeyResponse, error) { +func passkeyRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*connect.Response[user.RegisterPasskeyResponse], error) { objectDetails, options, err := webAuthNRegistrationDetailsToPb(details, err) if err != nil { return nil, err } - return &user.RegisterPasskeyResponse{ + return connect.NewResponse(&user.RegisterPasskeyResponse{ Details: objectDetails, PasskeyId: details.ID, PublicKeyCredentialCreationOptions: options, - }, nil + }), nil } -func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *user.VerifyPasskeyRegistrationRequest) (*user.VerifyPasskeyRegistrationResponse, error) { - pkc, err := req.GetPublicKeyCredential().MarshalJSON() +func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *connect.Request[user.VerifyPasskeyRegistrationRequest]) (*connect.Response[user.VerifyPasskeyRegistrationResponse], error) { + pkc, err := req.Msg.GetPublicKeyCredential().MarshalJSON() if err != nil { return nil, zerrors.ThrowInternal(err, "USERv2-Pha2o", "Errors.Internal") } - objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.GetUserId(), "", req.GetPasskeyName(), "", pkc) + objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.Msg.GetUserId(), "", req.Msg.GetPasskeyName(), "", pkc) if err != nil { return nil, err } - return &user.VerifyPasskeyRegistrationResponse{ + return connect.NewResponse(&user.VerifyPasskeyRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *user.CreatePasskeyRegistrationLinkRequest) (resp *user.CreatePasskeyRegistrationLinkResponse, err error) { - switch medium := req.Medium.(type) { +func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *connect.Request[user.CreatePasskeyRegistrationLinkRequest]) (resp *connect.Response[user.CreatePasskeyRegistrationLinkResponse], err error) { + switch medium := req.Msg.Medium.(type) { case nil: return passkeyDetailsToPb( - s.command.AddUserPasskeyCode(ctx, req.GetUserId(), "", s.userCodeAlg), + s.command.AddUserPasskeyCode(ctx, req.Msg.GetUserId(), "", s.userCodeAlg), ) case *user.CreatePasskeyRegistrationLinkRequest_SendLink: return passkeyDetailsToPb( - s.command.AddUserPasskeyCodeURLTemplate(ctx, req.GetUserId(), "", s.userCodeAlg, medium.SendLink.GetUrlTemplate()), + s.command.AddUserPasskeyCodeURLTemplate(ctx, req.Msg.GetUserId(), "", s.userCodeAlg, medium.SendLink.GetUrlTemplate()), ) case *user.CreatePasskeyRegistrationLinkRequest_ReturnCode: return passkeyCodeDetailsToPb( - s.command.AddUserPasskeyCodeReturn(ctx, req.GetUserId(), "", s.userCodeAlg), + s.command.AddUserPasskeyCodeReturn(ctx, req.Msg.GetUserId(), "", s.userCodeAlg), ) default: return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-gaD8y", "verification oneOf %T in method CreatePasskeyRegistrationLink not implemented", medium) } } -func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) { +func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*connect.Response[user.CreatePasskeyRegistrationLinkResponse], error) { if err != nil { return nil, err } - return &user.CreatePasskeyRegistrationLinkResponse{ + return connect.NewResponse(&user.CreatePasskeyRegistrationLinkResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) { +func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*connect.Response[user.CreatePasskeyRegistrationLinkResponse], error) { if err != nil { return nil, err } - return &user.CreatePasskeyRegistrationLinkResponse{ + return connect.NewResponse(&user.CreatePasskeyRegistrationLinkResponse{ Details: object.DomainToDetailsPb(details.ObjectDetails), Code: &user.PasskeyRegistrationCode{ Id: details.CodeID, Code: details.Code, }, - }, nil + }), nil } -func (s *Server) RemovePasskey(ctx context.Context, req *user.RemovePasskeyRequest) (*user.RemovePasskeyResponse, error) { - objectDetails, err := s.command.HumanRemovePasswordless(ctx, req.GetUserId(), req.GetPasskeyId(), "") +func (s *Server) RemovePasskey(ctx context.Context, req *connect.Request[user.RemovePasskeyRequest]) (*connect.Response[user.RemovePasskeyResponse], error) { + objectDetails, err := s.command.HumanRemovePasswordless(ctx, req.Msg.GetUserId(), req.Msg.GetPasskeyId(), "") if err != nil { return nil, err } - return &user.RemovePasskeyResponse{ + return connect.NewResponse(&user.RemovePasskeyResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) ListPasskeys(ctx context.Context, req *user.ListPasskeysRequest) (*user.ListPasskeysResponse, error) { +func (s *Server) ListPasskeys(ctx context.Context, req *connect.Request[user.ListPasskeysRequest]) (*connect.Response[user.ListPasskeysResponse], error) { query := new(query.UserAuthMethodSearchQueries) - err := query.AppendUserIDQuery(req.UserId) + err := query.AppendUserIDQuery(req.Msg.UserId) if err != nil { return nil, err } @@ -146,10 +147,10 @@ func (s *Server) ListPasskeys(ctx context.Context, req *user.ListPasskeysRequest if err != nil { return nil, err } - return &user.ListPasskeysResponse{ + return connect.NewResponse(&user.ListPasskeysResponse{ Details: object.ToListDetails(authMethods.SearchResponse), Result: authMethodsToPasskeyPb(authMethods), - }, nil + }), nil } func authMethodsToPasskeyPb(methods *query.AuthMethods) []*user.Passkey { diff --git a/internal/api/grpc/user/v2/passkey_test.go b/internal/api/grpc/user/v2/passkey_test.go index 9263012b98..6429dd7ce6 100644 --- a/internal/api/grpc/user/v2/passkey_test.go +++ b/internal/api/grpc/user/v2/passkey_test.go @@ -123,11 +123,11 @@ func Test_passkeyRegistrationDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyRegistrationDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("Not equal:\nExpected\n%s\nActual:%s", tt.want, got) } if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } @@ -181,7 +181,9 @@ func Test_passkeyDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.args.err) - assert.Equal(t, tt.want, got) + if tt.want != nil { + assert.Equal(t, tt.want, got.Msg) + } }) } } @@ -242,9 +244,9 @@ func Test_passkeyCodeDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyCodeDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.args.err) - assert.Equal(t, tt.want, got) if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + assert.Equal(t, tt.want, got.Msg) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } diff --git a/internal/api/grpc/user/v2/password.go b/internal/api/grpc/user/v2/password.go index 55cf225c4b..a256a00355 100644 --- a/internal/api/grpc/user/v2/password.go +++ b/internal/api/grpc/user/v2/password.go @@ -3,23 +3,25 @@ package user import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) PasswordReset(ctx context.Context, req *user.PasswordResetRequest) (_ *user.PasswordResetResponse, err error) { +func (s *Server) PasswordReset(ctx context.Context, req *connect.Request[user.PasswordResetRequest]) (_ *connect.Response[user.PasswordResetResponse], err error) { var details *domain.ObjectDetails var code *string - switch m := req.GetMedium().(type) { + switch m := req.Msg.GetMedium().(type) { case *user.PasswordResetRequest_SendLink: - details, code, err = s.command.RequestPasswordResetURLTemplate(ctx, req.GetUserId(), m.SendLink.GetUrlTemplate(), notificationTypeToDomain(m.SendLink.GetNotificationType())) + details, code, err = s.command.RequestPasswordResetURLTemplate(ctx, req.Msg.GetUserId(), m.SendLink.GetUrlTemplate(), notificationTypeToDomain(m.SendLink.GetNotificationType())) case *user.PasswordResetRequest_ReturnCode: - details, code, err = s.command.RequestPasswordResetReturnCode(ctx, req.GetUserId()) + details, code, err = s.command.RequestPasswordResetReturnCode(ctx, req.Msg.GetUserId()) case nil: - details, code, err = s.command.RequestPasswordReset(ctx, req.GetUserId()) + details, code, err = s.command.RequestPasswordReset(ctx, req.Msg.GetUserId()) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-SDeeg", "verification oneOf %T in method RequestPasswordReset not implemented", m) } @@ -27,10 +29,10 @@ func (s *Server) PasswordReset(ctx context.Context, req *user.PasswordResetReque return nil, err } - return &user.PasswordResetResponse{ + return connect.NewResponse(&user.PasswordResetResponse{ Details: object.DomainToDetailsPb(details), VerificationCode: code, - }, nil + }), nil } func notificationTypeToDomain(notificationType user.NotificationType) domain.NotificationType { @@ -46,16 +48,16 @@ func notificationTypeToDomain(notificationType user.NotificationType) domain.Not } } -func (s *Server) SetPassword(ctx context.Context, req *user.SetPasswordRequest) (_ *user.SetPasswordResponse, err error) { +func (s *Server) SetPassword(ctx context.Context, req *connect.Request[user.SetPasswordRequest]) (_ *connect.Response[user.SetPasswordResponse], err error) { var details *domain.ObjectDetails - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetPasswordRequest_CurrentPassword: - details, err = s.command.ChangePassword(ctx, "", req.GetUserId(), v.CurrentPassword, req.GetNewPassword().GetPassword(), "", req.GetNewPassword().GetChangeRequired()) + details, err = s.command.ChangePassword(ctx, "", req.Msg.GetUserId(), v.CurrentPassword, req.Msg.GetNewPassword().GetPassword(), "", req.Msg.GetNewPassword().GetChangeRequired()) case *user.SetPasswordRequest_VerificationCode: - details, err = s.command.SetPasswordWithVerifyCode(ctx, "", req.GetUserId(), v.VerificationCode, req.GetNewPassword().GetPassword(), "", req.GetNewPassword().GetChangeRequired()) + details, err = s.command.SetPasswordWithVerifyCode(ctx, "", req.Msg.GetUserId(), v.VerificationCode, req.Msg.GetNewPassword().GetPassword(), "", req.Msg.GetNewPassword().GetChangeRequired()) case nil: - details, err = s.command.SetPassword(ctx, "", req.GetUserId(), req.GetNewPassword().GetPassword(), req.GetNewPassword().GetChangeRequired()) + details, err = s.command.SetPassword(ctx, "", req.Msg.GetUserId(), req.Msg.GetNewPassword().GetPassword(), req.Msg.GetNewPassword().GetChangeRequired()) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-SFdf2", "verification oneOf %T in method SetPasswordRequest not implemented", v) } @@ -63,7 +65,7 @@ func (s *Server) SetPassword(ctx context.Context, req *user.SetPasswordRequest) return nil, err } - return &user.SetPasswordResponse{ + return connect.NewResponse(&user.SetPasswordResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/pat.go b/internal/api/grpc/user/v2/pat.go index 54f6e99367..0c90eeaebd 100644 --- a/internal/api/grpc/user/v2/pat.go +++ b/internal/api/grpc/user/v2/pat.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "github.com/zitadel/oidc/v3/pkg/oidc" "google.golang.org/protobuf/types/known/timestamppb" @@ -13,13 +14,13 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddPersonalAccessToken(ctx context.Context, req *user.AddPersonalAccessTokenRequest) (*user.AddPersonalAccessTokenResponse, error) { +func (s *Server) AddPersonalAccessToken(ctx context.Context, req *connect.Request[user.AddPersonalAccessTokenRequest]) (*connect.Response[user.AddPersonalAccessTokenResponse], error) { newPat := &command.PersonalAccessToken{ ObjectRoot: models.ObjectRoot{ - AggregateID: req.UserId, + AggregateID: req.Msg.GetUserId(), }, PermissionCheck: s.command.NewPermissionCheckUserWrite(ctx), - ExpirationDate: req.ExpirationDate.AsTime(), + ExpirationDate: req.Msg.GetExpirationDate().AsTime(), Scopes: []string{ oidc.ScopeOpenID, oidc.ScopeProfile, @@ -32,25 +33,25 @@ func (s *Server) AddPersonalAccessToken(ctx context.Context, req *user.AddPerson if err != nil { return nil, err } - return &user.AddPersonalAccessTokenResponse{ + return connect.NewResponse(&user.AddPersonalAccessTokenResponse{ CreationDate: timestamppb.New(details.EventDate), TokenId: newPat.TokenID, Token: newPat.Token, - }, nil + }), nil } -func (s *Server) RemovePersonalAccessToken(ctx context.Context, req *user.RemovePersonalAccessTokenRequest) (*user.RemovePersonalAccessTokenResponse, error) { +func (s *Server) RemovePersonalAccessToken(ctx context.Context, req *connect.Request[user.RemovePersonalAccessTokenRequest]) (*connect.Response[user.RemovePersonalAccessTokenResponse], error) { objectDetails, err := s.command.RemovePersonalAccessToken(ctx, &command.PersonalAccessToken{ - TokenID: req.TokenId, + TokenID: req.Msg.GetTokenId(), ObjectRoot: models.ObjectRoot{ - AggregateID: req.UserId, + AggregateID: req.Msg.GetUserId(), }, PermissionCheck: s.command.NewPermissionCheckUserWrite(ctx), }) if err != nil { return nil, err } - return &user.RemovePersonalAccessTokenResponse{ + return connect.NewResponse(&user.RemovePersonalAccessTokenResponse{ DeletionDate: timestamppb.New(objectDetails.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/pat_query.go b/internal/api/grpc/user/v2/pat_query.go index 6bbd44d511..64231c1d93 100644 --- a/internal/api/grpc/user/v2/pat_query.go +++ b/internal/api/grpc/user/v2/pat_query.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/grpc/filter/v2" @@ -12,12 +13,12 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) ListPersonalAccessTokens(ctx context.Context, req *user.ListPersonalAccessTokensRequest) (*user.ListPersonalAccessTokensResponse, error) { - offset, limit, asc, err := filter.PaginationPbToQuery(s.systemDefaults, req.Pagination) +func (s *Server) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[user.ListPersonalAccessTokensRequest]) (*connect.Response[user.ListPersonalAccessTokensResponse], error) { + offset, limit, asc, err := filter.PaginationPbToQuery(s.systemDefaults, req.Msg.GetPagination()) if err != nil { return nil, err } - filters, err := patFiltersToQueries(req.Filters) + filters, err := patFiltersToQueries(req.Msg.GetFilters()) if err != nil { return nil, err } @@ -26,7 +27,7 @@ func (s *Server) ListPersonalAccessTokens(ctx context.Context, req *user.ListPer Offset: offset, Limit: limit, Asc: asc, - SortingColumn: authnPersonalAccessTokenFieldNameToSortingColumn(req.SortingColumn), + SortingColumn: authnPersonalAccessTokenFieldNameToSortingColumn(req.Msg.SortingColumn), }, Queries: filters, } @@ -48,7 +49,7 @@ func (s *Server) ListPersonalAccessTokens(ctx context.Context, req *user.ListPer ExpirationDate: timestamppb.New(pat.Expiration), } } - return resp, nil + return connect.NewResponse(resp), nil } func patFiltersToQueries(filters []*user.PersonalAccessTokensSearchFilter) (_ []query.SearchQuery, err error) { diff --git a/internal/api/grpc/user/v2/phone.go b/internal/api/grpc/user/v2/phone.go index fdd5a140c1..4be616f7ea 100644 --- a/internal/api/grpc/user/v2/phone.go +++ b/internal/api/grpc/user/v2/phone.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/domain" @@ -11,18 +12,18 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) SetPhone(ctx context.Context, req *user.SetPhoneRequest) (resp *user.SetPhoneResponse, err error) { +func (s *Server) SetPhone(ctx context.Context, req *connect.Request[user.SetPhoneRequest]) (resp *connect.Response[user.SetPhoneResponse], err error) { var phone *domain.Phone - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetPhoneRequest_SendCode: - phone, err = s.command.ChangeUserPhone(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhone(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) case *user.SetPhoneRequest_ReturnCode: - phone, err = s.command.ChangeUserPhoneReturnCode(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhoneReturnCode(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) case *user.SetPhoneRequest_IsVerified: - phone, err = s.command.ChangeUserPhoneVerified(ctx, req.GetUserId(), req.GetPhone()) + phone, err = s.command.ChangeUserPhoneVerified(ctx, req.Msg.GetUserId(), req.Msg.GetPhone()) case nil: - phone, err = s.command.ChangeUserPhone(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhone(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-Ahng0", "verification oneOf %T in method SetPhone not implemented", v) } @@ -30,42 +31,42 @@ func (s *Server) SetPhone(ctx context.Context, req *user.SetPhoneRequest) (resp return nil, err } - return &user.SetPhoneResponse{ + return connect.NewResponse(&user.SetPhoneResponse{ Details: &object.Details{ Sequence: phone.Sequence, ChangeDate: timestamppb.New(phone.ChangeDate), ResourceOwner: phone.ResourceOwner, }, VerificationCode: phone.PlainCode, - }, nil + }), nil } -func (s *Server) RemovePhone(ctx context.Context, req *user.RemovePhoneRequest) (resp *user.RemovePhoneResponse, err error) { +func (s *Server) RemovePhone(ctx context.Context, req *connect.Request[user.RemovePhoneRequest]) (resp *connect.Response[user.RemovePhoneResponse], err error) { details, err := s.command.RemoveUserPhone(ctx, - req.GetUserId(), + req.Msg.GetUserId(), ) if err != nil { return nil, err } - return &user.RemovePhoneResponse{ + return connect.NewResponse(&user.RemovePhoneResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) ResendPhoneCode(ctx context.Context, req *user.ResendPhoneCodeRequest) (resp *user.ResendPhoneCodeResponse, err error) { +func (s *Server) ResendPhoneCode(ctx context.Context, req *connect.Request[user.ResendPhoneCodeRequest]) (resp *connect.Response[user.ResendPhoneCodeResponse], err error) { var phone *domain.Phone - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.ResendPhoneCodeRequest_SendCode: - phone, err = s.command.ResendUserPhoneCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case *user.ResendPhoneCodeRequest_ReturnCode: - phone, err = s.command.ResendUserPhoneCodeReturnCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCodeReturnCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case nil: - phone, err = s.command.ResendUserPhoneCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-ResendUserPhoneCode", "verification oneOf %T in method SetPhone not implemented", v) } @@ -73,30 +74,30 @@ func (s *Server) ResendPhoneCode(ctx context.Context, req *user.ResendPhoneCodeR return nil, err } - return &user.ResendPhoneCodeResponse{ + return connect.NewResponse(&user.ResendPhoneCodeResponse{ Details: &object.Details{ Sequence: phone.Sequence, ChangeDate: timestamppb.New(phone.ChangeDate), ResourceOwner: phone.ResourceOwner, }, VerificationCode: phone.PlainCode, - }, nil + }), nil } -func (s *Server) VerifyPhone(ctx context.Context, req *user.VerifyPhoneRequest) (*user.VerifyPhoneResponse, error) { +func (s *Server) VerifyPhone(ctx context.Context, req *connect.Request[user.VerifyPhoneRequest]) (*connect.Response[user.VerifyPhoneResponse], error) { details, err := s.command.VerifyUserPhone(ctx, - req.GetUserId(), - req.GetVerificationCode(), + req.Msg.GetUserId(), + req.Msg.GetVerificationCode(), s.userCodeAlg, ) if err != nil { return nil, err } - return &user.VerifyPhoneResponse{ + return connect.NewResponse(&user.VerifyPhoneResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/secret.go b/internal/api/grpc/user/v2/secret.go index 1d54e1dde8..acc7aef8cb 100644 --- a/internal/api/grpc/user/v2/secret.go +++ b/internal/api/grpc/user/v2/secret.go @@ -3,37 +3,38 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddSecret(ctx context.Context, req *user.AddSecretRequest) (*user.AddSecretResponse, error) { +func (s *Server) AddSecret(ctx context.Context, req *connect.Request[user.AddSecretRequest]) (*connect.Response[user.AddSecretResponse], error) { newSecret := &command.GenerateMachineSecret{ PermissionCheck: s.command.NewPermissionCheckUserWrite(ctx), } - details, err := s.command.GenerateMachineSecret(ctx, req.UserId, "", newSecret) + details, err := s.command.GenerateMachineSecret(ctx, req.Msg.GetUserId(), "", newSecret) if err != nil { return nil, err } - return &user.AddSecretResponse{ + return connect.NewResponse(&user.AddSecretResponse{ CreationDate: timestamppb.New(details.EventDate), ClientSecret: newSecret.ClientSecret, - }, nil + }), nil } -func (s *Server) RemoveSecret(ctx context.Context, req *user.RemoveSecretRequest) (*user.RemoveSecretResponse, error) { +func (s *Server) RemoveSecret(ctx context.Context, req *connect.Request[user.RemoveSecretRequest]) (*connect.Response[user.RemoveSecretResponse], error) { details, err := s.command.RemoveMachineSecret( ctx, - req.UserId, + req.Msg.GetUserId(), "", s.command.NewPermissionCheckUserWrite(ctx), ) if err != nil { return nil, err } - return &user.RemoveSecretResponse{ + return connect.NewResponse(&user.RemoveSecretResponse{ DeletionDate: timestamppb.New(details.EventDate), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/server.go b/internal/api/grpc/user/v2/server.go index e3c7e8011e..1f94906853 100644 --- a/internal/api/grpc/user/v2/server.go +++ b/internal/api/grpc/user/v2/server.go @@ -2,8 +2,10 @@ package user import ( "context" + "net/http" - "google.golang.org/grpc" + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -13,12 +15,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/user/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2/userconnect" ) -var _ user.UserServiceServer = (*Server)(nil) +var _ userconnect.UserServiceHandler = (*Server)(nil) type Server struct { - user.UnimplementedUserServiceServer systemDefaults systemdefaults.SystemDefaults command *command.Commands query *query.Queries @@ -58,8 +60,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - user.RegisterUserServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return userconnect.NewUserServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return user.File_zitadel_user_v2_user_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/user/v2/totp.go b/internal/api/grpc/user/v2/totp.go index 9e2d028d72..51b615dac5 100644 --- a/internal/api/grpc/user/v2/totp.go +++ b/internal/api/grpc/user/v2/totp.go @@ -3,42 +3,44 @@ package user import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) RegisterTOTP(ctx context.Context, req *user.RegisterTOTPRequest) (*user.RegisterTOTPResponse, error) { +func (s *Server) RegisterTOTP(ctx context.Context, req *connect.Request[user.RegisterTOTPRequest]) (*connect.Response[user.RegisterTOTPResponse], error) { return totpDetailsToPb( - s.command.AddUserTOTP(ctx, req.GetUserId(), ""), + s.command.AddUserTOTP(ctx, req.Msg.GetUserId(), ""), ) } -func totpDetailsToPb(totp *domain.TOTP, err error) (*user.RegisterTOTPResponse, error) { +func totpDetailsToPb(totp *domain.TOTP, err error) (*connect.Response[user.RegisterTOTPResponse], error) { if err != nil { return nil, err } - return &user.RegisterTOTPResponse{ + return connect.NewResponse(&user.RegisterTOTPResponse{ Details: object.DomainToDetailsPb(totp.ObjectDetails), Uri: totp.URI, Secret: totp.Secret, - }, nil + }), nil } -func (s *Server) VerifyTOTPRegistration(ctx context.Context, req *user.VerifyTOTPRegistrationRequest) (*user.VerifyTOTPRegistrationResponse, error) { - objectDetails, err := s.command.CheckUserTOTP(ctx, req.GetUserId(), req.GetCode(), "") +func (s *Server) VerifyTOTPRegistration(ctx context.Context, req *connect.Request[user.VerifyTOTPRegistrationRequest]) (*connect.Response[user.VerifyTOTPRegistrationResponse], error) { + objectDetails, err := s.command.CheckUserTOTP(ctx, req.Msg.GetUserId(), req.Msg.GetCode(), "") if err != nil { return nil, err } - return &user.VerifyTOTPRegistrationResponse{ + return connect.NewResponse(&user.VerifyTOTPRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) RemoveTOTP(ctx context.Context, req *user.RemoveTOTPRequest) (*user.RemoveTOTPResponse, error) { - objectDetails, err := s.command.HumanRemoveTOTP(ctx, req.GetUserId(), "") +func (s *Server) RemoveTOTP(ctx context.Context, req *connect.Request[user.RemoveTOTPRequest]) (*connect.Response[user.RemoveTOTPResponse], error) { + objectDetails, err := s.command.HumanRemoveTOTP(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveTOTPResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveTOTPResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } diff --git a/internal/api/grpc/user/v2/totp_test.go b/internal/api/grpc/user/v2/totp_test.go index 27ce6fb469..259f5ab5c6 100644 --- a/internal/api/grpc/user/v2/totp_test.go +++ b/internal/api/grpc/user/v2/totp_test.go @@ -63,7 +63,7 @@ func Test_totpDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := totpDetailsToPb(tt.args.otp, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("RegisterTOTPResponse =\n%v\nwant\n%v", got, tt.want) } }) diff --git a/internal/api/grpc/user/v2/u2f.go b/internal/api/grpc/user/v2/u2f.go index 60c0f5ab07..bd12ea0dac 100644 --- a/internal/api/grpc/user/v2/u2f.go +++ b/internal/api/grpc/user/v2/u2f.go @@ -3,50 +3,52 @@ package user import ( "context" + "connectrpc.com/connect" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) RegisterU2F(ctx context.Context, req *user.RegisterU2FRequest) (*user.RegisterU2FResponse, error) { +func (s *Server) RegisterU2F(ctx context.Context, req *connect.Request[user.RegisterU2FRequest]) (*connect.Response[user.RegisterU2FResponse], error) { return u2fRegistrationDetailsToPb( - s.command.RegisterUserU2F(ctx, req.GetUserId(), "", req.GetDomain()), + s.command.RegisterUserU2F(ctx, req.Msg.GetUserId(), "", req.Msg.GetDomain()), ) } -func u2fRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*user.RegisterU2FResponse, error) { +func u2fRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*connect.Response[user.RegisterU2FResponse], error) { objectDetails, options, err := webAuthNRegistrationDetailsToPb(details, err) if err != nil { return nil, err } - return &user.RegisterU2FResponse{ + return connect.NewResponse(&user.RegisterU2FResponse{ Details: objectDetails, U2FId: details.ID, PublicKeyCredentialCreationOptions: options, - }, nil + }), nil } -func (s *Server) VerifyU2FRegistration(ctx context.Context, req *user.VerifyU2FRegistrationRequest) (*user.VerifyU2FRegistrationResponse, error) { - pkc, err := req.GetPublicKeyCredential().MarshalJSON() +func (s *Server) VerifyU2FRegistration(ctx context.Context, req *connect.Request[user.VerifyU2FRegistrationRequest]) (*connect.Response[user.VerifyU2FRegistrationResponse], error) { + pkc, err := req.Msg.GetPublicKeyCredential().MarshalJSON() if err != nil { return nil, zerrors.ThrowInternal(err, "USERv2-IeTh4", "Errors.Internal") } - objectDetails, err := s.command.HumanVerifyU2FSetup(ctx, req.GetUserId(), "", req.GetTokenName(), "", pkc) + objectDetails, err := s.command.HumanVerifyU2FSetup(ctx, req.Msg.GetUserId(), "", req.Msg.GetTokenName(), "", pkc) if err != nil { return nil, err } - return &user.VerifyU2FRegistrationResponse{ + return connect.NewResponse(&user.VerifyU2FRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) RemoveU2F(ctx context.Context, req *user.RemoveU2FRequest) (*user.RemoveU2FResponse, error) { - objectDetails, err := s.command.HumanRemoveU2F(ctx, req.GetUserId(), req.GetU2FId(), "") +func (s *Server) RemoveU2F(ctx context.Context, req *connect.Request[user.RemoveU2FRequest]) (*connect.Response[user.RemoveU2FResponse], error) { + objectDetails, err := s.command.HumanRemoveU2F(ctx, req.Msg.GetUserId(), req.Msg.GetU2FId(), "") if err != nil { return nil, err } - return &user.RemoveU2FResponse{ + return connect.NewResponse(&user.RemoveU2FResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2/u2f_test.go b/internal/api/grpc/user/v2/u2f_test.go index fae3ba1cdb..f6798a6f89 100644 --- a/internal/api/grpc/user/v2/u2f_test.go +++ b/internal/api/grpc/user/v2/u2f_test.go @@ -92,11 +92,11 @@ func Test_u2fRegistrationDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := u2fRegistrationDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("Not equal:\nExpected\n%s\nActual:%s", tt.want, got) } if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } diff --git a/internal/api/grpc/user/v2/user.go b/internal/api/grpc/user/v2/user.go index 6b4b2da75b..95c2883195 100644 --- a/internal/api/grpc/user/v2/user.go +++ b/internal/api/grpc/user/v2/user.go @@ -4,6 +4,7 @@ import ( "context" "io" + "connectrpc.com/connect" "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" @@ -15,8 +16,8 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest) (_ *user.AddHumanUserResponse, err error) { - human, err := AddUserRequestToAddHuman(req) +func (s *Server) AddHumanUser(ctx context.Context, req *connect.Request[user.AddHumanUserRequest]) (_ *connect.Response[user.AddHumanUserResponse], err error) { + human, err := AddUserRequestToAddHuman(req.Msg) if err != nil { return nil, err } @@ -24,12 +25,12 @@ func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest if err = s.command.AddUserHuman(ctx, orgID, human, false, s.userCodeAlg); err != nil { return nil, err } - return &user.AddHumanUserResponse{ + return connect.NewResponse(&user.AddHumanUserResponse{ UserId: human.ID, Details: object.DomainToDetailsPb(human.Details), EmailCode: human.EmailCode, PhoneCode: human.PhoneCode, - }, nil + }), nil } func AddUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman, error) { @@ -117,8 +118,8 @@ func genderToDomain(gender user.Gender) domain.Gender { } } -func (s *Server) UpdateHumanUser(ctx context.Context, req *user.UpdateHumanUserRequest) (_ *user.UpdateHumanUserResponse, err error) { - human, err := updateHumanUserRequestToChangeHuman(req) +func (s *Server) UpdateHumanUser(ctx context.Context, req *connect.Request[user.UpdateHumanUserRequest]) (_ *connect.Response[user.UpdateHumanUserResponse], err error) { + human, err := updateHumanUserRequestToChangeHuman(req.Msg) if err != nil { return nil, err } @@ -126,51 +127,51 @@ func (s *Server) UpdateHumanUser(ctx context.Context, req *user.UpdateHumanUserR if err != nil { return nil, err } - return &user.UpdateHumanUserResponse{ + return connect.NewResponse(&user.UpdateHumanUserResponse{ Details: object.DomainToDetailsPb(human.Details), EmailCode: human.EmailCode, PhoneCode: human.PhoneCode, - }, nil + }), nil } -func (s *Server) LockUser(ctx context.Context, req *user.LockUserRequest) (_ *user.LockUserResponse, err error) { - details, err := s.command.LockUserV2(ctx, req.UserId) +func (s *Server) LockUser(ctx context.Context, req *connect.Request[user.LockUserRequest]) (_ *connect.Response[user.LockUserResponse], err error) { + details, err := s.command.LockUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.LockUserResponse{ + return connect.NewResponse(&user.LockUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) UnlockUser(ctx context.Context, req *user.UnlockUserRequest) (_ *user.UnlockUserResponse, err error) { - details, err := s.command.UnlockUserV2(ctx, req.UserId) +func (s *Server) UnlockUser(ctx context.Context, req *connect.Request[user.UnlockUserRequest]) (_ *connect.Response[user.UnlockUserResponse], err error) { + details, err := s.command.UnlockUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.UnlockUserResponse{ + return connect.NewResponse(&user.UnlockUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) DeactivateUser(ctx context.Context, req *user.DeactivateUserRequest) (_ *user.DeactivateUserResponse, err error) { - details, err := s.command.DeactivateUserV2(ctx, req.UserId) +func (s *Server) DeactivateUser(ctx context.Context, req *connect.Request[user.DeactivateUserRequest]) (_ *connect.Response[user.DeactivateUserResponse], err error) { + details, err := s.command.DeactivateUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.DeactivateUserResponse{ + return connect.NewResponse(&user.DeactivateUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ReactivateUser(ctx context.Context, req *user.ReactivateUserRequest) (_ *user.ReactivateUserResponse, err error) { - details, err := s.command.ReactivateUserV2(ctx, req.UserId) +func (s *Server) ReactivateUser(ctx context.Context, req *connect.Request[user.ReactivateUserRequest]) (_ *connect.Response[user.ReactivateUserResponse], err error) { + details, err := s.command.ReactivateUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.ReactivateUserResponse{ + return connect.NewResponse(&user.ReactivateUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func ifNotNilPtr[v, p any](value *v, conv func(v) p) *p { @@ -182,18 +183,18 @@ func ifNotNilPtr[v, p any](value *v, conv func(v) p) *p { return &pVal } -func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_ *user.DeleteUserResponse, err error) { - memberships, grants, err := s.removeUserDependencies(ctx, req.GetUserId()) +func (s *Server) DeleteUser(ctx context.Context, req *connect.Request[user.DeleteUserRequest]) (_ *connect.Response[user.DeleteUserResponse], err error) { + memberships, grants, err := s.removeUserDependencies(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...) + details, err := s.command.RemoveUserV2(ctx, req.Msg.GetUserId(), "", memberships, grants...) if err != nil { return nil, err } - return &user.DeleteUserResponse{ + return connect.NewResponse(&user.DeleteUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func (s *Server) removeUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) { @@ -268,35 +269,35 @@ func userGrantsToIDs(userGrants []*query.UserGrant) []string { return converted } -func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *user.ListAuthenticationMethodTypesRequest) (*user.ListAuthenticationMethodTypesResponse, error) { - authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.GetUserId(), true, req.GetDomainQuery().GetIncludeWithoutDomain(), req.GetDomainQuery().GetDomain()) +func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *connect.Request[user.ListAuthenticationMethodTypesRequest]) (*connect.Response[user.ListAuthenticationMethodTypesResponse], error) { + authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.Msg.GetUserId(), true, req.Msg.GetDomainQuery().GetIncludeWithoutDomain(), req.Msg.GetDomainQuery().GetDomain()) if err != nil { return nil, err } - return &user.ListAuthenticationMethodTypesResponse{ + return connect.NewResponse(&user.ListAuthenticationMethodTypesResponse{ Details: object.ToListDetails(authMethods.SearchResponse), AuthMethodTypes: authMethodTypesToPb(authMethods.AuthMethodTypes), - }, nil + }), nil } -func (s *Server) ListAuthenticationFactors(ctx context.Context, req *user.ListAuthenticationFactorsRequest) (*user.ListAuthenticationFactorsResponse, error) { +func (s *Server) ListAuthenticationFactors(ctx context.Context, req *connect.Request[user.ListAuthenticationFactorsRequest]) (*connect.Response[user.ListAuthenticationFactorsResponse], error) { query := new(query.UserAuthMethodSearchQueries) - if err := query.AppendUserIDQuery(req.UserId); err != nil { + if err := query.AppendUserIDQuery(req.Msg.GetUserId()); err != nil { return nil, err } authMethodsType := []domain.UserAuthMethodType{domain.UserAuthMethodTypeU2F, domain.UserAuthMethodTypeTOTP, domain.UserAuthMethodTypeOTPSMS, domain.UserAuthMethodTypeOTPEmail} - if len(req.GetAuthFactors()) > 0 { - authMethodsType = object.AuthFactorsToPb(req.GetAuthFactors()) + if len(req.Msg.GetAuthFactors()) > 0 { + authMethodsType = object.AuthFactorsToPb(req.Msg.GetAuthFactors()) } if err := query.AppendAuthMethodsQuery(authMethodsType...); err != nil { return nil, err } states := []domain.MFAState{domain.MFAStateReady} - if len(req.GetStates()) > 0 { - states = object.AuthFactorStatesToPb(req.GetStates()) + if len(req.Msg.GetStates()) > 0 { + states = object.AuthFactorStatesToPb(req.Msg.GetStates()) } if err := query.AppendStatesQuery(states...); err != nil { return nil, err @@ -307,9 +308,9 @@ func (s *Server) ListAuthenticationFactors(ctx context.Context, req *user.ListAu return nil, err } - return &user.ListAuthenticationFactorsResponse{ + return connect.NewResponse(&user.ListAuthenticationFactorsResponse{ Result: object.AuthMethodsToPb(authMethods), - }, nil + }), nil } func authMethodTypesToPb(methodTypes []domain.UserAuthMethodType) []user.AuthenticationMethodType { @@ -343,8 +344,8 @@ func authMethodTypeToPb(methodType domain.UserAuthMethodType) user.Authenticatio } } -func (s *Server) CreateInviteCode(ctx context.Context, req *user.CreateInviteCodeRequest) (*user.CreateInviteCodeResponse, error) { - invite, err := createInviteCodeRequestToCommand(req) +func (s *Server) CreateInviteCode(ctx context.Context, req *connect.Request[user.CreateInviteCodeRequest]) (*connect.Response[user.CreateInviteCodeResponse], error) { + invite, err := createInviteCodeRequestToCommand(req.Msg) if err != nil { return nil, err } @@ -352,30 +353,30 @@ func (s *Server) CreateInviteCode(ctx context.Context, req *user.CreateInviteCod if err != nil { return nil, err } - return &user.CreateInviteCodeResponse{ + return connect.NewResponse(&user.CreateInviteCodeResponse{ Details: object.DomainToDetailsPb(details), InviteCode: code, - }, nil + }), nil } -func (s *Server) ResendInviteCode(ctx context.Context, req *user.ResendInviteCodeRequest) (*user.ResendInviteCodeResponse, error) { - details, err := s.command.ResendInviteCode(ctx, req.GetUserId(), "", "") +func (s *Server) ResendInviteCode(ctx context.Context, req *connect.Request[user.ResendInviteCodeRequest]) (*connect.Response[user.ResendInviteCodeResponse], error) { + details, err := s.command.ResendInviteCode(ctx, req.Msg.GetUserId(), "", "") if err != nil { return nil, err } - return &user.ResendInviteCodeResponse{ + return connect.NewResponse(&user.ResendInviteCodeResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) VerifyInviteCode(ctx context.Context, req *user.VerifyInviteCodeRequest) (*user.VerifyInviteCodeResponse, error) { - details, err := s.command.VerifyInviteCode(ctx, req.GetUserId(), req.GetVerificationCode()) +func (s *Server) VerifyInviteCode(ctx context.Context, req *connect.Request[user.VerifyInviteCodeRequest]) (*connect.Response[user.VerifyInviteCodeResponse], error) { + details, err := s.command.VerifyInviteCode(ctx, req.Msg.GetUserId(), req.Msg.GetVerificationCode()) if err != nil { return nil, err } - return &user.VerifyInviteCodeResponse{ + return connect.NewResponse(&user.VerifyInviteCodeResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func createInviteCodeRequestToCommand(req *user.CreateInviteCodeRequest) (*command.CreateUserInvite, error) { @@ -394,33 +395,33 @@ func createInviteCodeRequestToCommand(req *user.CreateInviteCodeRequest) (*comma } } -func (s *Server) HumanMFAInitSkipped(ctx context.Context, req *user.HumanMFAInitSkippedRequest) (_ *user.HumanMFAInitSkippedResponse, err error) { - details, err := s.command.HumanMFAInitSkippedV2(ctx, req.UserId) +func (s *Server) HumanMFAInitSkipped(ctx context.Context, req *connect.Request[user.HumanMFAInitSkippedRequest]) (_ *connect.Response[user.HumanMFAInitSkippedResponse], err error) { + details, err := s.command.HumanMFAInitSkippedV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.HumanMFAInitSkippedResponse{ + return connect.NewResponse(&user.HumanMFAInitSkippedResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) CreateUser(ctx context.Context, req *user.CreateUserRequest) (*user.CreateUserResponse, error) { - switch userType := req.GetUserType().(type) { +func (s *Server) CreateUser(ctx context.Context, req *connect.Request[user.CreateUserRequest]) (*connect.Response[user.CreateUserResponse], error) { + switch userType := req.Msg.GetUserType().(type) { case *user.CreateUserRequest_Human_: - return s.createUserTypeHuman(ctx, userType.Human, req.OrganizationId, req.Username, req.UserId) + return s.createUserTypeHuman(ctx, userType.Human, req.Msg.GetOrganizationId(), req.Msg.Username, req.Msg.UserId) case *user.CreateUserRequest_Machine_: - return s.createUserTypeMachine(ctx, userType.Machine, req.OrganizationId, req.GetUsername(), req.GetUserId()) + return s.createUserTypeMachine(ctx, userType.Machine, req.Msg.GetOrganizationId(), req.Msg.GetUsername(), req.Msg.GetUserId()) default: return nil, zerrors.ThrowInternal(nil, "", "user type is not implemented") } } -func (s *Server) UpdateUser(ctx context.Context, req *user.UpdateUserRequest) (*user.UpdateUserResponse, error) { - switch userType := req.GetUserType().(type) { +func (s *Server) UpdateUser(ctx context.Context, req *connect.Request[user.UpdateUserRequest]) (*connect.Response[user.UpdateUserResponse], error) { + switch userType := req.Msg.GetUserType().(type) { case *user.UpdateUserRequest_Human_: - return s.updateUserTypeHuman(ctx, userType.Human, req.UserId, req.Username) + return s.updateUserTypeHuman(ctx, userType.Human, req.Msg.GetUserId(), req.Msg.Username) case *user.UpdateUserRequest_Machine_: - return s.updateUserTypeMachine(ctx, userType.Machine, req.UserId, req.Username) + return s.updateUserTypeMachine(ctx, userType.Machine, req.Msg.GetUserId(), req.Msg.Username) default: return nil, zerrors.ThrowUnimplemented(nil, "", "user type is not implemented") } diff --git a/internal/api/grpc/user/v2/user_query.go b/internal/api/grpc/user/v2/user_query.go index dc886462be..5f5603af31 100644 --- a/internal/api/grpc/user/v2/user_query.go +++ b/internal/api/grpc/user/v2/user_query.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/timestamppb" @@ -13,12 +14,12 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) (_ *user.GetUserByIDResponse, err error) { - resp, err := s.query.GetUserByIDWithPermission(ctx, true, req.GetUserId(), s.checkPermission) +func (s *Server) GetUserByID(ctx context.Context, req *connect.Request[user.GetUserByIDRequest]) (_ *connect.Response[user.GetUserByIDResponse], err error) { + resp, err := s.query.GetUserByIDWithPermission(ctx, true, req.Msg.GetUserId(), s.checkPermission) if err != nil { return nil, err } - return &user.GetUserByIDResponse{ + return connect.NewResponse(&user.GetUserByIDResponse{ Details: object.DomainToDetailsPb(&domain.ObjectDetails{ Sequence: resp.Sequence, CreationDate: resp.CreationDate, @@ -26,11 +27,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) ResourceOwner: resp.ResourceOwner, }), User: userToPb(resp, s.assetAPIPrefix(ctx)), - }, nil + }), nil } -func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, err := listUsersRequestToModel(req) +func (s *Server) ListUsers(ctx context.Context, req *connect.Request[user.ListUsersRequest]) (*connect.Response[user.ListUsersResponse], error) { + queries, err := listUsersRequestToModel(req.Msg) if err != nil { return nil, err } @@ -38,10 +39,10 @@ func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*us if err != nil { return nil, err } - return &user.ListUsersResponse{ + return connect.NewResponse(&user.ListUsersResponse{ Result: UsersToPb(res.Users, s.assetAPIPrefix(ctx)), Details: object.ToListDetails(res.SearchResponse), - }, nil + }), nil } func UsersToPb(users []*query.User, assetPrefix string) []*user.User { diff --git a/internal/api/grpc/user/v2beta/email.go b/internal/api/grpc/user/v2beta/email.go index 38cc73c75c..474111f767 100644 --- a/internal/api/grpc/user/v2beta/email.go +++ b/internal/api/grpc/user/v2beta/email.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/domain" @@ -11,18 +12,18 @@ import ( user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) { +func (s *Server) SetEmail(ctx context.Context, req *connect.Request[user.SetEmailRequest]) (resp *connect.Response[user.SetEmailResponse], err error) { var email *domain.Email - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetEmailRequest_SendCode: - email, err = s.command.ChangeUserEmailURLTemplate(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) + email, err = s.command.ChangeUserEmailURLTemplate(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) case *user.SetEmailRequest_ReturnCode: - email, err = s.command.ChangeUserEmailReturnCode(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg) + email, err = s.command.ChangeUserEmailReturnCode(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg) case *user.SetEmailRequest_IsVerified: - email, err = s.command.ChangeUserEmailVerified(ctx, req.GetUserId(), req.GetEmail()) + email, err = s.command.ChangeUserEmailVerified(ctx, req.Msg.GetUserId(), req.Msg.GetEmail()) case nil: - email, err = s.command.ChangeUserEmail(ctx, req.GetUserId(), req.GetEmail(), s.userCodeAlg) + email, err = s.command.ChangeUserEmail(ctx, req.Msg.GetUserId(), req.Msg.GetEmail(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-Ahng0", "verification oneOf %T in method SetEmail not implemented", v) } @@ -30,26 +31,26 @@ func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp return nil, err } - return &user.SetEmailResponse{ + return connect.NewResponse(&user.SetEmailResponse{ Details: &object.Details{ Sequence: email.Sequence, ChangeDate: timestamppb.New(email.ChangeDate), ResourceOwner: email.ResourceOwner, }, VerificationCode: email.PlainCode, - }, nil + }), nil } -func (s *Server) ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.ResendEmailCodeResponse, err error) { +func (s *Server) ResendEmailCode(ctx context.Context, req *connect.Request[user.ResendEmailCodeRequest]) (resp *connect.Response[user.ResendEmailCodeResponse], err error) { var email *domain.Email - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.ResendEmailCodeRequest_SendCode: - email, err = s.command.ResendUserEmailCodeURLTemplate(ctx, req.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) + email, err = s.command.ResendUserEmailCodeURLTemplate(ctx, req.Msg.GetUserId(), s.userCodeAlg, v.SendCode.GetUrlTemplate()) case *user.ResendEmailCodeRequest_ReturnCode: - email, err = s.command.ResendUserEmailReturnCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.ResendUserEmailReturnCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case nil: - email, err = s.command.ResendUserEmailCode(ctx, req.GetUserId(), s.userCodeAlg) + email, err = s.command.ResendUserEmailCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-faj0l0nj5x", "verification oneOf %T in method ResendEmailCode not implemented", v) } @@ -57,30 +58,30 @@ func (s *Server) ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeR return nil, err } - return &user.ResendEmailCodeResponse{ + return connect.NewResponse(&user.ResendEmailCodeResponse{ Details: &object.Details{ Sequence: email.Sequence, ChangeDate: timestamppb.New(email.ChangeDate), ResourceOwner: email.ResourceOwner, }, VerificationCode: email.PlainCode, - }, nil + }), nil } -func (s *Server) VerifyEmail(ctx context.Context, req *user.VerifyEmailRequest) (*user.VerifyEmailResponse, error) { +func (s *Server) VerifyEmail(ctx context.Context, req *connect.Request[user.VerifyEmailRequest]) (*connect.Response[user.VerifyEmailResponse], error) { details, err := s.command.VerifyUserEmail(ctx, - req.GetUserId(), - req.GetVerificationCode(), + req.Msg.GetUserId(), + req.Msg.GetVerificationCode(), s.userCodeAlg, ) if err != nil { return nil, err } - return &user.VerifyEmailResponse{ + return connect.NewResponse(&user.VerifyEmailResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2beta/otp.go b/internal/api/grpc/user/v2beta/otp.go index c11aa4c1a4..99919ce047 100644 --- a/internal/api/grpc/user/v2beta/otp.go +++ b/internal/api/grpc/user/v2beta/otp.go @@ -3,40 +3,42 @@ package user import ( "context" + "connectrpc.com/connect" + object "github.com/zitadel/zitadel/internal/api/grpc/object/v2beta" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) AddOTPSMS(ctx context.Context, req *user.AddOTPSMSRequest) (*user.AddOTPSMSResponse, error) { - details, err := s.command.AddHumanOTPSMS(ctx, req.GetUserId(), "") +func (s *Server) AddOTPSMS(ctx context.Context, req *connect.Request[user.AddOTPSMSRequest]) (*connect.Response[user.AddOTPSMSResponse], error) { + details, err := s.command.AddHumanOTPSMS(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.AddOTPSMSResponse{Details: object.DomainToDetailsPb(details)}, nil + return connect.NewResponse(&user.AddOTPSMSResponse{Details: object.DomainToDetailsPb(details)}), nil } -func (s *Server) RemoveOTPSMS(ctx context.Context, req *user.RemoveOTPSMSRequest) (*user.RemoveOTPSMSResponse, error) { - objectDetails, err := s.command.RemoveHumanOTPSMS(ctx, req.GetUserId(), "") +func (s *Server) RemoveOTPSMS(ctx context.Context, req *connect.Request[user.RemoveOTPSMSRequest]) (*connect.Response[user.RemoveOTPSMSResponse], error) { + objectDetails, err := s.command.RemoveHumanOTPSMS(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveOTPSMSResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveOTPSMSResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } -func (s *Server) AddOTPEmail(ctx context.Context, req *user.AddOTPEmailRequest) (*user.AddOTPEmailResponse, error) { - details, err := s.command.AddHumanOTPEmail(ctx, req.GetUserId(), "") +func (s *Server) AddOTPEmail(ctx context.Context, req *connect.Request[user.AddOTPEmailRequest]) (*connect.Response[user.AddOTPEmailResponse], error) { + details, err := s.command.AddHumanOTPEmail(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.AddOTPEmailResponse{Details: object.DomainToDetailsPb(details)}, nil + return connect.NewResponse(&user.AddOTPEmailResponse{Details: object.DomainToDetailsPb(details)}), nil } -func (s *Server) RemoveOTPEmail(ctx context.Context, req *user.RemoveOTPEmailRequest) (*user.RemoveOTPEmailResponse, error) { - objectDetails, err := s.command.RemoveHumanOTPEmail(ctx, req.GetUserId(), "") +func (s *Server) RemoveOTPEmail(ctx context.Context, req *connect.Request[user.RemoveOTPEmailRequest]) (*connect.Response[user.RemoveOTPEmailResponse], error) { + objectDetails, err := s.command.RemoveHumanOTPEmail(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveOTPEmailResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveOTPEmailResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } diff --git a/internal/api/grpc/user/v2beta/passkey.go b/internal/api/grpc/user/v2beta/passkey.go index 2df267f3fd..a63ac708b4 100644 --- a/internal/api/grpc/user/v2beta/passkey.go +++ b/internal/api/grpc/user/v2beta/passkey.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/structpb" object "github.com/zitadel/zitadel/internal/api/grpc/object/v2beta" @@ -12,17 +13,17 @@ import ( user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyRequest) (resp *user.RegisterPasskeyResponse, err error) { +func (s *Server) RegisterPasskey(ctx context.Context, req *connect.Request[user.RegisterPasskeyRequest]) (resp *connect.Response[user.RegisterPasskeyResponse], err error) { var ( - authenticator = passkeyAuthenticatorToDomain(req.GetAuthenticator()) + authenticator = passkeyAuthenticatorToDomain(req.Msg.GetAuthenticator()) ) - if code := req.GetCode(); code != nil { + if code := req.Msg.GetCode(); code != nil { return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), "", authenticator, code.Id, code.Code, req.GetDomain(), s.userCodeAlg), + s.command.RegisterUserPasskeyWithCode(ctx, req.Msg.GetUserId(), "", authenticator, code.Id, code.Code, req.Msg.GetDomain(), s.userCodeAlg), ) } return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskey(ctx, req.GetUserId(), "", req.GetDomain(), authenticator), + s.command.RegisterUserPasskey(ctx, req.Msg.GetUserId(), "", req.Msg.GetDomain(), authenticator), ) } @@ -50,69 +51,69 @@ func webAuthNRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails return object.DomainToDetailsPb(details.ObjectDetails), options, nil } -func passkeyRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*user.RegisterPasskeyResponse, error) { +func passkeyRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*connect.Response[user.RegisterPasskeyResponse], error) { objectDetails, options, err := webAuthNRegistrationDetailsToPb(details, err) if err != nil { return nil, err } - return &user.RegisterPasskeyResponse{ + return connect.NewResponse(&user.RegisterPasskeyResponse{ Details: objectDetails, PasskeyId: details.ID, PublicKeyCredentialCreationOptions: options, - }, nil + }), nil } -func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *user.VerifyPasskeyRegistrationRequest) (*user.VerifyPasskeyRegistrationResponse, error) { - pkc, err := req.GetPublicKeyCredential().MarshalJSON() +func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *connect.Request[user.VerifyPasskeyRegistrationRequest]) (*connect.Response[user.VerifyPasskeyRegistrationResponse], error) { + pkc, err := req.Msg.GetPublicKeyCredential().MarshalJSON() if err != nil { return nil, zerrors.ThrowInternal(err, "USERv2-Pha2o", "Errors.Internal") } - objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.GetUserId(), "", req.GetPasskeyName(), "", pkc) + objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.Msg.GetUserId(), "", req.Msg.GetPasskeyName(), "", pkc) if err != nil { return nil, err } - return &user.VerifyPasskeyRegistrationResponse{ + return connect.NewResponse(&user.VerifyPasskeyRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *user.CreatePasskeyRegistrationLinkRequest) (resp *user.CreatePasskeyRegistrationLinkResponse, err error) { - switch medium := req.Medium.(type) { +func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *connect.Request[user.CreatePasskeyRegistrationLinkRequest]) (resp *connect.Response[user.CreatePasskeyRegistrationLinkResponse], err error) { + switch medium := req.Msg.Medium.(type) { case nil: return passkeyDetailsToPb( - s.command.AddUserPasskeyCode(ctx, req.GetUserId(), "", s.userCodeAlg), + s.command.AddUserPasskeyCode(ctx, req.Msg.GetUserId(), "", s.userCodeAlg), ) case *user.CreatePasskeyRegistrationLinkRequest_SendLink: return passkeyDetailsToPb( - s.command.AddUserPasskeyCodeURLTemplate(ctx, req.GetUserId(), "", s.userCodeAlg, medium.SendLink.GetUrlTemplate()), + s.command.AddUserPasskeyCodeURLTemplate(ctx, req.Msg.GetUserId(), "", s.userCodeAlg, medium.SendLink.GetUrlTemplate()), ) case *user.CreatePasskeyRegistrationLinkRequest_ReturnCode: return passkeyCodeDetailsToPb( - s.command.AddUserPasskeyCodeReturn(ctx, req.GetUserId(), "", s.userCodeAlg), + s.command.AddUserPasskeyCodeReturn(ctx, req.Msg.GetUserId(), "", s.userCodeAlg), ) default: return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-gaD8y", "verification oneOf %T in method CreatePasskeyRegistrationLink not implemented", medium) } } -func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) { +func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*connect.Response[user.CreatePasskeyRegistrationLinkResponse], error) { if err != nil { return nil, err } - return &user.CreatePasskeyRegistrationLinkResponse{ + return connect.NewResponse(&user.CreatePasskeyRegistrationLinkResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) { +func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*connect.Response[user.CreatePasskeyRegistrationLinkResponse], error) { if err != nil { return nil, err } - return &user.CreatePasskeyRegistrationLinkResponse{ + return connect.NewResponse(&user.CreatePasskeyRegistrationLinkResponse{ Details: object.DomainToDetailsPb(details.ObjectDetails), Code: &user.PasskeyRegistrationCode{ Id: details.CodeID, Code: details.Code, }, - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2beta/passkey_test.go b/internal/api/grpc/user/v2beta/passkey_test.go index f4a48ed941..12ef8ed02f 100644 --- a/internal/api/grpc/user/v2beta/passkey_test.go +++ b/internal/api/grpc/user/v2beta/passkey_test.go @@ -123,11 +123,11 @@ func Test_passkeyRegistrationDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyRegistrationDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("Not equal:\nExpected\n%s\nActual:%s", tt.want, got) } if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } @@ -181,7 +181,9 @@ func Test_passkeyDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.args.err) - assert.Equal(t, tt.want, got) + if tt.want != nil { + assert.Equal(t, tt.want, got.Msg) + } }) } } @@ -242,9 +244,9 @@ func Test_passkeyCodeDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := passkeyCodeDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.args.err) - assert.Equal(t, tt.want, got) if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + assert.Equal(t, tt.want, got.Msg) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } diff --git a/internal/api/grpc/user/v2beta/password.go b/internal/api/grpc/user/v2beta/password.go index 0de1262215..ae9a549db0 100644 --- a/internal/api/grpc/user/v2beta/password.go +++ b/internal/api/grpc/user/v2beta/password.go @@ -3,23 +3,25 @@ package user import ( "context" + "connectrpc.com/connect" + object "github.com/zitadel/zitadel/internal/api/grpc/object/v2beta" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) PasswordReset(ctx context.Context, req *user.PasswordResetRequest) (_ *user.PasswordResetResponse, err error) { +func (s *Server) PasswordReset(ctx context.Context, req *connect.Request[user.PasswordResetRequest]) (_ *connect.Response[user.PasswordResetResponse], err error) { var details *domain.ObjectDetails var code *string - switch m := req.GetMedium().(type) { + switch m := req.Msg.GetMedium().(type) { case *user.PasswordResetRequest_SendLink: - details, code, err = s.command.RequestPasswordResetURLTemplate(ctx, req.GetUserId(), m.SendLink.GetUrlTemplate(), notificationTypeToDomain(m.SendLink.GetNotificationType())) + details, code, err = s.command.RequestPasswordResetURLTemplate(ctx, req.Msg.GetUserId(), m.SendLink.GetUrlTemplate(), notificationTypeToDomain(m.SendLink.GetNotificationType())) case *user.PasswordResetRequest_ReturnCode: - details, code, err = s.command.RequestPasswordResetReturnCode(ctx, req.GetUserId()) + details, code, err = s.command.RequestPasswordResetReturnCode(ctx, req.Msg.GetUserId()) case nil: - details, code, err = s.command.RequestPasswordReset(ctx, req.GetUserId()) + details, code, err = s.command.RequestPasswordReset(ctx, req.Msg.GetUserId()) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-SDeeg", "verification oneOf %T in method RequestPasswordReset not implemented", m) } @@ -27,10 +29,10 @@ func (s *Server) PasswordReset(ctx context.Context, req *user.PasswordResetReque return nil, err } - return &user.PasswordResetResponse{ + return connect.NewResponse(&user.PasswordResetResponse{ Details: object.DomainToDetailsPb(details), VerificationCode: code, - }, nil + }), nil } func notificationTypeToDomain(notificationType user.NotificationType) domain.NotificationType { @@ -46,16 +48,16 @@ func notificationTypeToDomain(notificationType user.NotificationType) domain.Not } } -func (s *Server) SetPassword(ctx context.Context, req *user.SetPasswordRequest) (_ *user.SetPasswordResponse, err error) { +func (s *Server) SetPassword(ctx context.Context, req *connect.Request[user.SetPasswordRequest]) (_ *connect.Response[user.SetPasswordResponse], err error) { var details *domain.ObjectDetails - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetPasswordRequest_CurrentPassword: - details, err = s.command.ChangePassword(ctx, "", req.GetUserId(), v.CurrentPassword, req.GetNewPassword().GetPassword(), "", req.GetNewPassword().GetChangeRequired()) + details, err = s.command.ChangePassword(ctx, "", req.Msg.GetUserId(), v.CurrentPassword, req.Msg.GetNewPassword().GetPassword(), "", req.Msg.GetNewPassword().GetChangeRequired()) case *user.SetPasswordRequest_VerificationCode: - details, err = s.command.SetPasswordWithVerifyCode(ctx, "", req.GetUserId(), v.VerificationCode, req.GetNewPassword().GetPassword(), "", req.GetNewPassword().GetChangeRequired()) + details, err = s.command.SetPasswordWithVerifyCode(ctx, "", req.Msg.GetUserId(), v.VerificationCode, req.Msg.GetNewPassword().GetPassword(), "", req.Msg.GetNewPassword().GetChangeRequired()) case nil: - details, err = s.command.SetPassword(ctx, "", req.GetUserId(), req.GetNewPassword().GetPassword(), req.GetNewPassword().GetChangeRequired()) + details, err = s.command.SetPassword(ctx, "", req.Msg.GetUserId(), req.Msg.GetNewPassword().GetPassword(), req.Msg.GetNewPassword().GetChangeRequired()) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-SFdf2", "verification oneOf %T in method SetPasswordRequest not implemented", v) } @@ -63,7 +65,7 @@ func (s *Server) SetPassword(ctx context.Context, req *user.SetPasswordRequest) return nil, err } - return &user.SetPasswordResponse{ + return connect.NewResponse(&user.SetPasswordResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2beta/phone.go b/internal/api/grpc/user/v2beta/phone.go index eac7eb4e31..20ef2075ab 100644 --- a/internal/api/grpc/user/v2beta/phone.go +++ b/internal/api/grpc/user/v2beta/phone.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/domain" @@ -11,18 +12,18 @@ import ( user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) SetPhone(ctx context.Context, req *user.SetPhoneRequest) (resp *user.SetPhoneResponse, err error) { +func (s *Server) SetPhone(ctx context.Context, req *connect.Request[user.SetPhoneRequest]) (resp *connect.Response[user.SetPhoneResponse], err error) { var phone *domain.Phone - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.SetPhoneRequest_SendCode: - phone, err = s.command.ChangeUserPhone(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhone(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) case *user.SetPhoneRequest_ReturnCode: - phone, err = s.command.ChangeUserPhoneReturnCode(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhoneReturnCode(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) case *user.SetPhoneRequest_IsVerified: - phone, err = s.command.ChangeUserPhoneVerified(ctx, req.GetUserId(), req.GetPhone()) + phone, err = s.command.ChangeUserPhoneVerified(ctx, req.Msg.GetUserId(), req.Msg.GetPhone()) case nil: - phone, err = s.command.ChangeUserPhone(ctx, req.GetUserId(), req.GetPhone(), s.userCodeAlg) + phone, err = s.command.ChangeUserPhone(ctx, req.Msg.GetUserId(), req.Msg.GetPhone(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-Ahng0", "verification oneOf %T in method SetPhone not implemented", v) } @@ -30,42 +31,42 @@ func (s *Server) SetPhone(ctx context.Context, req *user.SetPhoneRequest) (resp return nil, err } - return &user.SetPhoneResponse{ + return connect.NewResponse(&user.SetPhoneResponse{ Details: &object.Details{ Sequence: phone.Sequence, ChangeDate: timestamppb.New(phone.ChangeDate), ResourceOwner: phone.ResourceOwner, }, VerificationCode: phone.PlainCode, - }, nil + }), nil } -func (s *Server) RemovePhone(ctx context.Context, req *user.RemovePhoneRequest) (resp *user.RemovePhoneResponse, err error) { +func (s *Server) RemovePhone(ctx context.Context, req *connect.Request[user.RemovePhoneRequest]) (resp *connect.Response[user.RemovePhoneResponse], err error) { details, err := s.command.RemoveUserPhone(ctx, - req.GetUserId(), + req.Msg.GetUserId(), ) if err != nil { return nil, err } - return &user.RemovePhoneResponse{ + return connect.NewResponse(&user.RemovePhoneResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } -func (s *Server) ResendPhoneCode(ctx context.Context, req *user.ResendPhoneCodeRequest) (resp *user.ResendPhoneCodeResponse, err error) { +func (s *Server) ResendPhoneCode(ctx context.Context, req *connect.Request[user.ResendPhoneCodeRequest]) (resp *connect.Response[user.ResendPhoneCodeResponse], err error) { var phone *domain.Phone - switch v := req.GetVerification().(type) { + switch v := req.Msg.GetVerification().(type) { case *user.ResendPhoneCodeRequest_SendCode: - phone, err = s.command.ResendUserPhoneCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case *user.ResendPhoneCodeRequest_ReturnCode: - phone, err = s.command.ResendUserPhoneCodeReturnCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCodeReturnCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) case nil: - phone, err = s.command.ResendUserPhoneCode(ctx, req.GetUserId(), s.userCodeAlg) + phone, err = s.command.ResendUserPhoneCode(ctx, req.Msg.GetUserId(), s.userCodeAlg) default: err = zerrors.ThrowUnimplementedf(nil, "USERv2-ResendUserPhoneCode", "verification oneOf %T in method SetPhone not implemented", v) } @@ -73,30 +74,30 @@ func (s *Server) ResendPhoneCode(ctx context.Context, req *user.ResendPhoneCodeR return nil, err } - return &user.ResendPhoneCodeResponse{ + return connect.NewResponse(&user.ResendPhoneCodeResponse{ Details: &object.Details{ Sequence: phone.Sequence, ChangeDate: timestamppb.New(phone.ChangeDate), ResourceOwner: phone.ResourceOwner, }, VerificationCode: phone.PlainCode, - }, nil + }), nil } -func (s *Server) VerifyPhone(ctx context.Context, req *user.VerifyPhoneRequest) (*user.VerifyPhoneResponse, error) { +func (s *Server) VerifyPhone(ctx context.Context, req *connect.Request[user.VerifyPhoneRequest]) (*connect.Response[user.VerifyPhoneResponse], error) { details, err := s.command.VerifyUserPhone(ctx, - req.GetUserId(), - req.GetVerificationCode(), + req.Msg.GetUserId(), + req.Msg.GetVerificationCode(), s.userCodeAlg, ) if err != nil { return nil, err } - return &user.VerifyPhoneResponse{ + return connect.NewResponse(&user.VerifyPhoneResponse{ Details: &object.Details{ Sequence: details.Sequence, ChangeDate: timestamppb.New(details.EventDate), ResourceOwner: details.ResourceOwner, }, - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2beta/query.go b/internal/api/grpc/user/v2beta/query.go index 46b009a72e..b9654ea97c 100644 --- a/internal/api/grpc/user/v2beta/query.go +++ b/internal/api/grpc/user/v2beta/query.go @@ -3,6 +3,7 @@ package user import ( "context" + "connectrpc.com/connect" "github.com/muhlemmer/gu" "google.golang.org/protobuf/types/known/timestamppb" @@ -13,23 +14,23 @@ import ( user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) (_ *user.GetUserByIDResponse, err error) { - resp, err := s.query.GetUserByIDWithPermission(ctx, true, req.GetUserId(), s.checkPermission) +func (s *Server) GetUserByID(ctx context.Context, req *connect.Request[user.GetUserByIDRequest]) (_ *connect.Response[user.GetUserByIDResponse], err error) { + resp, err := s.query.GetUserByIDWithPermission(ctx, true, req.Msg.GetUserId(), s.checkPermission) if err != nil { return nil, err } - return &user.GetUserByIDResponse{ + return connect.NewResponse(&user.GetUserByIDResponse{ Details: object.DomainToDetailsPb(&domain.ObjectDetails{ Sequence: resp.Sequence, EventDate: resp.ChangeDate, ResourceOwner: resp.ResourceOwner, }), User: userToPb(resp, s.assetAPIPrefix(ctx)), - }, nil + }), nil } -func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, err := listUsersRequestToModel(req) +func (s *Server) ListUsers(ctx context.Context, req *connect.Request[user.ListUsersRequest]) (*connect.Response[user.ListUsersResponse], error) { + queries, err := listUsersRequestToModel(req.Msg) if err != nil { return nil, err } @@ -37,10 +38,10 @@ func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*us if err != nil { return nil, err } - return &user.ListUsersResponse{ + return connect.NewResponse(&user.ListUsersResponse{ Result: UsersToPb(res.Users, s.assetAPIPrefix(ctx)), Details: object.ToListDetails(res.SearchResponse), - }, nil + }), nil } func UsersToPb(users []*query.User, assetPrefix string) []*user.User { diff --git a/internal/api/grpc/user/v2beta/server.go b/internal/api/grpc/user/v2beta/server.go index 93af47f58b..7e3934a2c1 100644 --- a/internal/api/grpc/user/v2beta/server.go +++ b/internal/api/grpc/user/v2beta/server.go @@ -2,8 +2,10 @@ package user import ( "context" + "net/http" - "google.golang.org/grpc" + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" @@ -12,12 +14,12 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/user/v2beta/userconnect" ) -var _ user.UserServiceServer = (*Server)(nil) +var _ userconnect.UserServiceHandler = (*Server)(nil) type Server struct { - user.UnimplementedUserServiceServer command *command.Commands query *query.Queries userCodeAlg crypto.EncryptionAlgorithm @@ -54,8 +56,12 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - user.RegisterUserServiceServer(grpcServer, s) +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return userconnect.NewUserServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return user.File_zitadel_user_v2beta_user_service_proto } func (s *Server) AppName() string { diff --git a/internal/api/grpc/user/v2beta/totp.go b/internal/api/grpc/user/v2beta/totp.go index 2ef47a9817..e7bd01b2b6 100644 --- a/internal/api/grpc/user/v2beta/totp.go +++ b/internal/api/grpc/user/v2beta/totp.go @@ -3,42 +3,44 @@ package user import ( "context" + "connectrpc.com/connect" + object "github.com/zitadel/zitadel/internal/api/grpc/object/v2beta" "github.com/zitadel/zitadel/internal/domain" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) RegisterTOTP(ctx context.Context, req *user.RegisterTOTPRequest) (*user.RegisterTOTPResponse, error) { +func (s *Server) RegisterTOTP(ctx context.Context, req *connect.Request[user.RegisterTOTPRequest]) (*connect.Response[user.RegisterTOTPResponse], error) { return totpDetailsToPb( - s.command.AddUserTOTP(ctx, req.GetUserId(), ""), + s.command.AddUserTOTP(ctx, req.Msg.GetUserId(), ""), ) } -func totpDetailsToPb(totp *domain.TOTP, err error) (*user.RegisterTOTPResponse, error) { +func totpDetailsToPb(totp *domain.TOTP, err error) (*connect.Response[user.RegisterTOTPResponse], error) { if err != nil { return nil, err } - return &user.RegisterTOTPResponse{ + return connect.NewResponse(&user.RegisterTOTPResponse{ Details: object.DomainToDetailsPb(totp.ObjectDetails), Uri: totp.URI, Secret: totp.Secret, - }, nil + }), nil } -func (s *Server) VerifyTOTPRegistration(ctx context.Context, req *user.VerifyTOTPRegistrationRequest) (*user.VerifyTOTPRegistrationResponse, error) { - objectDetails, err := s.command.CheckUserTOTP(ctx, req.GetUserId(), req.GetCode(), "") +func (s *Server) VerifyTOTPRegistration(ctx context.Context, req *connect.Request[user.VerifyTOTPRegistrationRequest]) (*connect.Response[user.VerifyTOTPRegistrationResponse], error) { + objectDetails, err := s.command.CheckUserTOTP(ctx, req.Msg.GetUserId(), req.Msg.GetCode(), "") if err != nil { return nil, err } - return &user.VerifyTOTPRegistrationResponse{ + return connect.NewResponse(&user.VerifyTOTPRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } -func (s *Server) RemoveTOTP(ctx context.Context, req *user.RemoveTOTPRequest) (*user.RemoveTOTPResponse, error) { - objectDetails, err := s.command.HumanRemoveTOTP(ctx, req.GetUserId(), "") +func (s *Server) RemoveTOTP(ctx context.Context, req *connect.Request[user.RemoveTOTPRequest]) (*connect.Response[user.RemoveTOTPResponse], error) { + objectDetails, err := s.command.HumanRemoveTOTP(ctx, req.Msg.GetUserId(), "") if err != nil { return nil, err } - return &user.RemoveTOTPResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil + return connect.NewResponse(&user.RemoveTOTPResponse{Details: object.DomainToDetailsPb(objectDetails)}), nil } diff --git a/internal/api/grpc/user/v2beta/totp_test.go b/internal/api/grpc/user/v2beta/totp_test.go index 81a54675f2..77c6e5c343 100644 --- a/internal/api/grpc/user/v2beta/totp_test.go +++ b/internal/api/grpc/user/v2beta/totp_test.go @@ -63,7 +63,7 @@ func Test_totpDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := totpDetailsToPb(tt.args.otp, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("RegisterTOTPResponse =\n%v\nwant\n%v", got, tt.want) } }) diff --git a/internal/api/grpc/user/v2beta/u2f.go b/internal/api/grpc/user/v2beta/u2f.go index e23a22b8b5..a6823a4bc0 100644 --- a/internal/api/grpc/user/v2beta/u2f.go +++ b/internal/api/grpc/user/v2beta/u2f.go @@ -3,40 +3,42 @@ package user import ( "context" + "connectrpc.com/connect" + object "github.com/zitadel/zitadel/internal/api/grpc/object/v2beta" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) RegisterU2F(ctx context.Context, req *user.RegisterU2FRequest) (*user.RegisterU2FResponse, error) { +func (s *Server) RegisterU2F(ctx context.Context, req *connect.Request[user.RegisterU2FRequest]) (*connect.Response[user.RegisterU2FResponse], error) { return u2fRegistrationDetailsToPb( - s.command.RegisterUserU2F(ctx, req.GetUserId(), "", req.GetDomain()), + s.command.RegisterUserU2F(ctx, req.Msg.GetUserId(), "", req.Msg.GetDomain()), ) } -func u2fRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*user.RegisterU2FResponse, error) { +func u2fRegistrationDetailsToPb(details *domain.WebAuthNRegistrationDetails, err error) (*connect.Response[user.RegisterU2FResponse], error) { objectDetails, options, err := webAuthNRegistrationDetailsToPb(details, err) if err != nil { return nil, err } - return &user.RegisterU2FResponse{ + return connect.NewResponse(&user.RegisterU2FResponse{ Details: objectDetails, U2FId: details.ID, PublicKeyCredentialCreationOptions: options, - }, nil + }), nil } -func (s *Server) VerifyU2FRegistration(ctx context.Context, req *user.VerifyU2FRegistrationRequest) (*user.VerifyU2FRegistrationResponse, error) { - pkc, err := req.GetPublicKeyCredential().MarshalJSON() +func (s *Server) VerifyU2FRegistration(ctx context.Context, req *connect.Request[user.VerifyU2FRegistrationRequest]) (*connect.Response[user.VerifyU2FRegistrationResponse], error) { + pkc, err := req.Msg.GetPublicKeyCredential().MarshalJSON() if err != nil { return nil, zerrors.ThrowInternal(err, "USERv2-IeTh4", "Errors.Internal") } - objectDetails, err := s.command.HumanVerifyU2FSetup(ctx, req.GetUserId(), "", req.GetTokenName(), "", pkc) + objectDetails, err := s.command.HumanVerifyU2FSetup(ctx, req.Msg.GetUserId(), "", req.Msg.GetTokenName(), "", pkc) if err != nil { return nil, err } - return &user.VerifyU2FRegistrationResponse{ + return connect.NewResponse(&user.VerifyU2FRegistrationResponse{ Details: object.DomainToDetailsPb(objectDetails), - }, nil + }), nil } diff --git a/internal/api/grpc/user/v2beta/u2f_test.go b/internal/api/grpc/user/v2beta/u2f_test.go index 53f2a0bb8c..ac99c0d1eb 100644 --- a/internal/api/grpc/user/v2beta/u2f_test.go +++ b/internal/api/grpc/user/v2beta/u2f_test.go @@ -92,11 +92,11 @@ func Test_u2fRegistrationDetailsToPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := u2fRegistrationDetailsToPb(tt.args.details, tt.args.err) require.ErrorIs(t, err, tt.wantErr) - if !proto.Equal(tt.want, got) { + if tt.want != nil && !proto.Equal(tt.want, got.Msg) { t.Errorf("Not equal:\nExpected\n%s\nActual:%s", tt.want, got) } if tt.want != nil { - grpc.AllFieldsSet(t, got.ProtoReflect()) + grpc.AllFieldsSet(t, got.Msg.ProtoReflect()) } }) } diff --git a/internal/api/grpc/user/v2beta/user.go b/internal/api/grpc/user/v2beta/user.go index 49f0c7d9c7..e5b2094d2c 100644 --- a/internal/api/grpc/user/v2beta/user.go +++ b/internal/api/grpc/user/v2beta/user.go @@ -6,6 +6,7 @@ import ( "io" "time" + "connectrpc.com/connect" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -23,8 +24,8 @@ import ( user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" ) -func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest) (_ *user.AddHumanUserResponse, err error) { - human, err := AddUserRequestToAddHuman(req) +func (s *Server) AddHumanUser(ctx context.Context, req *connect.Request[user.AddHumanUserRequest]) (_ *connect.Response[user.AddHumanUserResponse], err error) { + human, err := AddUserRequestToAddHuman(req.Msg) if err != nil { return nil, err } @@ -32,12 +33,12 @@ func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest if err = s.command.AddUserHuman(ctx, orgID, human, false, s.userCodeAlg); err != nil { return nil, err } - return &user.AddHumanUserResponse{ + return connect.NewResponse(&user.AddHumanUserResponse{ UserId: human.ID, Details: object.DomainToDetailsPb(human.Details), EmailCode: human.EmailCode, PhoneCode: human.PhoneCode, - }, nil + }), nil } func AddUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman, error) { @@ -115,8 +116,8 @@ func genderToDomain(gender user.Gender) domain.Gender { } } -func (s *Server) UpdateHumanUser(ctx context.Context, req *user.UpdateHumanUserRequest) (_ *user.UpdateHumanUserResponse, err error) { - human, err := UpdateUserRequestToChangeHuman(req) +func (s *Server) UpdateHumanUser(ctx context.Context, req *connect.Request[user.UpdateHumanUserRequest]) (_ *connect.Response[user.UpdateHumanUserResponse], err error) { + human, err := UpdateUserRequestToChangeHuman(req.Msg) if err != nil { return nil, err } @@ -124,51 +125,51 @@ func (s *Server) UpdateHumanUser(ctx context.Context, req *user.UpdateHumanUserR if err != nil { return nil, err } - return &user.UpdateHumanUserResponse{ + return connect.NewResponse(&user.UpdateHumanUserResponse{ Details: object.DomainToDetailsPb(human.Details), EmailCode: human.EmailCode, PhoneCode: human.PhoneCode, - }, nil + }), nil } -func (s *Server) LockUser(ctx context.Context, req *user.LockUserRequest) (_ *user.LockUserResponse, err error) { - details, err := s.command.LockUserV2(ctx, req.UserId) +func (s *Server) LockUser(ctx context.Context, req *connect.Request[user.LockUserRequest]) (_ *connect.Response[user.LockUserResponse], err error) { + details, err := s.command.LockUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.LockUserResponse{ + return connect.NewResponse(&user.LockUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) UnlockUser(ctx context.Context, req *user.UnlockUserRequest) (_ *user.UnlockUserResponse, err error) { - details, err := s.command.UnlockUserV2(ctx, req.UserId) +func (s *Server) UnlockUser(ctx context.Context, req *connect.Request[user.UnlockUserRequest]) (_ *connect.Response[user.UnlockUserResponse], err error) { + details, err := s.command.UnlockUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.UnlockUserResponse{ + return connect.NewResponse(&user.UnlockUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) DeactivateUser(ctx context.Context, req *user.DeactivateUserRequest) (_ *user.DeactivateUserResponse, err error) { - details, err := s.command.DeactivateUserV2(ctx, req.UserId) +func (s *Server) DeactivateUser(ctx context.Context, req *connect.Request[user.DeactivateUserRequest]) (_ *connect.Response[user.DeactivateUserResponse], err error) { + details, err := s.command.DeactivateUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.DeactivateUserResponse{ + return connect.NewResponse(&user.DeactivateUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) ReactivateUser(ctx context.Context, req *user.ReactivateUserRequest) (_ *user.ReactivateUserResponse, err error) { - details, err := s.command.ReactivateUserV2(ctx, req.UserId) +func (s *Server) ReactivateUser(ctx context.Context, req *connect.Request[user.ReactivateUserRequest]) (_ *connect.Response[user.ReactivateUserResponse], err error) { + details, err := s.command.ReactivateUserV2(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - return &user.ReactivateUserResponse{ + return connect.NewResponse(&user.ReactivateUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func ifNotNilPtr[v, p any](value *v, conv func(v) p) *p { @@ -260,32 +261,32 @@ func SetHumanPasswordToPassword(password *user.SetPassword) *command.Password { } } -func (s *Server) AddIDPLink(ctx context.Context, req *user.AddIDPLinkRequest) (_ *user.AddIDPLinkResponse, err error) { - details, err := s.command.AddUserIDPLink(ctx, req.UserId, "", &command.AddLink{ - IDPID: req.GetIdpLink().GetIdpId(), - DisplayName: req.GetIdpLink().GetUserName(), - IDPExternalID: req.GetIdpLink().GetUserId(), +func (s *Server) AddIDPLink(ctx context.Context, req *connect.Request[user.AddIDPLinkRequest]) (_ *connect.Response[user.AddIDPLinkResponse], err error) { + details, err := s.command.AddUserIDPLink(ctx, req.Msg.GetUserId(), "", &command.AddLink{ + IDPID: req.Msg.GetIdpLink().GetIdpId(), + DisplayName: req.Msg.GetIdpLink().GetUserName(), + IDPExternalID: req.Msg.GetIdpLink().GetUserId(), }) if err != nil { return nil, err } - return &user.AddIDPLinkResponse{ + return connect.NewResponse(&user.AddIDPLinkResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } -func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_ *user.DeleteUserResponse, err error) { - memberships, grants, err := s.removeUserDependencies(ctx, req.GetUserId()) +func (s *Server) DeleteUser(ctx context.Context, req *connect.Request[user.DeleteUserRequest]) (_ *connect.Response[user.DeleteUserResponse], err error) { + memberships, grants, err := s.removeUserDependencies(ctx, req.Msg.GetUserId()) if err != nil { return nil, err } - details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...) + details, err := s.command.RemoveUserV2(ctx, req.Msg.GetUserId(), "", memberships, grants...) if err != nil { return nil, err } - return &user.DeleteUserResponse{ + return connect.NewResponse(&user.DeleteUserResponse{ Details: object.DomainToDetailsPb(details), - }, nil + }), nil } func (s *Server) removeUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) { @@ -360,18 +361,18 @@ func userGrantsToIDs(userGrants []*query.UserGrant) []string { return converted } -func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.StartIdentityProviderIntentRequest) (_ *user.StartIdentityProviderIntentResponse, err error) { - switch t := req.GetContent().(type) { +func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *connect.Request[user.StartIdentityProviderIntentRequest]) (_ *connect.Response[user.StartIdentityProviderIntentResponse], err error) { + switch t := req.Msg.GetContent().(type) { case *user.StartIdentityProviderIntentRequest_Urls: - return s.startIDPIntent(ctx, req.GetIdpId(), t.Urls) + return s.startIDPIntent(ctx, req.Msg.GetIdpId(), t.Urls) case *user.StartIdentityProviderIntentRequest_Ldap: - return s.startLDAPIntent(ctx, req.GetIdpId(), t.Ldap) + return s.startLDAPIntent(ctx, req.Msg.GetIdpId(), t.Ldap) default: return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-S2g21", "type oneOf %T in method StartIdentityProviderIntent not implemented", t) } } -func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { +func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*connect.Response[user.StartIdentityProviderIntentResponse], error) { state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) if err != nil { return nil, err @@ -386,12 +387,12 @@ func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.Re } switch a := auth.(type) { case *idp.RedirectAuth: - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: a.RedirectURL}, - }, nil + }), nil case *idp.FormAuth: - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_FormData{ FormData: &user.FormData{ @@ -399,12 +400,12 @@ func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.Re Fields: a.Fields, }, }, - }, nil + }), nil } return nil, zerrors.ThrowInvalidArgumentf(nil, "USERv2-3g2j3", "type oneOf %T in method StartIdentityProviderIntent not implemented", auth) } -func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { +func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*connect.Response[user.StartIdentityProviderIntentResponse], error) { intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil) if err != nil { return nil, err @@ -420,7 +421,7 @@ func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredenti if err != nil { return nil, err } - return &user.StartIdentityProviderIntentResponse{ + return connect.NewResponse(&user.StartIdentityProviderIntentResponse{ Details: object.DomainToDetailsPb(details), NextStep: &user.StartIdentityProviderIntentResponse_IdpIntent{ IdpIntent: &user.IDPIntent{ @@ -429,7 +430,7 @@ func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredenti UserId: userID, }, }, - }, nil + }), nil } func (s *Server) checkLinkedExternalUser(ctx context.Context, idpID, externalUserID string) (string, error) { @@ -483,12 +484,12 @@ func (s *Server) ldapLogin(ctx context.Context, idpID, username, password string return externalUser, userID, session, nil } -func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.RetrieveIdentityProviderIntentRequest) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { - intent, err := s.command.GetIntentWriteModel(ctx, req.GetIdpIntentId(), "") +func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *connect.Request[user.RetrieveIdentityProviderIntentRequest]) (_ *connect.Response[user.RetrieveIdentityProviderIntentResponse], err error) { + intent, err := s.command.GetIntentWriteModel(ctx, req.Msg.GetIdpIntentId(), "") if err != nil { return nil, err } - if err := s.checkIntentToken(req.GetIdpIntentToken(), intent.AggregateID); err != nil { + if err := s.checkIntentToken(req.Msg.GetIdpIntentToken(), intent.AggregateID); err != nil { return nil, err } if intent.State != domain.IDPIntentStateSucceeded { @@ -500,7 +501,7 @@ func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.R return idpIntentToIDPIntentPb(intent, s.idpAlg) } -func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { +func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *connect.Response[user.RetrieveIdentityProviderIntentResponse], err error) { rawInformation := new(structpb.Struct) err = rawInformation.UnmarshalJSON(intent.IDPUser) if err != nil { @@ -539,7 +540,7 @@ func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.Encr information.IdpInformation.Access = IDPSAMLResponseToPb(assertion) } - return information, nil + return connect.NewResponse(information), nil } func idpOAuthTokensToPb(idpIDToken string, idpAccessToken *crypto.CryptoValue, alg crypto.EncryptionAlgorithm) (_ *user.IDPInformation_Oauth, err error) { @@ -602,15 +603,15 @@ func (s *Server) checkIntentToken(token string, intentID string) error { return crypto.CheckToken(s.idpAlg, token, intentID) } -func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *user.ListAuthenticationMethodTypesRequest) (*user.ListAuthenticationMethodTypesResponse, error) { - authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.GetUserId(), true, false, "") +func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *connect.Request[user.ListAuthenticationMethodTypesRequest]) (*connect.Response[user.ListAuthenticationMethodTypesResponse], error) { + authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.Msg.GetUserId(), true, false, "") if err != nil { return nil, err } - return &user.ListAuthenticationMethodTypesResponse{ + return connect.NewResponse(&user.ListAuthenticationMethodTypesResponse{ Details: object.ToListDetails(authMethods.SearchResponse), AuthMethodTypes: authMethodTypesToPb(authMethods.AuthMethodTypes), - }, nil + }), nil } func authMethodTypesToPb(methodTypes []domain.UserAuthMethodType) []user.AuthenticationMethodType { diff --git a/internal/api/grpc/user/v2beta/user_test.go b/internal/api/grpc/user/v2beta/user_test.go index 9e398e83ff..8973d61fcc 100644 --- a/internal/api/grpc/user/v2beta/user_test.go +++ b/internal/api/grpc/user/v2beta/user_test.go @@ -322,7 +322,9 @@ func Test_idpIntentToIDPIntentPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := idpIntentToIDPIntentPb(tt.args.intent, tt.args.alg) require.ErrorIs(t, err, tt.res.err) - grpc.AllFieldsEqual(t, tt.res.resp.ProtoReflect(), got.ProtoReflect(), grpc.CustomMappers) + if tt.res.resp != nil { + grpc.AllFieldsEqual(t, tt.res.resp.ProtoReflect(), got.Msg.ProtoReflect(), grpc.CustomMappers) + } }) } } diff --git a/internal/api/grpc/webkey/v2/integration_test/webkey_integration_test.go b/internal/api/grpc/webkey/v2/integration_test/webkey_integration_test.go new file mode 100644 index 0000000000..48777927cf --- /dev/null +++ b/internal/api/grpc/webkey/v2/integration_test/webkey_integration_test.go @@ -0,0 +1,216 @@ +//go:build integration + +package webkey_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" +) + +var ( + CTX context.Context +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + CTX = ctx + return m.Run() + }()) +} + +func TestServer_ListWebKeys(t *testing.T) { + instance, iamCtx, creationDate := createInstance(t) + // After the feature is first enabled, we can expect 2 generated keys with the default config. + checkWebKeyListState(iamCtx, t, instance, 2, "", &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, creationDate) +} + +func TestServer_CreateWebKey(t *testing.T) { + instance, iamCtx, creationDate := createInstance(t) + client := instance.Client.WebKeyV2 + + _, err := client.CreateWebKey(iamCtx, &webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, + }) + require.NoError(t, err) + + checkWebKeyListState(iamCtx, t, instance, 3, "", &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, creationDate) +} + +func TestServer_ActivateWebKey(t *testing.T) { + instance, iamCtx, creationDate := createInstance(t) + client := instance.Client.WebKeyV2 + + resp, err := client.CreateWebKey(iamCtx, &webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, + }) + require.NoError(t, err) + + _, err = client.ActivateWebKey(iamCtx, &webkey.ActivateWebKeyRequest{ + Id: resp.GetId(), + }) + require.NoError(t, err) + + checkWebKeyListState(iamCtx, t, instance, 3, resp.GetId(), &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, creationDate) +} + +func TestServer_DeleteWebKey(t *testing.T) { + instance, iamCtx, creationDate := createInstance(t) + client := instance.Client.WebKeyV2 + + keyIDs := make([]string, 2) + for i := 0; i < 2; i++ { + resp, err := client.CreateWebKey(iamCtx, &webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, + }) + require.NoError(t, err) + keyIDs[i] = resp.GetId() + } + _, err := client.ActivateWebKey(iamCtx, &webkey.ActivateWebKeyRequest{ + Id: keyIDs[0], + }) + require.NoError(t, err) + + ok := t.Run("cannot delete active key", func(t *testing.T) { + _, err := client.DeleteWebKey(iamCtx, &webkey.DeleteWebKeyRequest{ + Id: keyIDs[0], + }) + require.Error(t, err) + s := status.Convert(err) + assert.Equal(t, codes.FailedPrecondition, s.Code()) + assert.Contains(t, s.Message(), "COMMAND-Chai1") + }) + if !ok { + return + } + + start := time.Now() + ok = t.Run("delete inactive key", func(t *testing.T) { + resp, err := client.DeleteWebKey(iamCtx, &webkey.DeleteWebKeyRequest{ + Id: keyIDs[1], + }) + require.NoError(t, err) + require.WithinRange(t, resp.GetDeletionDate().AsTime(), start, time.Now()) + }) + if !ok { + return + } + + ok = t.Run("delete inactive key again", func(t *testing.T) { + resp, err := client.DeleteWebKey(iamCtx, &webkey.DeleteWebKeyRequest{ + Id: keyIDs[1], + }) + require.NoError(t, err) + require.WithinRange(t, resp.GetDeletionDate().AsTime(), start, time.Now()) + }) + if !ok { + return + } + + ok = t.Run("delete not existing key", func(t *testing.T) { + resp, err := client.DeleteWebKey(iamCtx, &webkey.DeleteWebKeyRequest{ + Id: "not-existing", + }) + require.NoError(t, err) + require.Nil(t, resp.DeletionDate) + }) + if !ok { + return + } + + // There are 2 keys from feature setup, +2 created, -1 deleted = 3 + checkWebKeyListState(iamCtx, t, instance, 3, keyIDs[0], &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, creationDate) +} + +func createInstance(t *testing.T) (*integration.Instance, context.Context, *timestamppb.Timestamp) { + instance := integration.NewInstance(CTX) + creationDate := timestamppb.Now() + iamCTX := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(iamCTX, time.Minute) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + resp, err := instance.Client.WebKeyV2.ListWebKeys(iamCTX, &webkey.ListWebKeysRequest{}) + assert.NoError(collect, err) + assert.Len(collect, resp.GetWebKeys(), 2) + + }, retryDuration, tick) + + return instance, iamCTX, creationDate +} + +func checkWebKeyListState(ctx context.Context, t *testing.T, instance *integration.Instance, nKeys int, expectActiveKeyID string, config any, creationDate *timestamppb.Timestamp) { + t.Helper() + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + resp, err := instance.Client.WebKeyV2.ListWebKeys(ctx, &webkey.ListWebKeysRequest{}) + require.NoError(collect, err) + list := resp.GetWebKeys() + assert.Len(collect, list, nKeys) + + now := time.Now() + var gotActiveKeyID string + for _, key := range list { + assert.WithinRange(collect, key.GetCreationDate().AsTime(), now.Add(-time.Minute), now.Add(time.Minute)) + assert.WithinRange(collect, key.GetChangeDate().AsTime(), now.Add(-time.Minute), now.Add(time.Minute)) + assert.NotEqual(collect, webkey.State_STATE_UNSPECIFIED, key.GetState()) + assert.NotEqual(collect, webkey.State_STATE_REMOVED, key.GetState()) + assert.Equal(collect, config, key.GetKey()) + + if key.GetState() == webkey.State_STATE_ACTIVE { + gotActiveKeyID = key.GetId() + } + } + assert.NotEmpty(collect, gotActiveKeyID) + if expectActiveKeyID != "" { + assert.Equal(collect, expectActiveKeyID, gotActiveKeyID) + } + }, retryDuration, tick) +} diff --git a/internal/api/grpc/webkey/v2/server.go b/internal/api/grpc/webkey/v2/server.go new file mode 100644 index 0000000000..a62c29e2b9 --- /dev/null +++ b/internal/api/grpc/webkey/v2/server.go @@ -0,0 +1,51 @@ +package webkey + +import ( + "net/http" + + "connectrpc.com/connect" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2/webkeyconnect" +) + +var _ webkeyconnect.WebKeyServiceHandler = (*Server)(nil) + +type Server struct { + command *command.Commands + query *query.Queries +} + +func CreateServer( + command *command.Commands, + query *query.Queries, +) *Server { + return &Server{ + command: command, + query: query, + } +} + +func (s *Server) AppName() string { + return webkey.WebKeyService_ServiceDesc.ServiceName +} + +func (s *Server) MethodPrefix() string { + return webkey.WebKeyService_ServiceDesc.ServiceName +} + +func (s *Server) AuthMethods() authz.MethodMapping { + return webkey.WebKeyService_AuthMethods +} + +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return webkeyconnect.NewWebKeyServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return webkey.File_zitadel_webkey_v2_webkey_service_proto +} diff --git a/internal/api/grpc/webkey/v2/webkey.go b/internal/api/grpc/webkey/v2/webkey.go new file mode 100644 index 0000000000..d1a10a31d0 --- /dev/null +++ b/internal/api/grpc/webkey/v2/webkey.go @@ -0,0 +1,72 @@ +package webkey + +import ( + "context" + + "connectrpc.com/connect" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" +) + +func (s *Server) CreateWebKey(ctx context.Context, req *connect.Request[webkey.CreateWebKeyRequest]) (_ *connect.Response[webkey.CreateWebKeyResponse], err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + webKey, err := s.command.CreateWebKey(ctx, createWebKeyRequestToConfig(req.Msg)) + if err != nil { + return nil, err + } + + return connect.NewResponse(&webkey.CreateWebKeyResponse{ + Id: webKey.KeyID, + CreationDate: timestamppb.New(webKey.ObjectDetails.EventDate), + }), nil +} + +func (s *Server) ActivateWebKey(ctx context.Context, req *connect.Request[webkey.ActivateWebKeyRequest]) (_ *connect.Response[webkey.ActivateWebKeyResponse], err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + details, err := s.command.ActivateWebKey(ctx, req.Msg.GetId()) + if err != nil { + return nil, err + } + + return connect.NewResponse(&webkey.ActivateWebKeyResponse{ + ChangeDate: timestamppb.New(details.EventDate), + }), nil +} + +func (s *Server) DeleteWebKey(ctx context.Context, req *connect.Request[webkey.DeleteWebKeyRequest]) (_ *connect.Response[webkey.DeleteWebKeyResponse], err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + deletedAt, err := s.command.DeleteWebKey(ctx, req.Msg.GetId()) + if err != nil { + return nil, err + } + + var deletionDate *timestamppb.Timestamp + if !deletedAt.IsZero() { + deletionDate = timestamppb.New(deletedAt) + } + return connect.NewResponse(&webkey.DeleteWebKeyResponse{ + DeletionDate: deletionDate, + }), nil +} + +func (s *Server) ListWebKeys(ctx context.Context, _ *connect.Request[webkey.ListWebKeysRequest]) (_ *connect.Response[webkey.ListWebKeysResponse], err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + list, err := s.query.ListWebKeys(ctx) + if err != nil { + return nil, err + } + + return connect.NewResponse(&webkey.ListWebKeysResponse{ + WebKeys: webKeyDetailsListToPb(list), + }), nil +} diff --git a/internal/api/grpc/webkey/v2/webkey_converter.go b/internal/api/grpc/webkey/v2/webkey_converter.go new file mode 100644 index 0000000000..7ee7fbce05 --- /dev/null +++ b/internal/api/grpc/webkey/v2/webkey_converter.go @@ -0,0 +1,170 @@ +package webkey + +import ( + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" +) + +func createWebKeyRequestToConfig(req *webkey.CreateWebKeyRequest) crypto.WebKeyConfig { + switch config := req.GetKey().(type) { + case *webkey.CreateWebKeyRequest_Rsa: + return rsaToCrypto(config.Rsa) + case *webkey.CreateWebKeyRequest_Ecdsa: + return ecdsaToCrypto(config.Ecdsa) + case *webkey.CreateWebKeyRequest_Ed25519: + return new(crypto.WebKeyED25519Config) + default: + return rsaToCrypto(nil) + } +} + +func rsaToCrypto(config *webkey.RSA) *crypto.WebKeyRSAConfig { + out := new(crypto.WebKeyRSAConfig) + + switch config.GetBits() { + case webkey.RSABits_RSA_BITS_UNSPECIFIED: + out.Bits = crypto.RSABits2048 + case webkey.RSABits_RSA_BITS_2048: + out.Bits = crypto.RSABits2048 + case webkey.RSABits_RSA_BITS_3072: + out.Bits = crypto.RSABits3072 + case webkey.RSABits_RSA_BITS_4096: + out.Bits = crypto.RSABits4096 + default: + out.Bits = crypto.RSABits2048 + } + + switch config.GetHasher() { + case webkey.RSAHasher_RSA_HASHER_UNSPECIFIED: + out.Hasher = crypto.RSAHasherSHA256 + case webkey.RSAHasher_RSA_HASHER_SHA256: + out.Hasher = crypto.RSAHasherSHA256 + case webkey.RSAHasher_RSA_HASHER_SHA384: + out.Hasher = crypto.RSAHasherSHA384 + case webkey.RSAHasher_RSA_HASHER_SHA512: + out.Hasher = crypto.RSAHasherSHA512 + default: + out.Hasher = crypto.RSAHasherSHA256 + } + + return out +} + +func ecdsaToCrypto(config *webkey.ECDSA) *crypto.WebKeyECDSAConfig { + out := new(crypto.WebKeyECDSAConfig) + + switch config.GetCurve() { + case webkey.ECDSACurve_ECDSA_CURVE_UNSPECIFIED: + out.Curve = crypto.EllipticCurveP256 + case webkey.ECDSACurve_ECDSA_CURVE_P256: + out.Curve = crypto.EllipticCurveP256 + case webkey.ECDSACurve_ECDSA_CURVE_P384: + out.Curve = crypto.EllipticCurveP384 + case webkey.ECDSACurve_ECDSA_CURVE_P512: + out.Curve = crypto.EllipticCurveP512 + default: + out.Curve = crypto.EllipticCurveP256 + } + + return out +} + +func webKeyDetailsListToPb(list []query.WebKeyDetails) []*webkey.WebKey { + out := make([]*webkey.WebKey, len(list)) + for i := range list { + out[i] = webKeyDetailsToPb(&list[i]) + } + return out +} + +func webKeyDetailsToPb(details *query.WebKeyDetails) *webkey.WebKey { + out := &webkey.WebKey{ + Id: details.KeyID, + CreationDate: timestamppb.New(details.CreationDate), + ChangeDate: timestamppb.New(details.ChangeDate), + State: webKeyStateToPb(details.State), + } + + switch config := details.Config.(type) { + case *crypto.WebKeyRSAConfig: + out.Key = &webkey.WebKey_Rsa{ + Rsa: webKeyRSAConfigToPb(config), + } + case *crypto.WebKeyECDSAConfig: + out.Key = &webkey.WebKey_Ecdsa{ + Ecdsa: webKeyECDSAConfigToPb(config), + } + case *crypto.WebKeyED25519Config: + out.Key = &webkey.WebKey_Ed25519{ + Ed25519: new(webkey.ED25519), + } + } + + return out +} + +func webKeyStateToPb(state domain.WebKeyState) webkey.State { + switch state { + case domain.WebKeyStateUnspecified: + return webkey.State_STATE_UNSPECIFIED + case domain.WebKeyStateInitial: + return webkey.State_STATE_INITIAL + case domain.WebKeyStateActive: + return webkey.State_STATE_ACTIVE + case domain.WebKeyStateInactive: + return webkey.State_STATE_INACTIVE + case domain.WebKeyStateRemoved: + return webkey.State_STATE_REMOVED + default: + return webkey.State_STATE_UNSPECIFIED + } +} + +func webKeyRSAConfigToPb(config *crypto.WebKeyRSAConfig) *webkey.RSA { + out := new(webkey.RSA) + + switch config.Bits { + case crypto.RSABitsUnspecified: + out.Bits = webkey.RSABits_RSA_BITS_UNSPECIFIED + case crypto.RSABits2048: + out.Bits = webkey.RSABits_RSA_BITS_2048 + case crypto.RSABits3072: + out.Bits = webkey.RSABits_RSA_BITS_3072 + case crypto.RSABits4096: + out.Bits = webkey.RSABits_RSA_BITS_4096 + } + + switch config.Hasher { + case crypto.RSAHasherUnspecified: + out.Hasher = webkey.RSAHasher_RSA_HASHER_UNSPECIFIED + case crypto.RSAHasherSHA256: + out.Hasher = webkey.RSAHasher_RSA_HASHER_SHA256 + case crypto.RSAHasherSHA384: + out.Hasher = webkey.RSAHasher_RSA_HASHER_SHA384 + case crypto.RSAHasherSHA512: + out.Hasher = webkey.RSAHasher_RSA_HASHER_SHA512 + } + + return out +} + +func webKeyECDSAConfigToPb(config *crypto.WebKeyECDSAConfig) *webkey.ECDSA { + out := new(webkey.ECDSA) + + switch config.Curve { + case crypto.EllipticCurveUnspecified: + out.Curve = webkey.ECDSACurve_ECDSA_CURVE_UNSPECIFIED + case crypto.EllipticCurveP256: + out.Curve = webkey.ECDSACurve_ECDSA_CURVE_P256 + case crypto.EllipticCurveP384: + out.Curve = webkey.ECDSACurve_ECDSA_CURVE_P384 + case crypto.EllipticCurveP512: + out.Curve = webkey.ECDSACurve_ECDSA_CURVE_P512 + } + + return out +} diff --git a/internal/api/grpc/webkey/v2/webkey_converter_test.go b/internal/api/grpc/webkey/v2/webkey_converter_test.go new file mode 100644 index 0000000000..e7387d96ad --- /dev/null +++ b/internal/api/grpc/webkey/v2/webkey_converter_test.go @@ -0,0 +1,494 @@ +package webkey + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" +) + +func Test_createWebKeyRequestToConfig(t *testing.T) { + type args struct { + req *webkey.CreateWebKeyRequest + } + tests := []struct { + name string + args args + want crypto.WebKeyConfig + }{ + { + name: "RSA", + args: args{&webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_3072, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA384, + }, + }, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits3072, + Hasher: crypto.RSAHasherSHA384, + }, + }, + { + name: "ECDSA", + args: args{&webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Ecdsa{ + Ecdsa: &webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P384, + }, + }, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP384, + }, + }, + { + name: "ED25519", + args: args{&webkey.CreateWebKeyRequest{ + Key: &webkey.CreateWebKeyRequest_Ed25519{ + Ed25519: &webkey.ED25519{}, + }, + }}, + want: &crypto.WebKeyED25519Config{}, + }, + { + name: "default", + args: args{&webkey.CreateWebKeyRequest{}}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits2048, + Hasher: crypto.RSAHasherSHA256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := createWebKeyRequestToConfig(tt.args.req) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyRSAConfigToCrypto(t *testing.T) { + type args struct { + config *webkey.RSA + } + tests := []struct { + name string + args args + want *crypto.WebKeyRSAConfig + }{ + { + name: "unspecified", + args: args{&webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_UNSPECIFIED, + Hasher: webkey.RSAHasher_RSA_HASHER_UNSPECIFIED, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits2048, + Hasher: crypto.RSAHasherSHA256, + }, + }, + { + name: "2048, RSA256", + args: args{&webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits2048, + Hasher: crypto.RSAHasherSHA256, + }, + }, + { + name: "3072, RSA384", + args: args{&webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_3072, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA384, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits3072, + Hasher: crypto.RSAHasherSHA384, + }, + }, + { + name: "4096, RSA512", + args: args{&webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_4096, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA512, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits4096, + Hasher: crypto.RSAHasherSHA512, + }, + }, + { + name: "invalid", + args: args{&webkey.RSA{ + Bits: 99, + Hasher: 99, + }}, + want: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits2048, + Hasher: crypto.RSAHasherSHA256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rsaToCrypto(tt.args.config) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyECDSAConfigToCrypto(t *testing.T) { + type args struct { + config *webkey.ECDSA + } + tests := []struct { + name string + args args + want *crypto.WebKeyECDSAConfig + }{ + { + name: "unspecified", + args: args{&webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_UNSPECIFIED, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP256, + }, + }, + { + name: "P256", + args: args{&webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P256, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP256, + }, + }, + { + name: "P384", + args: args{&webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P384, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP384, + }, + }, + { + name: "P512", + args: args{&webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P512, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP512, + }, + }, + { + name: "invalid", + args: args{&webkey.ECDSA{ + Curve: 99, + }}, + want: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ecdsaToCrypto(tt.args.config) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyDetailsListToPb(t *testing.T) { + list := []query.WebKeyDetails{ + { + KeyID: "key1", + CreationDate: time.Unix(123, 456), + ChangeDate: time.Unix(789, 0), + Sequence: 123, + State: domain.WebKeyStateActive, + Config: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits3072, + Hasher: crypto.RSAHasherSHA384, + }, + }, + { + KeyID: "key2", + CreationDate: time.Unix(123, 456), + ChangeDate: time.Unix(789, 0), + Sequence: 123, + State: domain.WebKeyStateActive, + Config: &crypto.WebKeyED25519Config{}, + }, + } + want := []*webkey.WebKey{ + { + Id: "key1", + CreationDate: ×tamppb.Timestamp{Seconds: 123, Nanos: 456}, + ChangeDate: ×tamppb.Timestamp{Seconds: 789, Nanos: 0}, + State: webkey.State_STATE_ACTIVE, + Key: &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_3072, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA384, + }, + }, + }, + { + Id: "key2", + CreationDate: ×tamppb.Timestamp{Seconds: 123, Nanos: 456}, + ChangeDate: ×tamppb.Timestamp{Seconds: 789, Nanos: 0}, + State: webkey.State_STATE_ACTIVE, + Key: &webkey.WebKey_Ed25519{ + Ed25519: &webkey.ED25519{}, + }, + }, + } + got := webKeyDetailsListToPb(list) + assert.Equal(t, want, got) +} + +func Test_webKeyDetailsToPb(t *testing.T) { + type args struct { + details *query.WebKeyDetails + } + tests := []struct { + name string + args args + want *webkey.WebKey + }{ + { + name: "RSA", + args: args{&query.WebKeyDetails{ + KeyID: "keyID", + CreationDate: time.Unix(123, 456), + ChangeDate: time.Unix(789, 0), + Sequence: 123, + State: domain.WebKeyStateActive, + Config: &crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits3072, + Hasher: crypto.RSAHasherSHA384, + }, + }}, + want: &webkey.WebKey{ + Id: "keyID", + CreationDate: ×tamppb.Timestamp{Seconds: 123, Nanos: 456}, + ChangeDate: ×tamppb.Timestamp{Seconds: 789, Nanos: 0}, + State: webkey.State_STATE_ACTIVE, + Key: &webkey.WebKey_Rsa{ + Rsa: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_3072, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA384, + }, + }, + }, + }, + { + name: "ECDSA", + args: args{&query.WebKeyDetails{ + KeyID: "keyID", + CreationDate: time.Unix(123, 456), + ChangeDate: time.Unix(789, 0), + Sequence: 123, + State: domain.WebKeyStateActive, + Config: &crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP384, + }, + }}, + want: &webkey.WebKey{ + Id: "keyID", + CreationDate: ×tamppb.Timestamp{Seconds: 123, Nanos: 456}, + ChangeDate: ×tamppb.Timestamp{Seconds: 789, Nanos: 0}, + State: webkey.State_STATE_ACTIVE, + Key: &webkey.WebKey_Ecdsa{ + Ecdsa: &webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P384, + }, + }, + }, + }, + { + name: "ED25519", + args: args{&query.WebKeyDetails{ + KeyID: "keyID", + CreationDate: time.Unix(123, 456), + ChangeDate: time.Unix(789, 0), + Sequence: 123, + State: domain.WebKeyStateActive, + Config: &crypto.WebKeyED25519Config{}, + }}, + want: &webkey.WebKey{ + Id: "keyID", + CreationDate: ×tamppb.Timestamp{Seconds: 123, Nanos: 456}, + ChangeDate: ×tamppb.Timestamp{Seconds: 789, Nanos: 0}, + State: webkey.State_STATE_ACTIVE, + Key: &webkey.WebKey_Ed25519{ + Ed25519: &webkey.ED25519{}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := webKeyDetailsToPb(tt.args.details) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyStateToPb(t *testing.T) { + type args struct { + state domain.WebKeyState + } + tests := []struct { + name string + args args + want webkey.State + }{ + { + name: "unspecified", + args: args{domain.WebKeyStateUnspecified}, + want: webkey.State_STATE_UNSPECIFIED, + }, + { + name: "initial", + args: args{domain.WebKeyStateInitial}, + want: webkey.State_STATE_INITIAL, + }, + { + name: "active", + args: args{domain.WebKeyStateActive}, + want: webkey.State_STATE_ACTIVE, + }, + { + name: "inactive", + args: args{domain.WebKeyStateInactive}, + want: webkey.State_STATE_INACTIVE, + }, + { + name: "removed", + args: args{domain.WebKeyStateRemoved}, + want: webkey.State_STATE_REMOVED, + }, + { + name: "invalid", + args: args{99}, + want: webkey.State_STATE_UNSPECIFIED, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := webKeyStateToPb(tt.args.state) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyRSAConfigToPb(t *testing.T) { + type args struct { + config *crypto.WebKeyRSAConfig + } + tests := []struct { + name string + args args + want *webkey.RSA + }{ + { + name: "2048, RSA256", + args: args{&crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits2048, + Hasher: crypto.RSAHasherSHA256, + }}, + want: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_2048, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA256, + }, + }, + { + name: "3072, RSA384", + args: args{&crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits3072, + Hasher: crypto.RSAHasherSHA384, + }}, + want: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_3072, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA384, + }, + }, + { + name: "4096, RSA512", + args: args{&crypto.WebKeyRSAConfig{ + Bits: crypto.RSABits4096, + Hasher: crypto.RSAHasherSHA512, + }}, + want: &webkey.RSA{ + Bits: webkey.RSABits_RSA_BITS_4096, + Hasher: webkey.RSAHasher_RSA_HASHER_SHA512, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := webKeyRSAConfigToPb(tt.args.config) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webKeyECDSAConfigToPb(t *testing.T) { + type args struct { + config *crypto.WebKeyECDSAConfig + } + tests := []struct { + name string + args args + want *webkey.ECDSA + }{ + { + name: "P256", + args: args{&crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP256, + }}, + want: &webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P256, + }, + }, + { + name: "P384", + args: args{&crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP384, + }}, + want: &webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P384, + }, + }, + { + name: "P512", + args: args{&crypto.WebKeyECDSAConfig{ + Curve: crypto.EllipticCurveP512, + }}, + want: &webkey.ECDSA{ + Curve: webkey.ECDSACurve_ECDSA_CURVE_P512, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := webKeyECDSAConfigToPb(tt.args.config) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/grpc/webkey/v2beta/server.go b/internal/api/grpc/webkey/v2beta/server.go index 0d4ddb19c8..b000e98104 100644 --- a/internal/api/grpc/webkey/v2beta/server.go +++ b/internal/api/grpc/webkey/v2beta/server.go @@ -1,17 +1,22 @@ package webkey import ( - "google.golang.org/grpc" + "net/http" + + "connectrpc.com/connect" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/query" webkey "github.com/zitadel/zitadel/pkg/grpc/webkey/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/webkey/v2beta/webkeyconnect" ) +var _ webkeyconnect.WebKeyServiceHandler = (*Server)(nil) + type Server struct { - webkey.UnimplementedWebKeyServiceServer command *command.Commands query *query.Queries } @@ -26,10 +31,6 @@ func CreateServer( } } -func (s *Server) RegisterServer(grpcServer *grpc.Server) { - webkey.RegisterWebKeyServiceServer(grpcServer, s) -} - func (s *Server) AppName() string { return webkey.WebKeyService_ServiceDesc.ServiceName } @@ -45,3 +46,11 @@ func (s *Server) AuthMethods() authz.MethodMapping { func (s *Server) RegisterGateway() server.RegisterGatewayFunc { return webkey.RegisterWebKeyServiceHandler } + +func (s *Server) RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler) { + return webkeyconnect.NewWebKeyServiceHandler(s, connect.WithInterceptors(interceptors...)) +} + +func (s *Server) FileDescriptor() protoreflect.FileDescriptor { + return webkey.File_zitadel_webkey_v2beta_webkey_service_proto +} diff --git a/internal/api/grpc/webkey/v2beta/webkey.go b/internal/api/grpc/webkey/v2beta/webkey.go index 469d6fc9a6..fa37cc32e3 100644 --- a/internal/api/grpc/webkey/v2beta/webkey.go +++ b/internal/api/grpc/webkey/v2beta/webkey.go @@ -3,46 +3,47 @@ package webkey import ( "context" + "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/telemetry/tracing" webkey "github.com/zitadel/zitadel/pkg/grpc/webkey/v2beta" ) -func (s *Server) CreateWebKey(ctx context.Context, req *webkey.CreateWebKeyRequest) (_ *webkey.CreateWebKeyResponse, err error) { +func (s *Server) CreateWebKey(ctx context.Context, req *connect.Request[webkey.CreateWebKeyRequest]) (_ *connect.Response[webkey.CreateWebKeyResponse], err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - webKey, err := s.command.CreateWebKey(ctx, createWebKeyRequestToConfig(req)) + webKey, err := s.command.CreateWebKey(ctx, createWebKeyRequestToConfig(req.Msg)) if err != nil { return nil, err } - return &webkey.CreateWebKeyResponse{ + return connect.NewResponse(&webkey.CreateWebKeyResponse{ Id: webKey.KeyID, CreationDate: timestamppb.New(webKey.ObjectDetails.EventDate), - }, nil + }), nil } -func (s *Server) ActivateWebKey(ctx context.Context, req *webkey.ActivateWebKeyRequest) (_ *webkey.ActivateWebKeyResponse, err error) { +func (s *Server) ActivateWebKey(ctx context.Context, req *connect.Request[webkey.ActivateWebKeyRequest]) (_ *connect.Response[webkey.ActivateWebKeyResponse], err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - details, err := s.command.ActivateWebKey(ctx, req.GetId()) + details, err := s.command.ActivateWebKey(ctx, req.Msg.GetId()) if err != nil { return nil, err } - return &webkey.ActivateWebKeyResponse{ + return connect.NewResponse(&webkey.ActivateWebKeyResponse{ ChangeDate: timestamppb.New(details.EventDate), - }, nil + }), nil } -func (s *Server) DeleteWebKey(ctx context.Context, req *webkey.DeleteWebKeyRequest) (_ *webkey.DeleteWebKeyResponse, err error) { +func (s *Server) DeleteWebKey(ctx context.Context, req *connect.Request[webkey.DeleteWebKeyRequest]) (_ *connect.Response[webkey.DeleteWebKeyResponse], err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - deletedAt, err := s.command.DeleteWebKey(ctx, req.GetId()) + deletedAt, err := s.command.DeleteWebKey(ctx, req.Msg.GetId()) if err != nil { return nil, err } @@ -51,12 +52,12 @@ func (s *Server) DeleteWebKey(ctx context.Context, req *webkey.DeleteWebKeyReque if !deletedAt.IsZero() { deletionDate = timestamppb.New(deletedAt) } - return &webkey.DeleteWebKeyResponse{ + return connect.NewResponse(&webkey.DeleteWebKeyResponse{ DeletionDate: deletionDate, - }, nil + }), nil } -func (s *Server) ListWebKeys(ctx context.Context, _ *webkey.ListWebKeysRequest) (_ *webkey.ListWebKeysResponse, err error) { +func (s *Server) ListWebKeys(ctx context.Context, _ *connect.Request[webkey.ListWebKeysRequest]) (_ *connect.Response[webkey.ListWebKeysResponse], err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -65,7 +66,7 @@ func (s *Server) ListWebKeys(ctx context.Context, _ *webkey.ListWebKeysRequest) return nil, err } - return &webkey.ListWebKeysResponse{ + return connect.NewResponse(&webkey.ListWebKeysResponse{ WebKeys: webKeyDetailsListToPb(list), - }, nil + }), nil } diff --git a/internal/integration/client.go b/internal/integration/client.go index 326d6fa8b4..c4f639b4b8 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -47,6 +47,7 @@ import ( user_pb "github.com/zitadel/zitadel/pkg/grpc/user" user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2" user_v2beta "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" + webkey_v2 "github.com/zitadel/zitadel/pkg/grpc/webkey/v2" webkey_v2beta "github.com/zitadel/zitadel/pkg/grpc/webkey/v2beta" ) @@ -70,6 +71,7 @@ type Client struct { FeatureV2 feature.FeatureServiceClient UserSchemaV3 userschema_v3alpha.ZITADELUserSchemasClient WebKeyV2Beta webkey_v2beta.WebKeyServiceClient + WebKeyV2 webkey_v2.WebKeyServiceClient IDPv2 idp_pb.IdentityProviderServiceClient UserV3Alpha user_v3alpha.ZITADELUsersClient SAMLv2 saml_pb.SAMLServiceClient @@ -110,6 +112,7 @@ func newClient(ctx context.Context, target string) (*Client, error) { FeatureV2: feature.NewFeatureServiceClient(cc), UserSchemaV3: userschema_v3alpha.NewZITADELUserSchemasClient(cc), WebKeyV2Beta: webkey_v2beta.NewWebKeyServiceClient(cc), + WebKeyV2: webkey_v2.NewWebKeyServiceClient(cc), IDPv2: idp_pb.NewIdentityProviderServiceClient(cc), UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc), SAMLv2: saml_pb.NewSAMLServiceClient(cc), diff --git a/internal/protoc/protoc-gen-zitadel/zitadel.pb.go.tmpl b/internal/protoc/protoc-gen-zitadel/zitadel.pb.go.tmpl index adb71c42ff..0fb1c3e102 100644 --- a/internal/protoc/protoc-gen-zitadel/zitadel.pb.go.tmpl +++ b/internal/protoc/protoc-gen-zitadel/zitadel.pb.go.tmpl @@ -5,6 +5,7 @@ package {{.GoPackageName}} import ( "github.com/zitadel/zitadel/internal/api/authz" {{if .AuthContext}}"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"{{end}} + {{if .AuthContext}}"github.com/zitadel/zitadel/internal/api/grpc/server/connect_middleware"{{end}} ) var {{.ServiceName}}_AuthMethods = authz.MethodMapping { @@ -23,6 +24,13 @@ func (r *{{ $m.Name }}) OrganizationFromRequest() *middleware.Organization { Domain: r{{$m.OrgMethod}}.GetOrgDomain(), } } + +func (r *{{ $m.Name }}) OrganizationFromRequestConnect() *connect_middleware.Organization { + return &connect_middleware.Organization{ + ID: r{{$m.OrgMethod}}.GetOrgId(), + Domain: r{{$m.OrgMethod}}.GetOrgDomain(), + } +} {{ end }} {{ range $resp := .CustomHTTPResponses}} diff --git a/proto/zitadel/system.proto b/proto/zitadel/system.proto index 09b5559fb9..9b65fec600 100644 --- a/proto/zitadel/system.proto +++ b/proto/zitadel/system.proto @@ -118,7 +118,7 @@ service SystemService { // Returns a list of ZITADEL instances // - // Deprecated: Use [ListInstances](apis/resources/instance_service_v2/instance-service-list-instances.api.mdx) instead to list instances + // Deprecated: Use [ListInstances](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-list-instances.api.mdx) instead to list instances rpc ListInstances(ListInstancesRequest) returns (ListInstancesResponse) { option (google.api.http) = { post: "/instances/_search" @@ -136,7 +136,7 @@ service SystemService { // Returns the detail of an instance // - // Deprecated: Use [GetInstance](apis/resources/instance_service_v2/instance-service-get-instance.api.mdx) instead to get the details of the instance in context + // Deprecated: Use [GetInstance](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-get-instance.api.mdx) instead to get the details of the instance in context rpc GetInstance(GetInstanceRequest) returns (GetInstanceResponse) { option (google.api.http) = { get: "/instances/{instance_id}"; @@ -171,7 +171,7 @@ service SystemService { // Updates name of an existing instance // - // Deprecated: Use [UpdateInstance](apis/resources/instance_service_v2/instance-service-update-instance.api.mdx) instead to update the name of the instance in context + // Deprecated: Use [UpdateInstance](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-update-instance.api.mdx) instead to update the name of the instance in context rpc UpdateInstance(UpdateInstanceRequest) returns (UpdateInstanceResponse) { option (google.api.http) = { put: "/instances/{instance_id}" @@ -203,7 +203,7 @@ service SystemService { // Removes an instance // This might take some time // - // Deprecated: Use [DeleteInstance](apis/resources/instance_service_v2/instance-service-delete-instance.api.mdx) instead to delete an instance + // Deprecated: Use [DeleteInstance](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-delete-instance.api.mdx) instead to delete an instance rpc RemoveInstance(RemoveInstanceRequest) returns (RemoveInstanceResponse) { option (google.api.http) = { delete: "/instances/{instance_id}" @@ -234,7 +234,7 @@ service SystemService { // Checks if a domain exists // - // Deprecated: Use [ListCustomDomains](apis/resources/instance_service_v2/instance-service-list-custom-domains.api.mdx) instead to check existence of an instance + // Deprecated: Use [ListCustomDomains](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-list-custom-domains.api.mdx) instead to check existence of an instance rpc ExistsDomain(ExistsDomainRequest) returns (ExistsDomainResponse) { option (google.api.http) = { post: "/domains/{domain}/_exists"; @@ -270,7 +270,7 @@ service SystemService { // Adds a domain to an instance // - // Deprecated: Use [AddCustomDomain](apis/resources/instance_service_v2/instance-service-add-custom-domain.api.mdx) instead to add a custom domain to the instance in context + // Deprecated: Use [AddCustomDomain](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-add-custom-domain.api.mdx) instead to add a custom domain to the instance in context rpc AddDomain(AddDomainRequest) returns (AddDomainResponse) { option (google.api.http) = { post: "/instances/{instance_id}/domains"; @@ -288,7 +288,7 @@ service SystemService { // Removes the domain of an instance // - // Deprecated: Use [RemoveDomain](apis/resources/instance_service_v2/instance-service-remove-custom-domain.api.mdx) instead to remove a custom domain from the instance in context + // Deprecated: Use [RemoveDomain](apis/resources/instance_service_v2/zitadel-instance-v-2-beta-instance-service-remove-custom-domain.api.mdx) instead to remove a custom domain from the instance in context rpc RemoveDomain(RemoveDomainRequest) returns (RemoveDomainResponse) { option (google.api.http) = { delete: "/instances/{instance_id}/domains/{domain}"; diff --git a/proto/zitadel/webkey/v2/key.proto b/proto/zitadel/webkey/v2/key.proto new file mode 100644 index 0000000000..4ec85fa168 --- /dev/null +++ b/proto/zitadel/webkey/v2/key.proto @@ -0,0 +1,109 @@ +syntax = "proto3"; + +package zitadel.webkey.v2; + +import "google/protobuf/timestamp.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; +import "validate/validate.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/webkey/v2;webkey"; + +enum State { + STATE_UNSPECIFIED = 0; + // A newly created key is in the initial state and published to the public key endpoint. + STATE_INITIAL = 1; + // The active key is used to sign tokens. Only one key can be active at a time. + STATE_ACTIVE = 2; + // The inactive key is not used to sign tokens anymore, but still published to the public key endpoint. + STATE_INACTIVE = 3; + // The removed key is not used to sign tokens anymore and not published to the public key endpoint. + STATE_REMOVED = 4; +} + +message WebKey { + // The unique identifier of the key. + string id = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"69629012906488334\""; + } + ]; + // The timestamp of the key creation. + google.protobuf.Timestamp creation_date = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"2024-12-18T07:50:47.492Z\""; + } + ]; + // The timestamp of the last change to the key (e.g. creation, activation, deactivation). + google.protobuf.Timestamp change_date = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"2025-01-23T10:34:18.051Z\""; + } + ]; + // State of the key + State state = 4; + // Configured type of the key (either RSA, ECDSA or ED25519) + oneof key { + RSA rsa = 5; + ECDSA ecdsa = 6; + ED25519 ed25519 = 7; + } +} + +message RSA { + // Bit size of the RSA key. Default is 2048 bits. + RSABits bits = 1 [ + (validate.rules).enum = {defined_only: true, not_in: [0]}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + default: "RSA_BITS_2048"; + } + ]; + // Signing algrithm used. Default is SHA256. + RSAHasher hasher = 2 [ + (validate.rules).enum = {defined_only: true, not_in: [0]}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + default: "RSA_HASHER_SHA256"; + } + ]; +} + +enum RSABits { + RSA_BITS_UNSPECIFIED = 0; + // 2048 bit RSA key + RSA_BITS_2048 = 1; + // 3072 bit RSA key + RSA_BITS_3072 = 2; + // 4096 bit RSA key + RSA_BITS_4096 = 3; +} + +enum RSAHasher { + RSA_HASHER_UNSPECIFIED = 0; + // SHA256 hashing algorithm resulting in the RS256 algorithm header + RSA_HASHER_SHA256 = 1; + // SHA384 hashing algorithm resulting in the RS384 algorithm header + RSA_HASHER_SHA384 = 2; + // SHA512 hashing algorithm resulting in the RS512 algorithm header + RSA_HASHER_SHA512 = 3; +} + +message ECDSA { + // Curve of the ECDSA key. Default is P-256. + ECDSACurve curve = 1 [ + (validate.rules).enum = {defined_only: true, not_in: [0]}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + default: "ECDSA_CURVE_P256"; + } + ]; +} + +enum ECDSACurve { + ECDSA_CURVE_UNSPECIFIED = 0; + // NIST P-256 curve resulting in the ES256 algorithm header + ECDSA_CURVE_P256 = 1; + // NIST P-384 curve resulting in the ES384 algorithm header + ECDSA_CURVE_P384 = 2; + // NIST P-512 curve resulting in the ES512 algorithm header + ECDSA_CURVE_P512 = 3; +} + +message ED25519 {} diff --git a/proto/zitadel/webkey/v2/webkey_service.proto b/proto/zitadel/webkey/v2/webkey_service.proto new file mode 100644 index 0000000000..f29f291c38 --- /dev/null +++ b/proto/zitadel/webkey/v2/webkey_service.proto @@ -0,0 +1,335 @@ +syntax = "proto3"; + +package zitadel.webkey.v2; + +import "google/api/annotations.proto"; +import "google/api/field_behavior.proto"; +import "google/protobuf/timestamp.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; +import "validate/validate.proto"; +import "zitadel/protoc_gen_zitadel/v2/options.proto"; +import "zitadel/webkey/v2/key.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/webkey/v2;webkey"; + +option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = { + info: { + title: "Web Key Service"; + version: "2.0"; + description: "This API is intended to manage web keys for a ZITADEL instance, used to sign and validate OIDC tokens.\n\nThe public key endpoint (outside of this service) is used to retrieve the public keys of the active and inactive keys.\n\nPlease make sure to enable the `web_key` feature flag on your instance to use this service."; + contact:{ + name: "ZITADEL" + url: "https://zitadel.com" + email: "hi@zitadel.com" + } + license: { + name: "Apache 2.0", + url: "https://github.com/zitadel/zitadel/blob/main/LICENSING.md"; + }; + }; + schemes: HTTPS; + schemes: HTTP; + + consumes: "application/json"; + produces: "application/json"; + + consumes: "application/grpc"; + produces: "application/grpc"; + + consumes: "application/grpc-web+proto"; + produces: "application/grpc-web+proto"; + + host: "$CUSTOM-DOMAIN"; + base_path: "/"; + + external_docs: { + description: "Detailed information about ZITADEL", + url: "https://zitadel.com/docs" + } + security_definitions: { + security: { + key: "OAuth2"; + value: { + type: TYPE_OAUTH2; + flow: FLOW_ACCESS_CODE; + authorization_url: "$CUSTOM-DOMAIN/oauth/v2/authorize"; + token_url: "$CUSTOM-DOMAIN/oauth/v2/token"; + scopes: { + scope: { + key: "openid"; + value: "openid"; + } + scope: { + key: "urn:zitadel:iam:org:project:id:zitadel:aud"; + value: "urn:zitadel:iam:org:project:id:zitadel:aud"; + } + } + } + } + } + security: { + security_requirement: { + key: "OAuth2"; + value: { + scope: "openid"; + scope: "urn:zitadel:iam:org:project:id:zitadel:aud"; + } + } + } + responses: { + key: "403"; + value: { + description: "Returned when the user does not have permission to access the resource."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } + responses: { + key: "404"; + value: { + description: "Returned when the resource does not exist."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } +}; + +// Service to manage web keys for OIDC token signing and validation. +// The service provides methods to create, activate, delete and list web keys. +// The public key endpoint (outside of this service) is used to retrieve the public keys of the active and inactive keys. +// +// Please make sure to enable the `web_key` feature flag on your instance to use this service. +service WebKeyService { + // Create Web Key + // + // Generate a private and public key pair. The private key can be used to sign OIDC tokens after activation. + // The public key can be used to validate OIDC tokens. + // The newly created key will have the state `STATE_INITIAL` and is published to the public key endpoint. + // Note that the JWKs OIDC endpoint returns a cacheable response. + // + // If no key type is provided, a RSA key pair with 2048 bits and SHA256 hashing will be created. + // + // Required permission: + // - `iam.web_key.write` + // + // Required feature flag: + // - `web_key` + rpc CreateWebKey(CreateWebKeyRequest) returns (CreateWebKeyResponse) { + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "iam.web_key.write" + } + }; + } + + // Activate Web Key + // + // Switch the active signing web key. The previously active key will be deactivated. + // Note that the JWKs OIDC endpoint returns a cacheable response. + // Therefore it is not advised to activate a key that has been created within the cache duration (default is 5min), + // as the public key may not have been propagated to caches and clients yet. + // + // Required permission: + // - `iam.web_key.write` + // + // Required feature flag: + // - `web_key` + rpc ActivateWebKey(ActivateWebKeyRequest) returns (ActivateWebKeyResponse) { + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "iam.web_key.write" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + responses: { + key: "200" + value: { + description: "Web key activated successfully."; + } + }; + responses: { + key: "400" + value: { + description: "The feature flag `web_key` is not enabled."; + } + }; + responses: { + key: "404" + value: { + description: "The web key to active does not exist."; + } + }; + }; + } + + // Delete Web Key + // + // Delete a web key pair. Only inactive keys can be deleted. Once a key is deleted, + // any tokens signed by this key will be invalid. + // Note that the JWKs OIDC endpoint returns a cacheable response. + // In case the web key is not found, the request will return a successful response as + // the desired state is already achieved. + // You can check the change date in the response to verify if the web key was deleted during the request. + // + // Required permission: + // - `iam.web_key.delete` + // + // Required feature flag: + // - `web_key` + rpc DeleteWebKey(DeleteWebKeyRequest) returns (DeleteWebKeyResponse) { + option (google.api.http) = { + delete: "/v2/web_keys/{id}" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "iam.web_key.delete" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + responses: { + key: "200" + value: { + description: "Web key deleted successfully."; + } + }; + responses: { + key: "400" + value: { + description: "The feature flag `web_key` is not enabled or the web key is currently active."; + } + }; + }; + } + + // List Web Keys + // + // List all web keys and their states. + // + // Required permission: + // - `iam.web_key.read` + // + // Required feature flag: + // - `web_key` + rpc ListWebKeys(ListWebKeysRequest) returns (ListWebKeysResponse) { + option (google.api.http) = { + get: "/v2/web_keys" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "iam.web_key.read" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + responses: { + key: "200" + value: { + description: "List of all web keys."; + } + }; + responses: { + key: "400" + value: { + description: "The feature flag `web_key` is not enabled."; + } + }; + }; + } +} + +message CreateWebKeyRequest { + // The key type to create (RSA, ECDSA, ED25519). + // If no key type is provided, a RSA key pair with 2048 bits and SHA256 hashing will be created. + oneof key { + // Create a RSA key pair and specify the bit size and hashing algorithm. + // If no bits and hasher are provided, a RSA key pair with 2048 bits and SHA256 hashing will be created. + RSA rsa = 1; + // Create a ECDSA key pair and specify the curve. + // If no curve is provided, a ECDSA key pair with P-256 curve will be created. + ECDSA ecdsa = 2; + // Create a ED25519 key pair. + ED25519 ed25519 = 3; + } + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = { + example: "{\"rsa\":{\"bits\":\"RSA_BITS_2048\",\"hasher\":\"RSA_HASHER_SHA256\"}}"; + }; +} + +message CreateWebKeyResponse { + // The unique identifier of the newly created key. + string id = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"69629012906488334\""; + } + ]; + // The timestamp of the key creation. + google.protobuf.Timestamp creation_date = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"2024-12-18T07:50:47.492Z\""; + } + ]; +} + +message ActivateWebKeyRequest { + string id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (google.api.field_behavior) = REQUIRED, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1, + max_length: 200, + example: "\"69629026806489455\""; + } + ]; +} + +message ActivateWebKeyResponse { + // The timestamp of the activation of the key. + google.protobuf.Timestamp change_date = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"2025-01-23T10:34:18.051Z\""; + } + ]; +} + +message DeleteWebKeyRequest { + string id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (google.api.field_behavior) = REQUIRED, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1, + max_length: 200, + example: "\"69629026806489455\""; + } + ]; +} + +message DeleteWebKeyResponse { + // The timestamp of the deletion of the key. + // Note that the deletion date is only guaranteed to be set if the deletion was successful during the request. + // In case the deletion occurred in a previous request, the deletion date might be empty. + google.protobuf.Timestamp deletion_date = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"2025-01-23T10:34:18.051Z\""; + } + ]; +} + +message ListWebKeysRequest {} + +message ListWebKeysResponse { + repeated WebKey web_keys = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "[{\"id\":\"69629012906488334\",\"creationDate\":\"2024-12-18T07:50:47.492Z\",\"changeDate\":\"2024-12-18T08:04:47.492Z\",\"state\":\"STATE_ACTIVE\",\"rsa\":{\"bits\":\"RSA_BITS_2048\",\"hasher\":\"RSA_HASHER_SHA256\"}},{\"id\":\"69629012909346200\",\"creationDate\":\"2025-01-18T12:05:47.492Z\",\"state\":\"STATE_INITIAL\",\"ecdsa\":{\"curve\":\"ECDSA_CURVE_P256\"}}]"; + } + ]; +} \ No newline at end of file