1717model = VisionEncoderDecoderModel .from_pretrained (default_model_name )
1818
1919def create_app ():
20+ """Creates and configures an instance of a Flask application.
21+
22+ Returns:
23+ Flask: The Flask application instance.
24+ """
2025 app = Flask (__name__ )
2126
2227 @app .route ('/ocr' , methods = ['POST' ])
2328 def ocr ():
29+ """Handles OCR processing for uploaded files through POST requests.
30+
31+ Retrieves file and OCR settings from the request, processes the file to extract text,
32+ and returns the extracted text or an error message in JSON format.
33+
34+ Returns:
35+ Response: JSON response containing extracted text or error message.
36+ """
2437 file = request .files ['file' ]
25- model_name = request .form .get ('model' , None )
38+ model_name = request .form .get ('model' , default_model_name )
39+
40+ # Parameters for segment_lines
41+ threshold_value = int (request .form .get ('threshold_value' , 150 ))
42+ kernel_width = int (request .form .get ('kernel_width' , 20 ))
43+ kernel_height = int (request .form .get ('kernel_height' , 1 ))
44+ min_area = int (request .form .get ('min_area' , 50 ))
45+
2646 try :
27- text = extract_text (file , model_name = model_name )
47+ text = extract_text (file , model_name = model_name , threshold_value = threshold_value ,
48+ kernel_width = kernel_width , kernel_height = kernel_height , min_area = min_area )
2849 return jsonify ({'text' : text })
2950 except Exception as e :
30- return jsonify ({'error' : str (e )}), 500
51+ return internal_server_error (e )
52+
3153
3254 @app .route ('/models' , methods = ['GET' ])
3355 def list_models ():
56+ """Provides a list of supported OCR models via a GET request.
57+
58+ Returns:
59+ Response: JSON response containing a list of supported OCR models.
60+ """
3461 models = get_supported_models ()
3562 return jsonify ({'supported_models' : models })
3663
3764 @app .errorhandler (400 )
3865 def bad_request (error ):
39- return jsonify ({'error' : 'Bad request' }), 400
66+ """Handles HTTP 400 errors by returning a JSON formatted bad request message along with error details.
67+
68+ Args:
69+ error: The error object provided by Flask.
70+
71+ Returns:
72+ Response: JSON response indicating a bad request and including the error description.
73+ """
74+ return jsonify ({'error' : 'Bad request' , 'details' : str (error )}), 400
75+
4076
4177 @app .errorhandler (500 )
4278 def internal_server_error (error ):
43- return jsonify ({ 'error' : 'Internal server error' }), 500
79+ """Handles HTTP 500 errors by returning a JSON formatted internal server error message along with error details.
4480
81+ Args:
82+ error: The error object provided by Flask.
83+
84+ Returns:
85+ Response: JSON response indicating an internal server error and including the error description.
86+ """
87+ return jsonify ({'error' : 'Internal server error' , 'details' : str (error )}), 500
88+
4589 return app
4690
47- def segment_lines (image ):
48- # Convert to grayscale
91+ def segment_lines (image , threshold_value = 150 , kernel_width = 20 , kernel_height = 1 , min_area = 50 ):
92+ """Segments an image into lines based on provided image processing parameters.
93+
94+ Args:
95+ image (Image): The image to process.
96+ threshold_value (int): Value for thresholding operation.
97+ kernel_width (int): Width of the kernel for morphological operations.
98+ kernel_height (int): Height of the kernel for morphological operations.
99+ min_area (int): Minimum area to consider a contour as a line.
100+
101+ Returns:
102+ list: A list of cropped images, each containing a line of text.
103+ """
49104 gray = cv2 .cvtColor (np .array (image ), cv2 .COLOR_BGR2GRAY )
105+ _ , thresh = cv2 .threshold (gray , threshold_value , 255 , cv2 .THRESH_BINARY_INV )
50106
51- # Apply adaptive thresholding
52- thresh = cv2 .adaptiveThreshold (gray , 255 , cv2 .ADAPTIVE_THRESH_GAUSSIAN_C ,
53- cv2 .THRESH_BINARY_INV , 11 , 2 )
54-
55- # Use morphological operations to close gaps in between lines of text
56- kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT , (20 , 1 )) # Adjust the kernel size as needed
107+ kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT , (kernel_width , kernel_height ))
57108 thresh = cv2 .morphologyEx (thresh , cv2 .MORPH_CLOSE , kernel )
58109
59- # Find contours
60110 contours , _ = cv2 .findContours (thresh , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE )
61111
62112 lines = []
63- min_area = 50 # Adjust this value based on expected smallest area of text line
64113 for cnt in contours :
65114 area = cv2 .contourArea (cnt )
66115 if area > min_area :
@@ -70,7 +119,18 @@ def segment_lines(image):
70119 return lines
71120
72121
73- def extract_text (file_stream , model_name = None ):
122+ def extract_text (file_stream , model_name = None , threshold_value = 150 , kernel_width = 20 , kernel_height = 1 , min_area = 50 ):
123+ """Extracts text from a PDF file using OCR, configurable via POST parameters.
124+
125+ Args:
126+ file_stream (io.BytesIO): The file stream of the PDF.
127+ model_name (str): The model name for the OCR processor, defaults to a pre-set model.
128+ threshold_value (int), kernel_width (int), kernel_height (int), min_area (int):
129+ Parameters forwarded to the segment_lines function for image processing.
130+
131+ Returns:
132+ str: Extracted text from the PDF document.
133+ """
74134 global processor , model
75135 if model_name and model_name != default_model_name :
76136 try :
@@ -85,7 +145,6 @@ def extract_text(file_stream, model_name=None):
85145 device = 'cuda' if torch .cuda .is_available () else 'cpu'
86146 model = model .to (device )
87147
88- # Create a directory to store the line images for debugging
89148 debug_dir = "debug_line_images"
90149 os .makedirs (debug_dir , exist_ok = True )
91150 line_counter = 0
@@ -94,14 +153,12 @@ def extract_text(file_stream, model_name=None):
94153 img = page .get_pixmap ()
95154 img_bytes = img .tobytes ()
96155 image = Image .open (io .BytesIO (img_bytes ))
97- lines = segment_lines (image )
156+ lines = segment_lines (image , threshold_value , kernel_width , kernel_height , min_area )
98157 for line in lines :
99- # Save each line image for debugging
100158 line_image_path = os .path .join (debug_dir , f"line_{ page_number } _{ line_counter } .png" )
101159 line .save (line_image_path )
102160 line_counter += 1
103161
104- # Continue with OCR processing
105162 inputs = processor (images = line , return_tensors = "pt" ).to (device )
106163 outputs = model .generate (** inputs )
107164 text += processor .batch_decode (outputs , skip_special_tokens = True )[0 ] + "\n "
@@ -111,6 +168,11 @@ def extract_text(file_stream, model_name=None):
111168 return text
112169
113170def get_supported_models ():
171+ """Lists all OCR models supported by the application.
172+
173+ Returns:
174+ list: A list of supported model identifiers.
175+ """
114176 return [
115177 "microsoft/trocr-large-handwritten" ,
116178 "microsoft/trocr-large-printed" ,
0 commit comments