@@ -191,14 +191,11 @@ function kernel_BTRS!(
191
191
@inbounds n = count[I1]
192
192
@inbounds p = prob[CartesianIndex (I1, I2)]
193
193
end
194
- # BTRS approximations work well for p <= 0.5
195
- # invert p and set `invert` flag
196
- (invert = p > 0.5f0 ) && (p = 1 - p)
197
194
else
198
195
n = 0
199
196
p = 0f0
200
197
end
201
-
198
+
202
199
# SAMPLER
203
200
# edge cases
204
201
if p <= 0 || n <= 0
@@ -213,66 +210,120 @@ function kernel_BTRS!(
213
210
rand (Float32) < p && (k += 1 )
214
211
ctr += 1
215
212
end
216
- # Use inversion algorithm for n*p < 10
217
- elseif n * p < 10f0
218
- logp = CUDA. log (1f0 - p)
219
- geom_sum = 0f0
220
- k = 0
221
- while true
222
- geom = ceil (CUDA. log (rand (Float32)) / logp)
223
- geom_sum += geom
224
- geom_sum > n && break
225
- k += 1
226
- end
227
- # BTRS algorithm
228
- else
229
- r = p/ (1f0 - p)
230
- s = p* (1f0 - p)
231
-
232
- stddev = sqrt (n * s)
233
- b = 1.15f0 + 2.53f0 * stddev
234
- a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
235
- c = n * p + 0.5f0
236
- v_r = 0.92f0 - 4.2f0 / b
237
-
238
- alpha = (2.83f0 + 5.1f0 / b) * stddev;
239
- m = floor ((n + 1 ) * p)
240
-
241
- ks = 0f0
242
-
243
- while true
244
- usample = rand (Float32) - 0.5f0
245
- vsample = rand (Float32)
246
-
247
- us = 0.5f0 - abs (usample)
248
- ks = floor ((2 * a / us + b) * usample + c)
249
-
250
- if us >= 0.07f0 && vsample <= v_r
251
- break
213
+ elseif p <= 0.5f0
214
+ # Use inversion algorithm for n*p < 10
215
+ if n * p < 10f0
216
+ logp = CUDA. log (1f0 - p)
217
+ geom_sum = 0f0
218
+ k = 0
219
+ while true
220
+ geom = ceil (CUDA. log (rand (Float32)) / logp)
221
+ geom_sum += geom
222
+ geom_sum > n && break
223
+ k += 1
252
224
end
225
+ # BTRS algorithm
226
+ else
227
+ r = p/ (1f0 - p)
228
+ s = p* (1f0 - p)
229
+
230
+ stddev = sqrt (n * s)
231
+ b = 1.15f0 + 2.53f0 * stddev
232
+ a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
233
+ c = n * p + 0.5f0
234
+ v_r = 0.92f0 - 4.2f0 / b
235
+
236
+ alpha = (2.83f0 + 5.1f0 / b) * stddev;
237
+ m = floor ((n + 1 ) * p)
253
238
254
- if ks < 0 || ks > n
255
- continue
239
+ ks = 0f0
240
+
241
+ while true
242
+ usample = rand (Float32) - 0.5f0
243
+ vsample = rand (Float32)
244
+
245
+ us = 0.5f0 - abs (usample)
246
+ ks = floor ((2 * a / us + b) * usample + c)
247
+
248
+ if us >= 0.07f0 && vsample <= v_r
249
+ break
250
+ end
251
+
252
+ if ks < 0 || ks > n
253
+ continue
254
+ end
255
+
256
+ v2 = CUDA. log (vsample * alpha / (a / (us * us) + b))
257
+ ub = (m + 0.5f0 ) * CUDA. log ((m + 1 ) / (r * (n - m + 1 ))) +
258
+ (n + 1 ) * CUDA. log ((n - m + 1 ) / (n - ks + 1 )) +
259
+ (ks + 0.5f0 ) * CUDA. log (r * (n - ks + 1 ) / (ks + 1 )) +
260
+ stirling_approx_tail (m) + stirling_approx_tail (n - m) - stirling_approx_tail (ks) - stirling_approx_tail (n - ks)
261
+ if v2 <= ub
262
+ break
263
+ end
264
+ end
265
+ k = Int (ks)
266
+ end
267
+ elseif p > 0.5f0
268
+ p = 1 - p
269
+ # Use inversion algorithm for n*p < 10
270
+ if n * p < 10f0
271
+ logp = CUDA. log (1f0 - p)
272
+ geom_sum = 0f0
273
+ k = 0
274
+ while true
275
+ geom = ceil (CUDA. log (rand (Float32)) / logp)
276
+ geom_sum += geom
277
+ geom_sum > n && break
278
+ k += 1
256
279
end
280
+ # BTRS algorithm
281
+ else
282
+ r = p/ (1f0 - p)
283
+ s = p* (1f0 - p)
284
+
285
+ stddev = sqrt (n * s)
286
+ b = 1.15f0 + 2.53f0 * stddev
287
+ a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
288
+ c = n * p + 0.5f0
289
+ v_r = 0.92f0 - 4.2f0 / b
290
+
291
+ alpha = (2.83f0 + 5.1f0 / b) * stddev;
292
+ m = floor ((n + 1 ) * p)
257
293
258
- v2 = CUDA. log (vsample * alpha / (a / (us * us) + b))
259
- ub = (m + 0.5f0 ) * CUDA. log ((m + 1 ) / (r * (n - m + 1 ))) +
260
- (n + 1 ) * CUDA. log ((n - m + 1 ) / (n - ks + 1 )) +
261
- (ks + 0.5f0 ) * CUDA. log (r * (n - ks + 1 ) / (ks + 1 )) +
262
- stirling_approx_tail (m) + stirling_approx_tail (n - m) - stirling_approx_tail (ks) - stirling_approx_tail (n - ks)
263
- if v2 <= ub
264
- break
294
+ ks = 0f0
295
+
296
+ while true
297
+ usample = rand (Float32) - 0.5f0
298
+ vsample = rand (Float32)
299
+
300
+ us = 0.5f0 - abs (usample)
301
+ ks = floor ((2 * a / us + b) * usample + c)
302
+
303
+ if us >= 0.07f0 && vsample <= v_r
304
+ break
305
+ end
306
+
307
+ if ks < 0 || ks > n
308
+ continue
309
+ end
310
+
311
+ v2 = CUDA. log (vsample * alpha / (a / (us * us) + b))
312
+ ub = (m + 0.5f0 ) * CUDA. log ((m + 1 ) / (r * (n - m + 1 ))) +
313
+ (n + 1 ) * CUDA. log ((n - m + 1 ) / (n - ks + 1 )) +
314
+ (ks + 0.5f0 ) * CUDA. log (r * (n - ks + 1 ) / (ks + 1 )) +
315
+ stirling_approx_tail (m) + stirling_approx_tail (n - m) - stirling_approx_tail (ks) - stirling_approx_tail (n - ks)
316
+ if v2 <= ub
317
+ break
318
+ end
265
319
end
320
+ k = Int (ks)
266
321
end
267
- k = Int (ks)
322
+ k = n - k
268
323
end
269
324
270
325
if i <= length (A)
271
- if invert
272
- @inbounds A[i] = n - k
273
- else
274
- @inbounds A[i] = k
275
- end
326
+ @inbounds A[i] = k
276
327
end
277
328
offset += window
278
329
end
0 commit comments