Merkle Tree 及其算法的设计与实现
Hi,我是 Pluveto,正在学习成为区块链工程师。今天我们的主题是 Merkle 树。Merkle 🌲是一种思想巧妙,设计简洁的数据结构。它能快速地(准确来说,在对数级别的时间复杂度内)验证数据块是否存在于一个更大的数据集合中,甚至还能找到出它的位置。
我打算用 Python 来实现,这样能看懂的人会更多一些。我还会指出一些需要注意的点,这些点很可能在别的文章会被忽略。让我们开始吧。
只是简单地把哈希值堆放起来……
在动手写代码之前,我们可以先梳理思路。Merkle 树又叫哈希树。假设有一个列表 l
,有 4 个元素:
1Alice Bob Caro David
我们很容易计算出各自的哈希,记作 H(A), H(B), H(C), H(D)
。把它们作为叶子结点。
然后计算 H(AB)=H(H(A)·H(B))
, H(CD)=H(H(C)·H(D))
,我们就得到两个哈希值。把它们作为倒数第二层节点。
最后我们计算出 H(ABCD)=H(H(AB)·H(CD))
就得到根节点。
1 H(ABCD)
2 / \
3 H(AB) H(CD)
4 / \ / \
5 H(A) H(B) H(C) H(D)
这棵树看上去很好理解,但它的背后威力无穷!
和谎言说不——发现 Merkle Tree
现在,假设你是区块链系统,你有这棵树的全部知识。同时,你也在另一个地方存放了 A、B、C、D 的真实值(Alice、Bob……)只不过要在那里读取数据,会非常慢。
此时,老王声称:他在树里放了一个数据块 D。
我们怎么知道它有没有说谎?
方案 1:遍历整个数据库
就像子串匹配一样,我们拿着老王上传的数据 D,在整个数据库(可以理解为一个很长的字符串)中搜索,直到找到 D,然后我们告诉老王:你没说谎。
然而,由于系统太大了,当我们找到数据的时候,你看到的人可能是老王的孙子……
方案2:把所有数据的哈希值记录,然后搜索
这次我们不搜原始数据了,因为动辄有人拿出很大的数据给你搜索,导致你如同被 DDOS 攻击。于是你决定,记录下所有交易(其实就是一段数据)的哈希。这样只要遍历整个区块的所有交易的哈希,就知道某个交易是否存在(可能会恰好蒙对,老王胡编的哈希和哈希列表里的产生了碰撞,但概率可以忽略不计)。
这样快多了,直到你发现了一种更巧妙的方法。
方案3:哈希树
回看开头那棵树
1 H(ABCD)
2 / \
3 H(AB) H(CD)
4 / \ / \
5 H(A) H(B) H(C) H(D)
你发现:
-
如果老王把
D
、H(C)
给你,你就可以计算出H(CD)
-
如果他再给你
H(AB)
,那么你就可以计算出H(ABCD)
。 -
只要计算出来的
H(ABCD)
与真实的H(ABCD)
不一样,就说明老王骗了你。
原因在于,如果 [H(AB),H(C)]
中的任何一个无效,都会导致最终算出来的 H(ABCD)
与实际不同。我们称 [H(AB),H(C)]
为一个证明(proof)。
同时,根据证明,我们还能定位到元素的位置:
-
从根节点出发。位置序列为:0
-
证明的下一个元素为
H(AB)
,则我们选择邻分支H(CD)
。位置序列为:0、1 -
再下一个元素为
H(C)
,则我们选择邻分支H(D)
。位置序列为:0、1、1
恭喜你,发现了 Merkle 树,以及它的性质。
实现 Merkle Tree
定义树结点
由于是二叉树,只定义左右结点。我们还定义了 is_copied 字段来指示此节点是否复制产生。这是为了确保构建出二叉树,我们会在奇数结点的情况下复制补足成对。
注意:
content 字段用于学习、调试目的,生产环境应该去除。
对于高性能要求的环境,最好使用其他语言实现,并进行算法优化
1@dataclasses.dataclass
2class Node:
3 """
4 Represents a binary tree node, in our case a Merkle Tree node.
5
6 Attributes:
7 left: left child node
8 right: right child node
9 value: hash value of the node
10 content: content of the node
11 is_copied: whether the node is a copy, this is because we duplicate
12 the last element if the number of elements is odd when building
13 """
14
15 left: Optional["Node"]
16 right: Optional["Node"]
17 value: bytes
18 content: bytes # just for debugging, remove in production
19 is_copied: bool = False
20
21 def __str__(self) -> str:
22 return self.value.decode("utf-8")
23
24 def copy(self) -> "Node":
25 """
26 Get the duplication of the node.
27 We mark the node with copied=True when duplicating the last element.
28 """
29 return Node(self.left, self.right, self.value, self.content, True)
树的构建算法
先定义类的基本形状。
1class MerkleTree:
2 """
3 Represents a Merkle Tree, for which every leaf node is labelled with the hash
4 of a data block, and every non-leaf node is labelled with the crypto hash of
5 the labels of its child nodes. It is used to verify the integrity of blocks.
6 """
7
8 def __init__(self, values: List[bytes], hash_fn: HashFn) -> None:
9 self._hash_fn = hash_fn
10 self._root: Node = self._buildTree(values)
11
12 def __str__(self) -> str:
13 return self._root.value.hex()
14
15 @property
16 def root(self) -> Node:
17 """
18 get the duplication of root node of the Merkle Tree
19 """
20 return self._root.copy()
_buildTree 函数用于构建树。对于奇数个结点,会复制补齐。
1 def _buildTree(self, values: List[bytes]) -> Node:
2 leaves: List[Node] = [Node(None, None, self._hash_fn(e), e) for e in values]
3 return self._buildTreeRec(leaves)
4
5 def _buildTreeRec(self, nodes: List[Node]) -> Node:
6 # duplicate last elem if odd number of elements
7 if len(nodes) % 2 == 1:
8 nodes.append(nodes[-1].copy())
9
10 half: int = len(nodes) // 2
11 if len(nodes) == 2:
12 value = self._hash_fn(bytes_xor(nodes[0].value, nodes[1].value))
13 return Node(nodes[0], nodes[1], value, nodes[0].content + nodes[1].content)
14
15 left: Node = self._buildTreeRec(nodes[:half])
16 right: Node = self._buildTreeRec(nodes[half:])
17 value: bytes = self._hash_fn(bytes_xor(left.value, right.value))
18 return Node(left, right, value, left.content + right.content)
获取叶子位置的算法
遍历获取,复杂度为 O(n)
。**实际场景下,位置应该通过从元素在原列表的索引得出。**例如索引是 6,那么直接转换为二进制 110,则位置序列是 [1, 1, 0]
,而非调用这个函数。
1 def get_location(self, block_hash: bytes) -> Optional[List[int]]:
2 """
3 Get the location of a block hash in a Merkle Tree
4 """
5 return self._get_location_rec(self._root, block_hash, [])
6
7 def _get_location_rec(
8 self, node: Optional[Node], block_hash: bytes, path: List[int]
9 ) -> Optional[List[int]]:
10 if node is None:
11 return None
12
13 if node.value == block_hash and not node.is_copied:
14 return path
15
16 left_path = self._get_location_rec(node.left, block_hash, path + [0])
17 if left_path is not None:
18 return left_path
19
20 right_path = self._get_location_rec(node.right, block_hash, path + [1])
21 if right_path is not None:
22 return right_path
23
24 return None
获取证明序列的算法实现
这是关键部分。主要原理是,自上而下顺着位置路径行走,记录一路的邻元素哈希。
1 def get_proof(self, block_hash: bytes) -> Optional[List[bytes]]:
2 """
3 Get the proof of a block hash in a Merkle Tree.
4 The proof is a list of sibling hashes of the block hash.
5
6 Note:
7 This function returns hashes in the top-down order. So, don't forget
8 to reverse the list when you want to verify the proof.
9 """
10 location = self.get_location(block_hash)
11 if location is None:
12 return None
13
14 return self._get_proof_rec(self._root, location, 0, [])
15
16 def _get_proof_rec(
17 self, node: Optional[Node], location: List[int], index: int, proof: List[bytes]
18 ) -> Optional[List[bytes]]:
19 if node is None or index >= len(location):
20 return proof
21
22 if node.right and location[index] == 0:
23 proof.append(node.right.value)
24 return self._get_proof_rec(node.left, location, index + 1, proof)
25
26 elif node.left and location[index] == 1:
27 proof.append(node.left.value)
28 return self._get_proof_rec(node.right, location, index + 1, proof)
29
30 return None
验证算法的实现
验证十分简单,只需倒序遍历证明并迭代计算,最后与根哈希比对。
1def verify_proof(
2 root_hash: bytes, block_hash: bytes, proof: List[bytes], hash_fn: HashFn
3) -> bool:
4 """verify if a block hash is in a Merkle Tree with a given root hash and proof"""
5 current_hash = block_hash
6 for sibling_hash in reversed(proof):
7 current_hash = hash_fn(bytes_xor(current_hash, sibling_hash))
8
9 return current_hash == root_hash
测试 Merkle Tree
完整的实现和测试代码如下:
1import dataclasses
2
3from typing import Callable, Optional, List
4
5HashFn = Callable[[bytes], bytes]
6
7
8def bytes_xor(a: bytes, b: bytes) -> bytes:
9 assert len(a) == len(b), "length of a and b should be equal"
10 return bytes([_a ^ _b for _a, _b in zip(a, b)])
11
12
13def verify_proof(
14 root_hash: bytes, block_hash: bytes, proof: List[bytes], hash_fn: HashFn
15) -> bool:
16 """verify if a block hash is in a Merkle Tree with a given root hash and proof"""
17 current_hash = block_hash
18 for sibling_hash in reversed(proof):
19 current_hash = hash_fn(bytes_xor(current_hash, sibling_hash))
20
21 return current_hash == root_hash
22
23
24@dataclasses.dataclass
25class Node:
26 """
27 Represents a binary tree node, in our case a Merkle Tree node.
28
29 Attributes:
30 left: left child node
31 right: right child node
32 value: hash value of the node
33 content: content of the node
34 is_copied: whether the node is a copy, this is because we duplicate
35 the last element if the number of elements is odd when building
36 """
37
38 left: Optional["Node"]
39 right: Optional["Node"]
40 value: bytes
41 content: bytes # just for debugging, remove in production
42 is_copied: bool = False
43
44 def __str__(self) -> str:
45 return self.value.decode("utf-8")
46
47 def copy(self) -> "Node":
48 """
49 Get the duplication of the node.
50 We mark the node with copied=True when duplicating the last element.
51 """
52 return Node(self.left, self.right, self.value, self.content, True)
53
54
55class MerkleTree:
56 """
57 Represents a Merkle Tree, for which every leaf node is labelled with the hash
58 of a data block, and every non-leaf node is labelled with the crypto hash of
59 the labels of its child nodes. It is used to verify the integrity of blocks.
60 """
61
62 def __init__(self, values: List[bytes], hash_fn: HashFn) -> None:
63 self._hash_fn = hash_fn
64 self._root: Node = self._buildTree(values)
65
66 def __str__(self) -> str:
67 return self._root.value.hex()
68
69 @property
70 def root(self) -> Node:
71 """
72 get the duplication of root node of the Merkle Tree
73 """
74 return self._root.copy()
75
76 def _buildTree(self, values: List[bytes]) -> Node:
77 leaves: List[Node] = [Node(None, None, self._hash_fn(e), e) for e in values]
78 return self._buildTreeRec(leaves)
79
80 def _buildTreeRec(self, nodes: List[Node]) -> Node:
81 # duplicate last elem if odd number of elements
82 if len(nodes) % 2 == 1:
83 nodes.append(nodes[-1].copy())
84
85 half: int = len(nodes) // 2
86 if len(nodes) == 2:
87 value = self._hash_fn(bytes_xor(nodes[0].value, nodes[1].value))
88 return Node(nodes[0], nodes[1], value, nodes[0].content + nodes[1].content)
89
90 left: Node = self._buildTreeRec(nodes[:half])
91 right: Node = self._buildTreeRec(nodes[half:])
92 value: bytes = self._hash_fn(bytes_xor(left.value, right.value))
93 return Node(left, right, value, left.content + right.content)
94
95 def compare_trees(self, other: "MerkleTree") -> bool:
96 """
97 Compare the root hashes of two Merkle Trees
98 """
99 return self._root.value == other._root.value
100
101 def verify_block(self, root_hash: bytes, block_hash: bytes) -> bool:
102 """
103 Verify if a block hash is in a Merkle Tree with a given root hash
104 """
105 return self._root.value == root_hash and self._verify_block_rec(
106 self._root, block_hash
107 )
108
109 def _verify_block_rec(self, node: Optional[Node], block_hash: bytes) -> bool:
110 if node is None:
111 return False
112
113 return (
114 node.value == block_hash
115 or self._verify_block_rec(node.left, block_hash)
116 or self._verify_block_rec(node.right, block_hash)
117 )
118
119 def get_location(self, block_hash: bytes) -> Optional[List[int]]:
120 """
121 Get the location of a block hash in a Merkle Tree
122 """
123 return self._get_location_rec(self._root, block_hash, [])
124
125 def _get_location_rec(
126 self, node: Optional[Node], block_hash: bytes, path: List[int]
127 ) -> Optional[List[int]]:
128 if node is None:
129 return None
130
131 if node.value == block_hash and not node.is_copied:
132 return path
133
134 left_path = self._get_location_rec(node.left, block_hash, path + [0])
135 if left_path is not None:
136 return left_path
137
138 right_path = self._get_location_rec(node.right, block_hash, path + [1])
139 if right_path is not None:
140 return right_path
141
142 return None
143
144 def get_proof(self, block_hash: bytes) -> Optional[List[bytes]]:
145 """
146 Get the proof of a block hash in a Merkle Tree.
147 The proof is a list of sibling hashes of the block hash.
148
149 Note:
150 This function returns hashes in the top-down order. So, don't forget
151 to reverse the list when you want to verify the proof.
152 """
153 location = self.get_location(block_hash)
154 if location is None:
155 return None
156
157 return self._get_proof_rec(self._root, location, 0, [])
158
159 def _get_proof_rec(
160 self, node: Optional[Node], location: List[int], index: int, proof: List[bytes]
161 ) -> Optional[List[bytes]]:
162 if node is None or index >= len(location):
163 return proof
164
165 if node.right and location[index] == 0:
166 proof.append(node.right.value)
167 return self._get_proof_rec(node.left, location, index + 1, proof)
168
169 elif node.left and location[index] == 1:
170 proof.append(node.left.value)
171 return self._get_proof_rec(node.right, location, index + 1, proof)
172
173 return None
174
175 def print_tree(self, brief: bool = True) -> None:
176 """
177 Print the Merkle Tree in a tree structure.
178 """
179 self._print_tree_rec(self._root, 0, brief)
180
181 def _print_tree_rec(self, node: Optional[Node], level: int, brief: bool) -> None:
182 """helper function for print_tree"""
183 if node is None:
184 return
185
186 value = (node.value[:4] if brief else node.value).hex()
187 content = node.content
188 print(f'{" " * level}{value=}, {content=}')
189 self._print_tree_rec(node.left, level + 1, brief)
190 self._print_tree_rec(node.right, level + 1, brief)
1import hashlib
2import unittest
3from merkle_tree import MerkleTree, verify_proof
4
5
6def sha256(val: bytes) -> bytes:
7 return hashlib.sha256(val).digest()
8
9
10class TestMerkleTree(unittest.TestCase):
11 def setUp(self):
12 """set up a MerkleTree with some testing data"""
13 self._data = list(
14 map(
15 lambda x: x.encode("utf-8"),
16 [
17 # https://en.wikipedia.org/wiki/Classical_Chinese_poetry
18 "Li Bai",
19 "Du Fu",
20 "Wang Wei",
21 "Bai Juyi",
22 "Su Shi",
23 "Li Shangyin",
24 "Li Qingzhao",
25 "Wang Anshi",
26 ],
27 )
28 )
29 self._hash_fn = sha256
30 self._tree = MerkleTree(self._data, self._hash_fn)
31 self._tree.print_tree()
32
33 def test_verify_block(self):
34 """test the verify_block method"""
35 root_hash = self._tree.root.value
36 for i in range(len(self._data)):
37 block_hash = self._hash_fn(self._data[i])
38 # verify if the block hash is in the tree
39 self.assertTrue(self._tree.verify_block(root_hash, block_hash))
40
41 # verify a invalid block hash
42 self.assertFalse(
43 self._tree.verify_block(root_hash, self._hash_fn("invalid".encode("utf-8")))
44 )
45
46 def test_get_location(self):
47 """test the get_location method"""
48 for i in range(len(self._data)):
49 block_hash = self._hash_fn(self._data[i])
50 location = self._tree.get_location(block_hash)
51 # compare the location with the expected value location(block_hash)
52 assert location is not None, "location should not be None"
53 # location should conform to the binary counting sequence 000...111
54 self.assertEqual(location, [int(x) for x in format(i, "b").zfill(3)])
55
56 self.assertIsNone(
57 self._tree.get_location(self._hash_fn("invalid".encode("utf-8")))
58 )
59
60 def test_get_proof(self):
61 """test the get_proof method"""
62 for i in range(len(self._data)):
63 block_hash = self._hash_fn(self._data[i])
64 # get the proof of the block hash in the tree
65 proof = self._tree.get_proof(block_hash)
66 assert proof is not None, "proof should not be None"
67 self.assertEqual(len(proof), 3)
68 self.assertTrue(
69 verify_proof(self._tree.root.value, block_hash, proof, self._hash_fn),
70 f"proof {proof} is invalid for block {block_hash}",
71 )
72
73 self.assertIsNone(
74 self._tree.get_proof(self._hash_fn("invalid".encode("utf-8")))
75 )
76
77 def test_verify_proof(self):
78 """test the verify_proof method"""
79 root_hash = self._tree.root.value
80 for i in range(len(self._data)):
81 block_hash = self._hash_fn(self._data[i])
82 # get the proof of the block hash in the tree
83 proof = self._tree.get_proof(block_hash)
84 assert proof is not None, "proof should not be None"
85 self.assertTrue(verify_proof(root_hash, block_hash, proof, self._hash_fn))
86
87 # verify a invalid proof
88 self.assertFalse(
89 verify_proof(
90 root_hash, self._hash_fn("invalid".encode("utf-8")), [], self._hash_fn
91 )
92 )
93
94
95if __name__ == "__main__":
96 unittest.main()
结语
感谢你的阅读。可以在 pluveto/merkle-tree 获得本文源代码。