From a2606b5b891dd608aab3145dcbe6c1c8324129d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= Date: Sat, 6 Nov 2021 10:05:45 +0100 Subject: [PATCH] Minor refactoring * Made some LDAP methods public * Made the LDAP connection's config field private --- buildcert.go | 14 +++++++------- ldap.go | 33 +++++++++++++++++++-------------- main.go | 4 ++-- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/buildcert.go b/buildcert.go index b24f4de..1028c86 100644 --- a/buildcert.go +++ b/buildcert.go @@ -201,13 +201,13 @@ func (b *tCertificateBuilder) appendPem(input string) error { // Append the main, end-entity certificate from the LDAP func (b *tCertificateBuilder) appendCertificate() error { if b.config.Certificate != "" { - dn := b.conn.Config.Structure.BaseDN + dn := b.conn.BaseDN() if dn != "" { dn = "," + dn } dn = b.config.Certificate + dn b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP") - data, err := b.conn.getEndEntityCertificate(dn) + data, err := b.conn.GetEndEntityCertificate(dn) if err != nil { return err } @@ -230,13 +230,13 @@ func (b *tCertificateBuilder) appendCaCertificates() error { // Append CA certificates based on a list of DNs func (b *tCertificateBuilder) appendListedCaCerts() error { - bdn := b.conn.Config.Structure.BaseDN + bdn := b.conn.BaseDN() if bdn != "" { bdn = "," + bdn } for _, dn := range b.config.CACertificates { b.logger.WithField("dn", dn+bdn).Debug("Adding CA certificate from LDAP") - data, _, err := b.conn.getCaCertificate(dn + bdn) + data, _, err := b.conn.GetCaCertificate(dn + bdn) if err != nil { return err } @@ -252,11 +252,11 @@ func (b *tCertificateBuilder) appendListedCaCerts() error { func (b *tCertificateBuilder) appendChainedCaCerts() error { nFound := 0 dn := b.config.CAChainOf - if b.conn.Config.Structure.BaseDN != "" { - dn = dn + "," + b.conn.Config.Structure.BaseDN + if b.conn.BaseDN() != "" { + dn = dn + "," + b.conn.BaseDN() } for { - data, nextDn, err := b.conn.getCaCertificate(dn) + data, nextDn, err := b.conn.GetCaCertificate(dn) if err != nil { return err } diff --git a/ldap.go b/ldap.go index cac1983..da74b60 100644 --- a/ldap.go +++ b/ldap.go @@ -15,7 +15,7 @@ type ( // LDAP connection encapsulation. This includes the connection itself, as well as a logger // that includes fields related to the LDAP server and a copy of the initial configuration. tLdapConn struct { - Config tLdapConfig + config tLdapConfig conn *ldap.Conn log *logrus.Entry server int @@ -27,7 +27,7 @@ type ( ) // Try to establish a connection to one of the servers -func getLdapConnection(cfg tLdapConfig) *tLdapConn { +func NewLdapConnection(cfg tLdapConfig) *tLdapConn { for i := range cfg.Servers { conn := getLdapServerConnection(cfg, i) if conn != nil { @@ -37,6 +37,17 @@ func getLdapConnection(cfg tLdapConfig) *tLdapConn { return nil } +// Close a LDAP connection +func (conn *tLdapConn) Close() { + conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection") + conn.conn.Close() +} + +// Get the base DN +func (conn *tLdapConn) BaseDN() string { + return conn.config.Structure.BaseDN +} + // Establish a connection to a LDAP server func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn { if server < 0 || server >= len(cfg.Servers) { @@ -101,7 +112,7 @@ func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn { } log.Debug("LDAP connection established") return &tLdapConn{ - Config: cfg, + config: cfg, conn: lc, log: log, server: server, @@ -142,15 +153,9 @@ func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry) return true, res.Entries[0] } -// Close a LDAP connection -func (conn *tLdapConn) close() { - conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection") - conn.conn.Close() -} - // Get an end entity's certificate from the LDAP -func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) { - eec := conn.Config.Structure.EndEntityCertificate +func (conn *tLdapConn) GetEndEntityCertificate(dn string) ([]byte, error) { + eec := conn.config.Structure.EndEntityCertificate success, entry := conn.getObject(dn, []string{eec}) if !success { return nil, fmt.Errorf("Could not read certificate from '%s'", dn) @@ -173,9 +178,9 @@ func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) { // Get a CA certificate, as well as the value of the chaining field, from // the LDAP. -func (conn *tLdapConn) getCaCertificate(dn string) ([]byte, string, error) { - cc := conn.Config.Structure.CACertificate - chain := conn.Config.Structure.CAChaining +func (conn *tLdapConn) GetCaCertificate(dn string) ([]byte, string, error) { + cc := conn.config.Structure.CACertificate + chain := conn.config.Structure.CAChaining attrs := []string{cc} if chain != "" { attrs = append(attrs, chain) diff --git a/main.go b/main.go index 45f3026..4971a5f 100644 --- a/main.go +++ b/main.go @@ -71,11 +71,11 @@ func main() { } listener.Close() - conn := getLdapConnection(cfg.LdapConfig) + conn := NewLdapConnection(cfg.LdapConfig) if conn == nil { return } - defer conn.close() + defer conn.Close() for i := range cfg.Certificates { builder := NewCertificateBuilder(conn, &cfg.Certificates[i]) err := builder.Build()