zitadel/internal/actions/http_module.go
Tim Möhlmann f680dd934d
refactor: rename package errors to zerrors (#7039)
* chore: rename package errors to zerrors

* rename package errors to gerrors

* fix error related linting issues

* fix zitadel error assertion

* fix gosimple linting issues

* fix deprecated linting issues

* resolve gci linting issues

* fix import structure

---------

Co-authored-by: Elio Bischof <elio@zitadel.com>
2023-12-08 15:30:55 +01:00

197 lines
4.5 KiB
Go

package actions
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/dop251/goja"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/zerrors"
)
func WithHTTP(ctx context.Context) Option {
return func(c *runConfig) {
c.modules["zitadel/http"] = func(runtime *goja.Runtime, module *goja.Object) {
requireHTTP(ctx, &http.Client{Transport: new(transport)}, runtime, module)
}
}
}
type HTTP struct {
runtime *goja.Runtime
client *http.Client
}
func requireHTTP(ctx context.Context, client *http.Client, runtime *goja.Runtime, module *goja.Object) {
c := &HTTP{
client: client,
runtime: runtime,
}
o := module.Get("exports").(*goja.Object)
logging.OnError(o.Set("fetch", c.fetch(ctx))).Warn("unable to set module")
}
type fetchConfig struct {
Method string
Headers http.Header
Body io.Reader
}
var defaultFetchConfig = fetchConfig{
Method: http.MethodGet,
Headers: http.Header{
"Content-Type": []string{"application/json"},
"Accept": []string{"application/json"},
},
}
func (c *HTTP) fetchConfigFromArg(arg *goja.Object, config *fetchConfig) (err error) {
for _, key := range arg.Keys() {
switch key {
case "headers":
config.Headers = parseHeaders(arg.Get(key).ToObject(c.runtime))
case "method":
config.Method = arg.Get(key).String()
case "body":
body, err := arg.Get(key).ToObject(c.runtime).MarshalJSON()
if err != nil {
return err
}
config.Body = bytes.NewReader(body)
default:
return zerrors.ThrowInvalidArgument(nil, "ACTIO-OfUeA", "key is invalid")
}
}
return nil
}
type response struct {
Body string
Status int
Headers map[string][]string
runtime *goja.Runtime
}
func (r *response) Json() goja.Value {
var val interface{}
if err := json.Unmarshal([]byte(r.Body), &val); err != nil {
panic(err)
}
return r.runtime.ToValue(val)
}
func (r *response) Text() goja.Value {
return r.runtime.ToValue(r.Body)
}
func (c *HTTP) fetch(ctx context.Context) func(call goja.FunctionCall) goja.Value {
return func(call goja.FunctionCall) goja.Value {
req := c.buildHTTPRequest(ctx, call.Arguments)
if deadline, ok := ctx.Deadline(); ok {
c.client.Timeout = time.Until(deadline)
}
res, err := c.client.Do(req)
if err != nil {
logging.WithError(err).Debug("call failed")
panic(err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
logging.WithError(err).Warn("unable to parse body")
panic("unable to read response body")
}
return c.runtime.ToValue(&response{Status: res.StatusCode, Body: string(body), runtime: c.runtime})
}
}
// the first argument has to be a string and is required
// the second agrument is optional and an object with the following fields possible:
// - `Headers`: map with string key and value of type string or string array
// - `Body`: json body of the request
// - `Method`: http method type
func (c *HTTP) buildHTTPRequest(ctx context.Context, args []goja.Value) (req *http.Request) {
if len(args) > 2 {
logging.WithFields("count", len(args)).Debug("more than 2 args provided")
panic("too many args")
}
if len(args) == 0 {
panic("no url provided")
}
config := defaultFetchConfig
var err error
if len(args) == 2 {
if err = c.fetchConfigFromArg(args[1].ToObject(c.runtime), &config); err != nil {
panic(err)
}
}
req, err = http.NewRequestWithContext(ctx, config.Method, args[0].Export().(string), config.Body)
if err != nil {
panic(err)
}
req.Header = config.Headers
return req
}
func parseHeaders(headers *goja.Object) http.Header {
h := make(http.Header, len(headers.Keys()))
for _, k := range headers.Keys() {
header := headers.Get(k).Export()
var values []string
switch headerValue := header.(type) {
case string:
values = strings.Split(headerValue, ",")
case []any:
for _, v := range headerValue {
values = append(values, v.(string))
}
}
for _, v := range values {
h.Add(k, strings.TrimSpace(v))
}
}
return h
}
type transport struct{}
func (*transport) RoundTrip(req *http.Request) (*http.Response, error) {
if httpConfig == nil {
return http.DefaultTransport.RoundTrip(req)
}
if isHostBlocked(httpConfig.DenyList, req.URL) {
return nil, zerrors.ThrowInvalidArgument(nil, "ACTIO-N72d0", "host is denied")
}
return http.DefaultTransport.RoundTrip(req)
}
func isHostBlocked(denyList []AddressChecker, address *url.URL) bool {
for _, blocked := range denyList {
if blocked.Matches(address.Hostname()) {
return true
}
}
return false
}
type AddressChecker interface {
Matches(string) bool
}