From f95da0e3e88b6a0ee9ab913665b27faa45639f22 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= <tseeker@nocternity.net>
Date: Fri, 5 Nov 2021 17:16:44 +0100
Subject: [PATCH] Write certificate file and set privileges

---
 buildcert.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++--
 main.go      |  16 +++++-
 2 files changed, 151 insertions(+), 5 deletions(-)

diff --git a/buildcert.go b/buildcert.go
index 07996bf..b24f4de 100644
--- a/buildcert.go
+++ b/buildcert.go
@@ -4,18 +4,35 @@ import (
 	"encoding/pem"
 	"fmt"
 	"io/ioutil"
+	"os"
+	"os/user"
+	"strconv"
+	"syscall"
+
+	"github.com/sirupsen/logrus"
 )
 
 // Max supported CA chain length
 const MAX_CA_CHAIN_LENGTH = 8
 
 type (
+	// Structure that describes the existing file for a certificate.
+	tExistingFileInfo struct {
+		owner uint32
+		group uint32
+		mode  os.FileMode
+	}
+
 	// Certificate building, including the configuration, LDAP connection,
 	// and the array of chunks that's being built.
 	tCertificateBuilder struct {
-		config *tCertificateFileConfig
-		conn   *tLdapConn
-		data   [][]byte
+		config   *tCertificateFileConfig
+		conn     *tLdapConn
+		logger   *logrus.Entry
+		data     [][]byte
+		text     []byte
+		existing *tExistingFileInfo
+		changed  bool
 	}
 )
 
@@ -25,6 +42,7 @@ func NewCertificateBuilder(conn *tLdapConn, config *tCertificateFileConfig) tCer
 	return tCertificateBuilder{
 		config: config,
 		conn:   conn,
+		logger: log.WithField("file", config.Path),
 		data:   make([][]byte, 0),
 	}
 }
