33import asyncio
44from http .client import HTTPException
55import io
6- from fastapi .responses import StreamingResponse
6+ import sys
7+ import tempfile
8+ import uuid
9+ import atexit
10+ import threading
11+ from fastapi .responses import StreamingResponse , FileResponse
12+ import os
13+ from openai import BaseModel
714import pandas as pd
15+ from apps .system .models .user import UserModel
16+ from common .core .deps import SessionDep
817
918
19+ class RowValidator :
20+ def __init__ (self , success : bool = False , row = list [str ], error_info : dict = None ):
21+ self .success = success
22+ self .row = row
23+ self .dict_data = {}
24+ self .error_info = error_info or {}
25+ class CellValidator :
26+ def __init__ (self , success : bool = False , value : str | int | list = None , message : str = "" ):
27+ self .success = success
28+ self .value = value
29+ self .message = message
30+
31+ class UploadResultDTO (BaseModel ):
32+ successCount : int
33+ errorCount : int
34+ dataKey : str | None = None
35+
36+
1037async def downTemplate (trans ):
1138 def inner ():
1239 data = {
@@ -57,8 +84,257 @@ def inner():
5784 result = await asyncio .to_thread (inner )
5885 return StreamingResponse (result , media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" )
5986
60- async def batchUpload (trans , file ):
87+ async def batchUpload (session : SessionDep , trans , file ) -> UploadResultDTO :
6188 ALLOWED_EXTENSIONS = {"xlsx" , "xls" }
6289 if not file .filename .lower ().endswith (tuple (ALLOWED_EXTENSIONS )):
6390 raise HTTPException (400 , "Only support .xlsx/.xls" )
64- pass
91+
92+ # Support FastAPI UploadFile (async read) and file-like objects.
93+ NA_VALUES = ['' , 'NA' , 'N/A' , 'NULL' ]
94+ df = None
95+ # If file provides an async read (UploadFile), read bytes first
96+ if hasattr (file , 'read' ) and asyncio .iscoroutinefunction (getattr (file , 'read' )):
97+ content = await file .read ()
98+ df = pd .read_excel (io .BytesIO (content ), sheet_name = 0 , na_values = NA_VALUES )
99+ else :
100+ # If it's a Starlette UploadFile-like with a .file attribute, use that
101+ if hasattr (file , 'file' ):
102+ fobj = file .file
103+ try :
104+ fobj .seek (0 )
105+ except Exception :
106+ pass
107+ df = pd .read_excel (fobj , sheet_name = 0 , na_values = NA_VALUES )
108+ else :
109+ # fallback: assume a path or file-like object
110+ try :
111+ file .seek (0 )
112+ except Exception :
113+ pass
114+ df = pd .read_excel (file , sheet_name = 0 , na_values = NA_VALUES )
115+ head_list = list (df .columns )
116+ i18n_head_list = get_i18n_head_list ()
117+ if not validate_head (trans = trans , head_i18n_list = i18n_head_list , head_list = head_list ):
118+ raise HTTPException (400 , "Excel header validation failed" )
119+ success_list = []
120+ error_list = []
121+ for row in df .itertuples ():
122+ row_validator = validate_row (trans = trans , head_i18n_list = i18n_head_list , row = row )
123+ if row_validator .success :
124+ success_list .append (row_validator .dict_data )
125+ else :
126+ error_list .append (row_validator )
127+ error_file_id = None
128+ if error_list :
129+ error_file_id = generate_error_file (error_list , head_list )
130+ result = UploadResultDTO (successCount = len (success_list ), errorCount = len (error_list ), dataKey = error_file_id )
131+ if success_list :
132+ user_po_list = [UserModel .model_validate (row ) for row in success_list ]
133+ session .add_all (user_po_list )
134+ session .commit ()
135+ return result
136+
137+ def get_i18n_head_list ():
138+ return [
139+ 'i18n_user.account' ,
140+ 'i18n_user.name' ,
141+ 'i18n_user.email' ,
142+ 'i18n_user.workspace' ,
143+ 'i18n_user.role' ,
144+ 'i18n_user.status' ,
145+ 'i18n_user.origin' ,
146+ 'i18n_user.platform_user_id' ,
147+ ]
148+
149+ def validate_head (trans , head_i18n_list : list [str ], head_list : list ):
150+ if len (head_list ) != len (head_i18n_list ):
151+ return False
152+ for i in range (len (head_i18n_list )):
153+ if head_list [i ] != trans (head_i18n_list [i ]):
154+ return False
155+ return True
156+
157+
158+
159+ def validate_row (trans , head_i18n_list : list [str ], row ):
160+ validator = RowValidator (success = True , row = [], error_info = {})
161+ for i in range (len (head_i18n_list )):
162+ col_name = trans (head_i18n_list [i ])
163+ row_value = getattr (row , col_name )
164+ validator .row .append (row_value )
165+ _attr_name = f"{ head_i18n_list [i ].split ('.' )[- 1 ]} "
166+ _method_name = f"validate_{ _attr_name } "
167+ cellValidator = dynamic_call (_method_name , row_value )
168+ if not cellValidator .success :
169+ validator .success = False
170+ validator .error_info [i ] = cellValidator .message
171+ else :
172+ validator .dict_data [_attr_name ] = cellValidator .value
173+ return validator
174+
175+ def generate_error_file (error_list : list [RowValidator ], head_list : list [str ]) -> str :
176+ # If no errors, return empty string
177+ if not error_list :
178+ return ""
179+
180+ # Build DataFrame from error rows (only include rows that had errors)
181+ df_rows = [err .row for err in error_list ]
182+ df = pd .DataFrame (df_rows , columns = head_list )
183+
184+ tmp = tempfile .NamedTemporaryFile (delete = False , suffix = ".xlsx" )
185+ tmp_name = tmp .name
186+ tmp .close ()
187+
188+ with pd .ExcelWriter (tmp_name , engine = 'xlsxwriter' , engine_kwargs = {'options' : {'strings_to_numbers' : False }}) as writer :
189+ df .to_excel (writer , sheet_name = 'Errors' , index = False )
190+
191+ workbook = writer .book
192+ worksheet = writer .sheets ['Errors' ]
193+
194+ # header format similar to downTemplate
195+ header_format = workbook .add_format ({
196+ 'bold' : True ,
197+ 'font_size' : 12 ,
198+ 'font_name' : '微软雅黑' ,
199+ 'align' : 'center' ,
200+ 'valign' : 'vcenter' ,
201+ 'border' : 0 ,
202+ 'text_wrap' : False ,
203+ })
204+
205+ # apply header format and column widths
206+ for i , col in enumerate (df .columns ):
207+ max_length = max (
208+ len (str (col ).encode ('utf-8' )) * 1.1 ,
209+ (df [col ].astype (str )).apply (len ).max () if len (df ) > 0 else 0
210+ )
211+ worksheet .set_column (i , i , max_length + 12 )
212+ worksheet .write (0 , i , col , header_format )
213+
214+ worksheet .set_row (0 , 30 )
215+ for row_idx in range (1 , len (df ) + 1 ):
216+ worksheet .set_row (row_idx , 25 )
217+
218+ red_format = workbook .add_format ({'font_color' : 'red' })
219+
220+ # Add comments and set red font for each erroneous cell.
221+ # Note: pandas wrote header at row 0, data starts from row 1 in the sheet.
222+ for sheet_row_idx , err in enumerate (error_list , start = 1 ):
223+ for col_idx , message in err .error_info .items ():
224+ if message :
225+ comment_text = str (message )
226+ worksheet .write_comment (sheet_row_idx , col_idx , comment_text )
227+ try :
228+ cell_value = df .iat [sheet_row_idx - 1 , col_idx ]
229+ except Exception :
230+ cell_value = None
231+ worksheet .write (sheet_row_idx , col_idx , cell_value , red_format )
232+
233+ # register temp file in map and return an opaque file id
234+ file_id = uuid .uuid4 ().hex
235+ with _TEMP_FILE_LOCK :
236+ _TEMP_FILE_MAP [file_id ] = tmp_name
237+
238+ return file_id
239+
240+
241+ def download_error_file (file_id : str ) -> FileResponse :
242+ """Return a FileResponse for the given generated file id.
243+
244+ Look up the actual temp path from the internal map. Only files
245+ created by `generate_error_file` are allowed.
246+ """
247+ if not file_id :
248+ raise HTTPException (400 , "file_id required" )
249+
250+ with _TEMP_FILE_LOCK :
251+ file_path = _TEMP_FILE_MAP .get (file_id )
252+
253+ if not file_path :
254+ raise HTTPException (404 , "File not found" )
255+
256+ # ensure file is inside tempdir
257+ tempdir = tempfile .gettempdir ()
258+ try :
259+ common = os .path .commonpath ([tempdir , os .path .abspath (file_path )])
260+ except Exception :
261+ raise HTTPException (403 , "Unauthorized file access" )
262+
263+ if os .path .abspath (common ) != os .path .abspath (tempdir ):
264+ raise HTTPException (403 , "Unauthorized file access" )
265+
266+ if not os .path .exists (file_path ):
267+ raise HTTPException (404 , "File not found" )
268+
269+ return FileResponse (
270+ path = file_path ,
271+ media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ,
272+ filename = os .path .basename (file_path ),
273+ )
274+
275+ def validate_account (value : str ) -> CellValidator :
276+ return CellValidator (True , value , None )
277+ def validate_name (value : str ) -> CellValidator :
278+ return CellValidator (True , value , None )
279+ def validate_email (value : str ) -> CellValidator :
280+ return CellValidator (True , value , None )
281+ def validate_workspace (value : str ) -> CellValidator :
282+ return CellValidator (True , value , None )
283+ def validate_role (value : str ) -> CellValidator :
284+ return CellValidator (True , value , None )
285+ def validate_status (value : str ) -> CellValidator :
286+ if value == '已启用' : return CellValidator (True , 1 , None )
287+ if value == '已禁用' : return CellValidator (True , 0 , None )
288+ return CellValidator (False , None , "状态只能是已启用或已禁用" )
289+ def validate_origin (value : str ) -> CellValidator :
290+ if value == '本地创建' : return CellValidator (True , 0 , None )
291+ return CellValidator (False , None , "不支持当前来源" )
292+ def validate_platform_id (value : str ) -> CellValidator :
293+ return CellValidator (True , value , None )
294+
295+ _method_cache = {
296+ 'validate_account' : validate_account ,
297+ 'validate_name' : validate_name ,
298+ 'validate_email' : validate_email ,
299+ 'validate_workspace' : validate_workspace ,
300+ 'validate_role' : validate_role ,
301+ 'validate_status' : validate_status ,
302+ 'validate_origin' : validate_origin ,
303+ 'validate_platform_user_id' : validate_platform_id ,
304+ }
305+ _module = sys .modules [__name__ ]
306+ def dynamic_call (method_name : str , * args , ** kwargs ):
307+ if method_name in _method_cache :
308+ return _method_cache [method_name ](* args , ** kwargs )
309+
310+ if hasattr (_module , method_name ):
311+ func = getattr (_module , method_name )
312+ _method_cache [method_name ] = func
313+ return func (* args , ** kwargs )
314+
315+ raise AttributeError (f"Function '{ method_name } ' not found" )
316+
317+
318+ # Map of file_id -> temp path for generated error files
319+ _TEMP_FILE_MAP : dict [str , str ] = {}
320+ _TEMP_FILE_LOCK = threading .Lock ()
321+
322+
323+ def _cleanup_temp_files ():
324+ with _TEMP_FILE_LOCK :
325+ for fid , path in list (_TEMP_FILE_MAP .items ()):
326+ try :
327+ if os .path .exists (path ):
328+ os .remove (path )
329+ except Exception :
330+ pass
331+ _TEMP_FILE_MAP .clear ()
332+
333+
334+ atexit .register (_cleanup_temp_files )
335+
336+
337+
338+
339+
340+
0 commit comments