.. | .. |
---|
19 | 19 | |
---|
20 | 20 | #include <linux/siphash.h> |
---|
21 | 21 | |
---|
22 | | -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, |
---|
23 | | - u8 cidr) |
---|
24 | | -{ |
---|
25 | | - swap_endian(dst, src, bits); |
---|
26 | | - memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); |
---|
27 | | - if (cidr) |
---|
28 | | - dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); |
---|
29 | | -} |
---|
30 | | - |
---|
31 | 22 | static __init void print_node(struct allowedips_node *node, u8 bits) |
---|
32 | 23 | { |
---|
33 | 24 | char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; |
---|
34 | | - char *fmt_declaration = KERN_DEBUG |
---|
35 | | - "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
| 25 | + char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
| 26 | + u8 ip1[16], ip2[16], cidr1, cidr2; |
---|
36 | 27 | char *style = "dotted"; |
---|
37 | | - u8 ip1[16], ip2[16]; |
---|
38 | 28 | u32 color = 0; |
---|
39 | 29 | |
---|
| 30 | + if (node == NULL) |
---|
| 31 | + return; |
---|
40 | 32 | if (bits == 32) { |
---|
41 | 33 | fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; |
---|
42 | | - fmt_declaration = KERN_DEBUG |
---|
43 | | - "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
| 34 | + fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
44 | 35 | } else if (bits == 128) { |
---|
45 | 36 | fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; |
---|
46 | | - fmt_declaration = KERN_DEBUG |
---|
47 | | - "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
| 37 | + fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
---|
48 | 38 | } |
---|
49 | 39 | if (node->peer) { |
---|
50 | 40 | hsiphash_key_t key = { { 0 } }; |
---|
.. | .. |
---|
55 | 45 | hsiphash_1u32(0xabad1dea, &key) % 200; |
---|
56 | 46 | style = "bold"; |
---|
57 | 47 | } |
---|
58 | | - swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); |
---|
59 | | - printk(fmt_declaration, ip1, node->cidr, style, color); |
---|
| 48 | + wg_allowedips_read_node(node, ip1, &cidr1); |
---|
| 49 | + printk(fmt_declaration, ip1, cidr1, style, color); |
---|
60 | 50 | if (node->bit[0]) { |
---|
61 | | - swap_endian_and_apply_cidr(ip2, |
---|
62 | | - rcu_dereference_raw(node->bit[0])->bits, bits, |
---|
63 | | - node->cidr); |
---|
64 | | - printk(fmt_connection, ip1, node->cidr, ip2, |
---|
65 | | - rcu_dereference_raw(node->bit[0])->cidr); |
---|
66 | | - print_node(rcu_dereference_raw(node->bit[0]), bits); |
---|
| 51 | + wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2); |
---|
| 52 | + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
---|
67 | 53 | } |
---|
68 | 54 | if (node->bit[1]) { |
---|
69 | | - swap_endian_and_apply_cidr(ip2, |
---|
70 | | - rcu_dereference_raw(node->bit[1])->bits, |
---|
71 | | - bits, node->cidr); |
---|
72 | | - printk(fmt_connection, ip1, node->cidr, ip2, |
---|
73 | | - rcu_dereference_raw(node->bit[1])->cidr); |
---|
74 | | - print_node(rcu_dereference_raw(node->bit[1]), bits); |
---|
| 55 | + wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2); |
---|
| 56 | + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
---|
75 | 57 | } |
---|
| 58 | + if (node->bit[0]) |
---|
| 59 | + print_node(rcu_dereference_raw(node->bit[0]), bits); |
---|
| 60 | + if (node->bit[1]) |
---|
| 61 | + print_node(rcu_dereference_raw(node->bit[1]), bits); |
---|
76 | 62 | } |
---|
77 | 63 | |
---|
78 | 64 | static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) |
---|
.. | .. |
---|
121 | 107 | { |
---|
122 | 108 | union nf_inet_addr mask; |
---|
123 | 109 | |
---|
124 | | - memset(&mask, 0x00, 128 / 8); |
---|
125 | | - memset(&mask, 0xff, cidr / 8); |
---|
| 110 | + memset(&mask, 0, sizeof(mask)); |
---|
| 111 | + memset(&mask.all, 0xff, cidr / 8); |
---|
126 | 112 | if (cidr % 32) |
---|
127 | 113 | mask.all[cidr / 32] = (__force u32)htonl( |
---|
128 | 114 | (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); |
---|
.. | .. |
---|
149 | 135 | } |
---|
150 | 136 | |
---|
151 | 137 | static __init inline bool |
---|
152 | | -horrible_match_v4(const struct horrible_allowedips_node *node, |
---|
153 | | - struct in_addr *ip) |
---|
| 138 | +horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) |
---|
154 | 139 | { |
---|
155 | 140 | return (ip->s_addr & node->mask.ip) == node->ip.ip; |
---|
156 | 141 | } |
---|
157 | 142 | |
---|
158 | 143 | static __init inline bool |
---|
159 | | -horrible_match_v6(const struct horrible_allowedips_node *node, |
---|
160 | | - struct in6_addr *ip) |
---|
| 144 | +horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) |
---|
161 | 145 | { |
---|
162 | | - return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == |
---|
163 | | - node->ip.ip6[0] && |
---|
164 | | - (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == |
---|
165 | | - node->ip.ip6[1] && |
---|
166 | | - (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == |
---|
167 | | - node->ip.ip6[2] && |
---|
| 146 | + return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && |
---|
| 147 | + (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && |
---|
| 148 | + (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && |
---|
168 | 149 | (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; |
---|
169 | 150 | } |
---|
170 | 151 | |
---|
171 | 152 | static __init void |
---|
172 | | -horrible_insert_ordered(struct horrible_allowedips *table, |
---|
173 | | - struct horrible_allowedips_node *node) |
---|
| 153 | +horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) |
---|
174 | 154 | { |
---|
175 | 155 | struct horrible_allowedips_node *other = NULL, *where = NULL; |
---|
176 | 156 | u8 my_cidr = horrible_mask_to_cidr(node->mask); |
---|
177 | 157 | |
---|
178 | 158 | hlist_for_each_entry(other, &table->head, table) { |
---|
179 | | - if (!memcmp(&other->mask, &node->mask, |
---|
180 | | - sizeof(union nf_inet_addr)) && |
---|
181 | | - !memcmp(&other->ip, &node->ip, |
---|
182 | | - sizeof(union nf_inet_addr)) && |
---|
183 | | - other->ip_version == node->ip_version) { |
---|
| 159 | + if (other->ip_version == node->ip_version && |
---|
| 160 | + !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && |
---|
| 161 | + !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) { |
---|
184 | 162 | other->value = node->value; |
---|
185 | 163 | kfree(node); |
---|
186 | 164 | return; |
---|
187 | 165 | } |
---|
| 166 | + } |
---|
| 167 | + hlist_for_each_entry(other, &table->head, table) { |
---|
188 | 168 | where = other; |
---|
189 | 169 | if (horrible_mask_to_cidr(other->mask) <= my_cidr) |
---|
190 | 170 | break; |
---|
.. | .. |
---|
201 | 181 | horrible_allowedips_insert_v4(struct horrible_allowedips *table, |
---|
202 | 182 | struct in_addr *ip, u8 cidr, void *value) |
---|
203 | 183 | { |
---|
204 | | - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
---|
205 | | - GFP_KERNEL); |
---|
| 184 | + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
---|
206 | 185 | |
---|
207 | 186 | if (unlikely(!node)) |
---|
208 | 187 | return -ENOMEM; |
---|
.. | .. |
---|
219 | 198 | horrible_allowedips_insert_v6(struct horrible_allowedips *table, |
---|
220 | 199 | struct in6_addr *ip, u8 cidr, void *value) |
---|
221 | 200 | { |
---|
222 | | - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
---|
223 | | - GFP_KERNEL); |
---|
| 201 | + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
---|
224 | 202 | |
---|
225 | 203 | if (unlikely(!node)) |
---|
226 | 204 | return -ENOMEM; |
---|
.. | .. |
---|
234 | 212 | } |
---|
235 | 213 | |
---|
236 | 214 | static __init void * |
---|
237 | | -horrible_allowedips_lookup_v4(struct horrible_allowedips *table, |
---|
238 | | - struct in_addr *ip) |
---|
| 215 | +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) |
---|
239 | 216 | { |
---|
240 | 217 | struct horrible_allowedips_node *node; |
---|
241 | | - void *ret = NULL; |
---|
242 | 218 | |
---|
243 | 219 | hlist_for_each_entry(node, &table->head, table) { |
---|
244 | | - if (node->ip_version != 4) |
---|
245 | | - continue; |
---|
246 | | - if (horrible_match_v4(node, ip)) { |
---|
247 | | - ret = node->value; |
---|
248 | | - break; |
---|
249 | | - } |
---|
| 220 | + if (node->ip_version == 4 && horrible_match_v4(node, ip)) |
---|
| 221 | + return node->value; |
---|
250 | 222 | } |
---|
251 | | - return ret; |
---|
| 223 | + return NULL; |
---|
252 | 224 | } |
---|
253 | 225 | |
---|
254 | 226 | static __init void * |
---|
255 | | -horrible_allowedips_lookup_v6(struct horrible_allowedips *table, |
---|
256 | | - struct in6_addr *ip) |
---|
| 227 | +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) |
---|
257 | 228 | { |
---|
258 | 229 | struct horrible_allowedips_node *node; |
---|
259 | | - void *ret = NULL; |
---|
260 | 230 | |
---|
261 | 231 | hlist_for_each_entry(node, &table->head, table) { |
---|
262 | | - if (node->ip_version != 6) |
---|
263 | | - continue; |
---|
264 | | - if (horrible_match_v6(node, ip)) { |
---|
265 | | - ret = node->value; |
---|
266 | | - break; |
---|
267 | | - } |
---|
| 232 | + if (node->ip_version == 6 && horrible_match_v6(node, ip)) |
---|
| 233 | + return node->value; |
---|
268 | 234 | } |
---|
269 | | - return ret; |
---|
| 235 | + return NULL; |
---|
| 236 | +} |
---|
| 237 | + |
---|
| 238 | + |
---|
| 239 | +static __init void |
---|
| 240 | +horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value) |
---|
| 241 | +{ |
---|
| 242 | + struct horrible_allowedips_node *node; |
---|
| 243 | + struct hlist_node *h; |
---|
| 244 | + |
---|
| 245 | + hlist_for_each_entry_safe(node, h, &table->head, table) { |
---|
| 246 | + if (node->value != value) |
---|
| 247 | + continue; |
---|
| 248 | + hlist_del(&node->table); |
---|
| 249 | + kfree(node); |
---|
| 250 | + } |
---|
| 251 | + |
---|
270 | 252 | } |
---|
271 | 253 | |
---|
272 | 254 | static __init bool randomized_test(void) |
---|
.. | .. |
---|
296 | 278 | goto free; |
---|
297 | 279 | } |
---|
298 | 280 | kref_init(&peers[i]->refcount); |
---|
| 281 | + INIT_LIST_HEAD(&peers[i]->allowedips_list); |
---|
299 | 282 | } |
---|
300 | 283 | |
---|
301 | 284 | mutex_lock(&mutex); |
---|
.. | .. |
---|
333 | 316 | if (wg_allowedips_insert_v4(&t, |
---|
334 | 317 | (struct in_addr *)mutated, |
---|
335 | 318 | cidr, peer, &mutex) < 0) { |
---|
336 | | - pr_err("allowedips random malloc: FAIL\n"); |
---|
| 319 | + pr_err("allowedips random self-test malloc: FAIL\n"); |
---|
337 | 320 | goto free_locked; |
---|
338 | 321 | } |
---|
339 | 322 | if (horrible_allowedips_insert_v4(&h, |
---|
.. | .. |
---|
396 | 379 | print_tree(t.root6, 128); |
---|
397 | 380 | } |
---|
398 | 381 | |
---|
399 | | - for (i = 0; i < NUM_QUERIES; ++i) { |
---|
400 | | - prandom_bytes(ip, 4); |
---|
401 | | - if (lookup(t.root4, 32, ip) != |
---|
402 | | - horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
---|
403 | | - pr_err("allowedips random self-test: FAIL\n"); |
---|
404 | | - goto free; |
---|
| 382 | + for (j = 0;; ++j) { |
---|
| 383 | + for (i = 0; i < NUM_QUERIES; ++i) { |
---|
| 384 | + prandom_bytes(ip, 4); |
---|
| 385 | + if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
---|
| 386 | + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); |
---|
| 387 | + pr_err("allowedips random v4 self-test: FAIL\n"); |
---|
| 388 | + goto free; |
---|
| 389 | + } |
---|
| 390 | + prandom_bytes(ip, 16); |
---|
| 391 | + if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
---|
| 392 | + pr_err("allowedips random v6 self-test: FAIL\n"); |
---|
| 393 | + goto free; |
---|
| 394 | + } |
---|
405 | 395 | } |
---|
| 396 | + if (j >= NUM_PEERS) |
---|
| 397 | + break; |
---|
| 398 | + mutex_lock(&mutex); |
---|
| 399 | + wg_allowedips_remove_by_peer(&t, peers[j], &mutex); |
---|
| 400 | + mutex_unlock(&mutex); |
---|
| 401 | + horrible_allowedips_remove_by_value(&h, peers[j]); |
---|
406 | 402 | } |
---|
407 | 403 | |
---|
408 | | - for (i = 0; i < NUM_QUERIES; ++i) { |
---|
409 | | - prandom_bytes(ip, 16); |
---|
410 | | - if (lookup(t.root6, 128, ip) != |
---|
411 | | - horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
---|
412 | | - pr_err("allowedips random self-test: FAIL\n"); |
---|
413 | | - goto free; |
---|
414 | | - } |
---|
| 404 | + if (t.root4 || t.root6) { |
---|
| 405 | + pr_err("allowedips random self-test removal: FAIL\n"); |
---|
| 406 | + goto free; |
---|
415 | 407 | } |
---|
| 408 | + |
---|
416 | 409 | ret = true; |
---|
417 | 410 | |
---|
418 | 411 | free: |
---|
.. | .. |
---|
600 | 593 | wg_allowedips_remove_by_peer(&t, a, &mutex); |
---|
601 | 594 | test_negative(4, a, 192, 168, 0, 1); |
---|
602 | 595 | |
---|
603 | | - /* These will hit the WARN_ON(len >= 128) in free_node if something |
---|
604 | | - * goes wrong. |
---|
| 596 | + /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_DEPTH) in free_node |
---|
| 597 | + * if something goes wrong. |
---|
605 | 598 | */ |
---|
606 | | - for (i = 0; i < 128; ++i) { |
---|
607 | | - part = cpu_to_be64(~(1LLU << (i % 64))); |
---|
608 | | - memset(&ip, 0xff, 16); |
---|
609 | | - memcpy((u8 *)&ip + (i < 64) * 8, &part, 8); |
---|
| 599 | + for (i = 0; i < 64; ++i) { |
---|
| 600 | + part = cpu_to_be64(~0LLU << i); |
---|
| 601 | + memset(&ip, 0xff, 8); |
---|
| 602 | + memcpy((u8 *)&ip + 8, &part, 8); |
---|
| 603 | + wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex); |
---|
| 604 | + memcpy(&ip, &part, 8); |
---|
| 605 | + memset((u8 *)&ip + 8, 0, 8); |
---|
610 | 606 | wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex); |
---|
611 | 607 | } |
---|
612 | | - |
---|
| 608 | + memset(&ip, 0, 16); |
---|
| 609 | + wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex); |
---|
613 | 610 | wg_allowedips_free(&t, &mutex); |
---|
614 | 611 | |
---|
615 | 612 | wg_allowedips_init(&t); |
---|