LDAP connection code

* Code that connects to LDAP servers and send queries
  * Helper functions to fetch CA certificates or EE certificates
This commit is contained in:
Emmanuel BENOîT 2021-11-05 13:40:47 +01:00
parent f971c1e961
commit 0e642c85a6
2 changed files with 212 additions and 0 deletions

4
go.mod
View file

@ -5,12 +5,16 @@ go 1.17
require ( require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gemnasium/logrus-graylog-hook/v3 v3.0.3 github.com/gemnasium/logrus-graylog-hook/v3 v3.0.3
github.com/go-ldap/ldap/v3 v3.4.1
github.com/karrick/golf v1.4.0 github.com/karrick/golf v1.4.0
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
) )
require ( require (
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect
github.com/pkg/errors v0.8.1 // indirect github.com/pkg/errors v0.8.1 // indirect
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect
) )

208
ldap.go Normal file
View file

@ -0,0 +1,208 @@
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
ldap "github.com/go-ldap/ldap/v3"
"github.com/sirupsen/logrus"
)
type (
// LDAP connection encapsulation. This includes the connection itself, as well as a logger
// that includes fields related to the LDAP server and a copy of the initial configuration.
tLdapConn struct {
Config tLdapConfig
conn *ldap.Conn
log *logrus.Entry
server int
counter uint
}
// LDAP group members
ldapGroupMembers map[string][]string
)
// Try to establish a connection to one of the servers
func getLdapConnection(cfg tLdapConfig) *tLdapConn {
for i := range cfg.Servers {
conn := getLdapServerConnection(cfg, i)
if conn != nil {
return conn
}
}
return nil
}
// Establish a connection to a LDAP server
func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
if server < 0 || server >= len(cfg.Servers) {
logrus.Panic("Invalid server index %d", server)
}
scfg := cfg.Servers[server]
dest := fmt.Sprintf("%s:%d", scfg.Host, scfg.Port)
log := log.WithFields(logrus.Fields{
"ldap_server": dest,
"ldap_tls": scfg.TLS,
})
log.Trace("Establishing LDAP connection")
tlsConfig := &tls.Config{
InsecureSkipVerify: scfg.TLSNoVerify,
}
if scfg.TLS != "no" && scfg.CaChain != "" {
log := log.WithField("cachain", scfg.CaChain)
data, err := ioutil.ReadFile(scfg.CaChain)
if err != nil {
log.WithField("error", err).Error("Failed to read CA certificate chain")
return nil
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(data) {
log.Error("Could not add CA certificates")
return nil
}
tlsConfig.RootCAs = pool
}
var err error
var lc *ldap.Conn
if scfg.TLS == "yes" {
lc, err = ldap.DialTLS("tcp", dest, tlsConfig)
} else {
lc, err = ldap.Dial("tcp", dest)
}
if err != nil {
log.WithField("error", err).Error("Failed to connect to the LDAP server")
return nil
}
if scfg.TLS == "starttls" {
err = lc.StartTLS(tlsConfig)
if err != nil {
lc.Close()
log.WithField("error", err).Error("StartTLS failed")
return nil
}
}
if scfg.BindUser != "" {
log = log.WithField("ldap_user", scfg.BindUser)
err := lc.Bind(scfg.BindUser, scfg.BindPassword)
if err != nil {
lc.Close()
log.WithField("error", err).Error("Could not bind")
return nil
}
}
log.Debug("LDAP connection established")
return &tLdapConn{
Config: cfg,
conn: lc,
log: log,
server: server,
}
}
// Run a LDAP query to obtain a single object.
func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry) {
log := conn.log.WithFields(logrus.Fields{
"dn": dn,
"attributes": attrs,
})
log.Trace("Accessing DN")
conn.counter++
req := ldap.NewSearchRequest(
dn,
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false,
"(objectClass=*)", attrs, nil)
res, err := conn.conn.Search(req)
if err != nil {
log := log.WithField("error", err)
ldapError, ok := err.(*ldap.Error)
if ok {
log = log.WithFields(logrus.Fields{
"ldap_result": ldapError.ResultCode,
"ldap_message": ldapError.Error(),
})
}
log.Error("LDAP query failed")
return false, nil
}
if len(res.Entries) > 1 {
log.WithField("results", len(res.Entries)).
Warning("LDAP search returned more than 1 record")
return false, nil
}
log.Trace("Obtained LDAP object")
return true, res.Entries[0]
}
// Close a LDAP connection
func (conn *tLdapConn) close() {
conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
conn.conn.Close()
}
// Get an end entity's certificate from the LDAP
func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) {
eec := conn.Config.Structure.EndEntityCertificate
success, entry := conn.getObject(dn, []string{eec})
if !success {
return nil, fmt.Errorf("Could not read certificate from '%s'", dn)
}
values := entry.GetRawAttributeValues(eec)
nFound := len(values)
if nFound != 1 {
return nil, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, eec, nFound)
}
_, err := x509.ParseCertificate(values[0])
if err != nil {
return nil, fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, eec, err)
}
return values[0], nil
}
// Get a CA certificate, as well as the value of the chaining field, from
// the LDAP.
func (conn *tLdapConn) getCaCertificate(dn string) ([]byte, string, error) {
cc := conn.Config.Structure.CACertificate
chain := conn.Config.Structure.CAChaining
attrs := []string{cc}
if chain != "" {
attrs = append(attrs, chain)
}
success, entry := conn.getObject(dn, attrs)
if !success {
return nil, "", fmt.Errorf("Could not read certificate from '%s'", dn)
}
var ca_cert []byte = nil
var chain_dn string = ""
values := entry.GetRawAttributeValues(cc)
nFound := len(values)
if nFound > 1 {
return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, cc, nFound)
} else if nFound == 1 {
ca_cert = values[0]
_, err := x509.ParseCertificate(ca_cert)
if err != nil {
return nil, "", fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, cc, err)
}
}
chval := entry.GetAttributeValues(chain)
nFound = len(chval)
if nFound > 1 {
return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, chain, nFound)
} else if nFound == 1 {
chain_dn = chval[0]
}
return ca_cert, chain_dn, nil
}