Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ tasks:
generate:schema:
desc: Generate PowerSync and SQLite schema (uses local control plane)
cmds:
- go generate ./internal/powersync
- go generate ./internal/sqlite
- sed '/^-- Auto-generated/d; /^-- Run .task generate/d' ../control-plane/internal/powersync/schema.sql > internal/chat/tools/query_schema.sql
- doppler run -- go generate ./internal/powersync
- doppler run -- go generate ./internal/sqlite
- sed '/^-- Auto-generated/d; /^-- Run .task generate/d' internal/sqlite/schema.sql > internal/app/chattools/query_schema.sql

# ===========================================================================
# Build
Expand Down
635 changes: 256 additions & 379 deletions internal/app/chattools/query_schema.sql

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion internal/app/onboarding/preflight/preflight_effects.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,30 @@ import (

tea "charm.land/bubbletea/v2"

"github.com/usetero/cli/internal/auth"
"github.com/usetero/cli/internal/core/bootstrap"
"github.com/usetero/cli/internal/domain"
)

func (m *Model) checkAuth() tea.Cmd {
return func() tea.Msg {
hasValidAuth := false
var user *auth.User
if m.auth.IsAuthenticated() {
if _, err := m.auth.GetAccessToken(m.ctx); err == nil {
hasValidAuth = true
if userID, err := m.auth.GetUserID(m.ctx); err == nil && userID != "" {
user = &auth.User{ID: userID}
} else {
// Avoid getting stuck in sync with a valid token but no user identity.
_ = m.auth.ClearTokens()
hasValidAuth = false
}
} else {
_ = m.auth.ClearTokens()
}
}
return preflightAuthCheckCompletedMsg{hasValidAuth: hasValidAuth}
return preflightAuthCheckCompletedMsg{hasValidAuth: hasValidAuth, user: user}
}
}

Expand Down
69 changes: 69 additions & 0 deletions internal/app/onboarding/preflight/preflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ import (
"errors"
"testing"

"github.com/usetero/cli/internal/auth/authtest"
graphql "github.com/usetero/cli/internal/boundary/graphql"
"github.com/usetero/cli/internal/core/bootstrap"
"github.com/usetero/cli/internal/domain"
"github.com/usetero/cli/internal/log/logtest"
"github.com/usetero/cli/internal/preferences/preferencestest"
"github.com/usetero/cli/internal/styles"
)

func TestResolveOrg(t *testing.T) {
Expand Down Expand Up @@ -70,3 +75,67 @@ func TestPreflightOutcomeForError(t *testing.T) {
t.Fatalf("expected inconclusive outcome for generic error, got %v", outcome)
}
}

func TestCheckAuthHydratesUserWhenTokenValid(t *testing.T) {
t.Parallel()

auth := &authtest.MockAuth{
IsAuthenticatedFunc: func() bool { return true },
GetAccessTokenFunc: func(context.Context) (string, error) { return "token", nil },
GetUserIDFunc: func(context.Context) (string, error) { return "user-1", nil },
}

m := New(
context.Background(),
styles.NewTheme(true),
graphql.ServiceSet{},
auth,
preferencestest.NewMockUserPreferences(),
preferencestest.NewMockOrgPreferences(),
logtest.NewScope(t),
)

msg := m.checkAuth()().(preflightAuthCheckCompletedMsg)
if !msg.hasValidAuth {
t.Fatal("expected valid auth")
}
if msg.user == nil || msg.user.ID != "user-1" {
t.Fatalf("expected hydrated user id user-1, got %#v", msg.user)
}
}

func TestCheckAuthClearsInvalidAuthWhenUserIDMissing(t *testing.T) {
t.Parallel()

cleared := false
auth := &authtest.MockAuth{
IsAuthenticatedFunc: func() bool { return true },
GetAccessTokenFunc: func(context.Context) (string, error) { return "token", nil },
GetUserIDFunc: func(context.Context) (string, error) { return "", errors.New("missing sub") },
ClearTokensFunc: func() error {
cleared = true
return nil
},
}

m := New(
context.Background(),
styles.NewTheme(true),
graphql.ServiceSet{},
auth,
preferencestest.NewMockUserPreferences(),
preferencestest.NewMockOrgPreferences(),
logtest.NewScope(t),
)

msg := m.checkAuth()().(preflightAuthCheckCompletedMsg)
if msg.hasValidAuth {
t.Fatal("expected invalid auth when user id cannot be resolved")
}
if msg.user != nil {
t.Fatalf("expected nil user, got %#v", msg.user)
}
if !cleared {
t.Fatal("expected tokens to be cleared")
}
}
2 changes: 2 additions & 0 deletions internal/app/onboarding/preflight/preflight_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package preflight

