diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 3e0e60d7..7ca3d571 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -10,9 +10,9 @@ from sqlalchemy import and_, select from starlette.responses import JSONResponse -from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ +from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, list_chats, get_chat_with_records, create_chat, rename_chat, \ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \ - format_json_data, format_json_list_data, get_chart_config, list_recent_questions,get_chat as get_chat_exec + format_json_data, format_json_list_data, get_chart_config, list_recent_questions,get_chat as get_chat_exec, rename_chat_with_user from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \ ChatInfo, Chat, ChatFinishStep from apps.chat.task.llm import LLMService @@ -52,7 +52,7 @@ def inner(): return await asyncio.to_thread(inner) -@router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data") +""" @router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data") async def chat_record_data(session: SessionDep, chat_record_id: int): def inner(): data = get_chat_chart_data(chat_record_id=chat_record_id, session=session) @@ -67,10 +67,27 @@ def inner(): data = get_chat_predict_data(chat_record_id=chat_record_id, session=session) return format_json_list_data(data) + return await asyncio.to_thread(inner) """ + +@router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data") +async def chat_record_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int): + def inner(): + data = get_chart_data_with_user(chat_record_id=chat_record_id, session=session, current_user=current_user) + return format_json_data(data) + return await asyncio.to_thread(inner) -@router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat") +@router.get("/record/{chat_record_id}/predict_data", summary=f"{PLACEHOLDER_PREFIX}get_chart_predict_data") +async def chat_predict_data(session: SessionDep, current_user: CurrentUser, chat_record_id: int): + def inner(): + data = get_chat_predict_data_with_user(chat_record_id=chat_record_id, session=session, current_user=current_user) + return format_json_list_data(data) + + return await asyncio.to_thread(inner) + + +""" @router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat") @system_log(LogConfig( operation_type=OperationType.UPDATE, module=OperationModules.CHAT, @@ -83,10 +100,24 @@ async def rename(session: SessionDep, chat: RenameChat): raise HTTPException( status_code=500, detail=str(e) - ) + ) """ +@router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat") +@system_log(LogConfig( + operation_type=OperationType.UPDATE, + module=OperationModules.CHAT, + resource_id_expr="chat.id" +)) +async def rename(session: SessionDep, current_user: CurrentUser, chat: RenameChat): + try: + return rename_chat_with_user(session=session, current_user=current_user, rename_object=chat) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) -@router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat") +""" @router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat") @system_log(LogConfig( operation_type=OperationType.DELETE, module=OperationModules.CHAT, @@ -100,8 +131,23 @@ async def delete(session: SessionDep, chart_id: int, brief: str): raise HTTPException( status_code=500, detail=str(e) - ) + ) """ +@router.delete("/{chart_id}/{brief}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat") +@system_log(LogConfig( + operation_type=OperationType.DELETE, + module=OperationModules.CHAT, + resource_id_expr="chart_id", + remark_expr="brief" +)) +async def delete(session: SessionDep, current_user: CurrentUser, chart_id: int, brief: str): + try: + return delete_chat_with_user(session=session, current_user=current_user, chart_id=chart_id) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) @router.post("/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}start_chat") @require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource")) @@ -391,14 +437,18 @@ def _err(_e: Exception): @router.get("/record/{chat_record_id}/excel/export", summary=f"{PLACEHOLDER_PREFIX}export_chart_data") -async def export_excel(session: SessionDep, chat_record_id: int, trans: Trans): +async def export_excel(session: SessionDep, current_user: CurrentUser, chat_record_id: int, trans: Trans): chat_record = session.get(ChatRecord, chat_record_id) if not chat_record: raise HTTPException( status_code=500, detail=f"ChatRecord with id {chat_record_id} not found" ) - + if chat_record.create_by != current_user.id: + raise HTTPException( + status_code=500, + detail=f"ChatRecord with id {chat_record_id} not Owned by the current user" + ) is_predict_data = chat_record.predict_record_id is not None _origin_data = format_json_data(get_chat_chart_data(chat_record_id=chat_record_id, session=session)) diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index ae425ea0..2baa0910 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -59,6 +59,21 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso ) return [record[0] for record in chat_records] if chat_records else [] +def rename_chat_with_user(session: SessionDep, current_user: CurrentUser, rename_object: RenameChat) -> str: + chat = session.get(Chat, rename_object.id) + if not chat: + raise Exception(f"Chat with id {rename_object.id} not found") + if chat.create_by != current_user.id: + raise Exception(f"Chat with id {rename_object.id} not Owned by the current user") + chat.brief = rename_object.brief.strip()[:20] + chat.brief_generate = rename_object.brief_generate + session.add(chat) + session.flush() + session.refresh(chat) + + brief = chat.brief + session.commit() + return brief def rename_chat(session: SessionDep, rename_object: RenameChat) -> str: chat = session.get(Chat, rename_object.id) @@ -86,6 +101,17 @@ def delete_chat(session, chart_id) -> str: return f'Chat with id {chart_id} has been deleted' +def delete_chat_with_user(session, current_user: CurrentUser, chart_id) -> str: + chat = session.query(Chat).filter(Chat.id == chart_id).first() + if not chat: + return f'Chat with id {chart_id} has been deleted' + if chat.create_by != current_user.id: + raise Exception(f"Chat with id {chart_id} not Owned by the current user") + session.delete(chat) + session.commit() + + return f'Chat with id {chart_id} has been deleted' + def get_chart_config(session: SessionDep, chart_record_id: int): stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chart_record_id)) @@ -173,6 +199,15 @@ def get_chat_chart_config(session: SessionDep, chat_record_id: int): pass return {} +def get_chart_data_with_user(session: SessionDep, current_user: CurrentUser, chat_record_id: int): + stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id, ChatRecord.create_by == current_user.id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.data) + except Exception: + pass + return {} def get_chat_chart_data(session: SessionDep, chat_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id)) @@ -184,6 +219,15 @@ def get_chat_chart_data(session: SessionDep, chat_record_id: int): pass return {} +def get_chat_predict_data_with_user(session: SessionDep, current_user: CurrentUser, chat_record_id: int): + stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chat_record_id, ChatRecord.create_by == current_user.id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.predict_data) + except Exception: + pass + return {} def get_chat_predict_data(session: SessionDep, chat_record_id: int): stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chat_record_id)) @@ -210,7 +254,8 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr chat = session.get(Chat, chart_id) if not chat: raise Exception(f"Chat with id {chart_id} not found") - + if chat.create_by != current_user.id: + raise Exception(f"Chat with id {chart_id} not Owned by the current user") chat_info = ChatInfo(**chat.model_dump()) if current_assistant and current_assistant.type in dynamic_ds_types: diff --git a/backend/apps/dashboard/api/dashboard_api.py b/backend/apps/dashboard/api/dashboard_api.py index 55b6054a..8ccc0c98 100644 --- a/backend/apps/dashboard/api/dashboard_api.py +++ b/backend/apps/dashboard/api/dashboard_api.py @@ -19,8 +19,15 @@ async def list_resource_api(session: SessionDep, dashboard: QueryDashboard, curr @router.post("/load_resource", summary=f"{PLACEHOLDER_PREFIX}load_resource_api") -async def load_resource_api(session: SessionDep, dashboard: QueryDashboard): - return load_resource(session=session, dashboard=dashboard) +async def load_resource_api(session: SessionDep, current_user: CurrentUser, dashboard: QueryDashboard): + resource_dict = load_resource(session=session, dashboard=dashboard) + if resource_dict and resource_dict.get("create_by") != str(current_user.id): + raise HTTPException( + status_code=403, + detail="You do not have permission to access this resource" + ) + + return resource_dict @router.post("/create_resource", response_model=BaseDashboard, summary=f"{PLACEHOLDER_PREFIX}create_resource_api") @@ -45,8 +52,8 @@ async def update_resource_api(session: SessionDep, user: CurrentUser, dashboard: resource_id_expr="resource_id", remark_expr="name" )) -async def delete_resource_api(session: SessionDep, resource_id: str, name: str): - return delete_resource(session, resource_id) +async def delete_resource_api(session: SessionDep, current_user: CurrentUser, resource_id: str, name: str): + return delete_resource(session, current_user, resource_id) @router.post("/create_canvas", response_model=BaseDashboard, summary=f"{PLACEHOLDER_PREFIX}create_canvas_api") diff --git a/backend/apps/dashboard/crud/dashboard_service.py b/backend/apps/dashboard/crud/dashboard_service.py index 6abc5df4..566636af 100644 --- a/backend/apps/dashboard/crud/dashboard_service.py +++ b/backend/apps/dashboard/crud/dashboard_service.py @@ -27,7 +27,7 @@ def list_resource(session: SessionDep, dashboard: QueryDashboard, current_user: nodes = [DashboardBaseResponse(**row) for row in result.mappings()] tree = build_tree_generic(nodes, root_pid="root") return tree - + def load_resource(session: SessionDep, dashboard: QueryDashboard): sql = text(""" @@ -130,7 +130,12 @@ def validate_name(session: SessionDep,user: CurrentUser, dashboard: QueryDashbo return not session.query(query.exists()).scalar() -def delete_resource(session: SessionDep, resource_id: str): +def delete_resource(session: SessionDep, current_user: CurrentUser, resource_id: str): + coreDashboard = session.get(CoreDashboard, resource_id) + if not coreDashboard: + raise ValueError(f"Resource with id {resource_id} does not exist") + if coreDashboard.create_by != str(current_user.id): + raise ValueError(f"Resource with id {resource_id} not owned by the current user") sql = text("DELETE FROM core_dashboard WHERE id = :resource_id") result = session.execute(sql, {"resource_id": resource_id}) session.commit() diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index b6aa04a4..e05306f4 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -83,6 +83,7 @@ def inner(): @router.post("/chooseTables/{id}", response_model=None, summary=f"{PLACEHOLDER_PREFIX}ds_choose_tables") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'], permission=SqlbotPermission(type='ds', keyExpression="id"))) async def choose_tables(session: SessionDep, trans: Trans, tables: List[CoreTable], id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")): def inner(): @@ -117,6 +118,7 @@ async def get_tables(session: SessionDep, id: int = Path(..., description=f"{PLA @router.post("/getTablesByConf", response_model=List[TableSchemaResponse], summary=f"{PLACEHOLDER_PREFIX}ds_get_tables") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def get_tables_by_conf(session: SessionDep, trans: Trans, ds: CoreDatasource): try: def inner(): @@ -135,6 +137,7 @@ def inner(): @router.post("/getSchemaByConf", response_model=List[str], summary=f"{PLACEHOLDER_PREFIX}ds_get_schema") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def get_schema_by_conf(session: SessionDep, trans: Trans, ds: CoreDatasource): try: def inner(): @@ -154,6 +157,7 @@ def inner(): @router.post("/getFields/{id}/{table_name}", response_model=List[ColumnSchemaResponse], summary=f"{PLACEHOLDER_PREFIX}ds_get_fields") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="id")) async def get_fields(session: SessionDep, id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id"), table_name: str = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_table_name")): @@ -174,7 +178,7 @@ class TestObj(BaseModel): # not used, just do test -@router.post("/execSql/{id}", include_in_schema=False) +""" @router.post("/execSql/{id}", include_in_schema=False) async def exec_sql(session: SessionDep, id: int, obj: TestObj): def inner(): data = execSql(session, id, obj.sql) @@ -187,31 +191,36 @@ def inner(): return data - return await asyncio.to_thread(inner) + return await asyncio.to_thread(inner) """ @router.post("/tableList/{id}", response_model=List[CoreTable], summary=f"{PLACEHOLDER_PREFIX}ds_table_list") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="id")) async def table_list(session: SessionDep, id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")): return get_tables_by_ds_id(session, id) @router.post("/fieldList/{id}", response_model=List[CoreField], summary=f"{PLACEHOLDER_PREFIX}ds_field_list") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def field_list(session: SessionDep, field: FieldObj, id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_table_id")): return get_fields_by_table_id(session, id, field) @router.post("/editLocalComment", include_in_schema=False) +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def edit_local(session: SessionDep, data: TableObj): update_table_and_fields(session, data) @router.post("/editTable", response_model=None, summary=f"{PLACEHOLDER_PREFIX}ds_edit_table") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def edit_table(session: SessionDep, table: CoreTable): updateTable(session, table) @router.post("/editField", response_model=None, summary=f"{PLACEHOLDER_PREFIX}ds_edit_field") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def edit_field(session: SessionDep, field: CoreField): updateField(session, field)