Rust:学习 Rc、Arc 和 Weak 并动手实现它

基本概念

RcArc 都是引用计数类型,它们允许多个所有者共享同一个数据。Rc 是单线程环境下的引用计数类型,而 Arc 是线程安全的引用计数类型,适用于多线程环境。WeakRcArc 的一个重要补充工具,解决了在多所有权和防止循环引用的问题。

  • 什么是引用计数? 引用计数是一种内存管理技术,通过维护一个计数器来跟踪有多少个指针指向某个数据。当引用计数变为零时,数据就会被自动释放。

  • 什么是循环引用? 循环引用是指两个或多个对象互相引用,形成一个环。这种情况下,即使没有外部引用,引用计数也不会变为零,导致内存无法被释放,造成内存泄漏。

引用计数

Rc 的使用场景

Rc(Reference Counted)用于单线程环境,它允许多个所有者共享同一份数据。Rc 的使用非常简单,因为它不需要考虑线程安全问题。

一个典型的使用 Rc 的场景是共享不可变数据。例如,我们可以使用 Rc 来共享一个不可变的字符串:

 1use std::rc::Rc;
 2
 3fn main() {
 4    let data = Rc::new("Hello, world!".to_string());
 5    let data_clone1 = Rc::clone(&data);
 6    let data_clone2 = Rc::clone(&data);
 7
 8    println!("Original: {}", data);
 9    println!("Clone 1: {}", data_clone1);
10    println!("Clone 2: {}", data_clone2);
11}

在这个例子中,data 被多个所有者共享,而每个所有者都可以安全地访问它。

为什么需要 Rc?

在 Rust 中,一个值只能有一个所有者,这样可以确保内存安全。然而,在某些情况下,我们需要多个所有者共享同一个数据。例如,在树形数据结构中,多个节点可能需要共享同一个子节点。Rc 允许我们在单线程环境下实现这种多所有权。

原子引用计数

Arc 的使用场景

Arc(Atomic Reference Counted)类似于 Rc,但它是线程安全的。Arc 使用原子操作来维护引用计数,因此可以安全地在多个线程之间共享数据。

一个典型的使用 Arc 的场景是多线程环境下的共享数据。例如,我们可以使用 Arc 来共享一个不可变的向量:

 1use std::sync::Arc;
 2use std::thread;
 3
 4fn main() {
 5    let data = Arc::new(vec![1, 2, 3, 4, 5]);
 6    let mut handles = vec![];
 7
 8    for _ in 0..5 {
 9        let data_clone = Arc::clone(&data);
10        let handle = thread::spawn(move || {
11            println!("Data: {:?}", data_clone);
12        });
13        handles.push(handle);
14    }
15
16    for handle in handles {
17        handle.join().unwrap();
18    }
19}

在这个例子中,data 被多个线程共享,而每个线程都可以安全地访问它。

为什么需要 Arc?

在多线程编程中,我们经常需要在多个线程之间共享数据。Rust 的所有权和借用系统确保了数据的安全性,但它也限制了数据的共享。Arc 允许我们在多线程环境下实现安全的多所有权,从而简化了并发编程。

安全的边界

Shared references in Rust disallow mutation by default, and Arc is no exception: you cannot generally obtain a mutable reference to something inside an Arc. If you need to mutate through an Arc, use MutexRwLock, or one of the Atomic types. –https://doc.rust-lang.org/std/sync/struct.Arc.html#thread-safety

Arc 并不保证你访问数据的原子性,只保证计数增减的原子性。如果需要通过 Arc 来变异内部数据,通常需要结合 MutexRwLock 或者一些原子类型来实现。

例子:

 1use std::sync::{Arc, RwLock};
 2use std::thread;
 3
 4fn main() {
 5    let data = Arc::new(RwLock::new(0));
 6
 7    let mut handles = vec![];
 8
 9    for _ in 0..10 {
10        let data = Arc::clone(&data);
11        let handle = thread::spawn(move || {
12            let mut num = data.write().unwrap();
13            *num += 1;
14        });
15        handles.push(handle);
16    }
17
18    for handle in handles {
19        handle.join().unwrap();
20    }
21
22    println!("Result: {}", *data.read().unwrap());
23}

反例(注意:无法编译通过):

 1use std::cell::Cell;
 2use std::sync::Arc;
 3use std::thread;
 4
 5fn main() {
 6    let data = Arc::new(Cell::new(0));
 7
 8    let mut handles = vec![];
 9
10    for _ in 0..10 {
11        let data = Arc::clone(&data);
12        let handle = thread::spawn(move || {
13            // 尝试直接修改 Arc 包裹的 Cell 中的数据
14            data.set(data.get() + 1);
15        });
16        handles.push(handle);
17    }
18
19    for handle in handles {
20        handle.join().unwrap();
21    }
22
23    println!("Result: {}", data.get());
24}

