-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
222 lines (180 loc) · 8.15 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import re, os, signal
import praw, psycopg2
import requests
import xml.etree.ElementTree as ET
from sys import exit
from datetime import datetime, timezone, timedelta
from time import sleep
from contextlib import closing
class APIProcess(object):
base36_pattern = '[0-9a-z]+'
def __init__(self, source_version):
# reuse requests session for PRAW and external requests (feeds, etc.)
self.http_session = requests.Session()
# PRAW reads the env variable 'praw_site' automatically
self.reddit = praw.Reddit(
config_interpolation='basic',
requestor_kwargs={'session': self.http_session}
)
# make sure we use the same user-agent on all requests
self.http_session.headers.update({'user-agent': self.reddit.config.user_agent})
# check db connectivity and setup tables
self.db = psycopg2.connect(os.getenv('DATABASE_URL'))
self.init_db(source_version)
def _to_short_id(self, full_id):
"""Remove prefix from base36 id"""
return full_id.split('_').pop()
def _to_full_id(self, kind, short_id):
"""Add prefix to base36 id"""
prefix = self.reddit.config.kinds[kind]
return f'{prefix}_{short_id}'
def _exit_handler(self, signum, frame):
# disconnect db
self.db.commit()
self.db.close()
# end process
exit(0)
def init_db(self, source_version):
with closing(self.db.cursor()) as cur:
# get current db version (if any)
cur.execute("""
CREATE TABLE IF NOT EXISTS kv_store(
key VARCHAR(256),
value VARCHAR(256),
CONSTRAINT kv_store_pkey PRIMARY KEY(key));
SELECT value FROM kv_store
WHERE key='version';
""")
res = cur.fetchone()
if res is not None:
db_version = res[0]
if db_version != source_version:
raise Exception(f"Version mismatch (current: {source_version}; db: {db_version})")
else:
# save current db version
cur.execute("""
INSERT INTO kv_store(key, value)
VALUES ('version', %s);
""", (source_version,))
self.db.commit()
def setup_interrupt_handlers(self):
signal.signal(signal.SIGINT, self._exit_handler)
signal.signal(signal.SIGTERM, self._exit_handler)
class XMLProcess(APIProcess):
ATOM_NS = {'atom': 'http://www.w3.org/2005/Atom'}
def __init__(self, source_version, subreddit, path, kind_class):
# create PRAW instance and db connection
super().__init__(source_version)
# feed global params
self.subreddit, self.path, self.kind_class = subreddit, path, kind_class
self.kind = self.reddit.config.kinds[self.kind_class.__name__.lower()]
# private (stateful) vars
self._last_timestamp = None
self._after_full_id = None
self._db_key = f'{self.subreddit}_{self.path}_after_full_id'
def _exit_handler(self, signum, frame):
if self._after_full_id is not None:
# save last id seen for this path
with closing(self.db.cursor()) as cur:
cur.execute("""
INSERT INTO kv_store(key, value)
VALUES(%s, %s)
ON CONFLICT ON CONSTRAINT kv_store_pkey
DO UPDATE SET value=%s WHERE kv_store.key=%s;
""", (self._db_key, self._after_full_id,
self._after_full_id, self._db_key))
self.db.commit()
# terminate process
super()._exit_handler(signum, frame)
def _query_feed(self, **query):
"""Query the subreddit feed with the given params, return a list of entries from oldest to newest"""
feed_url = f'https://www.reddit.com/r/{self.subreddit}/{self.path}/.rss'
# feeds seem to be limited to 100 results per page
query.setdefault('limit', 100)
# the error message from the feed says not to request more than once every two seconds and has a retry counter but it doesn't seem reliable
delay, max_delay = 15, 120
while True:
if self._last_timestamp is not None:
wait = (self._last_timestamp + timedelta(seconds=delay) - datetime.utcnow()).total_seconds()
if wait > 0:
#XXX log
print(f'@@@ wait {wait}s')
sleep(wait)
response = self.http_session.get(feed_url, params=query)
self._last_timestamp = datetime.utcnow()
if response.text.startswith('<!doctype html>'):
#XXX warning
print(f'@@@ html doctype')
print(response.text)
pass
else:
root_elem = ET.fromstring(response.text)
entries = root_elem.findall('atom:entry', self.ATOM_NS)
if len(entries) > 0:
break
delay = min(2 * delay, max_delay)
result = []
for entry_elem in entries:
author_elem = entry_elem.find('atom:author', self.ATOM_NS)
author = {
'name': author_elem.find('atom:name', self.ATOM_NS).text[3:],
'uri': author_elem.find('atom:uri', self.ATOM_NS).text,
}
category_elem = entry_elem.find('atom:category', self.ATOM_NS)
category = category_elem.attrib
content = entry_elem.find('atom:content', self.ATOM_NS).text
full_id = entry_elem.find('atom:id', self.ATOM_NS).text
link = entry_elem.find('atom:link', self.ATOM_NS).attrib['href']
updated = entry_elem.find('atom:updated', self.ATOM_NS).text
updated_dt = datetime.fromisoformat(updated).astimezone(timezone.utc)
title = entry_elem.find('atom:title', self.ATOM_NS).text
result.append({
'author_name': author['name'],
'author_uri': author['uri'],
'category': category,
'content': content,
'id': full_id,
'link': link,
'updated': updated_dt,
'title': title,
})
result.reverse()
return result
def get_last_entry(self):
"""Retrieve the newest submission from the subreddit"""
entries = self._query_feed(limit=1)
self._after_full_id = entries[0]['id']
def iter_entries(self, after=None, reset=False):
"""Infinite generator that yields entry dicts in the order they were published"""
if reset:
self.get_last_entry()
elif after is None:
# check if we've stored the id on the previous run
with closing(self.db.cursor()) as cur:
cur.execute("""
SELECT value FROM kv_store
WHERE key=%s;
""", (self._db_key,))
res = cur.fetchone()
if res is not None:
self._after_full_id = res[0]
else:
self.get_last_entry()
elif re.fullmatch(self.kind + '_' + self.base36_pattern, after):
# received full_id
self._after_full_id = after
elif re.fullmatch(self.base36_pattern, after):
# received short_id
self._after_full_id = f'{self.kind}_{after}'
else:
# extract id from submission url or raise ValueError
after_short_id = self.kind_class.id_from_url(after)
self._after_full_id = f'{self.kind}_{after_short_id}'
while True:
# we'll output from oldest to newest but Reddit shows newest first on its feed
# in order to retrieve the entries published "after" the given one
# we need to ask for the ones that appear "before" that one in the feed
for entry_dict in self._query_feed(before=self._after_full_id):
yield entry_dict
# keep track of the id for the next _query_feed() call
self._after_full_id = entry_dict['id']