@@ -32,6 +50,7 @@ func NewCertificateBuilder(conn *tLdapConn, config *tCertificateFileConfig) tCer
 // Build the certificate file's data, returning any error that occurs while
 // reading the source data.
 func (b *tCertificateBuilder) Build() error {
+	b.logger.Debug("Checking for updates")
 	err := b.appendPemFiles(b.config.PrependFiles)
 	if err != nil {
 		return err
@@ -51,13 +70,102 @@ func (b *tCertificateBuilder) Build() error {
 	if b.config.Reverse {
 		b.reverseChunks()
 	}
+	b.generateText()
 	return nil
 }
 
+// Check whether the data should be written to disk.
+func (b *tCertificateBuilder) MustWrite() bool {
+	info, err := os.Lstat(b.config.Path)
+	if err != nil {
+		return true
+	}
+
+	sys_stat := info.Sys().(*syscall.Stat_t)
+	eif := &tExistingFileInfo{}
+	eif.mode = info.Mode()
+	eif.owner = sys_stat.Uid
+	eif.group = sys_stat.Gid
+	b.existing = eif
+
+	if sys_stat.Size != int64(len(b.text)) {
+		return true
+	}
+	existing, err := ioutil.ReadFile(b.config.Path)
+	if err != nil {
+		return true
+	}
+	for i, ch := range b.text {
+		if ch != existing[i] {
+			return true
+		}
+	}
+	return false
+}
+
+// Write the file's data
+func (b *tCertificateBuilder) WriteFile() error {
+	log.WithField("file", b.config.Path).Info("Writing certificate data to file")
+	err := ioutil.WriteFile(b.config.Path, b.text, b.config.Mode)
+	if err == nil {
+		b.changed = true
+	}
+	return err
+}
+
+// Update the file's owner and group
+func (b *tCertificateBuilder) UpdatePrivileges() error {
+	update_mode := !b.changed && b.existing.mode != b.config.Mode
+	if update_mode {
+		err := os.Chmod(b.config.Path, b.config.Mode)
+		if err != nil {
+			return err
+		}
+	}
+
+	log := b.logger
+	set_uid, set_gid := -1, -1
+	if b.config.Owner != "" {
+		usr, err := user.Lookup(b.config.Owner)
+		if err != nil {
+			return err
+		}
+		uid, err := strconv.Atoi(usr.Uid)
+		if b.changed || b.existing == nil || b.existing.owner != uint32(uid) {
+			set_uid = uid
+			log = log.WithField("uid", set_uid)
+		}
+	}
+	if b.config.Group != "" {
+		group, err := user.LookupGroup(b.config.Group)
+		if err != nil {
+			return err
+		}
+		gid, err := strconv.Atoi(group.Gid)
+		if b.changed || b.existing == nil || b.existing.group != uint32(gid) {
+			set_gid = gid
+			log = log.WithField("gid", set_gid)
+		}
+	}
+	if set_gid != -1 || set_uid != -1 {
+		log.Info("Updating file owner/group")
+		err := os.Chown(b.config.Path, set_uid, set_gid)
+		if err == nil {
+			b.changed = true
+		}
+		return err
+	} else {
+		b.changed = b.changed || update_mode
+		log.Debug("No update to privileges")
+		return nil
+	}
+}
+
 // Append PEM files from a list.
 func (b *tCertificateBuilder) appendPemFiles(files []string) error {
 	for _, path := range files {
 		var err error
+		b.logger.WithField("source", path).Debug("Adding PEM file")
 		err = b.appendPem(path)
 		if err != nil {
 			return err
@@ -98,6 +206,7 @@ func (b *tCertificateBuilder) appendCertificate() error {
 			dn = "," + dn
 		}
 		dn = b.config.Certificate + dn
+		b.logger.WithField("dn", dn).Debug("Adding EE certificate from LDAP")
 		data, err := b.conn.getEndEntityCertificate(dn)
 		if err != nil {
 			return err
@@ -126,6 +235,7 @@ func (b *tCertificateBuilder) appendListedCaCerts() error {
 		bdn = "," + bdn
 	}
 	for _, dn := range b.config.CACertificates {
+		b.logger.WithField("dn", dn+bdn).Debug("Adding CA certificate from LDAP")
 		data, _, err := b.conn.getCaCertificate(dn + bdn)
 		if err != nil {
 			return err
@@ -154,6 +264,7 @@ func (b *tCertificateBuilder) appendChainedCaCerts() error {
 			if data == nil {
 				return fmt.Errorf("No CA certificate at DN '%s'", dn)
 			}
+			b.logger.WithField("dn", dn).Debug("Adding CA certificate from LDAP chain")
 			b.data = append(b.data, data)
 		}
 		if nextDn == "" {
@@ -169,9 +280,32 @@ func (b *tCertificateBuilder) appendChainedCaCerts() error {
 
 // Reverse the chunks in the list
 func (b *tCertificateBuilder) reverseChunks() {
+	b.logger.Debug("Reversing PEM list")
 	l := len(b.data) / 2
 	for i := 0; i < l/2; i++ {
 		j := l - i - 1
 		b.data[i], b.data[j] = b.data[j], b.data[i]
 	}
 }
+
+// Generate the final text of the file
+func (b *tCertificateBuilder) generateText() {
+	size := int64(0)
+	for i := range b.data {
+		size += int64(len(b.data[i]))
+		if i != 0 && b.data[i-1][len(b.data[i-1])-1] != '\n' {
+			size++
+		}
+	}
+	b.text = make([]byte, size)
+	pos := 0
+	for i := range b.data {
+		copied := copy(b.text[pos:], b.data[i])
+		pos += copied
+		if i != 0 && b.data[i-1][len(b.data[i-1])-1] != '\n' {
+			b.text[pos] = '\n'
+			pos++
+		}
+	}
+	b.logger.WithField("size", size).Debug("Data generated")
+}
diff --git a/main.go b/main.go
index c8b6b25..45f3026 100644
--- a/main.go
+++ b/main.go
@@ -80,9 +80,21 @@ func main() {
 		builder := NewCertificateBuilder(conn, &cfg.Certificates[i])
 		err := builder.Build()
 		if err != nil {
-			log.WithField("error", err).Error("Failed to build data for certificate '", cfg.Certificates[i].Path)
+			log.WithField("error", err).Error("Failed to build data for certificate '", cfg.Certificates[i].Path, "'")
 			continue
 		}
-		// FIXME: check existing file, try to write
+		if builder.MustWrite() {
+			err := builder.WriteFile()
+			if err != nil {
+				log.WithField("error", err).Error("Failed to write '", cfg.Certificates[i].Path, "'")
+				continue
+			}
+		}
+		err = builder.UpdatePrivileges()
+		if err != nil {
+			log.WithField("error", err).Error("Failed to update privileges on '", cfg.Certificates[i].Path, "'")
+			continue
+		}
+		// TODO builder.RunCommandsIfChanged()
 	}
 }