弱引用

Weak 的使用场景

Weak 通常与 RcArc 一起使用,用于打破循环引用。例如,在树形数据结构中,父节点和子节点可能需要相互引用。我们可以使用 Weak 来引用父节点,从而避免循环引用。

使用 Rc 和 Weak 打破循环引用

假设我们有一个简单的树形数据结构,其中每个节点都有一个指向父节点的引用和一个指向子节点的引用:

 1use std::rc::{Rc, Weak};
 2use std::cell::RefCell;
 3
 4struct Node {
 5    value: i32,
 6    parent: RefCell<Weak<Node>>,
 7    children: RefCell<Vec<Rc<Node>>>,
 8}
 9
10fn main() {
11    let leaf = Rc::new(Node {
12        value: 3,
13        parent: RefCell::new(Weak::new()),
14        children: RefCell::new(vec![]),
15    });
16
17    let branch = Rc::new(Node {
18        value: 5,
19        parent: RefCell::new(Weak::new()),
20        children: RefCell::new(vec![Rc::clone(&leaf)]),
21    });
22
23    *leaf.parent.borrow_mut() = Rc::downgrade(&branch);
24
25    println!("Leaf parent: {:?}", leaf.parent.borrow().upgrade().map(|node| node.value));
26}

在这个例子中,leaf 节点引用 branch 节点作为父节点,但使用 Weak 避免了循环引用。Rc::downgrade 方法创建一个 Weak 引用,而 Weak::upgrade 方法则尝试将 Weak 引用转换回 Rc 引用,如果数据已经被释放,则返回 None

引用转换补充

下面对 upgrade 和 downgrade 补充。

  • Arc::downgrade: 类似于 Rc::downgrade,但是用于多线程环境下的 Arc<T>。它创建一个 Weak<T> 弱引用,不会阻止 T 被释放。

    1use std::sync::Arc;
    2use std::sync::Weak;
    3
    4fn main() {
    5    let strong_arc = Arc::new(5);
    6    let weak_arc: Weak<i32> = Arc::downgrade(&strong_arc);
    7}
    
  • Weak::upgradeArc 上的 Weak 类型): 类似于 Weak::upgrade,但是用于多线程环境下的 Weak<T> 弱引用。尝试从 Weak<T> 恢复到 Arc<T>,如果成功则增加引用计数。

     1use std::sync::Arc;
     2use std::sync::Weak;
     3
     4fn main() {
     5    let strong_arc = Arc::new(5);
     6    let weak_arc: Weak<i32> = Arc::downgrade(&strong_arc);
     7
     8    match weak_arc.upgrade() {
     9        Some(strong_again) => println!("Got strong Arc again: {}", strong_again),
    10        None => println!("The value has been dropped"),
    11    }
    12}
    

图结构的应用

在图数据结构中,节点之间可能存在复杂的引用关系,容易形成循环引用。我们可以使用 Weak 来引用其他节点,从而避免循环引用。例如:

 1use std::rc::{Rc, Weak};
 2use std::cell::RefCell;
 3
 4struct Node {
 5    value: i32,
 6    neighbors: RefCell<Vec<Weak<Node>>>,
 7}
 8
 9fn main() {
10    let node1 = Rc::new(Node {
11        value: 1,
12        neighbors: RefCell::new(vec![]),
13    });
14
15    let node2 = Rc::new(Node {
16        value: 2,
17        neighbors: RefCell::new(vec![Rc::downgrade(&node1)]),
18    });
19
20    node1.neighbors.borrow_mut().push(Rc::downgrade(&node2));
21
22    println!("Node1 neighbors: {:?}", node1.neighbors.borrow().len());
23    println!("Node2 neighbors: {:?}", node2.neighbors.borrow().len());
24}

在这个例子中,node1node2 互相引用,但使用 Weak 避免了循环引用。

我们进一步让结构更复杂看看。下面实现一个四节点全联通图结构,每个节点使用 Arc,而节点之间的连接则使用 Weak 引用,以防止强引用计数过高且形成循环。

 1use std::sync::{Arc, Weak};
 2use std::cell::RefCell;
 3
 4#[derive(Debug)]
 5struct Node {
 6    value: i32,
 7    neighbors: RefCell<Vec<Weak<Node>>>,
 8}
 9