import (
"github.com/usetero/cli/internal/auth"
"github.com/usetero/cli/internal/core/bootstrap"
"github.com/usetero/cli/internal/domain"
)
Expand All @@ -11,6 +12,7 @@ type preflightResolutionCompletedMsg struct {

type preflightAuthCheckCompletedMsg struct {
hasValidAuth bool
user *auth.User
}

type preflightOrganizationsLoadedMsg struct {
Expand Down
1 change: 1 addition & 0 deletions internal/app/onboarding/preflight/preflight_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func (m *Model) Update(msg tea.Msg) tea.Cmd {

func (m *Model) handleAuthChecked(msg preflightAuthCheckCompletedMsg) tea.Cmd {
m.state.HasValidAuth = msg.hasValidAuth
m.state.User = msg.user
if !m.state.HasValidAuth {
return m.emitResult()
}
Expand Down
24 changes: 22 additions & 2 deletions internal/auth/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,13 @@ func (s *Service) GetAccessToken(ctx context.Context) (string, error) {
return "", errors.New("no access token found")
}

workosOrgID := domain.WorkosOrganizationID("")

// Check if token is expired
claims, err := ParseToken(accessToken)
if err == nil && claims.OrgID != "" {
workosOrgID = domain.WorkosOrganizationID(claims.OrgID)
}
if err != nil || claims.IsExpired() {
s.scope.Debug("access token expired, refreshing")
refreshToken, err := s.storage.Get("refresh_token")
Expand All @@ -213,7 +218,7 @@ func (s *Service) GetAccessToken(ctx context.Context) (string, error) {
return "", errors.New("no refresh token found")
}

resp, err := s.provider.RefreshToken(ctx, refreshToken)
resp, err := s.refreshTokenForScope(ctx, refreshToken, workosOrgID)
if err != nil {
s.scope.Error("failed to refresh token", "error", err)
return "", err
Expand All @@ -237,6 +242,13 @@ func (s *Service) GetAccessToken(ctx context.Context) (string, error) {
func (s *Service) ForceRefreshAccessToken(ctx context.Context) (string, error) {
s.scope.Debug("force-refreshing access token")

workosOrgID := domain.WorkosOrganizationID("")
if accessToken, err := s.storage.Get("access_token"); err == nil && accessToken != "" {
if claims, parseErr := ParseToken(accessToken); parseErr == nil && claims.OrgID != "" {
workosOrgID = domain.WorkosOrganizationID(claims.OrgID)
}
}

refreshToken, err := s.storage.Get("refresh_token")
if err != nil {
s.scope.Error("failed to get refresh token", "error", err)
Expand All @@ -246,7 +258,7 @@ func (s *Service) ForceRefreshAccessToken(ctx context.Context) (string, error) {
return "", errors.New("no refresh token found")
}

resp, err := s.provider.RefreshToken(ctx, refreshToken)
resp, err := s.refreshTokenForScope(ctx, refreshToken, workosOrgID)
if err != nil {
s.scope.Error("failed to refresh token", "error", err)
return "", err
Expand Down Expand Up @@ -345,3 +357,11 @@ func (s *Service) saveTokens(accessToken, refreshToken string) error {
}
return nil
}

func (s *Service) refreshTokenForScope(ctx context.Context, refreshToken string, workosOrgID domain.WorkosOrganizationID) (*RefreshResponse, error) {
if workosOrgID != "" {
s.scope.Debug("refreshing token with preserved organization scope", "workos_org_id", workosOrgID)
return s.provider.RefreshTokenWithOrganization(ctx, refreshToken, workosOrgID)
}
return s.provider.RefreshToken(ctx, refreshToken)
}
104 changes: 104 additions & 0 deletions internal/auth/auth_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/usetero/cli/internal/auth"
"github.com/usetero/cli/internal/auth/authtest"
"github.com/usetero/cli/internal/domain"
"github.com/usetero/cli/internal/log/logtest"
)

Expand Down Expand Up @@ -108,6 +109,52 @@ func TestService_GetAccessToken(t *testing.T) {
t.Fatal("expected error, got nil")
}
})

t.Run("preserves org scope when refreshing expired token", func(t *testing.T) {
t.Parallel()
expiredOrgToken := makeTestTokenWithOrg(time.Now().Add(-10*time.Minute), "org_123")
newToken := makeTestTokenWithOrg(time.Now().Add(10*time.Minute), "org_123")

storage := &authtest.MockSecureStorage{
GetFunc: func(key string) (string, error) {
switch key {
case "access_token":
return expiredOrgToken, nil
case "refresh_token":
return "refresh_token_value", nil
}
return "", nil
},
}

provider := &authtest.MockOAuthProvider{
RefreshTokenWithOrganizationFunc: func(ctx context.Context, refreshToken string, organizationID domain.WorkosOrganizationID) (*auth.RefreshResponse, error) {
if refreshToken != "refresh_token_value" {
t.Errorf("unexpected refresh token: %s", refreshToken)
}
if organizationID != "org_123" {
t.Errorf("unexpected organization ID: %s", organizationID)
}
return &auth.RefreshResponse{
AccessToken: newToken,
RefreshToken: "new_refresh_token",
}, nil
},
RefreshTokenFunc: func(ctx context.Context, refreshToken string) (*auth.RefreshResponse, error) {
t.Fatal("expected org-scoped refresh path")
return nil, nil
},
}

svc := auth.NewService(provider, storage, logtest.NewScope(t))
token, err := svc.GetAccessToken(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != newToken {
t.Errorf("got %q, want %q", token, newToken)
}
})
}

func TestService_ForceRefreshAccessToken(t *testing.T) {
Expand Down Expand Up @@ -218,6 +265,52 @@ func TestService_ForceRefreshAccessToken(t *testing.T) {
t.Fatal("expected error, got nil")
}
})

t.Run("preserves org scope on force refresh", func(t *testing.T) {
t.Parallel()
currentOrgToken := makeTestTokenWithOrg(time.Now().Add(10*time.Minute), "org_abc")
newToken := makeTestTokenWithOrg(time.Now().Add(20*time.Minute), "org_abc")

storage := &authtest.MockSecureStorage{
GetFunc: func(key string) (string, error) {
switch key {
case "access_token":
return currentOrgToken, nil
case "refresh_token":
return "refresh_token_value", nil
}
return "", nil
},
}

provider := &authtest.MockOAuthProvider{
RefreshTokenWithOrganizationFunc: func(ctx context.Context, refreshToken string, organizationID domain.WorkosOrganizationID) (*auth.RefreshResponse, error) {
if refreshToken != "refresh_token_value" {
t.Errorf("unexpected refresh token: %s", refreshToken)
}
if organizationID != "org_abc" {
t.Errorf("unexpected organization ID: %s", organizationID)
}
return &auth.RefreshResponse{
AccessToken: newToken,
RefreshToken: "new_refresh_token",
}, nil
},
RefreshTokenFunc: func(ctx context.Context, refreshToken string) (*auth.RefreshResponse, error) {
t.Fatal("expected org-scoped refresh path")
return nil, nil
},
}

svc := auth.NewService(provider, storage, logtest.NewScope(t))
token, err := svc.ForceRefreshAccessToken(context.Background())
if err != nil {
t.Fatalf("ForceRefreshAccessToken error: %v", err)
}
if token != newToken {
t.Errorf("got %q, want %q", token, newToken)
}
})
}

func TestService_RefreshTokenWithoutOrganization(t *testing.T) {
Expand Down Expand Up @@ -299,3 +392,14 @@ func makeTestToken(exp time.Time) string {
sig := base64.RawURLEncoding.EncodeToString([]byte("signature"))
return header + "." + payloadEnc + "." + sig
}

func makeTestTokenWithOrg(exp time.Time, orgID string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
payload, _ := json.Marshal(map[string]interface{}{
"exp": exp.Unix(),
"org_id": orgID,
})
payloadEnc := base64.RawURLEncoding.EncodeToString(payload)
sig := base64.RawURLEncoding.EncodeToString([]byte("signature"))
return header + "." + payloadEnc + "." + sig
}
16 changes: 8 additions & 8 deletions internal/boundary/graphql/apitest/mock_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type MockClient struct {
CreateMessageFunc func(ctx context.Context, input gen.CreateMessageInput) (*gen.CreateMessageResponse, error)
EnableServiceFunc func(ctx context.Context, serviceID string) (*gen.EnableServiceResponse, error)
DisableServiceFunc func(ctx context.Context, serviceID string) (*gen.DisableServiceResponse, error)
ApproveLogEventPolicyFunc func(ctx context.Context, policyID string) (*gen.ApproveLogEventPolicyResponse, error)
DismissLogEventPolicyFunc func(ctx context.Context, policyID string) (*gen.DismissLogEventPolicyResponse, error)
ApproveLogEventRecommendationFunc func(ctx context.Context, recommendationID string, targets gen.EnforcementTargetInput) (*gen.ApproveLogEventRecommendationResponse, error)
DismissLogEventRecommendationFunc func(ctx context.Context, recommendationID string) (*gen.DismissLogEventRecommendationResponse, error)
}

// NewMockClient creates a MockClient with sensible defaults.
Expand Down Expand Up @@ -165,16 +165,16 @@ func (m *MockClient) DisableService(ctx context.Context, serviceID string) (*gen
return nil, nil
}

func (m *MockClient) ApproveLogEventPolicy(ctx context.Context, policyID string) (*gen.ApproveLogEventPolicyResponse, error) {
if m.ApproveLogEventPolicyFunc != nil {
return m.ApproveLogEventPolicyFunc(ctx, policyID)
func (m *MockClient) ApproveLogEventRecommendation(ctx context.Context, recommendationID string, targets gen.EnforcementTargetInput) (*gen.ApproveLogEventRecommendationResponse, error) {
if m.ApproveLogEventRecommendationFunc != nil {
return m.ApproveLogEventRecommendationFunc(ctx, recommendationID, targets)
}
return nil, nil
}

func (m *MockClient) DismissLogEventPolicy(ctx context.Context, policyID string) (*gen.DismissLogEventPolicyResponse, error) {
if m.DismissLogEventPolicyFunc != nil {
return m.DismissLogEventPolicyFunc(ctx, policyID)
func (m *MockClient) DismissLogEventRecommendation(ctx context.Context, recommendationID string) (*gen.DismissLogEventRecommendationResponse, error) {
if m.DismissLogEventRecommendationFunc != nil {
return m.DismissLogEventRecommendationFunc(ctx, recommendationID)
}
return nil, nil
}
Loading
Loading