Skip to content

Commit 59fe92a

Browse files
Merge pull request #155 from NotJoeMartinez/embedding_text_splitter
Embedding text splitter
2 parents d93a95a + 9307e12 commit 59fe92a

File tree

9 files changed

+34
-153
lines changed

9 files changed

+34
-153
lines changed

Diff for: .gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,5 @@ UCYO_jab_esuFRV4b17AJtAw
173173
.ignore
174174
.ignore/
175175
tests/test_data/
176-
.idea
176+
.idea
177+
*.sh

Diff for: README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,25 @@ This requires an OpenAI API key set in the environment variable `OPENAI_API_KEY`
9696
you can pass the key with the `--openai-api-key` flag.
9797

9898

99-
## `get-embedings`
99+
## `embeddings`
100100
Fetches OpenAI embeddings for specified channel
101101
```bash
102102

103103
# make sure openAI key is set
104104
# export OPENAI_API_KEY="[yourOpenAIKey]"
105105

106-
yt-fts get-embeddings --channel "3Blue1Brown"
106+
yt-fts embeddings --channel "3Blue1Brown"
107+
108+
# specify time interval in seconds to split text by default is 10
109+
yt-fts embeddings --interval 60 --channel "3Blue1Brown"
107110
```
108111

109112
After the embeddings are saved you will see a `(ss)` next to the channel name when you
110113
list channels and you will be able to use the `vsearch` command for that channel.
111114

112115
## `vsearch` (Semantic Search)
113116
`vsearch` is for "Vector search". This requires that you enable semantic
114-
search for a channel with `get-embeddings`. It has the same options as
117+
search for a channel with `embeddings`. It has the same options as
115118
`search` but output will be sorted by similarity to the search string and
116119
the default return limit is 10.
117120

Diff for: tests/basic.sh

-134
This file was deleted.

Diff for: tests/view_chromadb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import chromadb
22
import sys
33
from openai import OpenAI
4-
from yt_fts.embeddings import get_embedding
4+
from yt_fts.get_embeddings import get_embedding
55
from yt_fts.config import get_or_make_chroma_path
66
from yt_fts.utils import time_to_secs
77
from yt_fts.db_utils import get_channel_name_from_video_id, get_title_from_db
File renamed without changes.

Diff for: yt_fts/list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def list_channels(channel_id=None):
110110
console.print("")
111111

112112

113-
# not dry but for some reason importing from embeddings.py causes slow down
113+
# not dry but for some reason importing from get_embeddings.py causes slow down
114114
def check_ss_enabled(channel_id=None):
115115
from yt_fts.config import get_db_path
116116

Diff for: yt_fts/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def enable_ss(channel_id):
145145
con.close()
146146

147147

148-
def split_subtitles(video_id):
148+
def split_subtitles(video_id, interval=60):
149149
from datetime import datetime
150150
from .db_utils import get_subs_by_video_id
151151

@@ -172,8 +172,13 @@ def time_to_seconds(time_str):
172172

173173
interval_texts = {}
174174
for start, start_time_str, text in converted_data:
175-
interval = int(start // 10) * 10
176-
key = interval_texts.setdefault(interval, {'start_time': start_time_str, 'texts': []})
175+
split_interval = int(start // interval) * interval
176+
177+
key = interval_texts.setdefault(split_interval, {
178+
'start_time': start_time_str,
179+
'texts': []
180+
})
181+
177182
key['texts'].append(text)
178183

179184
result = [(data['start_time'], ' '.join(data['texts']).strip()) for data in interval_texts.values()]

Diff for: yt_fts/vector_search.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlite_utils import Database
55

66
from .utils import time_to_secs, bold_query_matches
7-
from .embeddings import get_embedding
7+
from .get_embeddings import get_embedding
88
from .config import get_chroma_client
99
from .db_utils import (
1010
get_channel_name_from_video_id,

Diff for: yt_fts/yt_fts.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def delete(channel):
166166
console.print("[bold]Are you sure you want to delete this channel and all its data?[/bold]")
167167
confirm = input("(Y/n): ")
168168

169-
if confirm == "y":
169+
if confirm.lower() == "y":
170170
delete_channel(channel_id)
171171
print(f"Deleted channel {channel_name}: {channel_url}")
172172
else:
@@ -303,17 +303,23 @@ def vsearch(text, channel, video, limit, export, openai_api_key):
303303
@cli.command(
304304
help="""
305305
Generate embeddings for a channel using OpenAI's embeddings API.
306-
307306
Requires an OpenAI API key to be set as an environment variable OPENAI_API_KEY.
308307
"""
309308
)
310-
@click.option("-c", "--channel", default=None, help="The name or id of the channel to generate embeddings for")
311-
@click.option("--openai-api-key", default=None,
312-
help="OpenAI API key. If not provided, the script will attempt to read it from the OPENAI_API_KEY "
313-
"environment variable.")
314-
def get_embeddings(channel, openai_api_key):
309+
@click.option("-c", "--channel",
310+
default=None,
311+
help="The name or id of the channel to generate embeddings for")
312+
@click.option("--openai-api-key",
313+
default=None,
314+
help="OpenAI API key. If not provided, the script will attempt to read it from"
315+
" the OPENAI_API_KEY environment variable.")
316+
@click.option("-i", "--interval",
317+
default=10,
318+
type=int,
319+
help="Interval in seconds to split the transcripts into chunks")
320+
def embeddings(channel, openai_api_key, interval=10):
315321
from yt_fts.db_utils import get_vid_ids_by_channel_id
316-
from yt_fts.embeddings import add_embeddings_to_chroma
322+
from yt_fts.get_embeddings import add_embeddings_to_chroma
317323
from yt_fts.utils import split_subtitles, check_ss_enabled, enable_ss
318324
from openai import OpenAI
319325

@@ -342,7 +348,7 @@ def get_embeddings(channel, openai_api_key):
342348

343349
channel_subs = []
344350
for vid_id in channel_video_ids:
345-
split_subs = split_subtitles(vid_id[0])
351+
split_subs = split_subtitles(vid_id[0], interval=interval)
346352
if split_subs is None:
347353
continue
348354
for sub in split_subs:

0 commit comments

Comments
 (0)