Skip to content

Commit 6fd75cf

Browse files
nagisashreyan-gupta
authored andcommitted
store: make Trie & TrieUpdate: Send + Sync (near#12981)
This does not yet mean that `TrieUpdate` is perfectly usable in MT scenarios -- this will remain impossible so long as any method takes a `&mut self`, but it is now possible to share a `&TrieUpdate` which allows parallel reading of the storage. Maybe that's already good enough for some uses? Main intent is to evaluate performance difference seen here without further changes to how `TrieUpdate` is being used. And if there is a significant impact, investigate mitigations.
1 parent bd863a6 commit 6fd75cf

File tree

10 files changed

+84
-59
lines changed

10 files changed

+84
-59
lines changed

Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,15 @@ smartstring = "1.0.1"
361361
strum = { version = "0.24", features = ["derive"] }
362362
stun = "0.4"
363363
subtle = "2.2"
364+
static_assertions = "1.1"
364365
syn = { version = "2.0.4", features = ["extra-traits", "full"] }
365366
sysinfo = "0.24.5"
366367
target-lexicon = { version = "0.12.2", default-features = false }
367368
tempfile = "3.3"
368369
testlib = { path = "test-utils/testlib" }
369370
test-log = { version = "0.2", default-features = false, features = ["trace"] }
370371
thiserror = "2.0"
372+
thread_local = "1.1"
371373
tikv-jemallocator = "0.5.0"
372374
time = { version = "0.3.9", default-features = false }
373375
tokio = { version = "1.28", default-features = false }

chain/chain/src/resharding/manager.rs

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
use std::cell::RefCell;
2-
use std::io;
3-
use std::num::NonZero;
4-
use std::sync::Arc;
5-
61
use super::event_type::{ReshardingEventType, ReshardingSplitShardParams};
72
use super::types::ReshardingSender;
83
use crate::ChainStoreUpdate;
@@ -26,6 +21,9 @@ use near_store::trie::mem::memtrie_update::TrackingMode;
2621
use near_store::trie::ops::resharding::RetainMode;
2722
use near_store::trie::outgoing_metadata::ReceiptGroupsQueue;
2823
use near_store::{DBCol, ShardTries, ShardUId, Store, TrieAccess};
24+
use std::io;
25+
use std::num::NonZero;
26+
use std::sync::Arc;
2927

3028
pub struct ReshardingManager {
3129
store: Store,
@@ -252,10 +250,7 @@ impl ReshardingManager {
252250
let parent_trie = tries.get_trie_for_shard(parent_shard_uid, parent_state_root);
253251
let parent_congestion_info =
254252
parent_chunk_extra.congestion_info().expect("The congestion info must exist!");
255-
256-
let trie_recorder = RefCell::new(trie_recorder);
257-
let parent_trie = parent_trie.recording_reads_with_recorder(trie_recorder);
258-
253+
let parent_trie = parent_trie.recording_reads_with_recorder(trie_recorder.into());
259254
let child_epoch_id = self.epoch_manager.get_next_epoch_id(block.hash())?;
260255
let child_shard_layout = self.epoch_manager.get_shard_layout(&child_epoch_id)?;
261256
let child_congestion_info = Self::get_child_congestion_info(
@@ -268,7 +263,7 @@ impl ReshardingManager {
268263
)?;
269264

270265
let trie_recorder = parent_trie.take_recorder().unwrap();
271-
let partial_storage = trie_recorder.borrow_mut().recorded_storage();
266+
let partial_storage = trie_recorder.write().expect("no poison").recorded_storage();
272267
let partial_state_len = match &partial_storage.nodes {
273268
PartialState::TrieValues(values) => values.len(),
274269
};

core/store/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ rlimit.workspace = true
3434
rocksdb.workspace = true
3535
serde.workspace = true
3636
serde_json.workspace = true
37+
static_assertions.workspace = true
3738
stdx.workspace = true
3839
strum.workspace = true
3940
tempfile.workspace = true
4041
thiserror.workspace = true
42+
thread_local.workspace = true
4143
tokio.workspace = true
4244
tracing.workspace = true
4345

core/store/src/trie/accounting_cache.rs

+11-4
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@ use near_primitives::hash::CryptoHash;
77
use near_primitives::shard_layout::ShardUId;
88
use std::collections::HashMap;
99
use std::sync::Arc;
10+
use std::sync::atomic;
1011

1112
/// Switch that controls whether the `TrieAccountingCache` is enabled.
12-
pub struct TrieAccountingCacheSwitch(Arc<std::sync::atomic::AtomicBool>);
13+
pub struct TrieAccountingCacheSwitch(Arc<thread_local::ThreadLocal<atomic::AtomicBool>>);
1314

1415
impl TrieAccountingCacheSwitch {
1516
pub fn set(&self, enabled: bool) {
16-
self.0.store(enabled, std::sync::atomic::Ordering::Relaxed);
17+
self.0.get_or(Default::default).store(enabled, atomic::Ordering::Relaxed);
1718
}
1819

1920
pub fn enabled(&self) -> bool {
20-
self.0.load(std::sync::atomic::Ordering::Relaxed)
21+
self.0.get_or(Default::default).load(atomic::Ordering::Relaxed)
2122
}
2223
}
2324

@@ -92,7 +93,13 @@ impl TrieAccountingCache {
9293
}
9394
});
9495
let switch = TrieAccountingCacheSwitch(Default::default());
95-
Self { enable: switch, cache: HashMap::new(), db_read_nodes: 0, mem_read_nodes: 0, metrics }
96+
Self {
97+
enable: switch,
98+
cache: Default::default(),
99+
db_read_nodes: Default::default(),
100+
mem_read_nodes: Default::default(),
101+
metrics,
102+
}
96103
}
97104

98105
pub fn enable_switch(&self) -> TrieAccountingCacheSwitch {

core/store/src/trie/mem/iter.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ impl<'a> GenericTrieInternalStorage<MemTrieNodeId, FlatStateValue> for MemTrieIt
3939
let view = node.as_ptr(self.memtrie.arena.memory()).view();
4040
if let Some(recorder) = &self.trie.recorder {
4141
let raw_node_serialized = borsh::to_vec(&view.to_raw_trie_node_with_size()).unwrap();
42-
recorder.borrow_mut().record(&view.node_hash(), raw_node_serialized.into());
42+
recorder
43+
.write()
44+
.expect("no poison")
45+
.record(&view.node_hash(), raw_node_serialized.into());
4346
}
4447
let node = MemTrieNode::from_existing_node_view(view);
4548
Ok(node)
@@ -50,7 +53,7 @@ impl<'a> GenericTrieInternalStorage<MemTrieNodeId, FlatStateValue> for MemTrieIt
5053
let value = self.trie.deref_optimized(&optimized_value_ref)?;
5154
if let Some(recorder) = &self.trie.recorder {
5255
let value_hash = optimized_value_ref.into_value_ref().hash;
53-
recorder.borrow_mut().record(&value_hash, value.clone().into());
56+
recorder.write().expect("no poison").record(&value_hash, value.clone().into());
5457
};
5558
Ok(value)
5659
}

core/store/src/trie/mod.rs

+38-31
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,12 @@ use ops::interface::{GenericNodeOrIndex, GenericTrieNode, GenericTrieUpdate};
3636
use ops::interface::{GenericTrieValue, UpdatedNodeId};
3737
use ops::resharding::{GenericTrieUpdateRetain, RetainMode};
3838
pub use raw_node::{Children, RawTrieNode, RawTrieNodeWithSize};
39-
use std::cell::RefCell;
4039
use std::collections::{BTreeMap, HashMap, HashSet};
4140
use std::fmt::Write;
4241
use std::hash::Hash;
4342
use std::ops::DerefMut;
4443
use std::str;
45-
use std::sync::{Arc, RwLock, RwLockReadGuard};
44+
use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
4645
pub use trie_recording::{SubtreeSize, TrieRecorder, TrieRecorderStats};
4746
use trie_storage_update::{
4847
TrieStorageNodeWithSize, TrieStorageUpdate, UpdatedTrieStorageNodeWithSize,
@@ -202,12 +201,13 @@ pub struct Trie {
202201
/// (which can be toggled on the fly), trie nodes that have been looked up
203202
/// once will be guaranteed to be cached, and further reads to these nodes
204203
/// will encounter less gas cost.
205-
accounting_cache: RefCell<TrieAccountingCache>,
204+
accounting_cache: Mutex<TrieAccountingCache>,
206205
/// If present, we're capturing all trie nodes that have been accessed
207206
/// during the lifetime of this Trie struct. This is used to produce a
208207
/// state proof so that the same access pattern can be replayed using only
209208
/// the captured result.
210-
recorder: Option<RefCell<TrieRecorder>>,
209+
// FIXME: make `TrieRecorder` internally MT-safe, instead of locking the entire structure.
210+
recorder: Option<RwLock<TrieRecorder>>,
211211
/// If true, access to trie nodes (not values) charges gas and affects the
212212
/// accounting cache. If false, access to trie nodes will not charge gas or
213213
/// affect the accounting cache. Value accesses always charge gas no matter
@@ -544,11 +544,10 @@ impl Trie {
544544
flat_storage_chunk_view: Option<FlatStorageChunkView>,
545545
) -> Self {
546546
let accounting_cache = match storage.as_caching_storage() {
547-
Some(caching_storage) => RefCell::new(TrieAccountingCache::new(Some((
548-
caching_storage.shard_uid,
549-
caching_storage.is_view,
550-
)))),
551-
None => RefCell::new(TrieAccountingCache::new(None)),
547+
Some(caching_storage) => {
548+
TrieAccountingCache::new(Some((caching_storage.shard_uid, caching_storage.is_view)))
549+
}
550+
None => TrieAccountingCache::new(None),
552551
};
553552
// Technically the charge_gas_for_trie_node_access should be set based
554553
// on the flat storage protocol feature. When flat storage is enabled
@@ -562,7 +561,7 @@ impl Trie {
562561
root,
563562
charge_gas_for_trie_node_access,
564563
flat_storage_chunk_view,
565-
accounting_cache,
564+
accounting_cache: Mutex::new(accounting_cache),
566565
recorder: None,
567566
}
568567
}
@@ -580,19 +579,19 @@ impl Trie {
580579
/// Makes a new trie that has everything the same except that access
581580
/// through that trie accumulates a state proof for all nodes accessed.
582581
pub fn recording_reads_new_recorder(&self) -> Self {
583-
let recorder = RefCell::new(TrieRecorder::new(None));
582+
let recorder = RwLock::new(TrieRecorder::new(None));
584583
self.recording_reads_with_recorder(recorder)
585584
}
586585

587586
/// Makes a new trie that has everything the same except that access
588587
/// through that trie accumulates a state proof for all nodes accessed.
589588
/// We also supply a proof size limit to prevent the proof from growing too large.
590589
pub fn recording_reads_with_proof_size_limit(&self, proof_size_limit: usize) -> Self {
591-
let recorder = RefCell::new(TrieRecorder::new(Some(proof_size_limit)));
590+
let recorder = RwLock::new(TrieRecorder::new(Some(proof_size_limit)));
592591
self.recording_reads_with_recorder(recorder)
593592
}
594593

595-
pub fn recording_reads_with_recorder(&self, recorder: RefCell<TrieRecorder>) -> Self {
594+
pub fn recording_reads_with_recorder(&self, recorder: RwLock<TrieRecorder>) -> Self {
596595
let mut trie = Self::new_with_memtries(
597596
self.storage.clone(),
598597
self.memtries.clone(),
@@ -605,20 +604,22 @@ impl Trie {
605604
trie
606605
}
607606

608-
pub fn take_recorder(self) -> Option<RefCell<TrieRecorder>> {
607+
pub fn take_recorder(self) -> Option<RwLock<TrieRecorder>> {
609608
self.recorder
610609
}
611610

612611
/// Takes the recorded state proof out of the trie.
613612
pub fn recorded_storage(&self) -> Option<PartialStorage> {
614-
self.recorder.as_ref().map(|recorder| recorder.borrow_mut().recorded_storage())
613+
self.recorder
614+
.as_ref()
615+
.map(|recorder| recorder.write().expect("no poison").recorded_storage())
615616
}
616617

617618
/// Returns the in-memory size of the recorded state proof. Useful for checking size limit of state witness
618619
pub fn recorded_storage_size(&self) -> usize {
619620
self.recorder
620621
.as_ref()
621-
.map(|recorder| recorder.borrow().recorded_storage_size())
622+
.map(|recorder| recorder.read().expect("no poison").recorded_storage_size())
622623
.unwrap_or_default()
623624
}
624625

@@ -627,14 +628,14 @@ impl Trie {
627628
pub fn recorded_storage_size_upper_bound(&self) -> usize {
628629
self.recorder
629630
.as_ref()
630-
.map(|recorder| recorder.borrow().recorded_storage_size_upper_bound())
631+
.map(|recorder| recorder.read().expect("no poison").recorded_storage_size_upper_bound())
631632
.unwrap_or_default()
632633
}
633634

634635
pub fn check_proof_size_limit_exceed(&self) -> bool {
635636
self.recorder
636637
.as_ref()
637-
.map(|recorder| recorder.borrow().check_proof_size_limit_exceed())
638+
.map(|recorder| recorder.read().expect("no poison").check_proof_size_limit_exceed())
638639
.unwrap_or_default()
639640
}
640641

@@ -662,7 +663,9 @@ impl Trie {
662663
/// Get statistics about the recorded trie. Useful for observability and debugging.
663664
/// This scans all of the recorded data, so could potentially be expensive to run.
664665
pub fn recorder_stats(&self) -> Option<TrieRecorderStats> {
665-
self.recorder.as_ref().map(|recorder| recorder.borrow().get_stats(&self.root))
666+
self.recorder
667+
.as_ref()
668+
.map(|recorder| recorder.read().expect("no poison").get_stats(&self.root))
666669
}
667670

668671
pub fn get_root(&self) -> &StateRoot {
@@ -683,7 +686,7 @@ impl Trie {
683686
return;
684687
};
685688
{
686-
let mut r = recorder.borrow_mut();
689+
let mut r = recorder.write().expect("no poison");
687690
if r.codes_to_record.contains(&account_id) {
688691
return;
689692
}
@@ -695,7 +698,7 @@ impl Trie {
695698
let key = TrieKey::ContractCode { account_id };
696699
let value_ref = self.get_optimized_ref(&key.to_vec(), KeyLookupMode::FlatStorage);
697700
if let Ok(Some(value_ref)) = value_ref {
698-
let mut r = recorder.borrow_mut();
701+
let mut r = recorder.write().expect("no poison");
699702
r.record_code_len(value_ref.len());
700703
}
701704
}
@@ -711,7 +714,7 @@ impl Trie {
711714
// that it is possible to generated continuous stream of witnesses with a fixed
712715
// size. Using static key achieves that since in case of multiple receipts garbage
713716
// data will simply be overwritten, not accumulated.
714-
recorder.borrow_mut().record_unaccounted(
717+
recorder.write().expect("no poison").record_unaccounted(
715718
&CryptoHash::hash_bytes(b"__garbage_data_key_1720025071757228"),
716719
data.into(),
717720
);
@@ -732,14 +735,15 @@ impl Trie {
732735
) -> Result<Arc<[u8]>, StorageError> {
733736
let result = if side_effects && use_accounting_cache {
734737
self.accounting_cache
735-
.borrow_mut()
738+
.lock()
739+
.unwrap()
736740
.retrieve_raw_bytes_with_accounting(hash, &*self.storage)?
737741
} else {
738742
self.storage.retrieve_raw_bytes(hash)?
739743
};
740744
if side_effects {
741745
if let Some(recorder) = &self.recorder {
742-
recorder.borrow_mut().record(hash, result.clone());
746+
recorder.write().expect("no poison").record(hash, result.clone());
743747
}
744748
}
745749
Ok(result)
@@ -1291,13 +1295,14 @@ impl Trie {
12911295
if charge_gas_for_trie_node_access {
12921296
for (node_hash, serialized_node) in &accessed_nodes {
12931297
self.accounting_cache
1294-
.borrow_mut()
1298+
.lock()
1299+
.unwrap()
12951300
.retroactively_account(*node_hash, serialized_node.clone());
12961301
}
12971302
}
12981303
if let Some(recorder) = &self.recorder {
12991304
for (node_hash, serialized_node) in accessed_nodes {
1300-
recorder.borrow_mut().record(&node_hash, serialized_node);
1305+
recorder.write().expect("no poison").record(&node_hash, serialized_node);
13011306
}
13021307
}
13031308
mem_value
@@ -1491,10 +1496,11 @@ impl Trie {
14911496
let value_hash = hash(value);
14921497
let arc_value: Arc<[u8]> = value.clone().into();
14931498
self.accounting_cache
1494-
.borrow_mut()
1499+
.lock()
1500+
.unwrap()
14951501
.retroactively_account(value_hash, arc_value.clone());
14961502
if let Some(recorder) = &self.recorder {
1497-
recorder.borrow_mut().record(&value_hash, arc_value);
1503+
recorder.write().expect("no poison").record(&value_hash, arc_value);
14981504
}
14991505
Ok(value.clone())
15001506
}
@@ -1515,7 +1521,7 @@ impl Trie {
15151521
{
15161522
// Call `get` for contract codes requested to be recorded.
15171523
let codes_to_record = if let Some(recorder) = &self.recorder {
1518-
recorder.borrow().codes_to_record.clone()
1524+
recorder.read().expect("no poison").codes_to_record.clone()
15191525
} else {
15201526
HashSet::default()
15211527
};
@@ -1537,7 +1543,8 @@ impl Trie {
15371543
{
15381544
// Get trie_update for memtrie
15391545
let guard = self.memtries.as_ref().unwrap().read().unwrap();
1540-
let mut recorder = self.recorder.as_ref().map(|recorder| recorder.borrow_mut());
1546+
let mut recorder =
1547+
self.recorder.as_ref().map(|recorder| recorder.write().expect("no poison"));
15411548
let tracking_mode = match &mut recorder {
15421549
Some(recorder) => TrackingMode::RefcountsAndAccesses(recorder.deref_mut()),
15431550
None => TrackingMode::Refcounts,
@@ -1650,7 +1657,7 @@ impl Trie {
16501657
}
16511658

16521659
pub fn get_trie_nodes_count(&self) -> TrieNodesCount {
1653-
self.accounting_cache.borrow().get_trie_nodes_count()
1660+
self.accounting_cache.lock().unwrap().get_trie_nodes_count()
16541661
}
16551662

16561663
/// Splits the trie, separating entries by the boundary account.

0 commit comments

Comments
 (0)