Skip to content

Commit 5f07404

Browse files
committed
pg: add returning_id option to parallel_execute
The default behavior is unchanged. This adds the possibility to parallelize modifying queries that have a `RETURNING id` clause. For those, return the resulting ids (in a defined order) instead of the affected row count. To avoid misuse add a warning to the docstring and try to detect queries other than the ones of the intended form. Raise an error if such are found.
1 parent 05aae80 commit 5f07404

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

src/util/pg.py

+28-16
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,34 @@ def savepoint(cr):
7373
yield
7474

7575

76-
def _parallel_execute_serial(cr, queries, logger=_logger):
77-
cnt = 0
76+
def _parallel_execute_serial(cr, queries, logger, returning_id):
77+
res = [] if returning_id else 0
7878
for query in log_progress(queries, logger, qualifier="queries", size=len(queries)):
7979
cr.execute(query)
80-
cnt += cr.rowcount
81-
return cnt
80+
res += cr.fetchall() if returning_id else cr.rowcount
81+
return res
8282

8383

8484
if ThreadPoolExecutor is not None:
8585

86-
def _parallel_execute_threaded(cr, queries, logger=_logger):
86+
def _parallel_execute_threaded(cr, queries, logger, returning_id):
8787
if not queries:
8888
return None
8989

9090
if len(queries) == 1:
9191
# No need to spawn other threads
9292
cr.execute(queries[0])
93-
return cr.rowcount
93+
return cr.fetchall() if returning_id else cr.rowcount
9494

9595
max_workers = min(get_max_workers(), len(queries))
9696
cursor = db_connect(cr.dbname).cursor
9797

9898
def execute(query):
9999
with cursor() as tcr:
100100
tcr.execute(query)
101-
cnt = tcr.rowcount
101+
res = tcr.fetchall() if returning_id else tcr.rowcount
102102
tcr.commit()
103-
return cnt
103+
return res
104104

105105
cr.commit()
106106

@@ -109,7 +109,7 @@ def execute(query):
109109
errorcodes.SERIALIZATION_FAILURE,
110110
}
111111
failed_queries = []
112-
tot_cnt = 0
112+
tot_res = [] if returning_id else 0
113113
with ThreadPoolExecutor(max_workers=max_workers) as executor:
114114
future_queries = {executor.submit(execute, q): q for q in queries}
115115
for future in log_progress(
@@ -121,7 +121,7 @@ def execute(query):
121121
log_hundred_percent=True,
122122
):
123123
try:
124-
tot_cnt += future.result() or 0
124+
tot_res += future.result() or ([] if returning_id else 0)
125125
except psycopg2.OperationalError as exc:
126126
if exc.pgcode not in CONCURRENCY_ERRORCODES:
127127
raise
@@ -131,16 +131,16 @@ def execute(query):
131131

132132
if failed_queries:
133133
logger.warning("Serialize queries that failed due to concurrency issues")
134-
tot_cnt += _parallel_execute_serial(cr, failed_queries, logger=logger)
134+
tot_res += _parallel_execute_serial(cr, failed_queries, logger, returning_id)
135135
cr.commit()
136136

137-
return tot_cnt
137+
return tot_res
138138

139139
else:
140140
_parallel_execute_threaded = _parallel_execute_serial
141141

142142

143-
def parallel_execute(cr, queries, logger=_logger):
143+
def parallel_execute(cr, queries, logger=_logger, returning_id=False):
144144
"""
145145
Execute queries in parallel.
146146
@@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger):
154154
155155
:param list(str) queries: list of queries to execute concurrently
156156
:param `~logging.Logger` logger: logger used to report the progress
157-
:return: the sum of `cr.rowcount` for each query run
157+
:param bool returning_id: wether to return a tuple of affected ids (default: return affected row count)
158+
:return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id`
158159
:rtype: int
159160
160161
.. warning::
162+
- As a side effect, the cursor will be committed.
163+
161164
- Due to the nature of `cr.rowcount`, the return value of this function may represent an
162165
underestimate of the real number of affected records. For instance, when some records
163166
are deleted/updated as a result of an `ondelete` clause, they won't be taken into account.
164167
165-
- As a side effect, the cursor will be committed.
168+
- It would not be generally safe to use this function for selecting queries. Because of this,
169+
`returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the
170+
caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order.
166171
167172
.. note::
168173
If a concurrency issue occurs, the *failing* queries will be retried sequentially.
@@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger):
172177
if getattr(threading.current_thread(), "testing", False)
173178
else _parallel_execute_threaded
174179
)
175-
return parallel_execute_impl(cr, queries, logger=_logger)
180+
181+
if returning_id:
182+
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
183+
if not all((bool(returning_id_re.search(q)) for q in queries)):
184+
raise ValueError("The returning_id parameter can only be used with certain queries.")
185+
186+
res = parallel_execute_impl(cr, queries, logger, returning_id)
187+
return tuple(sorted([id for (id,) in res])) if returning_id else res
176188

177189

178190
def format_query(cr, query, *args, **kwargs):

0 commit comments

Comments
 (0)