diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index 2739ca21e15..df94c65ed79 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add `createParallelBalanceMiddleware(sources, options?)` to run balance sources in parallel with chain distribution (same strategy as subscription): each chain is assigned to one source via `getActiveChains`, so there is no overlap. Uses `Promise.allSettled` so failures in one source do not block others. Optional `options.fallbackMiddlewares` run only for remaining chains (no balance after primary run) + +### Changed + +- **BREAKING:** Require `previousChains` in `handleActiveChainsUpdate(dataSourceId, activeChains, previousChains)` and in the `onActiveChainsUpdated` callback used by data sources; the third parameter is no longer optional. Callers and data sources must pass the previous chain list for correct added/removed chain diff computation ([#7867](https://github.com/MetaMask/core/pull/7867)) + ### Removed - **BREAKING:** Remove `initDataSources` and related exports (`InitDataSourcesOptions`, `DataSources`, `DataSourceActions`, `DataSourceEvents`, `DataSourceAllowedActions`, `DataSourceAllowedEvents`, `RootMessenger`). Initialize assets by creating `AssetsController` with `queryApiClient`; the controller instantiates all data sources internally ([#7859](https://github.com/MetaMask/core/pull/7859)) diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index 404de13eb61..283065628a5 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -384,17 +384,6 @@ describe('AssetsController', () => { }); }); - describe('registerDataSources', () => { - it('registers data sources in constructor', async () => { - await withController(({ controller }) => { - // The controller registers these data sources in the constructor: - // 'BackendWebsocketDataSource', 'AccountsApiDataSource', 'SnapDataSource', 'RpcDataSource' - // We verify initialization completed without error - expect(controller.state).toBeDefined(); - }); - }); - }); - describe('getAssetMetadata', () => { it('returns metadata for existing asset', async () => { const initialState: Partial = { @@ -489,7 +478,7 @@ describe('AssetsController', () => { describe('handleActiveChainsUpdate', () => { it('updates data source chains', async () => { await withController(({ controller }) => { - controller.handleActiveChainsUpdate('TestDataSource', ['eip155:1']); + controller.handleActiveChainsUpdate('TestDataSource', ['eip155:1'], []); // Should not throw expect(controller.state).toBeDefined(); @@ -498,7 +487,7 @@ describe('AssetsController', () => { it('handles empty chains array', async () => { await withController(({ controller }) => { - controller.handleActiveChainsUpdate('TestDataSource', []); + controller.handleActiveChainsUpdate('TestDataSource', [], []); expect(controller.state).toBeDefined(); }); @@ -507,10 +496,10 @@ describe('AssetsController', () => { it('triggers fetch when chains are added', async () => { await withController(async ({ controller }) => { // First set no chains - controller.handleActiveChainsUpdate('TestDataSource', []); + controller.handleActiveChainsUpdate('TestDataSource', [], []); // Then add chains - this should trigger fetch for added chains - controller.handleActiveChainsUpdate('TestDataSource', ['eip155:1']); + controller.handleActiveChainsUpdate('TestDataSource', ['eip155:1'], []); // Allow async operations to complete await new Promise(process.nextTick); diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index f6d25a358db..96661eeaca5 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -30,7 +30,11 @@ import BigNumberJS from 'bignumber.js'; import { isEqual } from 'lodash'; import type { AssetsControllerMethodActions } from './AssetsController-method-action-types'; -import type { SubscriptionRequest } from './data-sources/AbstractDataSource'; +import type { + AbstractDataSource, + DataSourceState, + SubscriptionRequest, +} from './data-sources/AbstractDataSource'; import { AccountsApiDataSource } from './data-sources/AccountsApiDataSource'; import { BackendWebsocketDataSource } from './data-sources/BackendWebsocketDataSource'; import { PriceDataSource } from './data-sources/PriceDataSource'; @@ -39,6 +43,7 @@ import { RpcDataSource } from './data-sources/RpcDataSource'; import { SnapDataSource } from './data-sources/SnapDataSource'; import { TokenDataSource } from './data-sources/TokenDataSource'; import { projectLogger, createModuleLogger } from './logger'; +import { createParallelBalanceMiddleware } from './middlewares/parallelBalanceMiddleware'; import { DetectionMiddleware } from './middlewares/DetectionMiddleware'; import type { AccountId, @@ -55,7 +60,6 @@ import type { DataResponse, NextFunction, Middleware, - DataSourceDefinition, SubscriptionResponse, Asset, AssetsControllerStateInternal, @@ -383,13 +387,6 @@ export class AssetsController extends BaseController< ); } - /** - * Registered data sources with their available chains. - * Updated continuously and independently from subscription flows. - * Key: sourceId, Value: Set of currently available chainIds - */ - readonly #dataSources: Map> = new Map(); - readonly #backendWebsocketDataSource: BackendWebsocketDataSource; readonly #accountsApiDataSource: AccountsApiDataSource; @@ -398,6 +395,25 @@ export class AssetsController extends BaseController< readonly #rpcDataSource: RpcDataSource; + /** + * Subscription balance data sources in assignment priority order (first that supports a chain gets it). + * + * @returns The four balance data source instances in priority order. + */ + get #subscriptionBalanceDataSources(): [ + BackendWebsocketDataSource, + AccountsApiDataSource, + SnapDataSource, + RpcDataSource, + ] { + return [ + this.#backendWebsocketDataSource, + this.#accountsApiDataSource, + this.#snapDataSource, + this.#rpcDataSource, + ]; + } + readonly #priceDataSource: PriceDataSource; readonly #detectionMiddleware: DetectionMiddleware; @@ -426,27 +442,29 @@ export class AssetsController extends BaseController< this.#defaultUpdateInterval = defaultUpdateInterval; const rpcConfig = rpcDataSourceConfig ?? {}; + const onActiveChainsUpdated = ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ): void => + this.handleActiveChainsUpdate(dataSourceName, chains, previousChains); + this.#backendWebsocketDataSource = new BackendWebsocketDataSource({ messenger: this.messenger, queryApiClient, - onActiveChainsUpdated: (chains): void => - this.handleActiveChainsUpdate('BackendWebsocketDataSource', chains), + onActiveChainsUpdated, }); this.#accountsApiDataSource = new AccountsApiDataSource({ queryApiClient, - onActiveChainsUpdated: (chains): void => { - this.handleActiveChainsUpdate('AccountsApiDataSource', chains); - }, + onActiveChainsUpdated, }); this.#snapDataSource = new SnapDataSource({ messenger: this.messenger, - onActiveChainsUpdated: (chains): void => - this.handleActiveChainsUpdate('SnapDataSource', chains), + onActiveChainsUpdated, }); this.#rpcDataSource = new RpcDataSource({ messenger: this.messenger, - onActiveChainsUpdated: (chains): void => - this.handleActiveChainsUpdate('RpcDataSource', chains), + onActiveChainsUpdated, ...rpcConfig, }); this.#tokenDataSource = new TokenDataSource({ @@ -457,11 +475,6 @@ export class AssetsController extends BaseController< }); this.#detectionMiddleware = new DetectionMiddleware(); - this.#dataSources.set('BackendWebsocketDataSource', new Set()); - this.#dataSources.set('AccountsApiDataSource', new Set()); - this.#dataSources.set('SnapDataSource', new Set()); - this.#dataSources.set('RpcDataSource', new Set()); - if (!this.#isEnabled) { log('AssetsController is disabled, skipping initialization'); return; @@ -476,34 +489,6 @@ export class AssetsController extends BaseController< this.#registerActionHandlers(); } - /** - * Returns the balance data source instance for subscribe/unsubscribe by sourceId. - * - * @param sourceId - Data source identifier (e.g. 'BackendWebsocketDataSource'). - * @returns The balance data source instance, or undefined if not found. - */ - #getBalanceDataSource( - sourceId: string, - ): - | BackendWebsocketDataSource - | AccountsApiDataSource - | SnapDataSource - | RpcDataSource - | undefined { - switch (sourceId) { - case 'BackendWebsocketDataSource': - return this.#backendWebsocketDataSource; - case 'AccountsApiDataSource': - return this.#accountsApiDataSource; - case 'SnapDataSource': - return this.#snapDataSource; - case 'RpcDataSource': - return this.#rpcDataSource; - default: - return undefined; - } - } - // ============================================================================ // INITIALIZATION // ============================================================================ @@ -603,44 +588,27 @@ export class AssetsController extends BaseController< ); } - // ============================================================================ - // DATA SOURCE MANAGEMENT - // ============================================================================ - - /** - * Register data sources with the controller. - * Order of the array determines subscription order. - * - * Data sources report chain changes via the onActiveChainsUpdated callback passed at construction. - * - * @param dataSourceIds - Array of data source identifiers to register. - */ - registerDataSources(dataSourceIds: DataSourceDefinition[]): void { - for (const id of dataSourceIds) { - log('Registering data source', { id }); - - // Initialize available chains tracking for this source - this.#dataSources.set(id, new Set()); - } - } - // ============================================================================ // DATA SOURCE CHAIN MANAGEMENT // ============================================================================ /** - * Handle when a data source's active chains change. - * Active chains are chains that are both supported AND available. - * Updates centralized chain tracking and triggers re-selection if needed. + * Handle when a data source's supported chains change. + * Used to refresh balance subscriptions and run a one-time fetch when a new chain is supported. + * + * - On any add/remove: re-subscribes to data sources so chain assignment stays correct. + * - When chains are added: fetches balances for the new chains (for selected accounts on enabled networks). * - * Called from the onActiveChainsUpdated callbacks passed to data sources at construction. + * Controller does not store chains; sources report via this callback. previousChains is required for diff. * * @param dataSourceId - The identifier of the data source reporting the change. - * @param activeChains - Array of currently active chain IDs for this source. + * @param activeChains - Currently active (supported and available) chain IDs for this source. + * @param previousChains - Previous chains; used to compute added/removed. */ handleActiveChainsUpdate( dataSourceId: string, activeChains: ChainId[], + previousChains: ChainId[], ): void { log('Data source active chains changed', { dataSourceId, @@ -648,30 +616,15 @@ export class AssetsController extends BaseController< chains: activeChains, }); - // When BackendWebsocketDataSource is updated via AccountsApiDataSource callback, sync its state - if (dataSourceId === 'BackendWebsocketDataSource') { - this.#backendWebsocketDataSource.setActiveChainsFromAccountsApi( - activeChains, - ); - } - - const previousChains = this.#dataSources.get(dataSourceId) ?? new Set(); - const newChains = new Set(activeChains); + const previous: ChainId[] = previousChains; - // Update centralized available chains tracking - this.#dataSources.set(dataSourceId, newChains); - - // Check for changes - const addedChains = activeChains.filter( - (chain) => !previousChains.has(chain), - ); - const removedChains = Array.from(previousChains).filter( - (chain) => !newChains.has(chain), - ); + const previousSet = new Set(previous); + const addedChains = activeChains.filter((ch) => !previousSet.has(ch)); + const removedChains = previous.filter((ch) => !activeChains.includes(ch)); if (addedChains.length > 0 || removedChains.length > 0) { // Refresh subscriptions to use updated data source availability - this.#subscribeToDataSources(); + this.#subscribeAssets(); } // If chains were added and we have selected accounts, do one-time fetch @@ -768,9 +721,28 @@ export class AssetsController extends BaseController< }); const response = await this.#executeMiddlewares( [ - this.#accountsApiDataSource.assetsMiddleware, - this.#snapDataSource.assetsMiddleware, - this.#rpcDataSource.assetsMiddleware, + createParallelBalanceMiddleware( + [ + { + middleware: this.#accountsApiDataSource.assetsMiddleware, + getActiveChains: () => + this.#accountsApiDataSource.getActiveChainsSync(), + }, + { + middleware: this.#snapDataSource.assetsMiddleware, + getActiveChains: () => + this.#snapDataSource.getActiveChainsSync(), + }, + { + middleware: this.#rpcDataSource.assetsMiddleware, + getActiveChains: () => + this.#rpcDataSource.getActiveChainsSync(), + }, + ], + { + fallbackMiddlewares: [this.#rpcDataSource.assetsMiddleware], + }, + ), this.#detectionMiddleware.assetsMiddleware, this.#tokenDataSource.assetsMiddleware, this.#priceDataSource.assetsMiddleware, @@ -981,48 +953,6 @@ export class AssetsController extends BaseController< // SUBSCRIPTIONS // ============================================================================ - /** - * Assign chains to data sources based on availability. - * Returns a map of sourceId -> chains to handle. - * - * @param requestedChains - Array of chain IDs to assign to data sources. - * @returns Map of sourceId to array of assigned chain IDs. - */ - #assignChainsToDataSources( - requestedChains: ChainId[], - ): Map { - const assignment = new Map(); - const remainingChains = new Set(requestedChains); - - for (const sourceId of this.#dataSources.keys()) { - // Get available chains for this data source - const availableChains = this.#dataSources.get(sourceId); - if (!availableChains || availableChains.size === 0) { - continue; - } - - const chainsForThisSource: ChainId[] = []; - - for (const chainId of remainingChains) { - // Check if this chain is available on this source - if (availableChains.has(chainId)) { - chainsForThisSource.push(chainId); - remainingChains.delete(chainId); - } - } - - if (chainsForThisSource.length > 0) { - assignment.set(sourceId, chainsForThisSource); - log('Assigned chains to data source', { - sourceId, - chains: chainsForThisSource, - }); - } - } - - return assignment; - } - /** * Subscribe to price updates for all assets held by the given accounts. * Polls PriceDataSource which fetches prices from balance state. @@ -1369,7 +1299,7 @@ export class AssetsController extends BaseController< enabledChainCount: this.#enabledChains.size, }); - this.#subscribeToDataSources(); + this.#subscribeAssets(); if (this.#selectedAccounts.length > 0) { this.getAssets(this.#selectedAccounts, { chainIds: [...this.#enabledChains], @@ -1398,10 +1328,14 @@ export class AssetsController extends BaseController< // Convert to array first to avoid modifying map during iteration const subscriptionKeys = [...this.#activeSubscriptions.keys()]; for (const subscriptionKey of subscriptionKeys) { - // Extract sourceId from subscription key (format: "ds:${sourceId}") if (subscriptionKey.startsWith('ds:')) { const sourceId = subscriptionKey.slice(3); - this.#unsubscribeDataSource(sourceId); + const source = this.#subscriptionBalanceDataSources.find( + (ds) => ds.getName() === sourceId, + ); + if (source) { + this.#unsubscribeDataSource(source); + } } } this.#activeSubscriptions.clear(); @@ -1410,20 +1344,22 @@ export class AssetsController extends BaseController< /** * Subscribe to asset updates for all selected accounts. */ - #subscribeToDataSources(): void { + #subscribeAssets(): void { if (this.#selectedAccounts.length === 0) { return; } // Subscribe to balance updates (batched by data source) - this.#subscribeAssetsBalance(); + this.#subscribeAssetsBalance(this.#selectedAccounts, [ + ...this.#enabledChains, + ]); // Subscribe to price updates for all assets held by selected accounts this.subscribeAssetsPrice(this.#selectedAccounts, [...this.#enabledChains]); } /** - * Subscribe to balance updates for all selected accounts. + * Subscribe to balance updates for the given accounts and chains. * * Strategy to minimize data source calls: * 1. Collect all chains to subscribe based on enabled networks @@ -1432,52 +1368,46 @@ export class AssetsController extends BaseController< * * This ensures we make minimal subscriptions to each data source while covering * all accounts and chains. + * + * @param accounts - Accounts to subscribe balance updates for. + * @param chainIds - Chain IDs to subscribe for. */ - #subscribeAssetsBalance(): void { - // Step 1: Build chain -> accounts mapping based on account scopes and enabled networks + #subscribeAssetsBalance( + accounts: InternalAccount[], + chainIds: ChainId[], + ): void { const chainToAccounts = this.#buildChainToAccountsMap( - this.#selectedAccounts, - this.#enabledChains, + accounts, + new Set(chainIds), ); - - // Step 2: Split by data source active chains (ordered by priority) - // Get all chains that need to be subscribed const remainingChains = new Set(chainToAccounts.keys()); - // Assign chains to data sources based on availability (ordered by priority) - const chainAssignment = this.#assignChainsToDataSources( - Array.from(remainingChains), - ); - - log('Subscribe - chain assignment', { - totalChains: remainingChains.size, - dataSourceAssignments: Array.from(chainAssignment.entries()).map( - ([sourceId, chains]) => ({ sourceId, chainCount: chains.length }), - ), - }); - - // Subscribe to each data source with its assigned chains and relevant accounts - for (const sourceId of this.#dataSources.keys()) { - const assignedChains = chainAssignment.get(sourceId); + for (const source of this.#subscriptionBalanceDataSources) { + const availableChains = new Set(source.getActiveChainsSync()); + const assignedChains: ChainId[] = []; - if (!assignedChains || assignedChains.length === 0) { - // Unsubscribe from data sources with no assigned chains - this.#unsubscribeDataSource(sourceId); - continue; + for (const chainId of remainingChains) { + if (availableChains.has(chainId)) { + assignedChains.push(chainId); + remainingChains.delete(chainId); + } } - // Collect unique accounts that need any of the assigned chains - const accountsForSource = this.#getAccountsForChains( - assignedChains, - chainToAccounts, - ); - - if (accountsForSource.length === 0) { + if (assignedChains.length === 0) { + this.#unsubscribeDataSource(source); continue; } - // Subscribe with ONE call per data source - this.#subscribeToDataSource(sourceId, accountsForSource, assignedChains); + const seenIds = new Set(); + const accountsForSource = assignedChains + .flatMap((chainId) => chainToAccounts.get(chainId) ?? []) + .filter( + (account) => + !seenIds.has(account.id) && (seenIds.add(account.id), true), + ); + if (accountsForSource.length > 0) { + this.#subscribeDataSource(source, accountsForSource, assignedChains); + } } } @@ -1494,64 +1424,36 @@ export class AssetsController extends BaseController< chainsToSubscribe: Set, ): Map { const chainToAccounts = new Map(); - for (const account of accounts) { - const accountChains = this.#getEnabledChainsForAccount(account); - - for (const chainId of accountChains) { + for (const chainId of this.#getEnabledChainsForAccount(account)) { if (!chainsToSubscribe.has(chainId)) { continue; } - - const existingAccounts = chainToAccounts.get(chainId) ?? []; - existingAccounts.push(account); - chainToAccounts.set(chainId, existingAccounts); - } - } - - return chainToAccounts; - } - - /** - * Get unique accounts that need any of the specified chains. - * - * @param chains - Array of chain IDs to find accounts for. - * @param chainToAccounts - Map of chainId to accounts. - * @returns Array of unique accounts that need any of the specified chains. - */ - #getAccountsForChains( - chains: ChainId[], - chainToAccounts: Map, - ): InternalAccount[] { - const accountIds = new Set(); - const accounts: InternalAccount[] = []; - - for (const chainId of chains) { - const chainAccounts = chainToAccounts.get(chainId) ?? []; - for (const account of chainAccounts) { - if (!accountIds.has(account.id)) { - accountIds.add(account.id); - accounts.push(account); + let list = chainToAccounts.get(chainId); + if (!list) { + list = []; + chainToAccounts.set(chainId, list); } + list.push(account); } } - - return accounts; + return chainToAccounts; } /** * Subscribe to a specific data source with accounts and chains. - * Uses the data source ID as the subscription key for batching. + * Uses the data source name as the subscription key for batching. * - * @param sourceId - The data source identifier. + * @param source - The balance data source instance. * @param accounts - Array of accounts to subscribe for. * @param chains - Array of chain IDs to subscribe for. */ - #subscribeToDataSource( - sourceId: string, + #subscribeDataSource( + source: AbstractDataSource, accounts: InternalAccount[], chains: ChainId[], ): void { + const sourceId = source.getName(); const subscriptionKey = `ds:${sourceId}`; const existingSubscription = this.#activeSubscriptions.get(subscriptionKey); const isUpdate = existingSubscription !== undefined; @@ -1576,11 +1478,7 @@ export class AssetsController extends BaseController< getAssetsState: () => this.state, }; - const balanceDs = this.#getBalanceDataSource(sourceId); - if (!balanceDs) { - return; - } - balanceDs.subscribe(subscribeReq).catch((error) => { + source.subscribe(subscribeReq).catch((error) => { console.error( `[AssetsController] Failed to subscribe to '${sourceId}':`, error, @@ -1604,17 +1502,16 @@ export class AssetsController extends BaseController< /** * Unsubscribe from a data source if we have an active subscription. * - * @param sourceId - The data source identifier to unsubscribe from. + * @param source - The balance data source instance to unsubscribe from. */ - #unsubscribeDataSource(sourceId: string): void { - const subscriptionKey = `ds:${sourceId}`; + #unsubscribeDataSource( + source: AbstractDataSource, + ): void { + const subscriptionKey = `ds:${source.getName()}`; const existingSubscription = this.#activeSubscriptions.get(subscriptionKey); if (existingSubscription) { - const balanceDs = this.#getBalanceDataSource(sourceId); - if (balanceDs) { - balanceDs.unsubscribe(subscriptionKey).catch(() => undefined); - } + source.unsubscribe(subscriptionKey).catch(() => undefined); existingSubscription.unsubscribe(); } } @@ -1699,7 +1596,7 @@ export class AssetsController extends BaseController< }); // Subscribe and fetch for the new account group - this.#subscribeToDataSources(); + this.#subscribeAssets(); if (accounts.length > 0) { await this.getAssets(accounts, { chainIds: [...this.#enabledChains], @@ -1742,7 +1639,7 @@ export class AssetsController extends BaseController< // The data will simply not be updated until the network is re-enabled. // Refresh subscriptions for new chain set - this.#subscribeToDataSources(); + this.#subscribeAssets(); // Do one-time fetch for newly enabled chains if (addedChains.length > 0 && this.#selectedAccounts.length > 0) { @@ -1798,7 +1695,7 @@ export class AssetsController extends BaseController< destroy(): void { log('Destroying AssetsController', { - dataSourceCount: this.#dataSources.size, + dataSourceCount: this.#subscriptionBalanceDataSources.length, subscriptionCount: this.#activeSubscriptions.size, }); @@ -1815,9 +1712,6 @@ export class AssetsController extends BaseController< (this.#rpcDataSource as { destroy: () => void }).destroy(); } - // Clear data sources - this.#dataSources.clear(); - // Stop all active subscriptions this.#stop(); diff --git a/packages/assets-controller/src/README.md b/packages/assets-controller/src/README.md index f70ecc282bb..80496f16e26 100644 --- a/packages/assets-controller/src/README.md +++ b/packages/assets-controller/src/README.md @@ -88,24 +88,9 @@ registerActionHandlers() └── AssetsController:assetsUpdate // Data sources push asset updates ``` -#### 1.5 Register Data Sources +#### 1.5 Balance data source priority -```typescript -registerDataSources([ - 'BackendWebsocketDataSource', // Real-time push updates (highest priority) - 'AccountsApiDataSource', // HTTP polling fallback - 'SnapDataSource', // Solana/Bitcoin/Tron snaps - 'RpcDataSource', // Direct blockchain queries (lowest priority) -]); -``` - -**Registration order determines subscription priority**: - -- Data sources are processed in registration order -- Earlier sources get first pick for chain assignment -- Later sources act as fallbacks for remaining chains - -Data sources report their active chains by calling `AssetsController:activeChainsUpdate` action. +Built-in balance data sources are fixed and processed in priority order: BackendWebsocketDataSource, AccountsApiDataSource, SnapDataSource, RpcDataSource. Earlier sources get first pick for chain assignment; later sources act as fallbacks. Data sources report active chains via the `onActiveChainsUpdated` callback passed at construction. #### 1.6 Middleware Chains @@ -145,15 +130,15 @@ When the keyring unlocks: ``` start() // Called by KeyringController:unlock │ -├── subscribeToDataSources() +├── subscribeAssets() │ │ -│ ├── subscribeAssetsBalance() +│ ├── subscribeAssetsBalance(selectedAccounts, enabledChains) │ │ │ │ │ ├── Build chain → accounts mapping based on account scopes │ │ ├── assignChainsToDataSources(enabledChains) // Order-based assignment │ │ │ │ │ └── For each dataSource (in registration order): -│ │ └── subscribeToDataSource(sourceId, accounts, chains) +│ │ └── subscribeDataSource(sourceId, accounts, chains) │ │ └── Call {sourceId}:subscribe via Messenger │ │ │ └── subscribeAssetsPrice(selectedAccounts, enabledChains) @@ -899,7 +884,7 @@ flowchart LR end subgraph Subscribe["Subscription Flow"] - S1[subscribeToDataSources] + S1[subscribeAssets] S2[assignChainsToDataSources] S3[Call DataSource:subscribe] S4[DataSource calls AssetsController:assetsUpdate] diff --git a/packages/assets-controller/src/data-sources/AbstractDataSource.ts b/packages/assets-controller/src/data-sources/AbstractDataSource.ts index e37f6c0327e..d4b63e32f40 100644 --- a/packages/assets-controller/src/data-sources/AbstractDataSource.ts +++ b/packages/assets-controller/src/data-sources/AbstractDataSource.ts @@ -96,6 +96,15 @@ export abstract class AbstractDataSource< return this.state.activeChains; } + /** + * Get currently active chains synchronously (no state duplication in controller). + * + * @returns Array of currently active chain IDs. + */ + getActiveChainsSync(): ChainId[] { + return this.state.activeChains; + } + /** * Subscribe to updates for the given request. */ diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts index f50e23b2ead..586c4c5082a 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts @@ -153,8 +153,8 @@ async function setupController( const controller = new AccountsApiDataSource({ queryApiClient: apiClient as unknown as AccountsApiDataSourceOptions['queryApiClient'], - onActiveChainsUpdated: (chains): void => - activeChainsUpdateHandler('AccountsApiDataSource', chains), + onActiveChainsUpdated: (dataSourceName, chains, previousChains): void => + activeChainsUpdateHandler(dataSourceName, chains, previousChains), }); // Wait for async initialization @@ -215,6 +215,7 @@ describe('AccountsApiDataSource', () => { expect(activeChainsUpdateHandler).toHaveBeenCalledWith( 'AccountsApiDataSource', [CHAIN_MAINNET, CHAIN_POLYGON, CHAIN_ARBITRUM], + [], ); controller.destroy(); @@ -244,6 +245,7 @@ describe('AccountsApiDataSource', () => { expect(activeChainsUpdateHandler).toHaveBeenCalledWith( 'AccountsApiDataSource', [expected], + [], ); controller.destroy(); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index bcf7ebfb3cd..d307223c80d 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -51,8 +51,12 @@ const defaultState: AccountsApiDataSourceState = { export type AccountsApiDataSourceOptions = { /** ApiPlatformClient for API calls with caching */ queryApiClient: ApiPlatformClient; - /** Called when active chains are updated (e.g. to sync BackendWebsocketDataSource). */ - onActiveChainsUpdated: (chains: ChainId[]) => void; + /** Called when active chains are updated. Pass dataSourceName so the controller knows the source. */ + onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; pollInterval?: number; state?: Partial; }; @@ -104,7 +108,11 @@ export class AccountsApiDataSource extends AbstractDataSource< typeof CONTROLLER_NAME, AccountsApiDataSourceState > { - readonly #onActiveChainsUpdated: (chains: ChainId[]) => void; + readonly #onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; readonly #pollInterval: number; @@ -134,8 +142,9 @@ export class AccountsApiDataSource extends AbstractDataSource< async #initializeActiveChains(): Promise { try { const chains = await this.#fetchActiveChains(); + const previous = [...this.state.activeChains]; this.updateActiveChains(chains, (updatedChains) => - this.#onActiveChainsUpdated(updatedChains), + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous), ); // Periodically refresh active chains (every 20 minutes) @@ -163,8 +172,9 @@ export class AccountsApiDataSource extends AbstractDataSource< ); if (added.length > 0 || removed.length > 0) { + const previous = [...this.state.activeChains]; this.updateActiveChains(chains, (updatedChains) => - this.#onActiveChainsUpdated(updatedChains), + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous), ); } } catch (error) { diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts index 906ce1bdc83..186636622e9 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts @@ -176,8 +176,8 @@ function setupController( const controller = new BackendWebsocketDataSource({ messenger: controllerMessenger as unknown as AssetsControllerMessenger, queryApiClient: queryApiClient as unknown as ApiPlatformClient, - onActiveChainsUpdated: (chains): void => - activeChainsUpdateHandler('BackendWebsocketDataSource', chains), + onActiveChainsUpdated: (dataSourceName, chains, previousChains): void => + activeChainsUpdateHandler(dataSourceName, chains, previousChains), state: { activeChains: initialActiveChains }, }); @@ -195,7 +195,11 @@ function setupController( const triggerActiveChainsUpdate = (chains: ChainId[]): void => { controller.setActiveChainsFromAccountsApi(chains); - activeChainsUpdateHandler('BackendWebsocketDataSource', chains); + activeChainsUpdateHandler( + 'BackendWebsocketDataSource', + chains, + initialActiveChains, + ); }; return { @@ -242,6 +246,7 @@ describe('BackendWebsocketDataSource', () => { expect(activeChainsUpdateHandler).toHaveBeenCalledWith( 'BackendWebsocketDataSource', [CHAIN_MAINNET, CHAIN_POLYGON], + [], ); controller.destroy(); @@ -257,6 +262,7 @@ describe('BackendWebsocketDataSource', () => { expect(activeChainsUpdateHandler).toHaveBeenCalledWith( 'BackendWebsocketDataSource', [CHAIN_MAINNET, CHAIN_BASE], + [], ); controller.destroy(); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index d9d6417a654..47bc2e25c99 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -64,8 +64,12 @@ export type BackendWebsocketDataSourceOptions = { messenger: AssetsControllerMessenger; /** ApiPlatformClient for fetching supported networks at init (same as AccountsApiDataSource). */ queryApiClient: ApiPlatformClient; - /** Called when active chains are updated (e.g. to notify AssetsController). */ - onActiveChainsUpdated: (chains: ChainId[]) => void; + /** Called when active chains are updated. Pass dataSourceName so the controller knows the source. */ + onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; state?: Partial; }; @@ -207,7 +211,11 @@ export class BackendWebsocketDataSource extends AbstractDataSource< readonly #apiClient: ApiPlatformClient; - readonly #onActiveChainsUpdated: (chains: ChainId[]) => void; + readonly #onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; /** Chains refresh timer */ #chainsRefreshTimer: ReturnType | null = null; @@ -242,8 +250,9 @@ export class BackendWebsocketDataSource extends AbstractDataSource< async #initializeActiveChains(): Promise { try { const chains = await this.#fetchActiveChains(); + const previous = [...this.state.activeChains]; this.updateActiveChains(chains, (updatedChains) => - this.#onActiveChainsUpdated(updatedChains), + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous), ); this.#chainsRefreshTimer = setInterval(() => { @@ -266,8 +275,9 @@ export class BackendWebsocketDataSource extends AbstractDataSource< ); if (added.length > 0 || removed.length > 0) { + const previous = [...this.state.activeChains]; this.updateActiveChains(chains, (updatedChains) => - this.#onActiveChainsUpdated(updatedChains), + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous), ); } } catch (error) { @@ -379,8 +389,9 @@ export class BackendWebsocketDataSource extends AbstractDataSource< * @param chains - Array of supported chain IDs. */ updateSupportedChains(chains: ChainId[]): void { + const previous = [...this.state.activeChains]; this.updateActiveChains(chains, (updatedChains) => - this.#onActiveChainsUpdated(updatedChains), + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous), ); } diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts index 467825cd0ce..0fa47597c41 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts @@ -120,7 +120,11 @@ type WithControllerCallback = ({ }: { controller: RpcDataSource; messenger: RootMessenger; - onActiveChainsUpdated: (chains: ChainId[]) => void; + onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; }) => Promise | ReturnValue; async function withController( @@ -211,8 +215,15 @@ async function withController( ); const onActiveChainsUpdated = - (options as { onActiveChainsUpdated?: (chains: ChainId[]) => void }) - .onActiveChainsUpdated ?? jest.fn(); + ( + options as { + onActiveChainsUpdated?: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; + } + ).onActiveChainsUpdated ?? jest.fn(); const controller = new RpcDataSource({ messenger: rpcDataSourceMessenger as unknown as AssetsControllerMessenger, onActiveChainsUpdated, @@ -283,11 +294,46 @@ describe('RpcDataSource', () => { it('reports active chains on initialization', async () => { await withController(async ({ onActiveChainsUpdated }) => { - expect(onActiveChainsUpdated).toHaveBeenCalledWith([ - MOCK_CHAIN_ID_CAIP, - ]); + expect(onActiveChainsUpdated).toHaveBeenCalledWith( + 'RpcDataSource', + [MOCK_CHAIN_ID_CAIP], + [], + ); }); }); + + it('updates state.activeChains before calling onActiveChainsUpdated so getActiveChainsSync returns new chains', async () => { + let source: RpcDataSource | null = null; + await withController( + { + options: { + onActiveChainsUpdated: ( + _name: string, + newChains: ChainId[], + _previousChains: ChainId[], + ) => { + // Simulate AssetsController: when handling the callback it calls + // source.getActiveChainsSync() to get available chains for subscriptions. + if (source !== null) { + expect(source.getActiveChainsSync()).toEqual(newChains); + } + }, + }, + }, + async ({ controller, messenger }) => { + source = controller; + // Trigger callback again via network state change (first call is during construction, before source is set). + const newNetworkState = createMockNetworkState(NetworkStatus.Available); + (messenger.publish as CallableFunction)( + 'NetworkController:stateChange', + newNetworkState, + [], + ); + await new Promise(process.nextTick); + expect(source.getActiveChainsSync()).toContain(MOCK_CHAIN_ID_CAIP); + }, + ); + }); }); describe('getName', () => { diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.ts b/packages/assets-controller/src/data-sources/RpcDataSource.ts index f729a4b6398..a509b33eb44 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.ts @@ -151,8 +151,12 @@ export type RpcDataSourceConfig = { export type RpcDataSourceOptions = { /** The AssetsController messenger (shared by all data sources). */ messenger: AssetsControllerMessenger; - /** Called when active chains are updated (e.g. to notify AssetsController). */ - onActiveChainsUpdated: (chains: ChainId[]) => void; + /** Called when active chains are updated. Pass dataSourceName so the controller knows the source. */ + onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; /** Request timeout in ms */ timeout?: number; /** Balance polling interval in ms (default: 30s) */ @@ -221,7 +225,11 @@ export class RpcDataSource extends AbstractDataSource< > { readonly #messenger: AssetsControllerMessenger; - readonly #onActiveChainsUpdated: (chains: ChainId[]) => void; + readonly #onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; readonly #timeout: number; @@ -582,19 +590,22 @@ export class RpcDataSource extends AbstractDataSource< }); // Check if chains changed - const previousChains = new Set(this.#activeChains); + const previousChains = [...this.#activeChains]; + const previousSet = new Set(previousChains); const hasChanges = - previousChains.size !== activeChains.length || - activeChains.some((chain) => !previousChains.has(chain)); + previousChains.length !== activeChains.length || + activeChains.some((chain) => !previousSet.has(chain)); - // Update internal state + // Update internal state and data source state before notifying, so that + // when the controller handles the callback and calls getActiveChainsSync(), + // it receives the updated chains (same order as AbstractDataSource.updateActiveChains). this.#chainStatuses = chainStatuses; this.#activeChains = activeChains; + this.state.activeChains = activeChains; if (hasChanges) { - this.#onActiveChainsUpdated(activeChains); + this.#onActiveChainsUpdated(this.getName(), activeChains, previousChains); } - this.state.activeChains = activeChains; } #getProvider(chainId: ChainId): Web3Provider | undefined { diff --git a/packages/assets-controller/src/data-sources/SnapDataSource.ts b/packages/assets-controller/src/data-sources/SnapDataSource.ts index 42941f31065..04dcb4b57db 100644 --- a/packages/assets-controller/src/data-sources/SnapDataSource.ts +++ b/packages/assets-controller/src/data-sources/SnapDataSource.ts @@ -156,8 +156,12 @@ export type SnapDataSourceAllowedActions = export type SnapDataSourceOptions = { /** The AssetsController messenger (shared by all data sources). */ messenger: AssetsControllerMessenger; - /** Called when this data source's active chains change. */ - onActiveChainsUpdated: (chains: ChainId[]) => void; + /** Called when this data source's active chains change. Pass dataSourceName so the controller knows the source. */ + onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; /** Configured networks to support (defaults to all snap networks) */ configuredNetworks?: ChainId[]; /** Default polling interval in ms for subscriptions */ @@ -194,7 +198,11 @@ export class SnapDataSource extends AbstractDataSource< > { readonly #messenger: AssetsControllerMessenger; - readonly #onActiveChainsUpdated: (chains: ChainId[]) => void; + readonly #onActiveChainsUpdated: ( + dataSourceName: string, + chains: ChainId[], + previousChains: ChainId[], + ) => void; /** Bound handler for snap keyring balance updates, stored for cleanup */ readonly #handleSnapBalancesUpdatedBound: ( @@ -388,8 +396,9 @@ export class SnapDataSource extends AbstractDataSource< // Notify if chains changed try { + const previous = [...this.state.activeChains]; this.updateActiveChains(supportedChains, (updatedChains) => { - this.#onActiveChainsUpdated(updatedChains); + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous); }); } catch { // AssetsController not ready yet - expected during initialization @@ -398,8 +407,9 @@ export class SnapDataSource extends AbstractDataSource< log('Keyring snap discovery failed', { error }); this.state.chainToSnap = {}; try { + const previous = [...this.state.activeChains]; this.updateActiveChains([], (updatedChains) => { - this.#onActiveChainsUpdated(updatedChains); + this.#onActiveChainsUpdated(this.getName(), updatedChains, previous); }); } catch { // AssetsController not ready yet - expected during initialization diff --git a/packages/assets-controller/src/index.ts b/packages/assets-controller/src/index.ts index 307a796cccd..d1dccf31e73 100644 --- a/packages/assets-controller/src/index.ts +++ b/packages/assets-controller/src/index.ts @@ -72,9 +72,6 @@ export type { FetchContext, FetchNextFunction, FetchMiddleware, - // Data source registration - DataSourceDefinition, - RegisteredDataSource, SubscriptionResponse, // Combined asset type Asset, @@ -153,7 +150,12 @@ export type { } from './data-sources'; // Middlewares -export { DetectionMiddleware } from './middlewares'; +export { + createParallelBalanceMiddleware, + type BalanceMiddlewareSource, + type ParallelBalanceMiddlewareOptions, + DetectionMiddleware, +} from './middlewares'; // Utilities export { normalizeAssetId } from './utils'; diff --git a/packages/assets-controller/src/middlewares/index.ts b/packages/assets-controller/src/middlewares/index.ts index 2c54fa8b313..aaf45df7d2a 100644 --- a/packages/assets-controller/src/middlewares/index.ts +++ b/packages/assets-controller/src/middlewares/index.ts @@ -1 +1,6 @@ export { DetectionMiddleware } from './DetectionMiddleware'; +export { + createParallelBalanceMiddleware, + type BalanceMiddlewareSource, + type ParallelBalanceMiddlewareOptions, +} from './parallelBalanceMiddleware'; diff --git a/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.test.ts b/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.test.ts new file mode 100644 index 00000000000..fdc0cba691a --- /dev/null +++ b/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.test.ts @@ -0,0 +1,374 @@ +import { createParallelBalanceMiddleware } from './parallelBalanceMiddleware'; +import type { + Context, + DataRequest, + Caip19AssetId, + AccountId, + ChainId, + AssetsControllerStateInternal, +} from '../types'; + +const MOCK_ACCOUNT_ID = 'mock-account-id' as AccountId; +const MOCK_CHAIN = 'eip155:1' as ChainId; +const MOCK_ASSET = 'eip155:1/slip44:60' as Caip19AssetId; + +function createDataRequest(overrides?: Partial): DataRequest { + return { + chainIds: [MOCK_CHAIN], + accountsWithSupportedChains: [], + dataTypes: ['balance'], + ...overrides, + } as DataRequest; +} + +function createContext(overrides?: Partial): Context { + return { + request: createDataRequest(), + response: {}, + getAssetsState: jest.fn().mockReturnValue({ + assetsMetadata: {}, + assetsBalance: {}, + customAssets: {}, + } as AssetsControllerStateInternal), + ...overrides, + }; +} + +describe('createParallelBalanceMiddleware', () => { + it('calls next with unchanged context when middlewares array is empty', async () => { + const middleware = createParallelBalanceMiddleware([]); + const context = createContext(); + const next = jest.fn().mockResolvedValue(context); + + await middleware(context, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(next).toHaveBeenCalledWith(context); + }); + + it('merges responses from all middlewares and passes to next', async () => { + const CHAIN_2 = 'eip155:137' as ChainId; + const middlewareA = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET]: { amount: '100' }, + }, + }; + return next(ctx); + }); + const middlewareB = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + ['eip155:137/slip44:60' as Caip19AssetId]: { amount: '200' }, + }, + }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware([ + { middleware: middlewareA, getActiveChains: () => [MOCK_CHAIN] }, + { middleware: middlewareB, getActiveChains: () => [CHAIN_2] }, + ]); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, CHAIN_2] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(next).toHaveBeenCalledTimes(1); + const passedContext = next.mock.calls[0][0]; + expect(passedContext.response.assetsBalance).toBeDefined(); + expect(passedContext.response.assetsBalance?.[MOCK_ACCOUNT_ID]).toEqual({ + [MOCK_ASSET]: { amount: '100' }, + 'eip155:137/slip44:60': { amount: '200' }, + }); + expect(middlewareA).toHaveBeenCalledWith( + expect.objectContaining({ + request: expect.objectContaining({ chainIds: [MOCK_CHAIN] }), + }), + expect.any(Function), + ); + expect(middlewareB).toHaveBeenCalledWith( + expect.objectContaining({ + request: expect.objectContaining({ chainIds: [CHAIN_2] }), + }), + expect.any(Function), + ); + }); + + it('uses fallback when one middleware throws: merges results from others', async () => { + const CHAIN_2 = 'eip155:137' as ChainId; + const goodMiddleware = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET]: { amount: '42' }, + }, + }; + return next(ctx); + }); + const failingMiddleware = jest.fn(async () => { + throw new Error('Source unavailable'); + }); + + const parallel = createParallelBalanceMiddleware([ + { middleware: goodMiddleware, getActiveChains: () => [MOCK_CHAIN] }, + { middleware: failingMiddleware, getActiveChains: () => [CHAIN_2] }, + ]); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, CHAIN_2] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(next).toHaveBeenCalledTimes(1); + const passedContext = next.mock.calls[0][0]; + expect(passedContext.response.assetsBalance).toEqual({ + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET]: { amount: '42' }, + }, + }); + }); + + it('passes through to next with initial response when all middlewares fail', async () => { + const failing1 = jest.fn(async () => { + throw new Error('Fail 1'); + }); + const failing2 = jest.fn(async () => { + throw new Error('Fail 2'); + }); + + const parallel = createParallelBalanceMiddleware([ + { middleware: failing1, getActiveChains: () => [MOCK_CHAIN] }, + { middleware: failing2, getActiveChains: () => ['eip155:137' as ChainId] }, + ]); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, 'eip155:137' as ChainId] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(next.mock.calls[0][0].response).toEqual({}); + }); + + it('runs middlewares in parallel', async () => { + const CHAIN_2 = 'eip155:137' as ChainId; + const order: number[] = []; + const delay = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); + + const slowMiddleware = jest.fn(async (ctx: Context, next) => { + order.push(1); + await delay(20); + order.push(2); + ctx.response.assetsBalance = { [MOCK_ACCOUNT_ID]: {} }; + return next(ctx); + }); + const fastMiddleware = jest.fn(async (ctx: Context, next) => { + order.push(3); + ctx.response.assetsBalance = { [MOCK_ACCOUNT_ID]: {} }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware([ + { middleware: slowMiddleware, getActiveChains: () => [MOCK_CHAIN] }, + { middleware: fastMiddleware, getActiveChains: () => [CHAIN_2] }, + ]); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, CHAIN_2] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(order).toEqual([1, 3, 2]); + }); + + it('merges errors from multiple responses', async () => { + const CHAIN_2 = 'eip155:2' as ChainId; + const CHAIN_3 = 'eip155:3' as ChainId; + const middlewareWithError = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = {}; + ctx.response.errors = { [CHAIN_2]: 'RPC failed' }; + return next(ctx); + }); + const otherMiddleware = jest.fn(async (ctx: Context, next) => { + ctx.response.errors = { [CHAIN_3]: 'Timeout' }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware([ + { middleware: middlewareWithError, getActiveChains: () => [CHAIN_2] }, + { middleware: otherMiddleware, getActiveChains: () => [CHAIN_3] }, + ]); + const context = createContext({ + request: createDataRequest({ chainIds: [CHAIN_2, CHAIN_3] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(next.mock.calls[0][0].response.errors).toEqual({ + 'eip155:2': 'RPC failed', + 'eip155:3': 'Timeout', + }); + }); + + describe('fallback for remaining chains (no balance from primary)', () => { + const CHAIN_2 = 'eip155:137' as ChainId; + const ASSET_CHAIN_2 = 'eip155:137/slip44:60' as Caip19AssetId; + + it('runs fallback for remaining chains when primary returns no balance', async () => { + const primary = jest.fn(async (ctx: Context, next) => { + ctx.response.errors = { [MOCK_CHAIN]: 'Accounts API down' }; + return next(ctx); + }); + const fallback = jest.fn(async (ctx: Context, next) => { + expect(ctx.request.chainIds).toEqual([MOCK_CHAIN]); + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET]: { amount: '99' }, + }, + }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware( + [{ middleware: primary, getActiveChains: () => [MOCK_CHAIN] }], + { fallbackMiddlewares: [fallback] }, + ); + const context = createContext(); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(fallback).toHaveBeenCalledTimes(1); + const passed = next.mock.calls[0][0].response; + expect(passed.assetsBalance?.[MOCK_ACCOUNT_ID]?.[MOCK_ASSET]).toEqual({ + amount: '99', + }); + expect(passed.errors).toBeUndefined(); + }); + + it('runs fallback only for chains with no balance (others are not remaining)', async () => { + const primary = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + [ASSET_CHAIN_2]: { amount: '200' }, + }, + }; + return next(ctx); + }); + const fallback = jest.fn(async (ctx: Context, next) => { + expect(ctx.request.chainIds).toEqual([MOCK_CHAIN]); + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET]: { amount: '1' }, + }, + }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware( + [{ middleware: primary, getActiveChains: () => [CHAIN_2] }], + { fallbackMiddlewares: [fallback] }, + ); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, CHAIN_2] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(fallback).toHaveBeenCalledTimes(1); + const passed = next.mock.calls[0][0].response; + expect(passed.assetsBalance?.[MOCK_ACCOUNT_ID]).toEqual({ + [MOCK_ASSET]: { amount: '1' }, + [ASSET_CHAIN_2]: { amount: '200' }, + }); + }); + + it('does not run fallback when all chains have balance after primary run', async () => { + const primary = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET]: { amount: '10' } }, + }; + return next(ctx); + }); + const fallback = jest.fn(async (ctx: Context, next) => next(ctx)); + + const parallel = createParallelBalanceMiddleware( + [{ middleware: primary, getActiveChains: () => [MOCK_CHAIN] }], + { fallbackMiddlewares: [fallback] }, + ); + const context = createContext(); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + expect(fallback).not.toHaveBeenCalled(); + expect(next.mock.calls[0][0].response.assetsBalance?.[MOCK_ACCOUNT_ID]).toEqual({ + [MOCK_ASSET]: { amount: '10' }, + }); + }); + + it('keeps errors for remaining chains that fallback did not supply balance for', async () => { + const primary = jest.fn(async (ctx: Context, next) => { + ctx.response.errors = { + [MOCK_CHAIN]: 'Failed', + [CHAIN_2]: 'Failed', + }; + return next(ctx); + }); + const fallback = jest.fn(async (ctx: Context, next) => { + ctx.response.assetsBalance = { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET]: { amount: '1' } }, + }; + return next(ctx); + }); + + const parallel = createParallelBalanceMiddleware( + [ + { + middleware: primary, + getActiveChains: () => [MOCK_CHAIN, CHAIN_2], + }, + ], + { fallbackMiddlewares: [fallback] }, + ); + const context = createContext({ + request: createDataRequest({ chainIds: [MOCK_CHAIN, CHAIN_2] }), + }); + const next = jest.fn().mockImplementation((ctx: Context) => + Promise.resolve(ctx), + ); + + await parallel(context, next); + + const passed = next.mock.calls[0][0].response; + expect(passed.assetsBalance?.[MOCK_ACCOUNT_ID]?.[MOCK_ASSET]).toEqual({ + amount: '1', + }); + expect(passed.errors).toEqual({ [CHAIN_2]: 'Failed' }); + }); + }); +}); diff --git a/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.ts b/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.ts new file mode 100644 index 00000000000..1a40f423345 --- /dev/null +++ b/packages/assets-controller/src/middlewares/parallelBalanceMiddleware.ts @@ -0,0 +1,291 @@ +import { createModuleLogger, projectLogger } from '../logger'; +import type { + AccountId, + AssetBalance, + Caip19AssetId, + ChainId, + Context, + DataResponse, + Middleware, + NextFunction, +} from '../types'; + +// ============================================================================ +// LOGGING +// ============================================================================ + +const LOGGER_NAME = 'ParallelBalanceMiddleware'; +const log = createModuleLogger(projectLogger, LOGGER_NAME); + +// ============================================================================ +// CHAIN / RESPONSE HELPERS +// ============================================================================ + +/** CAIP-19 asset IDs are "chainId/namespace:reference"; return the chainId prefix. */ +function getChainIdFromAssetId(assetId: Caip19AssetId): ChainId { + const idx = assetId.indexOf('/'); + return (idx === -1 ? assetId : assetId.slice(0, idx)) as ChainId; +} + +/** Collect chain IDs that have at least one balance in the response. */ +function getChainsWithBalance(response: DataResponse): Set { + const chains = new Set(); + if (!response.assetsBalance) return chains; + for (const accountBalances of Object.values(response.assetsBalance)) { + for (const assetId of Object.keys(accountBalances)) { + chains.add(getChainIdFromAssetId(assetId as Caip19AssetId)); + } + } + return chains; +} + +/** + * Chains that have no balance in the merged response (remaining after primary run). + * Fallback runs only for these; it does not depend on Promise success or response.errors. + */ +function getRemainingChains( + requestChainIds: ChainId[], + mergedResponse: DataResponse, +): ChainId[] { + const chainsWithBalance = getChainsWithBalance(mergedResponse); + return requestChainIds.filter((chainId) => !chainsWithBalance.has(chainId)); +} + +/** + * Distribute chain IDs across sources by support (same strategy as subscription). + * Each chain is assigned to the first source that supports it; no overlap. + * + * @param requestChainIds - Chains requested. + * @param sources - Sources with getActiveChains (order = priority). + * @returns Map of source index -> assigned chain IDs. + */ +function distributeChainsToSources( + requestChainIds: ChainId[], + sources: { getActiveChains: () => ChainId[] }[], +): Map { + const remaining = new Set(requestChainIds); + const assignment = new Map(); + + for (let i = 0; i < sources.length; i++) { + const available = new Set(sources[i].getActiveChains()); + const assigned: ChainId[] = []; + for (const chainId of remaining) { + if (available.has(chainId)) { + assigned.push(chainId); + remaining.delete(chainId); + } + } + if (assigned.length > 0) { + assignment.set(i, assigned); + } + } + + return assignment; +} + +// ============================================================================ +// MERGE HELPERS +// ============================================================================ + +/** + * Merge multiple DataResponses into one. + * Later responses overwrite earlier for the same keys (same semantics as sequential chain). + * + * @param responses - Array of responses to merge (e.g. from parallel balance middlewares). + * @returns Single merged DataResponse. + */ +function mergeDataResponses(responses: DataResponse[]): DataResponse { + const merged: DataResponse = {}; + + for (const response of responses) { + if (response.assetsBalance) { + merged.assetsBalance ??= {}; + for (const [accountId, accountBalances] of Object.entries( + response.assetsBalance, + )) { + merged.assetsBalance[accountId as AccountId] = { + ...merged.assetsBalance[accountId as AccountId], + ...(accountBalances as Record), + }; + } + } + if (response.assetsMetadata) { + merged.assetsMetadata = { + ...merged.assetsMetadata, + ...response.assetsMetadata, + }; + } + if (response.assetsPrice) { + merged.assetsPrice = { + ...merged.assetsPrice, + ...response.assetsPrice, + }; + } + if (response.errors) { + merged.errors = { + ...merged.errors, + ...response.errors, + }; + } + if (response.detectedAssets) { + merged.detectedAssets ??= {}; + for (const [accountId, assetIds] of Object.entries( + response.detectedAssets, + )) { + const existing = merged.detectedAssets[accountId as AccountId] ?? []; + const combined = [...new Set([...existing, ...assetIds])]; + merged.detectedAssets[accountId as AccountId] = combined; + } + } + } + + return merged; +} + +// ============================================================================ +// PARALLEL BALANCE MIDDLEWARE +// ============================================================================ + +/** + * A balance source that can be assigned a subset of chains (same idea as subscription). + */ +export type BalanceMiddlewareSource = { + middleware: Middleware; + getActiveChains: () => ChainId[]; +}; + +export type ParallelBalanceMiddlewareOptions = { + /** + * Middlewares to run only for remaining chains (chains with no balance after + * the primary run), e.g. RPC when Accounts API did not return balance for those chains. + */ + fallbackMiddlewares?: Middleware[]; +}; + +/** + * Creates a single middleware that distributes chains across balance sources (like + * subscription), runs each source in parallel with only its assigned chains, and + * merges responses. No overlap: each chain is assigned to at most one source. + * + * If `options.fallbackMiddlewares` is set, they run only for remaining chains + * (chains with no balance after the primary run). + * + * @param sources - Balance sources (middleware + getActiveChains), in priority order. + * @param options - Optional; fallbackMiddlewares run for remaining chains only. + * @returns A middleware that distributes chains, runs primary in parallel, then fallback, and merges. + */ +export function createParallelBalanceMiddleware( + sources: BalanceMiddlewareSource[], + options: ParallelBalanceMiddlewareOptions = {}, +): Middleware { + const { fallbackMiddlewares = [] } = options; + + return async (context: Context, next: NextFunction): Promise => { + if (sources.length === 0 && fallbackMiddlewares.length === 0) { + return next(context); + } + + const noopNext: NextFunction = async (ctx) => ctx; + + const runOne = async ( + middleware: Middleware, + ctx: Context, + ): Promise => { + try { + return await middleware(ctx, noopNext); + } catch (error) { + log('Balance middleware failed', { error }); + return null; + } + }; + + // Primary: distribute chains to sources (no overlap), then run in parallel + let mergedResponse: DataResponse = context.response; + if (sources.length > 0 && context.request.chainIds.length > 0) { + const assignment = distributeChainsToSources( + context.request.chainIds, + sources, + ); + const results = await Promise.allSettled( + Array.from(assignment.entries()).map(([index, chainIds]) => + runOne(sources[index].middleware, { + ...context, + request: { ...context.request, chainIds }, + response: {}, + }), + ), + ); + const contextsToMerge: Context[] = []; + for (const result of results) { + if (result.status === 'fulfilled' && result.value !== null) { + contextsToMerge.push(result.value); + } + } + mergedResponse = + contextsToMerge.length > 0 + ? mergeDataResponses(contextsToMerge.map((ctx) => ctx.response)) + : context.response; + } + + // Fallback: for remaining chains (no balance from primary), run fallback middlewares + if (fallbackMiddlewares.length > 0 && context.request.chainIds.length > 0) { + const remainingChains = getRemainingChains( + context.request.chainIds, + mergedResponse, + ); + if (remainingChains.length > 0) { + log('Fallback for remaining chains', { + chainIds: remainingChains, + }); + const fallbackContext: Context = { + ...context, + request: { + ...context.request, + chainIds: remainingChains, + }, + response: {}, + }; + const fallbackResults = await Promise.allSettled( + fallbackMiddlewares.map((m) => + runOne(m, { ...fallbackContext, response: {} }), + ), + ); + const fallbackContexts: Context[] = []; + for (const result of fallbackResults) { + if (result.status === 'fulfilled' && result.value !== null) { + fallbackContexts.push(result.value); + } + } + if (fallbackContexts.length > 0) { + const fallbackMerged = mergeDataResponses( + fallbackContexts.map((ctx) => ctx.response), + ); + mergedResponse = mergeDataResponses([ + mergedResponse, + fallbackMerged, + ]); + // Drop errors for chains that now have balance from fallback + if (mergedResponse.errors && Object.keys(mergedResponse.errors).length > 0) { + const chainsWithBalanceAfterFallback = + getChainsWithBalance(mergedResponse); + const stillFailing: Record = {}; + for (const [chainId, message] of Object.entries( + mergedResponse.errors, + )) { + if (!chainsWithBalanceAfterFallback.has(chainId as ChainId)) { + stillFailing[chainId as ChainId] = message; + } + } + mergedResponse.errors = + Object.keys(stillFailing).length > 0 ? stillFailing : undefined; + } + } + } + } + + return next({ + ...context, + response: mergedResponse, + }); + }; +} diff --git a/packages/assets-controller/src/types.ts b/packages/assets-controller/src/types.ts index 471986f18c0..11b3f204c93 100644 --- a/packages/assets-controller/src/types.ts +++ b/packages/assets-controller/src/types.ts @@ -517,23 +517,7 @@ export type FetchNextFunction = NextFunction; export type FetchMiddleware = Middleware; /** - * Data source ID. - * - * Data sources follow a standard messenger pattern: - * - `${id}:getActiveChains` - action to get active chains - * - `${id}:activeChainsUpdated` - event when chains change - * - * Registration order determines subscription order. - */ -export type DataSourceDefinition = string; - -/** - * Registered data source - */ -export type RegisteredDataSource = DataSourceDefinition; - -/** - * Subscription response + * Subscription response returned when subscribing to asset updates. */ export type SubscriptionResponse = { /** Chains actively subscribed */