@@ -73,34 +73,34 @@ def savepoint(cr):
73
73
yield
74
74
75
75
76
- def _parallel_execute_serial (cr , queries , logger = _logger ):
77
- cnt = 0
76
+ def _parallel_execute_serial (cr , queries , logger , returning_id ):
77
+ res = [] if returning_id else 0
78
78
for query in log_progress (queries , logger , qualifier = "queries" , size = len (queries )):
79
79
cr .execute (query )
80
- cnt += cr .rowcount
81
- return cnt
80
+ res += cr . fetchall () if returning_id else cr .rowcount
81
+ return res
82
82
83
83
84
84
if ThreadPoolExecutor is not None :
85
85
86
- def _parallel_execute_threaded (cr , queries , logger = _logger ):
86
+ def _parallel_execute_threaded (cr , queries , logger , returning_id ):
87
87
if not queries :
88
88
return None
89
89
90
90
if len (queries ) == 1 :
91
91
# No need to spawn other threads
92
92
cr .execute (queries [0 ])
93
- return cr .rowcount
93
+ return cr .fetchall () if returning_id else cr . rowcount
94
94
95
95
max_workers = min (get_max_workers (), len (queries ))
96
96
cursor = db_connect (cr .dbname ).cursor
97
97
98
98
def execute (query ):
99
99
with cursor () as tcr :
100
100
tcr .execute (query )
101
- cnt = tcr .rowcount
101
+ res = tcr . fetchall () if returning_id else tcr .rowcount
102
102
tcr .commit ()
103
- return cnt
103
+ return res
104
104
105
105
cr .commit ()
106
106
@@ -109,7 +109,7 @@ def execute(query):
109
109
errorcodes .SERIALIZATION_FAILURE ,
110
110
}
111
111
failed_queries = []
112
- tot_cnt = 0
112
+ tot_res = [] if returning_id else 0
113
113
with ThreadPoolExecutor (max_workers = max_workers ) as executor :
114
114
future_queries = {executor .submit (execute , q ): q for q in queries }
115
115
for future in log_progress (
@@ -121,7 +121,7 @@ def execute(query):
121
121
log_hundred_percent = True ,
122
122
):
123
123
try :
124
- tot_cnt += future .result () or 0
124
+ tot_res += future .result () or ([] if returning_id else 0 )
125
125
except psycopg2 .OperationalError as exc :
126
126
if exc .pgcode not in CONCURRENCY_ERRORCODES :
127
127
raise
@@ -131,16 +131,16 @@ def execute(query):
131
131
132
132
if failed_queries :
133
133
logger .warning ("Serialize queries that failed due to concurrency issues" )
134
- tot_cnt += _parallel_execute_serial (cr , failed_queries , logger = logger )
134
+ tot_res += _parallel_execute_serial (cr , failed_queries , logger , returning_id )
135
135
cr .commit ()
136
136
137
- return tot_cnt
137
+ return tot_res
138
138
139
139
else :
140
140
_parallel_execute_threaded = _parallel_execute_serial
141
141
142
142
143
- def parallel_execute (cr , queries , logger = _logger ):
143
+ def parallel_execute (cr , queries , logger = _logger , returning_id = False ):
144
144
"""
145
145
Execute queries in parallel.
146
146
@@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger):
154
154
155
155
:param list(str) queries: list of queries to execute concurrently
156
156
:param `~logging.Logger` logger: logger used to report the progress
157
- :return: the sum of `cr.rowcount` for each query run
157
+ :param bool returning_id: wether to return a tuple of affected ids (default: return affected row count)
158
+ :return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id`
158
159
:rtype: int
159
160
160
161
.. warning::
162
+ - As a side effect, the cursor will be committed.
163
+
161
164
- Due to the nature of `cr.rowcount`, the return value of this function may represent an
162
165
underestimate of the real number of affected records. For instance, when some records
163
166
are deleted/updated as a result of an `ondelete` clause, they won't be taken into account.
164
167
165
- - As a side effect, the cursor will be committed.
168
+ - It would not be generally safe to use this function for selecting queries. Because of this,
169
+ `returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the
170
+ caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order.
166
171
167
172
.. note::
168
173
If a concurrency issue occurs, the *failing* queries will be retried sequentially.
@@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger):
172
177
if getattr (threading .current_thread (), "testing" , False )
173
178
else _parallel_execute_threaded
174
179
)
175
- return parallel_execute_impl (cr , queries , logger = _logger )
180
+
181
+ if returning_id :
182
+ returning_id_re = re .compile (r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$" )
183
+ if not all ((bool (returning_id_re .search (q )) for q in queries )):
184
+ raise ValueError ("The returning_id parameter can only be used with certain queries." )
185
+
186
+ res = parallel_execute_impl (cr , queries , logger , returning_id )
187
+ return tuple (sorted ([id for (id ,) in res ])) if returning_id else res
176
188
177
189
178
190
def format_query (cr , query , * args , ** kwargs ):
0 commit comments