| .. | .. |
|---|
| 1 | +// SPDX-License-Identifier: GPL-2.0-only |
|---|
| 1 | 2 | #include <linux/module.h> |
|---|
| 2 | 3 | #include <linux/moduleparam.h> |
|---|
| 3 | 4 | #include <linux/rbtree_augmented.h> |
|---|
| .. | .. |
|---|
| 76 | 77 | } |
|---|
| 77 | 78 | |
|---|
| 78 | 79 | |
|---|
| 79 | | -static inline u32 augment_recompute(struct test_node *node) |
|---|
| 80 | | -{ |
|---|
| 81 | | - u32 max = node->val, child_augmented; |
|---|
| 82 | | - if (node->rb.rb_left) { |
|---|
| 83 | | - child_augmented = rb_entry(node->rb.rb_left, struct test_node, |
|---|
| 84 | | - rb)->augmented; |
|---|
| 85 | | - if (max < child_augmented) |
|---|
| 86 | | - max = child_augmented; |
|---|
| 87 | | - } |
|---|
| 88 | | - if (node->rb.rb_right) { |
|---|
| 89 | | - child_augmented = rb_entry(node->rb.rb_right, struct test_node, |
|---|
| 90 | | - rb)->augmented; |
|---|
| 91 | | - if (max < child_augmented) |
|---|
| 92 | | - max = child_augmented; |
|---|
| 93 | | - } |
|---|
| 94 | | - return max; |
|---|
| 95 | | -} |
|---|
| 80 | +#define NODE_VAL(node) ((node)->val) |
|---|
| 96 | 81 | |
|---|
| 97 | | -RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb, |
|---|
| 98 | | - u32, augmented, augment_recompute) |
|---|
| 82 | +RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks, |
|---|
| 83 | + struct test_node, rb, u32, augmented, NODE_VAL) |
|---|
| 99 | 84 | |
|---|
| 100 | 85 | static void insert_augmented(struct test_node *node, |
|---|
| 101 | 86 | struct rb_root_cached *root) |
|---|
| .. | .. |
|---|
| 237 | 222 | check(nr_nodes); |
|---|
| 238 | 223 | for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { |
|---|
| 239 | 224 | struct test_node *node = rb_entry(rb, struct test_node, rb); |
|---|
| 240 | | - WARN_ON_ONCE(node->augmented != augment_recompute(node)); |
|---|
| 225 | + u32 subtree, max = node->val; |
|---|
| 226 | + if (node->rb.rb_left) { |
|---|
| 227 | + subtree = rb_entry(node->rb.rb_left, struct test_node, |
|---|
| 228 | + rb)->augmented; |
|---|
| 229 | + if (max < subtree) |
|---|
| 230 | + max = subtree; |
|---|
| 231 | + } |
|---|
| 232 | + if (node->rb.rb_right) { |
|---|
| 233 | + subtree = rb_entry(node->rb.rb_right, struct test_node, |
|---|
| 234 | + rb)->augmented; |
|---|
| 235 | + if (max < subtree) |
|---|
| 236 | + max = subtree; |
|---|
| 237 | + } |
|---|
| 238 | + WARN_ON_ONCE(node->augmented != max); |
|---|
| 241 | 239 | } |
|---|
| 242 | 240 | } |
|---|
| 243 | 241 | |
|---|