parent
ce4a75caf5
commit
e333ccd29e
9 changed files with 260 additions and 7 deletions
@ -0,0 +1,61 @@ |
||||
package cookies |
||||
|
||||
import ( |
||||
"net/http" |
||||
"net/http/cookiejar" |
||||
"net/url" |
||||
"sync" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
type Jar struct { |
||||
jar *cookiejar.Jar |
||||
persister *Persister |
||||
locker sync.Locker |
||||
} |
||||
|
||||
func New(persister *Persister) (*Jar, error) { |
||||
jar, err := cookiejar.New(nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cookies, err := persister.Load() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for rawURL, cookies := range cookies { |
||||
url, err := url.Parse(rawURL) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
jar.SetCookies(url, cookies) |
||||
} |
||||
|
||||
return &Jar{ |
||||
jar: jar, |
||||
persister: persister, |
||||
locker: &sync.Mutex{}, |
||||
}, nil |
||||
} |
||||
|
||||
func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { |
||||
j.locker.Lock() |
||||
defer j.locker.Unlock() |
||||
|
||||
j.jar.SetCookies(u, cookies) |
||||
|
||||
if err := j.persister.Persist(u.String(), cookies); err != nil { |
||||
logrus.WithError(err).Warn("Failed to persist cookie") |
||||
} |
||||
} |
||||
|
||||
func (j *Jar) Cookies(u *url.URL) []*http.Cookie { |
||||
j.locker.Lock() |
||||
defer j.locker.Unlock() |
||||
|
||||
return j.jar.Cookies(u) |
||||
} |
||||
@ -0,0 +1,80 @@ |
||||
package cookies |
||||
|
||||
import ( |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestJar(t *testing.T) { |
||||
testCookies := []testCookie{ |
||||
{"TestName1", "TestValue1"}, |
||||
{"TestName2", "TestValue2"}, |
||||
{"TestName3", "TestValue3"}, |
||||
} |
||||
|
||||
ts := getTestServer(t, testCookies...) |
||||
defer ts.Close() |
||||
|
||||
jar, err := New(NewPersister(make(testPersister))) |
||||
require.NoError(t, err) |
||||
|
||||
client := &http.Client{Jar: jar} |
||||
|
||||
setRes, err := client.Get(ts.URL + "/set") |
||||
if err != nil { |
||||
t.FailNow() |
||||
} |
||||
require.NoError(t, setRes.Body.Close()) |
||||
|
||||
getRes, err := client.Get(ts.URL + "/get") |
||||
if err != nil { |
||||
t.FailNow() |
||||
} |
||||
require.NoError(t, getRes.Body.Close()) |
||||
} |
||||
|
||||
type testCookie struct { |
||||
name, value string |
||||
} |
||||
|
||||
func getTestServer(t *testing.T, wantCookies ...testCookie) *httptest.Server { |
||||
mux := http.NewServeMux() |
||||
|
||||
mux.HandleFunc("/set", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
for _, cookie := range wantCookies { |
||||
http.SetCookie(w, &http.Cookie{ |
||||
Name: cookie.name, |
||||
Value: cookie.value, |
||||
}) |
||||
} |
||||
|
||||
w.WriteHeader(http.StatusOK) |
||||
})) |
||||
|
||||
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
require.Len(t, r.Cookies(), len(wantCookies)) |
||||
|
||||
for k, v := range r.Cookies() { |
||||
assert.Equal(t, wantCookies[k].name, v.Name) |
||||
assert.Equal(t, wantCookies[k].value, v.Value) |
||||
} |
||||
|
||||
w.WriteHeader(http.StatusOK) |
||||
})) |
||||
|
||||
return httptest.NewServer(mux) |
||||
} |
||||
|
||||
type testPersister map[string]string |
||||
|
||||
func (p testPersister) Set(key, value string) { |
||||
p[key] = value |
||||
} |
||||
|
||||
func (p testPersister) Get(key string) string { |
||||
return p[key] |
||||
} |
||||
@ -0,0 +1,91 @@ |
||||
package cookies |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/preferences" |
||||
) |
||||
|
||||
type Persister struct { |
||||
prefs GetterSetter |
||||
} |
||||
|
||||
type GetterSetter interface { |
||||
Get(string) string |
||||
Set(string, string) |
||||
} |
||||
|
||||
func NewPersister(prefs GetterSetter) *Persister { |
||||
return &Persister{prefs: prefs} |
||||
} |
||||
|
||||
func (p *Persister) Persist(url string, cookies []*http.Cookie) error { |
||||
b, err := json.Marshal(cookies) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
val, err := p.load() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
val[url] = string(b) |
||||
|
||||
return p.save(val) |
||||
} |
||||
|
||||
func (p *Persister) Load() (map[string][]*http.Cookie, error) { |
||||
res := make(map[string][]*http.Cookie) |
||||
|
||||
val, err := p.load() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for url, rawCookies := range val { |
||||
var cookies []*http.Cookie |
||||
|
||||
if err := json.Unmarshal([]byte(rawCookies), &cookies); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
res[url] = cookies |
||||
} |
||||
|
||||
return res, nil |
||||
} |
||||
|
||||
type dataStructure map[string]string |
||||
|
||||
func (p *Persister) load() (dataStructure, error) { |
||||
b := p.prefs.Get(preferences.CookiesKey) |
||||
|
||||
if b == "" { |
||||
if err := p.save(make(dataStructure)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return p.load() |
||||
} |
||||
|
||||
var val dataStructure |
||||
|
||||
if err := json.Unmarshal([]byte(b), &val); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return val, nil |
||||
} |
||||
|
||||
func (p *Persister) save(val dataStructure) error { |
||||
b, err := json.Marshal(val) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
p.prefs.Set(preferences.CookiesKey, string(b)) |
||||
|
||||
return nil |
||||
} |
||||
Loading…
Reference in new issue