Skip to content

Commit bdc7b0a

Browse files
author
Maksym Lysak
committed
Improvements and corrections
Signed-off-by: Maksym Lysak <[email protected]>
1 parent 842e25c commit bdc7b0a

File tree

2 files changed

+142
-81
lines changed

2 files changed

+142
-81
lines changed

docling/utils/api_image_request.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,51 +23,68 @@ def api_image_request(
2323
**params,
2424
) -> Tuple[str, Optional[int], VlmStopReason]:
2525
img_io = BytesIO()
26-
image.save(img_io, "PNG")
27-
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
28-
messages = [
29-
{
30-
"role": "user",
31-
"content": [
26+
image = image.copy()
27+
image = image.convert("RGBA")
28+
good_image = True
29+
try:
30+
image.save(img_io, "PNG")
31+
except:
32+
good_image = False
33+
_log.error("Error, corrupter PNG of size: {}".format(image.size))
34+
35+
if good_image:
36+
try:
37+
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
38+
39+
messages = [
3240
{
33-
"type": "image_url",
34-
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
35-
},
36-
{
37-
"type": "text",
38-
"text": prompt,
39-
},
40-
],
41-
}
42-
]
43-
44-
payload = {
45-
"messages": messages,
46-
**params,
47-
}
41+
"role": "user",
42+
"content": [
43+
{
44+
"type": "image_url",
45+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
46+
},
47+
{
48+
"type": "text",
49+
"text": prompt,
50+
},
51+
],
52+
}
53+
]
54+
55+
payload = {
56+
"messages": messages,
57+
**params,
58+
}
59+
60+
headers = headers or {}
61+
62+
r = requests.post(
63+
str(url),
64+
headers=headers,
65+
json=payload,
66+
timeout=timeout,
67+
)
68+
if not r.ok:
69+
_log.error(f"Error calling the API. Response was {r.text}")
70+
# image.show()
71+
# r.raise_for_status()
72+
73+
api_resp = OpenAiApiResponse.model_validate_json(r.text)
74+
generated_text = api_resp.choices[0].message.content.strip()
75+
num_tokens = api_resp.usage.total_tokens
76+
stop_reason = (
77+
VlmStopReason.LENGTH
78+
if api_resp.choices[0].finish_reason == "length"
79+
else VlmStopReason.END_OF_SEQUENCE
80+
)
4881

49-
headers = headers or {}
50-
51-
r = requests.post(
52-
str(url),
53-
headers=headers,
54-
json=payload,
55-
timeout=timeout,
56-
)
57-
if not r.ok:
58-
_log.error(f"Error calling the API. Response was {r.text}")
59-
r.raise_for_status()
60-
61-
api_resp = OpenAiApiResponse.model_validate_json(r.text)
62-
generated_text = api_resp.choices[0].message.content.strip()
63-
num_tokens = api_resp.usage.total_tokens
64-
stop_reason = (
65-
VlmStopReason.LENGTH
66-
if api_resp.choices[0].finish_reason == "length"
67-
else VlmStopReason.END_OF_SEQUENCE
68-
)
69-
70-
return generated_text, num_tokens, stop_reason
82+
return generated_text, num_tokens,
83+
except Exception as e:
84+
_log.error(f"Error, could not process request: {e}")
85+
return "", 0, "bad request"
86+
else:
87+
return "", 0, "bad image"
7188

7289

