Skip to content

Commit 8fadca8

Browse files
authored
Add --ignore-embedded-text (#47)
This PR adds the `--ignore-embedded-text` flag to the `scan` command so that it is possible to ignore embedded text in document types that support it (e.g. PDFs). The motivation for this feature is that some PDFs have OCR results embedded that are low quality and should be ignored. The addition to `load_document` is tested via: ``` pytest -sv tests/test_end_to_end.py::test_run_with_ignore_embedded_text ``` Notably, the answer with Tesseract only is actually incorrect, because it _seems_ Tesseract is missing entire columns.
1 parent 13c127b commit 8fadca8

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

src/docquery/cmd/scan.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ def build_parser(subparsers, parent_parser):
2727
parser.add_argument(
2828
"--ocr", choices=list(OCR_MAPPING.keys()), default=None, help="The OCR engine you would like to use"
2929
)
30+
parser.add_argument(
31+
"--ignore-embedded-text",
32+
dest="use_embedded_text",
33+
action="store_false",
34+
help="Do not try and extract embedded text from document types that might provide it (e.g. PDFs)",
35+
)
3036
parser.add_argument(
3137
"--classify",
3238
default=False,
@@ -58,7 +64,7 @@ def main(args):
5864
for p in paths:
5965
try:
6066
log.info(f"Loading {p}")
61-
docs.append((p, load_document(str(p), ocr_reader=args.ocr)))
67+
docs.append((p, load_document(str(p), ocr_reader=args.ocr, use_embedded_text=args.use_embedded_text)))
6268
except UnsupportedDocument as e:
6369
log.warning(f"Cannot load {p}: {e}. Skipping...")
6470

src/docquery/document.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ def _generate_document_output(
101101

102102

103103
class PDFDocument(Document):
104-
def __init__(self, b, ocr_reader, **kwargs):
104+
def __init__(self, b, ocr_reader, use_embedded_text, **kwargs):
105105
self.b = b
106106
self.ocr_reader = ocr_reader
107+
self.use_embedded_text = use_embedded_text
107108

108109
super().__init__(**kwargs)
109110

@@ -125,7 +126,7 @@ def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
125126
boxes_by_page = []
126127
dimensions_by_page = []
127128
for i, page in enumerate(pdf.pages):
128-
extracted_words = page.extract_words()
129+
extracted_words = page.extract_words() if self.use_embedded_text else []
129130

130131
if len(extracted_words) == 0:
131132
words, boxes = self.ocr_reader.apply_ocr(images[i])
@@ -234,7 +235,7 @@ def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
234235

235236

236237
@validate_arguments
237-
def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None):
238+
def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None, use_embedded_text=True):
238239
base_path = os.path.basename(fpath).split("?")[0].strip()
239240
doc_type = mimetypes.guess_type(base_path)[0]
240241
if fpath.startswith("http://") or fpath.startswith("https://"):
@@ -255,7 +256,7 @@ def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None
255256
raise NoOCRReaderFound(f"{ocr_reader} is not a supported OCRReader class")
256257

257258
if doc_type == "application/pdf":
258-
return PDFDocument(b.read(), ocr_reader=ocr_reader)
259+
return PDFDocument(b.read(), ocr_reader=ocr_reader, use_embedded_text=use_embedded_text)
259260
elif doc_type == "text/html":
260261
return WebDocument(fpath)
261262
else:

tests/test_end_to_end.py

+14
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class Example(BaseModel):
6767
"question": "What are net sales for 2020?",
6868
"answers": {
6969
"LayoutLMv1": [{"score": 0.9429, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
70+
# (The answer with `use_embedded_text=False` relies entirely on Tesseract, and it is incorrect because it
71+
# misses 3,750 altogether.)
72+
"LayoutLMv1__use_embedded_text=False": [
73+
{"score": 0.3078, "answer": "$ 3,980", "word_ids": [11, 12], "page": 0}
74+
],
7075
"LayoutLMv1-Invoices": [{"score": 0.9956, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
7176
"Donut": [{"answer": "$ 3,750"}],
7277
},
@@ -132,3 +137,12 @@ def test_run_with_choosen_OCR_instance():
132137
for qa in example.qa_pairs:
133138
resp = pipe(question=qa.question, **document.context, top_k=1)
134139
assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"]
140+
141+
142+
def test_run_with_ignore_embedded_text():
143+
example = EXAMPLES[2]
144+
document = load_document(example.path, use_embedded_text=False)
145+
pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"])
146+
for qa in example.qa_pairs:
147+
resp = pipe(question=qa.question, **document.context, top_k=1)
148+
assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1__use_embedded_text=False"]

0 commit comments

Comments
 (0)