C++:深度使用 shared_ptr 的一些经验教训

看了网上,大多数都是告诉你什么不要循环引用、树形节点父子用 rc,子父用 weak 这种泛泛而谈。凡是论打破循环引用,总是说加 weak,却不提这样改可能导致的过早失效等问题。

最近在写一个试验性的库,高强度使用了这些指针。所以总结一下经验。

shared_ptr 意味着什么?

“共享所有权”一定是第一个跳入你脑海的词汇,因为它的名字就反映了这一点。但我想说的是,有时候我们可以尝试从依赖的视角来分析这个问题。

如果你的结构体或者类含有一个共享指针的话,这意味着:

  1. 你的结构体强依赖于它指向的东西

  2. 并且他指向的东西也可能被别的东西所强依赖

  3. 并且你需要确保它们都依赖的是同一个实例。

这就引申出了一个共享指针的使用原则:依赖关系是如何的?

依赖这个东西看起来简单但是一旦套入到一个实际场景就很容易犯错误。

例如很多人以为是工人离开老板就不能活,但是实际上反而是老板离不开工人。这里构成的依赖关系就是老板依赖于工人。

gh

(图中箭头表示“依赖于”)

很多时候这种东西是下意识的,因此可能会在后面给你带来困扰。

例如当建模一个符号 A 时,A 的值和梯度谁依赖谁?实际上梯度依赖于值,但大家会下意识地将其并列:

struct A {
  TensorPtr value;
  TensorPtr grad;
};

当然,这种设计也并不能说是错的,万事万物视具体实际情况而定。

原则一:如果A 都没有了那么 B 也无法存在,那么说明 B 依赖于 A

这里实际上就涉及到的生命周期。就是说你所强依赖的东西,他的生命期一定要比你长。是的,虽然我们不写 Rust,C++ 编译器也不会帮你检查,但是你违反这个规则,就会产生隐患。

从而我们可以推导第二个原则。

原则二:如果 A 的生命期要大于 B 的生命期并且 B 强依赖 A,那么建模时,可以考虑 A 作为 B 的一个所有权指针而存在。

例如,一个函数有输入、算符、和输出。一种直觉上的建模方式是:

1struct Fn {
2  Rc<Var> in;
3  Rc<Var> out;
4
5  void apply() { ... }
6}

但是实际上,Fn 依赖于 in,但并不依赖于 out。我们可以考虑让外部负责维护这个 out。