10impl Node {
11    fn new(value: i32) -> Arc<Self> {
12        Arc::new(Node {
13            value,
14            neighbors: RefCell::new(Vec::new()),
15        })
16    }
17
18    fn add_neighbor(node: &Arc<Self>, neighbor: &Arc<Self>) {
19        node.neighbors.borrow_mut().push(Arc::downgrade(neighbor));
20    }
21}
22
23fn main() {
24    // 创建四个节点
25    let node1 = Node::new(1);
26    let node2 = Node::new(2);
27    let node3 = Node::new(3);
28    let node4 = Node::new(4);
29
30    // 创建全联通图
31    Node::add_neighbor(&node1, &node2);
32    Node::add_neighbor(&node1, &node3);
33    Node::add_neighbor(&node1, &node4);
34
35    Node::add_neighbor(&node2, &node1);
36    Node::add_neighbor(&node2, &node3);
37    Node::add_neighbor(&node2, &node4);
38
39    Node::add_neighbor(&node3, &node1);
40    Node::add_neighbor(&node3, &node2);
41    Node::add_neighbor(&node3, &node4);
42
43    Node::add_neighbor(&node4, &node1);
44    Node::add_neighbor(&node4, &node2);
45    Node::add_neighbor(&node4, &node3);
46
47    // 打印每个节点和它的邻居
48    fn print_neighbors(node: &Arc<Node>) {
49        let neighbors = node.neighbors.borrow();
50        let neighbor_values: Vec<_> = neighbors.iter()
51            .filter_map(|weak| weak.upgrade())
52            .map(|neighbor| neighbor.value)
53            .collect();
54        println!("Node {}: {:?}", node.value, neighbor_values);
55    }
56
57    print_neighbors(&node1);
58    print_neighbors(&node2);
59    print_neighbors(&node3);
60    print_neighbors(&node4);
61}

实现引用计数

初始版本

下面是一个最简化的版本用来帮助理解 Rc 的核心原理。

  • 使用 NonNull 来表示非空指针。

  • 实现 Deref trait 以便 Rc 可以像普通引用一样被解引用。比如 *rc

  • Box::into_raw 将一个 Box<T> 转换成一个裸指针。原来的 Box<T> 实例不再负责管理那块内存。稍后使用 Box::from_raw 重新获取所有权,从而离开作用域才能释放。

  • clone 时增加计数,drop 时减少计数。

 1use std::ops::Deref;
 2use std::ptr::NonNull;
 3
 4struct Rc<T> {
 5    ptr: NonNull<Inner<T>>,
 6}
 7
 8struct Inner<T> {
 9    value: T,
10    ref_count: usize,
11}
12
13impl<T> Rc<T> {
14    fn new(value: T) -> Self {
15        let inner = Box::new(Inner {
16            value,
17            ref_count: 1,
18        });
19        Rc {
20            ptr: unsafe { NonNull::new_unchecked(Box::into_raw(inner)) },
21        }
22    }
23
24    fn clone(&self) -> Self {
25        unsafe {
26            (*self.ptr.as_ptr()).ref_count += 1;
27        }
28        Rc { ptr: self.ptr }
29    }
30}
31
32impl<T> Deref for Rc<T> {
33    type Target = T;
34
35    fn deref(&self) -> &Self::Target {
36        unsafe { &(*self.ptr.as_ptr()).value }
37    }
38}
39
40impl<T> Drop for Rc<T> {
41    fn drop(&mut self) {
42        unsafe {
43            let inner = self.ptr.as_ptr();
44            (*inner).ref_count -= 1;
45            if (*inner).ref_count == 0 {
46                Box::from_raw(inner); // 自动调用 drop 来释放内存
47            }
48        }
49    }
50}
51
52fn main() {
53    let rc1 = Rc::new(5);
54    let rc2 = rc1.clone();
55    let rc3 = rc1.clone();
56
57    println!("rc1: {}", *rc1);
58    println!("rc2: {}", *rc2);
59    println!("rc3: {}", *rc3);
60}

实现 Weak

