hc
2024-01-03 2f7c68cb55ecb7331f2381deb497c27155f32faf
kernel/drivers/net/wireguard/selftest/allowedips.c
....@@ -19,32 +19,22 @@
1919
2020 #include <linux/siphash.h>
2121
22
-static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
23
- u8 cidr)
24
-{
25
- swap_endian(dst, src, bits);
26
- memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
27
- if (cidr)
28
- dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
29
-}
30
-
3122 static __init void print_node(struct allowedips_node *node, u8 bits)
3223 {
3324 char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
34
- char *fmt_declaration = KERN_DEBUG
35
- "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
25
+ char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
26
+ u8 ip1[16], ip2[16], cidr1, cidr2;
3627 char *style = "dotted";
37
- u8 ip1[16], ip2[16];
3828 u32 color = 0;
3929
30
+ if (node == NULL)
31
+ return;
4032 if (bits == 32) {
4133 fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
42
- fmt_declaration = KERN_DEBUG
43
- "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
34
+ fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
4435 } else if (bits == 128) {
4536 fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
46
- fmt_declaration = KERN_DEBUG
47
- "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
37
+ fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
4838 }
4939 if (node->peer) {
5040 hsiphash_key_t key = { { 0 } };
....@@ -55,24 +45,20 @@
5545 hsiphash_1u32(0xabad1dea, &key) % 200;
5646 style = "bold";
5747 }
58
- swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
59
- printk(fmt_declaration, ip1, node->cidr, style, color);
48
+ wg_allowedips_read_node(node, ip1, &cidr1);
49
+ printk(fmt_declaration, ip1, cidr1, style, color);
6050 if (node->bit[0]) {
61
- swap_endian_and_apply_cidr(ip2,
62
- rcu_dereference_raw(node->bit[0])->bits, bits,
63
- node->cidr);
64
- printk(fmt_connection, ip1, node->cidr, ip2,
65
- rcu_dereference_raw(node->bit[0])->cidr);
66
- print_node(rcu_dereference_raw(node->bit[0]), bits);
51
+ wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2);
52
+ printk(fmt_connection, ip1, cidr1, ip2, cidr2);
6753 }
6854 if (node->bit[1]) {
69
- swap_endian_and_apply_cidr(ip2,
70
- rcu_dereference_raw(node->bit[1])->bits,
71
- bits, node->cidr);
72
- printk(fmt_connection, ip1, node->cidr, ip2,
73
- rcu_dereference_raw(node->bit[1])->cidr);
74
- print_node(rcu_dereference_raw(node->bit[1]), bits);
55
+ wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2);
56
+ printk(fmt_connection, ip1, cidr1, ip2, cidr2);
7557 }
58
+ if (node->bit[0])
59
+ print_node(rcu_dereference_raw(node->bit[0]), bits);
60
+ if (node->bit[1])
61
+ print_node(rcu_dereference_raw(node->bit[1]), bits);
7662 }
7763
7864 static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
....@@ -121,8 +107,8 @@
121107 {
122108 union nf_inet_addr mask;
123109
124
- memset(&mask, 0x00, 128 / 8);
125
- memset(&mask, 0xff, cidr / 8);
110
+ memset(&mask, 0, sizeof(mask));
111
+ memset(&mask.all, 0xff, cidr / 8);
126112 if (cidr % 32)
127113 mask.all[cidr / 32] = (__force u32)htonl(
128114 (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
....@@ -149,42 +135,36 @@
149135 }
150136
151137 static __init inline bool
152
-horrible_match_v4(const struct horrible_allowedips_node *node,
153
- struct in_addr *ip)
138
+horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
154139 {
155140 return (ip->s_addr & node->mask.ip) == node->ip.ip;
156141 }
157142
158143 static __init inline bool
159
-horrible_match_v6(const struct horrible_allowedips_node *node,
160
- struct in6_addr *ip)
144
+horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
161145 {
162
- return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
163
- node->ip.ip6[0] &&
164
- (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
165
- node->ip.ip6[1] &&
166
- (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
167
- node->ip.ip6[2] &&
146
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
147
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
148
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
168149 (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
169150 }
170151
171152 static __init void
172
-horrible_insert_ordered(struct horrible_allowedips *table,
173
- struct horrible_allowedips_node *node)
153
+horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
174154 {
175155 struct horrible_allowedips_node *other = NULL, *where = NULL;
176156 u8 my_cidr = horrible_mask_to_cidr(node->mask);
177157
178158 hlist_for_each_entry(other, &table->head, table) {
179
- if (!memcmp(&other->mask, &node->mask,
180
- sizeof(union nf_inet_addr)) &&
181
- !memcmp(&other->ip, &node->ip,
182
- sizeof(union nf_inet_addr)) &&
183
- other->ip_version == node->ip_version) {
159
+ if (other->ip_version == node->ip_version &&
160
+ !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
161
+ !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) {
184162 other->value = node->value;
185163 kfree(node);
186164 return;
187165 }
166
+ }
167
+ hlist_for_each_entry(other, &table->head, table) {
188168 where = other;
189169 if (horrible_mask_to_cidr(other->mask) <= my_cidr)
190170 break;
....@@ -201,8 +181,7 @@
201181 horrible_allowedips_insert_v4(struct horrible_allowedips *table,
202182 struct in_addr *ip, u8 cidr, void *value)
203183 {
204
- struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
205
- GFP_KERNEL);
184
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
206185
207186 if (unlikely(!node))
208187 return -ENOMEM;
....@@ -219,8 +198,7 @@
219198 horrible_allowedips_insert_v6(struct horrible_allowedips *table,
220199 struct in6_addr *ip, u8 cidr, void *value)
221200 {
222
- struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
223
- GFP_KERNEL);
201
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
224202
225203 if (unlikely(!node))
226204 return -ENOMEM;
....@@ -234,39 +212,43 @@
234212 }
235213
236214 static __init void *
237
-horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
238
- struct in_addr *ip)
215
+horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
239216 {
240217 struct horrible_allowedips_node *node;
241
- void *ret = NULL;
242218
243219 hlist_for_each_entry(node, &table->head, table) {
244
- if (node->ip_version != 4)
245
- continue;
246
- if (horrible_match_v4(node, ip)) {
247
- ret = node->value;
248
- break;
249
- }
220
+ if (node->ip_version == 4 && horrible_match_v4(node, ip))
221
+ return node->value;
250222 }
251
- return ret;
223
+ return NULL;
252224 }
253225
254226 static __init void *
255
-horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
256
- struct in6_addr *ip)
227
+horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
257228 {
258229 struct horrible_allowedips_node *node;
259
- void *ret = NULL;
260230
261231 hlist_for_each_entry(node, &table->head, table) {
262
- if (node->ip_version != 6)
263
- continue;
264
- if (horrible_match_v6(node, ip)) {
265
- ret = node->value;
266
- break;
267
- }
232
+ if (node->ip_version == 6 && horrible_match_v6(node, ip))
233
+ return node->value;
268234 }
269
- return ret;
235
+ return NULL;
236
+}
237
+
238
+
239
+static __init void
240
+horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value)
241
+{
242
+ struct horrible_allowedips_node *node;
243
+ struct hlist_node *h;
244
+
245
+ hlist_for_each_entry_safe(node, h, &table->head, table) {
246
+ if (node->value != value)
247
+ continue;
248
+ hlist_del(&node->table);
249
+ kfree(node);
250
+ }
251
+
270252 }
271253
272254 static __init bool randomized_test(void)
....@@ -296,6 +278,7 @@
296278 goto free;
297279 }
298280 kref_init(&peers[i]->refcount);
281
+ INIT_LIST_HEAD(&peers[i]->allowedips_list);
299282 }
300283
301284 mutex_lock(&mutex);
....@@ -333,7 +316,7 @@
333316 if (wg_allowedips_insert_v4(&t,
334317 (struct in_addr *)mutated,
335318 cidr, peer, &mutex) < 0) {
336
- pr_err("allowedips random malloc: FAIL\n");
319
+ pr_err("allowedips random self-test malloc: FAIL\n");
337320 goto free_locked;
338321 }
339322 if (horrible_allowedips_insert_v4(&h,
....@@ -396,23 +379,33 @@
396379 print_tree(t.root6, 128);
397380 }
398381
399
- for (i = 0; i < NUM_QUERIES; ++i) {
400
- prandom_bytes(ip, 4);
401
- if (lookup(t.root4, 32, ip) !=
402
- horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
403
- pr_err("allowedips random self-test: FAIL\n");
404
- goto free;
382
+ for (j = 0;; ++j) {
383
+ for (i = 0; i < NUM_QUERIES; ++i) {
384
+ prandom_bytes(ip, 4);
385
+ if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
386
+ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
387
+ pr_err("allowedips random v4 self-test: FAIL\n");
388
+ goto free;
389
+ }
390
+ prandom_bytes(ip, 16);
391
+ if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
392
+ pr_err("allowedips random v6 self-test: FAIL\n");
393
+ goto free;
394
+ }
405395 }
396
+ if (j >= NUM_PEERS)
397
+ break;
398
+ mutex_lock(&mutex);
399
+ wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
400
+ mutex_unlock(&mutex);
401
+ horrible_allowedips_remove_by_value(&h, peers[j]);
406402 }
407403
408
- for (i = 0; i < NUM_QUERIES; ++i) {
409
- prandom_bytes(ip, 16);
410
- if (lookup(t.root6, 128, ip) !=
411
- horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
412
- pr_err("allowedips random self-test: FAIL\n");
413
- goto free;
414
- }
404
+ if (t.root4 || t.root6) {
405
+ pr_err("allowedips random self-test removal: FAIL\n");
406
+ goto free;
415407 }
408
+
416409 ret = true;
417410
418411 free:
....@@ -600,16 +593,20 @@
600593 wg_allowedips_remove_by_peer(&t, a, &mutex);
601594 test_negative(4, a, 192, 168, 0, 1);
602595
603
- /* These will hit the WARN_ON(len >= 128) in free_node if something
604
- * goes wrong.
596
+ /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_DEPTH) in free_node
597
+ * if something goes wrong.
605598 */
606
- for (i = 0; i < 128; ++i) {
607
- part = cpu_to_be64(~(1LLU << (i % 64)));
608
- memset(&ip, 0xff, 16);
609
- memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
599
+ for (i = 0; i < 64; ++i) {
600
+ part = cpu_to_be64(~0LLU << i);
601
+ memset(&ip, 0xff, 8);
602
+ memcpy((u8 *)&ip + 8, &part, 8);
603
+ wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
604
+ memcpy(&ip, &part, 8);
605
+ memset((u8 *)&ip + 8, 0, 8);
610606 wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
611607 }
612
-
608
+ memset(&ip, 0, 16);
609
+ wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
613610 wg_allowedips_free(&t, &mutex);
614611
615612 wg_allowedips_init(&t);