@@ -10,7 +10,6 @@ use alloy_primitives::{
10
10
} ;
11
11
use alloy_rlp:: { Decodable , Encodable } ;
12
12
use core:: fmt;
13
- use itertools:: Itertools ;
14
13
use reth_execution_errors:: { SparseStateTrieErrorKind , SparseStateTrieResult , SparseTrieErrorKind } ;
15
14
use reth_primitives_traits:: Account ;
16
15
use reth_tracing:: tracing:: trace;
@@ -282,34 +281,11 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
282
281
branch_node_hash_masks : HashMap < Nibbles , TrieMask > ,
283
282
branch_node_tree_masks : HashMap < Nibbles , TrieMask > ,
284
283
) -> SparseStateTrieResult < ( ) > {
285
- let len = account_subtree. len ( ) ;
286
- let ( mut account_nodes, branch_children) = account_subtree
287
- . into_inner ( )
288
- . into_iter ( )
289
- . filter_map ( |( path, bytes) | {
290
- self . metrics . increment_total_account_nodes ( ) ;
291
- // If the node is already revealed, skip it.
292
- if self . revealed_account_paths . contains ( & path) {
293
- self . metrics . increment_skipped_account_nodes ( ) ;
294
- return None
295
- }
296
-
297
- Some ( TrieNode :: decode ( & mut & bytes[ ..] ) . map ( |node| ( path, node) ) )
298
- } )
299
- . fold_ok (
300
- ( Vec :: with_capacity ( len) , 0usize ) ,
301
- |( mut nodes, mut children) , ( path, node) | {
302
- if let TrieNode :: Branch ( branch) = & node {
303
- children += branch. state_mask . count_ones ( ) as usize ;
304
- }
305
-
306
- nodes. push ( ( path, node) ) ;
307
-
308
- ( nodes, children)
309
- } ,
310
- ) ?;
311
- account_nodes. sort_unstable_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
312
- let mut account_nodes = account_nodes. into_iter ( ) . peekable ( ) ;
284
+ let DecodedProofNodes { nodes, total_nodes, skipped_nodes, branch_children } =
285
+ DecodedProofNodes :: new ( account_subtree, & self . revealed_account_paths ) ?;
286
+ self . metrics . increment_total_account_nodes ( total_nodes) ;
287
+ self . metrics . increment_skipped_account_nodes ( skipped_nodes) ;
288
+ let mut account_nodes = nodes. into_iter ( ) . peekable ( ) ;
313
289
314
290
if let Some ( root_node) = Self :: validate_root_node_decoded ( & mut account_nodes) ? {
315
291
// Reveal root node if it wasn't already.
@@ -355,35 +331,10 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
355
331
) -> SparseStateTrieResult < ( ) > {
356
332
let revealed_nodes = self . revealed_storage_paths . entry ( account) . or_default ( ) ;
357
333
358
- let len = storage_subtree. subtree . len ( ) ;
359
- let ( mut nodes, branch_children) = storage_subtree
360
- . subtree
361
- . into_inner ( )
362
- . into_iter ( )
363
- . filter_map ( |( path, bytes) | {
364
- self . metrics . increment_total_storage_nodes ( ) ;
365
- // If the node is already revealed, skip it.
366
- if revealed_nodes. contains ( & path) {
367
- self . metrics . increment_skipped_storage_nodes ( ) ;
368
- return None
369
- }
370
-
371
- Some ( TrieNode :: decode ( & mut & bytes[ ..] ) . map ( |node| ( path, node) ) )
372
- } )
373
- . fold_ok (
374
- ( Vec :: with_capacity ( len) , 0usize ) ,
375
- |( mut nodes, mut children) , ( path, node) | {
376
- if let TrieNode :: Branch ( branch) = & node {
377
- children += branch. state_mask . count_ones ( ) as usize ;
378
- }
379
-
380
- nodes. push ( ( path, node) ) ;
381
-
382
- ( nodes, children)
383
- } ,
384
- ) ?;
385
-
386
- nodes. sort_unstable_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
334
+ let DecodedProofNodes { nodes, total_nodes, skipped_nodes, branch_children } =
335
+ DecodedProofNodes :: new ( storage_subtree. subtree , revealed_nodes) ?;
336
+ self . metrics . increment_total_storage_nodes ( total_nodes) ;
337
+ self . metrics . increment_skipped_storage_nodes ( skipped_nodes) ;
387
338
let mut nodes = nodes. into_iter ( ) . peekable ( ) ;
388
339
389
340
if let Some ( root_node) = Self :: validate_root_node_decoded ( & mut nodes) ? {
@@ -542,8 +493,6 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
542
493
& self ,
543
494
proof : & mut Peekable < I > ,
544
495
) -> SparseStateTrieResult < Option < TrieNode > > {
545
- let mut proof = proof. into_iter ( ) . peekable ( ) ;
546
-
547
496
// Validate root node.
548
497
let Some ( ( path, node) ) = proof. next ( ) else { return Ok ( None ) } ;
549
498
if !path. is_empty ( ) {
@@ -559,11 +508,10 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
559
508
Ok ( Some ( root_node) )
560
509
}
561
510
511
+ /// Validates the decoded root node of the proof and returns it if it exists and is valid.
562
512
fn validate_root_node_decoded < I : Iterator < Item = ( Nibbles , TrieNode ) > > (
563
513
proof : & mut Peekable < I > ,
564
514
) -> SparseStateTrieResult < Option < TrieNode > > {
565
- let mut proof = proof. into_iter ( ) . peekable ( ) ;
566
-
567
515
// Validate root node.
568
516
let Some ( ( path, root_node) ) = proof. next ( ) else { return Ok ( None ) } ;
569
517
if !path. is_empty ( ) {
@@ -574,7 +522,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
574
522
. into ( ) )
575
523
}
576
524
577
- // Decode root node and perform sanity check.
525
+ // Perform sanity check.
578
526
if matches ! ( root_node, TrieNode :: EmptyRoot ) && proof. peek ( ) . is_some ( ) {
579
527
return Err ( SparseStateTrieErrorKind :: InvalidRootNode {
580
528
path,
@@ -833,6 +781,49 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
833
781
}
834
782
}
835
783
784
+ /// Decoded proof nodes returned by [`decode_proof_nodes`].
785
+ #[ derive( Debug ) ]
786
+ struct DecodedProofNodes {
787
+ /// Filtered, decoded and sorted proof nodes.
788
+ nodes : Vec < ( Nibbles , TrieNode ) > ,
789
+ /// Number of nodes in the proof.
790
+ total_nodes : u64 ,
791
+ /// Number of nodes that were skipped because they were already revealed.
792
+ skipped_nodes : u64 ,
793
+ /// Number of children of all branch nodes in the proof.
794
+ branch_children : usize ,
795
+ }
796
+
797
+ impl DecodedProofNodes {
798
+ fn new ( proof_nodes : ProofNodes , revealed_nodes : & HashSet < Nibbles > ) -> alloy_rlp:: Result < Self > {
799
+ let mut result = Self {
800
+ nodes : Vec :: with_capacity ( proof_nodes. len ( ) ) ,
801
+ total_nodes : 0 ,
802
+ skipped_nodes : 0 ,
803
+ branch_children : 0 ,
804
+ } ;
805
+
806
+ for ( path, bytes) in proof_nodes. into_inner ( ) {
807
+ result. total_nodes += 1 ;
808
+ // If the node is already revealed, skip it.
809
+ if revealed_nodes. contains ( & path) {
810
+ result. skipped_nodes += 1 ;
811
+ continue
812
+ }
813
+
814
+ let node = TrieNode :: decode ( & mut & bytes[ ..] ) ?;
815
+ if let TrieNode :: Branch ( branch) = & node {
816
+ result. branch_children += branch. state_mask . count_ones ( ) as usize ;
817
+ }
818
+
819
+ result. nodes . push ( ( path, node) ) ;
820
+ }
821
+
822
+ result. nodes . sort_unstable_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
823
+ Ok ( result)
824
+ }
825
+ }
826
+
836
827
#[ cfg( test) ]
837
828
mod tests {
838
829
use super :: * ;
0 commit comments