Skip to content

Commit 41da9bd

Browse files
committed
Add support for refresh token for linkedin provider
1 parent 779abc8 commit 41da9bd

4 files changed

Lines changed: 118 additions & 13 deletions

File tree

providers/linkedin/linkedin.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package linkedin
33

44
import (
5+
"context"
56
"encoding/json"
67
"errors"
78
"fmt"
@@ -82,9 +83,10 @@ func (p *Provider) BeginAuth(state string) (goth.Session, error) {
8283
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
8384
s := session.(*Session)
8485
user := goth.User{
85-
AccessToken: s.AccessToken,
86-
Provider: p.Name(),
87-
ExpiresAt: s.ExpiresAt,
86+
AccessToken: s.AccessToken,
87+
Provider: p.Name(),
88+
ExpiresAt: s.ExpiresAt,
89+
RefreshToken: s.RefreshToken,
8890
}
8991

9092
if user.AccessToken == "" {
@@ -267,12 +269,18 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
267269
return c
268270
}
269271

270-
// RefreshToken refresh token is not provided by linkedin
271-
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
272-
return nil, errors.New("Refresh token is not provided by linkedin")
272+
// RefreshTokenAvailable tells whether a refresh token is provided by the auth provider or not
273+
func (p *Provider) RefreshTokenAvailable() bool {
274+
return true
273275
}
274276

275-
// RefreshTokenAvailable refresh token is not provided by linkedin
276-
func (p *Provider) RefreshTokenAvailable() bool {
277-
return false
277+
// RefreshToken gets a new access token using the refresh token
278+
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
279+
token := &oauth2.Token{RefreshToken: refreshToken}
280+
ts := p.config.TokenSource(context.Background(), token)
281+
newToken, err := ts.Token()
282+
if err != nil {
283+
return nil, err
284+
}
285+
return newToken, err
278286
}

providers/linkedin/linkedin_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ func Test_SessionFromJSON(t *testing.T) {
4747

4848
provider := linkedinProvider()
4949

50-
s, err := provider.UnmarshalSession(`{"AuthURL":"http://linkedin.com/auth_url","AccessToken":"1234567890"}`)
50+
s, err := provider.UnmarshalSession(`{"AuthURL":"http://linkedin.com/auth_url","AccessToken":"1234567890","RefreshToken":"987654321"}`)
5151
a.NoError(err)
5252
session := s.(*linkedin.Session)
5353
a.Equal(session.AuthURL, "http://linkedin.com/auth_url")
5454
a.Equal(session.AccessToken, "1234567890")
55+
a.Equal(session.RefreshToken, "987654321")
5556
}
5657

5758
func linkedinProvider() *linkedin.Provider {

providers/linkedin/session.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ import (
1010

1111
// Session stores data during the auth process with LinkedIn.
1212
type Session struct {
13-
AuthURL string
14-
AccessToken string
15-
ExpiresAt time.Time
13+
AuthURL string
14+
AccessToken string
15+
ExpiresAt time.Time
16+
RefreshToken string
1617
}
1718

1819
// GetAuthURL will return the URL set by calling the `BeginAuth` function on the LinkedIn provider.
@@ -37,6 +38,7 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,
3738

3839
s.AccessToken = token.AccessToken
3940
s.ExpiresAt = token.Expiry
41+
s.RefreshToken = token.RefreshToken
4042
return token.AccessToken, err
4143
}
4244

providers/linkedin/session_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,37 @@
11
package linkedin_test
22

33
import (
4+
"errors"
5+
"io"
6+
"net/http"
7+
"strings"
48
"testing"
9+
"time"
510

611
"github.com/markbates/goth"
712
"github.com/markbates/goth/providers/linkedin"
813
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/mock"
15+
"github.com/stretchr/testify/require"
916
)
1017

18+
type MockParams struct {
19+
params map[string]string
20+
}
21+
22+
func (m *MockParams) Get(key string) string {
23+
return m.params[key]
24+
}
25+
26+
type MockedHTTPClient struct {
27+
mock.Mock
28+
}
29+
30+
func (m *MockedHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
31+
args := m.Mock.Called(req)
32+
return args.Get(0).(*http.Response), args.Error(1)
33+
}
34+
1135
func Test_Implements_Session(t *testing.T) {
1236
t.Parallel()
1337
a := assert.New(t)
@@ -46,3 +70,73 @@ func Test_String(t *testing.T) {
4670

4771
a.Equal(s.String(), s.Marshal())
4872
}
73+
74+
func Test_Authorize(t *testing.T) {
75+
session := &linkedin.Session{}
76+
params := &MockParams{
77+
params: map[string]string{
78+
"code": "authorization_code",
79+
},
80+
}
81+
82+
t.Run("happy path", func(t *testing.T) {
83+
mockClient := new(MockedHTTPClient)
84+
p := linkedinProvider()
85+
p.HTTPClient = &http.Client{Transport: mockClient}
86+
mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{
87+
StatusCode: http.StatusOK,
88+
Body: io.NopCloser(strings.NewReader(`{"access_token":"test_token","expires_in":3600, "refresh_token":"refresh_token"}`)),
89+
}, nil)
90+
token, err := session.Authorize(p, params)
91+
require.NoError(t, err)
92+
assert.Equal(t, "test_token", token)
93+
assert.Equal(t, session.AccessToken, "test_token")
94+
assert.WithinDuration(t, session.ExpiresAt, time.Now().Add(3600*time.Second), 1*time.Second)
95+
assert.Equal(t, session.RefreshToken, "refresh_token")
96+
})
97+
98+
t.Run("error on request", func(t *testing.T) {
99+
mockClient := new(MockedHTTPClient)
100+
p := linkedinProvider()
101+
p.HTTPClient = &http.Client{Transport: mockClient}
102+
mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error"))
103+
_, err := session.Authorize(p, params)
104+
require.Error(t, err)
105+
})
106+
107+
t.Run("non-200 status code", func(t *testing.T) {
108+
mockClient := new(MockedHTTPClient)
109+
p := linkedinProvider()
110+
p.HTTPClient = &http.Client{Transport: mockClient}
111+
mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{
112+
StatusCode: http.StatusForbidden,
113+
Body: io.NopCloser(strings.NewReader(``)),
114+
}, nil)
115+
_, err := session.Authorize(p, params)
116+
require.Error(t, err)
117+
})
118+
119+
t.Run("error on response decode", func(t *testing.T) {
120+
mockClient := new(MockedHTTPClient)
121+
p := linkedinProvider()
122+
p.HTTPClient = &http.Client{Transport: mockClient}
123+
mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{
124+
StatusCode: http.StatusOK,
125+
Body: io.NopCloser(strings.NewReader(`not a json`)),
126+
}, nil)
127+
_, err := session.Authorize(p, params)
128+
require.Error(t, err)
129+
})
130+
131+
t.Run("error code in response", func(t *testing.T) {
132+
mockClient := new(MockedHTTPClient)
133+
p := linkedinProvider()
134+
p.HTTPClient = &http.Client{Transport: mockClient}
135+
mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{
136+
StatusCode: http.StatusOK,
137+
Body: io.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)),
138+
}, nil)
139+
_, err := session.Authorize(p, params)
140+
require.Error(t, err)
141+
})
142+
}

0 commit comments

Comments
 (0)