Skip to content

Commit 2927e5b

Browse files
Tommy TIANAirSwapTeam
Tommy TIAN
andauthored
Add support for OpenZeppelin sort-pair verifcation (#21)
* feat: add sortpairs in config to support openzeppelin verify * Refactor for OpenZeppelin sort-pair verification support. * Refactor. Signed-off-by: txaty <[email protected]> * Add unit tests. Signed-off-by: txaty <[email protected]> * Fix bugs and add unit tests. Signed-off-by: txaty <[email protected]> Signed-off-by: txaty <[email protected]> Co-authored-by: Vian <[email protected]>
1 parent 6ccb37a commit 2927e5b

File tree

4 files changed

+141
-36
lines changed

4 files changed

+141
-36
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ module github.com/txaty/go-merkletree
33
go 1.19
44

55
require (
6-
github.com/agiledragon/gomonkey/v2 v2.8.0
6+
github.com/agiledragon/gomonkey/v2 v2.9.0
77
github.com/txaty/gool v0.1.4
88
)

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
github.com/agiledragon/gomonkey/v2 v2.8.0 h1:u2K2nNGyk0ippzklz1CWalllEB9ptD+DtSXeCX5O000=
2-
github.com/agiledragon/gomonkey/v2 v2.8.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY=
1+
github.com/agiledragon/gomonkey/v2 v2.9.0 h1:PDiKKybR596O6FHW+RVSG0Z7uGCBNbmbUXh3uCNQ7Hc=
2+
github.com/agiledragon/gomonkey/v2 v2.9.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY=
33
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
44
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
55
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=

merkle_tree.go

+68-22
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,16 @@ type DataBlock interface {
5252
Serialize() ([]byte, error)
5353
}
5454

55+
type concatHashFuncType func([]byte, []byte) []byte
56+
5557
// HashFuncType is the signature of the hash functions used for Merkle Tree generation.
5658
type HashFuncType func([]byte) ([]byte, error)
5759

5860
// Config is the configuration of Merkle Tree.
5961
type Config struct {
62+
// appendFunc is the function for concatenating two hashes.
63+
// If SortSiblingPairs in Config is true, then the sibling pairs are first sorted and then concatenated.
64+
concatHashFunc concatHashFuncType
6065
// Customizable hash function used for tree generation.
6166
HashFunc HashFuncType
6267
// Number of goroutines run in parallel.
@@ -70,6 +75,9 @@ type Config struct {
7075
// If true, generate a dummy node with random hash value.
7176
// Otherwise, then the odd node situation is handled by duplicating the previous node.
7277
NoDuplicates bool
78+
// SortSiblingPairs is the parameter for OpenZeppelin compatibility.
79+
// If set to `true`, the hashing sibling pairs are sorted.
80+
SortSiblingPairs bool
7381
}
7482

7583
// MerkleTree implements the Merkle Tree structure
@@ -121,27 +129,35 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
121129
if config == nil {
122130
config = new(Config)
123131
}
124-
if config.HashFunc == nil {
125-
if config.RunInParallel {
126-
config.HashFunc = defaultHashFuncParal
132+
m = &MerkleTree{Config: config}
133+
// hash function initialization
134+
if m.HashFunc == nil {
135+
if m.RunInParallel {
136+
m.HashFunc = defaultHashFuncParal // parallelized hash function must be concurrent safe
127137
} else {
128-
config.HashFunc = defaultHashFunc
138+
m.HashFunc = defaultHashFunc
129139
}
130140
}
131141
// If the configuration mode is not set, then set it to ModeProofGen by default.
132-
if config.Mode == 0 {
133-
config.Mode = ModeProofGen
142+
if m.Mode == 0 {
143+
m.Mode = ModeProofGen
134144
}
135145
// If RunInParallel is true and NumRoutines is unset, then set NumRoutines to the number of CPU.
136-
if config.RunInParallel && config.NumRoutines == 0 {
137-
config.NumRoutines = runtime.NumCPU()
146+
if m.RunInParallel && m.NumRoutines == 0 {
147+
m.NumRoutines = runtime.NumCPU()
148+
}
149+
// hash concatenation function initialization
150+
if m.SortSiblingPairs {
151+
m.concatHashFunc = concatSortHash
152+
} else {
153+
m.concatHashFunc = concatHash
138154
}
139-
m = &MerkleTree{Config: config}
140155
m.Depth = calTreeDepth(len(blocks))
156+
// generic wait group initialization (for parallelized computation) and leaf generation
141157
var wp *gool.Pool[argType, error]
142158
if m.RunInParallel {
143159
// task channel capacity is passed as 0, so use the default value: 2 * numWorkers
144-
wp = gool.NewPool[argType, error](config.NumRoutines, 0)
160+
wp = gool.NewPool[argType, error](m.NumRoutines, 0)
145161
defer wp.Close()
146162
m.Leaves, err = m.leafGenParal(blocks, wp)
147163
if err != nil {
@@ -153,6 +169,7 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
153169
return
154170
}
155171
}
172+
156173
if m.Mode == ModeProofGen {
157174
if m.RunInParallel {
158175
err = m.proofGenParal(wp)
@@ -191,9 +208,21 @@ func New(config *Config, blocks []DataBlock) (m *MerkleTree, err error) {
191208
}
192209
return
193210
}
211+
194212
return nil, errors.New("invalid configuration mode")
195213
}
196214

215+
func concatHash(b1 []byte, b2 []byte) []byte {
216+
return append(b1, b2...)
217+
}
218+
219+
func concatSortHash(b1 []byte, b2 []byte) []byte {
220+
if bytes.Compare(b1, b2) < 0 {
221+
return append(b1, b2...)
222+
}
223+
return append(b2, b1...)
224+
}
225+
197226
// calTreeDepth calculates the tree depth,
198227
// the tree depth is then used to declare the capacity of the proof slices.
199228
func calTreeDepth(blockLen int) uint32 {
@@ -227,7 +256,8 @@ func (m *MerkleTree) proofGen() (err error) {
227256
m.updateProofs(buf, numLeaves, 0)
228257
for step := 1; step < int(m.Depth); step++ {
229258
for idx := 0; idx < prevLen; idx += 2 {
230-
buf[idx>>1], err = m.HashFunc(append(buf[idx], buf[idx+1]...))
259+
260+
buf[idx>>1], err = m.HashFunc(m.Config.concatHashFunc(buf[idx], buf[idx+1]))
231261
if err != nil {
232262
return
233263
}
@@ -239,7 +269,8 @@ func (m *MerkleTree) proofGen() (err error) {
239269
}
240270
m.updateProofs(buf, prevLen, step)
241271
}
242-
m.Root, err = m.HashFunc(append(buf[0], buf[1]...))
272+
273+
m.Root, err = m.HashFunc(m.Config.concatHashFunc(buf[0], buf[1]))
243274
return
244275
}
245276

@@ -317,7 +348,8 @@ func (m *MerkleTree) proofGenParal(wp *gool.Pool[argType, error]) (err error) {
317348
}
318349
m.updateProofsParal(buf1, prevLen, step, wp)
319350
}
320-
m.Root, err = m.HashFunc(append(buf1[0], buf1[1]...))
351+
352+
m.Root, err = m.HashFunc(m.Config.concatHashFunc(buf1[0], buf1[1]))
321353
return
322354
}
323355

@@ -563,7 +595,9 @@ func (m *MerkleTree) treeBuild(wp *gool.Pool[argType, error]) (err error) {
563595
}
564596
} else {
565597
for j := 0; j < prevLen; j += 2 {
566-
m.tree[i+1][j>>1], err = m.HashFunc(append(m.tree[i][j], m.tree[i][j+1]...))
598+
m.tree[i+1][j>>1], err = m.HashFunc(
599+
m.Config.concatHashFunc(m.tree[i][j], m.tree[i][j+1]),
600+
)
567601
if err != nil {
568602
return
569603
}
@@ -574,7 +608,8 @@ func (m *MerkleTree) treeBuild(wp *gool.Pool[argType, error]) (err error) {
574608
return
575609
}
576610
}
577-
m.Root, err = m.HashFunc(append(m.tree[m.Depth-1][0], m.tree[m.Depth-1][1]...))
611+
612+
m.Root, err = m.HashFunc(m.Config.concatHashFunc(m.tree[m.Depth-1][0], m.tree[m.Depth-1][1]))
578613
if err != nil {
579614
return
580615
}
@@ -614,37 +649,48 @@ func treeBuildHandler(arg argType) error {
614649

615650
// Verify verifies the data block with the Merkle Tree proof
616651
func (m *MerkleTree) Verify(dataBlock DataBlock, proof *Proof) (bool, error) {
617-
return Verify(dataBlock, proof, m.Root, m.HashFunc)
652+
return Verify(dataBlock, proof, m.Root, m.Config)
618653
}
619654

620655
// Verify verifies the data block with the Merkle Tree proof and Merkle root hash
621-
func Verify(dataBlock DataBlock, proof *Proof, root []byte, hashFunc HashFuncType) (bool, error) {
656+
func Verify(dataBlock DataBlock, proof *Proof, root []byte, config *Config) (bool, error) {
622657
if dataBlock == nil {
623658
return false, errors.New("data block is nil")
624659
}
625660
if proof == nil {
626661
return false, errors.New("proof is nil")
627662
}
628-
if hashFunc == nil {
629-
hashFunc = defaultHashFunc
663+
664+
if config == nil {
665+
config = new(Config)
666+
}
667+
668+
if config.HashFunc == nil {
669+
config.HashFunc = defaultHashFunc
630670
}
671+
672+
if config.concatHashFunc == nil {
673+
config.concatHashFunc = concatHash
674+
}
675+
631676
var (
632677
data, err = dataBlock.Serialize()
633678
hash []byte
634679
)
680+
635681
if err != nil {
636682
return false, err
637683
}
638-
hash, err = hashFunc(data)
684+
hash, err = config.HashFunc(data)
639685
if err != nil {
640686
return false, err
641687
}
642688
path := proof.Path
643689
for _, n := range proof.Siblings {
644690
if path&1 == 1 {
645-
hash, err = hashFunc(append(hash, n...))
691+
hash, err = config.HashFunc(config.concatHashFunc(hash, n))
646692
} else {
647-
hash, err = hashFunc(append(n, hash...))
693+
hash, err = config.HashFunc(config.concatHashFunc(n, hash))
648694
}
649695
if err != nil {
650696
return false, err

merkle_tree_test.go

+70-11
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ func TestMerkleTreeNew_proofGen(t *testing.T) {
155155
},
156156
wantErr: false,
157157
},
158+
{
159+
name: "test_8_sorted",
160+
args: args{
161+
blocks: genTestDataBlocks(8),
162+
config: &Config{
163+
SortSiblingPairs: true,
164+
},
165+
},
166+
wantErr: false,
167+
},
158168
{
159169
name: "test_hash_func_error",
160170
args: args{
@@ -924,7 +934,7 @@ func TestVerify(t *testing.T) {
924934
dataBlock DataBlock
925935
proof *Proof
926936
root []byte
927-
hashFunc HashFuncType
937+
config *Config
928938
}
929939
tests := []struct {
930940
name string
@@ -939,7 +949,21 @@ func TestVerify(t *testing.T) {
939949
dataBlock: blocks[0],
940950
proof: m.Proofs[0],
941951
root: m.Root,
942-
hashFunc: m.HashFunc,
952+
config: &Config{
953+
HashFunc: m.HashFunc,
954+
concatHashFunc: func(left, right []byte) []byte {
955+
return append(left, right...)
956+
},
957+
},
958+
},
959+
want: true,
960+
},
961+
{
962+
name: "test_config_nil",
963+
args: args{
964+
dataBlock: blocks[0],
965+
proof: m.Proofs[0],
966+
root: m.Root,
943967
},
944968
want: true,
945969
},
@@ -949,7 +973,12 @@ func TestVerify(t *testing.T) {
949973
dataBlock: blocks[0],
950974
proof: m.Proofs[0],
951975
root: []byte("test_wrong_root"),
952-
hashFunc: m.HashFunc,
976+
config: &Config{
977+
HashFunc: m.HashFunc,
978+
concatHashFunc: func(left, right []byte) []byte {
979+
return append(left, right...)
980+
},
981+
},
953982
},
954983
want: false,
955984
},
@@ -959,7 +988,12 @@ func TestVerify(t *testing.T) {
959988
dataBlock: blocks[0],
960989
proof: m.Proofs[0],
961990
root: m.Root,
962-
hashFunc: func([]byte) ([]byte, error) { return []byte("test_wrong_hash_hash"), nil },
991+
config: &Config{
992+
HashFunc: func([]byte) ([]byte, error) { return []byte("test_wrong_hash_hash"), nil },
993+
concatHashFunc: func(left, right []byte) []byte {
994+
return append(left, right...)
995+
},
996+
},
963997
},
964998
want: false,
965999
},
@@ -969,7 +1003,12 @@ func TestVerify(t *testing.T) {
9691003
dataBlock: blocks[0],
9701004
proof: nil,
9711005
root: m.Root,
972-
hashFunc: m.HashFunc,
1006+
config: &Config{
1007+
HashFunc: m.HashFunc,
1008+
concatHashFunc: func(left, right []byte) []byte {
1009+
return append(left, right...)
1010+
},
1011+
},
9731012
},
9741013
want: false,
9751014
wantErr: true,
@@ -980,7 +1019,12 @@ func TestVerify(t *testing.T) {
9801019
dataBlock: nil,
9811020
proof: m.Proofs[0],
9821021
root: m.Root,
983-
hashFunc: m.HashFunc,
1022+
config: &Config{
1023+
HashFunc: m.HashFunc,
1024+
concatHashFunc: func(left, right []byte) []byte {
1025+
return append(left, right...)
1026+
},
1027+
},
9841028
},
9851029
want: false,
9861030
wantErr: true,
@@ -991,7 +1035,12 @@ func TestVerify(t *testing.T) {
9911035
dataBlock: blocks[0],
9921036
proof: m.Proofs[0],
9931037
root: m.Root,
994-
hashFunc: nil,
1038+
config: &Config{
1039+
HashFunc: nil,
1040+
concatHashFunc: func(left, right []byte) []byte {
1041+
return append(left, right...)
1042+
},
1043+
},
9951044
},
9961045
want: true,
9971046
wantErr: false,
@@ -1002,8 +1051,13 @@ func TestVerify(t *testing.T) {
10021051
dataBlock: blocks[0],
10031052
proof: m.Proofs[0],
10041053
root: m.Root,
1005-
hashFunc: func([]byte) ([]byte, error) {
1006-
return nil, errors.New("test_hash_func_err")
1054+
config: &Config{
1055+
HashFunc: func([]byte) ([]byte, error) {
1056+
return nil, errors.New("test_hash_func_err")
1057+
},
1058+
concatHashFunc: func(left, right []byte) []byte {
1059+
return append(left, right...)
1060+
},
10071061
},
10081062
},
10091063
want: false,
@@ -1015,7 +1069,12 @@ func TestVerify(t *testing.T) {
10151069
dataBlock: blocks[0],
10161070
proof: m.Proofs[0],
10171071
root: m.Root,
1018-
hashFunc: m.HashFunc,
1072+
config: &Config{
1073+
HashFunc: m.HashFunc,
1074+
concatHashFunc: func(left, right []byte) []byte {
1075+
return append(left, right...)
1076+
},
1077+
},
10191078
},
10201079
mock: func() {
10211080
patches.ApplyMethod(reflect.TypeOf(&mockDataBlock{}), "Serialize",
@@ -1033,7 +1092,7 @@ func TestVerify(t *testing.T) {
10331092
tt.mock()
10341093
}
10351094
defer patches.Reset()
1036-
got, err := Verify(tt.args.dataBlock, tt.args.proof, tt.args.root, tt.args.hashFunc)
1095+
got, err := Verify(tt.args.dataBlock, tt.args.proof, tt.args.root, tt.args.config)
10371096
if (err != nil) != tt.wantErr {
10381097
t.Errorf("Verify() error = %v, wantErr %v", err, tt.wantErr)
10391098
return

0 commit comments

Comments
 (0)