改动如下:

  • 增加弱引用计数:用于跟踪有多少个弱引用指向相同的对象。

  • Weak 结构体:用于表示弱引用。

  • 管理弱引用的生命周期:在强引用计数和弱引用计数都为零时释放资源。

  1use std::cell::Cell;
  2use std::ops::Deref;
  3use std::ptr::NonNull;
  4
  5struct Rc<T> {
  6    ptr: NonNull<Inner<T>>,
  7}
  8
  9struct Weak<T> {
 10    ptr: NonNull<Inner<T>>,
 11}
 12
 13struct Inner<T> {
 14    value: T,
 15    strong_count: Cell<usize>,
 16    weak_count: Cell<usize>,
 17}
 18
 19impl<T> Rc<T> {
 20    fn new(value: T) -> Self {
 21        let inner = Box::new(Inner {
 22            value,
 23            strong_count: Cell::new(1),
 24            weak_count: Cell::new(0),
 25        });
 26        Rc {
 27            ptr: unsafe { NonNull::new_unchecked(Box::into_raw(inner)) },
 28        }
 29    }
 30
 31    fn downgrade(&self) -> Weak<T> {
 32        self.inner().weak_count.set(self.weak_count() + 1);
 33        Weak { ptr: self.ptr }
 34    }
 35
 36    fn strong_count(&self) -> usize {
 37        self.inner().strong_count.get()
 38    }
 39
 40    fn weak_count(&self) -> usize {
 41        self.inner().weak_count.get()
 42    }
 43
 44    fn inner(&self) -> &Inner<T> {
 45        unsafe { self.ptr.as_ref() }
 46    }
 47}
 48
 49impl<T> Clone for Rc<T> {
 50    fn clone(&self) -> Self {
 51        self.inner().strong_count.set(self.strong_count() + 1);
 52        Rc { ptr: self.ptr }
 53    }
 54}
 55
 56impl<T> Deref for Rc<T> {
 57    type Target = T;
 58
 59    fn deref(&self) -> &Self::Target {
 60        &self.inner().value
 61    }
 62}
 63
 64impl<T> Drop for Rc<T> {
 65    fn drop(&mut self) {
 66        let strong_count = self.strong_count();
 67        if strong_count > 1 {
 68            self.inner().strong_count.set(strong_count - 1);
 69        } else {
 70            let weak_count = self.weak_count();
 71            if weak_count == 0 {
 72                unsafe { Box::from_raw(self.ptr.as_ptr()); } // 释放 Inner
 73            } else {
 74                self.inner().strong_count.set(0);
 75            }
 76        }
 77    }
 78}
 79
 80impl<T> Weak<T> {
 81    fn upgrade(&self) -> Option<Rc<T>> {
 82        let strong_count = self.strong_count();
 83        if strong_count == 0 {
 84            None
 85        } else {
 86            self.inner().strong_count.set(strong_count + 1);
 87            Some(Rc { ptr: self.ptr })
 88        }
 89    }
 90
 91    fn strong_count(&self) -> usize {
 92        self.inner().strong_count.get()
 93    }
 94
 95    fn weak_count(&self) -> usize {
 96        self.inner().weak_count.get()
 97    }
 98
 99    fn inner(&self) -> &Inner<T> {
100        unsafe { self.ptr.as_ref() }
101    }
102}
103
104impl<T> Clone for Weak<T> {
105    fn clone(&self) -> Self {
106        self.inner().weak_count.set(self.weak_count() + 1);
107        Weak { ptr: self.ptr }
108    }
109}
110
111impl<T> Drop for Weak<T> {
112    fn drop(&mut self) {
113        let weak_count = self.weak_count();
114        if weak_count > 1 {
115            self.inner().weak_count.set(weak_count - 1);
116        } else {
117            let strong_count = self.strong_count();
118            if strong_count == 0 {
119                unsafe { Box::from_raw(self.ptr.as_ptr()); } // 释放 Inner
120            } else {
121                self.inner().weak_count.set(0);
122            }
123        }
124    }
125}
126
127fn main() {
128    let rc1 = Rc::new(5);
129    let weak1 = rc1.downgrade();
130    let rc2 = rc1.clone();
131
132    println!("Strong count: {}", rc1.strong_count());
133    println!("Weak count: {}", rc1.weak_count());
134
135    if let Some(rc3) = weak1.upgrade() {
136        println!("Upgraded value: {}", *rc3);
137    } else {
138        println!("Upgrade failed");
139    }
140
141    drop(rc1);
142    drop(rc2);
143
144    if let Some(rc3) = weak1.upgrade() {
145        println!("Upgraded value after drop: {}", *rc3);
146    } else {
147        println!("Upgrade failed after drop");
148    }
149}

实现 Arc

