package ldap

import (
	"context"
	"crypto/tls"
	"encoding/base64"
	"errors"
	"net"
	"net/url"
	"strconv"
	"time"
	"unicode/utf8"

	"github.com/go-ldap/ldap/v3"
	"github.com/zitadel/logging"
	"golang.org/x/text/language"

	"github.com/zitadel/zitadel/internal/domain"
	"github.com/zitadel/zitadel/internal/idp"
)

var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
var ErrFailedLogin = errors.New("user failed to login")

var _ idp.Session = (*Session)(nil)

type Session struct {
	Provider *Provider
	loginUrl string
	User     string
	Password string
	Entry    *ldap.Entry
}

// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
	return idp.Redirect(s.loginUrl)
}

func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
	var user *ldap.Entry
	for _, server := range s.Provider.servers {
		user, err = tryBind(server,
			s.Provider.startTLS,
			s.Provider.bindDN,
			s.Provider.bindPassword,
			s.Provider.baseDN,
			s.Provider.getNecessaryAttributes(),
			s.Provider.userObjectClasses,
			s.Provider.userFilters,
			s.User,
			s.Password, s.Provider.timeout)
		// If there were invalid credentials or multiple users with the credentials cancel process
		if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
			return nil, err
		}
		// If a user bind was successful and user is filled continue with login, otherwise try next server
		if err == nil && user != nil {
			break
		}
	}
	if err != nil {
		return nil, err
	}
	s.Entry = user

	return mapLDAPEntryToUser(
		user,
		s.Provider.idAttribute,
		s.Provider.firstNameAttribute,
		s.Provider.lastNameAttribute,
		s.Provider.displayNameAttribute,
		s.Provider.nickNameAttribute,
		s.Provider.preferredUsernameAttribute,
		s.Provider.emailAttribute,
		s.Provider.emailVerifiedAttribute,
		s.Provider.phoneAttribute,
		s.Provider.phoneVerifiedAttribute,
		s.Provider.preferredLanguageAttribute,
		s.Provider.avatarURLAttribute,
		s.Provider.profileAttribute,
	)
}

func tryBind(
	server string,
	startTLS bool,
	bindDN string,
	bindPassword string,
	baseDN string,
	attributes []string,
	objectClasses []string,
	userFilters []string,
	username string,
	password string,
	timeout time.Duration,
) (*ldap.Entry, error) {
	conn, err := getConnection(server, startTLS, timeout)
	if err != nil {
		return nil, err
	}
	defer conn.Close()

	if err := conn.Bind(bindDN, bindPassword); err != nil {
		return nil, err
	}

	return trySearchAndUserBind(
		conn,
		baseDN,
		attributes,
		objectClasses,
		userFilters,
		username,
		password,
		timeout,
	)
}

func getConnection(
	server string,
	startTLS bool,
	timeout time.Duration,
) (*ldap.Conn, error) {
	if timeout == 0 {
		timeout = ldap.DefaultTimeout
	}

	conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
	if err != nil {
		return nil, err
	}

	u, err := url.Parse(server)
	if err != nil {
		return nil, err
	}
	if u.Scheme == "ldaps" && startTLS {
		err = conn.StartTLS(&tls.Config{ServerName: u.Host})
		if err != nil {
			return nil, err
		}
	}
	return conn, nil
}

func trySearchAndUserBind(
	conn *ldap.Conn,
	baseDN string,
	attributes []string,
	objectClasses []string,
	userFilters []string,
	username string,
	password string,
	timeout time.Duration,
) (*ldap.Entry, error) {
	searchQuery := queriesAndToSearchQuery(
		objectClassesToSearchQuery(objectClasses),
		queriesOrToSearchQuery(
			userFiltersToSearchQuery(userFilters, username),
		),
	)

	// Search for user with the unique attribute for the userDN
	searchRequest := ldap.NewSearchRequest(
		baseDN,
		ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
		searchQuery,
		attributes,
		nil,
	)

	sr, err := conn.Search(searchRequest)
	if err != nil {
		return nil, err
	}
	if len(sr.Entries) != 1 {
		logging.WithFields("entries", len(sr.Entries)).Info("ldap: no single user found")
		return nil, ErrNoSingleUser
	}

	user := sr.Entries[0]
	// Bind as the user to verify their password
	if err = conn.Bind(user.DN, password); err != nil {
		logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user bind failed")
		return nil, ErrFailedLogin
	}
	return user, nil
}

func queriesAndToSearchQuery(queries ...string) string {
	if len(queries) == 0 {
		return ""
	}
	if len(queries) == 1 {
		return queries[0]
	}
	joinQueries := "(&"
	for _, s := range queries {
		joinQueries += s
	}
	return joinQueries + ")"
}

func queriesOrToSearchQuery(queries ...string) string {
	if len(queries) == 0 {
		return ""
	}
	if len(queries) == 1 {
		return queries[0]
	}
	joinQueries := "(|"
	for _, s := range queries {
		joinQueries += s
	}
	return joinQueries + ")"
}

func objectClassesToSearchQuery(classes []string) string {
	searchQuery := ""
	for _, class := range classes {
		searchQuery += "(objectClass=" + class + ")"
	}
	return searchQuery
}

func userFiltersToSearchQuery(filters []string, username string) string {
	searchQuery := ""
	for _, filter := range filters {
		searchQuery += "(" + filter + "=" + ldap.EscapeFilter(username) + ")"
	}
	return searchQuery
}

func mapLDAPEntryToUser(
	user *ldap.Entry,
	idAttribute,
	firstNameAttribute,
	lastNameAttribute,
	displayNameAttribute,
	nickNameAttribute,
	preferredUsernameAttribute,
	emailAttribute,
	emailVerifiedAttribute,
	phoneAttribute,
	phoneVerifiedAttribute,
	preferredLanguageAttribute,
	avatarURLAttribute,
	profileAttribute string,
) (_ *User, err error) {
	var emailVerified bool
	if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
		emailVerified, err = strconv.ParseBool(v)
		if err != nil {
			return nil, err
		}
	}
	var phoneVerified bool
	if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
		phoneVerified, err = strconv.ParseBool(v)
		if err != nil {
			return nil, err
		}
	}

	return NewUser(
		getAttributeValue(user, idAttribute),
		getAttributeValue(user, firstNameAttribute),
		getAttributeValue(user, lastNameAttribute),
		getAttributeValue(user, displayNameAttribute),
		getAttributeValue(user, nickNameAttribute),
		getAttributeValue(user, preferredUsernameAttribute),
		domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
		emailVerified,
		domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
		phoneVerified,
		language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
		user.GetAttributeValue(avatarURLAttribute),
		user.GetAttributeValue(profileAttribute),
	), nil
}

func getAttributeValue(user *ldap.Entry, attribute string) string {
	// return an empty string if no attribute is needed
	if attribute == "" {
		return ""
	}
	value := user.GetAttributeValue(attribute)
	if utf8.ValidString(value) {
		return value
	}
	return base64.StdEncoding.EncodeToString(user.GetRawAttributeValue(attribute))
}