diff --git a/apps/studio/src/hooks/tests/use-add-site.test.tsx b/apps/studio/src/hooks/tests/use-add-site.test.tsx index 8646746814..df6488eb00 100644 --- a/apps/studio/src/hooks/tests/use-add-site.test.tsx +++ b/apps/studio/src/hooks/tests/use-add-site.test.tsx @@ -9,6 +9,7 @@ import { useContentTabs } from 'src/hooks/use-content-tabs'; import { useSiteDetails } from 'src/hooks/use-site-details'; import { store } from 'src/stores'; import { setProviderConstants } from 'src/stores/provider-constants-slice'; +import { useConnectSiteMutation } from 'src/stores/sync/connected-sites'; import type { SyncSite } from 'src/modules/sync/types'; import type { WPCOM } from 'wpcom/types'; @@ -16,6 +17,13 @@ vi.mock( 'src/hooks/use-site-details' ); vi.mock( 'src/hooks/use-feature-flags' ); vi.mock( 'src/hooks/use-auth' ); vi.mock( 'src/hooks/use-content-tabs' ); +vi.mock( 'src/stores/sync/connected-sites', async ( importOriginal ) => { + const original = await importOriginal< typeof import('src/stores/sync/connected-sites') >(); + return { + ...original, + useConnectSiteMutation: vi.fn(), + }; +} ); const mockPullSiteThunk = vi.hoisted( () => vi.fn() ); @@ -37,7 +45,7 @@ vi.mock( 'src/hooks/use-import-export', () => ( { } ), } ) ); -const mockConnectWpcomSites = vi.fn().mockResolvedValue( undefined ); +const mockConnectSite = vi.fn().mockReturnValue( { unwrap: () => Promise.resolve( [] ) } ); const mockShowOpenFolderDialog = vi.fn(); const mockGenerateProposedSitePath = vi.fn().mockResolvedValue( { path: '/default/path', @@ -53,7 +61,6 @@ vi.mock( 'src/lib/get-ipc-api', () => ( { showOpenFolderDialog: mockShowOpenFolderDialog, showNotification: vi.fn(), getAllCustomDomains: vi.fn().mockResolvedValue( [] ), - connectWpcomSites: mockConnectWpcomSites, getConnectedWpcomSites: vi.fn().mockResolvedValue( [] ), comparePaths: mockComparePaths, } ), @@ -78,6 +85,11 @@ describe( 'useAddSite', () => { type: 'syncOperations/pullSite', } ) ); + vi.mocked( useConnectSiteMutation ).mockReturnValue( [ + mockConnectSite, + { isLoading: false, reset: vi.fn() }, + ] as unknown as ReturnType< typeof useConnectSiteMutation > ); + // Prepopulate store with provider constants store.dispatch( setProviderConstants( { @@ -271,12 +283,10 @@ describe( 'useAddSite', () => { await result.current.handleCreateSite( formValues ); } ); - expect( mockConnectWpcomSites ).toHaveBeenCalledWith( [ - { - sites: [ remoteSite ], - localSiteId: createdSite.id, - }, - ] ); + expect( mockConnectSite ).toHaveBeenCalledWith( { + site: remoteSite, + localSiteId: createdSite.id, + } ); expect( mockPullSiteThunk ).toHaveBeenCalledWith( { client: mockClient, connectedSite: remoteSite, diff --git a/apps/studio/src/hooks/use-add-site.ts b/apps/studio/src/hooks/use-add-site.ts index d1effbfd2d..44f3cb4480 100644 --- a/apps/studio/src/hooks/use-add-site.ts +++ b/apps/studio/src/hooks/use-add-site.ts @@ -285,7 +285,10 @@ export function useAddSite() { body: __( 'Your new site was imported' ), } ); } else if ( selectedRemoteSite && client ) { - await connectSite( { site: selectedRemoteSite, localSiteId: newSite.id } ); + await connectSite( { + site: selectedRemoteSite, + localSiteId: newSite.id, + } ).unwrap(); const pullOptions: SyncOption[] = [ 'all' ]; void dispatch( syncOperationsThunks.pullSite( { diff --git a/apps/studio/src/modules/sync/index.tsx b/apps/studio/src/modules/sync/index.tsx index 0891afb2a2..23d5218fa4 100644 --- a/apps/studio/src/modules/sync/index.tsx +++ b/apps/studio/src/modules/sync/index.tsx @@ -163,14 +163,19 @@ export function ContentTabSync( { selectedSite }: { selectedSite: SiteDetails } return ; } - const handleConnect = async ( newConnectedSite: SyncSite ) => { + const handleConnect = async ( remoteSite: SyncSite ) => { try { - await connectSite( { site: newConnectedSite, localSiteId: selectedSite.id } ); + await connectSite( { + site: remoteSite, + localSiteId: selectedSite.id, + } ).unwrap(); + return true; } catch ( error ) { getIpcApi().showErrorMessageBox( { title: __( 'Failed to connect to site' ), message: __( 'Please try again.' ), } ); + return false; } }; @@ -188,8 +193,10 @@ export function ContentTabSync( { selectedSite }: { selectedSite: SiteDetails } dispatch( connectedSitesActions.openModal( reduxModalMode ) ); setSelectedRemoteSite( selectedSiteFromList ); } else { - await handleConnect( selectedSiteFromList ); - dispatch( connectedSitesActions.closeModal() ); + const didConnect = await handleConnect( selectedSiteFromList ); + if ( didConnect ) { + dispatch( connectedSitesActions.closeModal() ); + } } }; diff --git a/apps/studio/src/stores/sync/connected-sites.ts b/apps/studio/src/stores/sync/connected-sites.ts index 2d6ac1e763..a10a2a2054 100644 --- a/apps/studio/src/stores/sync/connected-sites.ts +++ b/apps/studio/src/stores/sync/connected-sites.ts @@ -2,6 +2,7 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'; import { getIpcApi } from 'src/lib/get-ipc-api'; import { RootState } from 'src/stores'; +import { wpcomSitesApi } from 'src/stores/sync/wpcom-sites'; import type { SyncSite, SyncModalMode } from 'src/modules/sync/types'; type ConnectedSitesState = { @@ -73,6 +74,17 @@ export const connectedSitesSelectors = { Boolean( state.connectedSites.loadingSiteIds[ id ] ), }; +async function persistConnectedSite( site: SyncSite, localSiteId: string ) { + await getIpcApi().connectWpcomSites( [ + { + sites: [ site ], + localSiteId, + }, + ] ); + + return getIpcApi().getConnectedWpcomSites( localSiteId ); +} + export const connectedSitesApi = createApi( { reducerPath: 'connectedSitesApi', baseQuery: fetchBaseQuery(), @@ -97,15 +109,41 @@ export const connectedSitesApi = createApi( { connectSite: builder.mutation< SyncSite[], { site: SyncSite; localSiteId: string } >( { queryFn: async ( { site, localSiteId } ) => { - await getIpcApi().connectWpcomSites( [ - { - sites: [ site ], - localSiteId, - }, - ] ); + const actualConnectedSites = await persistConnectedSite( site, localSiteId ); + return { data: actualConnectedSites }; + }, + invalidatesTags: ( result, error, { localSiteId } ) => [ + { type: 'ConnectedSites', localSiteId }, + ], + } ), - const actualConnectedSites = await getIpcApi().getConnectedWpcomSites( localSiteId ); + connectSiteById: builder.mutation< + SyncSite[], + { remoteSiteId: number; localSiteId: string; userId?: number } + >( { + queryFn: async ( { remoteSiteId, localSiteId, userId }, api ) => { + const connectedSites = await getIpcApi().getConnectedWpcomSites( localSiteId ); + const { data: remoteSites = [] } = await api.dispatch( + wpcomSitesApi.endpoints.getWpComSites.initiate( + { + connectedSiteIds: connectedSites.map( ( site ) => site.id ), + userId, + }, + { forceRefetch: true } + ) + ); + const siteToConnect = remoteSites.find( ( site ) => site.id === remoteSiteId ); + + if ( ! siteToConnect ) { + return { + error: { + status: 'CUSTOM_ERROR', + error: 'Site not found in WordPress.com sites', + }, + }; + } + const actualConnectedSites = await persistConnectedSite( siteToConnect, localSiteId ); return { data: actualConnectedSites }; }, invalidatesTags: ( result, error, { localSiteId } ) => [ @@ -165,6 +203,7 @@ export const connectedSitesApi = createApi( { export const { useGetConnectedSitesForLocalSiteQuery, useConnectSiteMutation, + useConnectSiteByIdMutation, useDisconnectSiteMutation, useUpdateSiteTimestampMutation, } = connectedSitesApi;