You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
341 lines
8.9 KiB
341 lines
8.9 KiB
package pmapi |
|
|
|
import ( |
|
"fmt" |
|
"net/http" |
|
"sync" |
|
"time" |
|
|
|
"github.com/pkg/errors" |
|
"github.com/sirupsen/logrus" |
|
) |
|
|
|
var defaultProxyUseDuration = 24 * time.Hour |
|
|
|
// ClientManager is a manager of clients. |
|
type ClientManager struct { |
|
// newClient is used to create new Clients. By default this creates pmapi clients but it can be overridden to |
|
// create other types of clients (e.g. for integration tests). |
|
newClient func(userID string) Client |
|
|
|
config *ClientConfig |
|
roundTripper http.RoundTripper |
|
|
|
// TODO: These need to be Client (not *client) because we might need to create *FakePMAPI for integration tests. |
|
// But that screws up other things like not being able to clear sensitive info during logout |
|
// unless the client interface contains a method for that. |
|
clients map[string]Client |
|
clientsLocker sync.Locker |
|
|
|
tokens map[string]string |
|
tokensLocker sync.Locker |
|
|
|
expirations map[string]*tokenExpiration |
|
expirationsLocker sync.Locker |
|
|
|
host, scheme string |
|
hostLocker sync.Locker |
|
|
|
bridgeAuths chan ClientAuth |
|
clientAuths chan ClientAuth |
|
|
|
allowProxy bool |
|
proxyProvider *proxyProvider |
|
proxyUseDuration time.Duration |
|
} |
|
|
|
// ClientAuth holds an API auth produced by a Client for a specific user. |
|
type ClientAuth struct { |
|
UserID string |
|
Auth *Auth |
|
} |
|
|
|
// tokenExpiration manages the expiration of an access token. |
|
type tokenExpiration struct { |
|
timer *time.Timer |
|
cancel chan (struct{}) |
|
} |
|
|
|
// NewClientManager creates a new ClientMan which manages clients configured with the given client config. |
|
func NewClientManager(config *ClientConfig) (cm *ClientManager) { |
|
cm = &ClientManager{ |
|
config: config, |
|
roundTripper: http.DefaultTransport, |
|
|
|
clients: make(map[string]Client), |
|
clientsLocker: &sync.Mutex{}, |
|
|
|
tokens: make(map[string]string), |
|
tokensLocker: &sync.Mutex{}, |
|
|
|
expirations: make(map[string]*tokenExpiration), |
|
expirationsLocker: &sync.Mutex{}, |
|
|
|
host: RootURL, |
|
scheme: RootScheme, |
|
hostLocker: &sync.Mutex{}, |
|
|
|
bridgeAuths: make(chan ClientAuth), |
|
clientAuths: make(chan ClientAuth), |
|
|
|
proxyProvider: newProxyProvider(dohProviders, proxyQuery), |
|
proxyUseDuration: defaultProxyUseDuration, |
|
} |
|
|
|
cm.newClient = func(userID string) Client { |
|
return newClient(cm, userID) |
|
} |
|
|
|
go cm.forwardClientAuths() |
|
|
|
return |
|
} |
|
|
|
func (cm *ClientManager) SetClientConstructor(f func(userID string) Client) { |
|
cm.newClient = f |
|
} |
|
|
|
// SetRoundTripper sets the roundtripper used by clients created by this client manager. |
|
func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) { |
|
cm.roundTripper = rt |
|
} |
|
|
|
// GetClient returns a client for the given userID. |
|
// If the client does not exist already, it is created. |
|
func (cm *ClientManager) GetClient(userID string) Client { |
|
cm.clientsLocker.Lock() |
|
defer cm.clientsLocker.Unlock() |
|
|
|
if client, ok := cm.clients[userID]; ok { |
|
return client |
|
} |
|
|
|
cm.clients[userID] = cm.newClient(userID) |
|
|
|
return cm.clients[userID] |
|
} |
|
|
|
// GetAnonymousClient returns an anonymous client. It replaces any anonymous client that was already created. |
|
func (cm *ClientManager) GetAnonymousClient() Client { |
|
cm.clientsLocker.Lock() |
|
defer cm.clientsLocker.Unlock() |
|
|
|
if client, ok := cm.clients[""]; ok { |
|
client.DeleteAuth() |
|
} |
|
|
|
cm.clients[""] = cm.newClient("") |
|
|
|
return cm.clients[""] |
|
} |
|
|
|
// LogoutClient logs out the client with the given userID and ensures its sensitive data is successfully cleared. |
|
func (cm *ClientManager) LogoutClient(userID string) { |
|
client, ok := cm.clients[userID] |
|
|
|
if !ok { |
|
return |
|
} |
|
|
|
delete(cm.clients, userID) |
|
|
|
go func() { |
|
if err := client.DeleteAuth(); err != nil { |
|
// TODO: Retry if the request failed. |
|
} |
|
client.ClearData() |
|
cm.clearToken(userID) |
|
}() |
|
|
|
return |
|
} |
|
|
|
// GetRootURL returns the full root URL (scheme+host). |
|
func (cm *ClientManager) GetRootURL() string { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
return fmt.Sprintf("%v://%v", cm.scheme, cm.host) |
|
} |
|
|
|
// getHost returns the host to make requests to. |
|
// It does not include the protocol i.e. no "https://" (use getScheme for that). |
|
func (cm *ClientManager) getHost() string { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
return cm.host |
|
} |
|
|
|
// getScheme returns the scheme with which to make requests to the host. |
|
func (cm *ClientManager) getScheme() string { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
return cm.scheme |
|
} |
|
|
|
// IsProxyAllowed returns whether the user has allowed us to switch to a proxy if need be. |
|
func (cm *ClientManager) IsProxyAllowed() bool { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
return cm.allowProxy |
|
} |
|
|
|
// AllowProxy allows the client manager to switch clients over to a proxy if need be. |
|
func (cm *ClientManager) AllowProxy() { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
cm.allowProxy = true |
|
} |
|
|
|
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be. |
|
func (cm *ClientManager) DisallowProxy() { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
cm.allowProxy = false |
|
cm.host = RootURL |
|
} |
|
|
|
// IsProxyEnabled returns whether we are currently proxying requests. |
|
func (cm *ClientManager) IsProxyEnabled() bool { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
return cm.host != RootURL |
|
} |
|
|
|
// switchToReachableServer switches to using a reachable server (either proxy or standard API). |
|
func (cm *ClientManager) switchToReachableServer() (proxy string, err error) { |
|
cm.hostLocker.Lock() |
|
defer cm.hostLocker.Unlock() |
|
|
|
logrus.Info("Attempting to switch to a proxy") |
|
|
|
if proxy, err = cm.proxyProvider.findReachableServer(); err != nil { |
|
err = errors.Wrap(err, "failed to find a usable proxy") |
|
return |
|
} |
|
|
|
logrus.WithField("proxy", proxy).Info("Switching to a proxy") |
|
|
|
// If the host is currently the RootURL, it's the first time we are enabling a proxy. |
|
// This means we want to disable it again in 24 hours. |
|
if cm.host == RootURL { |
|
go func() { |
|
<-time.After(cm.proxyUseDuration) |
|
cm.host = RootURL |
|
}() |
|
} |
|
|
|
cm.host = proxy |
|
|
|
return |
|
} |
|
|
|
// GetToken returns the token for the given userID. |
|
func (cm *ClientManager) GetToken(userID string) string { |
|
cm.tokensLocker.Lock() |
|
defer cm.tokensLocker.Unlock() |
|
|
|
return cm.tokens[userID] |
|
} |
|
|
|
// GetAuthUpdateChannel returns a channel on which client auths can be received. |
|
func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth { |
|
return cm.bridgeAuths |
|
} |
|
|
|
// getClientAuthChannel returns a channel on which clients should send auths. |
|
func (cm *ClientManager) getClientAuthChannel() chan ClientAuth { |
|
return cm.clientAuths |
|
} |
|
|
|
// forwardClientAuths handles all incoming auths from clients before forwarding them on the bridge auth channel. |
|
func (cm *ClientManager) forwardClientAuths() { |
|
for auth := range cm.clientAuths { |
|
logrus.Debug("ClientManager received auth from client") |
|
cm.handleClientAuth(auth) |
|
logrus.Debug("ClientManager is forwarding auth to bridge") |
|
cm.bridgeAuths <- auth |
|
} |
|
} |
|
|
|
// setToken sets the token for the given userID with the given expiration time. |
|
func (cm *ClientManager) setToken(userID, token string, expiration time.Duration) { |
|
// We don't want to set tokens of anonymous clients. |
|
if userID == "" { |
|
return |
|
} |
|
|
|
cm.tokensLocker.Lock() |
|
defer cm.tokensLocker.Unlock() |
|
|
|
logrus.WithField("userID", userID).Info("Updating token") |
|
|
|
cm.tokens[userID] = token |
|
|
|
cm.setTokenExpiration(userID, expiration) |
|
} |
|
|
|
// setTokenExpiration will ensure the token is refreshed if it expires. |
|
// If the token already has an expiration time set, it is replaced. |
|
func (cm *ClientManager) setTokenExpiration(userID string, expiration time.Duration) { |
|
cm.expirationsLocker.Lock() |
|
defer cm.expirationsLocker.Unlock() |
|
|
|
if exp, ok := cm.expirations[userID]; ok { |
|
exp.timer.Stop() |
|
close(exp.cancel) |
|
} |
|
|
|
cm.expirations[userID] = &tokenExpiration{ |
|
timer: time.NewTimer(expiration), |
|
cancel: make(chan struct{}), |
|
} |
|
|
|
go cm.watchTokenExpiration(userID) |
|
} |
|
|
|
func (cm *ClientManager) clearToken(userID string) { |
|
cm.tokensLocker.Lock() |
|
defer cm.tokensLocker.Unlock() |
|
|
|
logrus.WithField("userID", userID).Info("Clearing token") |
|
|
|
delete(cm.tokens, userID) |
|
} |
|
|
|
// handleClientAuth updates or clears client authorisation based on auths received. |
|
func (cm *ClientManager) handleClientAuth(ca ClientAuth) { |
|
// If we aren't managing this client, there's nothing to do. |
|
if _, ok := cm.clients[ca.UserID]; !ok { |
|
logrus.WithField("userID", ca.UserID).Info("Handling auth for unmanaged client") |
|
return |
|
} |
|
|
|
// If the auth is nil, we should clear the token. |
|
// TODO: Maybe we should trigger a client logout here? Then we don't have to remember to log it out ourself. |
|
if ca.Auth == nil { |
|
cm.clearToken(ca.UserID) |
|
return |
|
} |
|
|
|
cm.setToken(ca.UserID, ca.Auth.GenToken(), time.Duration(ca.Auth.ExpiresIn)*time.Second) |
|
} |
|
|
|
func (cm *ClientManager) watchTokenExpiration(userID string) { |
|
expiration := cm.expirations[userID] |
|
|
|
select { |
|
case <-expiration.timer.C: |
|
logrus.WithField("userID", userID).Info("Auth token expired! Refreshing") |
|
cm.clients[userID].AuthRefresh(cm.tokens[userID]) |
|
|
|
case <-expiration.cancel: |
|
logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") |
|
} |
|
}
|
|
|