-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathauth_service.go
More file actions
367 lines (314 loc) · 11.3 KB
/
auth_service.go
File metadata and controls
367 lines (314 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
package auth
import (
"context"
"errors"
"time"
"github.com/usetero/cli/internal/domain"
"github.com/usetero/cli/internal/log"
)
// Auth provides authentication operations.
type Auth interface {
StartDeviceAuth(ctx context.Context) (*DeviceAuth, error)
WaitForAuth(ctx context.Context, deviceCode string, interval time.Duration) (*Result, error)
IsAuthenticated() bool
GetAccessToken(ctx context.Context) (string, error)
GetUserID(ctx context.Context) (string, error)
ClearTokens() error
RefreshTokenWithoutOrganization(ctx context.Context) (string, error)
RefreshTokenWithOrganization(ctx context.Context, workosOrgID domain.WorkosOrganizationID) (string, error)
}
// Service handles authentication business logic.
// It coordinates between the OAuth provider and secure token storage.
// It defines domain concepts (access_token, refresh_token) and translates them
// to/from generic key-value storage operations.
type Service struct {
provider OAuthProvider
storage SecureStorage
scope log.Scope
}
// Ensure Service implements Auth.
var _ Auth = (*Service)(nil)
// NewService creates a new authentication service.
func NewService(provider OAuthProvider, storage SecureStorage, scope log.Scope) *Service {
scope = scope.Child("auth")
return &Service{
provider: provider,
storage: storage,
scope: scope,
}
}
// DeviceAuth contains the information needed to display to the user.
type DeviceAuth struct {
UserCode string
VerificationURI string
VerificationURIComplete string
DeviceCode string // Kept internally for polling
ExpiresIn int
Interval int
}
// Result contains the tokens and user information after successful authentication.
type Result struct {
AccessToken string
RefreshToken string
User User
}
// User represents an authenticated user.
type User struct {
ID string
Email string
EmailVerified bool
FirstName string
LastName string
}
// StartDeviceAuth initiates the device authorization flow.
func (s *Service) StartDeviceAuth(ctx context.Context) (*DeviceAuth, error) {
s.scope.Debug("starting device authorization flow")
resp, err := s.provider.AuthorizeDevice(ctx)
if err != nil {
s.scope.Error("failed to start device authorization", "error", err)
return nil, err
}
s.scope.Debug("device authorization started",
"user_code", resp.UserCode,
"expires_in", resp.ExpiresIn,
"interval", resp.Interval)
return &DeviceAuth{
UserCode: resp.UserCode,
VerificationURI: resp.VerificationURI,
VerificationURIComplete: resp.VerificationURIComplete,
DeviceCode: resp.DeviceCode,
ExpiresIn: resp.ExpiresIn,
Interval: resp.Interval,
}, nil
}
// WaitForAuth polls the OAuth provider until the user completes authentication or an error occurs.
// This is a blocking call that handles the polling loop with proper backoff.
func (s *Service) WaitForAuth(ctx context.Context, deviceCode string, interval time.Duration) (*Result, error) {
s.scope.Debug("starting authentication polling", "interval", interval)
ticker := time.NewTicker(interval)
defer ticker.Stop()
currentInterval := interval
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
resp, err := s.provider.PollAuthentication(ctx, deviceCode)
if err == nil {
// Success! Save tokens and return
s.scope.Info("authentication successful", "user_id", resp.User.ID, "email", resp.User.Email)
if err := s.saveTokens(resp.AccessToken, resp.RefreshToken); err != nil {
s.scope.Error("failed to save tokens", "error", err)
return nil, err
}
return &Result{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
User: resp.User,
}, nil
}
// Handle specific error types
var pendingErr *AuthorizationPendingError
var slowDownErr *SlowDownError
var expiredErr *ExpiredTokenError
var deniedErr *AccessDeniedError
switch {
case errors.As(err, &pendingErr):
// Still waiting - continue polling
s.scope.Debug("authorization pending, continuing to poll")
continue
case errors.As(err, &slowDownErr):
// Increase polling interval
currentInterval = currentInterval * 2
ticker.Reset(currentInterval)
s.scope.Debug("slowing down polling", "new_interval", currentInterval)
continue
case errors.As(err, &expiredErr):
s.scope.Error("device code expired")
return nil, errors.New("device code expired - press 'r' to restart")
case errors.As(err, &deniedErr):
s.scope.Info("user denied authorization")
return nil, errors.New("user denied authorization")
default:
// Unknown error
s.scope.Error("authentication polling failed", "error", err)
return nil, err
}
}
}
}
// IsAuthenticated checks if the user has valid stored credentials.
func (s *Service) IsAuthenticated() bool {
accessToken, err := s.storage.Get("access_token")
if err != nil {
s.scope.Error("failed to check authentication", "error", err)
return false
}
return accessToken != ""
}
// GetUserID returns the WorkOS user ID from the current access token.
func (s *Service) GetUserID(ctx context.Context) (string, error) {
token, err := s.GetAccessToken(ctx)
if err != nil {
return "", err
}
claims, err := ParseToken(token)
if err != nil {
return "", err
}
if claims.Sub == "" {
return "", errors.New("token has no user ID")
}
return claims.Sub, nil
}
// GetAccessToken retrieves the stored access token, refreshing if expired.
func (s *Service) GetAccessToken(ctx context.Context) (string, error) {
accessToken, err := s.storage.Get("access_token")
if err != nil {
s.scope.Error("failed to get access token", "error", err)
return "", err
}
if accessToken == "" {
return "", errors.New("no access token found")
}
workosOrgID := domain.WorkosOrganizationID("")
// Check if token is expired
claims, err := ParseToken(accessToken)
if err == nil && claims.OrgID != "" {
workosOrgID = domain.WorkosOrganizationID(claims.OrgID)
}
if err != nil || claims.IsExpired() {
s.scope.Debug("access token expired, refreshing")
refreshToken, err := s.storage.Get("refresh_token")
if err != nil {
s.scope.Error("failed to get refresh token", "error", err)
return "", err
}
if refreshToken == "" {
return "", errors.New("no refresh token found")
}
resp, err := s.refreshTokenForScope(ctx, refreshToken, workosOrgID)
if err != nil {
s.scope.Error("failed to refresh token", "error", err)
return "", err
}
if err := s.saveTokens(resp.AccessToken, resp.RefreshToken); err != nil {
s.scope.Error("failed to save refreshed tokens", "error", err)
return "", err
}
s.scope.Debug("token refreshed successfully")
return resp.AccessToken, nil
}
return accessToken, nil
}
// ForceRefreshAccessToken refreshes the access token unconditionally, bypassing
// the local expiration check. This is needed when a server rejects a token that
// the client still considers valid (e.g. due to clock skew).
func (s *Service) ForceRefreshAccessToken(ctx context.Context) (string, error) {
s.scope.Debug("force-refreshing access token")
workosOrgID := domain.WorkosOrganizationID("")
if accessToken, err := s.storage.Get("access_token"); err == nil && accessToken != "" {
if claims, parseErr := ParseToken(accessToken); parseErr == nil && claims.OrgID != "" {
workosOrgID = domain.WorkosOrganizationID(claims.OrgID)
}
}
refreshToken, err := s.storage.Get("refresh_token")
if err != nil {
s.scope.Error("failed to get refresh token", "error", err)
return "", err
}
if refreshToken == "" {
return "", errors.New("no refresh token found")
}
resp, err := s.refreshTokenForScope(ctx, refreshToken, workosOrgID)
if err != nil {
s.scope.Error("failed to refresh token", "error", err)
return "", err
}
if err := s.saveTokens(resp.AccessToken, resp.RefreshToken); err != nil {
s.scope.Error("failed to save refreshed tokens", "error", err)
return "", err
}
s.scope.Debug("token force-refreshed successfully")
return resp.AccessToken, nil
}
// ClearTokens removes all stored authentication tokens.
func (s *Service) ClearTokens() error {
s.scope.Info("clearing authentication tokens")
if err := s.storage.Delete("access_token"); err != nil {
s.scope.Error("failed to delete access token", "error", err)
return err
}
if err := s.storage.Delete("refresh_token"); err != nil {
s.scope.Error("failed to delete refresh token", "error", err)
return err
}
return nil
}
// RefreshTokenWithoutOrganization refreshes the access token without any organization scope.
// This is used for bootstrap flows where the user needs a user-scoped token to create an org.
// Returns the new access token so callers can update their API clients.
func (s *Service) RefreshTokenWithoutOrganization(ctx context.Context) (string, error) {
s.scope.Debug("refreshing token without organization scope")
refreshToken, err := s.storage.Get("refresh_token")
if err != nil {
s.scope.Error("failed to get refresh token", "error", err)
return "", err
}
if refreshToken == "" {
return "", errors.New("no refresh token found")
}
resp, err := s.provider.RefreshToken(ctx, refreshToken)
if err != nil {
s.scope.Error("failed to refresh token without organization", "error", err)
return "", err
}
if err := s.saveTokens(resp.AccessToken, resp.RefreshToken); err != nil {
s.scope.Error("failed to save refreshed tokens", "error", err)
return "", err
}
s.scope.Info("token refreshed without organization scope")
return resp.AccessToken, nil
}
// RefreshTokenWithOrganization refreshes the access token scoped to an organization.
// This is used after creating/selecting an organization to get a token with the org_id claim.
// Returns the new access token so callers can update their API clients.
func (s *Service) RefreshTokenWithOrganization(ctx context.Context, workosOrgID domain.WorkosOrganizationID) (string, error) {
s.scope.Debug("refreshing token with organization", "workos_org_id", workosOrgID)
refreshToken, err := s.storage.Get("refresh_token")
if err != nil {
s.scope.Error("failed to get refresh token", "error", err)
return "", err
}
if refreshToken == "" {
return "", errors.New("no refresh token found")
}
resp, err := s.provider.RefreshTokenWithOrganization(ctx, refreshToken, workosOrgID)
if err != nil {
s.scope.Error("failed to refresh token with organization", "error", err)
return "", err
}
if err := s.saveTokens(resp.AccessToken, resp.RefreshToken); err != nil {
s.scope.Error("failed to save refreshed tokens", "error", err)
return "", err
}
s.scope.Info("token refreshed with organization scope", "workos_org_id", workosOrgID)
return resp.AccessToken, nil
}
// saveTokens stores the access and refresh tokens securely.
func (s *Service) saveTokens(accessToken, refreshToken string) error {
if err := s.storage.Set("access_token", accessToken); err != nil {
return err
}
if err := s.storage.Set("refresh_token", refreshToken); err != nil {
return err
}
return nil
}
func (s *Service) refreshTokenForScope(ctx context.Context, refreshToken string, workosOrgID domain.WorkosOrganizationID) (*RefreshResponse, error) {
if workosOrgID != "" {
s.scope.Debug("refreshing token with preserved organization scope", "workos_org_id", workosOrgID)
return s.provider.RefreshTokenWithOrganization(ctx, refreshToken, workosOrgID)
}
return s.provider.RefreshToken(ctx, refreshToken)
}