diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 06de8b00d4..246be154a6 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -110,17 +110,7 @@ func grantsToScopes(grants []*grant_model.UserGrantView) []string { func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - var userAgentID, applicationID, userOrgID string - var authTime time.Time - var authMethodsReferences []string - authReq, ok := req.(*AuthRequest) - if ok { - userAgentID = authReq.AgentID - applicationID = authReq.ApplicationID - userOrgID = authReq.UserOrgID - authTime = authReq.AuthTime - authMethodsReferences = authReq.GetAMR() - } + userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req) resp, token, err := o.command.AddAccessAndRefreshToken(ctx, userOrgID, userAgentID, applicationID, req.GetSubject(), refreshToken, req.GetAudience(), req.GetScopes(), authMethodsReferences, o.defaultAccessTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, authTime) //PLANNED: lifetime from client @@ -130,6 +120,18 @@ func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.Tok return resp.TokenID, token, resp.Expiration, nil } +func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time, []string) { + authReq, ok := req.(*AuthRequest) + if ok { + return authReq.AgentID, authReq.ApplicationID, authReq.UserOrgID, authReq.AuthTime, authReq.GetAMR() + } + refreshReq, ok := req.(*RefreshTokenRequest) + if ok { + return refreshReq.UserAgentID, refreshReq.ClientID, "", refreshReq.AuthTime, refreshReq.AuthMethodsReferences + } + return "", "", "", time.Time{}, nil +} + func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { tokenView, err := o.repo.RefreshTokenByID(ctx, refreshToken) if err != nil { diff --git a/internal/command/user.go b/internal/command/user.go index 75f9360e23..f28ddc2877 100644 --- a/internal/command/user.go +++ b/internal/command/user.go @@ -211,7 +211,7 @@ func (c *Commands) RemoveUser(ctx context.Context, userID, resourceOwner string, } func (c *Commands) AddUserToken(ctx context.Context, orgID, agentID, clientID, userID string, audience, scopes []string, lifetime time.Duration) (*domain.Token, error) { - if orgID == "" || userID == "" { + if userID == "" { //do not check for empty orgID (JWT Profile requests won't provide it, so service user requests fail) return nil, caos_errs.ThrowInvalidArgument(nil, "COMMAND-Dbge4", "Errors.IDMissing") } userWriteModel := NewUserWriteModel(userID, orgID) diff --git a/internal/user/repository/view/model/refresh_token.go b/internal/user/repository/view/model/refresh_token.go index beed817751..cad29f5de3 100644 --- a/internal/user/repository/view/model/refresh_token.go +++ b/internal/user/repository/view/model/refresh_token.go @@ -24,14 +24,14 @@ const ( ) type RefreshTokenView struct { - ID string `json:"tokenId" gorm:"column:id"` + ID string `json:"tokenId" gorm:"column:id;primary_key"` CreationDate time.Time `json:"-" gorm:"column:creation_date"` ChangeDate time.Time `json:"-" gorm:"column:change_date"` ResourceOwner string `json:"-" gorm:"column:resource_owner"` Token string `json:"-" gorm:"column:token"` - UserID string `json:"-" gorm:"column:user_id;primary_key"` - ClientID string `json:"clientID" gorm:"column:client_id;primary_key"` - UserAgentID string `json:"userAgentId" gorm:"column:user_agent_id;primary_key"` + UserID string `json:"-" gorm:"column:user_id"` + ClientID string `json:"clientID" gorm:"column:client_id"` + UserAgentID string `json:"userAgentId" gorm:"column:user_agent_id"` Audience pq.StringArray `json:"audience" gorm:"column:audience"` Scopes pq.StringArray `json:"scopes" gorm:"column:scopes"` AuthMethodsReferences pq.StringArray `json:"authMethodsReference" gorm:"column:amr"` diff --git a/internal/user/repository/view/refresh_token_view.go b/internal/user/repository/view/refresh_token_view.go index ca342e87c0..149219fd54 100644 --- a/internal/user/repository/view/refresh_token_view.go +++ b/internal/user/repository/view/refresh_token_view.go @@ -1,13 +1,14 @@ package view import ( + "github.com/jinzhu/gorm" + "github.com/lib/pq" + "github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/user/model" usr_model "github.com/caos/zitadel/internal/user/repository/view/model" "github.com/caos/zitadel/internal/view/repository" - "github.com/jinzhu/gorm" - "github.com/lib/pq" ) func RefreshTokenByID(db *gorm.DB, table, tokenID string) (*usr_model.RefreshTokenView, error) { @@ -35,7 +36,10 @@ func RefreshTokensByUserID(db *gorm.DB, table, userID string) ([]*usr_model.Refr } func PutRefreshToken(db *gorm.DB, table string, token *usr_model.RefreshTokenView) error { - save := repository.PrepareSave(table) + save := repository.PrepareSaveOnConflict(table, + []string{"client_id", "user_agent_id", "user_id"}, + []string{"id", "creation_date", "change_date", "token", "auth_time", "idle_expiration", "expiration", "sequence", "scopes", "audience", "amr"}, + ) return save(db, token) } diff --git a/internal/user/repository/view/token_view.go b/internal/user/repository/view/token_view.go index 93af59ce4a..7daa4cdcab 100644 --- a/internal/user/repository/view/token_view.go +++ b/internal/user/repository/view/token_view.go @@ -1,13 +1,14 @@ package view import ( + "github.com/jinzhu/gorm" + "github.com/lib/pq" + "github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/user/model" usr_model "github.com/caos/zitadel/internal/user/repository/view/model" "github.com/caos/zitadel/internal/view/repository" - "github.com/jinzhu/gorm" - "github.com/lib/pq" ) func TokenByID(db *gorm.DB, table, tokenID string) (*usr_model.TokenView, error) { diff --git a/internal/view/repository/requests.go b/internal/view/repository/requests.go index 5cd983e346..2e8744cb1c 100644 --- a/internal/view/repository/requests.go +++ b/internal/view/repository/requests.go @@ -3,6 +3,7 @@ package repository import ( "errors" "fmt" + "strings" "github.com/caos/logging" "github.com/jinzhu/gorm" @@ -80,6 +81,21 @@ func PrepareSave(table string) func(db *gorm.DB, object interface{}) error { } } +func PrepareSaveOnConflict(table string, conflictColumns, updateColumns []string) func(db *gorm.DB, object interface{}) error { + updates := make([]string, len(updateColumns)) + for i, column := range updateColumns { + updates[i] = column + "=excluded." + column + } + onConflict := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", strings.Join(conflictColumns, ","), strings.Join(updates, ",")) + return func(db *gorm.DB, object interface{}) error { + err := db.Table(table).Set("gorm:insert_option", onConflict).Save(object).Error + if err != nil { + return caos_errs.ThrowInternal(err, "VIEW-AfC7G", "unable to put object to view") + } + return nil + } +} + func PrepareDeleteByKey(table string, key ColumnKey, id interface{}) func(db *gorm.DB) error { return func(db *gorm.DB) error { err := db.Table(table). diff --git a/migrations/cockroach/V1.45__refresh_tokens.sql b/migrations/cockroach/V1.45__refresh_tokens.sql new file mode 100644 index 0000000000..9097112b2a --- /dev/null +++ b/migrations/cockroach/V1.45__refresh_tokens.sql @@ -0,0 +1,2 @@ +ALTER TABLE auth.refresh_tokens ALTER COLUMN id SET NOT NULL; +ALTER TABLE auth.refresh_tokens ALTER PRIMARY KEY USING COLUMNS(id); \ No newline at end of file