11import argparse
22import logging
33import os
4+ import re
45from collections .abc import Iterable
56from concurrent .futures import ThreadPoolExecutor
67from pathlib import Path
1718 DocItem ,
1819 GraphCell ,
1920 KeyValueItem ,
21+ FormItem ,
2022 PictureItem ,
2123 RichTableCell ,
2224 TableCell ,
5456LM_STUDIO_MODEL = "nanonets-ocr2-3b"
5557
5658DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally. Output pure text, no html and no markdown. Pay attention on line breaks and don't miss text after line break. Put all text in one line."
57- VERBOSE = False
59+ VERBOSE = True
5860SHOW_IMAGE = False
5961
6062
63+ def safe_crop (img : Image .Image , bbox ):
64+ left , top , right , bottom = bbox
65+ # Clamp to image boundaries
66+ left = max (0 , min (left , img .width ))
67+ top = max (0 , min (top , img .height ))
68+ right = max (0 , min (right , img .width ))
69+ bottom = max (0 , min (bottom , img .height ))
70+ return img .crop ((left , top , right , bottom ))
71+
72+
73+ def no_long_repeats (s : str , threshold : int ) -> bool :
74+ """
75+ Returns False if the string `s` contains more than `threshold`
76+ identical characters in a row, otherwise True.
77+ """
78+ pattern = r'(.)\1{' + str (threshold ) + ',}'
79+ return re .search (pattern , s ) is None
80+
81+
6182class PostOcrEnrichmentElement (BaseModel ):
6283 model_config = ConfigDict (arbitrary_types_allowed = True )
6384
@@ -136,7 +157,7 @@ def prepare_element(
136157 allowed = (DocItem , TableItem , GraphCell )
137158 assert isinstance (element , allowed )
138159
139- if isinstance (element , KeyValueItem ):
160+ if isinstance (element , ( KeyValueItem , FormItem ) ):
140161 # Yield from the graphCells inside here.
141162 result = []
142163 for c in element .graph .cells :
@@ -164,6 +185,9 @@ def prepare_element(
164185 cropped_image = conv_res .document .pages [
165186 page_ix
166187 ].image .pil_image .crop (expanded_bbox .as_tuple ())
188+
189+ # cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
190+
167191 # cropped_image.show()
168192 result .append (
169193 PostOcrEnrichmentElement (item = c , image = [cropped_image ])
@@ -202,6 +226,8 @@ def prepare_element(
202226 cropped_image = conv_res .document .pages [
203227 page_ix
204228 ].image .pil_image .crop (expanded_bbox .as_tuple ())
229+
230+ # cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
205231 # cropped_image.show()
206232 result .append (
207233 PostOcrEnrichmentElement (
@@ -234,15 +260,27 @@ def prepare_element(
234260 ):
235261 good_bbox = False
236262
237- if good_bbox :
238- cropped_image = conv_res .document .pages [
239- page_ix
240- ].image .pil_image .crop (expanded_bbox .as_tuple ())
241- multiple_crops .append (cropped_image )
242- # cropped_image.show()
263+ if hasattr (element , "text" ):
264+ if good_bbox :
265+ cropped_image = conv_res .document .pages [
266+ page_ix
267+ ].image .pil_image .crop (expanded_bbox .as_tuple ())
268+ # cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
269+
270+ multiple_crops .append (cropped_image )
271+ print ("" )
272+ print ("cropped image size: {}" .format (cropped_image .size ))
273+ print (type (element ))
274+ if hasattr (element , "text" ):
275+ print ("OLD TEXT: {}" .format (element .text ))
276+ # cropped_image.show()
277+ else :
278+ print ("Not a text element" )
243279 if len (multiple_crops ) > 0 :
280+ # good crops
244281 return [PostOcrEnrichmentElement (item = element , image = multiple_crops )]
245282 else :
283+ # nothing
246284 return []
247285
248286 @classmethod
@@ -260,8 +298,9 @@ def __init__(
260298 ):
261299 self .enabled = enabled
262300 self .options = options
263- self .concurrency = 4
301+ self .concurrency = 2
264302 self .expansion_factor = 0.05
303+ # self.expansion_factor = 0.0
265304 self .elements_batch_size = 4
266305 self ._accelerator_options = accelerator_options
267306 self ._artifacts_path = (
@@ -282,7 +321,8 @@ def _api_request(image: Image.Image) -> str:
282321 image = image ,
283322 prompt = self .options .prompt ,
284323 url = self .options .url ,
285- timeout = self .options .timeout ,
324+ # timeout=self.options.timeout,
325+ timeout = 30 ,
286326 headers = self .options .headers ,
287327 ** self .options .params ,
288328 )
@@ -343,36 +383,36 @@ def clean_html_tags(text):
343383 return text
344384
345385 output = clean_html_tags (output ).strip ()
346-
347- if VERBOSE :
348- if isinstance (item , (TextItem )):
349- print (f"OLD TEXT: { item .text } " )
350-
351- # Re-populate text
352- if isinstance (item , (TextItem , GraphCell )):
353- if img_ind > 0 :
354- # Concat texts across several provenances
355- item .text += " " + output
356- item .orig += " " + output
357- else :
386+ if no_long_repeats (output , 50 ):
387+ if VERBOSE :
388+ if isinstance (item , (TextItem )):
389+ print (f"OLD TEXT: { item .text } " )
390+
391+ # Re-populate text
392+ if isinstance (item , (TextItem , GraphCell )):
393+ if img_ind > 0 :
394+ # Concat texts across several provenances
395+ item .text += " " + output
396+ item .orig += " " + output
397+ else :
398+ item .text = output
399+ item .orig = output
400+ elif isinstance (item , (TableCell , RichTableCell )):
358401 item .text = output
359- item .orig = output
360- elif isinstance (item , (TableCell , RichTableCell )):
361- item .text = output
362- elif isinstance (item , PictureItem ):
363- pass
364- else :
365- raise ValueError (f"Unknown item type: { type (item )} " )
402+ elif isinstance (item , PictureItem ):
403+ pass
404+ else :
405+ raise ValueError (f"Unknown item type: { type (item )} " )
366406
367- if VERBOSE :
368- if isinstance (item , (TextItem )):
369- print (f"NEW TEXT: { item .text } " )
407+ if VERBOSE :
408+ if isinstance (item , (TextItem )):
409+ print (f"NEW TEXT: { item .text } " )
370410
371- # Take care of charspans for relevant types
372- if isinstance (item , GraphCell ):
373- item .prov .charspan = (0 , len (item .text ))
374- elif isinstance (item , TextItem ):
375- item .prov [0 ].charspan = (0 , len (item .text ))
411+ # Take care of charspans for relevant types
412+ if isinstance (item , GraphCell ):
413+ item .prov .charspan = (0 , len (item .text ))
414+ elif isinstance (item , TextItem ):
415+ item .prov [0 ].charspan = (0 , len (item .text ))
376416
377417 yield item
378418
@@ -382,7 +422,8 @@ def convert_pdf(pdf_path: Path, out_intermediate_json: Path):
382422 pipeline_options = PdfPipelineOptions ()
383423 pipeline_options .generate_page_images = True
384424 pipeline_options .generate_picture_images = True
385- pipeline_options .images_scale = 4.0
425+ # pipeline_options.images_scale = 4.0
426+ pipeline_options .images_scale = 2.0
386427
387428 doc_converter = (
388429 DocumentConverter ( # all of the below is optional, has internal defaults.
@@ -424,6 +465,7 @@ def post_process_json(in_json: Path, out_final_json: Path):
424465 )
425466 )
426467
468+ # try:
427469 doc_converter = DocumentConverter (
428470 format_options = {
429471 InputFormat .JSON_DOCLING : FormatOption (
@@ -440,6 +482,8 @@ def post_process_json(in_json: Path, out_final_json: Path):
440482 md = result .document .export_to_markdown ()
441483 print ("*** MARKDOWN ***" )
442484 print (md )
485+ # except:
486+ # print("ERROR IN OCR for: {}".format(in_json))
443487
444488
445489def process_pdf (pdf_path : Path , scratch_dir : Path , out_dir : Path ):
0 commit comments