Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions api/oss/src/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,17 +982,19 @@ async def check_if_user_exists_and_create_organization(user_email: str):


async def check_if_user_invitation_exists(email: str, organization_id: str):
"""Check if a user invitation with the given email and organization_id exists."""
"""Check if a user invitation with the given email and organization_id exists.
Email comparison is case-insensitive to handle SuperTokens email normalization.
"""

project_db = await get_project_by_organization_id(organization_id=organization_id)
if not project_db:
raise NoResultFound("Project not found for user invitation in organization.")

async with engine.core_session() as session:
result = await session.execute(
select(InvitationDB).filter_by(
email=email,
project_id=project_db.id,
select(InvitationDB).where(
func.lower(InvitationDB.email) == func.lower(email),
InvitationDB.project_id == project_db.id,
)
)
user_invitation = result.scalars().first()
Expand Down Expand Up @@ -1131,6 +1133,7 @@ async def delete_accounts() -> None:

async def create_accounts(payload: dict) -> UserDB:
"""Create a new account in the database.
Email is normalized to lowercase to match SuperTokens behavior.

Args:
payload (dict): The payload to create the user
Expand All @@ -1142,8 +1145,12 @@ async def create_accounts(payload: dict) -> UserDB:
# pop required fields for organization & workspace creation
organization_id = payload.pop("organization_id")

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = payload["email"].lower()
payload["email"] = normalized_email

# create user
user_info = {**payload, "username": payload["email"].split("@")[0]}
user_info = {**payload, "username": normalized_email.split("@")[0]}
user_db = await user_service.create_new_user(payload=user_info)

# only update organization to have user_db as its "owner" if it does not yet have one
Expand All @@ -1161,7 +1168,7 @@ async def create_accounts(payload: dict) -> UserDB:

# update user invitation in the case the user was invited
invitation = await get_project_invitation_by_email(
project_id=str(project_db.id), email=payload["email"]
project_id=str(project_db.id), email=normalized_email
)
if invitation is not None:
await update_invitation(
Expand Down Expand Up @@ -1370,6 +1377,7 @@ async def get_workspaces() -> List[WorkspaceDB]:

async def remove_user_from_workspace(project_id: str, email: str):
"""Remove a user from a workspace.
Email comparison is case-insensitive to handle SuperTokens email normalization.

Args:
project_id (str): The ID of the project
Expand Down Expand Up @@ -1461,6 +1469,7 @@ async def get_user_with_id(user_id: str) -> UserDB:
async def get_user_with_email(email: str):
"""
Retrieves a user from the database based on their email address.
Email comparison is case-insensitive to handle SuperTokens email normalization.

Args:
email (str): The email address of the user to retrieve.
Expand All @@ -1480,7 +1489,9 @@ async def get_user_with_email(email: str):
raise Exception("Please provide a valid email address")

async with engine.core_session() as session:
result = await session.execute(select(UserDB).filter_by(email=email))
result = await session.execute(
select(UserDB).where(func.lower(UserDB.email) == func.lower(email))
)
user = result.scalars().first()
return user

Expand Down Expand Up @@ -1625,6 +1636,7 @@ async def get_default_project_id_from_workspace(

async def get_project_invitation_by_email(project_id: str, email: str) -> InvitationDB:
"""Get project invitation by project ID and email.
Email comparison is case-insensitive to handle SuperTokens email normalization.

Args:
project_id (str): The ID of the project.
Expand All @@ -1636,8 +1648,9 @@ async def get_project_invitation_by_email(project_id: str, email: str) -> Invita

async with engine.core_session() as session:
result = await session.execute(
select(InvitationDB).filter_by(
project_id=uuid.UUID(project_id), email=email
select(InvitationDB).where(
InvitationDB.project_id == uuid.UUID(project_id),
func.lower(InvitationDB.email) == func.lower(email)
)
)
invitation = result.scalars().first()
Expand Down Expand Up @@ -1773,6 +1786,7 @@ async def get_project_invitation_by_token_and_email(
project_id: str, token: str, email: str
) -> InvitationDB:
"""Get project invitation by project ID, token and email.
Email comparison is case-insensitive to handle SuperTokens email normalization.

Args:
project_id (str): The ID of the project.
Expand All @@ -1785,8 +1799,10 @@ async def get_project_invitation_by_token_and_email(

async with engine.core_session() as session:
result = await session.execute(
select(InvitationDB).filter_by(
project_id=uuid.UUID(project_id), token=token, email=email
select(InvitationDB).where(
InvitationDB.project_id == uuid.UUID(project_id),
InvitationDB.token == token,
func.lower(InvitationDB.email) == func.lower(email)
)
)
invitation = result.scalars().first()
Expand Down
38 changes: 28 additions & 10 deletions api/oss/src/services/organization_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_invitation_token(token_length: int = 16):
async def check_existing_invitation(project_id: str, email: str):
"""
Checks if there is an existing invitation for a given project and email address.
Email is normalized to lowercase to match SuperTokens behavior.

Args:
project_id (str): The ID of the project for which the invitation is being checked.
Expand All @@ -26,14 +27,17 @@ async def check_existing_invitation(project_id: str, email: str):
Returns:
- invitation (InvitationDB): The existing invitation if it is valid and not expired. Otherwise, returns None.
"""

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = email.lower()

invitation = await db_manager.get_project_invitation_by_email(
project_id=project_id, email=email
project_id=project_id, email=normalized_email
)
if (
invitation is not None
and not invitation.used
and invitation.email == email
and invitation.email.lower() == normalized_email
and str(invitation.project_id) == project_id
):
if invitation.expiration_date > datetime.now(timezone.utc):
Expand All @@ -50,6 +54,7 @@ async def check_existing_invitation(project_id: str, email: str):
async def check_valid_invitation(project_id: str, email: str, token: str):
"""
Check if a project invitation is valid for a given user and token.
Email is normalized to lowercase to match SuperTokens behavior.

Args:
project_id (str): The ID of the project for which the invitation is being checked.
Expand All @@ -60,9 +65,12 @@ async def check_valid_invitation(project_id: str, email: str, token: str):
InvitationDB or None: Returns the invitation object if it's valid and not expired.
Returns None if the invitation is not found or has expired.
"""

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = email.lower()

invitation = await db_manager.get_project_invitation_by_token_and_email(
project_id, token, email
project_id, token, normalized_email
)
if invitation is not None and invitation.expiration_date > datetime.now(
timezone.utc
Expand Down Expand Up @@ -135,6 +143,7 @@ async def send_invitation_email(
async def create_invitation(role: str, project_id: str, email: str):
"""
Creates a new invitation for a user to join an organization.
Email is normalized to lowercase to match SuperTokens behavior.

Args:
role (str): The role to be assigned to the invited user in the organization.
Expand All @@ -148,12 +157,15 @@ async def create_invitation(role: str, project_id: str, email: str):

token = generate_invitation_token()
expiration_date = datetime.now(timezone.utc) + timedelta(days=7)

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = email.lower()

invitation = await db_manager.create_user_invitation_to_organization(
project_id=project_id,
token=token,
role=role,
email=email,
email=normalized_email,
expiration_date=expiration_date,
)
return invitation
Expand All @@ -178,16 +190,19 @@ async def invite_user_to_organization(

user_performing_action = await db_manager.get_user_with_id(user_id=user_id)

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = payload.email.lower()

# Check that the user is not inviting themselves
if payload.email == user_performing_action.email:
if normalized_email == user_performing_action.email.lower():
raise HTTPException(
status_code=400,
detail="You cannot invite yourself to a workspace",
)

# Check if the user is already a member of the workspace
existing_invitation, existing_role = await check_existing_invitation(
project_id=project_id, email=payload.email
project_id=project_id, email=normalized_email
)
if existing_invitation or existing_role:
raise HTTPException(
Expand All @@ -196,12 +211,12 @@ async def invite_user_to_organization(
)

# Create a new invitation since user hasn't been invited
invitation = await create_invitation("editor", project_id, payload.email)
invitation = await create_invitation("editor", project_id, normalized_email)

# Get project by id
project_db = await db_manager.get_project_by_id(project_id=project_id)

# Send the invitation email
# Send the invitation email (use original email for display purposes)
send_email = await send_invitation_email(
payload.email,
invitation.token, # type: ignore
Expand Down Expand Up @@ -240,15 +255,18 @@ async def resend_user_organization_invite(

user_performing_action = await db_manager.get_user_with_id(user_id=user_id)

# Normalize email to lowercase to match SuperTokens behavior
normalized_email = payload.email.lower()

# Check if the email address already has a valid, unused invitation for the workspace
existing_invitation, existing_role = await check_existing_invitation(
project_id, payload.email
project_id, normalized_email
)
if existing_invitation:
invitation = existing_invitation
elif existing_role:
# Create a new invitation
invitation = await create_invitation("editor", project_id, payload.email)
invitation = await create_invitation("editor", project_id, normalized_email)

# Get project by id
project_db = await db_manager.get_project_by_id(project_id=project_id)
Expand Down