@@ -52,11 +52,16 @@ type DataBlock interface {
52
52
Serialize () ([]byte , error )
53
53
}
54
54
55
+ type concatHashFuncType func ([]byte , []byte ) []byte
56
+
55
57
// HashFuncType is the signature of the hash functions used for Merkle Tree generation.
56
58
type HashFuncType func ([]byte ) ([]byte , error )
57
59
58
60
// Config is the configuration of Merkle Tree.
59
61
type Config struct {
62
+ // appendFunc is the function for concatenating two hashes.
63
+ // If SortSiblingPairs in Config is true, then the sibling pairs are first sorted and then concatenated.
64
+ concatHashFunc concatHashFuncType
60
65
// Customizable hash function used for tree generation.
61
66
HashFunc HashFuncType
62
67
// Number of goroutines run in parallel.
@@ -70,6 +75,9 @@ type Config struct {
70
75
// If true, generate a dummy node with random hash value.
71
76
// Otherwise, then the odd node situation is handled by duplicating the previous node.
72
77
NoDuplicates bool
78
+ // SortSiblingPairs is the parameter for OpenZeppelin compatibility.
79
+ // If set to `true`, the hashing sibling pairs are sorted.
80
+ SortSiblingPairs bool
73
81
}
74
82
75
83
// MerkleTree implements the Merkle Tree structure
@@ -121,27 +129,35 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
121
129
if config == nil {
122
130
config = new (Config )
123
131
}
124
- if config .HashFunc == nil {
125
- if config .RunInParallel {
126
- config .HashFunc = defaultHashFuncParal
132
+ m = & MerkleTree {Config : config }
133
+ // hash function initialization
134
+ if m .HashFunc == nil {
135
+ if m .RunInParallel {
136
+ m .HashFunc = defaultHashFuncParal // parallelized hash function must be concurrent safe
127
137
} else {
128
- config .HashFunc = defaultHashFunc
138
+ m .HashFunc = defaultHashFunc
129
139
}
130
140
}
131
141
// If the configuration mode is not set, then set it to ModeProofGen by default.
132
- if config .Mode == 0 {
133
- config .Mode = ModeProofGen
142
+ if m .Mode == 0 {
143
+ m .Mode = ModeProofGen
134
144
}
135
145
// If RunInParallel is true and NumRoutines is unset, then set NumRoutines to the number of CPU.
136
- if config .RunInParallel && config .NumRoutines == 0 {
137
- config .NumRoutines = runtime .NumCPU ()
146
+ if m .RunInParallel && m .NumRoutines == 0 {
147
+ m .NumRoutines = runtime .NumCPU ()
148
+ }
149
+ // hash concatenation function initialization
150
+ if m .SortSiblingPairs {
151
+ m .concatHashFunc = concatSortHash
152
+ } else {
153
+ m .concatHashFunc = concatHash
138
154
}
139
- m = & MerkleTree {Config : config }
140
155
m .Depth = calTreeDepth (len (blocks ))
156
+ // generic wait group initialization (for parallelized computation) and leaf generation
141
157
var wp * gool.Pool [argType , error ]
142
158
if m .RunInParallel {
143
159
// task channel capacity is passed as 0, so use the default value: 2 * numWorkers
144
- wp = gool .NewPool [argType , error ](config .NumRoutines , 0 )
160
+ wp = gool .NewPool [argType , error ](m .NumRoutines , 0 )
145
161
defer wp .Close ()
146
162
m .Leaves , err = m .leafGenParal (blocks , wp )
147
163
if err != nil {
@@ -153,6 +169,7 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
153
169
return
154
170
}
155
171
}
172
+
156
173
if m .Mode == ModeProofGen {
157
174
if m .RunInParallel {
158
175
err = m .proofGenParal (wp )
@@ -191,9 +208,21 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
191
208
}
192
209
return
193
210
}
211
+
194
212
return nil , errors .New ("invalid configuration mode" )
195
213
}
196
214
215
+ func concatHash (b1 []byte , b2 []byte ) []byte {
216
+ return append (b1 , b2 ... )
217
+ }
218
+
219
+ func concatSortHash (b1 []byte , b2 []byte ) []byte {
220
+ if bytes .Compare (b1 , b2 ) < 0 {
221
+ return append (b1 , b2 ... )
222
+ }
223
+ return append (b2 , b1 ... )
224
+ }
225
+
197
226
// calTreeDepth calculates the tree depth,
198
227
// the tree depth is then used to declare the capacity of the proof slices.
199
228
func calTreeDepth (blockLen int ) uint32 {
@@ -227,7 +256,8 @@ func (m *MerkleTree) proofGen() (err error) {
227
256
m .updateProofs (buf , numLeaves , 0 )
228
257
for step := 1 ; step < int (m .Depth ); step ++ {
229
258
for idx := 0 ; idx < prevLen ; idx += 2 {
230
- buf [idx >> 1 ], err = m .HashFunc (append (buf [idx ], buf [idx + 1 ]... ))
259
+
260
+ buf [idx >> 1 ], err = m .HashFunc (m .Config .concatHashFunc (buf [idx ], buf [idx + 1 ]))
231
261
if err != nil {
232
262
return
233
263
}
@@ -239,7 +269,8 @@ func (m *MerkleTree) proofGen() (err error) {
239
269
}
240
270
m .updateProofs (buf , prevLen , step )
241
271
}
242
- m .Root , err = m .HashFunc (append (buf [0 ], buf [1 ]... ))
272
+
273
+ m .Root , err = m .HashFunc (m .Config .concatHashFunc (buf [0 ], buf [1 ]))
243
274
return
244
275
}
245
276
@@ -317,7 +348,8 @@ func (m *MerkleTree) proofGenParal(wp *gool.Pool[argType, error]) (err error) {
317
348
}
318
349
m .updateProofsParal (buf1 , prevLen , step , wp )
319
350
}
320
- m .Root , err = m .HashFunc (append (buf1 [0 ], buf1 [1 ]... ))
351
+
352
+ m .Root , err = m .HashFunc (m .Config .concatHashFunc (buf1 [0 ], buf1 [1 ]))
321
353
return
322
354
}
323
355
@@ -563,7 +595,9 @@ func (m *MerkleTree) treeBuild(wp *gool.Pool[argType, error]) (err error) {
563
595
}
564
596
} else {
565
597
for j := 0 ; j < prevLen ; j += 2 {
566
- m .tree [i + 1 ][j >> 1 ], err = m .HashFunc (append (m.tree [i ][j ], m .tree [i ][j + 1 ]... ))
598
+ m .tree [i + 1 ][j >> 1 ], err = m .HashFunc (
599
+ m .Config .concatHashFunc (m.tree [i ][j ], m .tree [i ][j + 1 ]),
600
+ )
567
601
if err != nil {
568
602
return
569
603
}
@@ -574,7 +608,8 @@ func (m *MerkleTree) treeBuild(wp *gool.Pool[argType, error]) (err error) {
574
608
return
575
609
}
576
610
}
577
- m .Root , err = m .HashFunc (append (m .tree [m .Depth - 1 ][0 ], m .tree [m .Depth - 1 ][1 ]... ))
611
+
612
+ m .Root , err = m .HashFunc (m .Config .concatHashFunc (m .tree [m .Depth - 1 ][0 ], m .tree [m .Depth - 1 ][1 ]))
578
613
if err != nil {
579
614
return
580
615
}
@@ -614,37 +649,48 @@ func treeBuildHandler(arg argType) error {
614
649
615
650
// Verify verifies the data block with the Merkle Tree proof
616
651
func (m * MerkleTree ) Verify (dataBlock DataBlock , proof * Proof ) (bool , error ) {
617
- return Verify (dataBlock , proof , m .Root , m .HashFunc )
652
+ return Verify (dataBlock , proof , m .Root , m .Config )
618
653
}
619
654
620
655
// Verify verifies the data block with the Merkle Tree proof and Merkle root hash
621
- func Verify (dataBlock DataBlock , proof * Proof , root []byte , hashFunc HashFuncType ) (bool , error ) {
656
+ func Verify (dataBlock DataBlock , proof * Proof , root []byte , config * Config ) (bool , error ) {
622
657
if dataBlock == nil {
623
658
return false , errors .New ("data block is nil" )
624
659
}
625
660
if proof == nil {
626
661
return false , errors .New ("proof is nil" )
627
662
}
628
- if hashFunc == nil {
629
- hashFunc = defaultHashFunc
663
+
664
+ if config == nil {
665
+ config = new (Config )
666
+ }
667
+
668
+ if config .HashFunc == nil {
669
+ config .HashFunc = defaultHashFunc
630
670
}
671
+
672
+ if config .concatHashFunc == nil {
673
+ config .concatHashFunc = concatHash
674
+ }
675
+
631
676
var (
632
677
data , err = dataBlock .Serialize ()
633
678
hash []byte
634
679
)
680
+
635
681
if err != nil {
636
682
return false , err
637
683
}
638
- hash , err = hashFunc (data )
684
+ hash , err = config . HashFunc (data )
639
685
if err != nil {
640
686
return false , err
641
687
}
642
688
path := proof .Path
643
689
for _ , n := range proof .Siblings {
644
690
if path & 1 == 1 {
645
- hash , err = hashFunc ( append (hash , n ... ))
691
+ hash , err = config . HashFunc ( config . concatHashFunc (hash , n ))
646
692
} else {
647
- hash , err = hashFunc ( append (n , hash ... ))
693
+ hash , err = config . HashFunc ( config . concatHashFunc (n , hash ))
648
694
}
649
695
if err != nil {
650
696
return false , err
0 commit comments