Minor refactoring

* Made some LDAP methods public
  * Made the LDAP connection's config field private
This commit is contained in:
Emmanuel BENOîT 2021-11-06 10:05:45 +01:00
parent f95da0e3e8
commit a2606b5b89
3 changed files with 28 additions and 23 deletions

View file

@ -201,13 +201,13 @@ func (b *tCertificateBuilder) appendPem(input string) error {
// Append the main, end-entity certificate from the LDAP // Append the main, end-entity certificate from the LDAP
func (b *tCertificateBuilder) appendCertificate() error { func (b *tCertificateBuilder) appendCertificate() error {
if b.config.Certificate != "" { if b.config.Certificate != "" {
dn := b.conn.Config.Structure.BaseDN dn := b.conn.BaseDN()
if dn != "" { if dn != "" {
dn = "," + dn dn = "," + dn
} }
dn = b.config.Certificate + dn dn = b.config.Certificate + dn
b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP") b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP")
data, err := b.conn.getEndEntityCertificate(dn) data, err := b.conn.GetEndEntityCertificate(dn)
if err != nil { if err != nil {
return err return err
} }
@ -230,13 +230,13 @@ func (b *tCertificateBuilder) appendCaCertificates() error {
// Append CA certificates based on a list of DNs // Append CA certificates based on a list of DNs
func (b *tCertificateBuilder) appendListedCaCerts() error { func (b *tCertificateBuilder) appendListedCaCerts() error {
bdn := b.conn.Config.Structure.BaseDN bdn := b.conn.BaseDN()
if bdn != "" { if bdn != "" {
bdn = "," + bdn bdn = "," + bdn
} }
for _, dn := range b.config.CACertificates { for _, dn := range b.config.CACertificates {
b.logger.WithField("dn", dn+bdn).Debug("Adding CA certificate from LDAP") b.logger.WithField("dn", dn+bdn).Debug("Adding CA certificate from LDAP")
data, _, err := b.conn.getCaCertificate(dn + bdn) data, _, err := b.conn.GetCaCertificate(dn + bdn)
if err != nil { if err != nil {
return err return err
} }
@ -252,11 +252,11 @@ func (b *tCertificateBuilder) appendListedCaCerts() error {
func (b *tCertificateBuilder) appendChainedCaCerts() error { func (b *tCertificateBuilder) appendChainedCaCerts() error {
nFound := 0 nFound := 0
dn := b.config.CAChainOf dn := b.config.CAChainOf
if b.conn.Config.Structure.BaseDN != "" { if b.conn.BaseDN() != "" {
dn = dn + "," + b.conn.Config.Structure.BaseDN dn = dn + "," + b.conn.BaseDN()
} }
for { for {
data, nextDn, err := b.conn.getCaCertificate(dn) data, nextDn, err := b.conn.GetCaCertificate(dn)
if err != nil { if err != nil {
return err return err
} }

33
ldap.go
View file

@ -15,7 +15,7 @@ type (
// LDAP connection encapsulation. This includes the connection itself, as well as a logger // LDAP connection encapsulation. This includes the connection itself, as well as a logger
// that includes fields related to the LDAP server and a copy of the initial configuration. // that includes fields related to the LDAP server and a copy of the initial configuration.
tLdapConn struct { tLdapConn struct {
Config tLdapConfig config tLdapConfig
conn *ldap.Conn conn *ldap.Conn
log *logrus.Entry log *logrus.Entry
server int server int
@ -27,7 +27,7 @@ type (
) )
// Try to establish a connection to one of the servers // Try to establish a connection to one of the servers
func getLdapConnection(cfg tLdapConfig) *tLdapConn { func NewLdapConnection(cfg tLdapConfig) *tLdapConn {
for i := range cfg.Servers { for i := range cfg.Servers {
conn := getLdapServerConnection(cfg, i) conn := getLdapServerConnection(cfg, i)
if conn != nil { if conn != nil {
@ -37,6 +37,17 @@ func getLdapConnection(cfg tLdapConfig) *tLdapConn {
return nil return nil
} }
// Close a LDAP connection
func (conn *tLdapConn) Close() {
conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
conn.conn.Close()
}
// Get the base DN
func (conn *tLdapConn) BaseDN() string {
return conn.config.Structure.BaseDN
}
// Establish a connection to a LDAP server // Establish a connection to a LDAP server
func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn { func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
if server < 0 || server >= len(cfg.Servers) { if server < 0 || server >= len(cfg.Servers) {
@ -101,7 +112,7 @@ func getLdapServerConnection(cfg tLdapConfig, server int) *tLdapConn {
} }
log.Debug("LDAP connection established") log.Debug("LDAP connection established")
return &tLdapConn{ return &tLdapConn{
Config: cfg, config: cfg,
conn: lc, conn: lc,
log: log, log: log,
server: server, server: server,
@ -142,15 +153,9 @@ func (conn *tLdapConn) getObject(dn string, attrs []string) (bool, *ldap.Entry)
return true, res.Entries[0] return true, res.Entries[0]
} }
// Close a LDAP connection
func (conn *tLdapConn) close() {
conn.log.WithField("queries", conn.counter).Debug("Closing LDAP connection")
conn.conn.Close()
}
// Get an end entity's certificate from the LDAP // Get an end entity's certificate from the LDAP
func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) { func (conn *tLdapConn) GetEndEntityCertificate(dn string) ([]byte, error) {
eec := conn.Config.Structure.EndEntityCertificate eec := conn.config.Structure.EndEntityCertificate
success, entry := conn.getObject(dn, []string{eec}) success, entry := conn.getObject(dn, []string{eec})
if !success { if !success {
return nil, fmt.Errorf("Could not read certificate from '%s'", dn) return nil, fmt.Errorf("Could not read certificate from '%s'", dn)
@ -173,9 +178,9 @@ func (conn *tLdapConn) getEndEntityCertificate(dn string) ([]byte, error) {
// Get a CA certificate, as well as the value of the chaining field, from // Get a CA certificate, as well as the value of the chaining field, from
// the LDAP. // the LDAP.
func (conn *tLdapConn) getCaCertificate(dn string) ([]byte, string, error) { func (conn *tLdapConn) GetCaCertificate(dn string) ([]byte, string, error) {
cc := conn.Config.Structure.CACertificate cc := conn.config.Structure.CACertificate
chain := conn.Config.Structure.CAChaining chain := conn.config.Structure.CAChaining
attrs := []string{cc} attrs := []string{cc}
if chain != "" { if chain != "" {
attrs = append(attrs, chain) attrs = append(attrs, chain)

View file

@ -71,11 +71,11 @@ func main() {
} }
listener.Close() listener.Close()
conn := getLdapConnection(cfg.LdapConfig) conn := NewLdapConnection(cfg.LdapConfig)
if conn == nil { if conn == nil {
return return
} }
defer conn.close() defer conn.Close()
for i := range cfg.Certificates { for i := range cfg.Certificates {
builder := NewCertificateBuilder(conn, &cfg.Certificates[i]) builder := NewCertificateBuilder(conn, &cfg.Certificates[i])
err := builder.Build() err := builder.Build()