Arc 的主要区别在于线程安全的,因此需要使用 AtomicUsize 而不是 Cell<usize> 来管理引用计数。

  1use std::sync::atomic::{AtomicUsize, Ordering};
  2use std::sync::Arc as StdArc;
  3use std::ops::Deref;
  4use std::ptr::NonNull;
  5
  6struct Arc<T> {
  7    ptr: NonNull<Inner<T>>,
  8}
  9
 10struct Weak<T> {
 11    ptr: NonNull<Inner<T>>,
 12}
 13
 14struct Inner<T> {
 15    value: T,
 16    strong_count: AtomicUsize,
 17    weak_count: AtomicUsize,
 18}
 19
 20impl<T> Arc<T> {
 21    fn new(value: T) -> Self {
 22        let inner = Box::new(Inner {
 23            value,
 24            strong_count: AtomicUsize::new(1),
 25            weak_count: AtomicUsize::new(0),
 26        });
 27        Arc {
 28            ptr: unsafe { NonNull::new_unchecked(Box::into_raw(inner)) },
 29        }
 30    }
 31
 32    fn downgrade(&self) -> Weak<T> {
 33        self.inner().weak_count.fetch_add(1, Ordering::Relaxed);
 34        Weak { ptr: self.ptr }
 35    }
 36
 37    fn strong_count(&self) -> usize {
 38        self.inner().strong_count.load(Ordering::Relaxed)
 39    }
 40
 41    fn weak_count(&self) -> usize {
 42        self.inner().weak_count.load(Ordering::Relaxed)
 43    }
 44
 45    fn inner(&self) -> &Inner<T> {
 46        unsafe { self.ptr.as_ref() }
 47    }
 48}
 49
 50impl<T> Clone for Arc<T> {
 51    fn clone(&self) -> Self {
 52        self.inner().strong_count.fetch_add(1, Ordering::Relaxed);
 53        Arc { ptr: self.ptr }
 54    }
 55}
 56
 57impl<T> Deref for Arc<T> {
 58    type Target = T;
 59
 60    fn deref(&self) -> &Self::Target {
 61        &self.inner().value
 62    }
 63}
 64
 65impl<T> Drop for Arc<T> {
 66    fn drop(&mut self) {
 67        if self.inner().strong_count.fetch_sub(1, Ordering::Release) == 1 {
 68            std::sync::atomic::fence(Ordering::Acquire);
 69            if self.weak_count() == 0 {
 70                unsafe { Box::from_raw(self.ptr.as_ptr()); } // 释放 Inner
 71            } else {
 72                self.inner().strong_count.store(0, Ordering::Relaxed);
 73            }
 74        }
 75    }
 76}
 77
 78impl<T> Weak<T> {
 79    fn upgrade(&self) -> Option<Arc<T>> {
 80        let strong_count = self.strong_count();
 81        if strong_count == 0 {
 82            None
 83        } else {
 84            self.inner().strong_count.fetch_add(1, Ordering::Relaxed);
 85            Some(Arc { ptr: self.ptr })
 86        }
 87    }
 88
 89    fn strong_count(&self) -> usize {
 90        self.inner().strong_count.load(Ordering::Relaxed)
 91    }
 92
 93    fn weak_count(&self) -> usize {
 94        self.inner().weak_count.load(Ordering::Relaxed)
 95    }
 96
 97    fn inner(&self) -> &Inner<T> {
 98        unsafe { self.ptr.as_ref() }
 99    }
100}
101
102impl<T> Clone for Weak<T> {
103    fn clone(&self) -> Self {
104        self.inner().weak_count.fetch_add(1, Ordering::Relaxed);
105        Weak { ptr: self.ptr }
106    }
107}
108
109impl<T> Drop for Weak<T> {
110    fn drop(&mut self) {
111        if self.inner().weak_count.fetch_sub(1, Ordering::Release) == 1 {
112            std::sync::atomic::fence(Ordering::Acquire);
113            if self.strong_count() == 0 {
114                unsafe { Box::from_raw(self.ptr.as_ptr()); } // 释放 Inner
115            }
116        }
117    }
118}
119
120fn main() {
121    let arc1 = Arc::new(5);
122    let weak1 = arc1.downgrade();
123    let arc2 = arc1.clone();
124
125    println!("Strong count: {}", arc1.strong_count());
126    println!("Weak count: {}", arc1.weak_count());
127
128    if let Some(arc3) = weak1.upgrade() {
129        println!("Upgraded value: {}", *arc3);
130    } else {
131        println!("Upgrade failed");
132    }
133
134    drop(arc1);
135    drop(arc2);
136
137    if let Some(arc3) = weak1.upgrade() {
138        println!("Upgraded value after drop: {}", *arc3);
139    } else {
140        println!("Upgrade failed after drop");
141    }
142}

参考资料