You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
445 lines
12 KiB
445 lines
12 KiB
// Copyright (c) 2020 Proton Technologies AG |
|
// |
|
// This file is part of ProtonMail Bridge. |
|
// |
|
// ProtonMail Bridge is free software: you can redistribute it and/or modify |
|
// it under the terms of the GNU General Public License as published by |
|
// the Free Software Foundation, either version 3 of the License, or |
|
// (at your option) any later version. |
|
// |
|
// ProtonMail Bridge is distributed in the hope that it will be useful, |
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
// GNU General Public License for more details. |
|
// |
|
// You should have received a copy of the GNU General Public License |
|
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>. |
|
|
|
package pmapi |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"encoding/json" |
|
"errors" |
|
"fmt" |
|
"io" |
|
"io/ioutil" |
|
"math/rand" |
|
"net/http" |
|
"reflect" |
|
"strconv" |
|
"strings" |
|
"sync" |
|
"time" |
|
|
|
pmcrypto "github.com/ProtonMail/gopenpgp/crypto" |
|
"github.com/jaytaylor/html2text" |
|
"github.com/sirupsen/logrus" |
|
) |
|
|
|
// Version of the API. |
|
const Version = 3 |
|
|
|
// API return codes. |
|
const ( |
|
ForceUpgradeBadAPIVersion = 5003 |
|
ForceUpgradeInvalidAPI = 5004 |
|
ForceUpgradeBadAppVersion = 5005 |
|
APIOffline = 7001 |
|
ImportMessageTooLong = 36022 |
|
BansRequests = 85131 |
|
) |
|
|
|
// The output errors. |
|
var ( |
|
ErrInvalidToken = errors.New("refresh token invalid") |
|
ErrAPINotReachable = errors.New("cannot reach the server") |
|
ErrUpgradeApplication = errors.New("application upgrade required") |
|
|
|
ErrNoSuchAPIID = errors.New("no such API ID") |
|
) |
|
|
|
type ErrUnauthorized struct { |
|
error |
|
} |
|
|
|
func (err *ErrUnauthorized) Error() string { |
|
return fmt.Sprintf("unauthorized access: %+v", err.error.Error()) |
|
} |
|
|
|
// ClientConfig contains Client configuration. |
|
type ClientConfig struct { |
|
// The client application name and version. |
|
AppVersion string |
|
|
|
// The client ID. |
|
ClientID string |
|
|
|
// The sentry DSN. |
|
SentryDSN string |
|
|
|
// Transport specifies the mechanism by which individual HTTP requests are made. |
|
// If nil, http.DefaultTransport is used. |
|
// TODO: This could be removed entirely and set in the client manager via SetClientRoundTripper. |
|
Transport http.RoundTripper |
|
|
|
// Timeout specifies the timeout from request to getting response headers to our API. |
|
// Passed to http.Client, empty means no timeout. |
|
Timeout time.Duration |
|
|
|
// FirstReadTimeout specifies the timeout from getting response to the first read of body response. |
|
// This timeout is applied only when MinSpeed is used. |
|
// Default is 5 minutes. |
|
FirstReadTimeout time.Duration |
|
|
|
// MinSpeed specifies minimum Bytes per second or the request will be canceled. |
|
// Zero means no limitation. |
|
MinSpeed int64 |
|
} |
|
|
|
// Client to communicate with API. |
|
type Client struct { |
|
cm *ClientManager |
|
hc *http.Client |
|
|
|
uid string |
|
accessToken string |
|
userID string |
|
requestLocker sync.Locker |
|
keyLocker sync.Locker |
|
|
|
user *User |
|
addresses AddressList |
|
kr *pmcrypto.KeyRing |
|
|
|
log *logrus.Entry |
|
} |
|
|
|
// newClient creates a new API client. |
|
func newClient(cm *ClientManager, userID string) *Client { |
|
return &Client{ |
|
cm: cm, |
|
hc: getHTTPClient(cm.GetConfig()), |
|
userID: userID, |
|
requestLocker: &sync.Mutex{}, |
|
keyLocker: &sync.Mutex{}, |
|
log: logrus.WithField("pkg", "pmapi").WithField("userID", userID), |
|
} |
|
} |
|
|
|
// getHTTPClient returns a http client configured by the given client config. |
|
func getHTTPClient(cfg *ClientConfig) (hc *http.Client) { |
|
hc = &http.Client{Timeout: cfg.Timeout} |
|
|
|
if cfg.Transport == nil { |
|
if defaultTransport != nil { |
|
hc.Transport = defaultTransport |
|
} |
|
return |
|
} |
|
|
|
// In future use Clone here. |
|
// https://go-review.googlesource.com/c/go/+/174597/ |
|
if cfgTransport, ok := cfg.Transport.(*http.Transport); ok { |
|
transport := &http.Transport{} |
|
*transport = *cfgTransport //nolint |
|
if transport.Proxy == nil { |
|
transport.Proxy = http.ProxyFromEnvironment |
|
} |
|
hc.Transport = transport |
|
return |
|
} |
|
|
|
hc.Transport = cfg.Transport |
|
|
|
return hc |
|
} |
|
|
|
// Do makes an API request. It does not check for HTTP status code errors. |
|
func (c *Client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) { |
|
// Copy the request body in case we need to retry it. |
|
var bodyBuffer []byte |
|
if req.Body != nil { |
|
defer req.Body.Close() //nolint[errcheck] |
|
bodyBuffer, err = ioutil.ReadAll(req.Body) |
|
|
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
r := bytes.NewReader(bodyBuffer) |
|
req.Body = ioutil.NopCloser(r) |
|
} |
|
|
|
return c.doBuffered(req, bodyBuffer, retryUnauthorized) |
|
} |
|
|
|
// If needed it retries using req and buffered body. |
|
func (c *Client) doBuffered(req *http.Request, bodyBuffer []byte, retryUnauthorized bool) (res *http.Response, err error) { // nolint[funlen] |
|
isAuthReq := strings.Contains(req.URL.Path, "/auth") |
|
|
|
req.Header.Set("x-pm-appversion", c.cm.GetConfig().AppVersion) |
|
req.Header.Set("x-pm-apiversion", strconv.Itoa(Version)) |
|
|
|
if c.uid != "" { |
|
req.Header.Set("x-pm-uid", c.uid) |
|
} |
|
|
|
if c.accessToken != "" { |
|
req.Header.Set("Authorization", "Bearer "+c.accessToken) |
|
} |
|
|
|
c.log.Debugln("Requesting ", req.Method, req.URL.RequestURI()) |
|
if logrus.GetLevel() == logrus.TraceLevel { |
|
head := "" |
|
for i, v := range req.Header { |
|
head += i + ": " |
|
head += strings.Join(v, "") |
|
head += "\n" |
|
} |
|
c.log.Tracef("REQHEAD \n%s", head) |
|
c.log.Tracef("REQBODY '%s'", string(bodyBuffer)) |
|
} |
|
|
|
hasBody := len(bodyBuffer) > 0 |
|
if res, err = c.hc.Do(req); err != nil { |
|
if res == nil { |
|
c.log.WithError(err).Error("Cannot get response") |
|
err = ErrAPINotReachable |
|
} |
|
return |
|
} |
|
|
|
resDate := res.Header.Get("Date") |
|
if resDate != "" { |
|
if serverTime, err := http.ParseTime(resDate); err == nil { |
|
pmcrypto.GetGopenPGP().UpdateTime(serverTime.Unix()) |
|
} |
|
} |
|
|
|
if res.StatusCode == http.StatusUnauthorized { |
|
if hasBody { |
|
r := bytes.NewReader(bodyBuffer) |
|
req.Body = ioutil.NopCloser(r) |
|
} |
|
|
|
if !isAuthReq { |
|
_, _ = io.Copy(ioutil.Discard, res.Body) |
|
_ = res.Body.Close() |
|
return c.handleStatusUnauthorized(req, bodyBuffer, res, retryUnauthorized) |
|
} |
|
} |
|
|
|
// Retry induced by HTTP status code> |
|
retryAfter := 10 |
|
doRetry := res.StatusCode == http.StatusTooManyRequests |
|
if doRetry { |
|
if headerAfter, err := strconv.Atoi(res.Header.Get("Retry-After")); err == nil && headerAfter > 0 { |
|
retryAfter = headerAfter |
|
} |
|
// To avoid spikes when all clients retry at the same time, we add some random wait. |
|
retryAfter += rand.Intn(10) |
|
|
|
if hasBody { |
|
r := bytes.NewReader(bodyBuffer) |
|
req.Body = ioutil.NopCloser(r) |
|
} |
|
|
|
c.log.Warningf("Retrying %s after %ds induced by http code %d", req.URL.Path, retryAfter, res.StatusCode) |
|
time.Sleep(time.Duration(retryAfter) * time.Second) |
|
_, _ = io.Copy(ioutil.Discard, res.Body) |
|
_ = res.Body.Close() |
|
return c.doBuffered(req, bodyBuffer, false) |
|
} |
|
|
|
return res, err |
|
} |
|
|
|
// DoJSON performs the request and unmarshals the response as JSON into data. |
|
// If the API returns a non-2xx HTTP status code, the error returned will contain status |
|
// and response as plaintext. API errors must be checked by the caller. |
|
// It is performed buffered, in case we need to retry. |
|
func (c *Client) DoJSON(req *http.Request, data interface{}) error { |
|
// Copy the request body in case we need to retry it |
|
var reqBodyBuffer []byte |
|
|
|
if req.Body != nil { |
|
defer req.Body.Close() //nolint[errcheck] |
|
var err error |
|
if reqBodyBuffer, err = ioutil.ReadAll(req.Body); err != nil { |
|
return err |
|
} |
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) |
|
} |
|
|
|
return c.doJSONBuffered(req, reqBodyBuffer, data) |
|
} |
|
|
|
// doJSONBuffered performs a buffered json request (see DoJSON for more information). |
|
func (c *Client) doJSONBuffered(req *http.Request, reqBodyBuffer []byte, data interface{}) error { // nolint[funlen] |
|
req.Header.Set("Accept", "application/vnd.protonmail.v1+json") |
|
|
|
var cancelRequest context.CancelFunc |
|
if c.cm.GetConfig().MinSpeed > 0 { |
|
var ctx context.Context |
|
ctx, cancelRequest = context.WithCancel(req.Context()) |
|
defer func() { |
|
cancelRequest() |
|
}() |
|
req = req.WithContext(ctx) |
|
} |
|
|
|
res, err := c.doBuffered(req, reqBodyBuffer, false) |
|
if err != nil { |
|
return err |
|
} |
|
defer res.Body.Close() //nolint[errcheck] |
|
|
|
var resBody []byte |
|
if c.cm.GetConfig().MinSpeed == 0 { |
|
resBody, err = ioutil.ReadAll(res.Body) |
|
} else { |
|
resBody, err = c.readAllMinSpeed(res.Body, cancelRequest) |
|
} |
|
|
|
// The server response may contain data which we want to have in memory |
|
// for as little time as possible (such as keys). Go is garbage collected, |
|
// so we are not in charge of when the memory will actually be cleared. |
|
// We can at least try to rewrite the original data to mitigate this problem. |
|
defer func() { |
|
for i := 0; i < len(resBody); i++ { |
|
resBody[i] = byte(65) |
|
} |
|
}() |
|
|
|
if logrus.GetLevel() == logrus.TraceLevel { |
|
head := "" |
|
for i, v := range res.Header { |
|
head += i + ": " |
|
head += strings.Join(v, "") |
|
head += "\n" |
|
} |
|
c.log.Tracef("RESHEAD \n%s", head) |
|
c.log.Tracef("RESBODY '%s'", resBody) |
|
} |
|
|
|
if err != nil { |
|
return err |
|
} |
|
|
|
// Retry induced by API code. |
|
errCode := &Res{} |
|
if err := json.Unmarshal(resBody, errCode); err == nil { |
|
if errCode.Code == BansRequests { |
|
retryAfter := 3 |
|
c.log.Warningf("Retrying %s after %ds induced by API code %d", req.URL.Path, retryAfter, errCode.Code) |
|
time.Sleep(time.Duration(retryAfter) * time.Second) |
|
if len(reqBodyBuffer) > 0 { |
|
req.Body = ioutil.NopCloser(bytes.NewReader(reqBodyBuffer)) |
|
} |
|
return c.doJSONBuffered(req, reqBodyBuffer, data) |
|
} |
|
} |
|
|
|
if err := json.Unmarshal(resBody, data); err != nil { |
|
// Check to see if this is due to a non 2xx HTTP status code. |
|
if res.StatusCode != http.StatusOK { |
|
r := bytes.NewReader(bytes.ReplaceAll(resBody, []byte("\n"), []byte("\\n"))) |
|
plaintext, err := html2text.FromReader(r) |
|
if err == nil { |
|
return fmt.Errorf("Error: \n\n" + res.Status + "\n\n" + plaintext) |
|
} |
|
} |
|
|
|
if errJS, ok := err.(*json.SyntaxError); ok { |
|
return fmt.Errorf("invalid json %v (offset:%d) ", errJS.Error(), errJS.Offset) |
|
} |
|
|
|
return fmt.Errorf("unmarshal fail: %v ", err) |
|
} |
|
|
|
// Set StatusCode in case data struct supports that field. |
|
// It's safe to set StatusCode, server returns Code. StatusCode should be preferred over Code. |
|
dataValue := reflect.ValueOf(data).Elem() |
|
statusCodeField := dataValue.FieldByName("StatusCode") |
|
if statusCodeField.IsValid() && statusCodeField.CanSet() && statusCodeField.Kind() == reflect.Int { |
|
statusCodeField.SetInt(int64(res.StatusCode)) |
|
} |
|
|
|
if res.StatusCode != http.StatusOK { |
|
c.log.Warnf("request %s %s NOT OK: %s", req.Method, req.URL.Path, res.Status) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (c *Client) readAllMinSpeed(data io.Reader, cancelRequest context.CancelFunc) ([]byte, error) { |
|
firstReadTimeout := c.cm.GetConfig().FirstReadTimeout |
|
if firstReadTimeout == 0 { |
|
firstReadTimeout = 5 * time.Minute |
|
} |
|
timer := time.AfterFunc(firstReadTimeout, func() { |
|
cancelRequest() |
|
}) |
|
var buffer bytes.Buffer |
|
for { |
|
_, err := io.CopyN(&buffer, data, c.cm.GetConfig().MinSpeed) |
|
timer.Stop() |
|
timer.Reset(1 * time.Second) |
|
if err == io.EOF { |
|
break |
|
} else if err != nil { |
|
return nil, err |
|
} |
|
} |
|
return ioutil.ReadAll(&buffer) |
|
} |
|
|
|
func (c *Client) refreshAccessToken() (err error) { |
|
c.log.Debug("Refreshing token") |
|
|
|
refreshToken := c.cm.GetToken(c.userID) |
|
|
|
if refreshToken == "" { |
|
c.sendAuth(nil) |
|
return ErrInvalidToken |
|
} |
|
|
|
if _, err := c.AuthRefresh(refreshToken); err != nil { |
|
c.sendAuth(nil) |
|
return err |
|
} |
|
|
|
return |
|
} |
|
|
|
func (c *Client) handleStatusUnauthorized(req *http.Request, reqBodyBuffer []byte, res *http.Response, retry bool) (retryRes *http.Response, err error) { |
|
c.log.Info("Handling unauthorized status") |
|
|
|
// If this is not a retry, then it is the first time handling status unauthorized, |
|
// so try again without refreshing the access token. |
|
if !retry { |
|
c.log.Debug("Handling unauthorized status by retrying") |
|
c.requestLocker.Lock() |
|
defer c.requestLocker.Unlock() |
|
|
|
_, _ = io.Copy(ioutil.Discard, res.Body) |
|
_ = res.Body.Close() |
|
return c.doBuffered(req, reqBodyBuffer, true) |
|
} |
|
|
|
// This is already a retry, so we will try to refresh the access token before trying again. |
|
if err = c.refreshAccessToken(); err != nil { |
|
c.log.WithError(err).Warn("Cannot refresh token") |
|
err = &ErrUnauthorized{err} |
|
return |
|
} |
|
_, err = io.Copy(ioutil.Discard, res.Body) |
|
if err != nil { |
|
c.log.WithError(err).Warn("Failed to read out response body") |
|
} |
|
_ = res.Body.Close() |
|
return c.doBuffered(req, reqBodyBuffer, true) |
|
}
|
|
|