|
14 | 14 | ## decent library "Constant-Time Toolkit" (https://github.com/pornin/CTTK)
|
15 | 15 | ## Copyright (c) 2018 Thomas Pornin <[email protected]>
|
16 | 16 |
|
| 17 | +import std/macros |
| 18 | + |
| 19 | +proc replaceNodes(node: NimNode, what: NimNode, by: NimNode): NimNode = |
| 20 | + # Replace "what" ident node by "by" |
| 21 | + if node.kind in {nnkIdent, nnkSym}: |
| 22 | + if node.eqIdent(what): by else: node |
| 23 | + elif node.len == 0: |
| 24 | + node |
| 25 | + else: |
| 26 | + let rTree = node.kind.newTree() |
| 27 | + for child in node: |
| 28 | + rTree.add replaceNodes(child, what, by) |
| 29 | + rTree |
| 30 | + |
| 31 | +macro unroll(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped = |
| 32 | + ## unroll idx over the range [start, stopEx), repeating the body for each |
| 33 | + ## iteration |
| 34 | + result = newStmtList() |
| 35 | + for i in start ..< stopEx: |
| 36 | + # block unrolledIter_{idx}{i}: body |
| 37 | + result.add nnkBlockStmt.newTree( |
| 38 | + ident("unrolledIter_" & $idx & $i), body.replaceNodes(idx, newLit i) |
| 39 | + ) |
| 40 | + |
17 | 41 | type
|
18 | 42 | HexFlags* {.pure.} = enum
|
19 | 43 | LowerCase, ## Produce lowercase hexadecimal characters
|
@@ -171,15 +195,32 @@ proc stripSpaces*(s: string): string =
|
171 | 195 | if i in allowed:
|
172 | 196 | result &= i
|
173 | 197 |
|
174 |
| -proc burnMem*(p: pointer, size: Natural) = |
175 |
| - var sp {.volatile.} = cast[ptr byte](p) |
176 |
| - var c = size |
177 |
| - if not isNil(sp): |
178 |
| - zeroMem(p, size) |
179 |
| - while c > 0: |
180 |
| - sp[] = 0 |
181 |
| - sp = cast[ptr byte](cast[uint](sp) + 1) |
182 |
| - dec(c) |
| 198 | +when defined(linux): |
| 199 | + proc c_explicit_bzero( |
| 200 | + s: pointer, n: csize_t |
| 201 | + ) {.importc: "explicit_bzero", header: "string.h".} |
| 202 | + |
| 203 | + proc burnMem*(p: pointer, size: Natural) = |
| 204 | + c_explicit_bzero(p, csize_t size) |
| 205 | + |
| 206 | +elif defined(windows): |
| 207 | + proc cSecureZeroMemory( |
| 208 | + s: pointer, n: csize_t |
| 209 | + ) {.importc: "SecureZeroMemory", header: "windows.h".} |
| 210 | + |
| 211 | + proc burnMem*(p: pointer, size: Natural) = |
| 212 | + cSecureZeroMemory(p, csize_t size) |
| 213 | + |
| 214 | +else: |
| 215 | + proc burnMem*(p: pointer, size: Natural) = |
| 216 | + var sp {.volatile.} = cast[ptr byte](p) |
| 217 | + var c = size |
| 218 | + if not isNil(sp): |
| 219 | + zeroMem(p, size) |
| 220 | + while c > 0: |
| 221 | + sp[] = 0 |
| 222 | + sp = cast[ptr byte](cast[uint](sp) + 1) |
| 223 | + dec(c) |
183 | 224 |
|
184 | 225 | proc burnArray*[T](a: var openArray[T]) {.inline.} =
|
185 | 226 | if len(a) > 0:
|
@@ -360,13 +401,85 @@ template copyMem*[A, B](dst: var openArray[A], dsto: int,
|
360 | 401 | else:
|
361 | 402 | copyMem(addr dst[dsto], unsafeAddr src[srco], length * sizeof(B))
|
362 | 403 |
|
363 |
| -template compareMem*[T](a, b: openArray[T]): bool = |
364 |
| - if len(a) != len(b): |
365 |
| - return false |
| 404 | +template offset(p: pointer, n: Natural | uint): pointer = |
| 405 | + cast[pointer](cast[uint](p) + uint n) |
| 406 | + |
| 407 | +template equalMemFull( |
| 408 | + aParam, bParam: pointer, limbs: static Natural, Limb: type SomeUnsignedInt |
| 409 | +): bool = |
| 410 | + # Length known at runtime (and assumed to be small!) - unroll the loop |
| 411 | + var |
| 412 | + res = Limb(0) |
| 413 | + aa {.noinit.}, bb {.noinit.}: Limb |
| 414 | + |
| 415 | + let |
| 416 | + a = aParam |
| 417 | + b = bParam |
| 418 | + |
| 419 | + unroll i, 0, limbs: |
| 420 | + copyMem(addr aa, a.offset((limbs - i - 1) * sizeof(Limb)), sizeof(Limb)) |
| 421 | + copyMem(addr bb, b.offset((limbs - i - 1) * sizeof(Limb)), sizeof(Limb)) |
| 422 | + res = res or (aa xor bb) |
| 423 | + |
| 424 | + res == 0 |
| 425 | + |
| 426 | +template equalMemFull( |
| 427 | + aParam, bParam: pointer, limbsParam: Natural, Limb: type SomeUnsignedInt |
| 428 | +): bool = |
366 | 429 | var
|
367 |
| - n = len(a) |
368 |
| - res = 0'u8 |
369 |
| - while n > 0: |
370 |
| - dec(n) |
371 |
| - res = res or (a[n] xor b[n]) |
372 |
| - res == 0'u8 |
| 430 | + res = Limb(0) |
| 431 | + aa {.noinit.}, bb {.noinit.}: Limb |
| 432 | + |
| 433 | + let |
| 434 | + a = aParam |
| 435 | + b = bParam |
| 436 | + limbs = uint limbsParam # avoid range checks |
| 437 | + |
| 438 | + for i in uint(0)..<limbs: |
| 439 | + copyMem( |
| 440 | + addr aa, a.offset((limbs - i - 1) * uint sizeof(Limb)), sizeof(Limb)) |
| 441 | + copyMem( |
| 442 | + addr bb, b.offset((limbs - i - 1) * uint sizeof(Limb)), sizeof(Limb)) |
| 443 | + res = res or (aa xor bb) |
| 444 | + |
| 445 | + res == 0 |
| 446 | + |
| 447 | +proc equalMemFull*(a, b: pointer, len: static Natural): bool = |
| 448 | + when len mod sizeof(uint64) == 0: |
| 449 | + equalMemFull(a, b, len div sizeof(uint64), uint64) |
| 450 | + elif len mod sizeof(uint32) == 0: |
| 451 | + equalMemFull(a, b, len div sizeof(uint32), uint32) |
| 452 | + elif len mod sizeof(uint16) == 0: |
| 453 | + equalMemFull(a, b, len div sizeof(uint16), uint16) |
| 454 | + else: |
| 455 | + equalMemFull(a, b, len, uint8) |
| 456 | + |
| 457 | +proc equalMemFull*[I; T](a, b: array[I, T]): bool = |
| 458 | + when nimvm: |
| 459 | + a == b |
| 460 | + else: |
| 461 | + const bytes = a.len * sizeof(T) |
| 462 | + equalMemFull(unsafeAddr a[0], unsafeAddr b[0], bytes) |
| 463 | + |
| 464 | +proc equalMemFull*[T](a, b: openArray[T]): bool = |
| 465 | + when nimvm: |
| 466 | + a == b |
| 467 | + else: |
| 468 | + if a.len == b.len: |
| 469 | + if a.len == 0: |
| 470 | + true |
| 471 | + else: |
| 472 | + let |
| 473 | + bytes = a.len * sizeof(T) |
| 474 | + ap = unsafeAddr a[0] |
| 475 | + bp = unsafeAddr b[0] |
| 476 | + if bytes mod sizeof(uint64) == 0: |
| 477 | + equalMemFull(ap, bp, bytes div sizeof(uint64), uint64) |
| 478 | + elif bytes mod sizeof(uint32) == 0: |
| 479 | + equalMemFull(ap, bp, bytes div sizeof(uint32), uint32) |
| 480 | + elif bytes mod sizeof(uint16) == 0: |
| 481 | + equalMemFull(ap, bp, bytes div sizeof(uint16), uint16) |
| 482 | + else: |
| 483 | + equalMemFull(ap, bp, bytes, uint8) |
| 484 | + else: |
| 485 | + false |
0 commit comments