Skip to content

Commit 6c59219

Browse files
Merge pull request #156 from NotJoeMartinez/llm_rag_integration
Llm rag integration
2 parents 59fe92a + 28cb5bb commit 6c59219

File tree

3 files changed

+285
-1
lines changed

3 files changed

+285
-1
lines changed

tests/test_download.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77

88
CONFIG_DIR = os.path.expanduser('~/.config/yt-fts')
99

10+
11+
@pytest.fixture(scope="session", autouse=True)
12+
def cleanup_after_tests():
13+
yield
14+
if os.path.exists(CONFIG_DIR):
15+
shutil.rmtree(CONFIG_DIR)
16+
if os.path.exists(f"{CONFIG_DIR}_backup"):
17+
shutil.move(f"{CONFIG_DIR}_backup", CONFIG_DIR)
18+
19+
1020
@pytest.fixture
1121
def runner():
1222
return CliRunner()
@@ -15,7 +25,13 @@ def runner():
1525
def reset_testing_env():
1626
if os.path.exists(CONFIG_DIR):
1727
if os.environ.get('YT_FTS_TEST_RESET', 'true').lower() == 'true':
28+
29+
if os.path.exists(CONFIG_DIR):
30+
if not os.path.exists(f"{CONFIG_DIR}_backup"):
31+
shutil.copytree(CONFIG_DIR, f"{CONFIG_DIR}_backup")
32+
1833
shutil.rmtree(CONFIG_DIR)
34+
1935
else:
2036
print('running tests with existing db')
2137

@@ -100,6 +116,7 @@ def test_playlist_download(runner, capsys):
100116
assert subtitle_count >= 20970, f"Expected 20970 subtitles, but got {subtitle_count}"
101117

102118

119+
120+
103121
if __name__ == "__main__":
104122
pytest.main([__file__])
105-

