package main

import (
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"os"

	ldap "github.com/go-ldap/ldap/v3"
	"github.com/sirupsen/logrus"
)

type (
	// LDAP connection encapsulation. This includes the connection itself, as well as a logger
	// that includes fields related to the LDAP server and a copy of the initial configuration.
	tLdapConn struct {
		config  tLdapConfig
		conn    *ldap.Conn
		log     *logrus.Entry
		server  int
		counter uint
	}
)

// Try to establish a connection to one of the servers
func NewLdapConnection(cfg tLdapConfig) *tLdapConn {
	for i := range cfg.Servers {
		conn := getLdapServerConnection(cfg, i)
		if conn != nil {
			return conn
		}
	}
	return nil
}

// Close a LDAP connection
func (conn *tLdapConn) Close() {
	conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
	conn.conn.Close()
}

// Get the base DN
func (conn *tLdapConn) BaseDN() string {
	return conn.config.Structure.BaseDN
}

// Establish a connection to a LDAP server
func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
	if server < 0 || server >= len(cfg.Servers) {
		logrus.Panicf("Invalid server index %d", server)
	}

	scfg := cfg.Servers[server]
	dest := fmt.Sprintf("%s:%d", scfg.Host, scfg.Port)
	log := log.WithFields(logrus.Fields{
		"ldap_server": dest,
		"ldap_tls":    scfg.TLS,
	})
	log.Trace("Establishing LDAP connection")

	tlsConfig := &tls.Config{}
	if scfg.TLSNoVerify != nil {
		tlsConfig.InsecureSkipVerify = *scfg.TLSNoVerify
	}
	if scfg.TLS != "no" && scfg.CaChain != "" {
		log := log.WithField("cachain", scfg.CaChain)
		data, err := os.ReadFile(scfg.CaChain)
		if err != nil {
			log.WithField("error", err).Error("Failed to read CA certificate chain")
			return nil
		}
		pool := x509.NewCertPool()
		if !pool.AppendCertsFromPEM(data) {
			log.Error("Could not add CA certificates")
			return nil
		}
		tlsConfig.RootCAs = pool
	}

	var err error
	var lc *ldap.Conn
	if scfg.TLS == "yes" {
		lc, err = ldap.DialURL("ldaps://"+dest, ldap.DialWithTLSConfig(tlsConfig))
	} else {
		lc, err = ldap.DialURL("ldap://"+dest, ldap.DialWithTLSConfig(tlsConfig))
	}
	if err != nil {
		log.WithField("error", err).Error("Failed to connect to the LDAP server")
		return nil
	}

	if scfg.TLS == "starttls" {
		err = lc.StartTLS(tlsConfig)
		if err != nil {
			lc.Close()
			log.WithField("error", err).Error("StartTLS failed")
			return nil
		}
	}

	if scfg.BindUser != "" {
		log = log.WithField("ldap_user", scfg.BindUser)
		err := lc.Bind(scfg.BindUser, scfg.BindPassword)
		if err != nil {
			lc.Close()
			log.WithField("error", err).Error("Could not bind")
			return nil
		}
	}
	log.Debug("LDAP connection established")
	return &tLdapConn{
		config: cfg,
		conn:   lc,
		log:    log,
		server: server,
	}
}

// Run a LDAP query to obtain a single object.
func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry) {
	log := conn.log.WithFields(logrus.Fields{
		"dn":         dn,
		"attributes": attrs,
	})
	log.Trace("Accessing DN")
	conn.counter++
	req := ldap.NewSearchRequest(
		dn,
		ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false,
		"(objectClass=*)", attrs, nil)
	res, err := conn.conn.Search(req)
	if err != nil {
		log := log.WithField("error", err)
		ldapError, ok := err.(*ldap.Error)
		if ok {
			log = log.WithFields(logrus.Fields{
				"ldap_result":  ldapError.ResultCode,
				"ldap_message": ldapError.Error(),
			})
		}
		log.Error("LDAP query failed")
		return false, nil
	}
	if len(res.Entries) > 1 {
		log.WithField("results", len(res.Entries)).
			Warning("LDAP search returned more than 1 record")
		return false, nil
	}
	log.Trace("Obtained LDAP object")
	return true, res.Entries[0]
}

// Get an end entity's certificate from the LDAP
func (conn *tLdapConn) GetEndEntityCertificate(dn string) ([]byte, error) {
	eec := conn.config.Structure.EndEntityCertificate
	success, entry := conn.getObject(dn, []string{eec})
	if !success {
		return nil, fmt.Errorf("Could not read certificate from '%s'", dn)
	}
	values := entry.GetRawAttributeValues(eec)
	nFound := len(values)
	if nFound != 1 {
		return nil, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, eec, nFound)
	}
	_, err := x509.ParseCertificate(values[0])
	if err != nil {
		return nil, fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, eec, err)
	}
	data := pem.EncodeToMemory(&pem.Block{
		Type:  "CERTIFICATE",
		Bytes: values[0],
	})
	return data, nil
}

// Get a CA certificate, as well as the value of the chaining field, from
// the LDAP.
func (conn *tLdapConn) GetCaCertificate(dn string) ([]byte, string, error) {
	cc := conn.config.Structure.CACertificate
	chain := conn.config.Structure.CAChaining
	attrs := []string{cc}
	if chain != "" {
		attrs = append(attrs, chain)
	}

	success, entry := conn.getObject(dn, attrs)
	if !success {
		return nil, "", fmt.Errorf("Could not read certificate from '%s'", dn)
	}

	var ca_cert []byte = nil
	var chain_dn string = ""

	values := entry.GetRawAttributeValues(cc)
	nFound := len(values)
	if nFound > 1 {
		return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, cc, nFound)
	} else if nFound == 1 {
		_, err := x509.ParseCertificate(values[0])
		if err != nil {
			return nil, "", fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, cc, err)
		}
		ca_cert = pem.EncodeToMemory(&pem.Block{
			Type:  "CERTIFICATE",
			Bytes: values[0],
		})
	}

	chval := entry.GetAttributeValues(chain)
	nFound = len(chval)
	if nFound > 1 {
		return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, chain, nFound)
	} else if nFound == 1 {
		chain_dn = chval[0]
	}

	return ca_cert, chain_dn, nil
}