diff --git a/api/oss/src/services/db_manager.py b/api/oss/src/services/db_manager.py index 643d3712de..5e5a60b6d0 100644 --- a/api/oss/src/services/db_manager.py +++ b/api/oss/src/services/db_manager.py @@ -982,7 +982,9 @@ async def check_if_user_exists_and_create_organization(user_email: str): async def check_if_user_invitation_exists(email: str, organization_id: str): - """Check if a user invitation with the given email and organization_id exists.""" + """Check if a user invitation with the given email and organization_id exists. + Email comparison is case-insensitive to handle SuperTokens email normalization. + """ project_db = await get_project_by_organization_id(organization_id=organization_id) if not project_db: @@ -990,9 +992,9 @@ async def check_if_user_invitation_exists(email: str, organization_id: str): async with engine.core_session() as session: result = await session.execute( - select(InvitationDB).filter_by( - email=email, - project_id=project_db.id, + select(InvitationDB).where( + func.lower(InvitationDB.email) == func.lower(email), + InvitationDB.project_id == project_db.id, ) ) user_invitation = result.scalars().first() @@ -1131,6 +1133,7 @@ async def delete_accounts() -> None: async def create_accounts(payload: dict) -> UserDB: """Create a new account in the database. + Email is normalized to lowercase to match SuperTokens behavior. Args: payload (dict): The payload to create the user @@ -1142,8 +1145,12 @@ async def create_accounts(payload: dict) -> UserDB: # pop required fields for organization & workspace creation organization_id = payload.pop("organization_id") + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = payload["email"].lower() + payload["email"] = normalized_email + # create user - user_info = {**payload, "username": payload["email"].split("@")[0]} + user_info = {**payload, "username": normalized_email.split("@")[0]} user_db = await user_service.create_new_user(payload=user_info) # only update organization to have user_db as its "owner" if it does not yet have one @@ -1161,7 +1168,7 @@ async def create_accounts(payload: dict) -> UserDB: # update user invitation in the case the user was invited invitation = await get_project_invitation_by_email( - project_id=str(project_db.id), email=payload["email"] + project_id=str(project_db.id), email=normalized_email ) if invitation is not None: await update_invitation( @@ -1370,6 +1377,7 @@ async def get_workspaces() -> List[WorkspaceDB]: async def remove_user_from_workspace(project_id: str, email: str): """Remove a user from a workspace. + Email comparison is case-insensitive to handle SuperTokens email normalization. Args: project_id (str): The ID of the project @@ -1461,6 +1469,7 @@ async def get_user_with_id(user_id: str) -> UserDB: async def get_user_with_email(email: str): """ Retrieves a user from the database based on their email address. + Email comparison is case-insensitive to handle SuperTokens email normalization. Args: email (str): The email address of the user to retrieve. @@ -1480,7 +1489,9 @@ async def get_user_with_email(email: str): raise Exception("Please provide a valid email address") async with engine.core_session() as session: - result = await session.execute(select(UserDB).filter_by(email=email)) + result = await session.execute( + select(UserDB).where(func.lower(UserDB.email) == func.lower(email)) + ) user = result.scalars().first() return user @@ -1625,6 +1636,7 @@ async def get_default_project_id_from_workspace( async def get_project_invitation_by_email(project_id: str, email: str) -> InvitationDB: """Get project invitation by project ID and email. + Email comparison is case-insensitive to handle SuperTokens email normalization. Args: project_id (str): The ID of the project. @@ -1636,8 +1648,9 @@ async def get_project_invitation_by_email(project_id: str, email: str) -> Invita async with engine.core_session() as session: result = await session.execute( - select(InvitationDB).filter_by( - project_id=uuid.UUID(project_id), email=email + select(InvitationDB).where( + InvitationDB.project_id == uuid.UUID(project_id), + func.lower(InvitationDB.email) == func.lower(email) ) ) invitation = result.scalars().first() @@ -1773,6 +1786,7 @@ async def get_project_invitation_by_token_and_email( project_id: str, token: str, email: str ) -> InvitationDB: """Get project invitation by project ID, token and email. + Email comparison is case-insensitive to handle SuperTokens email normalization. Args: project_id (str): The ID of the project. @@ -1785,8 +1799,10 @@ async def get_project_invitation_by_token_and_email( async with engine.core_session() as session: result = await session.execute( - select(InvitationDB).filter_by( - project_id=uuid.UUID(project_id), token=token, email=email + select(InvitationDB).where( + InvitationDB.project_id == uuid.UUID(project_id), + InvitationDB.token == token, + func.lower(InvitationDB.email) == func.lower(email) ) ) invitation = result.scalars().first() diff --git a/api/oss/src/services/organization_service.py b/api/oss/src/services/organization_service.py index 030ce6bd3f..e945194493 100644 --- a/api/oss/src/services/organization_service.py +++ b/api/oss/src/services/organization_service.py @@ -18,6 +18,7 @@ def generate_invitation_token(token_length: int = 16): async def check_existing_invitation(project_id: str, email: str): """ Checks if there is an existing invitation for a given project and email address. + Email is normalized to lowercase to match SuperTokens behavior. Args: project_id (str): The ID of the project for which the invitation is being checked. @@ -26,14 +27,17 @@ async def check_existing_invitation(project_id: str, email: str): Returns: - invitation (InvitationDB): The existing invitation if it is valid and not expired. Otherwise, returns None. """ + + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = email.lower() invitation = await db_manager.get_project_invitation_by_email( - project_id=project_id, email=email + project_id=project_id, email=normalized_email ) if ( invitation is not None and not invitation.used - and invitation.email == email + and invitation.email.lower() == normalized_email and str(invitation.project_id) == project_id ): if invitation.expiration_date > datetime.now(timezone.utc): @@ -50,6 +54,7 @@ async def check_existing_invitation(project_id: str, email: str): async def check_valid_invitation(project_id: str, email: str, token: str): """ Check if a project invitation is valid for a given user and token. + Email is normalized to lowercase to match SuperTokens behavior. Args: project_id (str): The ID of the project for which the invitation is being checked. @@ -60,9 +65,12 @@ async def check_valid_invitation(project_id: str, email: str, token: str): InvitationDB or None: Returns the invitation object if it's valid and not expired. Returns None if the invitation is not found or has expired. """ + + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = email.lower() invitation = await db_manager.get_project_invitation_by_token_and_email( - project_id, token, email + project_id, token, normalized_email ) if invitation is not None and invitation.expiration_date > datetime.now( timezone.utc @@ -135,6 +143,7 @@ async def send_invitation_email( async def create_invitation(role: str, project_id: str, email: str): """ Creates a new invitation for a user to join an organization. + Email is normalized to lowercase to match SuperTokens behavior. Args: role (str): The role to be assigned to the invited user in the organization. @@ -148,12 +157,15 @@ async def create_invitation(role: str, project_id: str, email: str): token = generate_invitation_token() expiration_date = datetime.now(timezone.utc) + timedelta(days=7) + + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = email.lower() invitation = await db_manager.create_user_invitation_to_organization( project_id=project_id, token=token, role=role, - email=email, + email=normalized_email, expiration_date=expiration_date, ) return invitation @@ -178,8 +190,11 @@ async def invite_user_to_organization( user_performing_action = await db_manager.get_user_with_id(user_id=user_id) + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = payload.email.lower() + # Check that the user is not inviting themselves - if payload.email == user_performing_action.email: + if normalized_email == user_performing_action.email.lower(): raise HTTPException( status_code=400, detail="You cannot invite yourself to a workspace", @@ -187,7 +202,7 @@ async def invite_user_to_organization( # Check if the user is already a member of the workspace existing_invitation, existing_role = await check_existing_invitation( - project_id=project_id, email=payload.email + project_id=project_id, email=normalized_email ) if existing_invitation or existing_role: raise HTTPException( @@ -196,12 +211,12 @@ async def invite_user_to_organization( ) # Create a new invitation since user hasn't been invited - invitation = await create_invitation("editor", project_id, payload.email) + invitation = await create_invitation("editor", project_id, normalized_email) # Get project by id project_db = await db_manager.get_project_by_id(project_id=project_id) - # Send the invitation email + # Send the invitation email (use original email for display purposes) send_email = await send_invitation_email( payload.email, invitation.token, # type: ignore @@ -240,15 +255,18 @@ async def resend_user_organization_invite( user_performing_action = await db_manager.get_user_with_id(user_id=user_id) + # Normalize email to lowercase to match SuperTokens behavior + normalized_email = payload.email.lower() + # Check if the email address already has a valid, unused invitation for the workspace existing_invitation, existing_role = await check_existing_invitation( - project_id, payload.email + project_id, normalized_email ) if existing_invitation: invitation = existing_invitation elif existing_role: # Create a new invitation - invitation = await create_invitation("editor", project_id, payload.email) + invitation = await create_invitation("editor", project_id, normalized_email) # Get project by id project_db = await db_manager.get_project_by_id(project_id=project_id)