diff --git a/cmd/check_ssl_certificate/main.go b/cmd/check_ssl_certificate/main.go index f1c5426..8eca0eb 100644 --- a/cmd/check_ssl_certificate/main.go +++ b/cmd/check_ssl_certificate/main.go @@ -4,6 +4,8 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net" + "net/textproto" "os" "strings" "time" @@ -14,6 +16,72 @@ import ( "github.com/karrick/golf" ) +//-------------------------------------------------------------------------------------------------------- + +// Interface that can be implemented to fetch TLS certificates. +type certGetter interface { + getCertificate(tlsConfig *tls.Config, address string) (*x509.Certificate, error) +} + +// Full TLS certificate fetcher +type fullTLSGetter struct{} + +func (f fullTLSGetter) getCertificate(tlsConfig *tls.Config, address string) (*x509.Certificate, error) { + conn, err := tls.Dial("tcp", address, tlsConfig) + if err != nil { + return nil, err + } + defer conn.Close() + if err := conn.Handshake(); err != nil { + return nil, err + } + return conn.ConnectionState().PeerCertificates[0], nil +} + +// SMTP STARTTLS getter +type smtpGetter struct{} + +func (f smtpGetter) cmd(tcon *textproto.Conn, expectCode int, text string) (int, string, error) { + id, err := tcon.Cmd("%s", text) + if err != nil { + return 0, "", err + } + tcon.StartResponse(id) + defer tcon.EndResponse(id) + return tcon.ReadResponse(expectCode) +} + +func (f smtpGetter) getCertificate(tlsConfig *tls.Config, address string) (*x509.Certificate, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + text := textproto.NewConn(conn) + defer text.Close() + if _, _, err := text.ReadResponse(220); err != nil { + return nil, err + } + if _, _, err := f.cmd(text, 250, "HELO localhost"); err != nil { + return nil, err + } + if _, _, err := f.cmd(text, 220, "STARTTLS"); err != nil { + return nil, err + } + t := tls.Client(conn, tlsConfig) + if err := t.Handshake(); err != nil { + return nil, err + } + return t.ConnectionState().PeerCertificates[0], nil +} + +// Supported StartTLS protocols +var certGetters map[string]certGetter = map[string]certGetter{ + "": fullTLSGetter{}, + "smtp": &smtpGetter{}, +} + +//-------------------------------------------------------------------------------------------------------- + // Command line flags that have been parsed. type programFlags struct { hostname string // Main host name to connect to @@ -22,12 +90,14 @@ type programFlags struct { crit int // Threshold for critical state (days) ignoreCnOnly bool // Do not warn about SAN-less certificates extraNames []string // Extra names the certificate should include + startTLS string // Protocol to use before requesting a switch to TLS. } // Program data including configuration and runtime data. type checkProgram struct { programFlags // Flags from the command line plugin *plugin.Plugin // Plugin output state + getter certGetter // Certificate getter certificate *x509.Certificate // X.509 certificate from the server } @@ -49,6 +119,8 @@ func (flags *programFlags) parseArguments() { "Do not issue warnings regarding certificates that do not use SANs at all.") golf.StringVarP(&names, 'a', "additional-names", "", "A comma-separated list of names that the certificate should also provide.") + golf.StringVarP(&flags.startTLS, 's', "start-tls", "", + "Protocol to use before requesting a switch to TLS. Supported protocols: smtp.") golf.Parse() if help { golf.Usage() @@ -90,6 +162,11 @@ func (program *checkProgram) checkFlags() bool { program.plugin.SetState(plugin.UNKNOWN, "nonsensical thresholds") return false } + if _, ok := certGetters[program.startTLS]; !ok { + errstr := fmt.Sprintf("unsupported StartTLS protocol %s", program.startTLS) + program.plugin.SetState(plugin.UNKNOWN, errstr) + return false + } program.hostname = strings.ToLower(program.hostname) return true } @@ -102,18 +179,13 @@ func (program *checkProgram) getCertificate() error { MinVersion: tls.VersionTLS10, } connString := fmt.Sprintf("%s:%d", program.hostname, program.port) - conn, err := tls.Dial("tcp", connString, tlsConfig) - if err != nil { - return fmt.Errorf("connection failed: %s", err.Error()) - } - defer conn.Close() - if err := conn.Handshake(); err != nil { - return fmt.Errorf("handshake failed: %s", err.Error()) - } - program.certificate = conn.ConnectionState().PeerCertificates[0] - return nil + certificate, err := certGetters[program.startTLS].getCertificate(tlsConfig, connString) + program.certificate = certificate + return err } +// Check that the CN of a certificate that doesn't contain a SAN actually +// matches the requested host name. func (program *checkProgram) checkSANlessCertificate() bool { if !program.ignoreCnOnly || len(program.extraNames) != 0 { program.plugin.SetState(plugin.WARNING,