Skip to content

Commit 747e45a

Browse files
authored
Improve burmMem / equalMemFull performance (#79)
* Improve burmMem / equalMemFull performance * Linux and Windows define special `c` functions for clearing memory that the compiler is guaranteed to not remove - in addition to being safer, they are typically also faster (because they don't have to operate byte-by-byte) * `equalMemFull` can be implemented using unrolling and larger limbs than `byte` resulting in a significant speedup - this PR also removes some of the range checking done by previous code that would introduce branching in the C code where non is desired * include windows
1 parent a0b65f2 commit 747e45a

File tree

3 files changed

+137
-28
lines changed

3 files changed

+137
-28
lines changed

nimcrypto/bcmode.nim

+1-1
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ proc decrypt*[T](ctx: var GCM[T], input: openArray[byte],
11071107
let uselen = min(len(tag), 16)
11081108
ctx.decrypt(input, output)
11091109
ctx.getTag(dataTag.toOpenArray(0, uselen - 1))
1110-
compareMem(tag.toOpenArray(0, uselen - 1), dataTag.toOpenArray(0, uselen - 1))
1110+
equalMemFull(tag.toOpenArray(0, uselen - 1), dataTag.toOpenArray(0, uselen - 1))
11111111

11121112
proc clear*[T](ctx: var GCM[T]) {.inline.} =
11131113
## Clear ``GCM[T]`` context ``ctx``.

nimcrypto/hash.nim

+5-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
## This module provides helper procedures for calculating secure digests
1111
## supported by `nimcrypto` library.
12-
import utils
12+
import ./utils
1313

1414
const
1515
MaxMDigestLength* = 64
@@ -152,14 +152,10 @@ proc `==`*[A, B](d1: MDigest[A], d2: MDigest[B]): bool =
152152
## Check for equality between two ``MDigest`` objects ``d1`` and ``d2``.
153153
## If size in bits of ``d1`` is not equal to size in bits of ``d2`` then
154154
## digests considered as not equal.
155-
if d1.bits != d2.bits:
156-
return false
157-
var n = len(d1.data)
158-
var res = 0
159-
while n > 0:
160-
dec(n)
161-
res = res or int(d1.data[n] xor d2.data[n])
162-
result = (res == 0)
155+
when d1.bits == d2.bits:
156+
equalMemFull(d1.data, d2.data)
157+
else:
158+
false
163159

164160
when true:
165161
proc toDigestAux(n: static int, s: static string): MDigest[n] =

nimcrypto/utils.nim

+131-18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,30 @@
1414
## decent library "Constant-Time Toolkit" (https://github.com/pornin/CTTK)
1515
## Copyright (c) 2018 Thomas Pornin <[email protected]>
1616

17+
import std/macros
18+
19+
proc replaceNodes(node: NimNode, what: NimNode, by: NimNode): NimNode =
20+
# Replace "what" ident node by "by"
21+
if node.kind in {nnkIdent, nnkSym}:
22+
if node.eqIdent(what): by else: node
23+
elif node.len == 0:
24+
node
25+
else:
26+
let rTree = node.kind.newTree()
27+
for child in node:
28+
rTree.add replaceNodes(child, what, by)
29+
rTree
30+
31+
macro unroll(idx: untyped{nkIdent}, start, stopEx: static int, body: untyped): untyped =
32+
## unroll idx over the range [start, stopEx), repeating the body for each
33+
## iteration
34+
result = newStmtList()
35+
for i in start ..< stopEx:
36+
# block unrolledIter_{idx}{i}: body
37+
result.add nnkBlockStmt.newTree(
38+
ident("unrolledIter_" & $idx & $i), body.replaceNodes(idx, newLit i)
39+
)
40+
1741
type
1842
HexFlags* {.pure.} = enum
1943
LowerCase, ## Produce lowercase hexadecimal characters
@@ -171,15 +195,32 @@ proc stripSpaces*(s: string): string =
171195
if i in allowed:
172196
result &= i
173197

174-
proc burnMem*(p: pointer, size: Natural) =
175-
var sp {.volatile.} = cast[ptr byte](p)
176-
var c = size
177-
if not isNil(sp):
178-
zeroMem(p, size)
179-
while c > 0:
180-
sp[] = 0
181-
sp = cast[ptr byte](cast[uint](sp) + 1)
182-
dec(c)
198+
when defined(linux):
199+
proc c_explicit_bzero(
200+
s: pointer, n: csize_t
201+
) {.importc: "explicit_bzero", header: "string.h".}
202+
203+
proc burnMem*(p: pointer, size: Natural) =
204+
c_explicit_bzero(p, csize_t size)
205+
206+
elif defined(windows):
207+
proc cSecureZeroMemory(
208+
s: pointer, n: csize_t
209+
) {.importc: "SecureZeroMemory", header: "windows.h".}
210+
211+
proc burnMem*(p: pointer, size: Natural) =
212+
cSecureZeroMemory(p, csize_t size)
213+
214+
else:
215+
proc burnMem*(p: pointer, size: Natural) =
216+
var sp {.volatile.} = cast[ptr byte](p)
217+
var c = size
218+
if not isNil(sp):
219+
zeroMem(p, size)
220+
while c > 0:
221+
sp[] = 0
222+
sp = cast[ptr byte](cast[uint](sp) + 1)
223+
dec(c)
183224

184225
proc burnArray*[T](a: var openArray[T]) {.inline.} =
185226
if len(a) > 0:
@@ -360,13 +401,85 @@ template copyMem*[A, B](dst: var openArray[A], dsto: int,
360401
else:
361402
copyMem(addr dst[dsto], unsafeAddr src[srco], length * sizeof(B))
362403

363-
template compareMem*[T](a, b: openArray[T]): bool =
364-
if len(a) != len(b):
365-
return false
404+
template offset(p: pointer, n: Natural | uint): pointer =
405+
cast[pointer](cast[uint](p) + uint n)
406+
407+
template equalMemFull(
408+
aParam, bParam: pointer, limbs: static Natural, Limb: type SomeUnsignedInt
409+
): bool =
410+
# Length known at runtime (and assumed to be small!) - unroll the loop
411+
var
412+
res = Limb(0)
413+
aa {.noinit.}, bb {.noinit.}: Limb
414+
415+
let
416+
a = aParam
417+
b = bParam
418+
419+
unroll i, 0, limbs:
420+
copyMem(addr aa, a.offset((limbs - i - 1) * sizeof(Limb)), sizeof(Limb))
421+
copyMem(addr bb, b.offset((limbs - i - 1) * sizeof(Limb)), sizeof(Limb))
422+
res = res or (aa xor bb)
423+
424+
res == 0
425+
426+
template equalMemFull(
427+
aParam, bParam: pointer, limbsParam: Natural, Limb: type SomeUnsignedInt
428+
): bool =
366429
var
367-
n = len(a)
368-
res = 0'u8
369-
while n > 0:
370-
dec(n)
371-
res = res or (a[n] xor b[n])
372-
res == 0'u8
430+
res = Limb(0)
431+
aa {.noinit.}, bb {.noinit.}: Limb
432+
433+
let
434+
a = aParam
435+
b = bParam
436+
limbs = uint limbsParam # avoid range checks
437+
438+
for i in uint(0)..<limbs:
439+
copyMem(
440+
addr aa, a.offset((limbs - i - 1) * uint sizeof(Limb)), sizeof(Limb))
441+
copyMem(
442+
addr bb, b.offset((limbs - i - 1) * uint sizeof(Limb)), sizeof(Limb))
443+
res = res or (aa xor bb)
444+
445+
res == 0
446+
447+
proc equalMemFull*(a, b: pointer, len: static Natural): bool =
448+
when len mod sizeof(uint64) == 0:
449+
equalMemFull(a, b, len div sizeof(uint64), uint64)
450+
elif len mod sizeof(uint32) == 0:
451+
equalMemFull(a, b, len div sizeof(uint32), uint32)
452+
elif len mod sizeof(uint16) == 0:
453+
equalMemFull(a, b, len div sizeof(uint16), uint16)
454+
else:
455+
equalMemFull(a, b, len, uint8)
456+
457+
proc equalMemFull*[I; T](a, b: array[I, T]): bool =
458+
when nimvm:
459+
a == b
460+
else:
461+
const bytes = a.len * sizeof(T)
462+
equalMemFull(unsafeAddr a[0], unsafeAddr b[0], bytes)
463+
464+
proc equalMemFull*[T](a, b: openArray[T]): bool =
465+
when nimvm:
466+
a == b
467+
else:
468+
if a.len == b.len:
469+
if a.len == 0:
470+
true
471+
else:
472+
let
473+
bytes = a.len * sizeof(T)
474+
ap = unsafeAddr a[0]
475+
bp = unsafeAddr b[0]
476+
if bytes mod sizeof(uint64) == 0:
477+
equalMemFull(ap, bp, bytes div sizeof(uint64), uint64)
478+
elif bytes mod sizeof(uint32) == 0:
479+
equalMemFull(ap, bp, bytes div sizeof(uint32), uint32)
480+
elif bytes mod sizeof(uint16) == 0:
481+
equalMemFull(ap, bp, bytes div sizeof(uint16), uint16)
482+
else:
483+
equalMemFull(ap, bp, bytes, uint8)
484+
else:
485+
false

0 commit comments

Comments
 (0)