1struct Fn {
2  Uniq<Var> in;
3  Var* out;
4
5  void apply (...
6  void set_out (...
7}

这样,当组成链式调用的时候,你得到的是一条单向的依赖链条。而你的输出,被用户或另一个函数持有的变量所有。

gh

(上图的箭头是数据流方向)

数据流与依赖方向是平行的,这样的代码我们甚至不需要 shared_ptr,相比于共同持有所有权会更好维护。

实际上,除非代码里明显用到了图类结构,否则大多数情况 unique_ptr 可以满足需求。

如果由于实际需要,依然需要用 shared_ptr(例如为了节省内存,允许多个 Fn 共用一个输入),单向的依赖也会更好管理。

1struct Fn {
2  Rc<Var> in;
3  Weak<Var> out;
4
5  void apply (...
6  void set_out (...
7}

shared_ptr 成环并不一定内存泄漏

一个常见的误区是 shared_ptr 成环一定会导致无法析构,从而内存泄漏。让我们举一个反例。

 1#include <iostream>
 2#include <memory>
 3
 4struct Node {
 5    int value;
 6    std::shared_ptr<Node> next;
 7    std::weak_ptr<Node> prev;
 8
 9    Node(int val) : value(val) {}
10    ~Node() {
11        std::cout << "Node with value " << value << " destroyed" << std::endl;
12    }
13};
14
15class CircularLinkedList {
16public:
17    CircularLinkedList() : head(nullptr) {}
18
19    ~CircularLinkedList() {
20        if (head) {
21            head->next.reset();
22        }
23    }
24
25    void append(int value) {
26        std::shared_ptr<Node> newNode = std::make_shared<Node>(value);
27        if (!head) {
28            head = newNode;
29            head->next = head;
30            head->prev = head;
31        } else {
32            std::shared_ptr<Node> tail = head->prev.lock();
33            tail->next = newNode;
34            newNode->prev = tail;
35            newNode->next = head;
36            head->prev = newNode;
37        }
38    }
39
40    void display() {
41        if (!head) {
42            std::cout << "List is empty" << std::endl;
43            return;
44        }
45
46        std::shared_ptr<Node> current = head;
47        do {
48            std::cout << current->value << " ";
49            current = current->next;
50        } while (current != head);
51        std::cout << std::endl;
52    }
53
54private:
55    std::shared_ptr<Node> head;
56};
57
58int main() {
59    CircularLinkedList list;
60    list.append(1);
61    list.append(2);
62    list.append(3);
63    list.append(4);
64
65    list.display(); // Output: 1 2 3 4
66
67    return 0;
68}

上面的代码,我们虽然让 shared_ptr 成环了,但是并不影响 CircularLinkedList 的实例 list 离开作用域后调用析构,只要我们在析构函数中破环,就能避免泄漏。

不过,如果你的对象都是同生共死,而且需要相互引用的,那么实际上并不推荐使用共享指针。更推荐的方法是使用 Arena allocator 或者 GC。

是的,C++ 其实很容易自己实现 GC,尽管可能会导致代码比较丑陋。

unique_ptr 并不等于不会导致循环引用

一个常见的误区是,认为独占指针就不会导致循环引用。实际上独占指针并不阻止你相互独占。

一个简单的例子:

 1#include <iostream>
 2#include <memory>
 3
 4class B;
 5
 6class A {
 7public:
 8    std::unique_ptr<B> b_ptr;
 9    ~A() { std::cout << "A destroyed\n"; }
10};
11
12class B {
13public:
14    std::unique_ptr<A> a_ptr;
15    ~B() { std::cout << "B destroyed\n"; }
16};
17
18int main() {
19    auto a = std::make_unique<A>();
20    auto b = std::make_unique<B>();
21
22    a->b_ptr = std::move(b);
23    a->b_ptr->a_ptr = std::move(a);
24
25    return 0;
26}

而且独占指针一旦和共享指针联合使用的话更容易产生循环引用的问题,这里不展开。

一个图结构的建模小技巧

最后提供一个利用智能指针管理图结构的技巧,尤其是有环图。

这个方法是,通过使用一个上下文对象(如 Graph 类)来管理所有节点的所有权,可以避免循环引用问题。所有的边可以使用 std::weak_ptr 来管理,这样每个节点的生命周期都由 Graph 对象管理,而不是由节点之间的引用关系管理。

 1#include <iostream>
 2#include <memory>
 3#include <vector>
 4#include <unordered_map>
 5#include <algorithm>
 6
 7class Node {
 8public:
 9    int value;
10    std::vector<std::weak_ptr<Node>> neighbors;
11
12    Node(int val) : value(val) {
13        std::cout << "Node " << value << " created.\n";
14    }
15
16    ~Node() {
17        std::cout << "Node " << value << " destroyed.\n";
18    }
19};
20
21class Graph {
22public:
23    std::unordered_map<int, std::shared_ptr<Node>> nodes;
24
25    void addNode(int value) {
26        nodes[value] = std::make_shared<Node>(value);
27    }
28
29    void addEdge(int from, int to) {
30        auto fromNode = nodes[from];
31        auto toNode = nodes[to];
32        if (fromNode && toNode) {
33            fromNode->neighbors.push_back(toNode);
34        }
35    }
36
37    void removeNode(int value) {
38        auto it = nodes.find(value);
39        if (it != nodes.end()) {
40            for (auto& [otherValue, otherNode] : nodes) {
41                if (otherValue != value) {
42                    otherNode->neighbors.erase(
43                        std::remove_if(otherNode->neighbors.begin(), otherNode->neighbors.end(),
44                                       [&it](const std::weak_ptr<Node>& weak_neighbor) {
45                                           auto neighbor = weak_neighbor.lock();
46                                           return neighbor && neighbor->value == it->first;
47                                       }),
48                        otherNode->neighbors.end());
49                }
50            }
51            nodes.erase(it);
52        }
53    }
54
55    std::shared_ptr<Node> getNode(int value) {
56        return nodes[value];
57    }
58};
59
60int main() {
61    Graph graph;
62
63    graph.addNode(1);
64    graph.addNode(2);
65    graph.addNode(3);
66
67    graph.addEdge(1, 2);
68    graph.addEdge(2, 3);
69    graph.addEdge(3, 1);
70
71    std::cout << "Initial graph:\n";
72    for (const auto& [value, node] : graph.nodes) {
73        std::cout << "Node " << value << " has neighbors: ";
74        for (const auto& weak_neighbor : node->neighbors) {
75            if (auto neighbor = weak_neighbor.lock()) {
76                std::cout << neighbor->value << " ";
77            }
78        }
79        std::cout << "\n";
80    }
81
82    graph.removeNode(2);
83
84    std::cout << "After removing node 2:\n";
85
86    for (const auto& [value, node] : graph.nodes) {
87        std::cout << "Node " << value << " has neighbors: ";
88        for (const auto& weak_neighbor : node->neighbors) {
89            if (auto neighbor = weak_neighbor.lock()) {
90                std::cout << neighbor->value << " ";
91            } else {
92                std::cout << "(expired) ";
93            }
94        }
95        std::cout << "\n";
96    }
97
98    return 0;
99}