| .. | .. |
|---|
| 19 | 19 | |
|---|
| 20 | 20 | #include <linux/siphash.h> |
|---|
| 21 | 21 | |
|---|
| 22 | | -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, |
|---|
| 23 | | - u8 cidr) |
|---|
| 24 | | -{ |
|---|
| 25 | | - swap_endian(dst, src, bits); |
|---|
| 26 | | - memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); |
|---|
| 27 | | - if (cidr) |
|---|
| 28 | | - dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); |
|---|
| 29 | | -} |
|---|
| 30 | | - |
|---|
| 31 | 22 | static __init void print_node(struct allowedips_node *node, u8 bits) |
|---|
| 32 | 23 | { |
|---|
| 33 | 24 | char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; |
|---|
| 34 | | - char *fmt_declaration = KERN_DEBUG |
|---|
| 35 | | - "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 25 | + char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 26 | + u8 ip1[16], ip2[16], cidr1, cidr2; |
|---|
| 36 | 27 | char *style = "dotted"; |
|---|
| 37 | | - u8 ip1[16], ip2[16]; |
|---|
| 38 | 28 | u32 color = 0; |
|---|
| 39 | 29 | |
|---|
| 30 | + if (node == NULL) |
|---|
| 31 | + return; |
|---|
| 40 | 32 | if (bits == 32) { |
|---|
| 41 | 33 | fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; |
|---|
| 42 | | - fmt_declaration = KERN_DEBUG |
|---|
| 43 | | - "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 34 | + fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 44 | 35 | } else if (bits == 128) { |
|---|
| 45 | 36 | fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; |
|---|
| 46 | | - fmt_declaration = KERN_DEBUG |
|---|
| 47 | | - "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 37 | + fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
|---|
| 48 | 38 | } |
|---|
| 49 | 39 | if (node->peer) { |
|---|
| 50 | 40 | hsiphash_key_t key = { { 0 } }; |
|---|
| .. | .. |
|---|
| 55 | 45 | hsiphash_1u32(0xabad1dea, &key) % 200; |
|---|
| 56 | 46 | style = "bold"; |
|---|
| 57 | 47 | } |
|---|
| 58 | | - swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); |
|---|
| 59 | | - printk(fmt_declaration, ip1, node->cidr, style, color); |
|---|
| 48 | + wg_allowedips_read_node(node, ip1, &cidr1); |
|---|
| 49 | + printk(fmt_declaration, ip1, cidr1, style, color); |
|---|
| 60 | 50 | if (node->bit[0]) { |
|---|
| 61 | | - swap_endian_and_apply_cidr(ip2, |
|---|
| 62 | | - rcu_dereference_raw(node->bit[0])->bits, bits, |
|---|
| 63 | | - node->cidr); |
|---|
| 64 | | - printk(fmt_connection, ip1, node->cidr, ip2, |
|---|
| 65 | | - rcu_dereference_raw(node->bit[0])->cidr); |
|---|
| 66 | | - print_node(rcu_dereference_raw(node->bit[0]), bits); |
|---|
| 51 | + wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2); |
|---|
| 52 | + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
|---|
| 67 | 53 | } |
|---|
| 68 | 54 | if (node->bit[1]) { |
|---|
| 69 | | - swap_endian_and_apply_cidr(ip2, |
|---|
| 70 | | - rcu_dereference_raw(node->bit[1])->bits, |
|---|
| 71 | | - bits, node->cidr); |
|---|
| 72 | | - printk(fmt_connection, ip1, node->cidr, ip2, |
|---|
| 73 | | - rcu_dereference_raw(node->bit[1])->cidr); |
|---|
| 74 | | - print_node(rcu_dereference_raw(node->bit[1]), bits); |
|---|
| 55 | + wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2); |
|---|
| 56 | + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
|---|
| 75 | 57 | } |
|---|
| 58 | + if (node->bit[0]) |
|---|
| 59 | + print_node(rcu_dereference_raw(node->bit[0]), bits); |
|---|
| 60 | + if (node->bit[1]) |
|---|
| 61 | + print_node(rcu_dereference_raw(node->bit[1]), bits); |
|---|
| 76 | 62 | } |
|---|
| 77 | 63 | |
|---|
| 78 | 64 | static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) |
|---|
| .. | .. |
|---|
| 121 | 107 | { |
|---|
| 122 | 108 | union nf_inet_addr mask; |
|---|
| 123 | 109 | |
|---|
| 124 | | - memset(&mask, 0x00, 128 / 8); |
|---|
| 125 | | - memset(&mask, 0xff, cidr / 8); |
|---|
| 110 | + memset(&mask, 0, sizeof(mask)); |
|---|
| 111 | + memset(&mask.all, 0xff, cidr / 8); |
|---|
| 126 | 112 | if (cidr % 32) |
|---|
| 127 | 113 | mask.all[cidr / 32] = (__force u32)htonl( |
|---|
| 128 | 114 | (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); |
|---|
| .. | .. |
|---|
| 149 | 135 | } |
|---|
| 150 | 136 | |
|---|
| 151 | 137 | static __init inline bool |
|---|
| 152 | | -horrible_match_v4(const struct horrible_allowedips_node *node, |
|---|
| 153 | | - struct in_addr *ip) |
|---|
| 138 | +horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) |
|---|
| 154 | 139 | { |
|---|
| 155 | 140 | return (ip->s_addr & node->mask.ip) == node->ip.ip; |
|---|
| 156 | 141 | } |
|---|
| 157 | 142 | |
|---|
| 158 | 143 | static __init inline bool |
|---|
| 159 | | -horrible_match_v6(const struct horrible_allowedips_node *node, |
|---|
| 160 | | - struct in6_addr *ip) |
|---|
| 144 | +horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) |
|---|
| 161 | 145 | { |
|---|
| 162 | | - return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == |
|---|
| 163 | | - node->ip.ip6[0] && |
|---|
| 164 | | - (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == |
|---|
| 165 | | - node->ip.ip6[1] && |
|---|
| 166 | | - (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == |
|---|
| 167 | | - node->ip.ip6[2] && |
|---|
| 146 | + return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && |
|---|
| 147 | + (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && |
|---|
| 148 | + (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && |
|---|
| 168 | 149 | (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; |
|---|
| 169 | 150 | } |
|---|
| 170 | 151 | |
|---|
| 171 | 152 | static __init void |
|---|
| 172 | | -horrible_insert_ordered(struct horrible_allowedips *table, |
|---|
| 173 | | - struct horrible_allowedips_node *node) |
|---|
| 153 | +horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) |
|---|
| 174 | 154 | { |
|---|
| 175 | 155 | struct horrible_allowedips_node *other = NULL, *where = NULL; |
|---|
| 176 | 156 | u8 my_cidr = horrible_mask_to_cidr(node->mask); |
|---|
| 177 | 157 | |
|---|
| 178 | 158 | hlist_for_each_entry(other, &table->head, table) { |
|---|
| 179 | | - if (!memcmp(&other->mask, &node->mask, |
|---|
| 180 | | - sizeof(union nf_inet_addr)) && |
|---|
| 181 | | - !memcmp(&other->ip, &node->ip, |
|---|
| 182 | | - sizeof(union nf_inet_addr)) && |
|---|
| 183 | | - other->ip_version == node->ip_version) { |
|---|
| 159 | + if (other->ip_version == node->ip_version && |
|---|
| 160 | + !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && |
|---|
| 161 | + !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) { |
|---|
| 184 | 162 | other->value = node->value; |
|---|
| 185 | 163 | kfree(node); |
|---|
| 186 | 164 | return; |
|---|
| 187 | 165 | } |
|---|
| 166 | + } |
|---|
| 167 | + hlist_for_each_entry(other, &table->head, table) { |
|---|
| 188 | 168 | where = other; |
|---|
| 189 | 169 | if (horrible_mask_to_cidr(other->mask) <= my_cidr) |
|---|
| 190 | 170 | break; |
|---|
| .. | .. |
|---|
| 201 | 181 | horrible_allowedips_insert_v4(struct horrible_allowedips *table, |
|---|
| 202 | 182 | struct in_addr *ip, u8 cidr, void *value) |
|---|
| 203 | 183 | { |
|---|
| 204 | | - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
|---|
| 205 | | - GFP_KERNEL); |
|---|
| 184 | + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
|---|
| 206 | 185 | |
|---|
| 207 | 186 | if (unlikely(!node)) |
|---|
| 208 | 187 | return -ENOMEM; |
|---|
| .. | .. |
|---|
| 219 | 198 | horrible_allowedips_insert_v6(struct horrible_allowedips *table, |
|---|
| 220 | 199 | struct in6_addr *ip, u8 cidr, void *value) |
|---|
| 221 | 200 | { |
|---|
| 222 | | - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
|---|
| 223 | | - GFP_KERNEL); |
|---|
| 201 | + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
|---|
| 224 | 202 | |
|---|
| 225 | 203 | if (unlikely(!node)) |
|---|
| 226 | 204 | return -ENOMEM; |
|---|
| .. | .. |
|---|
| 234 | 212 | } |
|---|
| 235 | 213 | |
|---|
| 236 | 214 | static __init void * |
|---|
| 237 | | -horrible_allowedips_lookup_v4(struct horrible_allowedips *table, |
|---|
| 238 | | - struct in_addr *ip) |
|---|
| 215 | +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) |
|---|
| 239 | 216 | { |
|---|
| 240 | 217 | struct horrible_allowedips_node *node; |
|---|
| 241 | | - void *ret = NULL; |
|---|
| 242 | 218 | |
|---|
| 243 | 219 | hlist_for_each_entry(node, &table->head, table) { |
|---|
| 244 | | - if (node->ip_version != 4) |
|---|
| 245 | | - continue; |
|---|
| 246 | | - if (horrible_match_v4(node, ip)) { |
|---|
| 247 | | - ret = node->value; |
|---|
| 248 | | - break; |
|---|
| 249 | | - } |
|---|
| 220 | + if (node->ip_version == 4 && horrible_match_v4(node, ip)) |
|---|
| 221 | + return node->value; |
|---|
| 250 | 222 | } |
|---|
| 251 | | - return ret; |
|---|
| 223 | + return NULL; |
|---|
| 252 | 224 | } |
|---|
| 253 | 225 | |
|---|
| 254 | 226 | static __init void * |
|---|
| 255 | | -horrible_allowedips_lookup_v6(struct horrible_allowedips *table, |
|---|
| 256 | | - struct in6_addr *ip) |
|---|
| 227 | +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) |
|---|
| 257 | 228 | { |
|---|
| 258 | 229 | struct horrible_allowedips_node *node; |
|---|
| 259 | | - void *ret = NULL; |
|---|
| 260 | 230 | |
|---|
| 261 | 231 | hlist_for_each_entry(node, &table->head, table) { |
|---|
| 262 | | - if (node->ip_version != 6) |
|---|
| 263 | | - continue; |
|---|
| 264 | | - if (horrible_match_v6(node, ip)) { |
|---|
| 265 | | - ret = node->value; |
|---|
| 266 | | - break; |
|---|
| 267 | | - } |
|---|
| 232 | + if (node->ip_version == 6 && horrible_match_v6(node, ip)) |
|---|
| 233 | + return node->value; |
|---|
| 268 | 234 | } |
|---|
| 269 | | - return ret; |
|---|
| 235 | + return NULL; |
|---|
| 236 | +} |
|---|
| 237 | + |
|---|
| 238 | + |
|---|
| 239 | +static __init void |
|---|
| 240 | +horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value) |
|---|
| 241 | +{ |
|---|
| 242 | + struct horrible_allowedips_node *node; |
|---|
| 243 | + struct hlist_node *h; |
|---|
| 244 | + |
|---|
| 245 | + hlist_for_each_entry_safe(node, h, &table->head, table) { |
|---|
| 246 | + if (node->value != value) |
|---|
| 247 | + continue; |
|---|
| 248 | + hlist_del(&node->table); |
|---|
| 249 | + kfree(node); |
|---|
| 250 | + } |
|---|
| 251 | + |
|---|
| 270 | 252 | } |
|---|
| 271 | 253 | |
|---|
| 272 | 254 | static __init bool randomized_test(void) |
|---|
| .. | .. |
|---|
| 296 | 278 | goto free; |
|---|
| 297 | 279 | } |
|---|
| 298 | 280 | kref_init(&peers[i]->refcount); |
|---|
| 281 | + INIT_LIST_HEAD(&peers[i]->allowedips_list); |
|---|
| 299 | 282 | } |
|---|
| 300 | 283 | |
|---|
| 301 | 284 | mutex_lock(&mutex); |
|---|
| .. | .. |
|---|
| 333 | 316 | if (wg_allowedips_insert_v4(&t, |
|---|
| 334 | 317 | (struct in_addr *)mutated, |
|---|
| 335 | 318 | cidr, peer, &mutex) < 0) { |
|---|
| 336 | | - pr_err("allowedips random malloc: FAIL\n"); |
|---|
| 319 | + pr_err("allowedips random self-test malloc: FAIL\n"); |
|---|
| 337 | 320 | goto free_locked; |
|---|
| 338 | 321 | } |
|---|
| 339 | 322 | if (horrible_allowedips_insert_v4(&h, |
|---|
| .. | .. |
|---|
| 396 | 379 | print_tree(t.root6, 128); |
|---|
| 397 | 380 | } |
|---|
| 398 | 381 | |
|---|
| 399 | | - for (i = 0; i < NUM_QUERIES; ++i) { |
|---|
| 400 | | - prandom_bytes(ip, 4); |
|---|
| 401 | | - if (lookup(t.root4, 32, ip) != |
|---|
| 402 | | - horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
|---|
| 403 | | - pr_err("allowedips random self-test: FAIL\n"); |
|---|
| 404 | | - goto free; |
|---|
| 382 | + for (j = 0;; ++j) { |
|---|
| 383 | + for (i = 0; i < NUM_QUERIES; ++i) { |
|---|
| 384 | + prandom_bytes(ip, 4); |
|---|
| 385 | + if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
|---|
| 386 | + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); |
|---|
| 387 | + pr_err("allowedips random v4 self-test: FAIL\n"); |
|---|
| 388 | + goto free; |
|---|
| 389 | + } |
|---|
| 390 | + prandom_bytes(ip, 16); |
|---|
| 391 | + if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
|---|
| 392 | + pr_err("allowedips random v6 self-test: FAIL\n"); |
|---|
| 393 | + goto free; |
|---|
| 394 | + } |
|---|
| 405 | 395 | } |
|---|
| 396 | + if (j >= NUM_PEERS) |
|---|
| 397 | + break; |
|---|
| 398 | + mutex_lock(&mutex); |
|---|
| 399 | + wg_allowedips_remove_by_peer(&t, peers[j], &mutex); |
|---|
| 400 | + mutex_unlock(&mutex); |
|---|
| 401 | + horrible_allowedips_remove_by_value(&h, peers[j]); |
|---|
| 406 | 402 | } |
|---|
| 407 | 403 | |
|---|
| 408 | | - for (i = 0; i < NUM_QUERIES; ++i) { |
|---|
| 409 | | - prandom_bytes(ip, 16); |
|---|
| 410 | | - if (lookup(t.root6, 128, ip) != |
|---|
| 411 | | - horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
|---|
| 412 | | - pr_err("allowedips random self-test: FAIL\n"); |
|---|
| 413 | | - goto free; |
|---|
| 414 | | - } |
|---|
| 404 | + if (t.root4 || t.root6) { |
|---|
| 405 | + pr_err("allowedips random self-test removal: FAIL\n"); |
|---|
| 406 | + goto free; |
|---|
| 415 | 407 | } |
|---|
| 408 | + |
|---|
| 416 | 409 | ret = true; |
|---|
| 417 | 410 | |
|---|
| 418 | 411 | free: |
|---|
| .. | .. |
|---|
| 600 | 593 | wg_allowedips_remove_by_peer(&t, a, &mutex); |
|---|
| 601 | 594 | test_negative(4, a, 192, 168, 0, 1); |
|---|
| 602 | 595 | |
|---|
| 603 | | - /* These will hit the WARN_ON(len >= 128) in free_node if something |
|---|
| 604 | | - * goes wrong. |
|---|
| 596 | + /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_BITS) in free_node |
|---|
| 597 | + * if something goes wrong. |
|---|
| 605 | 598 | */ |
|---|
| 606 | | - for (i = 0; i < 128; ++i) { |
|---|
| 599 | + for (i = 0; i < MAX_ALLOWEDIPS_BITS; ++i) { |
|---|
| 607 | 600 | part = cpu_to_be64(~(1LLU << (i % 64))); |
|---|
| 608 | 601 | memset(&ip, 0xff, 16); |
|---|
| 609 | 602 | memcpy((u8 *)&ip + (i < 64) * 8, &part, 8); |
|---|