1
+ from rich .console import Console
2
+ from rich .markdown import Markdown
3
+ from rich .panel import Panel
4
+ from rich .prompt import Prompt
5
+ from rich .text import Text
6
+ import textwrap
7
+ import sys
8
+ import traceback
9
+ from openai import OpenAI
10
+ from .db_utils import (
11
+ get_channel_id_from_input ,
12
+ get_channel_name_from_video_id ,
13
+ get_title_from_db
14
+ )
15
+ from .get_embeddings import get_embedding
16
+ from .utils import time_to_secs
17
+ from .config import get_chroma_client
18
+
19
+ class LLMHandler :
20
+ def __init__ (self , openai_api_key : str , channel : str ):
21
+ self .openai_client = OpenAI (api_key = openai_api_key )
22
+ self .channel_id = get_channel_id_from_input (channel )
23
+ self .chroma_client = get_chroma_client ()
24
+ self .console = Console ()
25
+ self .max_width = 80
26
+
27
+ def init_llm (self , prompt : str ):
28
+ messages = self .start_llm (prompt )
29
+ self .display_message (messages [- 1 ]["content" ], "assistant" )
30
+
31
+ while True :
32
+ user_input = Prompt .ask ("> " )
33
+ if user_input .lower () == "exit" :
34
+ self .console .print ("Goodbye!" , style = "bold red" )
35
+ sys .exit (0 )
36
+ messages .append ({"role" : "user" , "content" : user_input })
37
+ messages = self .continue_llm (messages )
38
+ self .display_message (messages [- 1 ]["content" ], "assistant" )
39
+
40
+ def display_message (self , content : str , role : str ):
41
+ if role == "assistant" :
42
+ wrapped_content = self .wrap_text (content )
43
+ md = Markdown (wrapped_content )
44
+ # self.console.print(Panel(md, expand=False, border_style="green"))
45
+ self .console .print (md )
46
+ else :
47
+ wrapped_content = self .wrap_text (content )
48
+ self .console .print (Text (wrapped_content , style = "bold blue" ))
49
+
50
+ def wrap_text (self , text : str ) -> str :
51
+ lines = text .split ('\n ' )
52
+ wrapped_lines = []
53
+
54
+ for line in lines :
55
+ # If the line is a code block, don't wrap it
56
+ if line .strip ().startswith ('```' ) or line .strip ().startswith ('`' ):
57
+ wrapped_lines .append (line )
58
+ else :
59
+ # Wrap the line
60
+ wrapped = textwrap .wrap (line , width = self .max_width , break_long_words = False , replace_whitespace = False )
61
+ wrapped_lines .extend (wrapped )
62
+
63
+
64
+ # Join the wrapped lines back together
65
+ return " \n " .join (wrapped_lines )
66
+
67
+
68
+ def start_llm (self , prompt : str ) -> list :
69
+ try :
70
+ context = self .create_context (prompt )
71
+ user_str = f"Context: { context } \n \n ---\n \n Question: { prompt } \n Answer:"
72
+ system_prompt = """
73
+ Answer the question based on the context below, The context are
74
+ subtitles and timestamped links from videos related to the question.
75
+ In your answer, provide the link to the video where the answer can
76
+ be found. and if the question can't be answered based on the context,
77
+ say \" I don't know\" AND ONLY I don't know\n \n
78
+ """
79
+ messages = [
80
+ {"role" : "system" , "content" : system_prompt },
81
+ {"role" : "user" , "content" : user_str },
82
+ ]
83
+
84
+ response_text = self .get_completion (messages )
85
+
86
+ if "i don't know" in response_text .lower ():
87
+ expanded_query = self .get_expand_context_query (messages )
88
+ expanded_context = self .create_context (expanded_query )
89
+ messages .append ({
90
+ "role" : "user" ,
91
+ "content" : f"Okay here is some more context:\n ---\n \n { expanded_context } \n \n ---"
92
+ })
93
+ response_text = self .get_completion (messages )
94
+
95
+ messages .append ({
96
+ "role" : "assistant" ,
97
+ "content" : response_text
98
+ })
99
+ return messages
100
+
101
+ except Exception as e :
102
+ self .display_error (e )
103
+
104
+ def continue_llm (self , messages : list ) -> list :
105
+ try :
106
+ response_text = self .get_completion (messages )
107
+
108
+ if "i don't know" in response_text .lower ():
109
+ expanded_query = self .get_expand_context_query (messages )
110
+ self .console .print (f"[italic]Expanding context with query: { expanded_query } [/italic]" )
111
+ expanded_context = self .create_context (expanded_query )
112
+ messages .append ({
113
+ "role" : "user" ,
114
+ "content" : f"Okay here is some more context:\n ---\n \n { expanded_context } \n \n ---"
115
+ })
116
+ response_text = self .get_completion (messages )
117
+
118
+ messages .append ({
119
+ "role" : "assistant" ,
120
+ "content" : response_text
121
+ })
122
+ return messages
123
+
124
+ except Exception as e :
125
+ self .display_error (e )
126
+
127
+ def display_error (self , error : Exception ):
128
+ self .console .print (Panel (str (error ), title = "Error" , border_style = "red" ))
129
+ traceback .print_exc ()
130
+ sys .exit (1 )
131
+
132
+ def create_context (self , text : str ) -> str :
133
+ collection = self .chroma_client .get_collection (name = "subEmbeddings" )
134
+ search_embedding = get_embedding (text , "text-embedding-ada-002" , self .openai_client )
135
+ scope_options = {"channel_id" : self .channel_id }
136
+
137
+ chroma_res = collection .query (
138
+ query_embeddings = [search_embedding ],
139
+ n_results = 10 ,
140
+ where = scope_options ,
141
+ )
142
+
143
+ documents = chroma_res ["documents" ][0 ]
144
+ metadata = chroma_res ["metadatas" ][0 ]
145
+ distances = chroma_res ["distances" ][0 ]
146
+
147
+ res = []
148
+ for i in range (len (documents )):
149
+ text = documents [i ]
150
+ video_id = metadata [i ]["video_id" ]
151
+ start_time = metadata [i ]["start_time" ]
152
+ link = f"https://youtu.be/{ video_id } ?t={ time_to_secs (start_time )} "
153
+ channel_name = get_channel_name_from_video_id (video_id )
154
+ channel_id = metadata [i ]["channel_id" ]
155
+ title = get_title_from_db (video_id )
156
+
157
+ match = {
158
+ "distance" : distances [i ],
159
+ "channel_name" : channel_name ,
160
+ "channel_id" : channel_id ,
161
+ "video_title" : title ,
162
+ "subs" : text ,
163
+ "start_time" : start_time ,
164
+ "video_id" : video_id ,
165
+ "link" : link ,
166
+ }
167
+ res .append (match )
168
+
169
+ return self .format_context (res )
170
+
171
+ def get_expand_context_query (self , messages : list ) -> str :
172
+ try :
173
+ system_prompt = """
174
+ Your task is to generate a question to input into a vector search
175
+ engine of youtube subitles to find strings that can answer the question
176
+ asked in the previous message.
177
+ """
178
+ formatted_context = self .format_message_history_context (messages )
179
+ messages = [
180
+ {"role" : "system" , "content" : system_prompt },
181
+ {"role" : "user" , "content" : formatted_context },
182
+ ]
183
+
184
+ return self .get_completion (messages )
185
+
186
+ except Exception as e :
187
+ self .display_error (e )
188
+
189
+ def get_completion (self , messages : list ) -> str :
190
+ try :
191
+ response = self .openai_client .chat .completions .create (
192
+ model = "gpt-4" ,
193
+ messages = messages ,
194
+ temperature = 0 ,
195
+ max_tokens = 2000 ,
196
+ top_p = 1 ,
197
+ frequency_penalty = 0 ,
198
+ presence_penalty = 0 ,
199
+ stop = None ,
200
+ )
201
+ return response .choices [0 ].message .content
202
+
203
+ except Exception as e :
204
+ self .display_error (e )
205
+
206
+ @staticmethod
207
+ def format_message_history_context (messages : list ) -> str :
208
+ formatted_context = ""
209
+ for message in messages :
210
+ if message ["role" ] == "system" :
211
+ continue
212
+ role = message ["role" ]
213
+ content = message ["content" ]
214
+ formatted_context += f"{ role } : { content } \n "
215
+ return formatted_context
216
+
217
+ @staticmethod
218
+ def format_context (chroma_res : list ) -> str :
219
+ formatted_context = ""
220
+ for obj in chroma_res :
221
+ tmp = f"""
222
+ Video Title: { obj ["video_title" ]}
223
+ Text: { obj ["subs" ]}
224
+ Time: { obj ["start_time" ]}
225
+ Similarity: { obj ["distance" ]}
226
+ Link: { obj ["link" ]}
227
+ -------------------------
228
+ """
229
+ formatted_context += tmp
230
+ return formatted_context
0 commit comments