From 4de336c0bdd018187de5654c58a0865b0d3d8494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E9=94=90?= Date: Thu, 8 Jan 2026 14:26:07 +0800 Subject: [PATCH 01/38] feat:bug fix --- frontend/components/agent/AgentImportWizard.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/components/agent/AgentImportWizard.tsx b/frontend/components/agent/AgentImportWizard.tsx index 10e7b0899..6100bacb8 100644 --- a/frontend/components/agent/AgentImportWizard.tsx +++ b/frontend/components/agent/AgentImportWizard.tsx @@ -786,7 +786,7 @@ export default function AgentImportWizard({ const result = await importAgent(importData, { forceImport: false }); if (result.success) { - message.success(t("market.install.success", "Agent installed successfully!")); + // Parent component will show a success message; invalidate and notify parent queryClient.invalidateQueries({ queryKey: ["agents"] }); onImportComplete?.(); handleCancel(); // Close wizard after success From ceab2628176955990c89a6f6c04f6a98677355d2 Mon Sep 17 00:00:00 2001 From: Jasonxia007 Date: Wed, 31 Dec 2025 11:12:27 +0800 Subject: [PATCH 02/38] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20[WIP]=20User=20Manag?= =?UTF-8?q?ement=20Part1:=20Add=20database=20fields=20and=20db=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/consts/exceptions.py | 6 + backend/database/db_models.py | 87 +++ backend/database/group_db.py | 295 +++++++++ backend/database/invitation_db.py | 258 ++++++++ backend/database/knowledge_db.py | 4 + backend/database/role_permission_db.py | 236 +++++++ backend/database/tenant_config_db.py | 16 + backend/database/tool_db.py | 2 +- backend/database/user_tenant_db.py | 15 +- backend/utils/str_utils.py | 31 + ...0_1226_add_invitation_and_group_system.sql | 146 +++++ test/backend/database/test_group_db.py | 575 ++++++++++++++++++ test/backend/database/test_invitation_db.py | 513 ++++++++++++++++ test/backend/database/test_knowledge_db.py | 56 +- .../database/test_role_permission_db.py | 467 ++++++++++++++ test/backend/database/test_user_tenant_db.py | 101 +-- test/backend/utils/test_str_utils.py | 33 +- 17 files changed, 2789 insertions(+), 52 deletions(-) create mode 100644 backend/database/group_db.py create mode 100644 backend/database/invitation_db.py create mode 100644 backend/database/role_permission_db.py create mode 100644 docker/sql/v1.8.0_1226_add_invitation_and_group_system.sql create mode 100644 test/backend/database/test_group_db.py create mode 100644 test/backend/database/test_invitation_db.py create mode 100644 test/backend/database/test_role_permission_db.py diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index 068249998..815ed0eef 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -58,6 +58,12 @@ class TimeoutException(Exception): pass + +class ValidationError(Exception): + """Raised when validation fails.""" + pass + + class NotFoundException(Exception): """Raised when not found exception occurs.""" pass diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 301dd64aa..3f1875de3 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -217,6 +217,7 @@ class AgentInfo(TableBase): Text, doc="Manually entered by the user to describe the entire business process") business_logic_model_name = Column(String(100), doc="Model name used for business logic prompt generation") business_logic_model_id = Column(Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") + group_ids = Column(String, doc="Agent group IDs list") class ToolInstance(TableBase): @@ -251,6 +252,9 @@ class KnowledgeRecord(TableBase): knowledge_sources = Column(String(300), doc="Knowledge base sources") embedding_model_name = Column(String(200), doc="Embedding model name, used to record the embedding model used by the knowledge base") tenant_id = Column(String(100), doc="Tenant ID") + group_ids = Column(String, doc="Knowledge base group IDs list") + ingroup_permission = Column( + String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") class TenantConfig(TableBase): @@ -322,6 +326,7 @@ class UserTenant(TableBase): primary_key=True, nullable=False, doc="User tenant relationship ID, unique primary key") user_id = Column(String(100), nullable=False, doc="User ID") tenant_id = Column(String(100), nullable=False, doc="Tenant ID") + user_role = Column(String(30), doc="User role: SU, ADMIN, DEV, USER") class AgentRelation(TableBase): @@ -355,3 +360,85 @@ class PartnerMappingId(TableBase): 30), doc="Type of the external - internal mapping, value set: CONVERSATION") tenant_id = Column(String(100), doc="Tenant ID") user_id = Column(String(100), doc="User ID") + + +class TenantInvitationCode(TableBase): + """ + Tenant invitation code information table + """ + __tablename__ = "tenant_invitation_code_t" + __table_args__ = {"schema": SCHEMA} + + invitation_id = Column(Integer, Sequence("tenant_invitation_code_t_invitation_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Invitation ID, primary key") + tenant_id = Column(String(100), nullable=False, + doc="Tenant ID, foreign key") + invitation_code = Column(String(100), nullable=False, + unique=True, doc="Invitation code") + group_ids = Column(String, doc="Associated group IDs list") + capacity = Column(Integer, nullable=False, default=1, + doc="Invitation code capacity") + expiry_date = Column(TIMESTAMP(timezone=False), + doc="Invitation code expiry date") + status = Column(String(30), nullable=False, + doc="Invitation code status: IN_USE, EXPIRE, DISABLE, RUN_OUT") + code_type = Column(String(30), nullable=False, + doc="Invitation code type: ADMIN_INVITE, DEV_INVITE, USER_INVITE") + + +class TenantInvitationRecord(TableBase): + """ + Tenant invitation record table + """ + __tablename__ = "tenant_invitation_record_t" + __table_args__ = {"schema": SCHEMA} + + invitation_record_id = Column(Integer, Sequence("tenant_invitation_record_t_invitation_record_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Invitation record ID, primary key") + invitation_id = Column(Integer, nullable=False, + doc="Invitation ID, foreign key") + user_id = Column(String(100), nullable=False, doc="User ID") + + +class TenantGroupInfo(TableBase): + """ + Tenant group information table + """ + __tablename__ = "tenant_group_info_t" + __table_args__ = {"schema": SCHEMA} + + group_id = Column(Integer, Sequence("tenant_group_info_t_group_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Group ID, primary key") + tenant_id = Column(String(100), nullable=False, + doc="Tenant ID, foreign key") + group_name = Column(String(100), nullable=False, doc="Group name") + group_description = Column(String(500), doc="Group description") + + +class TenantGroupUser(TableBase): + """ + Tenant group user membership table + """ + __tablename__ = "tenant_group_user_t" + __table_args__ = {"schema": SCHEMA} + + group_user_id = Column(Integer, Sequence("tenant_group_user_t_group_user_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Group user ID, primary key") + group_id = Column(Integer, nullable=False, doc="Group ID, foreign key") + user_id = Column(String(100), nullable=False, doc="User ID, foreign key") + + +class RolePermission(TableBase): + """ + Role permission configuration table + """ + __tablename__ = "role_permission_t" + __table_args__ = {"schema": SCHEMA} + + role_permission_id = Column(Integer, Sequence("role_permission_t_role_permission_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Role permission ID, primary key") + user_role = Column(String(30), nullable=False, + doc="User role: SU, ADMIN, DEV, USER") + permission_category = Column(String(30), doc="Permission category") + permission_type = Column(String(30), doc="Permission type") + permission_subtype = Column(String(30), doc="Permission subtype") diff --git a/backend/database/group_db.py b/backend/database/group_db.py new file mode 100644 index 000000000..487d86715 --- /dev/null +++ b/backend/database/group_db.py @@ -0,0 +1,295 @@ +""" +Database operations for group management +""" +from typing import Any, Dict, List, Optional, Union + +from database.client import as_dict, get_db_session +from database.db_models import TenantGroupInfo, TenantGroupUser +from utils.str_utils import convert_string_to_list + + +def query_groups(group_id: Union[int, str, List[int]]) -> Union[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Query group(s) by group ID(s) + + Args: + group_id: Group ID(s) - can be int, comma-separated string, or list of ints + + Returns: + Single group dict if int provided, list of group dicts if string/list provided + """ + # Convert input to list of integers + if isinstance(group_id, int): + group_ids = [group_id] + return_single = True + elif isinstance(group_id, str): + group_ids = convert_string_to_list(group_id) + return_single = False + elif isinstance(group_id, list): + group_ids = group_id + return_single = False + else: + raise ValueError("group_id must be int, str, or List[int]") + + if not group_ids: + return [] if not return_single else None + + with get_db_session() as session: + result = session.query(TenantGroupInfo).filter( + TenantGroupInfo.group_id.in_(group_ids), + TenantGroupInfo.delete_flag == "N" + ).all() + + groups = [as_dict(record) for record in result] + + # Return single result if single ID was provided + if return_single: + return groups[0] if groups else None + else: + return groups + + +def query_groups_by_tenant(tenant_id: str) -> List[Dict[str, Any]]: + """ + Query all groups for a tenant + + Args: + tenant_id (str): Tenant ID + + Returns: + List[Dict[str, Any]]: List of group records + """ + with get_db_session() as session: + result = session.query(TenantGroupInfo).filter( + TenantGroupInfo.tenant_id == tenant_id, + TenantGroupInfo.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def add_group(tenant_id: str, group_name: str, group_description: Optional[str] = None, + created_by: Optional[str] = None) -> int: + """ + Add a new group + + Args: + tenant_id (str): Tenant ID + group_name (str): Group name + group_description (Optional[str]): Group description + created_by (Optional[str]): Created by user + + Returns: + int: Created group ID + """ + with get_db_session() as session: + group = TenantGroupInfo( + tenant_id=tenant_id, + group_name=group_name, + group_description=group_description, + created_by=created_by, + updated_by=created_by + ) + session.add(group) + session.flush() # To get the ID + return group.group_id + + +def modify_group(group_id: int, updates: Dict[str, Any], updated_by: Optional[str] = None) -> bool: + """ + Modify group information + + Args: + group_id (int): Group ID + updates (Dict[str, Any]): Fields to update + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether update was successful + """ + with get_db_session() as session: + update_data = updates.copy() + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(TenantGroupInfo).filter( + TenantGroupInfo.group_id == group_id, + TenantGroupInfo.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def remove_group(group_id: int, updated_by: Optional[str] = None) -> bool: + """ + Remove group (soft delete) + + Args: + group_id (int): Group ID + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether removal was successful + """ + with get_db_session() as session: + update_data: Dict[str, Any] = {"delete_flag": "Y"} + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(TenantGroupInfo).filter( + TenantGroupInfo.group_id == group_id, + TenantGroupInfo.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def add_user_to_group(group_id: int, user_id: str, created_by: Optional[str] = None) -> int: + """ + Add user to group + + Args: + group_id (int): Group ID + user_id (str): User ID + created_by (Optional[str]): Created by user + + Returns: + int: Created group user ID + """ + with get_db_session() as session: + group_user = TenantGroupUser( + group_id=group_id, + user_id=user_id, + created_by=created_by, + updated_by=created_by + ) + session.add(group_user) + session.flush() # To get the ID + return group_user.group_user_id + + +def remove_user_from_group(group_id: int, user_id: str, updated_by: Optional[str] = None) -> bool: + """ + Remove user from group + + Args: + group_id (int): Group ID + user_id (str): User ID + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether removal was successful + """ + with get_db_session() as session: + update_data: Dict[str, Any] = {"delete_flag": "Y"} + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(TenantGroupUser).filter( + TenantGroupUser.group_id == group_id, + TenantGroupUser.user_id == user_id, + TenantGroupUser.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def query_group_users(group_id: int) -> List[Dict[str, Any]]: + """ + Query all users in a group + + Args: + group_id (int): Group ID + + Returns: + List[Dict[str, Any]]: List of group user records + """ + with get_db_session() as session: + result = session.query(TenantGroupUser).filter( + TenantGroupUser.group_id == group_id, + TenantGroupUser.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def query_group_ids_by_user(user_id: str) -> List[int]: + """ + Query all group IDs for a user + + Args: + user_id (str): User ID + + Returns: + List[int]: List of group IDs + """ + with get_db_session() as session: + result = session.query(TenantGroupUser.group_id).filter( + TenantGroupUser.user_id == user_id, + TenantGroupUser.delete_flag == "N" + ).all() + + return [record[0] for record in result] + + +def query_groups_by_user(user_id: str) -> List[Dict[str, Any]]: + """ + Query all groups for a user + + Args: + user_id (str): User ID + + Returns: + List[Dict[str, Any]]: List of group records + """ + with get_db_session() as session: + result = session.query(TenantGroupInfo).join( + TenantGroupUser, + TenantGroupInfo.group_id == TenantGroupUser.group_id + ).filter( + TenantGroupUser.user_id == user_id, + TenantGroupUser.delete_flag == "N", + TenantGroupInfo.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def check_user_in_group(user_id: str, group_id: int) -> bool: + """ + Check if user is in a specific group + + Args: + user_id (str): User ID + group_id (int): Group ID + + Returns: + bool: Whether user is in the group + """ + with get_db_session() as session: + result = session.query(TenantGroupUser).filter( + TenantGroupUser.group_id == group_id, + TenantGroupUser.user_id == user_id, + TenantGroupUser.delete_flag == "N" + ).first() + + return result is not None + + +def count_group_users(group_id: int) -> int: + """ + Count users in a group + + Args: + group_id (int): Group ID + + Returns: + int: Number of users in the group + """ + with get_db_session() as session: + result = session.query(TenantGroupUser).filter( + TenantGroupUser.group_id == group_id, + TenantGroupUser.delete_flag == "N" + ).count() + + return result diff --git a/backend/database/invitation_db.py b/backend/database/invitation_db.py new file mode 100644 index 000000000..c9c71b30f --- /dev/null +++ b/backend/database/invitation_db.py @@ -0,0 +1,258 @@ +""" +Database operations for invitation code management +""" +from typing import Any, Dict, List, Optional + +from database.client import as_dict, get_db_session +from database.db_models import TenantInvitationCode, TenantInvitationRecord +from utils.str_utils import convert_list_to_string + + +def query_invitation_by_code(invitation_code: str) -> Optional[Dict[str, Any]]: + """ + Query invitation by invitation code + + Args: + invitation_code (str): Invitation code + + Returns: + Optional[Dict[str, Any]]: Invitation record + """ + with get_db_session() as session: + result = session.query(TenantInvitationCode).filter( + TenantInvitationCode.invitation_code == invitation_code, + TenantInvitationCode.delete_flag == "N" + ).first() + + if result: + return as_dict(result) + return None + + +def query_invitation_by_id(invitation_id: int) -> Optional[Dict[str, Any]]: + """ + Query invitation by ID + + Args: + invitation_id (int): Invitation ID + + Returns: + Optional[Dict[str, Any]]: Invitation record + """ + with get_db_session() as session: + result = session.query(TenantInvitationCode).filter( + TenantInvitationCode.invitation_id == invitation_id, + TenantInvitationCode.delete_flag == "N" + ).first() + + if result: + return as_dict(result) + return None + + +def query_invitations_by_tenant(tenant_id: str) -> List[Dict[str, Any]]: + """ + Query all invitations for a tenant + + Args: + tenant_id (str): Tenant ID + + Returns: + List[Dict[str, Any]]: List of invitation records + """ + with get_db_session() as session: + result = session.query(TenantInvitationCode).filter( + TenantInvitationCode.tenant_id == tenant_id, + TenantInvitationCode.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def add_invitation(tenant_id: str, invitation_code: str, code_type: str, group_ids: Optional[List[int]] = None, + capacity: int = 1, expiry_date: Optional[str] = None, + status: str = "IN_USE", created_by: Optional[str] = None) -> int: + """ + Add a new invitation + + Args: + tenant_id (str): Tenant ID + invitation_code (str): Invitation code + code_type (str): Invitation code type (ADMIN_INVITE, DEV_INVITE, USER_INVITE) + group_ids (Optional[List[int]]): Associated group IDs + capacity (int): Invitation capacity + expiry_date (Optional[str]): Expiry date + status (str): Status + created_by (Optional[str]): Created by user + + Returns: + int: Created invitation ID + """ + with get_db_session() as session: + invitation = TenantInvitationCode( + tenant_id=tenant_id, + invitation_code=invitation_code, + code_type=code_type, + group_ids=convert_list_to_string(group_ids), + capacity=capacity, + expiry_date=expiry_date, + status=status, + created_by=created_by, + updated_by=created_by + ) + session.add(invitation) + session.flush() # To get the ID + return invitation.invitation_id + + +def modify_invitation(invitation_id: int, updates: Dict[str, Any], updated_by: Optional[str] = None) -> bool: + """ + Modify invitation + + Args: + invitation_id (int): Invitation ID + updates (Dict[str, Any]): Fields to update + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether update was successful + """ + with get_db_session() as session: + update_data = updates.copy() + if updated_by: + update_data["updated_by"] = updated_by + + # Convert group_ids list to string if present + if "group_ids" in update_data and isinstance(update_data["group_ids"], list): + update_data["group_ids"] = convert_list_to_string(update_data["group_ids"]) + + result = session.query(TenantInvitationCode).filter( + TenantInvitationCode.invitation_id == invitation_id, + TenantInvitationCode.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def remove_invitation(invitation_id: int, updated_by: Optional[str] = None) -> bool: + """ + Remove invitation (soft delete) + + Args: + invitation_id (int): Invitation ID + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether removal was successful + """ + with get_db_session() as session: + update_data: Dict[str, Any] = {"delete_flag": "Y"} + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(TenantInvitationCode).filter( + TenantInvitationCode.invitation_id == invitation_id, + TenantInvitationCode.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def query_invitation_records(invitation_id: int) -> List[Dict[str, Any]]: + """ + Query invitation records by invitation ID + + Args: + invitation_id (int): Invitation ID + + Returns: + List[Dict[str, Any]]: List of invitation records + """ + with get_db_session() as session: + result = session.query(TenantInvitationRecord).filter( + TenantInvitationRecord.invitation_id == invitation_id, + TenantInvitationRecord.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def add_invitation_record(invitation_id: int, user_id: str, created_by: Optional[str] = None) -> int: + """ + Add invitation usage record + + Args: + invitation_id (int): Invitation ID + user_id (str): User ID + created_by (Optional[str]): Created by user + + Returns: + int: Created invitation record ID + """ + with get_db_session() as session: + record = TenantInvitationRecord( + invitation_id=invitation_id, + user_id=user_id, + created_by=created_by, + updated_by=created_by + ) + session.add(record) + session.flush() # To get the ID + return record.invitation_record_id + + +def query_invitation_records_by_user(user_id: str) -> List[Dict[str, Any]]: + """ + Query invitation records by user ID + + Args: + user_id (str): User ID + + Returns: + List[Dict[str, Any]]: List of invitation records + """ + with get_db_session() as session: + result = session.query(TenantInvitationRecord).filter( + TenantInvitationRecord.user_id == user_id, + TenantInvitationRecord.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def count_invitation_usage(invitation_id: int) -> int: + """ + Count usage for an invitation code + + Args: + invitation_id (int): Invitation ID + + Returns: + int: Number of times the invitation has been used + """ + with get_db_session() as session: + result = session.query(TenantInvitationRecord).filter( + TenantInvitationRecord.invitation_id == invitation_id, + TenantInvitationRecord.delete_flag == "N" + ).count() + + return result + + +def query_invitation_status(invitation_code: str) -> Optional[str]: + """ + Query invitation status + + Args: + invitation_code (str): Invitation code + + Returns: + Optional[str]: Invitation status if exists, None otherwise + """ + with get_db_session() as session: + invitation = session.query(TenantInvitationCode).filter( + TenantInvitationCode.invitation_code == invitation_code, + TenantInvitationCode.delete_flag == "N" + ).first() + + return invitation.status if invitation else None diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index 6faccdafa..f270cf7bf 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -6,6 +6,7 @@ from database.client import as_dict, get_db_session from database.db_models import KnowledgeRecord +from utils.str_utils import convert_list_to_string def _generate_index_name(knowledge_id: int) -> str: @@ -40,6 +41,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name") or query.get("index_name") # Prepare data dictionary + group_ids = query.get("group_ids") data: Dict[str, Any] = { "knowledge_describe": query.get("knowledge_describe", ""), "created_by": query.get("user_id"), @@ -48,6 +50,8 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "tenant_id": query.get("tenant_id"), "embedding_model_name": query.get("embedding_model_name"), "knowledge_name": knowledge_name, + "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, + "ingroup_permission": query.get("ingroup_permission"), } # For backward compatibility: if caller explicitly provides index_name, diff --git a/backend/database/role_permission_db.py b/backend/database/role_permission_db.py new file mode 100644 index 000000000..71714430b --- /dev/null +++ b/backend/database/role_permission_db.py @@ -0,0 +1,236 @@ +""" +Database operations for role permission management +""" +from typing import Any, Dict, List, Optional + +from database.client import as_dict, get_db_session +from database.db_models import RolePermission + + +def get_role_permissions(user_role: str) -> List[Dict[str, Any]]: + """ + Get all permissions for a user role + + Args: + user_role (str): User role (SU, ADMIN, DEV, USER) + + Returns: + List[Dict[str, Any]]: List of role permission records + """ + with get_db_session() as session: + result = session.query(RolePermission).filter( + RolePermission.user_role == user_role, + RolePermission.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def get_all_role_permissions() -> List[Dict[str, Any]]: + """ + Get all role permissions + + Returns: + List[Dict[str, Any]]: List of all role permission records + """ + with get_db_session() as session: + result = session.query(RolePermission).filter( + RolePermission.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def create_role_permission(user_role: str, permission_category: Optional[str] = None, + permission_type: Optional[str] = None, permission_subtype: Optional[str] = None, + created_by: Optional[str] = None) -> int: + """ + Create a new role permission + + Args: + user_role (str): User role + permission_category (Optional[str]): Permission category + permission_type (Optional[str]): Permission type + permission_subtype (Optional[str]): Permission subtype + created_by (Optional[str]): Created by user + + Returns: + int: Created role permission ID + """ + with get_db_session() as session: + permission = RolePermission( + user_role=user_role, + permission_category=permission_category, + permission_type=permission_type, + permission_subtype=permission_subtype, + created_by=created_by, + updated_by=created_by + ) + session.add(permission) + session.flush() # To get the ID + return permission.role_permission_id + + +def update_role_permission(role_permission_id: int, updates: Dict[str, Any], + updated_by: Optional[str] = None) -> bool: + """ + Update role permission + + Args: + role_permission_id (int): Role permission ID + updates (Dict[str, Any]): Fields to update + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether update was successful + """ + with get_db_session() as session: + update_data = updates.copy() + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(RolePermission).filter( + RolePermission.role_permission_id == role_permission_id, + RolePermission.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def soft_delete_role_permission(role_permission_id: int, updated_by: Optional[str] = None) -> bool: + """ + Soft delete role permission + + Args: + role_permission_id (int): Role permission ID + updated_by (Optional[str]): Updated by user + + Returns: + bool: Whether deletion was successful + """ + with get_db_session() as session: + update_data: Dict[str, Any] = {"delete_flag": "Y"} + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(RolePermission).filter( + RolePermission.role_permission_id == role_permission_id, + RolePermission.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result > 0 + + +def delete_role_permissions_by_role(user_role: str, updated_by: Optional[str] = None) -> int: + """ + Delete all permissions for a user role + + Args: + user_role (str): User role + updated_by (Optional[str]): Updated by user + + Returns: + int: Number of deleted permissions + """ + with get_db_session() as session: + update_data: Dict[str, Any] = {"delete_flag": "Y"} + if updated_by: + update_data["updated_by"] = updated_by + + result = session.query(RolePermission).filter( + RolePermission.user_role == user_role, + RolePermission.delete_flag == "N" + ).update(update_data, synchronize_session=False) + + return result + + +def check_role_permission(user_role: str, permission_category: Optional[str] = None, + permission_type: Optional[str] = None, permission_subtype: Optional[str] = None) -> bool: + """ + Check if a role has specific permission + + Args: + user_role (str): User role + permission_category (Optional[str]): Permission category + permission_type (Optional[str]): Permission type + permission_subtype (Optional[str]): Permission subtype + + Returns: + bool: Whether the role has the permission + """ + with get_db_session() as session: + query = session.query(RolePermission).filter( + RolePermission.user_role == user_role, + RolePermission.delete_flag == "N" + ) + + if permission_category: + query = query.filter(RolePermission.permission_category == permission_category) + if permission_type: + query = query.filter(RolePermission.permission_type == permission_type) + if permission_subtype: + query = query.filter(RolePermission.permission_subtype == permission_subtype) + + result = query.first() + return result is not None + + +def get_permissions_by_category(permission_category: str) -> List[Dict[str, Any]]: + """ + Get all permissions for a specific category + + Args: + permission_category (str): Permission category + + Returns: + List[Dict[str, Any]]: List of role permission records + """ + with get_db_session() as session: + result = session.query(RolePermission).filter( + RolePermission.permission_category == permission_category, + RolePermission.delete_flag == "N" + ).all() + + return [as_dict(record) for record in result] + + +def initialize_default_permissions() -> None: + """ + Initialize default role permissions + This should be called during system setup + """ + default_permissions = [ + # SUPER_ADMIN permissions (SU) + {"user_role": "SU", "permission_category": "SYSTEM", "permission_type": "ALL", "permission_subtype": "FULL_ACCESS"}, + # ADMIN permissions + {"user_role": "ADMIN", "permission_category": "USER_MANAGEMENT", "permission_type": "USER", "permission_subtype": "CRUD"}, + {"user_role": "ADMIN", "permission_category": "GROUP_MANAGEMENT", "permission_type": "GROUP", "permission_subtype": "CRUD"}, + {"user_role": "ADMIN", "permission_category": "KNOWLEDGE_BASE", "permission_type": "KNOWLEDGE", "permission_subtype": "CRUD"}, + {"user_role": "ADMIN", "permission_category": "AGENT_MANAGEMENT", "permission_type": "AGENT", "permission_subtype": "CRUD"}, + {"user_role": "ADMIN", "permission_category": "INVITATION_MANAGEMENT", "permission_type": "INVITATION", "permission_subtype": "CRUD"}, + # DEV permissions + {"user_role": "DEV", "permission_category": "KNOWLEDGE_BASE", "permission_type": "KNOWLEDGE", "permission_subtype": "READ"}, + {"user_role": "DEV", "permission_category": "KNOWLEDGE_BASE", "permission_type": "KNOWLEDGE", "permission_subtype": "CREATE"}, + {"user_role": "DEV", "permission_category": "AGENT_MANAGEMENT", "permission_type": "AGENT", "permission_subtype": "READ"}, + {"user_role": "DEV", "permission_category": "AGENT_MANAGEMENT", "permission_type": "AGENT", "permission_subtype": "CREATE"}, + # USER permissions + {"user_role": "USER", "permission_category": "KNOWLEDGE_BASE", "permission_type": "KNOWLEDGE", "permission_subtype": "READ"}, + {"user_role": "USER", "permission_category": "AGENT_MANAGEMENT", "permission_type": "AGENT", "permission_subtype": "READ"}, + ] + + for permission in default_permissions: + # Check if permission already exists + if not check_role_permission( + user_role=permission["user_role"], + permission_category=permission["permission_category"], + permission_type=permission["permission_type"], + permission_subtype=permission["permission_subtype"] + ): + create_role_permission( + user_role=permission["user_role"], + permission_category=permission["permission_category"], + permission_type=permission["permission_type"], + permission_subtype=permission["permission_subtype"], + created_by="SYSTEM" + ) diff --git a/backend/database/tenant_config_db.py b/backend/database/tenant_config_db.py index 2c97df457..fd4389e17 100644 --- a/backend/database/tenant_config_db.py +++ b/backend/database/tenant_config_db.py @@ -137,3 +137,19 @@ def update_config_by_tenant_config_id_and_data(tenant_config_id: int, insert_dat session.rollback() logger.error(f"update config by tenant config id and data failed, error: {e}") return False + + +def get_all_tenant_ids(): + """ + Get all tenant IDs that have tenant configurations + + Returns: + List[str]: List of tenant IDs + """ + with get_db_session() as session: + result = session.query(TenantConfig.tenant_id).filter( + TenantConfig.config_key == "TENANT_NAME", + TenantConfig.delete_flag == "N" + ).distinct().all() + + return [row[0] for row in result] diff --git a/backend/database/tool_db.py b/backend/database/tool_db.py index edc256fde..ff9c1488c 100644 --- a/backend/database/tool_db.py +++ b/backend/database/tool_db.py @@ -182,7 +182,7 @@ def search_tools_for_sub_agent(agent_id, tenant_id): ToolInstance.agent_id == agent_id, ToolInstance.tenant_id == tenant_id, ToolInstance.delete_flag != 'Y', - ToolInstance.enabled == True + ToolInstance.enabled ) tool_instances = query.all() diff --git a/backend/database/user_tenant_db.py b/backend/database/user_tenant_db.py index 16c8c0928..960b38855 100644 --- a/backend/database/user_tenant_db.py +++ b/backend/database/user_tenant_db.py @@ -1,11 +1,12 @@ """ Database operations for user tenant relationship management """ -from typing import Any, Dict, Optional +from typing import Any, List, Dict, Optional from consts.const import DEFAULT_TENANT_ID from database.client import as_dict, get_db_session from database.db_models import UserTenant +from utils.str_utils import convert_list_to_string def get_user_tenant_by_user_id(user_id: str) -> Optional[Dict[str, Any]]: @@ -32,7 +33,7 @@ def get_user_tenant_by_user_id(user_id: str) -> Optional[Dict[str, Any]]: def get_all_tenant_ids() -> list[str]: """ Get all unique tenant IDs from the database - + Returns: list[str]: List of unique tenant IDs """ @@ -40,28 +41,30 @@ def get_all_tenant_ids() -> list[str]: result = session.query(UserTenant.tenant_id).filter( UserTenant.delete_flag == "N" ).distinct().all() - + tenant_ids = [row[0] for row in result] - + # Add default tenant_id if not already in the list if DEFAULT_TENANT_ID not in tenant_ids: tenant_ids.append(DEFAULT_TENANT_ID) - + return tenant_ids -def insert_user_tenant(user_id: str, tenant_id: str): +def insert_user_tenant(user_id: str, tenant_id: str, user_role: str = "USER"): """ Insert user tenant relationship Args: user_id (str): User ID tenant_id (str): Tenant ID + user_role (str): User role (SUPER_ADMIN, ADMIN, DEV, USER) """ with get_db_session() as session: user_tenant = UserTenant( user_id=user_id, tenant_id=tenant_id, + user_role=user_role, created_by=user_id, updated_by=user_id ) diff --git a/backend/utils/str_utils.py b/backend/utils/str_utils.py index a20b4e4b7..dc7887595 100644 --- a/backend/utils/str_utils.py +++ b/backend/utils/str_utils.py @@ -1,4 +1,5 @@ import re +from typing import List, Optional def remove_think_blocks(text: str) -> str: @@ -6,3 +7,33 @@ def remove_think_blocks(text: str) -> str: if not text: return text return re.sub(r"(?:)?.*?", "", text, flags=re.DOTALL | re.IGNORECASE) + + +def convert_list_to_string(items: Optional[List[int]]) -> str: + """ + Convert list of integers to comma-separated string for database storage + + Args: + items: List of integers or None + + Returns: + Comma-separated string, empty string if None + """ + if items is None: + return "" + return ",".join(str(item) for item in items) + + +def convert_string_to_list(items_str: Optional[str]) -> List[int]: + """ + Convert comma-separated string to list of integers for processing + + Args: + items_str: Comma-separated string or None + + Returns: + List of integers, empty list if None or empty string + """ + if not items_str or items_str.strip() == "": + return [] + return [int(item.strip()) for item in items_str.split(",") if item.strip().isdigit()] diff --git a/docker/sql/v1.8.0_1226_add_invitation_and_group_system.sql b/docker/sql/v1.8.0_1226_add_invitation_and_group_system.sql new file mode 100644 index 000000000..a8376162c --- /dev/null +++ b/docker/sql/v1.8.0_1226_add_invitation_and_group_system.sql @@ -0,0 +1,146 @@ +-- Add invitation code and group management system +-- This migration adds invitation codes, groups, and permission management features + +-- 1. Create tenant_invitation_code_t table for invitation codes +CREATE TABLE IF NOT EXISTS nexent.tenant_invitation_code_t ( + invitation_id SERIAL PRIMARY KEY, + tenant_id VARCHAR(100) NOT NULL, + invitation_code VARCHAR(100) NOT NULL, + group_ids VARCHAR, -- int4 list + capacity INT4 NOT NULL DEFAULT 1, + expiry_date TIMESTAMP(6) WITHOUT TIME ZONE, + status VARCHAR(30) NOT NULL, + code_type VARCHAR(30) NOT NULL, + create_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +-- Add comments for tenant_invitation_code_t table +COMMENT ON TABLE nexent.tenant_invitation_code_t IS 'Tenant invitation code information table'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.invitation_id IS 'Invitation ID, primary key'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.tenant_id IS 'Tenant ID, foreign key'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.invitation_code IS 'Invitation code'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.group_ids IS 'Associated group IDs list'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.capacity IS 'Invitation code capacity'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.expiry_date IS 'Invitation code expiry date'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.status IS 'Invitation code status: IN_USE, EXPIRE, DISABLE, RUN_OUT'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.code_type IS 'Invitation code type: ADMIN_INVITE, DEV_INVITE, USER_INVITE'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.create_time IS 'Create time'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.created_by IS 'Created by'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.updated_by IS 'Updated by'; +COMMENT ON COLUMN nexent.tenant_invitation_code_t.delete_flag IS 'Delete flag, Y/N'; + +-- 2. Create tenant_invitation_record_t table for invitation usage records +CREATE TABLE IF NOT EXISTS nexent.tenant_invitation_record_t ( + invitation_record_id SERIAL PRIMARY KEY, + invitation_id INT4 NOT NULL, + user_id VARCHAR(100) NOT NULL, + create_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +-- Add comments for tenant_invitation_record_t table +COMMENT ON TABLE nexent.tenant_invitation_record_t IS 'Tenant invitation record table'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.invitation_record_id IS 'Invitation record ID, primary key'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.invitation_id IS 'Invitation ID, foreign key'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.user_id IS 'User ID'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.create_time IS 'Create time'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.created_by IS 'Created by'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.updated_by IS 'Updated by'; +COMMENT ON COLUMN nexent.tenant_invitation_record_t.delete_flag IS 'Delete flag, Y/N'; + +-- 3. Create tenant_group_info_t table for group information +CREATE TABLE IF NOT EXISTS nexent.tenant_group_info_t ( + group_id SERIAL PRIMARY KEY, + tenant_id VARCHAR(100) NOT NULL, + group_name VARCHAR(100) NOT NULL, + group_description VARCHAR(500), + create_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +-- Add comments for tenant_group_info_t table +COMMENT ON TABLE nexent.tenant_group_info_t IS 'Tenant group information table'; +COMMENT ON COLUMN nexent.tenant_group_info_t.group_id IS 'Group ID, primary key'; +COMMENT ON COLUMN nexent.tenant_group_info_t.tenant_id IS 'Tenant ID, foreign key'; +COMMENT ON COLUMN nexent.tenant_group_info_t.group_name IS 'Group name'; +COMMENT ON COLUMN nexent.tenant_group_info_t.group_description IS 'Group description'; +COMMENT ON COLUMN nexent.tenant_group_info_t.create_time IS 'Create time'; +COMMENT ON COLUMN nexent.tenant_group_info_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.tenant_group_info_t.created_by IS 'Created by'; +COMMENT ON COLUMN nexent.tenant_group_info_t.updated_by IS 'Updated by'; +COMMENT ON COLUMN nexent.tenant_group_info_t.delete_flag IS 'Delete flag, Y/N'; + +-- 4. Create tenant_group_user_t table for group user membership +CREATE TABLE IF NOT EXISTS nexent.tenant_group_user_t ( + group_user_id SERIAL PRIMARY KEY, + group_id INT4 NOT NULL, + user_id VARCHAR(100) NOT NULL, + create_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +-- Add comments for tenant_group_user_t table +COMMENT ON TABLE nexent.tenant_group_user_t IS 'Tenant group user membership table'; +COMMENT ON COLUMN nexent.tenant_group_user_t.group_user_id IS 'Group user ID, primary key'; +COMMENT ON COLUMN nexent.tenant_group_user_t.group_id IS 'Group ID, foreign key'; +COMMENT ON COLUMN nexent.tenant_group_user_t.user_id IS 'User ID, foreign key'; +COMMENT ON COLUMN nexent.tenant_group_user_t.create_time IS 'Create time'; +COMMENT ON COLUMN nexent.tenant_group_user_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.tenant_group_user_t.created_by IS 'Created by'; +COMMENT ON COLUMN nexent.tenant_group_user_t.updated_by IS 'Updated by'; +COMMENT ON COLUMN nexent.tenant_group_user_t.delete_flag IS 'Delete flag, Y/N'; + +-- 5. Add fields to user_tenant_t table +ALTER TABLE nexent.user_tenant_t +ADD COLUMN IF NOT EXISTS user_role VARCHAR(30); + +-- Add comments for new fields in user_tenant_t table +COMMENT ON COLUMN nexent.user_tenant_t.user_role IS 'User role: SU, ADMIN, DEV, USER'; + +-- 6. Create role_permission_t table for role permissions +CREATE TABLE IF NOT EXISTS nexent.role_permission_t ( + role_permission_id SERIAL PRIMARY KEY, + user_role VARCHAR(30) NOT NULL, + permission_category VARCHAR(30), + permission_type VARCHAR(30), + permission_subtype VARCHAR(30) +); + +-- Add comments for role_permission_t table +COMMENT ON TABLE nexent.role_permission_t IS 'Role permission configuration table'; +COMMENT ON COLUMN nexent.role_permission_t.role_permission_id IS 'Role permission ID, primary key'; +COMMENT ON COLUMN nexent.role_permission_t.user_role IS 'User role: SU, ADMIN, DEV, USER'; +COMMENT ON COLUMN nexent.role_permission_t.permission_category IS 'Permission category'; +COMMENT ON COLUMN nexent.role_permission_t.permission_type IS 'Permission type'; +COMMENT ON COLUMN nexent.role_permission_t.permission_subtype IS 'Permission subtype'; + +-- 7. Add fields to knowledge_record_t table +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS group_ids VARCHAR, -- int4 list +ADD COLUMN IF NOT EXISTS ingroup_permission VARCHAR(30); + +-- Add comments for new fields in knowledge_record_t table +COMMENT ON COLUMN nexent.knowledge_record_t.group_ids IS 'Knowledge base group IDs list'; +COMMENT ON COLUMN nexent.knowledge_record_t.ingroup_permission IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; + +-- 8. Add fields to ag_tenant_agent_t table +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS group_ids VARCHAR; -- int4 list + +-- Add comments for new fields in ag_tenant_agent_t table +COMMENT ON COLUMN nexent.ag_tenant_agent_t.group_ids IS 'Agent group IDs list'; diff --git a/test/backend/database/test_group_db.py b/test/backend/database/test_group_db.py new file mode 100644 index 000000000..e76de74c5 --- /dev/null +++ b/test/backend/database/test_group_db.py @@ -0,0 +1,575 @@ +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) + +import pytest +from unittest.mock import MagicMock + +# First mock the consts module to avoid ModuleNotFoundError +consts_mock = MagicMock() +consts_mock.const = MagicMock() +# Set required constants in consts.const +consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" +consts_mock.const.MINIO_ACCESS_KEY = "test_access_key" +consts_mock.const.MINIO_SECRET_KEY = "test_secret_key" +consts_mock.const.MINIO_REGION = "us-east-1" +consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_mock.const.POSTGRES_HOST = "localhost" +consts_mock.const.POSTGRES_USER = "test_user" +consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_mock.const.POSTGRES_DB = "test_db" +consts_mock.const.POSTGRES_PORT = 5432 +consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" + +# Add the mocked consts module to sys.modules +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_mock.const + +# Mock utils module +utils_mock = MagicMock() +utils_mock.auth_utils = MagicMock() +utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") + +# Mock str_utils module +str_utils_mock = MagicMock() + + +def mock_convert_string_to_list(s): + """Mock implementation of convert_string_to_list that converts comma-separated string to int list""" + if isinstance(s, str) and s: + return [int(x.strip()) for x in s.split(',') if x.strip()] + return [] + + +str_utils_mock.convert_string_to_list = mock_convert_string_to_list +utils_mock.str_utils = str_utils_mock + +# Add the mocked utils module to sys.modules +sys.modules['utils'] = utils_mock +sys.modules['utils.auth_utils'] = utils_mock.auth_utils +sys.modules['utils.str_utils'] = str_utils_mock + +# Provide a stub for the `boto3` module so that it can be imported safely even +# if the testing environment does not have it available. +boto3_mock = MagicMock() +sys.modules['boto3'] = boto3_mock + +# Mock the entire client module +client_mock = MagicMock() +client_mock.MinioClient = MagicMock() +client_mock.PostgresClient = MagicMock() +client_mock.db_client = MagicMock() +client_mock.get_db_session = MagicMock() +client_mock.as_dict = MagicMock() +client_mock.filter_property = MagicMock() + +# Add the mocked client module to sys.modules +sys.modules['database.client'] = client_mock +sys.modules['backend.database.client'] = client_mock + +# Mock db_models module +db_models_mock = MagicMock() +db_models_mock.TenantGroupInfo = MagicMock() +db_models_mock.TenantGroupUser = MagicMock() + +class MockTenantGroupInfo: + def __init__(self, **kwargs): + self.group_id = kwargs.get('group_id', 1) + self.tenant_id = kwargs.get('tenant_id', 'test_tenant') + self.group_name = kwargs.get('group_name', 'test_group') + self.group_description = kwargs.get('group_description', 'test description') + self.created_by = kwargs.get('created_by', 'test_user') + self.updated_by = kwargs.get('updated_by', 'test_user') + self.delete_flag = kwargs.get('delete_flag', 'N') + self.create_time = kwargs.get('create_time', '2024-01-01 00:00:00') + self.update_time = kwargs.get('update_time', '2024-01-01 00:00:00') + +class MockTenantGroupUser: + def __init__(self, **kwargs): + self.group_user_id = kwargs.get('group_user_id', 1) + self.group_id = kwargs.get('group_id', 1) + self.user_id = kwargs.get('user_id', 'test_user') + self.created_by = kwargs.get('created_by', 'test_user') + self.updated_by = kwargs.get('updated_by', 'test_user') + self.delete_flag = kwargs.get('delete_flag', 'N') + self.create_time = kwargs.get('create_time', '2024-01-01 00:00:00') + self.update_time = kwargs.get('update_time', '2024-01-01 00:00:00') + + +# Add the mocked db_models module to sys.modules +sys.modules['database.db_models'] = db_models_mock +sys.modules['backend.database.db_models'] = db_models_mock + +# Mock exceptions module +exceptions_mock = MagicMock() +sys.modules['consts.exceptions'] = exceptions_mock +sys.modules['backend.consts.exceptions'] = exceptions_mock + +# Mock sqlalchemy module +sqlalchemy_mock = MagicMock() +sqlalchemy_mock.exc = MagicMock() + +class MockSQLAlchemyError(Exception): + pass + +sqlalchemy_mock.exc.SQLAlchemyError = MockSQLAlchemyError + +# Add the mocked sqlalchemy module to sys.modules +sys.modules['sqlalchemy'] = sqlalchemy_mock +sys.modules['sqlalchemy.exc'] = sqlalchemy_mock.exc + +# Now we can safely import the module under test +from backend.database.group_db import ( + query_groups, + query_groups_by_tenant, + add_group, + modify_group, + remove_group, + add_user_to_group, + remove_user_from_group, + query_group_users, + query_groups_by_user, + query_group_ids_by_user, + check_user_in_group, + count_group_users +) + + +@pytest.fixture +def mock_session(): + """Create mock database session""" + mock_session = MagicMock() + mock_query = MagicMock() + mock_session.query.return_value = mock_query + return mock_session, mock_query + + +def test_get_group_by_id_success(monkeypatch, mock_session): + """Test successfully retrieving group by ID""" + session, query = mock_session + + mock_group = MockTenantGroupInfo() + mock_group.group_id = 123 + mock_group.group_name = "test_group" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_group] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", lambda obj: obj.__dict__) + + result = query_groups(123) + + assert result is not None + assert result["group_id"] == 123 + assert result["group_name"] == "test_group" + + +def test_get_group_by_id_not_found(monkeypatch, mock_session): + """Test retrieving non-existent group""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = query_groups(999) + + assert result is None + + +def test_get_group_by_id_with_string_input(monkeypatch, mock_session): + """Test retrieving groups by comma-separated string""" + session, query = mock_session + + mock_group1 = MockTenantGroupInfo(group_id=1, group_name="group1") + mock_group2 = MockTenantGroupInfo(group_id=2, group_name="group2") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_group1, mock_group2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", + lambda obj: obj.__dict__) + + result = query_groups("1,2") + + assert len(result) == 2 + assert result[0]["group_id"] == 1 + assert result[1]["group_id"] == 2 + + +def test_get_group_by_id_with_list_input(monkeypatch, mock_session): + """Test retrieving groups by list of IDs""" + session, query = mock_session + + mock_group1 = MockTenantGroupInfo(group_id=1, group_name="group1") + mock_group2 = MockTenantGroupInfo(group_id=3, group_name="group3") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_group1, mock_group2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", + lambda obj: obj.__dict__) + + result = query_groups([1, 3]) + + assert len(result) == 2 + assert result[0]["group_id"] == 1 + assert result[1]["group_id"] == 3 + + +def test_get_group_by_id_empty_string(monkeypatch, mock_session): + """Test retrieving groups with empty string""" + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = None + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = query_groups("") + + assert result == [] + + +def test_get_group_by_id_empty_list(monkeypatch, mock_session): + """Test retrieving groups with empty list""" + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = None + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = query_groups([]) + + assert result == [] + + +def test_get_group_by_id_invalid_type(): + """Test get_group_by_id with invalid input type""" + import pytest + + with pytest.raises(ValueError, match="group_id must be int, str, or List\\[int\\]"): + query_groups(1.5) # float is not supported + + +def test_get_groups_by_tenant_success(monkeypatch, mock_session): + """Test retrieving groups by tenant""" + session, query = mock_session + + mock_group1 = MockTenantGroupInfo(group_id=1, group_name="group1") + mock_group2 = MockTenantGroupInfo(group_id=2, group_name="group2") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_group1, mock_group2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", lambda obj: obj.__dict__) + + result = query_groups_by_tenant("test_tenant") + + assert len(result) == 2 + assert result[0]["group_name"] == "group1" + assert result[1]["group_name"] == "group2" + + +def test_create_group_success(monkeypatch, mock_session): + """Test successfully creating group""" + session, _ = mock_session + session.add = MagicMock() + + mock_group = MockTenantGroupInfo() + mock_group.group_id = 123 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.group_db.TenantGroupInfo', return_value=mock_group): + result = add_group( + tenant_id="test_tenant", + group_name="test_group", + group_description="test description", + created_by="test_user" + ) + + assert result == 123 + session.add.assert_called_once_with(mock_group) + session.flush.assert_called_once() + + +def test_update_group_success(monkeypatch, mock_session): + """Test successfully updating group""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = modify_group( + group_id=123, + updates={"group_name": "new_name", "group_description": "new description"}, + updated_by="test_user" + ) + + assert result is True + + +def test_soft_delete_group_success(monkeypatch, mock_session): + """Test successfully soft deleting group""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = remove_group(group_id=123, updated_by="test_user") + + assert result is True + + +def test_add_user_to_group_success(monkeypatch, mock_session): + """Test successfully adding user to group""" + session, _ = mock_session + session.add = MagicMock() + + mock_group_user = MockTenantGroupUser() + mock_group_user.group_user_id = 456 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.group_db.TenantGroupUser', return_value=mock_group_user): + result = add_user_to_group( + group_id=123, + user_id="test_user", + created_by="test_user" + ) + + assert result == 456 + session.add.assert_called_once_with(mock_group_user) + session.flush.assert_called_once() + + +def test_remove_user_from_group_success(monkeypatch, mock_session): + """Test successfully removing user from group""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = remove_user_from_group( + group_id=123, + user_id="test_user", + updated_by="test_user" + ) + + assert result is True + + +def test_get_group_users_success(monkeypatch, mock_session): + """Test retrieving users in a group""" + session, query = mock_session + + mock_user1 = MockTenantGroupUser(group_user_id=1, user_id="user1") + mock_user2 = MockTenantGroupUser(group_user_id=2, user_id="user2") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_user1, mock_user2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", lambda obj: obj.__dict__) + + result = query_group_users(123) + + assert len(result) == 2 + assert result[0]["user_id"] == "user1" + assert result[1]["user_id"] == "user2" + + +def test_get_groups_by_user_success(monkeypatch, mock_session): + """Test retrieving groups for a user""" + session, query = mock_session + + mock_group1 = MockTenantGroupInfo(group_id=1, group_name="group1") + mock_group2 = MockTenantGroupInfo(group_id=2, group_name="group2") + + # Mock the join query + mock_join = MagicMock() + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_group1, mock_group2] + mock_join.filter.return_value = mock_filter + query.join.return_value = mock_join + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.group_db.as_dict", lambda obj: obj.__dict__) + + result = query_groups_by_user("test_user") + + assert len(result) == 2 + assert result[0]["group_name"] == "group1" + assert result[1]["group_name"] == "group2" + + +def test_get_group_ids_by_user_success(monkeypatch, mock_session): + """Test retrieving group IDs for a user""" + session, _ = mock_session + + # Create a mock query that returns tuples of group_ids + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.all.return_value = [(1,), (2,), (3,)] + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = query_group_ids_by_user("test_user") + + assert result == [1, 2, 3] + + +def test_is_user_in_group_true(monkeypatch, mock_session): + """Test checking if user is in group - user is in group""" + session, query = mock_session + + mock_group_user = MockTenantGroupUser() + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_group_user + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = check_user_in_group("test_user", 123) + + assert result is True + + +def test_is_user_in_group_false(monkeypatch, mock_session): + """Test checking if user is in group - user is not in group""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = check_user_in_group("test_user", 123) + + assert result is False + + +def test_get_group_user_count_success(monkeypatch, mock_session): + """Test getting group user count""" + session, _ = mock_session + + # Create a mock query that returns count + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.count.return_value = 5 + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + result = count_group_users(123) + + assert result == 5 + + +def test_database_error_handling(monkeypatch, mock_session): + """Test database error handling""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.group_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + query_groups(123) diff --git a/test/backend/database/test_invitation_db.py b/test/backend/database/test_invitation_db.py new file mode 100644 index 000000000..6f4be3ef1 --- /dev/null +++ b/test/backend/database/test_invitation_db.py @@ -0,0 +1,513 @@ +import sys +import pytest +from unittest.mock import MagicMock + +# First mock the consts module to avoid ModuleNotFoundError +consts_mock = MagicMock() +consts_mock.const = MagicMock() +# Set required constants in consts.const +consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" +consts_mock.const.MINIO_ACCESS_KEY = "test_access_key" +consts_mock.const.MINIO_SECRET_KEY = "test_secret_key" +consts_mock.const.MINIO_REGION = "us-east-1" +consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_mock.const.POSTGRES_HOST = "localhost" +consts_mock.const.POSTGRES_USER = "test_user" +consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_mock.const.POSTGRES_DB = "test_db" +consts_mock.const.POSTGRES_PORT = 5432 +consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" + +# Add the mocked consts module to sys.modules +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_mock.const + +# Mock utils module +utils_mock = MagicMock() +utils_mock.auth_utils = MagicMock() +utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") +utils_mock.str_utils = MagicMock() +utils_mock.str_utils.convert_list_to_string = MagicMock(side_effect=lambda x: ",".join(str(i) for i in x) if x else "") + +# Add the mocked utils module to sys.modules +sys.modules['utils'] = utils_mock +sys.modules['utils.auth_utils'] = utils_mock.auth_utils +sys.modules['utils.str_utils'] = utils_mock.str_utils + +# Provide a stub for the `boto3` module so that it can be imported safely even +# if the testing environment does not have it available. +boto3_mock = MagicMock() +sys.modules['boto3'] = boto3_mock + +# Mock the entire client module +client_mock = MagicMock() +client_mock.MinioClient = MagicMock() +client_mock.PostgresClient = MagicMock() +client_mock.db_client = MagicMock() +client_mock.get_db_session = MagicMock() +client_mock.as_dict = MagicMock() +client_mock.filter_property = MagicMock() + +# Add the mocked client module to sys.modules +sys.modules['database.client'] = client_mock +sys.modules['backend.database.client'] = client_mock + +# Mock db_models module +db_models_mock = MagicMock() +db_models_mock.TenantInvitationCode = MagicMock() +db_models_mock.TenantInvitationRecord = MagicMock() + +class MockTenantInvitationCode: + def __init__(self, **kwargs): + self.invitation_id = kwargs.get('invitation_id', 1) + self.tenant_id = kwargs.get('tenant_id', 'test_tenant') + self.invitation_code = kwargs.get('invitation_code', 'test_code') + self.group_ids = kwargs.get('group_ids', '1,2,3') + self.capacity = kwargs.get('capacity', 5) + self.expiry_date = kwargs.get('expiry_date', '2024-12-31 23:59:59') + self.status = kwargs.get('status', 'IN_USE') + self.code_type = kwargs.get('code_type', 'ADMIN_INVITE') + self.created_by = kwargs.get('created_by', 'test_user') + self.updated_by = kwargs.get('updated_by', 'test_user') + self.delete_flag = kwargs.get('delete_flag', 'N') + self.create_time = kwargs.get('create_time', '2024-01-01 00:00:00') + self.update_time = kwargs.get('update_time', '2024-01-01 00:00:00') + +class MockTenantInvitationRecord: + def __init__(self, **kwargs): + self.invitation_record_id = kwargs.get('invitation_record_id', 1) + self.invitation_id = kwargs.get('invitation_id', 1) + self.user_id = kwargs.get('user_id', 'test_user') + self.created_by = kwargs.get('created_by', 'test_user') + self.updated_by = kwargs.get('updated_by', 'test_user') + self.delete_flag = kwargs.get('delete_flag', 'N') + self.create_time = kwargs.get('create_time', '2024-01-01 00:00:00') + self.update_time = kwargs.get('update_time', '2024-01-01 00:00:00') + + +# Add the mocked db_models module to sys.modules +sys.modules['database.db_models'] = db_models_mock +sys.modules['backend.database.db_models'] = db_models_mock + +# Mock exceptions module +exceptions_mock = MagicMock() +sys.modules['consts.exceptions'] = exceptions_mock +sys.modules['backend.consts.exceptions'] = exceptions_mock + +# Mock sqlalchemy module +sqlalchemy_mock = MagicMock() +sqlalchemy_mock.exc = MagicMock() + +class MockSQLAlchemyError(Exception): + pass + +sqlalchemy_mock.exc.SQLAlchemyError = MockSQLAlchemyError + +# Add the mocked sqlalchemy module to sys.modules +sys.modules['sqlalchemy'] = sqlalchemy_mock +sys.modules['sqlalchemy.exc'] = sqlalchemy_mock.exc + +# Now we can safely import the module under test +from backend.database.invitation_db import ( + query_invitation_by_code, + query_invitation_by_id, + query_invitations_by_tenant, + add_invitation, + modify_invitation, + remove_invitation, + query_invitation_records, + add_invitation_record, + count_invitation_usage, + query_invitation_status +) + + +@pytest.fixture +def mock_session(): + """Create mock database session""" + mock_session = MagicMock() + mock_query = MagicMock() + mock_session.query.return_value = mock_query + return mock_session, mock_query + + +def test_query_invitation_by_code_success(monkeypatch, mock_session): + """Test successfully retrieving invitation code by code""" + session, query = mock_session + + mock_invitation = MockTenantInvitationCode() + mock_invitation.invitation_id = 123 + mock_invitation.invitation_code = "test_code" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_invitation + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.invitation_db.as_dict", lambda obj: obj.__dict__) + + result = query_invitation_by_code("test_code") + + assert result is not None + assert result["invitation_code"] == "test_code" + assert result["invitation_id"] == 123 + + +def test_query_invitation_by_code_not_found(monkeypatch, mock_session): + """Test retrieving non-existent invitation code""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = query_invitation_by_code("nonexistent_code") + + assert result is None + + +def test_query_invitation_by_id_success(monkeypatch, mock_session): + """Test retrieving invitation code by ID""" + session, query = mock_session + + mock_invitation = MockTenantInvitationCode() + mock_invitation.invitation_id = 123 + mock_invitation.invitation_code = "test_code" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_invitation + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.invitation_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.invitation_db.as_dict", lambda obj: obj.__dict__) + + result = query_invitation_by_id(123) + + assert result is not None + assert result["invitation_code"] == "test_code" + assert result["invitation_id"] == 123 + + +def test_query_invitation_by_id_not_found(monkeypatch, mock_session): + """Test retrieving non-existent invitation code by ID""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = query_invitation_by_id(999) + + assert result is None + + +def test_query_invitations_by_tenant_success(monkeypatch, mock_session): + """Test retrieving invitation codes by tenant""" + session, query = mock_session + + mock_invitation1 = MockTenantInvitationCode(invitation_id=1, invitation_code="code1") + mock_invitation2 = MockTenantInvitationCode(invitation_id=2, invitation_code="code2") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_invitation1, mock_invitation2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.invitation_db.as_dict", lambda obj: obj.__dict__) + + result = query_invitations_by_tenant("test_tenant") + + assert len(result) == 2 + assert result[0]["invitation_code"] == "code1" + assert result[1]["invitation_code"] == "code2" + + +def test_add_invitation_success(monkeypatch, mock_session): + """Test successfully creating invitation code""" + session, _ = mock_session + session.add = MagicMock() + + mock_invitation = MockTenantInvitationCode() + mock_invitation.invitation_id = 123 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.invitation_db.TenantInvitationCode', return_value=mock_invitation): + result = add_invitation( + tenant_id="test_tenant", + invitation_code="test_code", + code_type="ADMIN_INVITE", + group_ids=[1, 2, 3], + capacity=5, + expiry_date="2024-12-31", + status="IN_USE", + created_by="test_user" + ) + + assert result == 123 + session.add.assert_called_once_with(mock_invitation) + session.flush.assert_called_once() + + +def test_add_invitation_with_group_ids_list(monkeypatch, mock_session): + """Test successfully creating invitation code with group IDs as list""" + session, _ = mock_session + session.add = MagicMock() + + mock_invitation = MockTenantInvitationCode() + mock_invitation.invitation_id = 123 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.invitation_db.TenantInvitationCode', return_value=mock_invitation) as mock_constructor: + result = add_invitation( + tenant_id="test_tenant", + invitation_code="test_code", + code_type="ADMIN_INVITE", + group_ids=[1, 2, 3], + capacity=5, + expiry_date="2024-12-31", + status="IN_USE", + created_by="test_user" + ) + + assert result == 123 + # Verify TenantInvitationCode was called with group_ids converted to string + mock_constructor.assert_called_once_with( + tenant_id="test_tenant", + invitation_code="test_code", + code_type="ADMIN_INVITE", + group_ids="1,2,3", # Should be converted to comma-separated string + capacity=5, + expiry_date="2024-12-31", + status="IN_USE", + created_by="test_user", + updated_by="test_user" + ) + session.add.assert_called_once_with(mock_invitation) + session.flush.assert_called_once() + + +def test_modify_invitation_success(monkeypatch, mock_session): + """Test successfully updating invitation code""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = modify_invitation( + invitation_id=123, + updates={"status": "DISABLE", "capacity": 10}, + updated_by="test_user" + ) + + assert result is True + + +def test_remove_invitation_success(monkeypatch, mock_session): + """Test successfully soft deleting invitation code""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = remove_invitation(invitation_id=123, updated_by="test_user") + + assert result is True + + +def test_query_invitation_records_success(monkeypatch, mock_session): + """Test retrieving invitation records by invitation ID""" + session, query = mock_session + + mock_record1 = MockTenantInvitationRecord(invitation_record_id=1, user_id="user1") + mock_record2 = MockTenantInvitationRecord(invitation_record_id=2, user_id="user2") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1, mock_record2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.invitation_db.as_dict", lambda obj: obj.__dict__) + + result = query_invitation_records(123) + + assert len(result) == 2 + assert result[0]["user_id"] == "user1" + assert result[1]["user_id"] == "user2" + + +def test_add_invitation_record_success(monkeypatch, mock_session): + """Test successfully creating invitation record""" + session, _ = mock_session + session.add = MagicMock() + + mock_record = MockTenantInvitationRecord() + mock_record.invitation_record_id = 456 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.invitation_db.TenantInvitationRecord', return_value=mock_record): + result = add_invitation_record( + invitation_id=123, + user_id="test_user", + created_by="test_user" + ) + + assert result == 456 + session.add.assert_called_once_with(mock_record) + session.flush.assert_called_once() + + +def test_count_invitation_usage_success(monkeypatch, mock_session): + """Test getting invitation usage count""" + session, _ = mock_session + + # Create a mock query that returns count + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.count.return_value = 3 + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = count_invitation_usage(123) + + assert result == 3 + + +def test_get_invitation_status_in_use(monkeypatch, mock_session): + """Test getting invitation status when in use""" + session, query = mock_session + + mock_invitation = MockTenantInvitationCode() + mock_invitation.status = "IN_USE" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_invitation + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = query_invitation_status("test_code") + + assert result == "IN_USE" + + +def test_get_invitation_status_expired(monkeypatch, mock_session): + """Test getting invitation status when expired""" + session, query = mock_session + + mock_invitation = MockTenantInvitationCode() + mock_invitation.status = "EXPIRE" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_invitation + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = query_invitation_status("test_code") + + assert result == "EXPIRE" + + +def test_get_invitation_status_not_found(monkeypatch, mock_session): + """Test getting invitation status when it doesn't exist""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + result = query_invitation_status("nonexistent_code") + + assert result is None + + +def test_database_error_handling(monkeypatch, mock_session): + """Test database error handling""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.invitation_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + query_invitation_by_code("test_code") diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index af337eb8d..415e2b8bd 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -26,10 +26,13 @@ utils_mock = MagicMock() utils_mock.auth_utils = MagicMock() utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") +utils_mock.str_utils = MagicMock() +utils_mock.str_utils.convert_list_to_string = MagicMock(side_effect=lambda x: ",".join(str(i) for i in x) if x else "") # Add the mocked utils module to sys.modules sys.modules['utils'] = utils_mock sys.modules['utils.auth_utils'] = utils_mock.auth_utils +sys.modules['utils.str_utils'] = utils_mock.str_utils # Provide a stub for the `boto3` module so that it can be imported safely even # if the testing environment does not have it available. @@ -78,6 +81,8 @@ def __init__(self, **kwargs): self.knowledge_sources = kwargs.get('knowledge_sources', 'elasticsearch') self.tenant_id = kwargs.get('tenant_id', 'test_tenant') self.embedding_model_name = kwargs.get('embedding_model_name', 'test_model') + self.group_ids = kwargs.get('group_ids', '1,2,3') # New field + self.ingroup_permission = kwargs.get('ingroup_permission', 'READ_ONLY') # New field, corrected name self.delete_flag = kwargs.get('delete_flag', 'N') self.update_time = kwargs.get('update_time', "2023-01-01 00:00:00") @@ -91,6 +96,8 @@ def __init__(self, **kwargs): knowledge_sources = MagicMock(name="knowledge_sources_column") tenant_id = MagicMock(name="tenant_id_column") embedding_model_name = MagicMock(name="embedding_model_name_column") + group_ids = MagicMock(name="group_ids_column") # New field + ingroup_permission = MagicMock(name="ingroup_permission_column") # New field, corrected name delete_flag = MagicMock(name="delete_flag_column") update_time = MagicMock(name="update_time_column") @@ -146,7 +153,9 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): "user_id": "test_user", "tenant_id": "test_tenant", "embedding_model_name": "test_model", - "knowledge_name": "test_knowledge" + "knowledge_name": "test_knowledge", + "group_ids": [1, 2, 3], + "ingroup_permission": "READ_ONLY" } # Mock KnowledgeRecord constructor @@ -163,6 +172,51 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): session.commit.assert_called_once() +def test_create_knowledge_record_with_group_ids_list(monkeypatch, mock_session): + """Test successful creation of knowledge record with group IDs as list""" + session, _ = mock_session + + # Create mock knowledge record + mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") + mock_record.knowledge_id = 123 + mock_record.index_name = "test_knowledge" + + # Mock database session context + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Prepare test data with group_ids as list + test_query = { + "index_name": "test_knowledge", + "knowledge_describe": "Test knowledge description", + "user_id": "test_user", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model", + "knowledge_name": "test_knowledge", + "group_ids": [1, 2, 3], + "ingroup_permission": "READ_ONLY" + } + + # Mock KnowledgeRecord constructor + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record) as mock_constructor: + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 123, + "index_name": "test_knowledge", + "knowledge_name": "test_knowledge", + } + # Verify KnowledgeRecord was called with group_ids converted to string + mock_constructor.assert_called_once() + call_kwargs = mock_constructor.call_args[1] # Get kwargs from the call + assert call_kwargs["group_ids"] == "1,2,3" # Should be converted to comma-separated string + session.add.assert_called_once_with(mock_record) + assert session.flush.call_count == 1 + session.commit.assert_called_once() + + def test_create_knowledge_record_exception(monkeypatch, mock_session): """Test exception during knowledge record creation""" session, _ = mock_session diff --git a/test/backend/database/test_role_permission_db.py b/test/backend/database/test_role_permission_db.py new file mode 100644 index 000000000..5bfbd6d76 --- /dev/null +++ b/test/backend/database/test_role_permission_db.py @@ -0,0 +1,467 @@ +import sys +import pytest +from unittest.mock import MagicMock + +# First mock the consts module to avoid ModuleNotFoundError +consts_mock = MagicMock() +consts_mock.const = MagicMock() +# Set required constants in consts.const +consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" +consts_mock.const.MINIO_ACCESS_KEY = "test_access_key" +consts_mock.const.MINIO_SECRET_KEY = "test_secret_key" +consts_mock.const.MINIO_REGION = "us-east-1" +consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_mock.const.POSTGRES_HOST = "localhost" +consts_mock.const.POSTGRES_USER = "test_user" +consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_mock.const.POSTGRES_DB = "test_db" +consts_mock.const.POSTGRES_PORT = 5432 +consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" + +# Add the mocked consts module to sys.modules +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_mock.const + +# Mock utils module +utils_mock = MagicMock() +utils_mock.auth_utils = MagicMock() +utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") + +# Add the mocked utils module to sys.modules +sys.modules['utils'] = utils_mock +sys.modules['utils.auth_utils'] = utils_mock.auth_utils + +# Provide a stub for the `boto3` module so that it can be imported safely even +# if the testing environment does not have it available. +boto3_mock = MagicMock() +sys.modules['boto3'] = boto3_mock + +# Mock the entire client module +client_mock = MagicMock() +client_mock.MinioClient = MagicMock() +client_mock.PostgresClient = MagicMock() +client_mock.db_client = MagicMock() +client_mock.get_db_session = MagicMock() +client_mock.as_dict = MagicMock() +client_mock.filter_property = MagicMock() + +# Add the mocked client module to sys.modules +sys.modules['database.client'] = client_mock +sys.modules['backend.database.client'] = client_mock + +# Mock db_models module +db_models_mock = MagicMock() +db_models_mock.RolePermission = MagicMock() + +class MockRolePermission: + def __init__(self, **kwargs): + self.role_permission_id = kwargs.get('role_permission_id', 1) + self.user_role = kwargs.get('user_role', 'USER') + self.permission_category = kwargs.get('permission_category', 'SYSTEM') + self.permission_type = kwargs.get('permission_type', 'READ') + self.permission_subtype = kwargs.get('permission_subtype', 'BASIC') + self.created_by = kwargs.get('created_by', 'test_user') + self.updated_by = kwargs.get('updated_by', 'test_user') + self.delete_flag = kwargs.get('delete_flag', 'N') + self.create_time = kwargs.get('create_time', '2024-01-01 00:00:00') + self.update_time = kwargs.get('update_time', '2024-01-01 00:00:00') + + +# Add the mocked db_models module to sys.modules +sys.modules['database.db_models'] = db_models_mock +sys.modules['backend.database.db_models'] = db_models_mock + +# Mock exceptions module +exceptions_mock = MagicMock() +sys.modules['consts.exceptions'] = exceptions_mock +sys.modules['backend.consts.exceptions'] = exceptions_mock + +# Mock sqlalchemy module +sqlalchemy_mock = MagicMock() +sqlalchemy_mock.exc = MagicMock() + +class MockSQLAlchemyError(Exception): + pass + +sqlalchemy_mock.exc.SQLAlchemyError = MockSQLAlchemyError + +# Add the mocked sqlalchemy module to sys.modules +sys.modules['sqlalchemy'] = sqlalchemy_mock +sys.modules['sqlalchemy.exc'] = sqlalchemy_mock.exc + +# Now we can safely import the module under test +from backend.database.role_permission_db import ( + get_role_permissions, + get_all_role_permissions, + create_role_permission, + update_role_permission, + soft_delete_role_permission, + delete_role_permissions_by_role, + check_role_permission, + get_permissions_by_category, + initialize_default_permissions +) + + +@pytest.fixture +def mock_session(): + """Create mock database session""" + mock_session = MagicMock() + mock_query = MagicMock() + mock_session.query.return_value = mock_query + return mock_session, mock_query + + +def test_get_role_permissions_success(monkeypatch, mock_session): + """Test successfully retrieving role permissions""" + session, query = mock_session + + mock_permission1 = MockRolePermission( + role_permission_id=1, + user_role="USER", + permission_category="KNOWLEDGE_BASE", + permission_type="KNOWLEDGE", + permission_subtype="READ" + ) + mock_permission2 = MockRolePermission( + role_permission_id=2, + user_role="USER", + permission_category="AGENT_MANAGEMENT", + permission_type="AGENT", + permission_subtype="READ" + ) + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_permission1, mock_permission2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.role_permission_db.as_dict", lambda obj: obj.__dict__) + + result = get_role_permissions("USER") + + assert len(result) == 2 + assert result[0]["user_role"] == "USER" + assert result[0]["permission_category"] == "KNOWLEDGE_BASE" + assert result[1]["permission_category"] == "AGENT_MANAGEMENT" + + +def test_get_all_role_permissions_success(monkeypatch, mock_session): + """Test retrieving all role permissions""" + session, query = mock_session + + mock_permission1 = MockRolePermission(user_role="USER") + mock_permission2 = MockRolePermission(user_role="ADMIN") + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_permission1, mock_permission2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.role_permission_db.as_dict", lambda obj: obj.__dict__) + + result = get_all_role_permissions() + + assert len(result) == 2 + assert result[0]["user_role"] == "USER" + assert result[1]["user_role"] == "ADMIN" + + +def test_create_role_permission_success(monkeypatch, mock_session): + """Test successfully creating role permission""" + session, _ = mock_session + session.add = MagicMock() + + mock_permission = MockRolePermission() + mock_permission.role_permission_id = 123 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.role_permission_db.RolePermission', return_value=mock_permission): + result = create_role_permission( + user_role="USER", + permission_category="KNOWLEDGE_BASE", + permission_type="KNOWLEDGE", + permission_subtype="READ", + created_by="test_user" + ) + + assert result == 123 + session.add.assert_called_once_with(mock_permission) + session.flush.assert_called_once() + + +def test_update_role_permission_success(monkeypatch, mock_session): + """Test successfully updating role permission""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = update_role_permission( + role_permission_id=123, + updates={"permission_category": "new_category"}, + updated_by="test_user" + ) + + assert result is True + + +def test_soft_delete_role_permission_success(monkeypatch, mock_session): + """Test successfully soft deleting role permission""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 1 # 1 row affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = soft_delete_role_permission(role_permission_id=123, updated_by="test_user") + + assert result is True + + +def test_delete_role_permissions_by_role_success(monkeypatch, mock_session): + """Test successfully deleting all permissions for a role""" + session, query = mock_session + + # Setup query filter().update() chain + mock_update = MagicMock() + mock_update.return_value = 3 # 3 rows affected + mock_filter = MagicMock() + mock_filter.update = mock_update + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = delete_role_permissions_by_role("USER", updated_by="test_user") + + assert result == 3 + + +def test_check_role_permission_true(monkeypatch, mock_session): + """Test checking role permission - permission exists""" + session, query = mock_session + + mock_permission = MockRolePermission() + + # Mock chain: query.filter().filter().filter().filter().first() + mock_filter_final = MagicMock() + mock_filter_final.first.return_value = mock_permission + + mock_filter3 = MagicMock() + mock_filter3.filter.return_value = mock_filter_final + + mock_filter2 = MagicMock() + mock_filter2.filter.return_value = mock_filter3 + + mock_filter1 = MagicMock() + mock_filter1.filter.return_value = mock_filter2 + + query.filter.return_value = mock_filter1 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = check_role_permission( + user_role="USER", + permission_category="KNOWLEDGE_BASE", + permission_type="KNOWLEDGE", + permission_subtype="READ" + ) + + assert result is True + + +def test_check_role_permission_false(monkeypatch, mock_session): + """Test checking role permission - permission does not exist""" + session, query = mock_session + + # Mock chain: query.filter().filter().filter().filter().first() + mock_filter_final = MagicMock() + mock_filter_final.first.return_value = None + + mock_filter3 = MagicMock() + mock_filter3.filter.return_value = mock_filter_final + + mock_filter2 = MagicMock() + mock_filter2.filter.return_value = mock_filter3 + + mock_filter1 = MagicMock() + mock_filter1.filter.return_value = mock_filter2 + + query.filter.return_value = mock_filter1 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = check_role_permission( + user_role="USER", + permission_category="NONEXISTENT", + permission_type="NONEXISTENT", + permission_subtype="NONEXISTENT" + ) + + assert result is False + + +def test_get_permissions_by_category_success(monkeypatch, mock_session): + """Test retrieving permissions by category""" + session, query = mock_session + + mock_permission1 = MockRolePermission( + role_permission_id=1, + user_role="USER", + permission_category="KNOWLEDGE_BASE" + ) + mock_permission2 = MockRolePermission( + role_permission_id=2, + user_role="ADMIN", + permission_category="KNOWLEDGE_BASE" + ) + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_permission1, mock_permission2] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.role_permission_db.as_dict", lambda obj: obj.__dict__) + + result = get_permissions_by_category("KNOWLEDGE_BASE") + + assert len(result) == 2 + assert all(perm["permission_category"] == "KNOWLEDGE_BASE" for perm in result) + + +def test_initialize_default_permissions_success(monkeypatch, mock_session): + """Test initializing default permissions""" + session, query = mock_session + + # Mock that permissions don't exist yet + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + # Mock create_role_permission to avoid actual creation + def mock_create(*args, **kwargs): + return 1 + + monkeypatch.setattr("backend.database.role_permission_db.create_role_permission", mock_create) + + # Should not raise any exception + initialize_default_permissions() + + # Verify create_role_permission was called multiple times for default permissions + # (We can't easily count calls with this mock setup, but we can ensure no exception) + + +def test_database_error_handling(monkeypatch, mock_session): + """Test database error handling""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_role_permissions("USER") + + +def test_create_role_permission_with_none_fields(monkeypatch, mock_session): + """Test creating role permission with None fields""" + session, _ = mock_session + session.add = MagicMock() + + mock_permission = MockRolePermission() + mock_permission.role_permission_id = 123 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + from unittest.mock import patch + with patch('backend.database.role_permission_db.RolePermission', return_value=mock_permission): + result = create_role_permission( + user_role="USER", + permission_category=None, + permission_type=None, + permission_subtype=None, + created_by="test_user" + ) + + assert result == 123 + session.add.assert_called_once_with(mock_permission) + + +def test_check_role_permission_partial_match(monkeypatch, mock_session): + """Test checking role permission with partial criteria""" + session, query = mock_session + + mock_permission = MockRolePermission() + + # Mock filter chain for partial matching + mock_filter1 = MagicMock() + mock_filter2 = MagicMock() + mock_filter3 = MagicMock() + mock_filter3.first.return_value = mock_permission + + query.filter.return_value = mock_filter1 + mock_filter1.filter.return_value = mock_filter2 + mock_filter2.filter.return_value = mock_filter3 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.role_permission_db.get_db_session", lambda: mock_ctx) + + result = check_role_permission( + user_role="USER", + permission_category="KNOWLEDGE_BASE" + # Only checking category, not type or subtype + ) + + assert result is True diff --git a/test/backend/database/test_user_tenant_db.py b/test/backend/database/test_user_tenant_db.py index 305eb3806..f78cd7a34 100644 --- a/test/backend/database/test_user_tenant_db.py +++ b/test/backend/database/test_user_tenant_db.py @@ -29,10 +29,14 @@ utils_mock = MagicMock() utils_mock.auth_utils = MagicMock() utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") +utils_mock.str_utils = MagicMock() +utils_mock.str_utils.convert_list_to_string = MagicMock( + side_effect=lambda x: ",".join(str(i) for i in x) if x else "") # Add the mocked utils module to sys.modules sys.modules['utils'] = utils_mock sys.modules['utils.auth_utils'] = utils_mock.auth_utils +sys.modules['utils.str_utils'] = utils_mock.str_utils # Provide a stub for the `boto3` module so that it can be imported safely even # if the testing environment does not have it available. @@ -75,6 +79,8 @@ class MockUserTenant: def __init__(self): self.user_id = "test_user_id" self.tenant_id = "test_tenant_id" + self.group_ids = "1,2,3" # New field + self.user_role = "USER" # New field with correct role value self.delete_flag = "N" self.created_by = "test_user_id" self.updated_by = "test_user_id" @@ -83,6 +89,8 @@ def __init__(self): self.__dict__ = { "user_id": "test_user_id", "tenant_id": "test_tenant_id", + "group_ids": "1,2,3", + "user_role": "USER", "delete_flag": "N", "created_by": "test_user_id", "updated_by": "test_user_id", @@ -102,57 +110,59 @@ def test_get_user_tenant_by_user_id_success(monkeypatch, mock_session): """Test successful retrieval of user tenant relationship by user ID""" session, query = mock_session mock_user_tenant = MockUserTenant() - + mock_first = MagicMock() mock_first.return_value = mock_user_tenant mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) monkeypatch.setattr("backend.database.user_tenant_db.as_dict", lambda obj: obj.__dict__) - + result = get_user_tenant_by_user_id("test_user_id") - + assert result is not None assert result["user_id"] == "test_user_id" assert result["tenant_id"] == "test_tenant_id" + assert result["group_ids"] == "1,2,3" + assert result["user_role"] == "USER" assert result["delete_flag"] == "N" def test_get_user_tenant_by_user_id_not_found(monkeypatch, mock_session): """Test retrieval of user tenant relationship when record does not exist""" session, query = mock_session - + mock_first = MagicMock() mock_first.return_value = None mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + result = get_user_tenant_by_user_id("nonexistent_user_id") - + assert result is None def test_get_user_tenant_by_user_id_database_error(monkeypatch, mock_session): """Test database error when retrieving user tenant relationship - exception should propagate""" from sqlalchemy.exc import SQLAlchemyError - + session, query = mock_session query.filter.side_effect = SQLAlchemyError("Database error") - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + # Should raise SQLAlchemyError with pytest.raises(SQLAlchemyError): get_user_tenant_by_user_id("test_user_id") @@ -161,31 +171,31 @@ def test_insert_user_tenant_success(monkeypatch, mock_session): """Test successful insertion of user tenant relationship""" session, _ = mock_session session.add = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) monkeypatch.setattr("backend.database.user_tenant_db.UserTenant", lambda **kwargs: MagicMock()) - + # Should not raise any exception insert_user_tenant("test_user_id", "test_tenant_id") - + session.add.assert_called_once() def test_insert_user_tenant_failure(monkeypatch, mock_session): """Test failure of user tenant relationship insertion - exception should propagate""" from sqlalchemy.exc import SQLAlchemyError - + session, _ = mock_session session.add = MagicMock(side_effect=SQLAlchemyError("Database error")) - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) monkeypatch.setattr("backend.database.user_tenant_db.UserTenant", lambda **kwargs: MagicMock()) - + # Should raise SQLAlchemyError with pytest.raises(SQLAlchemyError): insert_user_tenant("test_user_id", "test_tenant_id") @@ -194,51 +204,54 @@ def test_insert_user_tenant_with_empty_user_id(monkeypatch, mock_session): """Test insertion of user tenant relationship with empty user ID""" session, _ = mock_session session.add = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + # Mock UserTenant constructor to capture the arguments mock_user_tenant_instance = MagicMock() mock_user_tenant_constructor = MagicMock(return_value=mock_user_tenant_instance) monkeypatch.setattr("backend.database.user_tenant_db.UserTenant", mock_user_tenant_constructor) - + # Should not raise any exception insert_user_tenant("", "test_tenant_id") - + # Verify UserTenant was called with correct parameters mock_user_tenant_constructor.assert_called_once_with( user_id="", tenant_id="test_tenant_id", + user_role="USER", created_by="", updated_by="" ) session.add.assert_called_once_with(mock_user_tenant_instance) + def test_insert_user_tenant_with_empty_tenant_id(monkeypatch, mock_session): """Test insertion of user tenant relationship with empty tenant ID""" session, _ = mock_session session.add = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + # Mock UserTenant constructor to capture the arguments mock_user_tenant_instance = MagicMock() mock_user_tenant_constructor = MagicMock(return_value=mock_user_tenant_instance) monkeypatch.setattr("backend.database.user_tenant_db.UserTenant", mock_user_tenant_constructor) - + # Should not raise any exception insert_user_tenant("test_user_id", "") - + # Verify UserTenant was called with correct parameters mock_user_tenant_constructor.assert_called_once_with( user_id="test_user_id", tenant_id="", + user_role="USER", created_by="test_user_id", updated_by="test_user_id" ) @@ -248,10 +261,10 @@ def test_insert_user_tenant_with_empty_tenant_id(monkeypatch, mock_session): def test_user_tenant_lifecycle(monkeypatch, mock_session): """Test complete user tenant lifecycle: insert and then retrieve""" session, query = mock_session - + # Mock database operations for insertion session.add = MagicMock() - + # Mock database operations for retrieval mock_user_tenant = MockUserTenant() mock_first = MagicMock() @@ -259,48 +272,50 @@ def test_user_tenant_lifecycle(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + # Create a proper mock UserTenant class with attributes mock_user_tenant_class = MagicMock() mock_user_tenant_class.user_id = MagicMock() mock_user_tenant_class.delete_flag = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) monkeypatch.setattr("backend.database.user_tenant_db.UserTenant", mock_user_tenant_class) monkeypatch.setattr("backend.database.user_tenant_db.as_dict", lambda obj: obj.__dict__) - + # 1. Insert user tenant relationship - should not raise exception insert_user_tenant("test_user_id", "test_tenant_id") session.add.assert_called_once() - + # 2. Retrieve user tenant relationship result = get_user_tenant_by_user_id("test_user_id") assert result is not None assert result["user_id"] == "test_user_id" assert result["tenant_id"] == "test_tenant_id" + assert result["group_ids"] == "1,2,3" + assert result["user_role"] == "USER" assert result["delete_flag"] == "N" def test_get_user_tenant_by_user_id_with_deleted_record(monkeypatch, mock_session): """Test retrieval of user tenant relationship when record is marked as deleted""" session, query = mock_session - + # Mock a deleted record (should not be returned) mock_first = MagicMock() mock_first.return_value = None # Filter should exclude deleted records mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + result = get_user_tenant_by_user_id("deleted_user_id") - + assert result is None # Verify that the filter was called with correct conditions query.filter.assert_called_once() @@ -309,17 +324,17 @@ def test_get_user_tenant_by_user_id_with_deleted_record(monkeypatch, mock_sessio def test_get_all_tenant_ids_empty_database(monkeypatch, mock_session): """Test get_all_tenant_ids when database is empty - should return only DEFAULT_TENANT_ID""" session, query = mock_session - + # Mock empty database result query.filter.return_value.distinct.return_value.all.return_value = [] - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + result = get_all_tenant_ids() - + assert result == ["default_tenant"] # DEFAULT_TENANT_ID from consts_mock assert len(result) == 1 @@ -327,7 +342,7 @@ def test_get_all_tenant_ids_empty_database(monkeypatch, mock_session): def test_get_all_tenant_ids_with_existing_tenants(monkeypatch, mock_session): """Test get_all_tenant_ids with existing tenants - should include all plus DEFAULT_TENANT_ID""" session, query = mock_session - + # Mock database result with existing tenants mock_tenants = [ ("tenant_1",), @@ -335,14 +350,14 @@ def test_get_all_tenant_ids_with_existing_tenants(monkeypatch, mock_session): ("tenant_3",) ] query.filter.return_value.distinct.return_value.all.return_value = mock_tenants - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx) - + result = get_all_tenant_ids() - + assert len(result) == 4 # 3 existing + 1 default assert "tenant_1" in result assert "tenant_2" in result diff --git a/test/backend/utils/test_str_utils.py b/test/backend/utils/test_str_utils.py index a66487d93..a5429f80e 100644 --- a/test/backend/utils/test_str_utils.py +++ b/test/backend/utils/test_str_utils.py @@ -1,5 +1,5 @@ import pytest -from backend.utils.str_utils import remove_think_blocks +from backend.utils.str_utils import remove_think_blocks, convert_list_to_string class TestStrUtils: @@ -8,6 +8,7 @@ class TestStrUtils: def setup_method(self): """Setup before each test method""" self.remove_think_blocks = remove_think_blocks + self.convert_list_to_string = convert_list_to_string def test_remove_think_blocks_no_tags(self): """Text without any think tags remains unchanged""" @@ -63,6 +64,36 @@ def test_remove_think_blocks_case_insensitive(self): result = self.remove_think_blocks(text) assert result == " tags" + def test_convert_list_to_string_none_input(self): + """None input should return empty string""" + result = self.convert_list_to_string(None) + assert result == "" + + def test_convert_list_to_string_empty_list(self): + """Empty list should return empty string""" + result = self.convert_list_to_string([]) + assert result == "" + + def test_convert_list_to_string_single_item(self): + """Single item list should return single item as string""" + result = self.convert_list_to_string([42]) + assert result == "42" + + def test_convert_list_to_string_multiple_items(self): + """Multiple items should be joined with commas""" + result = self.convert_list_to_string([1, 2, 3]) + assert result == "1,2,3" + + def test_convert_list_to_string_mixed_types(self): + """List with mixed integer types should work correctly""" + result = self.convert_list_to_string([1, 2, 3, 10]) + assert result == "1,2,3,10" + + def test_convert_list_to_string_zero_and_negative(self): + """Zero and negative numbers should be handled correctly""" + result = self.convert_list_to_string([0, -1, 5]) + assert result == "0,-1,5" + if __name__ == "__main__": pytest.main() From 2405b03503148830d96b193c3677a2e36eaafb11 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Fri, 9 Jan 2026 21:28:39 +0800 Subject: [PATCH 03/38] =?UTF-8?q?=E2=9C=A8Added=20ModelEngine=20ENV=20Conf?= =?UTF-8?q?ig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/consts/const.py | 14 +- backend/services/config_sync_service.py | 3 +- backend/services/model_provider_service.py | 3 - docker/.env.example | 2 - docker/deploy.sh | 2 +- frontend/app/[locale]/layout.tsx | 25 ++- .../models/components/modelConfig.tsx | 113 +++++----- frontend/hooks/useConfig.ts | 27 ++- frontend/lib/config.ts | 209 ++++++++++-------- frontend/server.js | 88 ++++---- frontend/services/api.ts | 3 + frontend/services/configService.ts | 32 +-- frontend/types/modelConfig.ts | 2 +- 13 files changed, 293 insertions(+), 230 deletions(-) diff --git a/backend/consts/const.py b/backend/consts/const.py index 7fd0e0098..a0039e196 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -16,11 +16,6 @@ class VectorDatabaseType(str, Enum): ELASTICSEARCH = "elasticsearch" -# ModelEngine Configuration -MODEL_ENGINE_HOST = os.getenv('MODEL_ENGINE_HOST') -MODEL_ENGINE_APIKEY = os.getenv('MODEL_ENGINE_APIKEY') - - # Elasticsearch Configuration ES_HOST = os.getenv("ELASTICSEARCH_HOST") ES_API_KEY = os.getenv("ELASTICSEARCH_API_KEY") @@ -129,8 +124,10 @@ class VectorDatabaseType(str, Enum): DISABLE_CELERY_FLOWER = os.getenv( "DISABLE_CELERY_FLOWER", "false").lower() == "true" DOCKER_ENVIRONMENT = os.getenv("DOCKER_ENVIRONMENT", "false").lower() == "true" -NEXENT_MCP_DOCKER_IMAGE = os.getenv("NEXENT_MCP_DOCKER_IMAGE", "nexent/nexent-mcp:latest") -ENABLE_UPLOAD_IMAGE = os.getenv("ENABLE_UPLOAD_IMAGE", "false").lower() == "true" +NEXENT_MCP_DOCKER_IMAGE = os.getenv( + "NEXENT_MCP_DOCKER_IMAGE", "nexent/nexent-mcp:latest") +ENABLE_UPLOAD_IMAGE = os.getenv( + "ENABLE_UPLOAD_IMAGE", "false").lower() == "true" # Celery Configuration @@ -286,5 +283,8 @@ class VectorDatabaseType(str, Enum): DEFAULT_EN_TITLE = "New Conversation" +# Model Engine Configuration +MODEL_ENGINE_ENABLED = os.getenv("MODEL_ENGINE_ENABLED") + # APP Version APP_VERSION = "v1.7.9.1" diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 54b477a0e..464f3ab11 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -49,7 +49,8 @@ def handle_model_config(tenant_id: str, user_id: str, config_key: str, model_id: return current_model_id = tenant_config_dict.get(config_key) - current_model_id = int(current_model_id) if str(current_model_id).isdigit() else None + current_model_id = int(current_model_id) if str( + current_model_id).isdigit() else None if current_model_id == model_id: tenant_config_manager.update_single_config(tenant_id, config_key) diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 24fb5bc16..3e67a804f 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -9,12 +9,9 @@ DEFAULT_LLM_MAX_TOKENS, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE, - MODEL_ENGINE_HOST, - MODEL_ENGINE_APIKEY, ) from consts.model import ModelConnectStatusEnum, ModelRequest from consts.provider import SILICON_GET_URL, ProviderEnum -from consts.exceptions import TimeoutException from database.model_management_db import get_models_by_tenant_factory_type from services.model_health_service import embedding_dimension_check from utils.model_name_utils import split_repo_name, add_repo_to_name diff --git a/docker/.env.example b/docker/.env.example index 04a9cfa5a..bd1ad2ee5 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -71,8 +71,6 @@ REDIS_BACKEND_URL=redis://redis:6379/1 # Model Engine Config MODEL_ENGINE_ENABLED=false -MODEL_ENGINE_HOST="" -MODEL_ENGINE_API_KEY="" # Supabase Config DASHBOARD_USERNAME=supabase diff --git a/docker/deploy.sh b/docker/deploy.sh index 2545bf2dc..b149b1cd9 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -916,7 +916,7 @@ main_deploy() { echo "--------------------------------" echo "" - APP_VERSION="$(get_app_version)" + APP_VERSION="latest" if [ -z "$APP_VERSION" ]; then echo "❌ Failed to get app version, please check the backend/consts/const.py file" exit 1 diff --git a/frontend/app/[locale]/layout.tsx b/frontend/app/[locale]/layout.tsx index 26a6f9d6b..638cbc530 100644 --- a/frontend/app/[locale]/layout.tsx +++ b/frontend/app/[locale]/layout.tsx @@ -19,9 +19,9 @@ export async function generateMetadata(props: { params: Promise<{ locale?: string }>; }): Promise { const { locale } = await props.params; - const resolvedLocale = (["zh", "en"].includes(locale ?? "") - ? locale - : "zh") as "zh" | "en"; + const resolvedLocale = ( + ["zh", "en"].includes(locale ?? "") ? locale : "zh" + ) as "zh" | "en"; let messages: any = {}; if (["zh", "en"].includes(resolvedLocale)) { @@ -65,9 +65,12 @@ export default async function RootLayout({ params: Promise<{ locale?: string }>; }) { const { locale } = await params; - const resolvedLocale = (["zh", "en"].includes(locale ?? "") - ? locale - : "zh") as "zh" | "en"; + const resolvedLocale = ( + ["zh", "en"].includes(locale ?? "") ? locale : "zh" + ) as "zh" | "en"; + + // 获取环境变量 + const modelEngineEnabled = process.env.MODEL_ENGINE_ENABLED || "true"; return ( @@ -82,6 +85,16 @@ export default async function RootLayout({ {children} +