diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go
index d155354..0398280 100644
--- a/internal/app/bridge/bridge.go
+++ b/internal/app/bridge/bridge.go
@@ -19,12 +19,14 @@
package bridge
import (
+ "crypto/tls"
"time"
"github.com/ProtonMail/proton-bridge/internal/api"
"github.com/ProtonMail/proton-bridge/internal/app/base"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
+ pkgTLS "github.com/ProtonMail/proton-bridge/internal/config/tls"
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/frontend"
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
@@ -58,9 +60,9 @@ func New(base *base.Base) *cli.App {
}
func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
- tls, err := b.TLS.GetConfig()
+ tlsConfig, err := loadTLSConfig(b)
if err != nil {
- logrus.WithError(err).Fatal("Failed to create TLS config")
+ logrus.WithError(err).Fatal("Failed to load TLS config")
}
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.CrashHandler, b.Listener, b.CM, b.Creds)
@@ -78,7 +80,7 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
imap.NewIMAPServer(
c.String("log-imap") == "client" || c.String("log-imap") == "all",
c.String("log-imap") == "server" || c.String("log-imap") == "all",
- imapPort, tls, imapBackend, b.Listener).ListenAndServe()
+ imapPort, tlsConfig, imapBackend, b.Listener).ListenAndServe()
}()
go func() {
@@ -87,7 +89,7 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
useSSL := b.Settings.GetBool(settings.SMTPSSLKey)
smtp.NewSMTPServer(
c.Bool("log-smtp"),
- smtpPort, useSSL, tls, smtpBackend, b.Listener).ListenAndServe()
+ smtpPort, useSSL, tlsConfig, smtpBackend, b.Listener).ListenAndServe()
}()
// Bridge supports no-window option which we should use for autostart.
@@ -140,6 +142,44 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
return f.Loop()
}
+func loadTLSConfig(b *base.Base) (*tls.Config, error) {
+ if !b.TLS.HasCerts() {
+ if err := generateTLSCerts(b); err != nil {
+ return nil, err
+ }
+ }
+
+ tlsConfig, err := b.TLS.GetConfig()
+ if err == nil {
+ return tlsConfig, nil
+ }
+
+ logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
+
+ if err := generateTLSCerts(b); err != nil {
+ return nil, err
+ }
+
+ return b.TLS.GetConfig()
+}
+
+func generateTLSCerts(b *base.Base) error {
+ template, err := pkgTLS.NewTLSTemplate()
+ if err != nil {
+ return errors.Wrap(err, "failed to generate TLS template")
+ }
+
+ if err := b.TLS.GenerateCerts(template); err != nil {
+ return errors.Wrap(err, "failed to generate TLS certs")
+ }
+
+ if err := b.TLS.InstallCerts(); err != nil {
+ return errors.Wrap(err, "failed to install TLS certs")
+ }
+
+ return nil
+}
+
func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) {
version, err := u.Check()
if err != nil {
diff --git a/internal/config/tls/cert_store_darwin.go b/internal/config/tls/cert_store_darwin.go
new file mode 100644
index 0000000..ee14a42
--- /dev/null
+++ b/internal/config/tls/cert_store_darwin.go
@@ -0,0 +1,53 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package tls
+
+import "os/exec"
+
+func addTrustedCert(certPath string) error {
+ return exec.Command( // nolint[gosec]
+ "/usr/bin/security",
+ "execute-with-privileges",
+ "/usr/bin/security",
+ "add-trusted-cert",
+ "-d",
+ "-r", "trustRoot",
+ "-p", "ssl",
+ "-k", "/Library/Keychains/System.keychain",
+ certPath,
+ ).Run()
+}
+
+func removeTrustedCert(certPath string) error {
+ return exec.Command( // nolint[gosec]
+ "/usr/bin/security",
+ "execute-with-privileges",
+ "/usr/bin/security",
+ "remove-trusted-cert",
+ "-d",
+ certPath,
+ ).Run()
+}
+
+func (t *TLS) InstallCerts() error {
+ return addTrustedCert(t.getTLSCertPath())
+}
+
+func (t *TLS) UninstallCerts() error {
+ return removeTrustedCert(t.getTLSCertPath())
+}
diff --git a/internal/config/tls/cert_store_linux.go b/internal/config/tls/cert_store_linux.go
new file mode 100644
index 0000000..01a138b
--- /dev/null
+++ b/internal/config/tls/cert_store_linux.go
@@ -0,0 +1,26 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package tls
+
+func (t *TLS) InstallCerts() error {
+ return nil // Linux doesn't have a root cert store.
+}
+
+func (t *TLS) UninstallCerts() error {
+ return nil // Linux doesn't have a root cert store.
+}
diff --git a/internal/config/tls/cert_store_windows.go b/internal/config/tls/cert_store_windows.go
new file mode 100644
index 0000000..0fed515
--- /dev/null
+++ b/internal/config/tls/cert_store_windows.go
@@ -0,0 +1,26 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package tls
+
+func (t *TLS) InstallCerts() error {
+ return nil // NOTE(GODT-986): Install certs to root cert store?
+}
+
+func (t *TLS) UninstallCerts() error {
+ return nil // NOTE(GODT-986): Uninstall certs from root cert store?
+}
diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go
index 0b9e27e..0ce76cd 100644
--- a/internal/config/tls/tls.go
+++ b/internal/config/tls/tls.go
@@ -28,12 +28,10 @@ import (
"math/big"
"net"
"os"
- "os/exec"
"path/filepath"
- "runtime"
"time"
- "github.com/sirupsen/logrus"
+ "github.com/pkg/errors"
)
type TLS struct {
@@ -46,24 +44,32 @@ func New(settingsPath string) *TLS {
}
}
-var tlsTemplate = x509.Certificate{ //nolint[gochecknoglobals]
- SerialNumber: big.NewInt(-1),
- Subject: pkix.Name{
- Country: []string{"CH"},
- Organization: []string{"Proton Technologies AG"},
- OrganizationalUnit: []string{"ProtonMail"},
- CommonName: "127.0.0.1",
- },
- KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
- ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
- BasicConstraintsValid: true,
- IsCA: true,
- IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
- NotBefore: time.Now(),
- NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
+// NewTLSTemplate creates a new TLS template certificate with a random serial number.
+func NewTLSTemplate() (*x509.Certificate, error) {
+ serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to generate serial number")
+ }
+
+ return &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ Country: []string{"CH"},
+ Organization: []string{"Proton Technologies AG"},
+ OrganizationalUnit: []string{"ProtonMail"},
+ CommonName: "127.0.0.1",
+ },
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+ BasicConstraintsValid: true,
+ IsCA: true,
+ IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(20 * 365 * 24 * time.Hour),
+ }, nil
}
-var ErrTLSCertExpireSoon = fmt.Errorf("TLS certificate will expire soon")
+var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon")
// getTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP).
func (t *TLS) getTLSCertPath() string {
@@ -75,110 +81,78 @@ func (t *TLS) getTLSKeyPath() string {
return filepath.Join(t.settingsPath, "key.pem")
}
-// GenerateConfig generates certs and keys at the given filepaths and returns a TLS Config which holds them.
-// See https://golang.org/src/crypto/tls/generate_cert.go
-func (t *TLS) GenerateConfig() (tlsConfig *tls.Config, err error) {
- priv, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- err = fmt.Errorf("failed to generate private key: %s", err)
- return
+// HasCerts returns whether TLS certs have been generated.
+func (t *TLS) HasCerts() bool {
+ if _, err := os.Stat(t.getTLSCertPath()); err != nil {
+ return false
+ }
+
+ if _, err := os.Stat(t.getTLSKeyPath()); err != nil {
+ return false
}
- serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
- serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ return true
+}
+
+// GenerateCerts generates certs from the given template.
+func (t *TLS) GenerateCerts(template *x509.Certificate) error {
+ priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
- err = fmt.Errorf("failed to generate serial number: %s", err)
- return
+ return errors.Wrap(err, "failed to generate private key")
}
- tlsTemplate.SerialNumber = serialNumber
- derBytes, err := x509.CreateCertificate(rand.Reader, &tlsTemplate, &tlsTemplate, &priv.PublicKey, priv)
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
- err = fmt.Errorf("failed to create certificate: %s", err)
- return
+ return errors.Wrap(err, "failed to create certificate")
}
certOut, err := os.Create(t.getTLSCertPath())
if err != nil {
- return
+ return err
}
- defer certOut.Close() //nolint[errcheck]
- err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
- if err != nil {
- return
+ defer certOut.Close() // nolint[errcheck]
+
+ if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
+ return err
}
keyOut, err := os.OpenFile(t.getTLSKeyPath(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
- return
- }
- defer keyOut.Close() //nolint[errcheck]
- err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
- if err != nil {
- return
+ return err
}
+ defer keyOut.Close() // nolint[errcheck]
- return loadTLSConfig(t.getTLSCertPath(), t.getTLSKeyPath())
-}
-
-// GetConfig tries to load TLS config or generate new one which is then returned.
-func (t *TLS) GetConfig() (tlsConfig *tls.Config, err error) {
- certPath := t.getTLSCertPath()
- keyPath := t.getTLSKeyPath()
- tlsConfig, err = loadTLSConfig(certPath, keyPath)
- if err != nil {
- logrus.WithError(err).Warn("Cannot load cert, generating a new one")
- tlsConfig, err = t.GenerateConfig()
- if err != nil {
- return
- }
-
- if runtime.GOOS == "darwin" {
- if err := exec.Command( // nolint[gosec]
- "/usr/bin/security",
- "execute-with-privileges",
- "/usr/bin/security",
- "add-trusted-cert",
- "-d",
- "-r", "trustRoot",
- "-p", "ssl",
- "-k", "/Library/Keychains/System.keychain",
- certPath,
- ).Run(); err != nil {
- logrus.WithError(err).Error("Failed to add cert to system keychain")
- }
- }
+ if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
+ return err
}
- tlsConfig.ServerName = "127.0.0.1"
- tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
-
- caCertPool := x509.NewCertPool()
- caCertPool.AddCert(tlsConfig.Certificates[0].Leaf)
- tlsConfig.RootCAs = caCertPool
- tlsConfig.ClientCAs = caCertPool
-
- return tlsConfig, err
+ return nil
}
-func loadTLSConfig(certPath, keyPath string) (tlsConfig *tls.Config, err error) {
- c, err := tls.LoadX509KeyPair(certPath, keyPath)
+// GetConfig tries to load TLS config or generate new one which is then returned.
+func (t *TLS) GetConfig() (*tls.Config, error) {
+ c, err := tls.LoadX509KeyPair(t.getTLSCertPath(), t.getTLSKeyPath())
if err != nil {
- return
+ return nil, errors.Wrap(err, "failed to load keypair")
}
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
if err != nil {
- return
- }
-
- tlsConfig = &tls.Config{
- Certificates: []tls.Certificate{c},
+ return nil, errors.Wrap(err, "failed to parse certificate")
}
if time.Now().Add(31 * 24 * time.Hour).After(c.Leaf.NotAfter) {
- err = ErrTLSCertExpireSoon
- return
+ return nil, ErrTLSCertExpiresSoon
}
- return
+
+ caCertPool := x509.NewCertPool()
+ caCertPool.AddCert(c.Leaf)
+
+ return &tls.Config{
+ Certificates: []tls.Certificate{c},
+ ServerName: "127.0.0.1",
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ RootCAs: caCertPool,
+ ClientCAs: caCertPool,
+ }, nil
}
diff --git a/internal/config/tls/tls_test.go b/internal/config/tls/tls_test.go
index 292682d..5c41a46 100644
--- a/internal/config/tls/tls_test.go
+++ b/internal/config/tls/tls_test.go
@@ -19,46 +19,59 @@ package tls
import (
"io/ioutil"
- "os"
- "path/filepath"
- "runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
)
-func TestTLSKeyRenewal(t *testing.T) {
- // Remove keys.
- configPath := "/tmp"
- certPath := filepath.Join(configPath, "cert.pem")
- keyPath := filepath.Join(configPath, "key.pem")
- _ = os.Remove(certPath)
- _ = os.Remove(keyPath)
-
+func TestGetOldConfig(t *testing.T) {
dir, err := ioutil.TempDir("", "test-tls")
require.NoError(t, err)
+ // Create new tls object.
tls := New(dir)
- // Put old key there.
+ // Create new TLS template.
+ tlsTemplate, err := NewTLSTemplate()
+ require.NoError(t, err)
+
+ // Make the template be an old key.
tlsTemplate.NotBefore = time.Now().Add(-365 * 24 * time.Hour)
tlsTemplate.NotAfter = time.Now()
- cert, err := tls.GenerateConfig()
- require.Equal(t, err, ErrTLSCertExpireSoon)
- require.Equal(t, len(cert.Certificates), 1)
- time.Sleep(time.Second)
- now, notValidAfter := time.Now(), cert.Certificates[0].Leaf.NotAfter
- require.True(t, now.After(notValidAfter), "old certificate expected to not be valid at %v but have valid until %v", now, notValidAfter)
-
- // Renew key.
+
+ // Generate the certs from the template.
+ require.NoError(t, tls.GenerateCerts(tlsTemplate))
+
+ // Generate the config from the certs -- it's going to expire soon so we don't want to use it.
+ _, err = tls.GetConfig()
+ require.Equal(t, err, ErrTLSCertExpiresSoon)
+}
+
+func TestGetValidConfig(t *testing.T) {
+ dir, err := ioutil.TempDir("", "test-tls")
+ require.NoError(t, err)
+
+ // Create new tls object.
+ tls := New(dir)
+
+ // Create new TLS template.
+ tlsTemplate, err := NewTLSTemplate()
+ require.NoError(t, err)
+
+ // Make the template be a new key.
tlsTemplate.NotBefore = time.Now()
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
- cert, err = tls.GetConfig()
- if runtime.GOOS != "darwin" { // Darwin is not supported.
- require.NoError(t, err)
- }
- require.Equal(t, len(cert.Certificates), 1)
- now, notValidAfter = time.Now(), cert.Certificates[0].Leaf.NotAfter
+
+ // Generate the certs from the template.
+ require.NoError(t, tls.GenerateCerts(tlsTemplate))
+
+ // Generate the config from the certs -- it's not going to expire soon so we want to use it.
+ config, err := tls.GetConfig()
+ require.NoError(t, err)
+ require.Equal(t, len(config.Certificates), 1)
+
+ // Check the cert is valid.
+ now, notValidAfter := time.Now(), config.Certificates[0].Leaf.NotAfter
require.False(t, now.After(notValidAfter), "new certificate expected to be valid at %v but have valid until %v", now, notValidAfter)
}