Skip to content

Commit a005757

Browse files
authored
rust oak_runtime: Fix any possible lock inversion issues (#784)
Fixes #780
1 parent 589b2d2 commit a005757

File tree

1 file changed

+18
-5
lines changed
  • oak/server/rust/oak_runtime/src/runtime

1 file changed

+18
-5
lines changed

oak/server/rust/oak_runtime/src/runtime/mod.rs

+18-5
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ impl Runtime {
502502
handles_capacity: usize,
503503
) -> Result<Option<ReadStatus>, OakStatus> {
504504
self.validate_handle_access(node_id, reference)?;
505-
self.channels
505+
let result = self.channels
506506
.with_channel(self.channels.get_reader_channel(reference)?, |channel| {
507507
let mut messages = channel.messages.write().unwrap();
508508
match messages.front() {
@@ -517,7 +517,6 @@ impl Runtime {
517517
ReadStatus::NeedsCapacity(req_bytes_capacity, req_handles_capacity)
518518
} else {
519519
let msg = messages.pop_front().expect( "Front element disappeared while we were holding the write lock!");
520-
self.track_handles_in_node(node_id, msg.channels.clone());
521520
ReadStatus::Success(msg)
522521
},
523522
))
@@ -530,7 +529,15 @@ impl Runtime {
530529
}
531530
}
532531
}
533-
})
532+
});
533+
534+
// Add handles outside the channels lock so we don't hold the node lock inside the channel
535+
// lock.
536+
if let Ok(Some(ReadStatus::Success(ref msg))) = result {
537+
self.track_handles_in_node(node_id, msg.channels.clone());
538+
}
539+
540+
result
534541
}
535542

536543
/// Return the direction of a [`Handle`]. This is useful when reading
@@ -605,6 +612,13 @@ impl Runtime {
605612
.remove(&node_id)
606613
.expect("remove_node_id: Node didn't exist!");
607614
}
615+
616+
/// Add an [`NodeId`] [`Node`] pair to the [`Runtime`]. This method temporarily holds the node
617+
/// write lock.
618+
fn add_running_node(&self, reference: NodeId, node: Node) {
619+
let mut nodes = self.nodes.write().unwrap();
620+
nodes.insert(reference, node);
621+
}
608622
}
609623

610624
/// A reference to a [`Runtime`].
@@ -639,7 +653,6 @@ impl RuntimeRef {
639653
// to do that we first need to provide a reference to the caller node as a parameter to this
640654
// function.
641655

642-
let mut nodes = self.nodes.write().unwrap();
643656
let reference = self.new_node_reference();
644657

645658
let reader = self.channels.duplicate_reference(reader)?;
@@ -666,7 +679,7 @@ impl RuntimeRef {
666679

667680
// If the node was successfully started, insert it in the list of currently running
668681
// nodes.
669-
nodes.insert(
682+
self.add_running_node(
670683
reference,
671684
Node {
672685
reference,

0 commit comments

Comments
 (0)