Skip to content
26 changes: 18 additions & 8 deletions apps/studio/src/hooks/tests/use-add-site.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@ import { useContentTabs } from 'src/hooks/use-content-tabs';
import { useSiteDetails } from 'src/hooks/use-site-details';
import { store } from 'src/stores';
import { setProviderConstants } from 'src/stores/provider-constants-slice';
import { useConnectSiteMutation } from 'src/stores/sync/connected-sites';
import type { SyncSite } from 'src/modules/sync/types';
import type { WPCOM } from 'wpcom/types';

vi.mock( 'src/hooks/use-site-details' );
vi.mock( 'src/hooks/use-feature-flags' );
vi.mock( 'src/hooks/use-auth' );
vi.mock( 'src/hooks/use-content-tabs' );
vi.mock( 'src/stores/sync/connected-sites', async ( importOriginal ) => {
const original = await importOriginal< typeof import('src/stores/sync/connected-sites') >();
return {
...original,
useConnectSiteMutation: vi.fn(),
};
} );

const mockPullSiteThunk = vi.hoisted( () => vi.fn() );

Expand All @@ -37,7 +45,7 @@ vi.mock( 'src/hooks/use-import-export', () => ( {
} ),
} ) );

const mockConnectWpcomSites = vi.fn().mockResolvedValue( undefined );
const mockConnectSite = vi.fn().mockReturnValue( { unwrap: () => Promise.resolve( [] ) } );
const mockShowOpenFolderDialog = vi.fn();
const mockGenerateProposedSitePath = vi.fn().mockResolvedValue( {
path: '/default/path',
Expand All @@ -53,7 +61,6 @@ vi.mock( 'src/lib/get-ipc-api', () => ( {
showOpenFolderDialog: mockShowOpenFolderDialog,
showNotification: vi.fn(),
getAllCustomDomains: vi.fn().mockResolvedValue( [] ),
connectWpcomSites: mockConnectWpcomSites,
getConnectedWpcomSites: vi.fn().mockResolvedValue( [] ),
comparePaths: mockComparePaths,
} ),
Expand All @@ -78,6 +85,11 @@ describe( 'useAddSite', () => {
type: 'syncOperations/pullSite',
} ) );

vi.mocked( useConnectSiteMutation ).mockReturnValue( [
mockConnectSite,
{ isLoading: false, reset: vi.fn() },
] as unknown as ReturnType< typeof useConnectSiteMutation > );

// Prepopulate store with provider constants
store.dispatch(
setProviderConstants( {
Expand Down Expand Up @@ -271,12 +283,10 @@ describe( 'useAddSite', () => {
await result.current.handleCreateSite( formValues );
} );

expect( mockConnectWpcomSites ).toHaveBeenCalledWith( [
{
sites: [ remoteSite ],
localSiteId: createdSite.id,
},
] );
expect( mockConnectSite ).toHaveBeenCalledWith( {
site: remoteSite,
localSiteId: createdSite.id,
} );
expect( mockPullSiteThunk ).toHaveBeenCalledWith( {
client: mockClient,
connectedSite: remoteSite,
Expand Down
5 changes: 4 additions & 1 deletion apps/studio/src/hooks/use-add-site.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ export function useAddSite() {
body: __( 'Your new site was imported' ),
} );
} else if ( selectedRemoteSite && client ) {
await connectSite( { site: selectedRemoteSite, localSiteId: newSite.id } );
await connectSite( {
site: selectedRemoteSite,
localSiteId: newSite.id,
} ).unwrap();
const pullOptions: SyncOption[] = [ 'all' ];
void dispatch(
syncOperationsThunks.pullSite( {
Expand Down
15 changes: 11 additions & 4 deletions apps/studio/src/modules/sync/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,19 @@ export function ContentTabSync( { selectedSite }: { selectedSite: SiteDetails }
return <NoAuthSyncTab />;
}

const handleConnect = async ( newConnectedSite: SyncSite ) => {
const handleConnect = async ( remoteSite: SyncSite ) => {
try {
await connectSite( { site: newConnectedSite, localSiteId: selectedSite.id } );
await connectSite( {
site: remoteSite,
localSiteId: selectedSite.id,
} ).unwrap();
return true;
} catch ( error ) {
getIpcApi().showErrorMessageBox( {
title: __( 'Failed to connect to site' ),
message: __( 'Please try again.' ),
} );
return false;
}
};

Expand All @@ -188,8 +193,10 @@ export function ContentTabSync( { selectedSite }: { selectedSite: SiteDetails }
dispatch( connectedSitesActions.openModal( reduxModalMode ) );
setSelectedRemoteSite( selectedSiteFromList );
} else {
await handleConnect( selectedSiteFromList );
dispatch( connectedSitesActions.closeModal() );
const didConnect = await handleConnect( selectedSiteFromList );
if ( didConnect ) {
dispatch( connectedSitesActions.closeModal() );
}
}
};

Expand Down
53 changes: 46 additions & 7 deletions apps/studio/src/stores/sync/connected-sites.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react';
import { getIpcApi } from 'src/lib/get-ipc-api';
import { RootState } from 'src/stores';
import { wpcomSitesApi } from 'src/stores/sync/wpcom-sites';
import type { SyncSite, SyncModalMode } from 'src/modules/sync/types';

type ConnectedSitesState = {
Expand Down Expand Up @@ -73,6 +74,17 @@ export const connectedSitesSelectors = {
Boolean( state.connectedSites.loadingSiteIds[ id ] ),
};

async function persistConnectedSite( site: SyncSite, localSiteId: string ) {
await getIpcApi().connectWpcomSites( [
{
sites: [ site ],
localSiteId,
},
] );

return getIpcApi().getConnectedWpcomSites( localSiteId );
}

export const connectedSitesApi = createApi( {
reducerPath: 'connectedSitesApi',
baseQuery: fetchBaseQuery(),
Expand All @@ -97,15 +109,41 @@ export const connectedSitesApi = createApi( {

connectSite: builder.mutation< SyncSite[], { site: SyncSite; localSiteId: string } >( {
queryFn: async ( { site, localSiteId } ) => {
await getIpcApi().connectWpcomSites( [
{
sites: [ site ],
localSiteId,
},
] );
const actualConnectedSites = await persistConnectedSite( site, localSiteId );
return { data: actualConnectedSites };
},
invalidatesTags: ( result, error, { localSiteId } ) => [
{ type: 'ConnectedSites', localSiteId },
],
} ),

const actualConnectedSites = await getIpcApi().getConnectedWpcomSites( localSiteId );
connectSiteById: builder.mutation<
SyncSite[],
{ remoteSiteId: number; localSiteId: string; userId?: number }
>( {
queryFn: async ( { remoteSiteId, localSiteId, userId }, api ) => {
const connectedSites = await getIpcApi().getConnectedWpcomSites( localSiteId );
const { data: remoteSites = [] } = await api.dispatch(
wpcomSitesApi.endpoints.getWpComSites.initiate(
{
connectedSiteIds: connectedSites.map( ( site ) => site.id ),
userId,
},
{ forceRefetch: true }
)
);
const siteToConnect = remoteSites.find( ( site ) => site.id === remoteSiteId );

if ( ! siteToConnect ) {
return {
error: {
status: 'CUSTOM_ERROR',
error: 'Site not found in WordPress.com sites',
},
};
}

const actualConnectedSites = await persistConnectedSite( siteToConnect, localSiteId );
return { data: actualConnectedSites };
},
invalidatesTags: ( result, error, { localSiteId } ) => [
Expand Down Expand Up @@ -165,6 +203,7 @@ export const connectedSitesApi = createApi( {
export const {
useGetConnectedSitesForLocalSiteQuery,
useConnectSiteMutation,
useConnectSiteByIdMutation,
useDisconnectSiteMutation,
useUpdateSiteTimestampMutation,
} = connectedSitesApi;