diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index 6f70ed00f..af009af4f 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -14,6 +14,7 @@ "strconv" "strings" + "go4.org/mem" "inet.af/netaddr" "tailscale.com/types/wgkey" ) @@ -56,25 +57,14 @@ func parseEndpoint(s string) (host string, port uint16, err error) { return host, uint16(uport), nil } -func parseKeyHex(s string) (*wgkey.Key, error) { - k, err := hex.DecodeString(s) - if err != nil { - return nil, &ParseError{"Invalid key: " + err.Error(), s} +// memROCut separates a mem.RO at the separator if it exists, otherwise +// it returns two empty ROs and reports that it was not found. +func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { + if i := mem.IndexByte(s, sep); i >= 0 { + return s.SliceTo(i), s.SliceFrom(i + 1), true } - if len(k) != wgkey.Size { - return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} - } - var key wgkey.Key - copy(key[:], k) - return &key, nil -} - -// stringsCut is strings.Cut from proposed https://github.com/golang/go/issues/46336. -func stringsCut(s, sep string) (before, after string, found bool) { - if i := strings.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):], true - } - return s, "", false + found = false + return } // FromUAPI generates a Config from r. @@ -87,23 +77,23 @@ func FromUAPI(r io.Reader) (*Config, error) { scanner := bufio.NewScanner(r) for scanner.Scan() { - line := scanner.Text() - if line == "" { + line := mem.B(scanner.Bytes()) + if line.Len() == 0 { continue } - key, value, ok := stringsCut(line, "=") + key, value, ok := memROCut(line, '=') if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line) + return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) } - valueBytes := scanner.Bytes()[len(key)+1:] + valueBytes := scanner.Bytes()[key.Len()+1:] - if key == "public_key" { + if key.EqualString("public_key") { if deviceConfig { deviceConfig = false } // Load/create the peer we are now configuring. var err error - peer, err = cfg.handlePublicKeyLine(value) + peer, err = cfg.handlePublicKeyLine(valueBytes) if err != nil { return nil, err } @@ -112,7 +102,7 @@ func FromUAPI(r io.Reader) (*Config, error) { var err error if deviceConfig { - err = cfg.handleDeviceLine(key, value) + err = cfg.handleDeviceLine(key, value, valueBytes) } else { err = cfg.handlePeerLine(peer, key, value, valueBytes) } @@ -128,63 +118,73 @@ func FromUAPI(r io.Reader) (*Config, error) { return cfg, nil } -func (cfg *Config) handleDeviceLine(key, value string) error { - switch key { - case "private_key": - k, err := parseKeyHex(value) - if err != nil { - return err - } - // wireguard-go guarantees not to send zero value; private keys are already clamped. - cfg.PrivateKey = wgkey.Private(*k) - case "listen_port": - // ignore - case "fwmark": - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %v", key) +func parseKeyHex(s []byte, dst []byte) error { + n, err := hex.Decode(dst, s) + if err != nil { + return &ParseError{"Invalid key: " + err.Error(), string(s)} + } + if n != wgkey.Size { + return &ParseError{"Keys must decode to exactly 32 bytes", string(s)} } return nil } -func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) { - k, err := parseKeyHex(value) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, Peer{}) - peer := &cfg.Peers[len(cfg.Peers)-1] - peer.PublicKey = *k - return peer, nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, key, value string, valueBytes []byte) error { - switch key { - case "endpoint": - err := json.Unmarshal(valueBytes, &peer.Endpoints) - if err != nil { +func (cfg *Config) handleDeviceLine(key, value mem.RO, valueBytes []byte) error { + switch { + case key.EqualString("private_key"): + // wireguard-go guarantees not to send zero value; private keys are already clamped. + if err := parseKeyHex(valueBytes, cfg.PrivateKey[:]); err != nil { return err } - case "persistent_keepalive_interval": - n, err := strconv.ParseUint(value, 10, 16) + case key.EqualString("listen_port") || key.EqualString("fwmark"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy()) + } + return nil +} + +func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { + p := Peer{} + if err := parseKeyHex(valueBytes, p.PublicKey[:]); err != nil { + return nil, err + } + cfg.Peers = append(cfg.Peers, p) + return &cfg.Peers[len(cfg.Peers)-1], nil +} + +func (cfg *Config) handlePeerLine(peer *Peer, key, value mem.RO, valueBytes []byte) error { + switch { + case key.EqualString("endpoint"): + if err := json.Unmarshal(valueBytes, &peer.Endpoints); err != nil { + return err + } + case key.EqualString("persistent_keepalive_interval"): + n, err := mem.ParseUint(value, 10, 16) if err != nil { return err } peer.PersistentKeepalive = uint16(n) - case "allowed_ip": - ipp, err := netaddr.ParseIPPrefix(value) + case key.EqualString("allowed_ip"): + ipp := netaddr.IPPrefix{} + err := ipp.UnmarshalText(valueBytes) if err != nil { return err } peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case "protocol_version": - if value != "1" { - return fmt.Errorf("invalid protocol version: %v", value) + case key.EqualString("protocol_version"): + if !value.EqualString("1") { + return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) } - case "replace_allowed_ips", "preshared_key", "last_handshake_time_sec", "last_handshake_time_nsec", "tx_bytes", "rx_bytes": - // ignore + case key.EqualString("replace_allowed_ips") || + key.EqualString("preshared_key") || + key.EqualString("last_handshake_time_sec") || + key.EqualString("last_handshake_time_nsec") || + key.EqualString("tx_bytes") || + key.EqualString("rx_bytes"): + // ignore default: - return fmt.Errorf("unexpected IpcGetOperation key: %v", key) + return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy()) } return nil }