.. | .. |
---|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
---|
1 | 2 | #include <linux/module.h> |
---|
2 | 3 | #include <linux/moduleparam.h> |
---|
3 | 4 | #include <linux/rbtree_augmented.h> |
---|
.. | .. |
---|
76 | 77 | } |
---|
77 | 78 | |
---|
78 | 79 | |
---|
79 | | -static inline u32 augment_recompute(struct test_node *node) |
---|
80 | | -{ |
---|
81 | | - u32 max = node->val, child_augmented; |
---|
82 | | - if (node->rb.rb_left) { |
---|
83 | | - child_augmented = rb_entry(node->rb.rb_left, struct test_node, |
---|
84 | | - rb)->augmented; |
---|
85 | | - if (max < child_augmented) |
---|
86 | | - max = child_augmented; |
---|
87 | | - } |
---|
88 | | - if (node->rb.rb_right) { |
---|
89 | | - child_augmented = rb_entry(node->rb.rb_right, struct test_node, |
---|
90 | | - rb)->augmented; |
---|
91 | | - if (max < child_augmented) |
---|
92 | | - max = child_augmented; |
---|
93 | | - } |
---|
94 | | - return max; |
---|
95 | | -} |
---|
| 80 | +#define NODE_VAL(node) ((node)->val) |
---|
96 | 81 | |
---|
97 | | -RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb, |
---|
98 | | - u32, augmented, augment_recompute) |
---|
| 82 | +RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks, |
---|
| 83 | + struct test_node, rb, u32, augmented, NODE_VAL) |
---|
99 | 84 | |
---|
100 | 85 | static void insert_augmented(struct test_node *node, |
---|
101 | 86 | struct rb_root_cached *root) |
---|
.. | .. |
---|
237 | 222 | check(nr_nodes); |
---|
238 | 223 | for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { |
---|
239 | 224 | struct test_node *node = rb_entry(rb, struct test_node, rb); |
---|
240 | | - WARN_ON_ONCE(node->augmented != augment_recompute(node)); |
---|
| 225 | + u32 subtree, max = node->val; |
---|
| 226 | + if (node->rb.rb_left) { |
---|
| 227 | + subtree = rb_entry(node->rb.rb_left, struct test_node, |
---|
| 228 | + rb)->augmented; |
---|
| 229 | + if (max < subtree) |
---|
| 230 | + max = subtree; |
---|
| 231 | + } |
---|
| 232 | + if (node->rb.rb_right) { |
---|
| 233 | + subtree = rb_entry(node->rb.rb_right, struct test_node, |
---|
| 234 | + rb)->augmented; |
---|
| 235 | + if (max < subtree) |
---|
| 236 | + max = subtree; |
---|
| 237 | + } |
---|
| 238 | + WARN_ON_ONCE(node->augmented != max); |
---|
241 | 239 | } |
---|
242 | 240 | } |
---|
243 | 241 | |
---|