mirror of
https://github.com/zitadel/zitadel.git
synced 2025-03-01 01:07:23 +00:00
perf(oidc): optimize token creation (#7822)
* implement code exchange * port tokenexchange to v2 tokens * implement refresh token * implement client credentials * implement jwt profile * implement device token * cleanup unused code * fix current unit tests * add user agent unit test * unit test domain package * need refresh token as argument * test commands create oidc session * test commands device auth * fix device auth build error * implicit for oidc session API * implement authorize callback handler for legacy implicit mode * upgrade oidc module to working draft * add missing auth methods and time * handle all errors in defer * do not fail auth request on error the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed. The original error is then already lost and the oauth2 library does not return the original error. Therefore we should not fail the auth request. Might be worth discussing and perhaps send a bug report to Oauth2? * fix code flow tests by explicitly setting code exchanged * fix unit tests in command package * return allowed scope from client credential client * add device auth done reducer * carry nonce thru session into ID token * fix token exchange integration tests * allow project role scope prefix in client credentials client * gci formatting * do not return refresh token in client credentials and jwt profile * check org scope * solve linting issue on authorize callback error * end session based on v2 session ID * use preferred language and user agent ID for v2 access tokens * pin oidc v3.23.2 * add integration test for jwt profile and client credentials with org scopes * refresh token v1 to v2 * add user token v2 audit event * add activity trigger * cleanup and set panics for unused methods * use the encrypted code for v1 auth request get by code * add missing event translation * fix pipeline errors (hopefully) * fix another test * revert pointer usage of preferred language * solve browser info panic in device auth * remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim * revoke v1 refresh token to prevent reuse * fix terminate oidc session * always return a new refresh toke in refresh token grant --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
parent
6cf9ca9f7e
commit
8e0c8393e9
36
go.mod
36
go.mod
@ -49,7 +49,7 @@ require (
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0
|
||||
github.com/pquerna/otp v1.4.0
|
||||
github.com/rakyll/statik v0.1.7
|
||||
github.com/rs/cors v1.10.1
|
||||
github.com/rs/cors v1.11.0
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
|
||||
github.com/sony/sonyflake v1.2.0
|
||||
github.com/spf13/cobra v1.8.0
|
||||
@ -58,20 +58,20 @@ require (
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
|
||||
github.com/ttacon/libphonenumber v1.2.1
|
||||
github.com/zitadel/logging v0.6.0
|
||||
github.com/zitadel/oidc/v3 v3.21.0
|
||||
github.com/zitadel/oidc/v3 v3.23.2
|
||||
github.com/zitadel/passwap v0.5.0
|
||||
github.com/zitadel/saml v0.1.3
|
||||
github.com/zitadel/schema v1.3.0
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.50.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0
|
||||
go.opentelemetry.io/otel v1.25.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.47.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.25.0
|
||||
go.opentelemetry.io/otel/metric v1.25.0
|
||||
go.opentelemetry.io/otel/sdk v1.25.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.25.0
|
||||
go.opentelemetry.io/otel/trace v1.25.0
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0
|
||||
go.opentelemetry.io/otel v1.26.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.26.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.26.0
|
||||
go.opentelemetry.io/otel/metric v1.26.0
|
||||
go.opentelemetry.io/otel/sdk v1.26.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0
|
||||
go.opentelemetry.io/otel/trace v1.26.0
|
||||
go.uber.org/mock v0.4.0
|
||||
golang.org/x/crypto v0.22.0
|
||||
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8
|
||||
@ -80,9 +80,9 @@ require (
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/text v0.14.0
|
||||
google.golang.org/api v0.172.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240412170617-26222e5d3d56
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6
|
||||
google.golang.org/grpc v1.63.2
|
||||
google.golang.org/protobuf v1.33.0
|
||||
google.golang.org/protobuf v1.34.0
|
||||
sigs.k8s.io/yaml v1.4.0
|
||||
)
|
||||
|
||||
@ -116,7 +116,7 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20240412170617-26222e5d3d56 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240412170617-26222e5d3d56 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
@ -183,8 +183,8 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/client_golang v1.19.0
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.52.3 // indirect
|
||||
github.com/prometheus/procfs v0.13.0 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.14.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/russellhaering/goxmldsig v1.4.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
@ -196,7 +196,7 @@ require (
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.2.0 // indirect
|
||||
golang.org/x/sys v0.19.0
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
72
go.sum
72
go.sum
@ -618,16 +618,16 @@ github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8
|
||||
github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA=
|
||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||
github.com/prometheus/common v0.15.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s=
|
||||
github.com/prometheus/common v0.52.3 h1:5f8uj6ZwHSscOGNdIQg6OiZv/ybiK2CO2q2drVZAQSA=
|
||||
github.com/prometheus/common v0.52.3/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE=
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||
github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A=
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.3.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.13.0 h1:GqzLlQyfsPbaEHaQkO7tbDlriv/4o5Hudv6OXHGKX7o=
|
||||
github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43ZKY6tow0Y1g=
|
||||
github.com/prometheus/procfs v0.14.0 h1:Lw4VdGGoKEZilJsayHf0B+9YgLGREba2C6xr+Fdfq6s=
|
||||
github.com/prometheus/procfs v0.14.0/go.mod h1:XL+Iwz8k8ZabyZfMFHPiilCniixqQarAy5Mu67pHlNQ=
|
||||
github.com/rakyll/statik v0.1.7 h1:OF3QCZUuyPxuGEP7B4ypUa7sB/iHtqOTDYZXGM8KOdQ=
|
||||
github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
@ -639,8 +639,8 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||
github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo=
|
||||
github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po=
|
||||
github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys=
|
||||
@ -731,8 +731,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8=
|
||||
github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
github.com/zitadel/logging v0.6.0 h1:t5Nnt//r+m2ZhhoTmoPX+c96pbMarqJvW1Vq6xFTank=
|
||||
github.com/zitadel/logging v0.6.0/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow=
|
||||
github.com/zitadel/oidc/v3 v3.21.0 h1:dvhPLAOCJQHxZq+1vqd2+TYu1EzwrHhnPSSh4nVamgo=
|
||||
github.com/zitadel/oidc/v3 v3.21.0/go.mod h1:3uCwJc680oWoTBdzIppMZQS+VNxq+sVcwgodbreuatM=
|
||||
github.com/zitadel/oidc/v3 v3.23.2 h1:vRUM6SKudr6WR/lqxue4cvCbgR+IdEJGVBklucKKXgk=
|
||||
github.com/zitadel/oidc/v3 v3.23.2/go.mod h1:9snlhm3W/GNURqxtchjL1AAuClWRZ2NTkn9sLs1WYfM=
|
||||
github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs=
|
||||
github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA=
|
||||
github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM=
|
||||
@ -746,28 +746,28 @@ go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.50.0 h1:zvpPXY7RfYAGSdYQLjp6zxdJNSYD/+FFoCTQN9IPxBs=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.50.0/go.mod h1:BMn8NB1vsxTljvuorms2hyOs8IBuuBEq0pl7ltOfy30=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0 h1:cEPbyTSEHlQR89XVlyo78gqluF8Y3oMeBkXGWzQsfXY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0/go.mod h1:DKdbWcT4GH1D0Y3Sqt/PFXt2naRKDWtU+eE6oLdFNA8=
|
||||
go.opentelemetry.io/otel v1.25.0 h1:gldB5FfhRl7OJQbUHt/8s0a7cE8fbsPAtdpRaApKy4k=
|
||||
go.opentelemetry.io/otel v1.25.0/go.mod h1:Wa2ds5NOXEMkCmUou1WA7ZBfLTHWIsp034OVD7AO+Vg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0 h1:dT33yIHtmsqpixFsSQPwNeY5drM9wTcoL8h0FWF4oGM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0/go.mod h1:h95q0LBGh7hlAC08X2DhSeyIG02YQ0UyioTCVAqRPmc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0 h1:vOL89uRfOCCNIjkisd0r7SEdJF3ZJFyCNY34fdZs8eU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0/go.mod h1:8GlBGcDk8KKi7n+2S4BT/CPZQYH3erLu0/k64r1MYgo=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.47.0 h1:OL6yk1Z/pEGdDnrBbxSsH+t4FY1zXfBRGd7bjwhlMLU=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.47.0/go.mod h1:xF3N4OSICZDVbbYZydz9MHFro1RjmkPUKEvar2utG+Q=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.25.0 h1:0vZZdECYzhTt9MKQZ5qQ0V+J3MFu4MQaQ3COfugF+FQ=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.25.0/go.mod h1:e7iXx3HjaSSBXfy9ykVUlupS2Vp7LBIBuT21ousM2Hk=
|
||||
go.opentelemetry.io/otel/metric v1.25.0 h1:LUKbS7ArpFL/I2jJHdJcqMGxkRdxpPHE0VU/D4NuEwA=
|
||||
go.opentelemetry.io/otel/metric v1.25.0/go.mod h1:rkDLUSd2lC5lq2dFNrX9LGAbINP5B7WBkC78RXCpH5s=
|
||||
go.opentelemetry.io/otel/sdk v1.25.0 h1:PDryEJPC8YJZQSyLY5eqLeafHtG+X7FWnf3aXMtxbqo=
|
||||
go.opentelemetry.io/otel/sdk v1.25.0/go.mod h1:oFgzCM2zdsxKzz6zwpTZYLLQsFwc+K0daArPdIhuxkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.25.0 h1:7CiHOy08LbrxMAp4vWpbiPcklunUshVpAvGBrdDRlGw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.25.0/go.mod h1:LzwoKptdbBBdYfvtGCzGwk6GWMA3aUzBOwtQpR6Nz7o=
|
||||
go.opentelemetry.io/otel/trace v1.25.0 h1:tqukZGLwQYRIFtSQM2u2+yfMVTgGVeqRLPUYx1Dq6RM=
|
||||
go.opentelemetry.io/otel/trace v1.25.0/go.mod h1:hCCs70XM/ljO+BeQkyFnbK28SBIJ/Emuha+ccrCRT7I=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 h1:A3SayB3rNyt+1S6qpI9mHPkeHTZbD7XILEqWnYZb2l0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0/go.mod h1:27iA5uvhuRNmalO+iEUdVn5ZMj2qy10Mm+XRIpRmyuU=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs=
|
||||
go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0 h1:1u/AyyOqAWzy+SkPxDpahCNZParHV8Vid1RnI2clyDE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0/go.mod h1:z46paqbJ9l7c9fIPCXTqTGwhQZ5XoTIsfeFYWboizjs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.26.0 h1:Waw9Wfpo/IXzOI8bCB7DIk+0JZcqqsyn1JFnAc+iam8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.26.0/go.mod h1:wnJIG4fOqyynOnnQF/eQb4/16VlX2EJAHhHgqIqWfAo=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.26.0 h1:0W5o9SzoR15ocYHEQfvfipzcNog1lBxOLfnex91Hk6s=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.26.0/go.mod h1:zVZ8nz+VSggWmnh6tTsJqXQ7rU4xLwRtna1M4x5jq58=
|
||||
go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30=
|
||||
go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4=
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8=
|
||||
go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE=
|
||||
go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA=
|
||||
go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
|
||||
go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IOkz94=
|
||||
go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
@ -991,10 +991,10 @@ google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEY
|
||||
google.golang.org/genproto v0.0.0-20210126160654-44e461bb6506/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20240412170617-26222e5d3d56 h1:LlcUFJ4BLmJVS4Kly+WCK7LQqcevmycHj88EPgyhNx8=
|
||||
google.golang.org/genproto v0.0.0-20240412170617-26222e5d3d56/go.mod h1:n1CaIKYMIlxFt1zJE/1kU40YpSL0drGMbl0Idum1VSs=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240412170617-26222e5d3d56 h1:KuFzeG+qPmpT8KpJXcrKAyeHhn64dgEICWlccP9qp0U=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240412170617-26222e5d3d56/go.mod h1:wTHjrkbcS8AoQbb/0v9bFIPItZQPAsyVfgG9YPUhjAM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240412170617-26222e5d3d56 h1:zviK8GX4VlMstrK3JkexM5UHjH1VOkRebH9y3jhSBGk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240412170617-26222e5d3d56/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6 h1:DTJM0R8LECCgFeUwApvcEJHz85HLagW8uRENYxHh1ww=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6/go.mod h1:10yRODfgim2/T8csjQsMPgZOMvtytXKTDRzH6HRGzRw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6 h1:DujSIu+2tC9Ht0aPNA7jgj23Iq8Ewi5sgkQ++wdvonE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
|
||||
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM=
|
||||
@ -1022,8 +1022,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.34.0 h1:Qo/qEd2RZPCf2nKuorzksSknv0d3ERwp1vFG38gSmH4=
|
||||
google.golang.org/protobuf v1.34.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
object_pb "github.com/zitadel/zitadel/internal/api/grpc/object"
|
||||
|
@ -112,7 +112,7 @@ func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID str
|
||||
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
|
||||
} else {
|
||||
callback, err = oidc.CreateTokenCallbackURL(ctx, authReq, s.op.Provider())
|
||||
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"golang.org/x/text/language"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
@ -352,7 +353,12 @@ func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([
|
||||
if !user.State.IsEnabled() {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SESSION-Gj4ko", "Errors.User.NotActive")
|
||||
}
|
||||
sessionChecks = append(sessionChecks, command.CheckUser(user.ID, user.ResourceOwner))
|
||||
|
||||
var preferredLanguage *language.Tag
|
||||
if user.Human != nil && !user.Human.PreferredLanguage.IsRoot() {
|
||||
preferredLanguage = &user.Human.PreferredLanguage
|
||||
}
|
||||
sessionChecks = append(sessionChecks, command.CheckUser(user.ID, user.ResourceOwner, preferredLanguage))
|
||||
}
|
||||
if password := checks.GetPassword(); password != nil {
|
||||
sessionChecks = append(sessionChecks, command.CheckPassword(password.GetPassword()))
|
||||
|
@ -6,8 +6,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
@ -19,19 +21,20 @@ import (
|
||||
)
|
||||
|
||||
type accessToken struct {
|
||||
tokenID string
|
||||
userID string
|
||||
resourceOwner string
|
||||
subject string
|
||||
clientID string
|
||||
audience []string
|
||||
scope []string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
tokenCreation time.Time
|
||||
tokenExpiration time.Time
|
||||
isPAT bool
|
||||
actor *domain.TokenActor
|
||||
tokenID string
|
||||
userID string
|
||||
resourceOwner string
|
||||
subject string
|
||||
preferredLanguage *language.Tag
|
||||
clientID string
|
||||
audience []string
|
||||
scope []string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
tokenCreation time.Time
|
||||
tokenExpiration time.Time
|
||||
isPAT bool
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
|
||||
var ErrInvalidTokenFormat = errors.New("invalid token format")
|
||||
@ -73,35 +76,41 @@ func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (_ *accessTo
|
||||
}
|
||||
|
||||
func accessTokenV1(tokenID, subject string, token *model.TokenView) *accessToken {
|
||||
var preferredLanguage *language.Tag
|
||||
if token.PreferredLanguage != "" {
|
||||
preferredLanguage = gu.Ptr(language.Make(token.PreferredLanguage))
|
||||
}
|
||||
return &accessToken{
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
resourceOwner: token.ResourceOwner,
|
||||
subject: subject,
|
||||
clientID: token.ApplicationID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scopes,
|
||||
tokenCreation: token.CreationDate,
|
||||
tokenExpiration: token.Expiration,
|
||||
isPAT: token.IsPAT,
|
||||
actor: token.Actor,
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
resourceOwner: token.ResourceOwner,
|
||||
subject: subject,
|
||||
preferredLanguage: preferredLanguage,
|
||||
clientID: token.ApplicationID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scopes,
|
||||
tokenCreation: token.CreationDate,
|
||||
tokenExpiration: token.Expiration,
|
||||
isPAT: token.IsPAT,
|
||||
actor: token.Actor,
|
||||
}
|
||||
}
|
||||
|
||||
func accessTokenV2(tokenID, subject string, token *query.OIDCSessionAccessTokenReadModel) *accessToken {
|
||||
return &accessToken{
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
resourceOwner: token.ResourceOwner,
|
||||
subject: subject,
|
||||
clientID: token.ClientID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scope,
|
||||
authMethods: token.AuthMethods,
|
||||
authTime: token.AuthTime,
|
||||
tokenCreation: token.AccessTokenCreation,
|
||||
tokenExpiration: token.AccessTokenExpiration,
|
||||
actor: token.Actor,
|
||||
tokenID: tokenID,
|
||||
userID: token.UserID,
|
||||
resourceOwner: token.ResourceOwner,
|
||||
subject: subject,
|
||||
preferredLanguage: token.PreferredLanguage,
|
||||
clientID: token.ClientID,
|
||||
audience: token.Audience,
|
||||
scope: token.Scope,
|
||||
authMethods: token.AuthMethods,
|
||||
authTime: token.AuthTime,
|
||||
tokenCreation: token.AccessTokenCreation,
|
||||
tokenExpiration: token.AccessTokenExpiration,
|
||||
actor: token.Actor,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
package oidc
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/domain"
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
// Password states that the users password has been verified
|
||||
@ -87,5 +91,5 @@ func AMRToAuthMethodTypes(amr []string) []domain.UserAuthMethodType {
|
||||
authMethods = append(authMethods, domain.UserAuthMethodTypeU2F)
|
||||
}
|
||||
}
|
||||
return authMethods
|
||||
return slices.Compact(authMethods) // remove duplicate entries
|
||||
}
|
||||
|
@ -3,14 +3,17 @@ package oidc
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
@ -64,18 +67,19 @@ func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.
|
||||
return nil, err
|
||||
}
|
||||
authRequest := &command.AuthRequest{
|
||||
LoginClient: loginClient,
|
||||
ClientID: req.ClientID,
|
||||
RedirectURI: req.RedirectURI,
|
||||
State: req.State,
|
||||
Nonce: req.Nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: ResponseTypeToBusiness(req.ResponseType),
|
||||
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
|
||||
Prompt: PromptToBusiness(req.Prompt),
|
||||
UILocales: UILocalesToBusiness(req.UILocales),
|
||||
MaxAge: MaxAgeToBusiness(req.MaxAge),
|
||||
LoginClient: loginClient,
|
||||
ClientID: req.ClientID,
|
||||
RedirectURI: req.RedirectURI,
|
||||
State: req.State,
|
||||
Nonce: req.Nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
NeedRefreshToken: slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
ResponseType: ResponseTypeToBusiness(req.ResponseType),
|
||||
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
|
||||
Prompt: PromptToBusiness(req.Prompt),
|
||||
UILocales: UILocalesToBusiness(req.UILocales),
|
||||
MaxAge: MaxAgeToBusiness(req.MaxAge),
|
||||
}
|
||||
if req.LoginHint != "" {
|
||||
authRequest.LoginHint = &req.LoginHint
|
||||
@ -149,28 +153,7 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
|
||||
}
|
||||
|
||||
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
plainCode, err := o.decryptGrant(code)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "OIDC-ahLi2", "Errors.User.Code.Invalid")
|
||||
}
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{authReq}, nil
|
||||
}
|
||||
resp, err := o.repo.AuthRequestByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return AuthRequestFromBusiness(resp)
|
||||
panic(o.panicErr("AuthRequestByCode"))
|
||||
}
|
||||
|
||||
// decryptGrant decrypts a code or refresh_token
|
||||
@ -201,136 +184,40 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
|
||||
}
|
||||
|
||||
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
return o.repo.DeleteAuthRequest(ctx, id)
|
||||
panic(o.panicErr("DeleteAuthRequest"))
|
||||
}
|
||||
|
||||
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
if authReq, ok := req.(*AuthRequestV2); ok {
|
||||
activity.Trigger(ctx, "", authReq.CurrentAuthRequest.UserID, activity.OIDCAccessToken, o.eventstore.FilterToQueryReducer)
|
||||
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
|
||||
}
|
||||
|
||||
userAgentID, applicationID, userOrgID, authTime, amr, reason, actor := getInfoFromRequest(req)
|
||||
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
resp, err := o.command.AddUserToken(setContextUserSystem(ctx), userOrgID, userAgentID, applicationID, req.GetSubject(), req.GetAudience(), req.GetScopes(), amr, accessTokenLifetime, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
// trigger activity log for authentication for user
|
||||
activity.Trigger(ctx, userOrgID, req.GetSubject(), activity.OIDCAccessToken, o.eventstore.FilterToQueryReducer)
|
||||
return resp.TokenID, resp.Expiration, nil
|
||||
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (string, time.Time, error) {
|
||||
panic(o.panicErr("CreateAccessToken"))
|
||||
}
|
||||
|
||||
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
// handle V2 request directly
|
||||
switch tokenReq := req.(type) {
|
||||
case *AuthRequestV2:
|
||||
// trigger activity log for authentication for user
|
||||
activity.Trigger(ctx, "", tokenReq.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
|
||||
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
|
||||
case *RefreshTokenRequestV2:
|
||||
// trigger activity log for authentication for user
|
||||
activity.Trigger(ctx, "", tokenReq.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
|
||||
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
|
||||
}
|
||||
|
||||
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences, reason, actor := getInfoFromRequest(req)
|
||||
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, zerrors.ThrowPreconditionFailed(err, "OIDC-Df2fq", "Errors.Internal")
|
||||
}
|
||||
if request, ok := req.(op.RefreshTokenRequest); ok {
|
||||
request.SetCurrentScopes(scopes)
|
||||
}
|
||||
|
||||
accessTokenLifetime, _, refreshTokenIdleExpiration, refreshTokenExpiration, err := o.getOIDCSettings(ctx)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
resp, token, err := o.command.AddAccessAndRefreshToken(setContextUserSystem(ctx), userOrgID, userAgentID, applicationID, req.GetSubject(),
|
||||
refreshToken, req.GetAudience(), scopes, authMethodsReferences, accessTokenLifetime,
|
||||
refreshTokenIdleExpiration, refreshTokenExpiration, authTime, reason, actor) //PLANNED: lifetime from client
|
||||
if err != nil {
|
||||
if zerrors.IsErrorInvalidArgument(err) {
|
||||
err = oidc.ErrInvalidGrant().WithParent(err)
|
||||
}
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
// trigger activity log for authentication for user
|
||||
activity.Trigger(ctx, userOrgID, req.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
|
||||
return resp.TokenID, token, resp.Expiration, nil
|
||||
func (o *OPStorage) CreateAccessAndRefreshTokens(context.Context, op.TokenRequest, string) (string, string, time.Time, error) {
|
||||
panic(o.panicErr("CreateAccessAndRefreshTokens"))
|
||||
}
|
||||
|
||||
func getInfoFromRequest(req op.TokenRequest) (agentID string, clientID string, userOrgID string, authTime time.Time, amr []string, reason domain.TokenReason, actor *domain.TokenActor) {
|
||||
func (*OPStorage) panicErr(method string) error {
|
||||
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
|
||||
}
|
||||
|
||||
func getInfoFromRequest(req op.TokenRequest) (agentID, clientID, userOrgID string, authTime time.Time, amr []string, preferredLanguage *language.Tag, reason domain.TokenReason, actor *domain.TokenActor) {
|
||||
switch r := req.(type) {
|
||||
case *AuthRequest:
|
||||
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), domain.TokenReasonAuthRequest, nil
|
||||
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), r.PreferredLanguage, domain.TokenReasonAuthRequest, nil
|
||||
case *RefreshTokenRequest:
|
||||
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, domain.TokenReasonRefresh, r.Actor
|
||||
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, nil, domain.TokenReasonRefresh, r.Actor
|
||||
case op.IDTokenRequest:
|
||||
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), domain.TokenReasonAuthRequest, nil
|
||||
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), nil, domain.TokenReasonAuthRequest, nil
|
||||
case *oidc.JWTTokenRequest:
|
||||
return "", "", "", r.GetAuthTime(), nil, domain.TokenReasonJWTProfile, nil
|
||||
return "", "", "", r.GetAuthTime(), nil, nil, domain.TokenReasonJWTProfile, nil
|
||||
case *clientCredentialsRequest:
|
||||
return "", "", "", time.Time{}, nil, domain.TokenReasonClientCredentials, nil
|
||||
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonClientCredentials, nil
|
||||
default:
|
||||
return "", "", "", time.Time{}, nil, domain.TokenReasonAuthRequest, nil
|
||||
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonAuthRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
plainToken, err := o.decryptGrant(refreshToken)
|
||||
if err != nil {
|
||||
return nil, op.ErrInvalidRefreshToken
|
||||
}
|
||||
if strings.HasPrefix(plainToken, command.IDPrefixV2) {
|
||||
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// trigger activity log for authentication for user
|
||||
activity.Trigger(ctx, "", oidcSession.UserID, activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
|
||||
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
|
||||
}
|
||||
|
||||
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// trigger activity log for use of refresh token for user
|
||||
activity.Trigger(ctx, tokenView.ResourceOwner, tokenView.UserID, activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
|
||||
return RefreshTokenRequestFromBusiness(tokenView), nil
|
||||
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
|
||||
}
|
||||
|
||||
func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) {
|
||||
@ -368,18 +255,19 @@ func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionR
|
||||
}()
|
||||
|
||||
// check for the login client header
|
||||
// and if not provided, terminate the session using the V1 method
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient == "" {
|
||||
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
|
||||
}
|
||||
|
||||
// in case there are not id_token_hint, redirect to the UI and let it decide which session to terminate
|
||||
if endSessionRequest.IDTokenHintClaims == nil {
|
||||
// in case there is no id_token_hint, redirect to the UI and let it decide which session to terminate
|
||||
if headers.Get(LoginClientHeader) != "" && endSessionRequest.IDTokenHintClaims == nil {
|
||||
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
|
||||
}
|
||||
|
||||
// terminate the session of the id_token_hint
|
||||
// If there is no login client header and no id_token_hint or the id_token_hint does not have a session ID,
|
||||
// do a v1 Terminate session.
|
||||
if endSessionRequest.IDTokenHintClaims == nil || endSessionRequest.IDTokenHintClaims.SessionID == "" {
|
||||
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
|
||||
}
|
||||
|
||||
// terminate the v2 session of the id_token_hint
|
||||
_, err = o.command.TerminateSessionWithoutTokenCheck(ctx, endSessionRequest.IDTokenHintClaims.SessionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -543,18 +431,6 @@ func setContextUserSystem(ctx context.Context) context.Context {
|
||||
return authz.SetCtxData(ctx, data)
|
||||
}
|
||||
|
||||
func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, idTokenLifetime, refreshTokenIdleExpiration, refreshTokenExpiration time.Duration, _ error) {
|
||||
oidcSettings, err := o.query.OIDCSettingsByAggID(ctx, authz.GetInstance(ctx).InstanceID())
|
||||
if err != nil && !zerrors.IsNotFound(err) {
|
||||
return time.Duration(0), time.Duration(0), time.Duration(0), time.Duration(0), err
|
||||
}
|
||||
|
||||
if oidcSettings != nil {
|
||||
return oidcSettings.AccessTokenLifetime, oidcSettings.IdTokenLifetime, oidcSettings.RefreshTokenIdleExpiration, oidcSettings.RefreshTokenExpiration, nil
|
||||
}
|
||||
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
|
||||
}
|
||||
|
||||
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
|
||||
e := struct {
|
||||
Error string `schema:"error"`
|
||||
@ -593,19 +469,140 @@ func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authoriz
|
||||
return callback, err
|
||||
}
|
||||
|
||||
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {
|
||||
provider := s.Provider()
|
||||
opClient, err := provider.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
|
||||
client, ok := opClient.(*Client)
|
||||
if !ok {
|
||||
return "", zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
|
||||
}
|
||||
|
||||
session, state, err := s.command.CreateOIDCSessionFromAuthRequest(
|
||||
setContextUserSystem(ctx),
|
||||
req.GetID(),
|
||||
implicitFlowComplianceChecker(),
|
||||
slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, provider.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, err
|
||||
}
|
||||
|
||||
func implicitFlowComplianceChecker() command.AuthRequestComplianceChecker {
|
||||
return func(_ context.Context, authReq *command.AuthRequestWriteModel) error {
|
||||
if err := authReq.CheckAuthenticated(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
authorizer := s.Provider()
|
||||
authReq, err := func() (authReq op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
id, err := op.ParseAuthorizeCallbackRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq, err = authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !authReq.Done() {
|
||||
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
|
||||
}
|
||||
return authReq, s.authResponse(authReq, authorizer, w, r)
|
||||
}()
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
|
||||
ctx, span := tracing.NewSpan(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, authReq.GetClientID())
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
op.AuthResponseCode(w, r, authReq, authorizer)
|
||||
return nil
|
||||
}
|
||||
return s.authResponseToken(authReq, authorizer, client, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
|
||||
ctx, span := tracing.NewSpan(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
client, ok := opClient.(*Client)
|
||||
if !ok {
|
||||
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
|
||||
}
|
||||
|
||||
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
|
||||
scope := authReq.GetScopes()
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
authReq.GetSubject(),
|
||||
userOrgID,
|
||||
client.client.ClientID,
|
||||
scope,
|
||||
authReq.GetAudience(),
|
||||
AMRToAuthMethodTypes(authMethodsReferences),
|
||||
authTime,
|
||||
authReq.GetNonce(),
|
||||
preferredLanguage,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: &userAgentID,
|
||||
},
|
||||
reason,
|
||||
actor,
|
||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
)
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion)
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
|
||||
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
||||
if err = op.AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder()); err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -69,7 +70,7 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false)
|
||||
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1040,3 +1041,26 @@ func (s *Server) verifyClientSecret(ctx context.Context, client *query.OIDCClien
|
||||
s.command.OIDCSecretCheckSucceeded(ctx, client.AppID, client.ProjectID, client.Settings.ResourceOwner, updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkOrgScopes(ctx context.Context, user *query.User, scopes []string) ([]string, error) {
|
||||
if slices.ContainsFunc(scopes, func(scope string) bool {
|
||||
return strings.HasPrefix(scope, domain.OrgDomainPrimaryScope)
|
||||
}) {
|
||||
org, err := s.query.OrgByID(ctx, false, user.ResourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scopes = slices.DeleteFunc(scopes, func(scope string) bool {
|
||||
if domain, ok := strings.CutPrefix(scope, domain.OrgDomainPrimaryScope); ok {
|
||||
return domain != org.Domain
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
return slices.DeleteFunc(scopes, func(scope string) bool {
|
||||
if orgID, ok := strings.CutPrefix(scope, domain.OrgIDScope); ok {
|
||||
return orgID != user.ResourceOwner
|
||||
}
|
||||
return false
|
||||
}), nil
|
||||
}
|
||||
|
@ -104,28 +104,7 @@ func (c *Client) AccessTokenType() op.AccessTokenType {
|
||||
}
|
||||
|
||||
func (c *Client) IsScopeAllowed(scope string) bool {
|
||||
if strings.HasPrefix(scope, domain.OrgDomainPrimaryScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.OrgIDScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.ProjectIDScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.SelectIDPScope) {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeUserMetaData {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeResourceOwner {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeProjectsRoles {
|
||||
return true
|
||||
}
|
||||
return slices.Contains(c.allowedScopes, scope)
|
||||
return isScopeAllowed(scope, c.allowedScopes...)
|
||||
}
|
||||
|
||||
func (c *Client) ClockSkew() time.Duration {
|
||||
@ -249,3 +228,28 @@ func clientIDFromCredentials(cc *op.ClientCredentials) (clientID string, asserti
|
||||
}
|
||||
return cc.ClientID, false, nil
|
||||
}
|
||||
|
||||
func isScopeAllowed(scope string, allowedScopes ...string) bool {
|
||||
if strings.HasPrefix(scope, domain.OrgDomainPrimaryScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.OrgIDScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.ProjectIDScope) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(scope, domain.SelectIDPScope) {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeUserMetaData {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeResourceOwner {
|
||||
return true
|
||||
}
|
||||
if scope == ScopeProjectsRoles {
|
||||
return true
|
||||
}
|
||||
return slices.Contains(allowedScopes, scope)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
@ -136,9 +137,8 @@ func (c *clientCredentialsClient) RestrictAdditionalAccessTokenScopes() func(sco
|
||||
}
|
||||
}
|
||||
|
||||
// IsScopeAllowed returns null false as the check is executed during the auth request validation
|
||||
func (c *clientCredentialsClient) IsScopeAllowed(scope string) bool {
|
||||
return false
|
||||
return isScopeAllowed(scope) || strings.HasPrefix(scope, ScopeProjectRolePrefix)
|
||||
}
|
||||
|
||||
// IDTokenUserinfoClaimsAssertion returns null false as no id_token is issued
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
@ -22,7 +21,6 @@ import (
|
||||
"github.com/zitadel/zitadel/pkg/grpc/authn"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user"
|
||||
)
|
||||
|
||||
func TestServer_Introspect(t *testing.T) {
|
||||
@ -346,73 +344,3 @@ func createInvalidKeyData(t testing.TB, client *management.AddOIDCAppResponse) [
|
||||
require.NoError(t, err)
|
||||
return data
|
||||
}
|
||||
|
||||
func TestServer_CreateAccessToken_ClientCredentials(t *testing.T) {
|
||||
_, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
require.NoError(t, err)
|
||||
|
||||
type clientDetails struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
keyData []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing client ID error",
|
||||
clientID: "",
|
||||
clientSecret: clientSecret,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "client not found error",
|
||||
clientID: "foo",
|
||||
clientSecret: clientSecret,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine user without secret error",
|
||||
clientID: func() string {
|
||||
name := gofakeit.Username()
|
||||
_, err := Tester.Client.Mgmt.AddMachineUser(CTX, &management.AddMachineUserRequest{
|
||||
Name: name,
|
||||
UserName: name,
|
||||
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return name
|
||||
}(),
|
||||
clientSecret: clientSecret,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong secret error",
|
||||
clientID: clientID,
|
||||
clientSecret: "bar",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), tt.clientID, tt.clientSecret, redirectURI, []string{oidc.ScopeOpenID})
|
||||
require.NoError(t, err)
|
||||
tokens, err := rp.ClientCredentials(CTX, provider, nil)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, tokens)
|
||||
assert.NotEmpty(t, tokens.AccessToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,14 +2,14 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
@ -80,7 +80,7 @@ func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, devi
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scope, audience)
|
||||
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scope, audience, slices.Contains(scope, oidc.ScopeOfflineAccess))
|
||||
if err == nil {
|
||||
logger.SetFields("details", details).Debug(logMsg)
|
||||
}
|
||||
@ -88,50 +88,6 @@ func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, devi
|
||||
return err
|
||||
}
|
||||
|
||||
func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationState {
|
||||
return &op.DeviceAuthorizationState{
|
||||
ClientID: d.ClientID,
|
||||
Scopes: d.Scopes,
|
||||
Audience: d.Audience,
|
||||
Expires: d.Expires,
|
||||
Done: d.State.Done(),
|
||||
Denied: d.State.Denied(),
|
||||
Subject: d.Subject,
|
||||
AMR: AuthMethodTypesToAMR(d.UserAuthMethods),
|
||||
AuthTime: d.AuthTime,
|
||||
}
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizatonState retrieves the current state of the Device Authorization process.
|
||||
// It implements the [op.DeviceAuthorizationStorage] interface and is used by devices that
|
||||
// are polling until they successfully receive a token or we indicate a denied or expired state.
|
||||
// As generated user codes are of low entropy, this implementation also takes care or
|
||||
// device authorization request cleanup, when it has been Approved, Denied or Expired.
|
||||
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
deviceAuth, err := o.query.DeviceAuthByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logging.WithFields(
|
||||
"device_code", deviceCode,
|
||||
"expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes,
|
||||
"subject", deviceAuth.Subject, "state", deviceAuth.State,
|
||||
).Debug("device authorization state")
|
||||
|
||||
// Cancel the request if it is expired, only if it wasn't Done meanwhile
|
||||
if !deviceAuth.State.Done() && deviceAuth.Expires.Before(time.Now()) {
|
||||
_, err = o.command.CancelDeviceAuth(ctx, deviceAuth.DeviceCode, domain.DeviceAuthCanceledExpired)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceAuth.State = domain.DeviceAuthStateExpired
|
||||
}
|
||||
|
||||
return newDeviceAuthorizationState(deviceAuth), nil
|
||||
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, _, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -140,10 +140,13 @@ func NewServer(
|
||||
fallbackLogger: fallbackLogger,
|
||||
hasher: hasher,
|
||||
signingKeyAlgorithm: config.SigningKeyAlgorithm,
|
||||
encAlg: encryptionAlg,
|
||||
opCrypto: op.NewAESCrypto(opConfig.CryptoKey),
|
||||
assetAPIPrefix: assets.AssetAPI(externalSecure),
|
||||
}
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
server.Handler = op.RegisterLegacyServer(server,
|
||||
server.authorizeCallbackHandler,
|
||||
op.WithFallbackLogger(fallbackLogger),
|
||||
op.WithHTTPMiddleware(
|
||||
middleware.MetricsHandler(metricTypes),
|
||||
|
@ -37,7 +37,10 @@ type Server struct {
|
||||
fallbackLogger *slog.Logger
|
||||
hasher *crypto.Hasher
|
||||
signingKeyAlgorithm string
|
||||
assetAPIPrefix func(ctx context.Context) string
|
||||
encAlg crypto.EncryptionAlgorithm
|
||||
opCrypto op.Crypto
|
||||
|
||||
assetAPIPrefix func(ctx context.Context) string
|
||||
}
|
||||
|
||||
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
|
||||
@ -153,41 +156,6 @@ func (s *Server) DeviceAuthorization(ctx context.Context, r *op.ClientRequest[oi
|
||||
return s.LegacyServer.DeviceAuthorization(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.AccessTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.CodeExchange(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.RefreshToken(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGrantRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.JWTProfile(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequest[oidc.ClientCredentialsRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.ClientCredentialsExchange(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.DeviceAccessTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
return s.LegacyServer.DeviceToken(ctx, r)
|
||||
}
|
||||
|
||||
func (s *Server) Revocation(ctx context.Context, r *op.ClientRequest[oidc.RevocationRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
@ -232,3 +200,10 @@ func (s *Server) createDiscoveryConfig(ctx context.Context, supportedUILocales o
|
||||
RequestParameterSupported: s.Provider().RequestObjectSupported(),
|
||||
}
|
||||
}
|
||||
|
||||
func response(resp any, err error) (*op.Response, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return op.NewResponse(resp), nil
|
||||
}
|
||||
|
198
internal/api/oidc/token.go
Normal file
198
internal/api/oidc/token.go
Normal file
@ -0,0 +1,198 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
/*
|
||||
For each grant-type, tokens creation follows the same rough logical steps:
|
||||
|
||||
1. Information gathering: who is requesting the token, what do we put in the claims?
|
||||
2. Decision making: is the request authorized? (valid exchange code, auth request completed, valid token etc...)
|
||||
3. Build an OIDC session in storage: inform the eventstore we are creating tokens.
|
||||
4. Use the OIDC session to encrypt and / or sign the requested tokens
|
||||
|
||||
In some cases step 1 till 3 are completely implemented in the command package,
|
||||
for example the v2 code exchange and refresh token.
|
||||
*/
|
||||
|
||||
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
|
||||
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
|
||||
getSigner := s.getSignerOnce()
|
||||
|
||||
resp := &oidc.AccessTokenResponse{
|
||||
TokenType: oidc.BearerToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
ExpiresIn: timeToOIDCExpiresIn(session.Expiration),
|
||||
State: state,
|
||||
}
|
||||
|
||||
// If the session does not have a token ID, it is an implicit ID-Token only response.
|
||||
if session.TokenID != "" {
|
||||
if client.AccessTokenType() == op.AccessTokenTypeJWT {
|
||||
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
|
||||
} else {
|
||||
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// signerFunc is a getter function that allows add-hoc retrieval of the instance's signer.
|
||||
type signerFunc func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error)
|
||||
|
||||
// getSignerOnce returns a function which retrieves the instance's signer from the database once.
|
||||
// Repeated calls of the returned function return the same results.
|
||||
func (s *Server) getSignerOnce() signerFunc {
|
||||
var (
|
||||
once sync.Once
|
||||
signer jose.Signer
|
||||
signAlg jose.SignatureAlgorithm
|
||||
err error
|
||||
)
|
||||
return func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error) {
|
||||
once.Do(func() {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var signingKey op.SigningKey
|
||||
signingKey, err = s.Provider().Storage().SigningKey(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
signAlg = signingKey.SignatureAlgorithm()
|
||||
|
||||
signer, err = op.SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
return signer, signAlg, err
|
||||
}
|
||||
}
|
||||
|
||||
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
|
||||
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
|
||||
|
||||
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
|
||||
// Repeated calls of the returned function return the same results.
|
||||
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
|
||||
var (
|
||||
once sync.Once
|
||||
userInfo *oidc.UserInfo
|
||||
err error
|
||||
)
|
||||
return func(ctx context.Context) (*oidc.UserInfo, error) {
|
||||
once.Do(func() {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
|
||||
})
|
||||
return userInfo, err
|
||||
}
|
||||
}
|
||||
|
||||
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
userInfo, err := getUserInfo(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
signer, signAlg, err := getSigningKey(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
expTime := time.Now().Add(client.IDTokenLifetime()).Add(client.ClockSkew())
|
||||
claims := oidc.NewIDTokenClaims(
|
||||
op.IssuerFromContext(ctx),
|
||||
"",
|
||||
audience,
|
||||
expTime,
|
||||
authTime,
|
||||
nonce,
|
||||
"",
|
||||
AuthMethodTypesToAMR(authMethods),
|
||||
client.GetID(),
|
||||
client.ClockSkew(),
|
||||
)
|
||||
claims.SessionID = sessionID
|
||||
claims.Actor = actorDomainToClaims(actor)
|
||||
claims.SetUserInfo(userInfo)
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signAlg)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
}
|
||||
idToken, err = crypto.Sign(claims, signer)
|
||||
return idToken, timeToOIDCExpiresIn(expTime), err
|
||||
}
|
||||
|
||||
func timeToOIDCExpiresIn(exp time.Time) uint64 {
|
||||
return uint64(time.Until(exp) / time.Second)
|
||||
}
|
||||
|
||||
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
userInfo, err := getUserInfo(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signer, _, err := getSigner(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
expTime := session.Expiration.Add(client.ClockSkew())
|
||||
claims := oidc.NewAccessTokenClaims(
|
||||
op.IssuerFromContext(ctx),
|
||||
userInfo.Subject,
|
||||
session.Audience,
|
||||
expTime,
|
||||
session.TokenID,
|
||||
client.GetID(),
|
||||
client.ClockSkew(),
|
||||
)
|
||||
claims.Actor = actorDomainToClaims(session.Actor)
|
||||
claims.Claims = userInfo.Claims
|
||||
|
||||
return crypto.Sign(claims, signer)
|
||||
}
|
||||
|
||||
// decryptCode decrypts a code or refresh_token
|
||||
func (s *Server) decryptCode(ctx context.Context, code string) (_ string, err error) {
|
||||
_, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.encAlg.DecryptString(decoded, s.encAlg.EncryptionKeyID())
|
||||
}
|
51
internal/api/oidc/token_client_credentials.go
Normal file
51
internal/api/oidc/token_client_credentials.go
Normal file
@ -0,0 +1,51 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequest[oidc.ClientCredentialsRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
span.EndWithError(err)
|
||||
err = oidcError(err)
|
||||
}()
|
||||
client, ok := r.Client.(*clientCredentialsClient)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInternal(nil, "OIDC-ga0EP", "Error.Internal")
|
||||
}
|
||||
scope, err := op.ValidateAuthReqScopes(client, r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err = s.checkOrgScopes(ctx, client.user, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
client.user.ID,
|
||||
client.user.ResourceOwner,
|
||||
r.Data.ClientID,
|
||||
scope,
|
||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Now(),
|
||||
"",
|
||||
nil,
|
||||
nil,
|
||||
domain.TokenReasonClientCredentials,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
|
||||
}
|
143
internal/api/oidc/token_client_credentials_integration_test.go
Normal file
143
internal/api/oidc/token_client_credentials_integration_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user"
|
||||
)
|
||||
|
||||
func TestServer_ClientCredentialsExchange(t *testing.T) {
|
||||
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
require.NoError(t, err)
|
||||
|
||||
type claims struct {
|
||||
resourceOwnerID any
|
||||
resourceOwnerName any
|
||||
resourceOwnerPrimaryDomain any
|
||||
orgDomain any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
scope []string
|
||||
wantClaims claims
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing client ID error",
|
||||
clientID: "",
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "client not found error",
|
||||
clientID: "foo",
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine user without secret error",
|
||||
clientID: func() string {
|
||||
name := gofakeit.Username()
|
||||
_, err := Tester.Client.Mgmt.AddMachineUser(CTX, &management.AddMachineUserRequest{
|
||||
Name: name,
|
||||
UserName: name,
|
||||
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return name
|
||||
}(),
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong secret error",
|
||||
clientID: clientID,
|
||||
clientSecret: "bar",
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
{
|
||||
name: "org id and domain scope",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{
|
||||
oidc.ScopeOpenID,
|
||||
domain.OrgIDScope + Tester.Organisation.ID,
|
||||
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
|
||||
},
|
||||
wantClaims: claims{
|
||||
resourceOwnerID: Tester.Organisation.ID,
|
||||
resourceOwnerName: Tester.Organisation.Name,
|
||||
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
|
||||
orgDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid org domain filtered",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{
|
||||
oidc.ScopeOpenID,
|
||||
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
|
||||
domain.OrgDomainPrimaryScope + "foo"},
|
||||
wantClaims: claims{
|
||||
orgDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid org id filtered",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID,
|
||||
domain.OrgIDScope + Tester.Organisation.ID,
|
||||
domain.OrgIDScope + "foo",
|
||||
},
|
||||
wantClaims: claims{
|
||||
resourceOwnerID: Tester.Organisation.ID,
|
||||
resourceOwnerName: Tester.Organisation.Name,
|
||||
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), tt.clientID, tt.clientSecret, redirectURI, tt.scope)
|
||||
require.NoError(t, err)
|
||||
tokens, err := rp.ClientCredentials(CTX, provider, nil)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokens)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
|
||||
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
|
||||
})
|
||||
}
|
||||
}
|
125
internal/api/oidc/token_code.go
Normal file
125
internal/api/oidc/token_code.go
Normal file
@ -0,0 +1,125 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.AccessTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
span.EndWithError(err)
|
||||
err = oidcError(err)
|
||||
}()
|
||||
|
||||
client, ok := r.Client.(*Client)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInternal(nil, "OIDC-Ae2ph", "Error.Internal")
|
||||
}
|
||||
|
||||
plainCode, err := s.decryptCode(ctx, r.Data.Code)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "OIDC-ahLi2", "Errors.User.Code.Invalid")
|
||||
}
|
||||
|
||||
var (
|
||||
session *command.OIDCSession
|
||||
state string
|
||||
)
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
session, state, err = s.command.CreateOIDCSessionFromAuthRequest(
|
||||
setContextUserSystem(ctx),
|
||||
plainCode,
|
||||
codeExchangeComplianceChecker(client, r.Data),
|
||||
slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken),
|
||||
)
|
||||
} else {
|
||||
session, state, err = s.codeExchangeV1(ctx, client, r.Data, r.Data.Code)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
}
|
||||
|
||||
// codeExchangeV1 creates a v2 token from a v1 auth request.
|
||||
func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.AccessTokenRequest, code string) (session *command.OIDCSession, state string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
authReq, err := s.getAuthRequestV1ByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if challenge := authReq.GetCodeChallenge(); challenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
|
||||
if err = op.AuthorizeCodeChallenge(req.CodeVerifier, challenge); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
if req.RedirectURI != authReq.GetRedirectURI() {
|
||||
return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
||||
}
|
||||
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
|
||||
|
||||
scope := authReq.GetScopes()
|
||||
session, err = s.command.CreateOIDCSession(ctx,
|
||||
authReq.GetSubject(),
|
||||
userOrgID,
|
||||
client.client.ClientID,
|
||||
scope,
|
||||
authReq.GetAudience(),
|
||||
AMRToAuthMethodTypes(authMethodsReferences),
|
||||
authTime,
|
||||
authReq.GetNonce(),
|
||||
preferredLanguage,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: &userAgentID,
|
||||
},
|
||||
reason,
|
||||
actor,
|
||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return session, authReq.GetState(), s.repo.DeleteAuthRequest(ctx, authReq.GetID())
|
||||
}
|
||||
|
||||
// getAuthRequestV1ByCode finds the v1 auth request by code.
|
||||
// code needs to be the encrypted version of the ID,
|
||||
// this is required by the underlying repo.
|
||||
func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.AuthRequest, error) {
|
||||
authReq, err := s.repo.AuthRequestByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return AuthRequestFromBusiness(authReq)
|
||||
}
|
||||
|
||||
func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker {
|
||||
return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error {
|
||||
if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
|
||||
err := op.AuthorizeCodeChallenge(req.CodeVerifier, CodeChallengeToOIDC(authReq.CodeChallenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if req.RedirectURI != authReq.RedirectURI {
|
||||
return oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
||||
}
|
||||
if err := authReq.CheckAuthenticated(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
46
internal/api/oidc/token_device.go
Normal file
46
internal/api/oidc/token_device.go
Normal file
@ -0,0 +1,46 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.DeviceAccessTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
span.EndWithError(err)
|
||||
err = oidcError(err)
|
||||
}()
|
||||
|
||||
client, ok := r.Client.(*Client)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInternal(nil, "OIDC-Ae2ph", "Error.Internal")
|
||||
}
|
||||
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
|
||||
if err == nil {
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||
}
|
||||
|
||||
var target command.DeviceAuthStateError
|
||||
if errors.As(err, &target) {
|
||||
state := domain.DeviceAuthState(target)
|
||||
if state == domain.DeviceAuthStateInitiated {
|
||||
return nil, oidc.ErrAuthorizationPending()
|
||||
}
|
||||
if state == domain.DeviceAuthStateExpired {
|
||||
return nil, oidc.ErrExpiredDeviceCode()
|
||||
}
|
||||
}
|
||||
return nil, oidc.ErrAccessDenied().WithParent(err)
|
||||
}
|
@ -5,9 +5,9 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -134,13 +134,16 @@ func (s *Server) verifyExchangeToken(ctx context.Context, client *Client, token
|
||||
return idTokenClaimsToExchangeToken(claims, resourceOwner), nil
|
||||
|
||||
case oidc.JWTTokenType:
|
||||
resourceOwner := new(string)
|
||||
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, client.client.ClockSkew, s.jwtProfileUserCheck(ctx, resourceOwner))
|
||||
var (
|
||||
resourceOwner string
|
||||
preferredLanguage *language.Tag
|
||||
)
|
||||
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, client.client.ClockSkew, s.jwtProfileUserCheck(ctx, &resourceOwner, &preferredLanguage))
|
||||
jwt, err := op.VerifyJWTAssertion(ctx, token, verifier)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowPermissionDenied(err, "OIDC-eiS6o", "Errors.TokenExchange.Token.Invalid")
|
||||
}
|
||||
return jwtToExchangeToken(jwt, *resourceOwner), nil
|
||||
return jwtToExchangeToken(jwt, resourceOwner, preferredLanguage), nil
|
||||
|
||||
case UserIDTokenType:
|
||||
user, err := s.query.GetUserByID(ctx, false, token)
|
||||
@ -156,13 +159,18 @@ func (s *Server) verifyExchangeToken(ctx context.Context, client *Client, token
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) jwtProfileUserCheck(ctx context.Context, resourceOwner *string) op.JWTProfileVerifierOption {
|
||||
// jwtProfileUserCheck finds the user by subject (user ID) and sets the resourceOwner through the pointer.
|
||||
// preferred Language is set only if it was defined for a Human user, else the pointed pointer remains nil.
|
||||
func (s *Server) jwtProfileUserCheck(ctx context.Context, resourceOwner *string, preferredLanguage **language.Tag) op.JWTProfileVerifierOption {
|
||||
return op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
|
||||
user, err := s.query.GetUserByID(ctx, false, request.Subject)
|
||||
if err != nil {
|
||||
return zerrors.ThrowPermissionDenied(err, "OIDC-Nee6r", "Errors.TokenExchange.Token.Invalid")
|
||||
}
|
||||
*resourceOwner = user.ResourceOwner
|
||||
if user.Human != nil && !user.Human.PreferredLanguage.IsRoot() {
|
||||
*preferredLanguage = &user.Human.PreferredLanguage
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -210,21 +218,8 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
|
||||
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
|
||||
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
|
||||
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
|
||||
var (
|
||||
userInfo *oidc.UserInfo
|
||||
signingKey op.SigningKey
|
||||
)
|
||||
if slices.Contains(scopes, oidc.ScopeOpenID) || tokenType == oidc.JWTTokenType || tokenType == oidc.IDTokenType {
|
||||
projectID := client.client.ProjectID
|
||||
userInfo, err = s.userInfo(ctx, subjectToken.userID, scopes, projectID, client.client.ProjectRoleAssertion, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signingKey, err = s.Provider().Storage().SigningKey(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes)
|
||||
getSigner := s.getSignerOnce()
|
||||
|
||||
resp := &oidc.TokenExchangeResponse{
|
||||
Scopes: scopes,
|
||||
@ -237,21 +232,23 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
|
||||
actor = actorToken.nestedActor()
|
||||
}
|
||||
|
||||
var sessionID string
|
||||
switch tokenType {
|
||||
case oidc.AccessTokenType, "":
|
||||
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.resourceOwner, subjectToken.userID, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor)
|
||||
resp.AccessToken, resp.RefreshToken, sessionID, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
|
||||
resp.TokenType = oidc.BearerToken
|
||||
resp.IssuedTokenType = oidc.AccessTokenType
|
||||
|
||||
case oidc.JWTTokenType:
|
||||
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, signingKey, client, subjectToken.resourceOwner, subjectToken.userID, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor, userInfo.Claims)
|
||||
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
|
||||
resp.TokenType = oidc.BearerToken
|
||||
resp.IssuedTokenType = oidc.JWTTokenType
|
||||
|
||||
case oidc.IDTokenType:
|
||||
resp.AccessToken, resp.ExpiresIn, err = s.createExchangeIDToken(ctx, signingKey, client, subjectToken.userID, "", audience, userInfo, actorToken.authMethods, actorToken.authTime, reason, actor)
|
||||
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
resp.TokenType = TokenTypeNA
|
||||
resp.IssuedTokenType = oidc.IDTokenType
|
||||
|
||||
case oidc.RefreshTokenType, UserIDTokenType:
|
||||
fallthrough
|
||||
default:
|
||||
@ -262,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
|
||||
resp.IDToken, _, err = s.createExchangeIDToken(ctx, signingKey, client, subjectToken.userID, resp.AccessToken, audience, userInfo, actorToken.authMethods, actorToken.authTime, reason, actor)
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -271,77 +268,83 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Server) createExchangeAccessToken(ctx context.Context, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (accessToken string, refreshToken string, exp uint64, err error) {
|
||||
tokenInfo, refreshToken, err := s.createAccessTokenCommands(ctx, client, resourceOwner, userID, audience, scopes, authMethods, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
accessToken, err = op.CreateBearerToken(tokenInfo.TokenID, userID, s.Provider().Crypto())
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
return accessToken, refreshToken, timeToOIDCExpiresIn(tokenInfo.Expiration), nil
|
||||
}
|
||||
func (s *Server) createExchangeAccessToken(
|
||||
ctx context.Context,
|
||||
client *Client,
|
||||
userID,
|
||||
resourceOwner string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken, refreshToken, sessionID string, exp uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
func (s *Server) createExchangeJWT(ctx context.Context, signingKey op.SigningKey, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor, privateClaims map[string]any) (accessToken string, refreshToken string, exp uint64, err error) {
|
||||
tokenInfo, refreshToken, err := s.createAccessTokenCommands(ctx, client, resourceOwner, userID, audience, scopes, authMethods, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
expTime := tokenInfo.Expiration.Add(client.ClockSkew())
|
||||
claims := oidc.NewAccessTokenClaims(op.IssuerFromContext(ctx), userID, tokenInfo.Audience, expTime, tokenInfo.TokenID, client.GetID(), client.ClockSkew())
|
||||
claims.Actor = actorDomainToClaims(tokenInfo.Actor)
|
||||
claims.Claims = privateClaims
|
||||
|
||||
signer, err := op.SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
accessToken, err = crypto.Sign(claims, signer)
|
||||
if err != nil {
|
||||
return "", "", 0, nil
|
||||
}
|
||||
return accessToken, refreshToken, timeToOIDCExpiresIn(expTime), nil
|
||||
}
|
||||
|
||||
func (s *Server) createExchangeIDToken(ctx context.Context, signingKey op.SigningKey, client *Client, userID, accessToken string, audience []string, userInfo *oidc.UserInfo, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
|
||||
expTime := time.Now().Add(client.IDTokenLifetime()).Add(client.ClockSkew())
|
||||
claims := oidc.NewIDTokenClaims(op.IssuerFromContext(ctx), userID, audience, expTime, authTime, "", "", AuthMethodTypesToAMR(authMethods), client.GetID(), client.ClockSkew())
|
||||
claims.Actor = actorDomainToClaims(actor)
|
||||
claims.SetUserInfo(userInfo)
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
}
|
||||
signer, err := op.SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
idToken, err = crypto.Sign(claims, signer)
|
||||
return idToken, timeToOIDCExpiresIn(expTime), err
|
||||
}
|
||||
|
||||
func timeToOIDCExpiresIn(exp time.Time) uint64 {
|
||||
return uint64(time.Until(exp) / time.Second)
|
||||
}
|
||||
|
||||
func (s *Server) createAccessTokenCommands(ctx context.Context, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (tokenInfo *domain.Token, refreshToken string, err error) {
|
||||
settings := client.client.Settings
|
||||
if slices.Contains(scopes, oidc.ScopeOfflineAccess) {
|
||||
return s.command.AddAccessAndRefreshToken(
|
||||
ctx, resourceOwner, "", client.GetID(), userID, "", audience, scopes, AuthMethodTypesToAMR(authMethods),
|
||||
settings.AccessTokenLifetime, settings.RefreshTokenIdleExpiration, settings.RefreshTokenExpiration,
|
||||
authTime, reason, actor,
|
||||
)
|
||||
}
|
||||
tokenInfo, err = s.command.AddUserToken(
|
||||
ctx, resourceOwner, "", client.GetID(), userID, audience, scopes, AuthMethodTypesToAMR(authMethods),
|
||||
settings.AccessTokenLifetime,
|
||||
authTime, reason, actor,
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
userID,
|
||||
resourceOwner,
|
||||
client.client.ClientID,
|
||||
scope,
|
||||
audience,
|
||||
authMethods,
|
||||
authTime,
|
||||
"",
|
||||
preferredLanguage,
|
||||
nil,
|
||||
reason,
|
||||
actor,
|
||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
)
|
||||
return tokenInfo, "", err
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
accessToken, err = op.CreateBearerToken(session.TokenID, userID, s.opCrypto)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
return accessToken, session.RefreshToken, session.SessionID, timeToOIDCExpiresIn(session.Expiration), nil
|
||||
}
|
||||
|
||||
func (s *Server) createExchangeJWT(
|
||||
ctx context.Context,
|
||||
client *Client,
|
||||
getUserInfo userInfoFunc,
|
||||
getSigner signerFunc,
|
||||
userID,
|
||||
resourceOwner string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken string, refreshToken string, exp uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
userID,
|
||||
resourceOwner,
|
||||
client.client.ClientID,
|
||||
scope,
|
||||
audience,
|
||||
authMethods,
|
||||
authTime,
|
||||
"",
|
||||
preferredLanguage,
|
||||
nil,
|
||||
reason,
|
||||
actor,
|
||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
)
|
||||
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
return accessToken, session.RefreshToken, timeToOIDCExpiresIn(session.Expiration), nil
|
||||
}
|
||||
|
@ -4,21 +4,23 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
type exchangeToken struct {
|
||||
tokenType oidc.TokenType
|
||||
userID string
|
||||
issuer string
|
||||
resourceOwner string
|
||||
authTime time.Time
|
||||
authMethods []domain.UserAuthMethodType
|
||||
actor *domain.TokenActor
|
||||
audience []string
|
||||
scopes []string
|
||||
tokenType oidc.TokenType
|
||||
userID string
|
||||
issuer string
|
||||
resourceOwner string
|
||||
authTime time.Time
|
||||
authMethods []domain.UserAuthMethodType
|
||||
actor *domain.TokenActor
|
||||
audience []string
|
||||
scopes []string
|
||||
preferredLanguage *language.Tag
|
||||
}
|
||||
|
||||
func (et *exchangeToken) nestedActor() *domain.TokenActor {
|
||||
@ -31,27 +33,33 @@ func (et *exchangeToken) nestedActor() *domain.TokenActor {
|
||||
|
||||
func accessToExchangeToken(token *accessToken, issuer string) *exchangeToken {
|
||||
return &exchangeToken{
|
||||
tokenType: oidc.AccessTokenType,
|
||||
userID: token.userID,
|
||||
issuer: issuer,
|
||||
resourceOwner: token.resourceOwner,
|
||||
authMethods: token.authMethods,
|
||||
actor: token.actor,
|
||||
audience: token.audience,
|
||||
scopes: token.scope,
|
||||
tokenType: oidc.AccessTokenType,
|
||||
userID: token.userID,
|
||||
issuer: issuer,
|
||||
resourceOwner: token.resourceOwner,
|
||||
authMethods: token.authMethods,
|
||||
actor: token.actor,
|
||||
audience: token.audience,
|
||||
scopes: token.scope,
|
||||
preferredLanguage: token.preferredLanguage,
|
||||
}
|
||||
}
|
||||
|
||||
func idTokenClaimsToExchangeToken(claims *oidc.IDTokenClaims, resourceOwner string) *exchangeToken {
|
||||
var preferredLanguage *language.Tag
|
||||
if tag := claims.Locale.Tag(); !tag.IsRoot() {
|
||||
preferredLanguage = &tag
|
||||
}
|
||||
return &exchangeToken{
|
||||
tokenType: oidc.IDTokenType,
|
||||
userID: claims.Subject,
|
||||
issuer: claims.Issuer,
|
||||
resourceOwner: resourceOwner,
|
||||
authTime: claims.GetAuthTime(),
|
||||
authMethods: AMRToAuthMethodTypes(claims.AuthenticationMethodsReferences),
|
||||
actor: actorClaimsToDomain(claims.Actor),
|
||||
audience: claims.Audience,
|
||||
tokenType: oidc.IDTokenType,
|
||||
userID: claims.Subject,
|
||||
issuer: claims.Issuer,
|
||||
resourceOwner: resourceOwner,
|
||||
authTime: claims.GetAuthTime(),
|
||||
authMethods: AMRToAuthMethodTypes(claims.AuthenticationMethodsReferences),
|
||||
actor: actorClaimsToDomain(claims.Actor),
|
||||
audience: claims.Audience,
|
||||
preferredLanguage: preferredLanguage,
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,7 +85,7 @@ func actorDomainToClaims(actor *domain.TokenActor) *oidc.ActorClaims {
|
||||
}
|
||||
}
|
||||
|
||||
func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string) *exchangeToken {
|
||||
func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string, preferredLanguage *language.Tag) *exchangeToken {
|
||||
return &exchangeToken{
|
||||
tokenType: oidc.JWTTokenType,
|
||||
userID: jwt.Subject,
|
||||
@ -86,6 +94,7 @@ func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string) *exchan
|
||||
scopes: jwt.Scopes,
|
||||
authTime: jwt.IssuedAt.AsTime(),
|
||||
// audience omitted as we don't thrust audiences not signed by us
|
||||
preferredLanguage: preferredLanguage,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,5 +587,5 @@ func TestImpersonation_API_Call(t *testing.T) {
|
||||
_, err = Tester.Client.Admin.GetAllowedLanguages(impersonatedCTX, &admin.GetAllowedLanguagesRequest{})
|
||||
status := status.Convert(err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code())
|
||||
assert.Equal(t, "Errors.TokenExchange.Token.NotForAPI (APP-wai8O)", status.Message())
|
||||
assert.Equal(t, "Errors.TokenExchange.Token.NotForAPI (APP-Shi0J)", status.Message())
|
||||
}
|
||||
|
99
internal/api/oidc/token_jwt_profile.go
Normal file
99
internal/api/oidc/token_jwt_profile.go
Normal file
@ -0,0 +1,99 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGrantRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
span.EndWithError(err)
|
||||
err = oidcError(err)
|
||||
}()
|
||||
|
||||
user, jwtReq, err := s.verifyJWTProfile(ctx, r.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &clientCredentialsClient{
|
||||
id: jwtReq.Subject,
|
||||
user: user,
|
||||
}
|
||||
scope, err := op.ValidateAuthReqScopes(client, r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err = s.checkOrgScopes(ctx, client.user, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
user.ID,
|
||||
user.ResourceOwner,
|
||||
jwtReq.Subject,
|
||||
scope,
|
||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||
nil,
|
||||
time.Now(),
|
||||
"",
|
||||
nil,
|
||||
nil,
|
||||
domain.TokenReasonClientCredentials,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
|
||||
}
|
||||
|
||||
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
checkSubject := func(jwt *oidc.JWTTokenRequest) (err error) {
|
||||
user, err = s.query.GetUserByID(ctx, true, jwt.Subject)
|
||||
return err
|
||||
}
|
||||
verifier := op.NewJWTProfileVerifier(
|
||||
&jwtProfileKeyStorage{query: s.query},
|
||||
op.IssuerFromContext(ctx),
|
||||
time.Hour, time.Second,
|
||||
op.SubjectCheck(checkSubject),
|
||||
)
|
||||
tokenRequest, err = op.VerifyJWTAssertion(ctx, req.Assertion, verifier)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, tokenRequest, nil
|
||||
}
|
||||
|
||||
type jwtProfileKeyStorage struct {
|
||||
query *query.Queries
|
||||
}
|
||||
|
||||
func (s *jwtProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) {
|
||||
publicKeyData, err := s.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publicKey, err := crypto.BytesToPublicKey(publicKeyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &jose.JSONWebKey{
|
||||
KeyID: keyID,
|
||||
Use: "sig",
|
||||
Key: publicKey,
|
||||
}, nil
|
||||
}
|
103
internal/api/oidc/token_jwt_profile_integration_test.go
Normal file
103
internal/api/oidc/token_jwt_profile_integration_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/profile"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
func TestServer_JWTProfile(t *testing.T) {
|
||||
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
|
||||
require.NoError(t, err)
|
||||
|
||||
type claims struct {
|
||||
resourceOwnerID any
|
||||
resourceOwnerName any
|
||||
resourceOwnerPrimaryDomain any
|
||||
orgDomain any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
keyData []byte
|
||||
scope []string
|
||||
wantClaims claims
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
keyData: keyData,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
{
|
||||
name: "org id and domain scope",
|
||||
keyData: keyData,
|
||||
scope: []string{
|
||||
oidc.ScopeOpenID,
|
||||
domain.OrgIDScope + Tester.Organisation.ID,
|
||||
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
|
||||
},
|
||||
wantClaims: claims{
|
||||
resourceOwnerID: Tester.Organisation.ID,
|
||||
resourceOwnerName: Tester.Organisation.Name,
|
||||
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
|
||||
orgDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid org domain filtered",
|
||||
keyData: keyData,
|
||||
scope: []string{
|
||||
oidc.ScopeOpenID,
|
||||
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
|
||||
domain.OrgDomainPrimaryScope + "foo"},
|
||||
wantClaims: claims{
|
||||
orgDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid org id filtered",
|
||||
keyData: keyData,
|
||||
scope: []string{oidc.ScopeOpenID,
|
||||
domain.OrgIDScope + Tester.Organisation.ID,
|
||||
domain.OrgIDScope + "foo",
|
||||
},
|
||||
wantClaims: claims{
|
||||
resourceOwnerID: Tester.Organisation.ID,
|
||||
resourceOwnerName: Tester.Organisation.Name,
|
||||
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenSource, err := profile.NewJWTProfileTokenSourceFromKeyFileData(CTX, Tester.OIDCIssuer(), tt.keyData, tt.scope)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokens, err := tokenSource.TokenCtx(CTX)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokens)
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
|
||||
require.NoError(t, err)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
|
||||
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
|
||||
})
|
||||
}
|
||||
}
|
101
internal/api/oidc/token_refresh.go
Normal file
101
internal/api/oidc/token_refresh.go
Normal file
@ -0,0 +1,101 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
span.EndWithError(err)
|
||||
err = oidcError(err)
|
||||
}()
|
||||
|
||||
client, ok := r.Client.(*Client)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInternal(nil, "OIDC-ga0EP", "Error.Internal")
|
||||
}
|
||||
|
||||
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
|
||||
if err == nil {
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
|
||||
// We try again for v1 tokens when we encountered specific parsing error
|
||||
return s.refreshTokenV1(ctx, client, r)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refreshTokenV1 verifies a v1 refresh token.
|
||||
// When valid a v2 OIDC session is created and v2 tokens are returned.
|
||||
// This "upgrades" existing v1 sessions to v2 session without requiring users to re-login.
|
||||
//
|
||||
// This function can be removed when we retire the v1 token repo.
|
||||
func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
|
||||
refreshToken, err := s.repo.RefreshTokenByToken(ctx, r.Data.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err := validateRefreshTokenScopes(refreshToken.Scopes, r.Data.Scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := s.command.CreateOIDCSession(ctx,
|
||||
refreshToken.UserID,
|
||||
refreshToken.ResourceOwner,
|
||||
refreshToken.ClientID,
|
||||
scope,
|
||||
refreshToken.Audience,
|
||||
AMRToAuthMethodTypes(refreshToken.AuthMethodsReferences),
|
||||
refreshToken.AuthTime,
|
||||
"",
|
||||
nil, // Preferred language not in refresh token view
|
||||
&domain.UserAgent{
|
||||
FingerprintID: &refreshToken.UserAgentID,
|
||||
Description: &refreshToken.UserAgentID,
|
||||
},
|
||||
domain.TokenReasonRefresh,
|
||||
refreshToken.Actor,
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// make sure the v1 refresh token can't be reused.
|
||||
_, err = s.command.RevokeRefreshToken(ctx, refreshToken.UserID, refreshToken.ResourceOwner, refreshToken.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
}
|
||||
|
||||
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.
|
||||
func refreshTokenComplianceChecker() command.RefreshTokenComplianceChecker {
|
||||
return func(_ context.Context, model *command.OIDCSessionWriteModel, requestedScope []string) ([]string, error) {
|
||||
return validateRefreshTokenScopes(model.Scope, requestedScope)
|
||||
}
|
||||
}
|
||||
|
||||
func validateRefreshTokenScopes(currentScope, requestedScope []string) ([]string, error) {
|
||||
if len(requestedScope) == 0 {
|
||||
return currentScope, nil
|
||||
}
|
||||
for _, s := range requestedScope {
|
||||
if !slices.Contains(currentScope, s) {
|
||||
return nil, oidc.ErrInvalidScope()
|
||||
}
|
||||
}
|
||||
return requestedScope, nil
|
||||
}
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
|
||||
@ -48,7 +49,8 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
|
||||
)
|
||||
if token.clientID != "" {
|
||||
projectID, assertion, err = s.query.GetOIDCUserinfoClientByID(ctx, token.clientID)
|
||||
if err != nil {
|
||||
// token.clientID might contain a username (e.g. client credentials) -> ignore the not found
|
||||
if err != nil && !zerrors.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +112,7 @@ func (l *Login) handleDeviceAuthUserCode(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
authRequest, err := l.authRepo.CreateAuthRequest(ctx, &domain.AuthRequest{
|
||||
CreationDate: time.Now(),
|
||||
BrowserInfo: domain.BrowserInfoFromRequest(r),
|
||||
AgentID: userAgentID,
|
||||
ApplicationID: deviceAuthReq.ClientID,
|
||||
InstanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
@ -163,7 +164,7 @@ func (l *Login) handleDeviceAuthAction(w http.ResponseWriter, r *http.Request) {
|
||||
action := mux.Vars(r)["action"]
|
||||
switch action {
|
||||
case deviceAuthAllowed:
|
||||
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.DeviceCode, authReq.UserID, authReq.UserAuthMethodTypes(), authReq.AuthTime)
|
||||
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.DeviceCode, authReq.UserID, authReq.UserOrgID, authReq.UserAuthMethodTypes(), authReq.AuthTime, authReq.PreferredLanguage, authReq.BrowserInfo.ToUserAgent())
|
||||
case deviceAuthDenied:
|
||||
_, err = l.command.CancelDeviceAuth(r.Context(), authDev.DeviceCode, domain.DeviceAuthCanceledDenied)
|
||||
default:
|
||||
|
@ -5,7 +5,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
@ -1016,6 +1018,9 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
|
||||
}
|
||||
request.DisplayName = userSession.DisplayName
|
||||
request.AvatarKey = userSession.AvatarKey
|
||||
if user.HumanView != nil && user.HumanView.PreferredLanguage != "" {
|
||||
request.PreferredLanguage = gu.Ptr(language.Make(user.HumanView.PreferredLanguage))
|
||||
}
|
||||
|
||||
isInternalLogin := request.SelectedIDPConfigID == "" && userSession.SelectedIDPConfigID == ""
|
||||
idps, err := checkExternalIDPsOfUser(ctx, repo.IDPUserLinksProvider, user.ID)
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
@ -96,8 +97,7 @@ func (repo *TokenVerifierRepo) VerifyAccessToken(ctx context.Context, tokenStrin
|
||||
return "", "", "", "", "", zerrors.ThrowUnauthenticated(nil, "APP-Reb32", "invalid token")
|
||||
}
|
||||
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
|
||||
userID, clientID, resourceOwner, err = repo.verifyAccessTokenV2(ctx, tokenID, verifierClientID, projectID)
|
||||
return
|
||||
return repo.verifyAccessTokenV2(ctx, tokenID, verifierClientID, projectID)
|
||||
}
|
||||
if sessionID, ok := strings.CutPrefix(tokenID, authz.SessionTokenPrefix); ok {
|
||||
userID, clientID, resourceOwner, err = repo.verifySessionToken(ctx, sessionID, tokenString)
|
||||
@ -106,7 +106,7 @@ func (repo *TokenVerifierRepo) VerifyAccessToken(ctx context.Context, tokenStrin
|
||||
return repo.verifyAccessTokenV1(ctx, tokenID, subject, verifierClientID, projectID)
|
||||
}
|
||||
|
||||
func (repo *TokenVerifierRepo) verifyAccessTokenV1(ctx context.Context, tokenID, subject, verifierClientID, projectID string) (userID string, agentID string, clientID, prefLang, resourceOwner string, err error) {
|
||||
func (repo *TokenVerifierRepo) verifyAccessTokenV1(ctx context.Context, tokenID, subject, verifierClientID, projectID string) (userID, agentID, clientID, prefLang, resourceOwner string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -131,24 +131,27 @@ func (repo *TokenVerifierRepo) verifyAccessTokenV1(ctx context.Context, tokenID,
|
||||
return token.UserID, token.UserAgentID, token.ApplicationID, token.PreferredLanguage, token.ResourceOwner, nil
|
||||
}
|
||||
|
||||
func (repo *TokenVerifierRepo) verifyAccessTokenV2(ctx context.Context, token, verifierClientID, projectID string) (userID, clientID, resourceOwner string, err error) {
|
||||
func (repo *TokenVerifierRepo) verifyAccessTokenV2(ctx context.Context, token, verifierClientID, projectID string) (userID, agentID, clientID, prefLang, resourceOwner string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
activeToken, err := repo.Query.ActiveAccessTokenByToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
return "", "", "", "", "", err
|
||||
}
|
||||
if activeToken.Actor != nil {
|
||||
return "", "", "", zerrors.ThrowPermissionDenied(nil, "APP-Shi0J", "Errors.TokenExchange.Token.NotForAPI")
|
||||
return "", "", "", "", "", zerrors.ThrowPermissionDenied(nil, "APP-Shi0J", "Errors.TokenExchange.Token.NotForAPI")
|
||||
}
|
||||
if err = verifyAudience(activeToken.Audience, verifierClientID, projectID); err != nil {
|
||||
return "", "", "", err
|
||||
return "", "", "", "", "", err
|
||||
}
|
||||
if err = repo.checkAuthentication(ctx, activeToken.AuthMethods, activeToken.UserID); err != nil {
|
||||
return "", "", "", err
|
||||
return "", "", "", "", "", err
|
||||
}
|
||||
return activeToken.UserID, activeToken.ClientID, activeToken.ResourceOwner, nil
|
||||
prefLang = gu.Value(activeToken.PreferredLanguage).String()
|
||||
agentID = gu.Value(gu.Value(activeToken.UserAgent).FingerprintID)
|
||||
|
||||
return activeToken.UserID, agentID, activeToken.ClientID, prefLang, activeToken.ResourceOwner, nil
|
||||
}
|
||||
|
||||
func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID, token string) (userID, clientID, resourceOwner string, err error) {
|
||||
|
@ -12,21 +12,22 @@ import (
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
type CurrentAuthRequest struct {
|
||||
@ -69,6 +70,7 @@ func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest)
|
||||
authRequest.MaxAge,
|
||||
authRequest.LoginHint,
|
||||
authRequest.HintUserID,
|
||||
authRequest.NeedRefreshToken,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -148,25 +150,6 @@ func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code s
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
|
||||
if code == "" {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
|
||||
return &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
|
@ -34,6 +34,7 @@ type AuthRequestWriteModel struct {
|
||||
AuthTime time.Time
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthRequestState domain.AuthRequestState
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
|
||||
@ -64,6 +65,7 @@ func (m *AuthRequestWriteModel) Reduce() error {
|
||||
m.LoginHint = e.LoginHint
|
||||
m.HintUserID = e.HintUserID
|
||||
m.AuthRequestState = domain.AuthRequestStateAdded
|
||||
m.NeedRefreshToken = e.NeedRefreshToken
|
||||
case *authrequest.SessionLinkedEvent:
|
||||
m.SessionID = e.SessionID
|
||||
m.UserID = e.UserID
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -59,6 +60,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
false,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -96,6 +98,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
false,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -223,6 +226,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
@ -263,6 +267,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -301,6 +306,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -338,6 +344,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -354,7 +361,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow.Add(-5*time.Minute)),
|
||||
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -398,6 +405,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -447,6 +455,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -463,7 +472,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -532,6 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -548,7 +558,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -674,6 +684,7 @@ func TestCommands_FailAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -771,6 +782,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -807,6 +819,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
@ -841,166 +854,3 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeAuthCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
authRequest *CurrentAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "",
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"no code added error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"code exchanged",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
authRequest: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_authRequestID",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
|
||||
assert.ErrorIs(t, tt.res.err, err)
|
||||
|
||||
if err == nil {
|
||||
// equal on time won't work -> test separately and clear it before comparing the rest
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authRequest, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,16 +2,20 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string, needRefreshToken bool) (*domain.ObjectDetails, error) {
|
||||
aggr := deviceauth.NewAggregate(deviceCode, authz.GetInstance(ctx).InstanceID())
|
||||
model := NewDeviceAuthWriteModel(deviceCode, aggr.ResourceOwner)
|
||||
|
||||
@ -24,6 +28,7 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
|
||||
expires,
|
||||
scopes,
|
||||
audience,
|
||||
needRefreshToken,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -36,7 +41,16 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
|
||||
return writeModelToObjectDetails(&model.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject string, authMethods []domain.UserAuthMethodType, authTime time.Time) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) ApproveDeviceAuth(
|
||||
ctx context.Context,
|
||||
deviceCode,
|
||||
userID,
|
||||
userOrgID string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) (*domain.ObjectDetails, error) {
|
||||
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -44,9 +58,7 @@ func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject st
|
||||
if !model.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound")
|
||||
}
|
||||
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject, authMethods, authTime))
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, model.aggregate, userID, userOrgID, authMethods, authTime, preferredLanguage, userAgent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -66,9 +78,7 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
|
||||
if !model.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound")
|
||||
}
|
||||
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, aggr, reason))
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, model.aggregate, reason))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -81,10 +91,89 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
|
||||
}
|
||||
|
||||
func (c *Commands) getDeviceAuthWriteModelByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuthWriteModel, error) {
|
||||
model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: deviceCode}}
|
||||
model := &DeviceAuthWriteModel{
|
||||
WriteModel: eventstore.WriteModel{AggregateID: deviceCode},
|
||||
}
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.aggregate = deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
type DeviceAuthStateError domain.DeviceAuthState
|
||||
|
||||
func (e DeviceAuthStateError) Error() string {
|
||||
return fmt.Sprintf("device auth state not approved: %s", domain.DeviceAuthState(e).String())
|
||||
}
|
||||
|
||||
// CreateOIDCSessionFromDeviceAuth creates a new OIDC session if the device authorization
|
||||
// flow is completed (user logged in).
|
||||
// A [DeviceAuthStateError] is returned if the device authorization was not approved,
|
||||
// containing a [domain.DeviceAuthState] which can be used to inform the client about the state.
|
||||
//
|
||||
// As devices can poll at various intervals, an explicit state takes precedence over expiry.
|
||||
// This is to prevent cases where users might approve or deny the authorization on time, but the next poll
|
||||
// happens after expiry.
|
||||
func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCode string) (_ *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
deviceAuthModel, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch deviceAuthModel.State {
|
||||
case domain.DeviceAuthStateApproved:
|
||||
break
|
||||
case domain.DeviceAuthStateUndefined:
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound")
|
||||
|
||||
case domain.DeviceAuthStateInitiated:
|
||||
if deviceAuthModel.Expires.Before(time.Now()) {
|
||||
c.asyncPush(ctx, deviceauth.NewCanceledEvent(ctx, deviceAuthModel.aggregate, domain.DeviceAuthCanceledExpired))
|
||||
return nil, DeviceAuthStateError(domain.DeviceAuthStateExpired)
|
||||
}
|
||||
fallthrough
|
||||
case domain.DeviceAuthStateDenied, domain.DeviceAuthStateExpired, domain.DeviceAuthStateDone:
|
||||
fallthrough
|
||||
default:
|
||||
return nil, DeviceAuthStateError(deviceAuthModel.State)
|
||||
}
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, deviceAuthModel.UserOrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx,
|
||||
deviceAuthModel.UserID,
|
||||
deviceAuthModel.UserOrgID,
|
||||
"",
|
||||
deviceAuthModel.ClientID,
|
||||
deviceAuthModel.Audience,
|
||||
deviceAuthModel.Scopes,
|
||||
deviceAuthModel.UserAuthMethods,
|
||||
deviceAuthModel.AuthTime,
|
||||
"",
|
||||
deviceAuthModel.PreferredLanguage,
|
||||
deviceAuthModel.UserAgent,
|
||||
)
|
||||
if err = cmd.AddAccessToken(ctx, deviceAuthModel.Scopes, deviceAuthModel.UserID, deviceAuthModel.UserOrgID, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if deviceAuthModel.NeedRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, deviceAuthModel.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
cmd.DeviceAuthRequestDone(ctx, deviceAuthModel.aggregate)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
func (cmd *OIDCSessionEvents) DeviceAuthRequestDone(ctx context.Context, deviceAuthAggregate *eventstore.Aggregate) {
|
||||
cmd.events = append(cmd.events, deviceauth.NewDoneEvent(ctx, deviceAuthAggregate))
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
@ -10,16 +12,22 @@ import (
|
||||
|
||||
type DeviceAuthWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
aggregate *eventstore.Aggregate
|
||||
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
State domain.DeviceAuthState
|
||||
Subject string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
Audience []string
|
||||
State domain.DeviceAuthState
|
||||
UserID string
|
||||
UserOrgID string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
|
||||
@ -28,6 +36,7 @@ func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteM
|
||||
AggregateID: deviceCode,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
aggregate: deviceauth.NewAggregate(deviceCode, resourceOwner),
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,14 +49,21 @@ func (m *DeviceAuthWriteModel) Reduce() error {
|
||||
m.UserCode = e.UserCode
|
||||
m.Expires = e.Expires
|
||||
m.Scopes = e.Scopes
|
||||
m.Audience = e.Audience
|
||||
m.State = e.State
|
||||
m.NeedRefreshToken = e.NeedRefreshToken
|
||||
case *deviceauth.ApprovedEvent:
|
||||
m.State = domain.DeviceAuthStateApproved
|
||||
m.Subject = e.Subject
|
||||
m.UserID = e.UserID
|
||||
m.UserOrgID = e.UserOrgID
|
||||
m.UserAuthMethods = e.UserAuthMethods
|
||||
m.AuthTime = e.AuthTime
|
||||
m.PreferredLanguage = e.PreferredLanguage
|
||||
m.UserAgent = e.UserAgent
|
||||
case *deviceauth.CanceledEvent:
|
||||
m.State = e.Reason.State()
|
||||
case *deviceauth.DoneEvent:
|
||||
m.State = domain.DeviceAuthStateDone
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,16 +3,27 @@ package command
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@ -25,16 +36,17 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
require.Len(t, unique, 2)
|
||||
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
expires time.Time
|
||||
scopes []string
|
||||
audience []string
|
||||
ctx context.Context
|
||||
clientID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
expires time.Time
|
||||
scopes []string
|
||||
audience []string
|
||||
needRefreshToken bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -46,24 +58,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t, expectPush(
|
||||
eventstore: expectEventstore(expectPush(
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
needRefreshToken: true,
|
||||
},
|
||||
wantDetails: &domain.ObjectDetails{
|
||||
ResourceOwner: "instance1",
|
||||
@ -72,24 +85,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
{
|
||||
name: "push error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t, expectPushFailed(pushErr,
|
||||
eventstore: expectEventstore(expectPushFailed(pushErr,
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, false,
|
||||
)),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
needRefreshToken: false,
|
||||
},
|
||||
wantErr: pushErr,
|
||||
},
|
||||
@ -97,9 +111,9 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience)
|
||||
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience, tt.args.needRefreshToken)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantDetails, gotDetails)
|
||||
})
|
||||
@ -115,11 +129,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
subject string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
ctx context.Context
|
||||
id string
|
||||
userID string
|
||||
userOrgID string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
preferredLanguage *language.Tag
|
||||
userAgent *domain.UserAgent
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -136,9 +153,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
|
||||
},
|
||||
@ -153,22 +175,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPushFailed(pushErr,
|
||||
deviceauth.NewApprovedEvent(
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantErr: pushErr,
|
||||
},
|
||||
@ -183,22 +215,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
deviceauth.NewApprovedEvent(
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantDetails: &domain.ObjectDetails{
|
||||
ResourceOwner: "instance1",
|
||||
@ -210,7 +252,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject, tt.args.authMethods, tt.args.authTime)
|
||||
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, gotDetails, tt.wantDetails)
|
||||
})
|
||||
@ -258,7 +300,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPushFailed(pushErr,
|
||||
@ -283,7 +325,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
@ -310,7 +352,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
@ -338,3 +380,392 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instance1")
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
deviceCode string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *OIDCSession
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "device auth filter error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"device1",
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "not yet approved",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateInitiated),
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPushSlow(time.Second, deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledExpired,
|
||||
)),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
|
||||
},
|
||||
{
|
||||
name: "already expired",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledExpired,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
|
||||
},
|
||||
{
|
||||
name: "denied",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledDenied,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDenied),
|
||||
},
|
||||
{
|
||||
name: "already done",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledDenied,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone),
|
||||
},
|
||||
{
|
||||
name: "approved, success",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewApprovedEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"userID", "org1",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
testNow, &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(),
|
||||
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
|
||||
),
|
||||
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
want: &OIDCSession{
|
||||
TokenID: "V2_oidcSessionID-at_accessTokenID",
|
||||
ClientID: "clientID",
|
||||
UserID: "userID",
|
||||
Audience: []string{"audience"},
|
||||
Expiration: time.Time{}.Add(time.Hour),
|
||||
Scope: []string{"openid", "offline_access"},
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
AuthTime: testNow,
|
||||
PreferredLanguage: &language.Afrikaans,
|
||||
UserAgent: &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
Reason: domain.TokenReasonAuthRequest,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "approved, with refresh token",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewApprovedEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"userID", "org1",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
testNow, &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(),
|
||||
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
|
||||
),
|
||||
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour,
|
||||
),
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
want: &OIDCSession{
|
||||
TokenID: "V2_oidcSessionID-at_accessTokenID",
|
||||
ClientID: "clientID",
|
||||
UserID: "userID",
|
||||
Audience: []string{"audience"},
|
||||
Expiration: time.Time{}.Add(time.Hour),
|
||||
Scope: []string{"openid", "offline_access"},
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
AuthTime: testNow,
|
||||
PreferredLanguage: &language.Afrikaans,
|
||||
UserAgent: &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
Reason: domain.TokenReasonAuthRequest,
|
||||
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.CreateOIDCSessionFromDeviceAuth(tt.args.ctx, tt.args.deviceCode)
|
||||
c.jobs.Wait()
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
if got != nil {
|
||||
assert.WithinRange(t, got.AuthTime, tt.want.AuthTime.Add(-time.Second), tt.want.AuthTime.Add(time.Second))
|
||||
got.AuthTime = time.Time{}
|
||||
tt.want.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -17,6 +19,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@ -28,60 +31,175 @@ const (
|
||||
oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s"
|
||||
)
|
||||
|
||||
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
|
||||
return accessTokenID, accessTokenExpiration, err
|
||||
type OIDCSession struct {
|
||||
SessionID string
|
||||
TokenID string
|
||||
ClientID string
|
||||
UserID string
|
||||
Audience []string
|
||||
Expiration time.Time
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
Reason domain.TokenReason
|
||||
Actor *domain.TokenActor
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
|
||||
type AuthRequestComplianceChecker func(context.Context, *AuthRequestWriteModel) error
|
||||
|
||||
// CreateOIDCSessionFromAuthRequest creates a new OIDC Session, creates an access token and refresh token.
|
||||
// It returns the access token id, expiration and the refresh token.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReqId string, complianceCheck AuthRequestComplianceChecker, needRefreshToken bool) (session *OIDCSession, state string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if authReqId == "" {
|
||||
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
|
||||
authReqModel, err := c.getAuthRequestWriteModel(ctx, authReqId)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, "", err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
|
||||
if authReqModel.ResponseType == domain.OIDCResponseTypeCode && authReqModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
if err = cmd.AddRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
|
||||
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err = sessionModel.CheckIsActive(); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, sessionModel.UserResourceOwner)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if authReqModel.ResponseType == domain.OIDCResponseTypeCode {
|
||||
if err = cmd.SetAuthRequestCodeExchanged(ctx, authReqModel); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
if err = complianceCheck(ctx, authReqModel); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx,
|
||||
sessionModel.UserID,
|
||||
sessionModel.UserResourceOwner,
|
||||
sessionModel.AggregateID,
|
||||
authReqModel.ClientID,
|
||||
authReqModel.Audience,
|
||||
authReqModel.Scope,
|
||||
authReqModel.AuthMethods,
|
||||
authReqModel.AuthTime,
|
||||
authReqModel.Nonce,
|
||||
sessionModel.PreferredLanguage,
|
||||
sessionModel.UserAgent,
|
||||
)
|
||||
|
||||
if authReqModel.ResponseType != domain.OIDCResponseTypeIDToken {
|
||||
if err = cmd.AddAccessToken(ctx, authReqModel.Scope, sessionModel.UserID, sessionModel.UserResourceOwner, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
if authReqModel.NeedRefreshToken && needRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, sessionModel.UserID); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
|
||||
session, err = cmd.PushEvents(ctx)
|
||||
return session, authReqModel.State, err
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOIDCSession(ctx context.Context,
|
||||
userID,
|
||||
resourceOwner,
|
||||
clientID string,
|
||||
scope,
|
||||
audience []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
nonce string,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
needRefreshToken bool,
|
||||
) (session *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason == domain.TokenReasonImpersonation {
|
||||
if err := c.checkPermission(ctx, "impersonation", resourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd.UserImpersonated(ctx, userID, resourceOwner, clientID, actor)
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx, userID, resourceOwner, "", clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
|
||||
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
|
||||
|
||||
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
|
||||
// It returns the access token id and expiration and the new refresh token.
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, refreshToken string, scope []string, complianceCheck RefreshTokenComplianceChecker) (_ *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.AddAccessToken(ctx, scope, domain.TokenReasonRefresh, nil); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
scope, err = complianceCheck(ctx, cmd.oidcSessionWriteModel, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = cmd.AddAccessToken(ctx, scope,
|
||||
cmd.oidcSessionWriteModel.UserID,
|
||||
cmd.oidcSessionWriteModel.UserResourceOwner,
|
||||
domain.TokenReasonRefresh,
|
||||
cmd.oidcSessionWriteModel.AccessTokenActor,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.RenewRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
|
||||
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (_ *OIDCSessionWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -146,26 +264,7 @@ func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID s
|
||||
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
|
||||
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetInstance(ctx).InstanceID())
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = sessionWriteModel.CheckIsActive(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resourceOwner, err := c.getResourceOwnerOfSessionUser(ctx, sessionWriteModel.UserID, sessionWriteModel.InstanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, resourceOwner string, pending ...eventstore.Command) (*OIDCSessionEvents, error) {
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -179,42 +278,24 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID st
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
events: pending,
|
||||
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner),
|
||||
sessionWriteModel: sessionWriteModel,
|
||||
authRequestWriteModel: authRequestWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) getResourceOwnerOfSessionUser(ctx context.Context, userID, instanceID string) (string, error) {
|
||||
events, err := c.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID(instanceID).
|
||||
AllowTimeTravel().
|
||||
OrderAsc().
|
||||
Limit(1).
|
||||
AddQuery().
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(userID).
|
||||
Builder())
|
||||
if err != nil || len(events) != 1 {
|
||||
return "", zerrors.ThrowInternal(err, "OIDCS-sferh", "Errors.Internal")
|
||||
}
|
||||
return events[0].Aggregate().ResourceOwner, nil
|
||||
}
|
||||
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshTokenID string, err error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
|
||||
if err != nil {
|
||||
return "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
|
||||
return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
_, refreshTokenID, err = parseRefreshToken(decrypted)
|
||||
return refreshTokenID, err
|
||||
return parseRefreshToken(decrypted)
|
||||
}
|
||||
|
||||
func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) {
|
||||
@ -227,8 +308,8 @@ func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID strin
|
||||
return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
oidcSessionID, refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -255,13 +336,12 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID
|
||||
}
|
||||
|
||||
type OIDCSessionEvents struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
sessionWriteModel *SessionWriteModel
|
||||
authRequestWriteModel *AuthRequestWriteModel
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
|
||||
accessTokenLifetime time.Duration
|
||||
refreshTokenLifeTime time.Duration
|
||||
refreshTokenIdleLifetime time.Duration
|
||||
@ -270,44 +350,75 @@ type OIDCSessionEvents struct {
|
||||
accessTokenID string
|
||||
|
||||
// refreshToken is set by the command
|
||||
refreshToken string
|
||||
refreshTokenID string
|
||||
refreshToken string
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
|
||||
func (c *OIDCSessionEvents) AddSession(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
clientID string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
nonce string,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) {
|
||||
c.events = append(c.events, oidcsession.NewAddedEvent(
|
||||
ctx,
|
||||
c.oidcSessionWriteModel.aggregate,
|
||||
c.sessionWriteModel.UserID,
|
||||
c.sessionWriteModel.AggregateID,
|
||||
c.authRequestWriteModel.ClientID,
|
||||
c.authRequestWriteModel.Audience,
|
||||
c.authRequestWriteModel.Scope,
|
||||
c.sessionWriteModel.AuthMethodTypes(),
|
||||
c.sessionWriteModel.AuthenticationTime(),
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
clientID,
|
||||
audience,
|
||||
scope,
|
||||
authMethods,
|
||||
authTime,
|
||||
nonce,
|
||||
preferredLanguage,
|
||||
userAgent,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
|
||||
func (c *OIDCSessionEvents) SetAuthRequestCodeExchanged(ctx context.Context, model *AuthRequestWriteModel) error {
|
||||
event := authrequest.NewCodeExchangedEvent(ctx, model.aggregate)
|
||||
model.AppendEvents(event)
|
||||
c.events = append(c.events, event)
|
||||
return model.Reduce()
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, reason domain.TokenReason, actor *domain.TokenActor) error {
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context, authRequestAggregate *eventstore.Aggregate) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, authRequestAggregate))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestFailed(ctx context.Context, authRequestAggregate *eventstore.Aggregate, err error) {
|
||||
c.events = append(c.events, authrequest.NewFailedEvent(ctx, authRequestAggregate, domain.OIDCErrorReasonFromError(err)))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, userID, resourceOwner string, reason domain.TokenReason, actor *domain.TokenActor) error {
|
||||
accessTokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.accessTokenID = AccessTokenPrefix + accessTokenID
|
||||
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor))
|
||||
c.events = append(c.events,
|
||||
oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor),
|
||||
user.NewUserTokenV2AddedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, c.accessTokenID), // for user audit log
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID)
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context, userID string) (err error) {
|
||||
c.refreshTokenID, c.refreshToken, err = c.generateRefreshToken(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -321,6 +432,10 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) UserImpersonated(ctx context.Context, userID, resourceOwner, clientID string, actor *domain.TokenActor) {
|
||||
c.events = append(c.events, user.NewUserImpersonatedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, clientID, actor))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) {
|
||||
refreshTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
@ -334,18 +449,38 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
|
||||
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
|
||||
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
session := &OIDCSession{
|
||||
SessionID: c.oidcSessionWriteModel.SessionID,
|
||||
ClientID: c.oidcSessionWriteModel.ClientID,
|
||||
UserID: c.oidcSessionWriteModel.UserID,
|
||||
Audience: c.oidcSessionWriteModel.Audience,
|
||||
Expiration: c.oidcSessionWriteModel.AccessTokenExpiration,
|
||||
Scope: c.oidcSessionWriteModel.Scope,
|
||||
AuthMethods: c.oidcSessionWriteModel.AuthMethods,
|
||||
AuthTime: c.oidcSessionWriteModel.AuthTime,
|
||||
Nonce: c.oidcSessionWriteModel.Nonce,
|
||||
PreferredLanguage: c.oidcSessionWriteModel.PreferredLanguage,
|
||||
UserAgent: c.oidcSessionWriteModel.UserAgent,
|
||||
Reason: c.oidcSessionWriteModel.AccessTokenReason,
|
||||
Actor: c.oidcSessionWriteModel.AccessTokenActor,
|
||||
RefreshToken: c.refreshToken,
|
||||
}
|
||||
if c.accessTokenID != "" {
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
|
||||
}
|
||||
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
|
||||
@ -368,3 +503,14 @@ func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime
|
||||
}
|
||||
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
|
||||
}
|
||||
|
||||
func tokenReasonToActivityMethodType(r domain.TokenReason) activity.TriggerMethod {
|
||||
if r == domain.TokenReasonUnspecified {
|
||||
return activity.Unspecified
|
||||
}
|
||||
if r == domain.TokenReasonRefresh {
|
||||
return activity.OIDCRefreshToken
|
||||
}
|
||||
// all other reasons result in an access token
|
||||
return activity.OIDCAccessToken
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
@ -13,12 +15,16 @@ type OIDCSessionWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
UserResourceOwner string
|
||||
PreferredLanguage *language.Tag
|
||||
SessionID string
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
UserAgent *domain.UserAgent
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenID string
|
||||
AccessTokenCreation time.Time
|
||||
@ -85,12 +91,16 @@ func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.UserResourceOwner = e.UserResourceOwner
|
||||
wm.SessionID = e.SessionID
|
||||
wm.ClientID = e.ClientID
|
||||
wm.Audience = e.Audience
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethods = e.AuthMethods
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.Nonce = e.Nonce
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
wm.UserAgent = e.UserAgent
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
// the write model might be initialized without resource owner,
|
||||
// so update the aggregate
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
@ -56,12 +58,12 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
|
||||
}
|
||||
|
||||
// CheckUser defines a user check to be executed for a session update
|
||||
func CheckUser(id string, resourceOwner string) SessionCommand {
|
||||
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
|
||||
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
|
||||
}
|
||||
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now())
|
||||
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,8 +173,8 @@ func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent
|
||||
s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate, userAgent))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time) error {
|
||||
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt))
|
||||
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time, preferredLanguage *language.Tag) error {
|
||||
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt, preferredLanguage))
|
||||
// set the userID so other checks can use it
|
||||
s.sessionWriteModel.UserID = userID
|
||||
s.sessionWriteModel.UserResourceOwner = resourceOwner
|
||||
|
@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@ -40,6 +42,7 @@ type SessionWriteModel struct {
|
||||
TokenID string
|
||||
UserID string
|
||||
UserResourceOwner string
|
||||
PreferredLanguage *language.Tag
|
||||
UserCheckedAt time.Time
|
||||
PasswordCheckedAt time.Time
|
||||
IntentCheckedAt time.Time
|
||||
@ -50,6 +53,7 @@ type SessionWriteModel struct {
|
||||
WebAuthNUserVerified bool
|
||||
Metadata map[string][]byte
|
||||
State domain.SessionState
|
||||
UserAgent *domain.UserAgent
|
||||
Expiration time.Time
|
||||
|
||||
WebAuthNChallenge *WebAuthNChallengeModel
|
||||
@ -137,12 +141,14 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
|
||||
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
|
||||
wm.State = domain.SessionStateActive
|
||||
wm.UserAgent = e.UserAgent
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceUserChecked(e *session.UserCheckedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.UserResourceOwner = e.UserResourceOwner
|
||||
wm.UserCheckedAt = e.CheckedAt
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reducePasswordChecked(e *session.PasswordCheckedEvent) {
|
||||
|
@ -566,7 +566,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectPush(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow,
|
||||
"userID", "org1", testNow, &language.Afrikaans,
|
||||
),
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow,
|
||||
@ -585,7 +585,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckPassword("password"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@ -634,7 +634,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@ -673,7 +673,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@ -723,7 +723,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent2", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t),
|
||||
@ -751,7 +751,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectPush(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
session.NewIntentCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow),
|
||||
session.NewMetadataSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -766,7 +766,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@ -1188,7 +1188,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"user1", "org1", testNow),
|
||||
"user1", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -1229,7 +1229,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@ -1271,7 +1271,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,
|
||||
"userID", "", testNow),
|
||||
"userID", "", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
@ -13,7 +12,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -232,35 +230,6 @@ func (c *Commands) RemoveUser(ctx context.Context, userID, resourceOwner string,
|
||||
return writeModelToObjectDetails(&existingUser.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddUserToken(
|
||||
ctx context.Context,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID,
|
||||
userID string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
lifetime time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (*domain.Token, error) {
|
||||
if userID == "" { //do not check for empty orgID (JWT Profile requests won't provide it, so service user requests fail)
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dbge4", "Errors.IDMissing")
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, "", audience, scopes, authMethodsReferences, lifetime, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, cmds...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
|
||||
removeEvent, accessTokenWriteModel, err := c.removeAccessToken(ctx, userID, orgID, tokenID)
|
||||
if err != nil {
|
||||
@ -277,61 +246,6 @@ func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID
|
||||
return writeModelToObjectDetails(&accessTokenWriteModel.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) addUserToken(ctx context.Context, userWriteModel *UserWriteModel, agentID, clientID, refreshTokenID string, audience, scopes, authMethodsReferences []string, lifetime time.Duration, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) ([]eventstore.Command, *domain.Token, error) {
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, userWriteModel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if userWriteModel.UserState != domain.UserStateActive {
|
||||
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-1d6Gg", "Errors.User.NotFound")
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
userAgg := UserAggregateFromWriteModel(&userWriteModel.WriteModel)
|
||||
|
||||
var cmds []eventstore.Command
|
||||
if reason == domain.TokenReasonImpersonation {
|
||||
if err := c.checkPermission(ctx, "impersonation", userWriteModel.ResourceOwner, userWriteModel.AggregateID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cmds = append(cmds, user.NewUserImpersonatedEvent(ctx, userAgg, clientID, actor))
|
||||
}
|
||||
|
||||
preferredLanguage := ""
|
||||
existingHuman, err := c.getHumanWriteModelByID(ctx, userWriteModel.AggregateID, userWriteModel.ResourceOwner)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if existingHuman != nil {
|
||||
preferredLanguage = existingHuman.PreferredLanguage.String()
|
||||
}
|
||||
expiration := time.Now().UTC().Add(lifetime)
|
||||
tokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cmds = append(cmds,
|
||||
user.NewUserTokenAddedEvent(ctx, userAgg, tokenID, clientID, agentID, preferredLanguage, refreshTokenID, audience, scopes, authMethodsReferences, authTime, expiration, reason, actor),
|
||||
)
|
||||
|
||||
return cmds, &domain.Token{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: userWriteModel.AggregateID,
|
||||
},
|
||||
TokenID: tokenID,
|
||||
UserAgentID: agentID,
|
||||
ApplicationID: clientID,
|
||||
RefreshTokenID: refreshTokenID,
|
||||
Audience: audience,
|
||||
Scopes: scopes,
|
||||
Expiration: expiration,
|
||||
PreferredLanguage: preferredLanguage,
|
||||
Reason: reason,
|
||||
Actor: actor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) removeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*user.UserTokenRemovedEvent, *UserAccessTokenWriteModel, error) {
|
||||
if userID == "" || orgID == "" || tokenID == "" {
|
||||
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dng42", "Errors.IDMissing")
|
||||
|
@ -1,9 +1,6 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
)
|
||||
@ -81,16 +78,6 @@ func writeModelToAddress(wm *HumanAddressWriteModel) *domain.Address {
|
||||
}
|
||||
}
|
||||
|
||||
func writeModelToMachine(wm *MachineWriteModel) *domain.Machine {
|
||||
return &domain.Machine{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
Username: wm.UserName,
|
||||
Name: wm.Name,
|
||||
Description: wm.Description,
|
||||
State: wm.UserState,
|
||||
}
|
||||
}
|
||||
|
||||
func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
|
||||
return &domain.MachineKey{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
@ -100,18 +87,6 @@ func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
|
||||
}
|
||||
}
|
||||
|
||||
func personalTokenWriteModelToToken(wm *PersonalAccessTokenWriteModel, algorithm crypto.EncryptionAlgorithm) (*domain.Token, string, error) {
|
||||
encrypted, err := algorithm.Encrypt([]byte(wm.TokenID + ":" + wm.AggregateID))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return &domain.Token{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
TokenID: wm.TokenID,
|
||||
Expiration: wm.ExpirationDate,
|
||||
}, base64.RawURLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func readModelToWebAuthNTokens(readModel HumanWebAuthNTokensReadModel) []*domain.WebAuthNToken {
|
||||
tokens := make([]*domain.WebAuthNToken, len(readModel.GetWebAuthNTokens()))
|
||||
for i, token := range readModel.GetWebAuthNTokens() {
|
||||
|
@ -2,7 +2,6 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@ -10,98 +9,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (c *Commands) AddAccessAndRefreshToken(
|
||||
ctx context.Context,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID,
|
||||
userID,
|
||||
refreshToken string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
accessLifetime,
|
||||
refreshIdleExpiration,
|
||||
refreshExpiration time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
if refreshToken == "" {
|
||||
return c.AddNewRefreshTokenAndAccessToken(ctx, userID, orgID, agentID, clientID, audience, scopes, authMethodsReferences, refreshExpiration, accessLifetime, refreshIdleExpiration, authTime, reason, actor)
|
||||
}
|
||||
return c.RenewRefreshTokenAndAccessToken(ctx, userID, orgID, refreshToken, agentID, clientID, audience, scopes, refreshIdleExpiration, accessLifetime, actor)
|
||||
}
|
||||
|
||||
func (c *Commands) AddNewRefreshTokenAndAccessToken(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
refreshExpiration,
|
||||
accessLifetime,
|
||||
refreshIdleExpiration time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
if userID == "" || clientID == "" {
|
||||
return nil, "", zerrors.ThrowInvalidArgument(nil, "COMMAND-adg4r", "Errors.IDMissing")
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
refreshTokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, refreshTokenID, audience, scopes, authMethodsReferences, accessLifetime, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
refreshTokenEvent, newRefreshToken, err := c.addRefreshToken(ctx, accessToken, authMethodsReferences, authTime, refreshIdleExpiration, refreshExpiration, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cmds = append(cmds, refreshTokenEvent)
|
||||
_, err = c.eventstore.Push(ctx, cmds...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return accessToken, newRefreshToken, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RenewRefreshTokenAndAccessToken(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
orgID,
|
||||
refreshToken,
|
||||
agentID,
|
||||
clientID string,
|
||||
audience,
|
||||
scopes []string,
|
||||
idleExpiration,
|
||||
accessLifetime time.Duration,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
renewed, err := c.renewRefreshToken(ctx, userID, orgID, refreshToken, idleExpiration)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, renewed.tokenID, audience, scopes, renewed.authMethodsReferences, accessLifetime, renewed.authTime, domain.TokenReasonRefresh, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, append(cmds, renewed.event)...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return accessToken, renewed.token, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RevokeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
|
||||
removeEvent, refreshTokenWriteModel, err := c.removeRefreshToken(ctx, userID, orgID, tokenID)
|
||||
if err != nil {
|
||||
@ -134,70 +41,6 @@ func (c *Commands) RevokeRefreshTokens(ctx context.Context, userID, orgID string
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) addRefreshToken(ctx context.Context, accessToken *domain.Token, authMethodsReferences []string, authTime time.Time, idleExpiration, expiration time.Duration, actor *domain.TokenActor) (*user.HumanRefreshTokenAddedEvent, string, error) {
|
||||
refreshToken, err := domain.NewRefreshToken(accessToken.AggregateID, accessToken.RefreshTokenID, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(accessToken.AggregateID, accessToken.ResourceOwner, accessToken.RefreshTokenID)
|
||||
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
|
||||
return user.NewHumanRefreshTokenAddedEvent(ctx, userAgg, accessToken.RefreshTokenID, accessToken.ApplicationID, accessToken.UserAgentID,
|
||||
accessToken.PreferredLanguage, accessToken.Audience, accessToken.Scopes, authMethodsReferences, authTime, idleExpiration, expiration, actor),
|
||||
refreshToken, nil
|
||||
}
|
||||
|
||||
type renewedRefreshToken struct {
|
||||
event *user.HumanRefreshTokenRenewedEvent
|
||||
authTime time.Time
|
||||
authMethodsReferences []string
|
||||
tokenID string
|
||||
token string
|
||||
}
|
||||
|
||||
func (c *Commands) renewRefreshToken(ctx context.Context, userID, orgID, refreshToken string, idleExpiration time.Duration) (*renewedRefreshToken, error) {
|
||||
if refreshToken == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-DHrr3", "Errors.IDMissing")
|
||||
}
|
||||
|
||||
tokenUserID, tokenID, token, err := domain.FromRefreshToken(refreshToken, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenUserID != userID {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Ht2g2", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(userID, orgID, tokenID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, refreshTokenWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshTokenWriteModel.UserState != domain.UserStateActive {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-BHnhs", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
if refreshTokenWriteModel.RefreshToken != token ||
|
||||
refreshTokenWriteModel.IdleExpiration.Before(time.Now()) ||
|
||||
refreshTokenWriteModel.Expiration.Before(time.Now()) {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Vr43e", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
|
||||
newToken, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRefreshToken, err := domain.RefreshToken(userID, tokenID, newToken, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
|
||||
return &renewedRefreshToken{
|
||||
event: user.NewHumanRefreshTokenRenewedEvent(ctx, userAgg, tokenID, newToken, idleExpiration),
|
||||
authTime: refreshTokenWriteModel.AuthTime,
|
||||
authMethodsReferences: refreshTokenWriteModel.AuthMethodsReferences,
|
||||
tokenID: tokenID,
|
||||
token: newRefreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) removeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*user.HumanRefreshTokenRemovedEvent, *HumanRefreshTokenWriteModel, error) {
|
||||
if userID == "" || orgID == "" || tokenID == "" {
|
||||
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-GVDgf", "Errors.IDMissing")
|
||||
|
@ -2,316 +2,18 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestCommands_AddAccessAndRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
orgID string
|
||||
agentID string
|
||||
clientID string
|
||||
userID string
|
||||
refreshToken string
|
||||
audience []string
|
||||
scopes []string
|
||||
authMethodsReferences []string
|
||||
lifetime time.Duration
|
||||
authTime time.Time
|
||||
refreshIdleExpiration time.Duration
|
||||
refreshExpiration time.Duration
|
||||
reason domain.TokenReason
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
type res struct {
|
||||
token *domain.Token
|
||||
refreshToken string
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "missing ID, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args: args{},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add refresh token, user deactivated, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewUserDeactivatedEvent(context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshTokenID1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "orgID",
|
||||
agentID: "agentID",
|
||||
userID: "userID",
|
||||
clientID: "clientID",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, invalid token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, invalid token (invalid userID), error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, token inactive, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, token expired, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
-1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
//fails because of timestamp equality
|
||||
//{
|
||||
// name: "push failed, error",
|
||||
// fields: fields{
|
||||
// eventstore: eventstoreExpect(t,
|
||||
// expectFilter(
|
||||
// eventFromEventPusher(user.NewHumanAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "username",
|
||||
// "firstname",
|
||||
// "lastname",
|
||||
// "nickname",
|
||||
// "displayname",
|
||||
// language.German,
|
||||
// domain.GenderUnspecified,
|
||||
// "email",
|
||||
// true,
|
||||
// )),
|
||||
// ),
|
||||
// expectFilter(
|
||||
// eventFromEventPusherWithCreationDateNow(user.NewHumanAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "username",
|
||||
// "firstname",
|
||||
// "lastname",
|
||||
// "nickname",
|
||||
// "displayname",
|
||||
// language.German,
|
||||
// domain.GenderUnspecified,
|
||||
// "email",
|
||||
// true,
|
||||
// )),
|
||||
// ),
|
||||
// expectFilter(
|
||||
// eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "tokenID",
|
||||
// "applicationID",
|
||||
// "userAgentID",
|
||||
// "de",
|
||||
// []string{"clientID1"},
|
||||
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// []string{"password"},
|
||||
// time.Now(),
|
||||
// 1*time.Hour,
|
||||
// 24*time.Hour,
|
||||
// )),
|
||||
// ),
|
||||
// expectPushFailed(
|
||||
// zerrors.ThrowInternal(nil, "ERROR", "internal"),
|
||||
// []*repository.Event{
|
||||
// eventFromEventPusher(user.NewUserTokenAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "accessTokenID1",
|
||||
// "clientID",
|
||||
// "agentID",
|
||||
// "de",
|
||||
// []string{"clientID1"},
|
||||
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// time.Now().Add(5*time.Minute),
|
||||
// )),
|
||||
// eventFromEventPusher(user.NewHumanRefreshTokenRenewedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "tokenID",
|
||||
// "refreshToken1",
|
||||
// 1*time.Hour,
|
||||
// )),
|
||||
// },
|
||||
// ),
|
||||
// ),
|
||||
// idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "accessTokenID1", "refreshToken1"),
|
||||
// keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
// },
|
||||
// args: args{
|
||||
// ctx: context.Background(),
|
||||
// orgID: "orgID",
|
||||
// agentID: "agentID",
|
||||
// clientID: "clientID",
|
||||
// userID: "userID",
|
||||
// refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
// audience: []string{"clientID1"},
|
||||
// scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// authMethodsReferences: []string{"password"},
|
||||
// lifetime: 5 * time.Minute,
|
||||
// authTime: time.Now(),
|
||||
// },
|
||||
// res: res{
|
||||
// err: zerrors.IsInternal,
|
||||
// },
|
||||
//},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, gotRefresh, err := c.AddAccessAndRefreshToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.refreshToken,
|
||||
tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.refreshIdleExpiration, tt.args.refreshExpiration, tt.args.authTime, tt.args.reason, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.token, got)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefresh)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_RevokeRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
@ -669,395 +371,3 @@ func TestCommands_RevokeRefreshTokens(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func refreshTokenEncryptionAlgorithm(ctrl *gomock.Controller) crypto.EncryptionAlgorithm {
|
||||
mCrypto := crypto.NewMockEncryptionAlgorithm(ctrl)
|
||||
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||
mCrypto.EXPECT().Encrypt(gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(refrehToken []byte) ([]byte, error) {
|
||||
return refrehToken, nil
|
||||
},
|
||||
)
|
||||
mCrypto.EXPECT().Decrypt(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(refrehToken []byte, keyID string) ([]byte, error) {
|
||||
if keyID != "id" {
|
||||
return nil, zerrors.ThrowInternal(nil, "id", "invalid key id")
|
||||
}
|
||||
return refrehToken, nil
|
||||
},
|
||||
)
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
func TestCommands_addRefreshToken(t *testing.T) {
|
||||
authTime := time.Now().Add(-1 * time.Hour)
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
accessToken *domain.Token
|
||||
authMethodsReferences []string
|
||||
authTime time.Time
|
||||
idleExpiration time.Duration
|
||||
expiration time.Duration
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
type res struct {
|
||||
event *user.HumanRefreshTokenAddedEvent
|
||||
refreshToken string
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
|
||||
{
|
||||
name: "add refresh Token",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
accessToken: &domain.Token{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "userID",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
TokenID: "accessTokenID1",
|
||||
ApplicationID: "clientID",
|
||||
UserAgentID: "agentID",
|
||||
RefreshTokenID: "refreshTokenID",
|
||||
Audience: []string{"clientID1"},
|
||||
Expiration: time.Now().Add(5 * time.Minute),
|
||||
Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
PreferredLanguage: "de",
|
||||
},
|
||||
authMethodsReferences: []string{"password"},
|
||||
authTime: authTime,
|
||||
idleExpiration: 1 * time.Hour,
|
||||
expiration: 10 * time.Hour,
|
||||
},
|
||||
res: res{
|
||||
event: user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "org1").Aggregate,
|
||||
"refreshTokenID",
|
||||
"clientID",
|
||||
"agentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
authTime,
|
||||
1*time.Hour,
|
||||
10*time.Hour,
|
||||
nil,
|
||||
),
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:refreshTokenID:refreshTokenID")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotEvent, gotRefreshToken, err := c.addRefreshToken(tt.args.ctx, tt.args.accessToken, tt.args.authMethodsReferences, tt.args.authTime, tt.args.idleExpiration, tt.args.expiration, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.event, gotEvent)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_renewRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
userID string
|
||||
orgID string
|
||||
refreshToken string
|
||||
idleExpiration time.Duration
|
||||
}
|
||||
type res struct {
|
||||
event *user.HumanRefreshTokenRenewedEvent
|
||||
refreshTokenID string
|
||||
newRefreshToken string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *renewedRefreshToken
|
||||
wantErr func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "empty token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid token (invalid userID), error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token inactive, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token expired, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "user deactivated, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
user.NewUserDeactivatedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "user signedout, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanSignedOutEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"userAgentID",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token renewed, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshToken1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
want: &renewedRefreshToken{
|
||||
event: user.NewHumanRefreshTokenRenewedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"refreshToken1",
|
||||
1*time.Hour,
|
||||
),
|
||||
authMethodsReferences: []string{"password"},
|
||||
tokenID: "tokenID",
|
||||
token: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:refreshToken1")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.renewRefreshToken(tt.args.ctx, tt.args.userID, tt.args.orgID, tt.args.refreshToken, tt.args.idleExpiration)
|
||||
if tt.wantErr != nil && !tt.wantErr(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want.event, got.event)
|
||||
assert.Equal(t, tt.want.authMethodsReferences, got.authMethodsReferences)
|
||||
assert.Equal(t, tt.want.tokenID, got.tokenID)
|
||||
assert.Equal(t, tt.want.token, got.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
@ -1433,91 +1432,6 @@ func TestCommandSide_RemoveUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandSide_AddUserToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type (
|
||||
args struct {
|
||||
ctx context.Context
|
||||
orgID string
|
||||
agentID string
|
||||
clientID string
|
||||
userID string
|
||||
audience []string
|
||||
scopes []string
|
||||
authMethodsReferences []string
|
||||
lifetime time.Duration
|
||||
authTime time.Time
|
||||
reason domain.TokenReason
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
)
|
||||
type res struct {
|
||||
want *domain.Token
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userid missing, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "org1",
|
||||
userID: "",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user not existing, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "org1",
|
||||
userID: "user1",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := r.AddUserToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.authTime, tt.args.reason, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_RevokeAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
|
@ -34,6 +34,7 @@ type AuthRequest struct {
|
||||
AvatarKey string
|
||||
PresignedAvatar string
|
||||
UserOrgID string
|
||||
PreferredLanguage *language.Tag
|
||||
RequestedOrgID string
|
||||
RequestedOrgName string
|
||||
RequestedPrimaryDomain string
|
||||
|
@ -11,6 +11,7 @@ type BrowserInfo struct {
|
||||
UserAgent string
|
||||
AcceptLanguage string
|
||||
RemoteIP net.IP
|
||||
Header net_http.Header
|
||||
}
|
||||
|
||||
func BrowserInfoFromRequest(r *net_http.Request) *BrowserInfo {
|
||||
@ -18,5 +19,18 @@ func BrowserInfoFromRequest(r *net_http.Request) *BrowserInfo {
|
||||
UserAgent: r.Header.Get(http_util.UserAgentHeader),
|
||||
AcceptLanguage: r.Header.Get(http_util.AcceptLanguage),
|
||||
RemoteIP: http_util.RemoteIPFromRequest(r),
|
||||
Header: r.Header,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BrowserInfo) ToUserAgent() *UserAgent {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserAgent{
|
||||
FingerprintID: &b.UserAgent,
|
||||
IP: b.RemoteIP,
|
||||
Description: &b.UserAgent,
|
||||
Header: b.Header,
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ const (
|
||||
DeviceAuthStateApproved // approved
|
||||
DeviceAuthStateDenied // denied
|
||||
DeviceAuthStateExpired // expired
|
||||
DeviceAuthStateDone // done
|
||||
|
||||
deviceAuthStateCount // invalid
|
||||
)
|
||||
|
@ -13,12 +13,13 @@ func _() {
|
||||
_ = x[DeviceAuthStateApproved-2]
|
||||
_ = x[DeviceAuthStateDenied-3]
|
||||
_ = x[DeviceAuthStateExpired-4]
|
||||
_ = x[deviceAuthStateCount-5]
|
||||
_ = x[DeviceAuthStateDone-5]
|
||||
_ = x[deviceAuthStateCount-6]
|
||||
}
|
||||
|
||||
const _DeviceAuthState_name = "undefinedinitiatedapproveddeniedexpiredinvalid"
|
||||
const _DeviceAuthState_name = "undefinedinitiatedapproveddeniedexpireddoneinvalid"
|
||||
|
||||
var _DeviceAuthState_index = [...]uint8{0, 9, 18, 26, 32, 39, 46}
|
||||
var _DeviceAuthState_index = [...]uint8{0, 9, 18, 26, 32, 39, 43, 50}
|
||||
|
||||
func (i DeviceAuthState) String() string {
|
||||
if i >= DeviceAuthState(len(_DeviceAuthState_index)-1) {
|
||||
|
@ -1,5 +1,13 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type OIDCErrorReason int32
|
||||
|
||||
const (
|
||||
@ -20,4 +28,21 @@ const (
|
||||
OIDCErrorReasonRequestNotSupported
|
||||
OIDCErrorReasonRequestURINotSupported
|
||||
OIDCErrorReasonRegistrationNotSupported
|
||||
OIDCErrorReasonInvalidGrant
|
||||
)
|
||||
|
||||
func OIDCErrorReasonFromError(err error) OIDCErrorReason {
|
||||
if errors.Is(err, oidc.ErrInvalidRequest()) {
|
||||
return OIDCErrorReasonInvalidRequest
|
||||
}
|
||||
if errors.Is(err, oidc.ErrInvalidGrant()) {
|
||||
return OIDCErrorReasonInvalidGrant
|
||||
}
|
||||
if zerrors.IsPreconditionFailed(err) {
|
||||
return OIDCErrorReasonAccessDenied
|
||||
}
|
||||
if zerrors.IsInternal(err) {
|
||||
return OIDCErrorReasonServerError
|
||||
}
|
||||
return OIDCErrorReasonUnspecified
|
||||
}
|
||||
|
51
internal/domain/oidc_error_reason_test.go
Normal file
51
internal/domain/oidc_error_reason_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestOIDCErrorReasonFromError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want OIDCErrorReason
|
||||
}{
|
||||
{
|
||||
name: "invalid request",
|
||||
err: oidc.ErrInvalidRequest().WithDescription("foo"),
|
||||
want: OIDCErrorReasonInvalidRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid grant",
|
||||
err: oidc.ErrInvalidGrant().WithDescription("foo"),
|
||||
want: OIDCErrorReasonInvalidGrant,
|
||||
},
|
||||
{
|
||||
name: "precondition failed",
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "123", "bar"),
|
||||
want: OIDCErrorReasonAccessDenied,
|
||||
},
|
||||
{
|
||||
name: "internal",
|
||||
err: zerrors.ThrowInternal(nil, "123", "bar"),
|
||||
want: OIDCErrorReasonServerError,
|
||||
},
|
||||
{
|
||||
name: "unspecified",
|
||||
err: io.ErrClosedPipe,
|
||||
want: OIDCErrorReasonUnspecified,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := OIDCErrorReasonFromError(tt.err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
@ -3,27 +3,10 @@ package domain
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
es_models.ObjectRoot
|
||||
|
||||
TokenID string
|
||||
ApplicationID string
|
||||
UserAgentID string
|
||||
RefreshTokenID string
|
||||
Audience []string
|
||||
Expiration time.Time
|
||||
Scopes []string
|
||||
PreferredLanguage string
|
||||
Reason TokenReason
|
||||
Actor *TokenActor
|
||||
}
|
||||
|
||||
func AddAudScopeToAudience(ctx context.Context, audience, scopes []string) []string {
|
||||
for _, scope := range scopes {
|
||||
if !(strings.HasPrefix(scope, ProjectIDScope) && strings.HasSuffix(scope, AudSuffix)) {
|
||||
|
@ -15,3 +15,10 @@ type UserAgent struct {
|
||||
func (ua UserAgent) IsEmpty() bool {
|
||||
return ua.FingerprintID == nil && len(ua.IP) == 0 && ua.Description == nil && ua.Header == nil
|
||||
}
|
||||
|
||||
func (ua *UserAgent) GetFingerprintID() string {
|
||||
if ua == nil || ua.FingerprintID == nil {
|
||||
return ""
|
||||
}
|
||||
return *ua.FingerprintID
|
||||
}
|
||||
|
42
internal/domain/user_agent_test.go
Normal file
42
internal/domain/user_agent_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUserAgent_GetFingerprintID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields *UserAgent
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil useragent",
|
||||
fields: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil fingerprintID",
|
||||
fields: &UserAgent{
|
||||
FingerprintID: nil,
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "value",
|
||||
fields: &UserAgent{
|
||||
FingerprintID: gu.Ptr("fp"),
|
||||
},
|
||||
want: "fp",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.fields.GetFingerprintID()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
@ -299,3 +299,24 @@ func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clien
|
||||
}
|
||||
return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) {
|
||||
name := gofakeit.Username()
|
||||
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
|
||||
Name: name,
|
||||
UserName: name,
|
||||
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
|
||||
UserId: user.GetUserId(),
|
||||
Type: authn.KeyType_KEY_TYPE_JSON,
|
||||
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return user.GetUserId(), keyResp.GetKeyDetails(), nil
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
@ -24,10 +26,13 @@ type OIDCSessionAccessTokenReadModel struct {
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenID string
|
||||
AccessTokenCreation time.Time
|
||||
AccessTokenExpiration time.Time
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
Reason domain.TokenReason
|
||||
Actor *domain.TokenActor
|
||||
}
|
||||
@ -79,6 +84,9 @@ func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethods = e.AuthMethods
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.Nonce = e.Nonce
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
wm.UserAgent = e.UserAgent
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
}
|
||||
|
||||
@ -112,7 +120,7 @@ func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (m
|
||||
if !model.AccessTokenExpiration.After(time.Now()) {
|
||||
return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
|
||||
}
|
||||
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation); err != nil {
|
||||
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation, model.UserAgent.GetFingerprintID()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return model, nil
|
||||
@ -132,16 +140,17 @@ func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSe
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event occurred after a certain time
|
||||
// and will return an error if so.
|
||||
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time) (err error) {
|
||||
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event (or user events leading to a session termination)
|
||||
// occurred after a certain time and will return an error if so.
|
||||
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time, fingerprintID string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
model := &sessionTerminatedModel{
|
||||
sessionID: sessionID,
|
||||
creation: creation,
|
||||
userID: userID,
|
||||
sessionID: sessionID,
|
||||
creation: creation,
|
||||
userID: userID,
|
||||
fingerPrintID: fingerprintID,
|
||||
}
|
||||
err = q.eventstore.FilterToQueryReducer(ctx, model)
|
||||
if err != nil {
|
||||
@ -155,9 +164,10 @@ func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID,
|
||||
}
|
||||
|
||||
type sessionTerminatedModel struct {
|
||||
creation time.Time
|
||||
sessionID string
|
||||
userID string
|
||||
creation time.Time
|
||||
sessionID string
|
||||
userID string
|
||||
fingerPrintID string
|
||||
|
||||
events int
|
||||
terminated bool
|
||||
@ -195,5 +205,12 @@ func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
|
||||
user.UserLockedType,
|
||||
user.UserRemovedType,
|
||||
).
|
||||
Or(). // for specific logout on v1 sessions from the same user agent
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(s.userID).
|
||||
EventTypes(
|
||||
user.HumanSignedOutType,
|
||||
).
|
||||
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
|
||||
Builder()
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i
|
||||
return key, err
|
||||
}
|
||||
|
||||
func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (key []byte, err error) {
|
||||
func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string) (key []byte, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
@ -59,34 +58,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type DeviceAuth struct {
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
Audience []string
|
||||
State domain.DeviceAuthState
|
||||
Subject string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
// DeviceAuthByDeviceCode gets the current state of a Device Authorization directly from the eventstore.
|
||||
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (deviceAuth *DeviceAuth, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
model := NewDeviceAuthReadModel(deviceCode, authz.GetInstance(ctx).InstanceID())
|
||||
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !model.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting")
|
||||
}
|
||||
return &model.DeviceAuth, nil
|
||||
}
|
||||
|
||||
// DeviceAuthRequestByUserCode finds a Device Authorization request by User-Code from the `device_auth_requests` projection.
|
||||
func (q *Queries) DeviceAuthRequestByUserCode(ctx context.Context, userCode string) (authReq *domain.AuthRequestDevice, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
|
@ -1,59 +0,0 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
)
|
||||
|
||||
type DeviceAuthReadModel struct {
|
||||
eventstore.ReadModel
|
||||
DeviceAuth
|
||||
}
|
||||
|
||||
func NewDeviceAuthReadModel(deviceCode, resourceOwner string) *DeviceAuthReadModel {
|
||||
return &DeviceAuthReadModel{
|
||||
ReadModel: eventstore.ReadModel{
|
||||
AggregateID: deviceCode,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DeviceAuthReadModel) Reduce() error {
|
||||
for _, event := range m.Events {
|
||||
switch e := event.(type) {
|
||||
case *deviceauth.AddedEvent:
|
||||
m.ClientID = e.ClientID
|
||||
m.DeviceCode = e.DeviceCode
|
||||
m.UserCode = e.UserCode
|
||||
m.Expires = e.Expires
|
||||
m.Scopes = e.Scopes
|
||||
m.Audience = e.Audience
|
||||
m.State = e.State
|
||||
case *deviceauth.ApprovedEvent:
|
||||
m.State = domain.DeviceAuthStateApproved
|
||||
m.Subject = e.Subject
|
||||
m.UserAuthMethods = e.UserAuthMethods
|
||||
m.AuthTime = e.AuthTime
|
||||
case *deviceauth.CanceledEvent:
|
||||
m.State = e.Reason.State()
|
||||
}
|
||||
}
|
||||
|
||||
return m.ReadModel.Reduce()
|
||||
}
|
||||
|
||||
func (m *DeviceAuthReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
ResourceOwner(m.ResourceOwner).
|
||||
AddQuery().
|
||||
AggregateTypes(deviceauth.AggregateType).
|
||||
AggregateIDs(m.AggregateID).
|
||||
EventTypes(
|
||||
deviceauth.AddedEventType,
|
||||
deviceauth.ApprovedEventType,
|
||||
deviceauth.CanceledEventType,
|
||||
).
|
||||
Builder()
|
||||
}
|
@ -6,167 +6,18 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
db_mock "github.com/zitadel/zitadel/internal/database/mock"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
|
||||
ctx := authz.NewMockContext("inst1", "org1", "user1")
|
||||
timestamp := time.Date(2015, 12, 15, 22, 13, 45, 0, time.UTC)
|
||||
tests := []struct {
|
||||
name string
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
want *DeviceAuth
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "filter error",
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
wantErr: zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting"),
|
||||
},
|
||||
{
|
||||
name: "ok, initiated",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
|
||||
[]string{"projectID", "clientID"},
|
||||
)),
|
||||
),
|
||||
),
|
||||
want: &DeviceAuth{
|
||||
ClientID: "client1",
|
||||
DeviceCode: "device1",
|
||||
UserCode: "user-code",
|
||||
Expires: timestamp,
|
||||
Scopes: []string{"foo", "bar"},
|
||||
Audience: []string{"projectID", "clientID"},
|
||||
State: domain.DeviceAuthStateInitiated,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, approved",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
|
||||
[]string{"projectID", "clientID"},
|
||||
)),
|
||||
eventFromEventPusher(deviceauth.NewApprovedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
"user1", []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
|
||||
timestamp,
|
||||
)),
|
||||
),
|
||||
),
|
||||
want: &DeviceAuth{
|
||||
ClientID: "client1",
|
||||
DeviceCode: "device1",
|
||||
UserCode: "user-code",
|
||||
Expires: timestamp,
|
||||
Scopes: []string{"foo", "bar"},
|
||||
Audience: []string{"projectID", "clientID"},
|
||||
State: domain.DeviceAuthStateApproved,
|
||||
Subject: "user1",
|
||||
UserAuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
|
||||
AuthTime: timestamp,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, denied",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
|
||||
[]string{"projectID", "clientID"},
|
||||
)),
|
||||
eventFromEventPusher(deviceauth.NewCanceledEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
domain.DeviceAuthCanceledDenied,
|
||||
)),
|
||||
),
|
||||
),
|
||||
want: &DeviceAuth{
|
||||
ClientID: "client1",
|
||||
DeviceCode: "device1",
|
||||
UserCode: "user-code",
|
||||
Expires: timestamp,
|
||||
Scopes: []string{"foo", "bar"},
|
||||
Audience: []string{"projectID", "clientID"},
|
||||
State: domain.DeviceAuthStateDenied,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, expired",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
|
||||
[]string{"projectID", "clientID"},
|
||||
)),
|
||||
eventFromEventPusher(deviceauth.NewCanceledEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("device1", "instance1"),
|
||||
domain.DeviceAuthCanceledExpired,
|
||||
)),
|
||||
),
|
||||
),
|
||||
want: &DeviceAuth{
|
||||
ClientID: "client1",
|
||||
DeviceCode: "device1",
|
||||
UserCode: "user-code",
|
||||
Expires: timestamp,
|
||||
Scopes: []string{"foo", "bar"},
|
||||
Audience: []string{"projectID", "clientID"},
|
||||
State: domain.DeviceAuthStateExpired,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := &Queries{
|
||||
eventstore: tt.eventstore(t),
|
||||
}
|
||||
got, err := q.DeviceAuthByDeviceCode(ctx, "device1")
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
expectedDeviceAuthQueryC = `SELECT` +
|
||||
` projections.device_auth_requests2.client_id,` +
|
||||
|
@ -74,6 +74,10 @@ func (p *deviceAuthRequestProjection) Reducers() []handler.AggregateReducer {
|
||||
Event: deviceauth.CanceledEventType,
|
||||
Reduce: p.reduceDoneEvents,
|
||||
},
|
||||
{
|
||||
Event: deviceauth.DoneEventType,
|
||||
Reduce: p.reduceDoneEvents,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -103,7 +107,7 @@ func (p *deviceAuthRequestProjection) reduceAdded(event eventstore.Event) (*hand
|
||||
// reduceDoneEvents removes the device auth request from the projection.
|
||||
func (p *deviceAuthRequestProjection) reduceDoneEvents(event eventstore.Event) (*handler.Statement, error) {
|
||||
switch event.(type) {
|
||||
case *deviceauth.ApprovedEvent, *deviceauth.CanceledEvent:
|
||||
case *deviceauth.ApprovedEvent, *deviceauth.CanceledEvent, *deviceauth.DoneEvent:
|
||||
return handler.NewDeleteStatement(event,
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DeviceAuthRequestColumnInstanceID, event.Aggregate().InstanceID),
|
||||
|
@ -22,20 +22,21 @@ const (
|
||||
type AddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
LoginClient string `json:"login_client"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
State string `json:"state,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Scope []string `json:"scope,omitempty"`
|
||||
Audience []string `json:"audience,omitempty"`
|
||||
ResponseType domain.OIDCResponseType `json:"response_type,omitempty"`
|
||||
CodeChallenge *domain.OIDCCodeChallenge `json:"code_challenge,omitempty"`
|
||||
Prompt []domain.Prompt `json:"prompt,omitempty"`
|
||||
UILocales []string `json:"ui_locales,omitempty"`
|
||||
MaxAge *time.Duration `json:"max_age,omitempty"`
|
||||
LoginHint *string `json:"login_hint,omitempty"`
|
||||
HintUserID *string `json:"hint_user_id,omitempty"`
|
||||
LoginClient string `json:"login_client"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
State string `json:"state,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Scope []string `json:"scope,omitempty"`
|
||||
Audience []string `json:"audience,omitempty"`
|
||||
ResponseType domain.OIDCResponseType `json:"response_type,omitempty"`
|
||||
CodeChallenge *domain.OIDCCodeChallenge `json:"code_challenge,omitempty"`
|
||||
Prompt []domain.Prompt `json:"prompt,omitempty"`
|
||||
UILocales []string `json:"ui_locales,omitempty"`
|
||||
MaxAge *time.Duration `json:"max_age,omitempty"`
|
||||
LoginHint *string `json:"login_hint,omitempty"`
|
||||
HintUserID *string `json:"hint_user_id,omitempty"`
|
||||
NeedRefreshToken bool `json:"need_refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
func (e *AddedEvent) Payload() interface{} {
|
||||
@ -62,6 +63,7 @@ func NewAddedEvent(ctx context.Context,
|
||||
maxAge *time.Duration,
|
||||
loginHint,
|
||||
hintUserID *string,
|
||||
needRefreshToken bool,
|
||||
) *AddedEvent {
|
||||
return &AddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
@ -69,20 +71,21 @@ func NewAddedEvent(ctx context.Context,
|
||||
aggregate,
|
||||
AddedType,
|
||||
),
|
||||
LoginClient: loginClient,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: responseType,
|
||||
CodeChallenge: codeChallenge,
|
||||
Prompt: prompt,
|
||||
UILocales: uiLocales,
|
||||
MaxAge: maxAge,
|
||||
LoginHint: loginHint,
|
||||
HintUserID: hintUserID,
|
||||
LoginClient: loginClient,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: responseType,
|
||||
CodeChallenge: codeChallenge,
|
||||
Prompt: prompt,
|
||||
UILocales: uiLocales,
|
||||
MaxAge: maxAge,
|
||||
LoginHint: loginHint,
|
||||
HintUserID: hintUserID,
|
||||
NeedRefreshToken: needRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
@ -13,18 +15,20 @@ const (
|
||||
AddedEventType = eventTypePrefix + "added"
|
||||
ApprovedEventType = eventTypePrefix + "approved"
|
||||
CanceledEventType = eventTypePrefix + "canceled"
|
||||
DoneEventType = eventTypePrefix + "done"
|
||||
)
|
||||
|
||||
type AddedEvent struct {
|
||||
*eventstore.BaseEvent `json:"-"`
|
||||
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
Audience []string
|
||||
State domain.DeviceAuthState
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
Audience []string
|
||||
State domain.DeviceAuthState
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
func (e *AddedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
||||
@ -48,20 +52,26 @@ func NewAddedEvent(
|
||||
expires time.Time,
|
||||
scopes []string,
|
||||
audience []string,
|
||||
needRefreshToken bool,
|
||||
) *AddedEvent {
|
||||
return &AddedEvent{
|
||||
eventstore.NewBaseEventForPush(
|
||||
ctx, aggregate, AddedEventType,
|
||||
),
|
||||
clientID, deviceCode, userCode, expires, scopes, audience, domain.DeviceAuthStateInitiated}
|
||||
clientID, deviceCode, userCode, expires, scopes, audience,
|
||||
domain.DeviceAuthStateInitiated, needRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
type ApprovedEvent struct {
|
||||
*eventstore.BaseEvent `json:"-"`
|
||||
|
||||
Subject string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
UserID string
|
||||
UserOrgID string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
}
|
||||
|
||||
func (e *ApprovedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
||||
@ -79,17 +89,23 @@ func (e *ApprovedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
func NewApprovedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
subject string,
|
||||
userID,
|
||||
userOrgID string,
|
||||
userAuthMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) *ApprovedEvent {
|
||||
return &ApprovedEvent{
|
||||
eventstore.NewBaseEventForPush(
|
||||
ctx, aggregate, ApprovedEventType,
|
||||
),
|
||||
subject,
|
||||
userID,
|
||||
userOrgID,
|
||||
userAuthMethods,
|
||||
authTime,
|
||||
preferredLanguage,
|
||||
userAgent,
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,3 +130,23 @@ func (e *CanceledEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
func NewCanceledEvent(ctx context.Context, aggregate *eventstore.Aggregate, reason domain.DeviceAuthCanceled) *CanceledEvent {
|
||||
return &CanceledEvent{eventstore.NewBaseEventForPush(ctx, aggregate, CanceledEventType), reason}
|
||||
}
|
||||
|
||||
type DoneEvent struct {
|
||||
*eventstore.BaseEvent `json:"-"`
|
||||
}
|
||||
|
||||
func (e *DoneEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
||||
e.BaseEvent = b
|
||||
}
|
||||
|
||||
func (e *DoneEvent) Payload() any {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *DoneEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDoneEvent(ctx context.Context, aggregate *eventstore.Aggregate) *DoneEvent {
|
||||
return &DoneEvent{eventstore.NewBaseEventForPush(ctx, aggregate, DoneEventType)}
|
||||
}
|
||||
|
@ -6,4 +6,5 @@ func init() {
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, AddedEventType, eventstore.GenericEventMapper[AddedEvent])
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, ApprovedEventType, eventstore.GenericEventMapper[ApprovedEvent])
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, CanceledEventType, eventstore.GenericEventMapper[CanceledEvent])
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, DoneEventType, eventstore.GenericEventMapper[DoneEvent])
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
@ -21,13 +23,17 @@ const (
|
||||
type AddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
UserID string `json:"userID"`
|
||||
SessionID string `json:"sessionID"`
|
||||
ClientID string `json:"clientID"`
|
||||
Audience []string `json:"audience"`
|
||||
Scope []string `json:"scope"`
|
||||
AuthMethods []domain.UserAuthMethodType `json:"authMethods"`
|
||||
AuthTime time.Time `json:"authTime"`
|
||||
UserID string `json:"userID"`
|
||||
UserResourceOwner string `json:"userResourceOwner"`
|
||||
SessionID string `json:"sessionID"`
|
||||
ClientID string `json:"clientID"`
|
||||
Audience []string `json:"audience"`
|
||||
Scope []string `json:"scope"`
|
||||
AuthMethods []domain.UserAuthMethodType `json:"authMethods"`
|
||||
AuthTime time.Time `json:"authTime"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
PreferredLanguage *language.Tag `json:"preferredLanguage,omitempty"`
|
||||
UserAgent *domain.UserAgent `json:"userAgent,omitempty"`
|
||||
}
|
||||
|
||||
func (e *AddedEvent) Payload() interface{} {
|
||||
@ -45,12 +51,16 @@ func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
|
||||
func NewAddedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
clientID string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
nonce string,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) *AddedEvent {
|
||||
return &AddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
@ -58,13 +68,17 @@ func NewAddedEvent(ctx context.Context,
|
||||
aggregate,
|
||||
AddedType,
|
||||
),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
ClientID: clientID,
|
||||
Audience: audience,
|
||||
Scope: scope,
|
||||
AuthMethods: authMethods,
|
||||
AuthTime: authTime,
|
||||
UserID: userID,
|
||||
UserResourceOwner: userResourceOwner,
|
||||
SessionID: sessionID,
|
||||
ClientID: clientID,
|
||||
Audience: audience,
|
||||
Scope: scope,
|
||||
AuthMethods: authMethods,
|
||||
AuthTime: authTime,
|
||||
Nonce: nonce,
|
||||
PreferredLanguage: preferredLanguage,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -75,9 +77,10 @@ func AddedEventMapper(event eventstore.Event) (eventstore.Event, error) {
|
||||
type UserCheckedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
UserID string `json:"userID"`
|
||||
UserResourceOwner string `json:"userResourceOwner"`
|
||||
CheckedAt time.Time `json:"checkedAt"`
|
||||
UserID string `json:"userID"`
|
||||
UserResourceOwner string `json:"userResourceOwner"`
|
||||
CheckedAt time.Time `json:"checkedAt"`
|
||||
PreferredLanguage *language.Tag `json:"preferredLanguage,omitempty"`
|
||||
}
|
||||
|
||||
func (e *UserCheckedEvent) Payload() interface{} {
|
||||
@ -94,6 +97,7 @@ func NewUserCheckedEvent(
|
||||
userID,
|
||||
userResourceOwner string,
|
||||
checkedAt time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
) *UserCheckedEvent {
|
||||
return &UserCheckedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
@ -104,6 +108,7 @@ func NewUserCheckedEvent(
|
||||
UserID: userID,
|
||||
UserResourceOwner: userResourceOwner,
|
||||
CheckedAt: checkedAt,
|
||||
PreferredLanguage: preferredLanguage,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,6 +42,7 @@ func init() {
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserReactivatedType, UserReactivatedEventMapper)
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserRemovedType, UserRemovedEventMapper)
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserTokenAddedType, UserTokenAddedEventMapper)
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserTokenV2AddedType, eventstore.GenericEventMapper[UserTokenV2AddedEvent])
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserImpersonatedType, eventstore.GenericEventMapper[UserImpersonatedEvent])
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserTokenRemovedType, UserTokenRemovedEventMapper)
|
||||
eventstore.RegisterFilterEventMapper(AggregateType, UserDomainClaimedType, DomainClaimedEventMapper)
|
||||
|
@ -19,6 +19,7 @@ const (
|
||||
UserReactivatedType = userEventTypePrefix + "reactivated"
|
||||
UserRemovedType = userEventTypePrefix + "removed"
|
||||
UserTokenAddedType = userEventTypePrefix + "token.added"
|
||||
UserTokenV2AddedType = userEventTypePrefix + "token.v2.added"
|
||||
UserTokenRemovedType = userEventTypePrefix + "token.removed"
|
||||
UserImpersonatedType = userEventTypePrefix + "impersonated"
|
||||
UserDomainClaimedType = userEventTypePrefix + "domain.claimed"
|
||||
@ -279,6 +280,39 @@ func UserTokenAddedEventMapper(event eventstore.Event) (eventstore.Event, error)
|
||||
return tokenAdded, nil
|
||||
}
|
||||
|
||||
type UserTokenV2AddedEvent struct {
|
||||
*eventstore.BaseEvent `json:"-"`
|
||||
|
||||
TokenID string `json:"tokenId,omitempty"`
|
||||
}
|
||||
|
||||
func (e *UserTokenV2AddedEvent) Payload() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *UserTokenV2AddedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
||||
e.BaseEvent = b
|
||||
}
|
||||
|
||||
func (e *UserTokenV2AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewUserTokenV2AddedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
tokenID string,
|
||||
) *UserTokenV2AddedEvent {
|
||||
return &UserTokenV2AddedEvent{
|
||||
BaseEvent: eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
UserTokenV2AddedType,
|
||||
),
|
||||
TokenID: tokenID,
|
||||
}
|
||||
}
|
||||
|
||||
type UserImpersonatedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
|
@ -631,6 +631,7 @@ EventTypes:
|
||||
failed: Проверката на инициализацията е неуспешна
|
||||
token:
|
||||
added: Токенът за достъп е създаден
|
||||
v2.added: Токенът за достъп е създаден
|
||||
removed: Токенът за достъп е премахнат
|
||||
impersonated: Имитиран потребител
|
||||
username:
|
||||
|
@ -612,6 +612,7 @@ EventTypes:
|
||||
failed: Kontrola inicializace selhala
|
||||
token:
|
||||
added: Přístupový token vytvořen
|
||||
v2.added: Přístupový token vytvořen
|
||||
removed: Přístupový token odstraněn
|
||||
impersonated: Usuario suplantado
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: Benutzerinitialisierung fehlgeschlagen
|
||||
token:
|
||||
added: Access Token ausgestellt
|
||||
v2.added: Access Token ausgestellt
|
||||
removed: Access Token gelöscht
|
||||
impersonated: Benutzer hat sich als Benutzer ausgegeben
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: Initialization check failed
|
||||
token:
|
||||
added: Access Token created
|
||||
v2.added: Access Token created
|
||||
removed: Access Token removed
|
||||
impersonated: User impersonated
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: La comprobación de inicialización falló
|
||||
token:
|
||||
added: Token de acceso creado
|
||||
v2.added: Token de acceso creado
|
||||
removed: Token de acceso eliminado
|
||||
impersonated: Usuario suplantado
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: La vérification de l'initialisation a échoué
|
||||
token:
|
||||
added: Jeton d'accès créé
|
||||
v2.added: Jeton d'accès créé
|
||||
impersonated: Utilisateur usurpé l'identité
|
||||
username:
|
||||
reserved: Nom d'utilisateur réservé
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: Controllo dell'inizializzazione fallito
|
||||
token:
|
||||
added: Access Token creato
|
||||
v2.added: Access Token creato
|
||||
impersonated: Utente impersonificato
|
||||
username:
|
||||
reserved: Nome utente riservato
|
||||
|
@ -603,6 +603,7 @@ EventTypes:
|
||||
failed: 初期化チェックの失敗
|
||||
token:
|
||||
added: アクセストークンの作成
|
||||
v2.added: アクセストークンの作成
|
||||
removed: アクセストークンの削除
|
||||
impersonated: ユーザーがなりすました
|
||||
username:
|
||||
|
@ -613,6 +613,7 @@ EventTypes:
|
||||
failed: Проверката на иницијализацијата е неуспешна
|
||||
token:
|
||||
added: Креиран е токен за пристап
|
||||
v2.added: Креиран е токен за пристап
|
||||
removed: Токенот за пристап е отстранет
|
||||
impersonated: Корисникот имитиран
|
||||
username:
|
||||
|
@ -613,6 +613,7 @@ EventTypes:
|
||||
failed: Initialisatiecontrole mislukt
|
||||
token:
|
||||
added: Toegangstoken aangemaakt
|
||||
v2.added: Toegangstoken aangemaakt
|
||||
removed: Toegangstoken verwijderd
|
||||
impersonated: Gebruiker nagebootst
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: Sprawdzenie inicjujące nie powiodło się
|
||||
token:
|
||||
added: Token dostępu utworzony
|
||||
v2.added: Token dostępu utworzony
|
||||
removed: Token dostępu usunięty
|
||||
impersonated: Użytkownik podszywał się pod użytkownika
|
||||
username:
|
||||
|
@ -609,6 +609,7 @@ EventTypes:
|
||||
failed: Verificação de inicialização falhou
|
||||
token:
|
||||
added: Token de acesso criado
|
||||
v2.added: Token de acesso criado
|
||||
removed: Token de acesso removido
|
||||
impersonated: Usuário personificado
|
||||
username:
|
||||
|
@ -603,6 +603,7 @@ EventTypes:
|
||||
failed: Проверка инициализации не удалась
|
||||
token:
|
||||
added: Токен доступа создан
|
||||
v2.added: Токен доступа создан
|
||||
removed: Токен доступа удалён
|
||||
impersonated: Пользователь олицетворяет себя
|
||||
username:
|
||||
|
@ -614,6 +614,7 @@ EventTypes:
|
||||
failed: 初始化检查失败
|
||||
token:
|
||||
added: 已创建访问令牌
|
||||
v2.added: 已创建访问令牌
|
||||
impersonated: 用户冒充
|
||||
username:
|
||||
reserved: 保留用户名
|
||||
|
Loading…
x
Reference in New Issue
Block a user