diff --git a/buildcert.go b/buildcert.go index b6e182f..dc2362d 100644 --- a/buildcert.go +++ b/buildcert.go @@ -32,11 +32,9 @@ type ( // and the array of chunks that's being built. tCertificateBuilder struct { // The certificate file's current configuration - config *tCertificateFileConfig + Config *tCertificateFileConfig // The LDAP connection to read data from conn *tLdapConn - // The command that caused the update - command TCommand // The logger to use logger *logrus.Entry // The list of DNs that are involved in generating this certificate. If the @@ -57,13 +55,12 @@ type ( // Initialize a certificate file building using a LDAP connection and // certificate file configuration. -func NewCertificateBuilder(conn *tLdapConn, config *tCertificateFileConfig, cmd *TCommand) tCertificateBuilder { - return tCertificateBuilder{ - config: config, - conn: conn, - command: *cmd, - logger: log.WithField("file", config.Path), - data: make([][]byte, 0), +func NewCertificateBuilder(conn *tLdapConn, config *tCertificateFileConfig) *tCertificateBuilder { + return &tCertificateBuilder{ + Config: config, + conn: conn, + logger: log.WithField("file", config.Path), + data: make([][]byte, 0), } } @@ -71,7 +68,7 @@ func NewCertificateBuilder(conn *tLdapConn, config *tCertificateFileConfig, cmd // reading the source data. func (b *tCertificateBuilder) Build() error { b.logger.Debug("Checking for updates") - err := b.appendPemFiles(b.config.PrependFiles) + err := b.appendPemFiles(b.Config.PrependFiles) if err != nil { return err } @@ -83,38 +80,40 @@ func (b *tCertificateBuilder) Build() error { if err != nil { return err } - err = b.appendPemFiles(b.config.AppendFiles) + err = b.appendPemFiles(b.Config.AppendFiles) if err != nil { return err } - if b.config.Reverse { + if b.Config.Reverse { b.reverseChunks() } b.generateText() return nil } -// Check whether the command's selector matches one of the current certificate -// file's DNs. -func (b *tCertificateBuilder) SelectorMatches() bool { - if b.command.Selector == "*" { +// Check whether a selector matches one of the current certificate file's DNs. +func (b *tCertificateBuilder) SelectorMatches(selector string) bool { + if selector == "*" { return true } - sel := strings.ToLower(b.command.Selector) + sel := strings.ToLower(selector) for _, v := range b.dnList { if strings.ToLower(v) == sel { return true } } - b.logger.WithField("selector", b.command.Selector).Debug("Selector does not match.") + b.logger.WithField("selector", selector).Debug("Selector does not match.") return false } // Check whether the data should be written to disk. This also caches the // file's owner, group and mode. If the update is being forced it will return // `true` even if nothing changed. -func (b *tCertificateBuilder) MustWrite() bool { - info, err := os.Lstat(b.config.Path) +// +// Note: file information will be read even when updates are forced, because +// it is used later to set file privileges. +func (b *tCertificateBuilder) MustWrite(force bool) bool { + info, err := os.Lstat(b.Config.Path) if err != nil { return true } @@ -126,10 +125,10 @@ func (b *tCertificateBuilder) MustWrite() bool { eif.group = sys_stat.Gid b.existing = eif - if b.command.Force || sys_stat.Size != int64(len(b.text)) { + if force || sys_stat.Size != int64(len(b.text)) { return true } - existing, err := ioutil.ReadFile(b.config.Path) + existing, err := ioutil.ReadFile(b.Config.Path) if err != nil { return true } @@ -143,8 +142,8 @@ func (b *tCertificateBuilder) MustWrite() bool { // Write the file's data func (b *tCertificateBuilder) WriteFile() error { - log.WithField("file", b.config.Path).Info("Writing certificate data to file") - err := ioutil.WriteFile(b.config.Path, b.text, b.config.Mode) + log.WithField("file", b.Config.Path).Info("Writing certificate data to file") + err := ioutil.WriteFile(b.Config.Path, b.text, b.Config.Mode) if err == nil { b.changed = true } @@ -153,9 +152,9 @@ func (b *tCertificateBuilder) WriteFile() error { // Update the file's owner and group func (b *tCertificateBuilder) UpdatePrivileges() error { - update_mode := !b.changed && b.existing.mode != b.config.Mode + update_mode := !b.changed && b.existing.mode != b.Config.Mode if update_mode { - err := os.Chmod(b.config.Path, b.config.Mode) + err := os.Chmod(b.Config.Path, b.Config.Mode) if err != nil { return err } @@ -163,8 +162,8 @@ func (b *tCertificateBuilder) UpdatePrivileges() error { log := b.logger set_uid, set_gid := -1, -1 - if b.config.Owner != "" { - usr, err := user.Lookup(b.config.Owner) + if b.Config.Owner != "" { + usr, err := user.Lookup(b.Config.Owner) if err != nil { return err } @@ -174,8 +173,8 @@ func (b *tCertificateBuilder) UpdatePrivileges() error { log = log.WithField("uid", set_uid) } } - if b.config.Group != "" { - group, err := user.LookupGroup(b.config.Group) + if b.Config.Group != "" { + group, err := user.LookupGroup(b.Config.Group) if err != nil { return err } @@ -187,7 +186,7 @@ func (b *tCertificateBuilder) UpdatePrivileges() error { } if set_gid != -1 || set_uid != -1 { log.Info("Updating file owner/group") - err := os.Chown(b.config.Path, set_uid, set_gid) + err := os.Chown(b.Config.Path, set_uid, set_gid) if err == nil { b.changed = true } @@ -199,19 +198,21 @@ func (b *tCertificateBuilder) UpdatePrivileges() error { } } -// Run the necessary commands if the certificate file has been modified in +// Did this certificate change in any way? +func (b *tCertificateBuilder) Changed() bool { + return b.changed +} + +// Run the commands from the pre_commands section of the configuration. +// This is only called if the certificate file has been modified in // any way. Execution will stop at the first failure. -func (b *tCertificateBuilder) RunCommandsIfChanged() error { - if !b.changed { - log.Debug("Not running commands") - return nil - } - for i := range b.config.AfterUpdate.PreCommands { +func (b *tCertificateBuilder) RunPreCommands() error { + for i := range b.Config.AfterUpdate.PreCommands { err := b.RunCommand(i) if err != nil { return fmt.Errorf( "Failed while executing command '%s': %w", - b.config.AfterUpdate.PreCommands[i], + b.Config.AfterUpdate.PreCommands[i], err, ) } @@ -224,9 +225,9 @@ func (b *tCertificateBuilder) RunCommand(pos int) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - log := b.logger.WithField("command", b.config.AfterUpdate.PreCommands[pos]) + log := b.logger.WithField("command", b.Config.AfterUpdate.PreCommands[pos]) log.Debug("Executing command") - cmd := exec.CommandContext(ctx, "sh", "-c", b.config.AfterUpdate.PreCommands[pos]) + cmd := exec.CommandContext(ctx, "sh", "-c", b.Config.AfterUpdate.PreCommands[pos]) output, err := cmd.CombinedOutput() if len(output) != 0 { if utf8.Valid(output) { @@ -282,12 +283,12 @@ func (b *tCertificateBuilder) appendPem(input string) error { // Append the main, end-entity certificate from the LDAP func (b *tCertificateBuilder) appendCertificate() error { - if b.config.Certificate != "" { + if b.Config.Certificate != "" { dn := b.conn.BaseDN() if dn != "" { dn = "," + dn } - dn = b.config.Certificate + dn + dn = b.Config.Certificate + dn b.dnList = append(b.dnList, strings.ToLower(dn)) b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP") data, err := b.conn.GetEndEntityCertificate(dn) @@ -302,9 +303,9 @@ func (b *tCertificateBuilder) appendCertificate() error { // Append all CA certificates, reading the list from the LDAP or from the // configuration. func (b *tCertificateBuilder) appendCaCertificates() error { - if len(b.config.CACertificates) != 0 { + if len(b.Config.CACertificates) != 0 { return b.appendListedCaCerts() - } else if b.config.CAChainOf != "" { + } else if b.Config.CAChainOf != "" { return b.appendChainedCaCerts() } else { return nil @@ -317,7 +318,7 @@ func (b *tCertificateBuilder) appendListedCaCerts() error { if bdn != "" { bdn = "," + bdn } - for _, dn := range b.config.CACertificates { + for _, dn := range b.Config.CACertificates { full_dn := dn + bdn b.dnList = append(b.dnList, strings.ToLower(full_dn)) b.logger.WithField("dn", full_dn).Debug("Adding CA certificate from LDAP") @@ -336,7 +337,7 @@ func (b *tCertificateBuilder) appendListedCaCerts() error { // Append CA certificates by following a chain starting at some DN func (b *tCertificateBuilder) appendChainedCaCerts() error { nFound := 0 - dn := b.config.CAChainOf + dn := b.Config.CAChainOf if b.conn.BaseDN() != "" { dn = dn + "," + b.conn.BaseDN() } diff --git a/socket.go b/socket.go index 2042c92..5a497b9 100644 --- a/socket.go +++ b/socket.go @@ -100,7 +100,7 @@ func executeFromSocket(cfg *tConfiguration, conn net.Conn) TCommandType { "force": command.Force, "selector": command.Selector, }).Info("Update request received") - success := executeUpdate(cfg, command) + success := executeUpdate(cfg, command.Selector, command.Force) conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) var bval byte if success { @@ -142,45 +142,3 @@ func parseCommand(n int, buf []byte) *TCommand { log.Warn("Invalid command received") return nil } - -func executeUpdate(cfg *tConfiguration, cmd *TCommand) bool { - conn := NewLdapConnection(cfg.LdapConfig) - if conn == nil { - return false - } - defer conn.Close() - - had_errors := false - for i := range cfg.Certificates { - builder := NewCertificateBuilder(conn, &cfg.Certificates[i], cmd) - err := builder.Build() - if err != nil { - log.WithField("error", err).Error("Failed to build data for certificate '", cfg.Certificates[i].Path, "'") - had_errors = true - continue - } - if !builder.SelectorMatches() { - continue - } - if builder.MustWrite() { - err := builder.WriteFile() - if err != nil { - log.WithField("error", err).Error("Failed to write '", cfg.Certificates[i].Path, "'") - had_errors = true - continue - } - } - err = builder.UpdatePrivileges() - if err != nil { - log.WithField("error", err).Error("Failed to update privileges on '", cfg.Certificates[i].Path, "'") - had_errors = true - continue - } - err = builder.RunCommandsIfChanged() - if err != nil { - log.WithField("error", err).Error("Failed to run commands after update of '", cfg.Certificates[i].Path, "'") - had_errors = true - } - } - return !had_errors -} diff --git a/update.go b/update.go new file mode 100644 index 0000000..74a8412 --- /dev/null +++ b/update.go @@ -0,0 +1,259 @@ +package main + +import ( + "context" + "fmt" + "os/exec" + "time" + "unicode/utf8" + + "github.com/sirupsen/logrus" +) + +type ( + tUpdate struct { + // The current configuration + config *tConfiguration + // The selector for this update + selector string + // Whether the update must be forced. + force bool + // Certificate builders for each configured certificate file + builders []*tCertificateBuilder + // Whether errors occurred during the update. + errors bool + } +) + +// Start a new update, based on the specified configuration. The update's +// parameters (selector and force flag) will be stored as well. +func NewUpdate(cfg *tConfiguration, selector string, force bool) tUpdate { + return tUpdate{ + config: cfg, + selector: selector, + force: force, + builders: make([]*tCertificateBuilder, len(cfg.Certificates)), + } +} + +// Execute the update. Builders will be initialized and filtered based on the +// selector, then used to write the certificates to files. After that, commands +// and handlers will be executed. +func (u *tUpdate) Execute() bool { + u.initBuilders() + u.writeFiles() + u.runPreCommands() + handlers := u.enumerateHandlers() + failedHandlers := u.runHandlers(handlers) + u.disableBuildersWithFailedHandlers(failedHandlers) + u.runPostCommands() + return !u.errors +} + +// Initialise builders for all certificates that need to be updated. If errors +// occur while preparing one of the certificates, or if it doesn't match the +// selector, the builder will not be kept. +func (u *tUpdate) initBuilders() { + ldap := NewLdapConnection(u.config.LdapConfig) + if ldap == nil { + return + } + defer ldap.Close() + for i := range u.config.Certificates { + builder := NewCertificateBuilder(ldap, &u.config.Certificates[i]) + err := builder.Build() + if err != nil { + log.WithField("error", err).Error( + "Failed to build data for certificate '", + builder.Config.Path, "'", + ) + u.errors = true + } else if builder.SelectorMatches(u.selector) { + u.builders[i] = builder + } + } +} + +// Write certificates to disk and set file ownership/privileges for all builders +// that were initalised. +func (u *tUpdate) writeFiles() { + for i, builder := range u.builders { + if builder == nil { + continue + } + + if builder.MustWrite(u.force) { + err := builder.WriteFile() + if err != nil { + log.WithField("error", err).Error( + "Failed to write '", + builder.Config.Path, "'", + ) + u.errors = true + continue + } + } + err := builder.UpdatePrivileges() + if err != nil { + log.WithField("error", err).Error( + "Failed to update privileges on '", + builder.Config.Path, "'", + ) + u.errors = true + continue + } + if !builder.Changed() { + u.builders[i] = nil + } + } +} + +// Run pre-commands for all builders. +func (u *tUpdate) runPreCommands() { + for i, builder := range u.builders { + if builder == nil { + continue + } + + commands := u.config.Certificates[i].AfterUpdate.PreCommands + if len(commands) == 0 { + continue + } + + l := log.WithField("file", u.config.Certificates[i].Path) + l.Info("Running pre-commands") + err := u.runCommands(commands, l) + if err == nil { + continue + } + + l.WithField("error", err).Error("Failed to run pre-commands") + u.builders[i] = nil + u.errors = true + } +} + +// Returns a list of all handlers that must be executed based on the builders +// still listed as active. +func (u *tUpdate) enumerateHandlers() []string { + handlers := make(map[string]bool) + for i, builder := range u.builders { + if builder == nil { + continue + } + for _, handler := range u.config.Certificates[i].AfterUpdate.Handlers { + handlers[handler] = true + } + } + hdl_list := []string{} + for handler := range handlers { + hdl_list = append(hdl_list, handler) + } + return hdl_list +} + +// Execute commands for all listed handlers, returning a map of handlers that +// failed to execute. +func (u *tUpdate) runHandlers(handlers []string) map[string]bool { + failures := make(map[string]bool) + for _, handler := range handlers { + l := log.WithField("handler", handler) + l.Info("Running handler") + err := u.runCommands(u.config.Handlers[handler], l) + if err == nil { + continue + } + l.WithField("error", err).Error("Failed to run handler commands") + failures[handler] = true + u.errors = true + } + return failures +} + +// Disable builders that have one of the failed handlers in their list of +// handlers. +func (u *tUpdate) disableBuildersWithFailedHandlers(failedHandlers map[string]bool) { + for i, builder := range u.builders { + if builder == nil { + continue + } + for _, handler := range u.config.Certificates[i].AfterUpdate.Handlers { + if _, exists := failedHandlers[handler]; exists { + log.WithFields(logrus.Fields{ + "handler": handler, + "file": u.config.Certificates[i].Path, + }).Debug("Disabling builder due to failed handler") + u.builders[i] = nil + break + } + } + } +} + +// Run post-commands for all builders. +func (u *tUpdate) runPostCommands() { + for i, builder := range u.builders { + if builder == nil { + continue + } + + commands := u.config.Certificates[i].AfterUpdate.PostCommands + if len(commands) == 0 { + continue + } + + l := log.WithField("file", u.config.Certificates[i].Path) + l.Info("Running post-commands") + err := u.runCommands(commands, l) + if err == nil { + continue + } + + l.WithField("error", err).Error("Failed to run post-commands") + u.builders[i] = nil + u.errors = true + } +} + +// Run a list of commands +func (u *tUpdate) runCommands(commands []string, log *logrus.Entry) error { + for i := range commands { + err := u.runCommand(commands[i], log) + if err != nil { + return fmt.Errorf( + "Failed while executing command '%s': %w", + commands[i], err, + ) + } + } + return nil +} + +// Run a command through the `sh` shell. +func (b *tUpdate) runCommand(command string, log *logrus.Entry) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + log = log.WithField("command", command) + log.Debug("Executing command") + cmd := exec.CommandContext(ctx, "sh", "-c", command) + output, err := cmd.CombinedOutput() + if len(output) != 0 { + if utf8.Valid(output) { + log = log.WithField("output", string(output)) + } else { + log = log.WithField("output", string(output)) + } + } + if err == nil { + log.Info("Command executed") + } else { + log.WithField("error", err).Error("Command failed") + } + return err +} + +func executeUpdate(cfg *tConfiguration, selector string, force bool) bool { + ex := NewUpdate(cfg, selector, force) + return ex.Execute() +}