2021-11-05 13:40:47 +01:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/x509"
|
2021-11-05 14:55:51 +01:00
|
|
|
"encoding/pem"
|
2021-11-05 13:40:47 +01:00
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
|
|
|
|
ldap "github.com/go-ldap/ldap/v3"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
|
|
)
|
|
|
|
|
|
|
|
type (
|
|
|
|
// LDAP connection encapsulation. This includes the connection itself, as well as a logger
|
|
|
|
// that includes fields related to the LDAP server and a copy of the initial configuration.
|
|
|
|
tLdapConn struct {
|
2021-11-06 10:05:45 +01:00
|
|
|
config tLdapConfig
|
2021-11-05 13:40:47 +01:00
|
|
|
conn *ldap.Conn
|
|
|
|
log *logrus.Entry
|
|
|
|
server int
|
|
|
|
counter uint
|
|
|
|
}
|
|
|
|
|
|
|
|
// LDAP group members
|
|
|
|
ldapGroupMembers map[string][]string
|
|
|
|
)
|
|
|
|
|
|
|
|
// Try to establish a connection to one of the servers
|
2021-11-06 10:05:45 +01:00
|
|
|
func NewLdapConnection(cfg tLdapConfig) *tLdapConn {
|
2021-11-05 13:40:47 +01:00
|
|
|
for i := range cfg.Servers {
|
|
|
|
conn := getLdapServerConnection(cfg, i)
|
|
|
|
if conn != nil {
|
|
|
|
return conn
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2021-11-06 10:05:45 +01:00
|
|
|
// Close a LDAP connection
|
|
|
|
func (conn *tLdapConn) Close() {
|
|
|
|
conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
|
|
|
|
conn.conn.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the base DN
|
|
|
|
func (conn *tLdapConn) BaseDN() string {
|
|
|
|
return conn.config.Structure.BaseDN
|
|
|
|
}
|
|
|
|
|
2021-11-05 13:40:47 +01:00
|
|
|
// Establish a connection to a LDAP server
|
|
|
|
func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
|
|
|
|
if server < 0 || server >= len(cfg.Servers) {
|
|
|
|
logrus.Panic("Invalid server index %d", server)
|
|
|
|
}
|
|
|
|
|
|
|
|
scfg := cfg.Servers[server]
|
|
|
|
dest := fmt.Sprintf("%s:%d", scfg.Host, scfg.Port)
|
|
|
|
log := log.WithFields(logrus.Fields{
|
|
|
|
"ldap_server": dest,
|
|
|
|
"ldap_tls": scfg.TLS,
|
|
|
|
})
|
|
|
|
log.Trace("Establishing LDAP connection")
|
|
|
|
|
2021-12-05 17:21:52 +01:00
|
|
|
tlsConfig := &tls.Config{}
|
|
|
|
if scfg.TLSNoVerify != nil {
|
|
|
|
tlsConfig.InsecureSkipVerify = *scfg.TLSNoVerify
|
2021-11-05 13:40:47 +01:00
|
|
|
}
|
|
|
|
if scfg.TLS != "no" && scfg.CaChain != "" {
|
|
|
|
log := log.WithField("cachain", scfg.CaChain)
|
|
|
|
data, err := ioutil.ReadFile(scfg.CaChain)
|
|
|
|
if err != nil {
|
|
|
|
log.WithField("error", err).Error("Failed to read CA certificate chain")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
pool := x509.NewCertPool()
|
|
|
|
if !pool.AppendCertsFromPEM(data) {
|
|
|
|
log.Error("Could not add CA certificates")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
tlsConfig.RootCAs = pool
|
|
|
|
}
|
|
|
|
|
|
|
|
var err error
|
|
|
|
var lc *ldap.Conn
|
|
|
|
if scfg.TLS == "yes" {
|
|
|
|
lc, err = ldap.DialTLS("tcp", dest, tlsConfig)
|
|
|
|
} else {
|
|
|
|
lc, err = ldap.Dial("tcp", dest)
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
log.WithField("error", err).Error("Failed to connect to the LDAP server")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if scfg.TLS == "starttls" {
|
|
|
|
err = lc.StartTLS(tlsConfig)
|
|
|
|
if err != nil {
|
|
|
|
lc.Close()
|
|
|
|
log.WithField("error", err).Error("StartTLS failed")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if scfg.BindUser != "" {
|
|
|
|
log = log.WithField("ldap_user", scfg.BindUser)
|
|
|
|
err := lc.Bind(scfg.BindUser, scfg.BindPassword)
|
|
|
|
if err != nil {
|
|
|
|
lc.Close()
|
|
|
|
log.WithField("error", err).Error("Could not bind")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
log.Debug("LDAP connection established")
|
|
|
|
return &tLdapConn{
|
2021-11-06 10:05:45 +01:00
|
|
|
config: cfg,
|
2021-11-05 13:40:47 +01:00
|
|
|
conn: lc,
|
|
|
|
log: log,
|
|
|
|
server: server,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run a LDAP query to obtain a single object.
|
|
|
|
func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry) {
|
|
|
|
log := conn.log.WithFields(logrus.Fields{
|
|
|
|
"dn": dn,
|
|
|
|
"attributes": attrs,
|
|
|
|
})
|
|
|
|
log.Trace("Accessing DN")
|
|
|
|
conn.counter++
|
|
|
|
req := ldap.NewSearchRequest(
|
|
|
|
dn,
|
|
|
|
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false,
|
|
|
|
"(objectClass=*)", attrs, nil)
|
|
|
|
res, err := conn.conn.Search(req)
|
|
|
|
if err != nil {
|
|
|
|
log := log.WithField("error", err)
|
|
|
|
ldapError, ok := err.(*ldap.Error)
|
|
|
|
if ok {
|
|
|
|
log = log.WithFields(logrus.Fields{
|
|
|
|
"ldap_result": ldapError.ResultCode,
|
|
|
|
"ldap_message": ldapError.Error(),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
log.Error("LDAP query failed")
|
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
if len(res.Entries) > 1 {
|
|
|
|
log.WithField("results", len(res.Entries)).
|
|
|
|
Warning("LDAP search returned more than 1 record")
|
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
log.Trace("Obtained LDAP object")
|
|
|
|
return true, res.Entries[0]
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get an end entity's certificate from the LDAP
|
2021-11-06 10:05:45 +01:00
|
|
|
func (conn *tLdapConn) GetEndEntityCertificate(dn string) ([]byte, error) {
|
|
|
|
eec := conn.config.Structure.EndEntityCertificate
|
2021-11-05 13:40:47 +01:00
|
|
|
success, entry := conn.getObject(dn, []string{eec})
|
|
|
|
if !success {
|
|
|
|
return nil, fmt.Errorf("Could not read certificate from '%s'", dn)
|
|
|
|
}
|
|
|
|
values := entry.GetRawAttributeValues(eec)
|
|
|
|
nFound := len(values)
|
|
|
|
if nFound != 1 {
|
|
|
|
return nil, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, eec, nFound)
|
|
|
|
}
|
|
|
|
_, err := x509.ParseCertificate(values[0])
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, eec, err)
|
|
|
|
}
|
2021-11-05 14:55:51 +01:00
|
|
|
data := pem.EncodeToMemory(&pem.Block{
|
|
|
|
Type: "CERTIFICATE",
|
|
|
|
Bytes: values[0],
|
|
|
|
})
|
|
|
|
return data, nil
|
2021-11-05 13:40:47 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Get a CA certificate, as well as the value of the chaining field, from
|
|
|
|
// the LDAP.
|
2021-11-06 10:05:45 +01:00
|
|
|
func (conn *tLdapConn) GetCaCertificate(dn string) ([]byte, string, error) {
|
|
|
|
cc := conn.config.Structure.CACertificate
|
|
|
|
chain := conn.config.Structure.CAChaining
|
2021-11-05 13:40:47 +01:00
|
|
|
attrs := []string{cc}
|
|
|
|
if chain != "" {
|
|
|
|
attrs = append(attrs, chain)
|
|
|
|
}
|
|
|
|
|
|
|
|
success, entry := conn.getObject(dn, attrs)
|
|
|
|
if !success {
|
|
|
|
return nil, "", fmt.Errorf("Could not read certificate from '%s'", dn)
|
|
|
|
}
|
|
|
|
|
|
|
|
var ca_cert []byte = nil
|
|
|
|
var chain_dn string = ""
|
|
|
|
|
|
|
|
values := entry.GetRawAttributeValues(cc)
|
|
|
|
nFound := len(values)
|
|
|
|
if nFound > 1 {
|
|
|
|
return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, cc, nFound)
|
|
|
|
} else if nFound == 1 {
|
2021-11-05 14:55:51 +01:00
|
|
|
_, err := x509.ParseCertificate(values[0])
|
2021-11-05 13:40:47 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, "", fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, cc, err)
|
|
|
|
}
|
2021-11-05 14:55:51 +01:00
|
|
|
ca_cert = pem.EncodeToMemory(&pem.Block{
|
|
|
|
Type: "CERTIFICATE",
|
|
|
|
Bytes: values[0],
|
|
|
|
})
|
2021-11-05 13:40:47 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
chval := entry.GetAttributeValues(chain)
|
|
|
|
nFound = len(chval)
|
|
|
|
if nFound > 1 {
|
|
|
|
return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, chain, nFound)
|
|
|
|
} else if nFound == 1 {
|
|
|
|
chain_dn = chval[0]
|
|
|
|
}
|
|
|
|
|
|
|
|
return ca_cert, chain_dn, nil
|
|
|
|
}
|