yt_fts/llm.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from rich.console import Console
2+
from rich.markdown import Markdown
3+
from rich.panel import Panel
4+
from rich.prompt import Prompt
5+
from rich.text import Text
6+
import textwrap
7+
import sys
8+
import traceback
9+
from openai import OpenAI
10+
from .db_utils import (
11+
get_channel_id_from_input,
12+
get_channel_name_from_video_id,
13+
get_title_from_db
14+
)
15+
from .get_embeddings import get_embedding
16+
from .utils import time_to_secs
17+
from .config import get_chroma_client
18+
19+
class LLMHandler:
20+
def __init__(self, openai_api_key: str, channel: str):
21+
self.openai_client = OpenAI(api_key=openai_api_key)
22+
self.channel_id = get_channel_id_from_input(channel)
23+
self.chroma_client = get_chroma_client()
24+
self.console = Console()
25+
self.max_width = 80
26+
27+
def init_llm(self, prompt: str):
28+
messages = self.start_llm(prompt)
29+
self.display_message(messages[-1]["content"], "assistant")
30+
31+
while True:
32+
user_input = Prompt.ask("> ")
33+
if user_input.lower() == "exit":
34+
self.console.print("Goodbye!", style="bold red")
35+
sys.exit(0)
36+
messages.append({"role": "user", "content": user_input})
37+
messages = self.continue_llm(messages)
38+
self.display_message(messages[-1]["content"], "assistant")
39+
40+
def display_message(self, content: str, role: str):
41+
if role == "assistant":
42+
wrapped_content = self.wrap_text(content)
43+
md = Markdown(wrapped_content)
44+
# self.console.print(Panel(md, expand=False, border_style="green"))
45+
self.console.print(md)
46+
else:
47+
wrapped_content = self.wrap_text(content)
48+
self.console.print(Text(wrapped_content, style="bold blue"))
49+
50+
def wrap_text(self, text: str) -> str:
51+
lines = text.split('\n')
52+
wrapped_lines = []
53+
54+
for line in lines:
55+
# If the line is a code block, don't wrap it
56+
if line.strip().startswith('```') or line.strip().startswith('`'):
57+
wrapped_lines.append(line)
58+
else:
59+
# Wrap the line
60+
wrapped = textwrap.wrap(line, width=self.max_width, break_long_words=False, replace_whitespace=False)
61+
wrapped_lines.extend(wrapped)
62+
63+
64+
# Join the wrapped lines back together
65+
return " \n".join(wrapped_lines)
66+
67+
68+
def start_llm(self, prompt: str) -> list:
69+
try:
70+
context = self.create_context(prompt)
71+
user_str = f"Context: {context}\n\n---\n\nQuestion: {prompt}\nAnswer:"
72+
system_prompt = """
73+
Answer the question based on the context below, The context are
74+
subtitles and timestamped links from videos related to the question.
75+
In your answer, provide the link to the video where the answer can
76+
be found. and if the question can't be answered based on the context,
77+
say \"I don't know\" AND ONLY I don't know\n\n
78+
"""
79+
messages = [
80+
{"role": "system", "content": system_prompt},
81+
{"role": "user", "content": user_str},
82+
]
83+
84+
response_text = self.get_completion(messages)
85+
86+
if "i don't know" in response_text.lower():
87+
expanded_query = self.get_expand_context_query(messages)
88+
expanded_context = self.create_context(expanded_query)
89+
messages.append({
90+
"role": "user",
91+
"content": f"Okay here is some more context:\n---\n\n{expanded_context}\n\n---"
92+
})
93+
response_text = self.get_completion(messages)
94+
95+
messages.append({
96+
"role": "assistant",
97+
"content": response_text
98+
})
99+
return messages
100+
101+
except Exception as e:
102+
self.display_error(e)
103+
104+
def continue_llm(self, messages: list) -> list:
105+
try:
106+
response_text = self.get_completion(messages)
107+
108+
if "i don't know" in response_text.lower():
109+
expanded_query = self.get_expand_context_query(messages)
110+
self.console.print(f"[italic]Expanding context with query: {expanded_query}[/italic]")
111+
expanded_context = self.create_context(expanded_query)
112+
messages.append({
113+
"role": "user",
114+
"content": f"Okay here is some more context:\n---\n\n{expanded_context}\n\n---"
115+
})
116+
response_text = self.get_completion(messages)
117+
118+
messages.append({
119+
"role": "assistant",
120+
"content": response_text
121+
})
122+
return messages
123+
124+
except Exception as e:
125+
self.display_error(e)
126+
127+
def display_error(self, error: Exception):
128+
self.console.print(Panel(str(error), title="Error", border_style="red"))
129+
traceback.print_exc()
130+
sys.exit(1)
131+
132+
def create_context(self, text: str) -> str:
133+
collection = self.chroma_client.get_collection(name="subEmbeddings")
134+
search_embedding = get_embedding(text, "text-embedding-ada-002", self.openai_client)
135+
scope_options = {"channel_id": self.channel_id}
136+
137+
chroma_res = collection.query(
138+
query_embeddings=[search_embedding],
139+
n_results=10,
140+
where=scope_options,
141+
)
142+
143+
documents = chroma_res["documents"][0]
144+
metadata = chroma_res["metadatas"][0]
145+
distances = chroma_res["distances"][0]
146+
147+
res = []
148+
for i in range(len(documents)):
149+
text = documents[i]
150+
video_id = metadata[i]["video_id"]
151+
start_time = metadata[i]["start_time"]
152+
link = f"https://youtu.be/{video_id}?t={time_to_secs(start_time)}"
153+
channel_name = get_channel_name_from_video_id(video_id)
154+
channel_id = metadata[i]["channel_id"]
155+
title = get_title_from_db(video_id)
156+
157+
match = {
158+
"distance": distances[i],
159+
"channel_name": channel_name,
160+
"channel_id": channel_id,
161+
"video_title": title,
162+
"subs": text,
163+
"start_time": start_time,
164+
"video_id": video_id,
165+
"link": link,
166+
}
167+
res.append(match)
168+
169+
return self.format_context(res)
170+
171+
def get_expand_context_query(self, messages: list) -> str:
172+
try:
173+
system_prompt = """
174+
Your task is to generate a question to input into a vector search
175+
engine of youtube subitles to find strings that can answer the question
176+
asked in the previous message.
177+
"""
178+
formatted_context = self.format_message_history_context(messages)
179+
messages = [
180+
{"role": "system", "content": system_prompt},
181+
{"role": "user", "content": formatted_context},
182+
]
183+
184+
return self.get_completion(messages)
185+
186+
except Exception as e:
187+
self.display_error(e)
188+
189+
def get_completion(self, messages: list) -> str:
190+
try:
191+
response = self.openai_client.chat.completions.create(
192+
model="gpt-4",
193+
messages=messages,
194+
temperature=0,
195+
max_tokens=2000,
196+
top_p=1,
197+
frequency_penalty=0,
198+
presence_penalty=0,
199+
stop=None,
200+
)
201+
return response.choices[0].message.content
202+
203+
except Exception as e:
204+
self.display_error(e)
205+
206+
@staticmethod
207+
def format_message_history_context(messages: list) -> str:
208+
formatted_context = ""
209+
for message in messages:
210+
if message["role"] == "system":
211+
continue
212+
role = message["role"]
213+
content = message["content"]
214+
formatted_context += f"{role}: {content}\n"
215+
return formatted_context
216+
217+
@staticmethod
218+
def format_context(chroma_res: list) -> str:
219+
formatted_context = ""
220+
for obj in chroma_res:
221+
tmp = f"""
222+
Video Title: {obj["video_title"]}
223+
Text: {obj["subs"]}
224+
Time: {obj["start_time"]}
225+
Similarity: {obj["distance"]}
226+
Link: {obj["link"]}
227+
-------------------------
228+
"""
229+
formatted_context += tmp
230+
return formatted_context

yt_fts/yt_fts.py

+37
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,43 @@ def embeddings(channel, openai_api_key, interval=10):
365365
sys.exit(0)
366366

367367

368+
@cli.command(
369+
name="llm",
370+
help="""
371+
Interactive LLM chat bot RAG bot, needs to be run on a channel with
372+
Embeddings.
373+
"""
374+
)
375+
@click.argument("prompt", required=True)
376+
@click.option("-c",
377+
"--channel",
378+
default=None,
379+
required=True,
380+
help="The name or id of the channel to generate embeddings for")
381+
@click.option("--openai-api-key",
382+
default=None,
383+
help="OpenAI API key. If not provided, the script will attempt to read it from"
384+
" the OPENAI_API_KEY environment variable.")
385+
def llm(prompt, channel, openai_api_key=None):
386+
from yt_fts.llm import LLMHandler
387+
388+
if openai_api_key is None:
389+
openai_api_key = os.environ.get("OPENAI_API_KEY")
390+
391+
if openai_api_key is None:
392+
console.print("""
393+
[bold][red]Error:[/red][/bold] OPENAI_API_KEY environment variable not set, Run:
394+
395+
export OPENAI_API_KEY=<your_key> to set the key
396+
""")
397+
sys.exit(1)
398+
399+
llm_handler = LLMHandler(openai_api_key, channel)
400+
llm_handler.init_llm(prompt)
401+
402+
sys.exit(0)
403+
404+
368405
@cli.command(
369406
help="""
370407
Show config settings

0 commit comments

Comments
 (0)