mirror of
https://github.com/zitadel/zitadel.git
synced 2025-10-24 15:40:00 +00:00
148 lines
3.7 KiB
Go
148 lines
3.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
_ "embed"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"text/template"
|
|
|
|
"google.golang.org/protobuf/compiler/protogen"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/descriptorpb"
|
|
"google.golang.org/protobuf/types/pluginpb"
|
|
|
|
protoc_gen_zitadel "github.com/zitadel/zitadel/pkg/grpc/protoc/v2"
|
|
)
|
|
|
|
var (
|
|
//go:embed zitadel.pb.go.tmpl
|
|
zitadelTemplate []byte
|
|
)
|
|
|
|
type authMethods struct {
|
|
GoPackageName string
|
|
ProtoPackageName string
|
|
ServiceName string
|
|
AuthOptions []authOption
|
|
AuthContext []authContext
|
|
CustomHTTPResponses []httpResponse
|
|
}
|
|
|
|
type authOption struct {
|
|
Name string
|
|
Permission string
|
|
CheckFieldName string
|
|
}
|
|
|
|
type authContext struct {
|
|
Name string
|
|
OrgMethod string
|
|
}
|
|
|
|
type httpResponse struct {
|
|
Name string
|
|
Code int32
|
|
}
|
|
|
|
func main() {
|
|
input, _ := io.ReadAll(os.Stdin)
|
|
var req pluginpb.CodeGeneratorRequest
|
|
err := proto.Unmarshal(input, &req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
opts := protogen.Options{}
|
|
plugin, err := opts.New(&req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
|
|
|
|
tmpl := loadTemplate(zitadelTemplate)
|
|
|
|
for _, file := range plugin.Files {
|
|
methods := new(authMethods)
|
|
for _, service := range file.Services {
|
|
methods.ServiceName = service.GoName
|
|
methods.GoPackageName = string(file.GoPackageName)
|
|
methods.ProtoPackageName = *file.Proto.Package
|
|
for _, method := range service.Methods {
|
|
options := method.Desc.Options().(*descriptorpb.MethodOptions)
|
|
if options == nil {
|
|
continue
|
|
}
|
|
ext := proto.GetExtension(options, protoc_gen_zitadel.E_Options).(*protoc_gen_zitadel.Options)
|
|
if ext == nil {
|
|
continue
|
|
}
|
|
if ext.AuthOption != nil {
|
|
generateAuthOption(methods, ext.AuthOption, method)
|
|
}
|
|
if ext.HttpResponse != nil {
|
|
methods.CustomHTTPResponses = append(methods.CustomHTTPResponses, httpResponse{Name: string(method.Output.Desc.Name()), Code: ext.HttpResponse.SuccessCode})
|
|
}
|
|
}
|
|
}
|
|
if len(methods.AuthOptions) > 0 {
|
|
generateFile(tmpl, methods, file, plugin)
|
|
}
|
|
}
|
|
|
|
// Generate a response from our plugin and marshall as protobuf
|
|
stdout := plugin.Response()
|
|
out, err := proto.Marshal(stdout)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Write the response to stdout, to be picked up by protoc
|
|
_, err = fmt.Fprint(os.Stdout, string(out))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func generateAuthOption(methods *authMethods, protoAuthOption *protoc_gen_zitadel.AuthOption, method *protogen.Method) {
|
|
methods.AuthOptions = append(methods.AuthOptions, authOption{Name: string(method.Desc.Name()), Permission: protoAuthOption.Permission})
|
|
if protoAuthOption.OrgField == "" {
|
|
return
|
|
}
|
|
orgMethod := buildAuthContextField(method.Input.Fields, protoAuthOption.OrgField)
|
|
if orgMethod != "" {
|
|
methods.AuthContext = append(methods.AuthContext, authContext{Name: string(method.Input.Desc.Name()), OrgMethod: orgMethod})
|
|
}
|
|
}
|
|
|
|
func generateFile(tmpl *template.Template, methods *authMethods, protoFile *protogen.File, plugin *protogen.Plugin) {
|
|
var buffer bytes.Buffer
|
|
err := tmpl.Execute(&buffer, &methods)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
filename := protoFile.GeneratedFilenamePrefix + ".pb.zitadel.go"
|
|
file := plugin.NewGeneratedFile(filename, ".")
|
|
|
|
_, err = file.Write(buffer.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func loadTemplate(templateData []byte) *template.Template {
|
|
return template.Must(template.New("").
|
|
Parse(string(templateData)))
|
|
}
|
|
|
|
func buildAuthContextField(fields []*protogen.Field, fieldName string) string {
|
|
for _, field := range fields {
|
|
if string(field.Desc.Name()) == fieldName {
|
|
return ".Get" + field.GoName + "()"
|
|
}
|
|
}
|
|
return ""
|
|
}
|