From a2606b5b891dd608aab3145dcbe6c1c8324129d7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= <tseeker@nocternity.net>
Date: Sat, 6 Nov 2021 10:05:45 +0100
Subject: [PATCH] Minor refactoring

  * Made some LDAP methods public
  * Made the LDAP connection's config field private
---
 buildcert.go | 14 +++++++-------
 ldap.go      | 33 +++++++++++++++++++--------------
 main.go      |  4 ++--
 3 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/buildcert.go b/buildcert.go
index b24f4de..1028c86 100644
--- a/buildcert.go
+++ b/buildcert.go
@@ -201,13 +201,13 @@ func (b *tCertificateBuilder) appendPem(input string) error {
 // Append the main, end-entity certificate from the LDAP
 func (b *tCertificateBuilder) appendCertificate() error {
 	if b.config.Certificate != "" {
-		dn := b.conn.Config.Structure.BaseDN
+		dn := b.conn.BaseDN()
 		if dn != "" {
 			dn = "," + dn
 		}
 		dn = b.config.Certificate + dn
 		b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP")
-		data, err := b.conn.getEndEntityCertificate(dn)
+		data, err := b.conn.GetEndEntityCertificate(dn)
 		if err != nil {
 			return err
 		}
@@ -230,13 +230,13 @@ func (b *tCertificateBuilder) appendCaCertificates() error {
 
 // Append CA certificates based on a list of DNs
 func (b *tCertificateBuilder) appendListedCaCerts() error {
-	bdn := b.conn.Config.Structure.BaseDN
+	bdn := b.conn.BaseDN()
 	if bdn != "" {
 		bdn = "," + bdn
 	}
 	for _, dn := range b.config.CACertificates {
 		b.logger.WithField("dn", dn+bdn).Debug("Adding CA certificate from LDAP")
-		data, _, err := b.conn.getCaCertificate(dn + bdn)
+		data, _, err := b.conn.GetCaCertificate(dn + bdn)
 		if err != nil {
 			return err
 		}
@@ -252,11 +252,11 @@ func (b *tCertificateBuilder) appendListedCaCerts() error {
 func (b *tCertificateBuilder) appendChainedCaCerts() error {
 	nFound := 0
 	dn := b.config.CAChainOf
-	if b.conn.Config.Structure.BaseDN != "" {
-		dn = dn + "," + b.conn.Config.Structure.BaseDN
+	if b.conn.BaseDN() != "" {
+		dn = dn + "," + b.conn.BaseDN()
 	}
 	for {
-		data, nextDn, err := b.conn.getCaCertificate(dn)
+		data, nextDn, err := b.conn.GetCaCertificate(dn)
 		if err != nil {
 			return err
 		}
diff --git a/ldap.go b/ldap.go
index cac1983..da74b60 100644
--- a/ldap.go
+++ b/ldap.go
@@ -15,7 +15,7 @@ type (
 	// LDAP connection encapsulation. This includes the connection itself, as well as a logger
 	// that includes fields related to the LDAP server and a copy of the initial configuration.
 	tLdapConn struct {
-		Config  tLdapConfig
+		config  tLdapConfig
 		conn    *ldap.Conn
 		log     *logrus.Entry
 		server  int
@@ -27,7 +27,7 @@ type (
 )
 
 // Try to establish a connection to one of the servers
-func getLdapConnection(cfg tLdapConfig) *tLdapConn {
+func NewLdapConnection(cfg tLdapConfig) *tLdapConn {
 	for i := range cfg.Servers {
 		conn := getLdapServerConnection(cfg, i)
 		if conn != nil {
@@ -37,6 +37,17 @@ func getLdapConnection(cfg tLdapConfig) *tLdapConn {
 	return nil
 }
 
+// Close a LDAP connection
+func (conn *tLdapConn) Close() {
+	conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
+	conn.conn.Close()
+}
+
+// Get the base DN
+func (conn *tLdapConn) BaseDN() string {
+	return conn.config.Structure.BaseDN
+}
+
 // Establish a connection to a LDAP server
 func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
 	if server < 0 || server >= len(cfg.Servers) {
@@ -101,7 +112,7 @@ func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
 	}
 	log.Debug("LDAP connection established")
 	return &tLdapConn{
-		Config: cfg,
+		config: cfg,
 		conn:   lc,
 		log:    log,
 		server: server,
@@ -142,15 +153,9 @@ func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry)
 	return true, res.Entries[0]
 }
 
-// Close a LDAP connection
-func (conn *tLdapConn) close() {
-	conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
-	conn.conn.Close()
-}
-
 // Get an end entity's certificate from the LDAP
-func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) {
-	eec := conn.Config.Structure.EndEntityCertificate
+func (conn *tLdapConn) GetEndEntityCertificate(dn string) ([]byte, error) {
+	eec := conn.config.Structure.EndEntityCertificate
 	success, entry := conn.getObject(dn, []string{eec})
 	if !success {
 		return nil, fmt.Errorf("Could not read certificate from '%s'", dn)
@@ -173,9 +178,9 @@ func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) {
 
 // Get a CA certificate, as well as the value of the chaining field, from
 // the LDAP.
-func (conn *tLdapConn) getCaCertificate(dn string) ([]byte, string, error) {
-	cc := conn.Config.Structure.CACertificate
-	chain := conn.Config.Structure.CAChaining
+func (conn *tLdapConn) GetCaCertificate(dn string) ([]byte, string, error) {
+	cc := conn.config.Structure.CACertificate
+	chain := conn.config.Structure.CAChaining
 	attrs := []string{cc}
 	if chain != "" {
 		attrs = append(attrs, chain)
diff --git a/main.go b/main.go
index 45f3026..4971a5f 100644
--- a/main.go
+++ b/main.go
@@ -71,11 +71,11 @@ func main() {
 	}
 	listener.Close()
 
-	conn := getLdapConnection(cfg.LdapConfig)
+	conn := NewLdapConnection(cfg.LdapConfig)
 	if conn == nil {
 		return
 	}
-	defer conn.close()
+	defer conn.Close()
 	for i := range cfg.Certificates {
 		builder := NewCertificateBuilder(conn, &cfg.Certificates[i])
 		err := builder.Build()