Skip to content

Commit 9bc2679

Browse files
committed
Update app.py
Added inline comments. Optional parameters are now read from the form data, such as model, threshold_value, kernel_width, kernel_height, and min_area. Defaults are used for any parameters not provided.
1 parent 487babf commit 9bc2679

File tree

1 file changed

+82
-20
lines changed

1 file changed

+82
-20
lines changed

src/app.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,99 @@
1717
model = VisionEncoderDecoderModel.from_pretrained(default_model_name)
1818

1919
def create_app():
20+
"""Creates and configures an instance of a Flask application.
21+
22+
Returns:
23+
Flask: The Flask application instance.
24+
"""
2025
app = Flask(__name__)
2126

2227
@app.route('/ocr', methods=['POST'])
2328
def ocr():
29+
"""Handles OCR processing for uploaded files through POST requests.
30+
31+
Retrieves file and OCR settings from the request, processes the file to extract text,
32+
and returns the extracted text or an error message in JSON format.
33+
34+
Returns:
35+
Response: JSON response containing extracted text or error message.
36+
"""
2437
file = request.files['file']
25-
model_name = request.form.get('model', None)
38+
model_name = request.form.get('model', default_model_name)
39+
40+
# Parameters for segment_lines
41+
threshold_value = int(request.form.get('threshold_value', 150))
42+
kernel_width = int(request.form.get('kernel_width', 20))
43+
kernel_height = int(request.form.get('kernel_height', 1))
44+
min_area = int(request.form.get('min_area', 50))
45+
2646
try:
27-
text = extract_text(file, model_name=model_name)
47+
text = extract_text(file, model_name=model_name, threshold_value=threshold_value,
48+
kernel_width=kernel_width, kernel_height=kernel_height, min_area=min_area)
2849
return jsonify({'text': text})
2950
except Exception as e:
30-
return jsonify({'error': str(e)}), 500
51+
return internal_server_error(e)
52+
3153

3254
@app.route('/models', methods=['GET'])
3355
def list_models():
56+
"""Provides a list of supported OCR models via a GET request.
57+
58+
Returns:
59+
Response: JSON response containing a list of supported OCR models.
60+
"""
3461
models = get_supported_models()
3562
return jsonify({'supported_models': models})
3663

3764
@app.errorhandler(400)
3865
def bad_request(error):
39-
return jsonify({'error': 'Bad request'}), 400
66+
"""Handles HTTP 400 errors by returning a JSON formatted bad request message along with error details.
67+
68+
Args:
69+
error: The error object provided by Flask.
70+
71+
Returns:
72+
Response: JSON response indicating a bad request and including the error description.
73+
"""
74+
return jsonify({'error': 'Bad request', 'details': str(error)}), 400
75+
4076

4177
@app.errorhandler(500)
4278
def internal_server_error(error):
43-
return jsonify({'error': 'Internal server error'}), 500
79+
"""Handles HTTP 500 errors by returning a JSON formatted internal server error message along with error details.
4480
81+
Args:
82+
error: The error object provided by Flask.
83+
84+
Returns:
85+
Response: JSON response indicating an internal server error and including the error description.
86+
"""
87+
return jsonify({'error': 'Internal server error', 'details': str(error)}), 500
88+
4589
return app
4690

47-
def segment_lines(image):
48-
# Convert to grayscale
91+
def segment_lines(image, threshold_value=150, kernel_width=20, kernel_height=1, min_area=50):
92+
"""Segments an image into lines based on provided image processing parameters.
93+
94+
Args:
95+
image (Image): The image to process.
96+
threshold_value (int): Value for thresholding operation.
97+
kernel_width (int): Width of the kernel for morphological operations.
98+
kernel_height (int): Height of the kernel for morphological operations.
99+
min_area (int): Minimum area to consider a contour as a line.
100+
101+
Returns:
102+
list: A list of cropped images, each containing a line of text.
103+
"""
49104
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY)
105+
_, thresh = cv2.threshold(gray, threshold_value, 255, cv2.THRESH_BINARY_INV)
50106

51-
# Apply adaptive thresholding
52-
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
53-
cv2.THRESH_BINARY_INV, 11, 2)
54-
55-
# Use morphological operations to close gaps in between lines of text
56-
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (20, 1)) # Adjust the kernel size as needed
107+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_width, kernel_height))
57108
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
58109

59-
# Find contours
60110
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
61111

62112
lines = []
63-
min_area = 50 # Adjust this value based on expected smallest area of text line
64113
for cnt in contours:
65114
area = cv2.contourArea(cnt)
66115
if area > min_area:
@@ -70,7 +119,18 @@ def segment_lines(image):
70119
return lines
71120

72121

73-
def extract_text(file_stream, model_name=None):
122+
def extract_text(file_stream, model_name=None, threshold_value=150, kernel_width=20, kernel_height=1, min_area=50):
123+
"""Extracts text from a PDF file using OCR, configurable via POST parameters.
124+
125+
Args:
126+
file_stream (io.BytesIO): The file stream of the PDF.
127+
model_name (str): The model name for the OCR processor, defaults to a pre-set model.
128+
threshold_value (int), kernel_width (int), kernel_height (int), min_area (int):
129+
Parameters forwarded to the segment_lines function for image processing.
130+
131+
Returns:
132+
str: Extracted text from the PDF document.
133+
"""
74134
global processor, model
75135
if model_name and model_name != default_model_name:
76136
try:
@@ -85,7 +145,6 @@ def extract_text(file_stream, model_name=None):
85145
device = 'cuda' if torch.cuda.is_available() else 'cpu'
86146
model = model.to(device)
87147

88-
# Create a directory to store the line images for debugging
89148
debug_dir = "debug_line_images"
90149
os.makedirs(debug_dir, exist_ok=True)
91150
line_counter = 0
@@ -94,14 +153,12 @@ def extract_text(file_stream, model_name=None):
94153
img = page.get_pixmap()
95154
img_bytes = img.tobytes()
96155
image = Image.open(io.BytesIO(img_bytes))
97-
lines = segment_lines(image)
156+
lines = segment_lines(image, threshold_value, kernel_width, kernel_height, min_area)
98157
for line in lines:
99-
# Save each line image for debugging
100158
line_image_path = os.path.join(debug_dir, f"line_{page_number}_{line_counter}.png")
101159
line.save(line_image_path)
102160
line_counter += 1
103161

104-
# Continue with OCR processing
105162
inputs = processor(images=line, return_tensors="pt").to(device)
106163
outputs = model.generate(**inputs)
107164
text += processor.batch_decode(outputs, skip_special_tokens=True)[0] + "\n"
@@ -111,6 +168,11 @@ def extract_text(file_stream, model_name=None):
111168
return text
112169

113170
def get_supported_models():
171+
"""Lists all OCR models supported by the application.
172+
173+
Returns:
174+
list: A list of supported model identifiers.
175+
"""
114176
return [
115177
"microsoft/trocr-large-handwritten",
116178
"microsoft/trocr-large-printed",

0 commit comments

Comments
 (0)