feat: allow ECDSA and ED25519 public keys

This commit is contained in:
Livio Spring 2024-07-05 14:54:11 +02:00
parent 5ca8ad2075
commit 5d4e4f5e2c
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
4 changed files with 54 additions and 46 deletions

View File

@ -2,7 +2,7 @@ package authz
import (
"context"
"crypto/rsa"
"crypto"
"errors"
"os"
"sync"
@ -11,7 +11,7 @@ import (
"github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto"
zcrypto "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -44,7 +44,7 @@ func StartSystemTokenVerifierFromConfig(issuer string, keys map[string]*SystemAP
systemJWTProfile: op.NewJWTProfileVerifier(
&systemJWTStorage{
keys: keys,
cachedKeys: make(map[string]*rsa.PublicKey),
cachedKeys: make(map[string]crypto.PublicKey),
},
issuer,
1*time.Hour,
@ -77,7 +77,7 @@ func (s *SystemTokenVerifierFromConfig) VerifySystemToken(ctx context.Context, t
type systemJWTStorage struct {
keys map[string]*SystemAPIUser
mutex sync.Mutex
cachedKeys map[string]*rsa.PublicKey
cachedKeys map[string]crypto.PublicKey
}
type SystemAPIUser struct {
@ -86,7 +86,7 @@ type SystemAPIUser struct {
Memberships Memberships
}
func (s *SystemAPIUser) readKey() (*rsa.PublicKey, error) {
func (s *SystemAPIUser) readKey() (crypto.PublicKey, error) {
if s.Path != "" {
var err error
s.KeyData, err = os.ReadFile(s.Path)
@ -94,7 +94,7 @@ func (s *SystemAPIUser) readKey() (*rsa.PublicKey, error) {
return nil, zerrors.ThrowInternal(err, "AUTHZ-JK31F", "Errors.NotFound")
}
}
return crypto.BytesToPublicKey(s.KeyData)
return zcrypto.BytesToPublicKey(s.KeyData)
}
func (s *systemJWTStorage) GetKeyByIDAndClientID(_ context.Context, _, userID string) (*jose.JSONWebKey, error) {

View File

@ -2,6 +2,9 @@ package crypto
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@ -169,8 +172,9 @@ func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
}
var ErrEmpty = errors.New("cannot decode, empty data")
var ErrNoPublicKey = errors.New("key is no supported public key type")
func BytesToPublicKey(pub []byte) (*rsa.PublicKey, error) {
func BytesToPublicKey(pub []byte) (crypto.PublicKey, error) {
if len(pub) == 0 {
return nil, ErrEmpty
}
@ -178,15 +182,18 @@ func BytesToPublicKey(pub []byte) (*rsa.PublicKey, error) {
if block == nil {
return nil, ErrEmpty
}
ifc, err := x509.ParsePKIXPublicKey(block.Bytes)
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
key, ok := ifc.(*rsa.PublicKey)
if !ok {
return nil, err
switch key.(type) {
case *rsa.PublicKey,
*ecdsa.PublicKey,
ed25519.PublicKey:
return key, nil
default:
return nil, ErrNoPublicKey
}
return key, nil
}
func EncryptKeys(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, alg EncryptionAlgorithm) (*CryptoValue, *CryptoValue, error) {

View File

@ -2,7 +2,7 @@ package query
import (
"context"
"crypto/rsa"
"crypto"
"database/sql"
"time"
@ -10,7 +10,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/crypto"
zcrypto "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query/projection"
@ -29,7 +29,7 @@ type Key interface {
type PrivateKey interface {
Key
Expiry() time.Time
Key() *crypto.CryptoValue
Key() *zcrypto.CryptoValue
}
type PublicKey interface {
@ -77,21 +77,21 @@ func (k *key) Sequence() uint64 {
type privateKey struct {
key
expiry time.Time
privateKey *crypto.CryptoValue
privateKey *zcrypto.CryptoValue
}
func (k *privateKey) Expiry() time.Time {
return k.expiry
}
func (k *privateKey) Key() *crypto.CryptoValue {
func (k *privateKey) Key() *zcrypto.CryptoValue {
return k.privateKey
}
type rsaPublicKey struct {
key
expiry time.Time
publicKey *rsa.PublicKey
publicKey crypto.PublicKey
}
func (r *rsaPublicKey) Expiry() time.Time {
@ -281,7 +281,7 @@ func preparePublicKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectB
if err != nil {
return nil, err
}
k.publicKey, err = crypto.BytesToPublicKey(keyValue)
k.publicKey, err = zcrypto.BytesToPublicKey(keyValue)
if err != nil {
return nil, err
}
@ -356,7 +356,7 @@ type PublicKeyReadModel struct {
eventstore.ReadModel
Algorithm string
Key *crypto.CryptoValue
Key *zcrypto.CryptoValue
Expiry time.Time
Usage domain.KeyUsage
}
@ -410,11 +410,11 @@ func (q *Queries) GetPublicKeyByID(ctx context.Context, keyID string) (_ PublicK
if model.Algorithm == "" || model.Key == nil {
return nil, zerrors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
}
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
keyValue, err := zcrypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal")
}
publicKey, err := crypto.BytesToPublicKey(keyValue)
publicKey, err := zcrypto.BytesToPublicKey(keyValue)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Kai2Z", "Errors.Internal")
}

View File

@ -2,6 +2,7 @@ package query
import (
"context"
"crypto"
"crypto/rsa"
"database/sql"
"database/sql/driver"
@ -18,7 +19,7 @@ import (
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
zcrypto "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
@ -215,8 +216,8 @@ func Test_KeyPrepares(t *testing.T) {
use: domain.KeyUsageSigning,
},
expiry: testNow,
privateKey: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
privateKey: &zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("privateKey"),
@ -276,7 +277,7 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
tests := []struct {
name string
eventstore func(*testing.T) *eventstore.Eventstore
encryption func(*testing.T) *crypto.MockEncryptionAlgorithm
encryption func(*testing.T) *zcrypto.MockEncryptionAlgorithm
want *rsaPublicKey
wantErr error
}{
@ -307,14 +308,14 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
Version: key_repo.AggregateVersion,
},
domain.KeyUsageSigning, "alg",
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("private"),
},
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("public"),
@ -324,8 +325,8 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
)),
),
),
encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm {
encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
encryption: func(t *testing.T) *zcrypto.MockEncryptionAlgorithm {
encryption := zcrypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
expect := encryption.EXPECT()
expect.Algorithm().Return("alg")
expect.DecryptionKeyIDs().Return([]string{})
@ -346,14 +347,14 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
Version: key_repo.AggregateVersion,
},
domain.KeyUsageSigning, "alg",
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("private"),
},
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("public"),
@ -363,8 +364,8 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
)),
),
),
encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm {
encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
encryption: func(t *testing.T) *zcrypto.MockEncryptionAlgorithm {
encryption := zcrypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
expect := encryption.EXPECT()
expect.Algorithm().Return("alg")
expect.DecryptionKeyIDs().Return([]string{"keyID"})
@ -386,14 +387,14 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
Version: key_repo.AggregateVersion,
},
domain.KeyUsageSigning, "alg",
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("private"),
},
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
&zcrypto.CryptoValue{
CryptoType: zcrypto.TypeEncryption,
Algorithm: "alg",
KeyID: "keyID",
Crypted: []byte("public"),
@ -403,8 +404,8 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
)),
),
),
encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm {
encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
encryption: func(t *testing.T) *zcrypto.MockEncryptionAlgorithm {
encryption := zcrypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
expect := encryption.EXPECT()
expect.Algorithm().Return("alg")
expect.DecryptionKeyIDs().Return([]string{"keyID"})
@ -419,8 +420,8 @@ func TestQueries_GetPublicKeyByID(t *testing.T) {
use: domain.KeyUsageSigning,
},
expiry: future,
publicKey: func() *rsa.PublicKey {
publicKey, err := crypto.BytesToPublicKey([]byte(pubKey))
publicKey: func() crypto.PublicKey {
publicKey, err := zcrypto.BytesToPublicKey([]byte(pubKey))
if err != nil {
panic(err)
}