From 0e642c85a6010e7eeac3ddfe5e5d5a492036baef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= Date: Fri, 5 Nov 2021 13:40:47 +0100 Subject: [PATCH] LDAP connection code * Code that connects to LDAP servers and send queries * Helper functions to fetch CA certificates or EE certificates --- go.mod | 4 ++ ldap.go | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 ldap.go diff --git a/go.mod b/go.mod index 80a2a59..f1b9aff 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,16 @@ go 1.17 require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gemnasium/logrus-graylog-hook/v3 v3.0.3 + github.com/go-ldap/ldap/v3 v3.4.1 github.com/karrick/golf v1.4.0 github.com/sirupsen/logrus v1.8.1 gopkg.in/yaml.v2 v2.4.0 ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect + github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect github.com/pkg/errors v0.8.1 // indirect + golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect ) diff --git a/ldap.go b/ldap.go new file mode 100644 index 0000000..fd40a0b --- /dev/null +++ b/ldap.go @@ -0,0 +1,208 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + + ldap "github.com/go-ldap/ldap/v3" + "github.com/sirupsen/logrus" +) + +type ( + // LDAP connection encapsulation. This includes the connection itself, as well as a logger + // that includes fields related to the LDAP server and a copy of the initial configuration. + tLdapConn struct { + Config tLdapConfig + conn *ldap.Conn + log *logrus.Entry + server int + counter uint + } + + // LDAP group members + ldapGroupMembers map[string][]string +) + +// Try to establish a connection to one of the servers +func getLdapConnection(cfg tLdapConfig) *tLdapConn { + for i := range cfg.Servers { + conn := getLdapServerConnection(cfg, i) + if conn != nil { + return conn + } + } + return nil +} + +// Establish a connection to a LDAP server +func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn { + if server < 0 || server >= len(cfg.Servers) { + logrus.Panic("Invalid server index %d", server) + } + + scfg := cfg.Servers[server] + dest := fmt.Sprintf("%s:%d", scfg.Host, scfg.Port) + log := log.WithFields(logrus.Fields{ + "ldap_server": dest, + "ldap_tls": scfg.TLS, + }) + log.Trace("Establishing LDAP connection") + + tlsConfig := &tls.Config{ + InsecureSkipVerify: scfg.TLSNoVerify, + } + if scfg.TLS != "no" && scfg.CaChain != "" { + log := log.WithField("cachain", scfg.CaChain) + data, err := ioutil.ReadFile(scfg.CaChain) + if err != nil { + log.WithField("error", err).Error("Failed to read CA certificate chain") + return nil + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(data) { + log.Error("Could not add CA certificates") + return nil + } + tlsConfig.RootCAs = pool + } + + var err error + var lc *ldap.Conn + if scfg.TLS == "yes" { + lc, err = ldap.DialTLS("tcp", dest, tlsConfig) + } else { + lc, err = ldap.Dial("tcp", dest) + } + if err != nil { + log.WithField("error", err).Error("Failed to connect to the LDAP server") + return nil + } + + if scfg.TLS == "starttls" { + err = lc.StartTLS(tlsConfig) + if err != nil { + lc.Close() + log.WithField("error", err).Error("StartTLS failed") + return nil + } + } + + if scfg.BindUser != "" { + log = log.WithField("ldap_user", scfg.BindUser) + err := lc.Bind(scfg.BindUser, scfg.BindPassword) + if err != nil { + lc.Close() + log.WithField("error", err).Error("Could not bind") + return nil + } + } + log.Debug("LDAP connection established") + return &tLdapConn{ + Config: cfg, + conn: lc, + log: log, + server: server, + } +} + +// Run a LDAP query to obtain a single object. +func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry) { + log := conn.log.WithFields(logrus.Fields{ + "dn": dn, + "attributes": attrs, + }) + log.Trace("Accessing DN") + conn.counter++ + req := ldap.NewSearchRequest( + dn, + ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false, + "(objectClass=*)", attrs, nil) + res, err := conn.conn.Search(req) + if err != nil { + log := log.WithField("error", err) + ldapError, ok := err.(*ldap.Error) + if ok { + log = log.WithFields(logrus.Fields{ + "ldap_result": ldapError.ResultCode, + "ldap_message": ldapError.Error(), + }) + } + log.Error("LDAP query failed") + return false, nil + } + if len(res.Entries) > 1 { + log.WithField("results", len(res.Entries)). + Warning("LDAP search returned more than 1 record") + return false, nil + } + log.Trace("Obtained LDAP object") + return true, res.Entries[0] +} + +// Close a LDAP connection +func (conn *tLdapConn) close() { + conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection") + conn.conn.Close() +} + +// Get an end entity's certificate from the LDAP +func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) { + eec := conn.Config.Structure.EndEntityCertificate + success, entry := conn.getObject(dn, []string{eec}) + if !success { + return nil, fmt.Errorf("Could not read certificate from '%s'", dn) + } + values := entry.GetRawAttributeValues(eec) + nFound := len(values) + if nFound != 1 { + return nil, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, eec, nFound) + } + _, err := x509.ParseCertificate(values[0]) + if err != nil { + return nil, fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, eec, err) + } + return values[0], nil +} + +// Get a CA certificate, as well as the value of the chaining field, from +// the LDAP. +func (conn *tLdapConn) getCaCertificate(dn string) ([]byte, string, error) { + cc := conn.Config.Structure.CACertificate + chain := conn.Config.Structure.CAChaining + attrs := []string{cc} + if chain != "" { + attrs = append(attrs, chain) + } + + success, entry := conn.getObject(dn, attrs) + if !success { + return nil, "", fmt.Errorf("Could not read certificate from '%s'", dn) + } + + var ca_cert []byte = nil + var chain_dn string = "" + + values := entry.GetRawAttributeValues(cc) + nFound := len(values) + if nFound > 1 { + return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, cc, nFound) + } else if nFound == 1 { + ca_cert = values[0] + _, err := x509.ParseCertificate(ca_cert) + if err != nil { + return nil, "", fmt.Errorf("DN %s - invalid certificate in attribute %s : %w", dn, cc, err) + } + } + + chval := entry.GetAttributeValues(chain) + nFound = len(chval) + if nFound > 1 { + return ca_cert, chain_dn, fmt.Errorf("DN %s - one value expected for %s, %d values found", dn, chain, nFound) + } else if nFound == 1 { + chain_dn = chval[0] + } + + return ca_cert, chain_dn, nil +}