graylog-groups/ldap.go

213 lines
5.4 KiB
Go
Raw Permalink Normal View History

package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"strings"
"github.com/go-ldap/ldap"
"github.com/sirupsen/logrus"
)
type (
// LDAP connection encapsulation. This includes the connection itself, as well as a logger
// that includes fields related to the LDAP server and a copy of the initial configuration.
ldapConn struct {
conn *ldap.Conn
log *logrus.Entry
2021-02-13 18:26:37 +01:00
cfg ldapConfig
usernames map[string]string
counter uint
}
// LDAP group members
2021-02-13 18:26:37 +01:00
ldapGroupMembers map[string][]string
)
// Establish a connection to the LDAP server
2021-02-13 18:26:37 +01:00
func getLdapConnection(cfg ldapConfig) *ldapConn {
dest := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
log := log.WithFields(logrus.Fields{
"ldap_server": dest,
2021-02-13 18:26:37 +01:00
"ldap_tls": cfg.TLS,
})
log.Trace("Establishing LDAP connection")
tlsConfig := &tls.Config{
2021-02-13 18:26:37 +01:00
InsecureSkipVerify: cfg.TLSNoVerify,
}
2021-02-13 18:26:37 +01:00
if cfg.TLS != "no" && cfg.CaChain != "" {
log := log.WithField("cachain", cfg.CaChain)
data, err := ioutil.ReadFile(cfg.CaChain)
if err != nil {
log.WithField("error", err).Fatal("Failed to read CA certificate chain")
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(data) {
log.Fatal("Could not add CA certificates")
}
tlsConfig.RootCAs = pool
}
var err error
var lc *ldap.Conn
2021-02-13 18:26:37 +01:00
if cfg.TLS == "yes" {
lc, err = ldap.DialTLS("tcp", dest, tlsConfig)
} else {
lc, err = ldap.Dial("tcp", dest)
}
if err != nil {
log.WithField("error", err).Fatal("Failed to connect to the LDAP server")
}
2021-02-13 18:26:37 +01:00
if cfg.TLS == "starttls" {
err = lc.StartTLS(tlsConfig)
if err != nil {
lc.Close()
log.WithField("error", err).Fatal("StartTLS failed")
}
}
if cfg.BindUser != "" {
log = log.WithField("ldap_user", cfg.BindUser)
err := lc.Bind(cfg.BindUser, cfg.BindPassword)
if err != nil {
lc.Close()
log.WithField("error", err).Fatal("Could not bind")
}
}
log.Debug("LDAP connection established")
return &ldapConn{
conn: lc,
log: log,
cfg: cfg,
usernames: make(map[string]string),
}
}
// Run a LDAP query to obtain a single object.
func (conn *ldapConn) query(dn string, attrs []string) (bool, *ldap.Entry) {
log := conn.log.WithFields(logrus.Fields{
"dn": dn,
"attributes": attrs,
})
log.Trace("Accessing DN")
conn.counter++
req := ldap.NewSearchRequest(
dn,
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false,
"(objectClass=*)", attrs, nil)
res, err := conn.conn.Search(req)
if err != nil {
ldapError, ok := err.(*ldap.Error)
if ok && ldapError.ResultCode == ldap.LDAPResultNoSuchObject {
log.Trace("DN not found")
return false, nil
}
log.WithField("error", err).Fatal("LDAP query failed")
}
if len(res.Entries) > 1 {
log.WithField("results", len(res.Entries)).
Warning("LDAP search returned more than 1 record")
return false, nil
}
log.Trace("Obtained LDAP object")
return true, res.Entries[0]
}
// Close a LDAP connection
func (conn *ldapConn) close() {
conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
conn.conn.Close()
}
// Read a username from a LDAP record based on a DN.
func (conn *ldapConn) readUsername(dn string) (bool, string) {
log := conn.log.WithField("dn", dn)
log.Debug("LDAP username lookup")
ok, res := conn.query(dn, []string{conn.cfg.UsernameAttr})
if !ok {
return false, ""
}
values := res.GetAttributeValues(conn.cfg.UsernameAttr)
if len(values) != 1 {
log.WithField("count", len(values)).
Warning("Attribute does not have 1 value exactly.")
return false, ""
}
log.WithField("username", values[0]).Trace("Mapped DN to username")
return true, values[0]
}
// Extract an username from something that may be an username or a DN.
func (conn *ldapConn) usernameFromMember(member string) (bool, string) {
eqPos := strings.Index(member, "=")
if eqPos == -1 {
return true, member
}
if conn.cfg.UsernameAttr != "" {
return conn.readUsername(member)
}
commaPos := strings.Index(member, ",")
if commaPos == -1 {
return true, member[eqPos+1:]
}
if eqPos > commaPos {
log.WithField("member", member).Warning("Couldn't extract user name")
return false, ""
}
return true, member[eqPos+1 : commaPos]
}
// Read a username from the cache. If the username is not cached, extract it or request it from
// the LDAP.
func (conn *ldapConn) getUsername(member string) (bool, string) {
name, ok := conn.usernames[member]
if ok {
return true, name
}
ok, name = conn.usernameFromMember(member)
if ok {
conn.usernames[member] = name
}
return ok, name
}
// Read the list of members from a LDAP group
func (conn *ldapConn) getGroupMembers(group string) (members []string) {
log := conn.log.WithField("group", group)
log.Trace("Obtaining group members")
ok, entry := conn.query(group, conn.cfg.MemberFields)
if !ok {
return
}
for _, attr := range conn.cfg.MemberFields {
values := entry.GetAttributeValues(attr)
if len(values) == 0 {
continue
}
for _, value := range values {
ok, name := conn.getUsername(value)
if ok {
members = append(members, name)
}
}
break
}
log.WithField("members", members).Info("Obtained group members")
return
}
// Read the list of group members from the LDAP server for all groups in the mapping section.
2021-02-13 18:26:37 +01:00
func readLdapGroups(cfg configuration) ldapGroupMembers {
conn := getLdapConnection(cfg.LDAP)
defer conn.close()
2021-02-13 18:26:37 +01:00
groups := make(ldapGroupMembers)
for group := range cfg.Mapping {
groups[group] = conn.getGroupMembers(group)
}
return groups
}