7390
def api_image_request_streaming(

docs/examples/post_process_ocr_with_vlm.py

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import logging
33
import os
4+
import re
45
from collections.abc import Iterable
56
from concurrent.futures import ThreadPoolExecutor
67
from pathlib import Path
@@ -17,6 +18,7 @@
1718
DocItem,
1819
GraphCell,
1920
KeyValueItem,
21+
FormItem,
2022
PictureItem,
2123
RichTableCell,
2224
TableCell,
@@ -54,10 +56,29 @@
5456
LM_STUDIO_MODEL = "nanonets-ocr2-3b"
5557

5658
DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally. Output pure text, no html and no markdown. Pay attention on line breaks and don't miss text after line break. Put all text in one line."
57-
VERBOSE = False
59+
VERBOSE = True
5860
SHOW_IMAGE = False
5961

6062

63+
def safe_crop(img: Image.Image, bbox):
64+
left, top, right, bottom = bbox
65+
# Clamp to image boundaries
66+
left = max(0, min(left, img.width))
67+
top = max(0, min(top, img.height))
68+
right = max(0, min(right, img.width))
69+
bottom = max(0, min(bottom, img.height))
70+
return img.crop((left, top, right, bottom))
71+
72+
73+
def no_long_repeats(s: str, threshold: int) -> bool:
74+
"""
75+
Returns False if the string `s` contains more than `threshold`
76+
identical characters in a row, otherwise True.
77+
"""
78+
pattern = r'(.)\1{' + str(threshold) + ',}'
79+
return re.search(pattern, s) is None
80+
81+
6182
class PostOcrEnrichmentElement(BaseModel):
6283
model_config = ConfigDict(arbitrary_types_allowed=True)
6384

@@ -136,7 +157,7 @@ def prepare_element(
136157
allowed = (DocItem, TableItem, GraphCell)
137158
assert isinstance(element, allowed)
138159

139-
if isinstance(element, KeyValueItem):
160+
if isinstance(element, (KeyValueItem, FormItem)):
140161
# Yield from the graphCells inside here.
141162
result = []
142163
for c in element.graph.cells:
@@ -164,6 +185,9 @@ def prepare_element(
164185
cropped_image = conv_res.document.pages[
165186
page_ix
166187
].image.pil_image.crop(expanded_bbox.as_tuple())
188+
189+
# cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
190+
167191
# cropped_image.show()
168192
result.append(
169193
PostOcrEnrichmentElement(item=c, image=[cropped_image])
@@ -202,6 +226,8 @@ def prepare_element(
202226
cropped_image = conv_res.document.pages[
203227
page_ix
204228
].image.pil_image.crop(expanded_bbox.as_tuple())
229+
230+
# cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
205231
# cropped_image.show()
206232
result.append(
207233
PostOcrEnrichmentElement(
@@ -234,15 +260,27 @@ def prepare_element(
234260
):
235261
good_bbox = False
236262

237-
if good_bbox:
238-
cropped_image = conv_res.document.pages[
239-
page_ix
240-
].image.pil_image.crop(expanded_bbox.as_tuple())
241-
multiple_crops.append(cropped_image)
242-
# cropped_image.show()
263+
if hasattr(element, "text"):
264+
if good_bbox:
265+
cropped_image = conv_res.document.pages[
266+
page_ix
267+
].image.pil_image.crop(expanded_bbox.as_tuple())
268+
# cropped_image = safe_crop(conv_res.document.pages[page_ix].image.pil_image, expanded_bbox.as_tuple())
269+
270+
multiple_crops.append(cropped_image)
271+
print("")
272+
print("cropped image size: {}".format(cropped_image.size))
273+
print(type(element))
274+
if hasattr(element, "text"):
275+
print("OLD TEXT: {}".format(element.text))
276+
# cropped_image.show()
277+
else:
278+
print("Not a text element")
243279
if len(multiple_crops) > 0:
280+
# good crops
244281
return [PostOcrEnrichmentElement(item=element, image=multiple_crops)]
245282
else:
283+
# nothing
246284
return []
247285

248286
@classmethod
@@ -260,8 +298,9 @@ def __init__(
260298
):
261299
self.enabled = enabled
262300
self.options = options
263-
self.concurrency = 4
301+
self.concurrency = 2
264302
self.expansion_factor = 0.05
303+
# self.expansion_factor = 0.0
265304
self.elements_batch_size = 4
266305
self._accelerator_options = accelerator_options
267306
self._artifacts_path = (
@@ -282,7 +321,8 @@ def _api_request(image: Image.Image) -> str:
282321
image=image,
283322
prompt=self.options.prompt,
284323
url=self.options.url,
285-
timeout=self.options.timeout,
324+
# timeout=self.options.timeout,
325+
timeout=30,
286326
headers=self.options.headers,
287327
**self.options.params,
288328
)
@@ -343,36 +383,36 @@ def clean_html_tags(text):
343383
return text
344384

345385
output = clean_html_tags(output).strip()
346-
347-
if VERBOSE:
348-
if isinstance(item, (TextItem)):
349-
print(f"OLD TEXT: {item.text}")
350-
351-
# Re-populate text
352-
if isinstance(item, (TextItem, GraphCell)):
353-
if img_ind > 0:
354-
# Concat texts across several provenances
355-
item.text += " " + output
356-
item.orig += " " + output
357-
else:
386+
if no_long_repeats(output, 50):
387+
if VERBOSE:
388+
if isinstance(item, (TextItem)):
389+
print(f"OLD TEXT: {item.text}")
390+
391+
# Re-populate text
392+
if isinstance(item, (TextItem, GraphCell)):
393+
if img_ind > 0:
394+
# Concat texts across several provenances
395+
item.text += " " + output
396+
item.orig += " " + output
397+
else:
398+
item.text = output
399+
item.orig = output
400+
elif isinstance(item, (TableCell, RichTableCell)):
358401
item.text = output
359-
item.orig = output
360-
elif isinstance(item, (TableCell, RichTableCell)):
361-
item.text = output
362-
elif isinstance(item, PictureItem):
363-
pass
364-
else:
365-
raise ValueError(f"Unknown item type: {type(item)}")
402+
elif isinstance(item, PictureItem):
403+
pass
404+
else:
405+
raise ValueError(f"Unknown item type: {type(item)}")
366406

367-
if VERBOSE:
368-
if isinstance(item, (TextItem)):
369-
print(f"NEW TEXT: {item.text}")
407+
if VERBOSE:
408+
if isinstance(item, (TextItem)):
409+
print(f"NEW TEXT: {item.text}")
370410

371-
# Take care of charspans for relevant types
372-
if isinstance(item, GraphCell):
373-
item.prov.charspan = (0, len(item.text))
374-
elif isinstance(item, TextItem):
375-
item.prov[0].charspan = (0, len(item.text))
411+
# Take care of charspans for relevant types
412+
if isinstance(item, GraphCell):
413+
item.prov.charspan = (0, len(item.text))
414+
elif isinstance(item, TextItem):
415+
item.prov[0].charspan = (0, len(item.text))
376416

377417
yield item
378418

@@ -382,7 +422,8 @@ def convert_pdf(pdf_path: Path, out_intermediate_json: Path):
382422
pipeline_options = PdfPipelineOptions()
383423
pipeline_options.generate_page_images = True
384424
pipeline_options.generate_picture_images = True
385-
pipeline_options.images_scale = 4.0
425+
# pipeline_options.images_scale = 4.0
426+
pipeline_options.images_scale = 2.0
386427

387428
doc_converter = (
388429
DocumentConverter( # all of the below is optional, has internal defaults.
@@ -424,6 +465,7 @@ def post_process_json(in_json: Path, out_final_json: Path):
424465
)
425466
)
426467

468+
# try:
427469
doc_converter = DocumentConverter(
428470
format_options={
429471
InputFormat.JSON_DOCLING: FormatOption(
@@ -440,6 +482,8 @@ def post_process_json(in_json: Path, out_final_json: Path):
440482
md = result.document.export_to_markdown()
441483
print("*** MARKDOWN ***")
442484
print(md)
485+
# except:
486+
# print("ERROR IN OCR for: {}".format(in_json))
443487

444488

445489
def process_pdf(pdf_path: Path, scratch_dir: Path, out_dir: Path):

0 commit comments

Comments
 (0)