1
1
import logging
2
2
from abc import ABC , abstractmethod
3
3
from importlib .metadata import PackageNotFoundError , version
4
- from typing import Any
4
+ from typing import Any , Generic , List , TypeVar
5
5
6
+ from attr import dataclass
6
7
from langchain_core .documents import Document
7
8
8
9
from quivr_core .files .file import FileExtension , QuivrFile
11
12
logger = logging .getLogger ("quivr_core" )
12
13
13
14
15
+ R = TypeVar ("R" , covariant = True )
16
+
17
+
18
+ @dataclass
19
+ class ProcessedDocument (Generic [R ]):
20
+ chunks : List [Document ]
21
+ processor_cls : str
22
+ processor_response : R
23
+
24
+
14
25
# TODO: processors should be cached somewhere ?
15
26
# The processor should be cached by processor type
16
27
# The cache should use a single
17
- class ProcessorBase (ABC ):
28
+ class ProcessorBase (ABC , Generic [ R ] ):
18
29
supported_extensions : list [FileExtension | str ]
19
30
20
- def check_supported (self , file : QuivrFile ):
31
+ def check_supported (self , file : QuivrFile ) -> None :
21
32
if file .file_extension not in self .supported_extensions :
22
33
raise ValueError (f"can't process a file of type { file .file_extension } " )
23
34
@@ -26,7 +37,7 @@ def check_supported(self, file: QuivrFile):
26
37
def processor_metadata (self ) -> dict [str , Any ]:
27
38
raise NotImplementedError
28
39
29
- async def process_file (self , file : QuivrFile ) -> list [ Document ]:
40
+ async def process_file (self , file : QuivrFile ) -> ProcessedDocument [ R ]:
30
41
logger .debug (f"Processing file { file } " )
31
42
self .check_supported (file )
32
43
docs = await self .process_file_inner (file )
@@ -35,7 +46,7 @@ async def process_file(self, file: QuivrFile) -> list[Document]:
35
46
except PackageNotFoundError :
36
47
qvr_version = "dev"
37
48
38
- for idx , doc in enumerate (docs , start = 1 ):
49
+ for idx , doc in enumerate (docs . chunks , start = 1 ):
39
50
if "original_file_name" in doc .metadata :
40
51
doc .page_content = f"Filename: { doc .metadata ['original_file_name' ]} Content: { doc .page_content } "
41
52
doc .page_content = doc .page_content .replace ("\u0000 " , "" )
@@ -56,5 +67,5 @@ async def process_file(self, file: QuivrFile) -> list[Document]:
56
67
return docs
57
68
58
69
@abstractmethod
59
- async def process_file_inner (self , file : QuivrFile ) -> list [ Document ]:
70
+ async def process_file_inner (self , file : QuivrFile ) -> ProcessedDocument [ R ]:
60
71
raise NotImplementedError
0 commit comments