bitree.c 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #include "bitree.h"
  2. #include <mm/slab.h>
  3. #include <common/errno.h>
  4. #include <debug/bug.h>
  5. #define smaller(root, a, b) (root->cmp(a, b) == -1)
  6. #define equal(root, a, b) (root->cmp(a, b) == 0)
  7. #define greater(root, a, b) (root->cmp(a, b) == 1)
  8. /**
  9. * @brief 创建二叉搜索树
  10. *
  11. * @param node 根节点
  12. * @param cmp 比较函数
  13. * @return struct bt_root_t* 树根结构体
  14. */
  15. struct bt_root_t *bt_create_tree(struct bt_node_t *node, int (*cmp)(struct bt_node_t *a, struct bt_node_t *b))
  16. {
  17. if (node == NULL || cmp == NULL)
  18. return -EINVAL;
  19. struct bt_root_t *root = (struct bt_root_t *)kmalloc(sizeof(struct bt_root_t), 0);
  20. memset((void *)root, 0, sizeof(struct bt_root_t));
  21. root->bt_node = node;
  22. root->cmp = cmp;
  23. return root;
  24. }
  25. /**
  26. * @brief 创建结点
  27. *
  28. * @param left 左子节点
  29. * @param right 右子节点
  30. * @param value 当前节点的值
  31. * @return struct bt_node_t*
  32. */
  33. struct bt_node_t *bt_create_node(struct bt_node_t *left, struct bt_node_t *right, struct bt_node_t *parent, void *value)
  34. {
  35. struct bt_node_t *node = (struct bt_node_t *)kmalloc(sizeof(struct bt_node_t), 0);
  36. FAIL_ON_TO(node == NULL, nomem);
  37. memset((void *)node, 0, sizeof(struct bt_node_t));
  38. node->left = left;
  39. node->right = right;
  40. node->value = value;
  41. node->parent = parent;
  42. return node;
  43. nomem:;
  44. return -ENOMEM;
  45. }
  46. /**
  47. * @brief 插入结点
  48. *
  49. * @param root 树根结点
  50. * @param value 待插入结点的值
  51. * @return int 返回码
  52. */
  53. int bt_insert(struct bt_root_t *root, void *value)
  54. {
  55. if (root == NULL)
  56. return -EINVAL;
  57. struct bt_node_t *this_node = root->bt_node;
  58. struct bt_node_t *last_node = NULL;
  59. struct bt_node_t *insert_node = bt_create_node(NULL, NULL, NULL, value);
  60. FAIL_ON_TO((uint64_t)insert_node == (uint64_t)(-ENOMEM), failed);
  61. while (this_node != NULL)
  62. {
  63. last_node = this_node;
  64. if (smaller(root, insert_node, this_node))
  65. this_node = this_node->left;
  66. else
  67. this_node = this_node->right;
  68. }
  69. insert_node->parent = last_node;
  70. if (unlikely(last_node == NULL))
  71. root->bt_node = insert_node;
  72. else
  73. {
  74. if (smaller(root, insert_node, last_node))
  75. last_node->left = insert_node;
  76. else
  77. last_node->right = insert_node;
  78. }
  79. return 0;
  80. failed:;
  81. return -ENOMEM;
  82. }
  83. /**
  84. * @brief 搜索值为value的结点
  85. *
  86. * @param value 值
  87. * @param ret_addr 返回的结点基地址
  88. * @return int 错误码
  89. */
  90. int bt_query(struct bt_root_t *root, void *value, uint64_t *ret_addr)
  91. {
  92. struct bt_node_t *this_node = root->bt_node;
  93. struct bt_node_t tmp_node = {0};
  94. tmp_node.value = value;
  95. while (this_node != NULL && !equal(root, this_node, &tmp_node))
  96. {
  97. if (smaller(root, &tmp_node, this_node))
  98. this_node = this_node->left;
  99. else
  100. this_node = this_node->right;
  101. }
  102. if (equal(root, this_node, &tmp_node))
  103. {
  104. *ret_addr = (uint64_t)this_node;
  105. return 0;
  106. }
  107. else
  108. {
  109. // 找不到则返回-1,且addr设为0
  110. *ret_addr = NULL;
  111. return -1;
  112. }
  113. }
  114. static struct bt_node_t *bt_get_minimum(struct bt_node_t *this_node)
  115. {
  116. while (this_node->left != NULL)
  117. this_node = this_node->left;
  118. return this_node;
  119. }
  120. /**
  121. * @brief 删除结点
  122. *
  123. * @param root 树根
  124. * @param value 待删除结点的值
  125. * @return int 返回码
  126. */
  127. int bt_delete(struct bt_root_t *root, void *value)
  128. {
  129. uint64_t tmp_addr;
  130. int retval;
  131. // 寻找待删除结点
  132. retval = bt_query(root, value, &tmp_addr);
  133. if (retval != 0 || tmp_addr == NULL)
  134. return retval;
  135. struct bt_node_t *this_node = (struct bt_node_t *)tmp_addr;
  136. struct bt_node_t *to_delete = NULL, *to_delete_son = NULL;
  137. if (this_node->left == NULL || this_node->right == NULL)
  138. to_delete = this_node;
  139. else
  140. {
  141. to_delete = bt_get_minimum(this_node->right);
  142. // 释放要被删除的值,并把下一个结点的值替换上来
  143. root->release(this_node->value);
  144. this_node->value = to_delete->value;
  145. }
  146. if (to_delete->left != NULL)
  147. to_delete_son = to_delete->left;
  148. else
  149. to_delete_son = to_delete->right;
  150. if (to_delete_son != NULL)
  151. to_delete_son->parent = to_delete->parent;
  152. if (to_delete->parent == NULL)
  153. root->bt_node = to_delete_son;
  154. else
  155. {
  156. if (to_delete->parent->left == to_delete)
  157. to_delete->parent->left = to_delete_son;
  158. else
  159. to_delete->parent->right = to_delete_son;
  160. }
  161. // 释放最终要删除的结点的对象
  162. kfree(to_delete);
  163. }