hc
2024-02-20 102a0743326a03cd1a1202ceda21e175b7d3575c
kernel/drivers/net/wireguard/peer.c
....@@ -15,6 +15,7 @@
1515 #include <linux/rcupdate.h>
1616 #include <linux/list.h>
1717
18
+static struct kmem_cache *peer_cache;
1819 static atomic64_t peer_counter = ATOMIC64_INIT(0);
1920
2021 struct wg_peer *wg_peer_create(struct wg_device *wg,
....@@ -29,30 +30,25 @@
2930 if (wg->num_peers >= MAX_PEERS_PER_DEVICE)
3031 return ERR_PTR(ret);
3132
32
- peer = kzalloc(sizeof(*peer), GFP_KERNEL);
33
+ peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL);
3334 if (unlikely(!peer))
3435 return ERR_PTR(ret);
35
- peer->device = wg;
36
+ if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)))
37
+ goto err;
3638
39
+ peer->device = wg;
3740 wg_noise_handshake_init(&peer->handshake, &wg->static_identity,
3841 public_key, preshared_key, peer);
39
- if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
40
- goto err_1;
41
- if (wg_packet_queue_init(&peer->tx_queue, wg_packet_tx_worker, false,
42
- MAX_QUEUED_PACKETS))
43
- goto err_2;
44
- if (wg_packet_queue_init(&peer->rx_queue, NULL, false,
45
- MAX_QUEUED_PACKETS))
46
- goto err_3;
47
-
4842 peer->internal_id = atomic64_inc_return(&peer_counter);
4943 peer->serial_work_cpu = nr_cpumask_bits;
5044 wg_cookie_init(&peer->latest_cookie);
5145 wg_timers_init(peer);
5246 wg_cookie_checker_precompute_peer_keys(peer);
5347 spin_lock_init(&peer->keypairs.keypair_update_lock);
54
- INIT_WORK(&peer->transmit_handshake_work,
55
- wg_packet_handshake_send_worker);
48
+ INIT_WORK(&peer->transmit_handshake_work, wg_packet_handshake_send_worker);
49
+ INIT_WORK(&peer->transmit_packet_work, wg_packet_tx_worker);
50
+ wg_prev_queue_init(&peer->tx_queue);
51
+ wg_prev_queue_init(&peer->rx_queue);
5652 rwlock_init(&peer->endpoint_lock);
5753 kref_init(&peer->refcount);
5854 skb_queue_head_init(&peer->staged_packet_queue);
....@@ -68,12 +64,8 @@
6864 pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id);
6965 return peer;
7066
71
-err_3:
72
- wg_packet_queue_free(&peer->tx_queue, false);
73
-err_2:
74
- dst_cache_destroy(&peer->endpoint_cache);
75
-err_1:
76
- kfree(peer);
67
+err:
68
+ kmem_cache_free(peer_cache, peer);
7769 return ERR_PTR(ret);
7870 }
7971
....@@ -97,7 +89,7 @@
9789 /* Mark as dead, so that we don't allow jumping contexts after. */
9890 WRITE_ONCE(peer->is_dead, true);
9991
100
- /* The caller must now synchronize_rcu() for this to take effect. */
92
+ /* The caller must now synchronize_net() for this to take effect. */
10193 }
10294
10395 static void peer_remove_after_dead(struct wg_peer *peer)
....@@ -169,7 +161,7 @@
169161 lockdep_assert_held(&peer->device->device_update_lock);
170162
171163 peer_make_dead(peer);
172
- synchronize_rcu();
164
+ synchronize_net();
173165 peer_remove_after_dead(peer);
174166 }
175167
....@@ -187,7 +179,7 @@
187179 peer_make_dead(peer);
188180 list_add_tail(&peer->peer_list, &dead_peers);
189181 }
190
- synchronize_rcu();
182
+ synchronize_net();
191183 list_for_each_entry_safe(peer, temp, &dead_peers, peer_list)
192184 peer_remove_after_dead(peer);
193185 }
....@@ -197,13 +189,13 @@
197189 struct wg_peer *peer = container_of(rcu, struct wg_peer, rcu);
198190
199191 dst_cache_destroy(&peer->endpoint_cache);
200
- wg_packet_queue_free(&peer->rx_queue, false);
201
- wg_packet_queue_free(&peer->tx_queue, false);
192
+ WARN_ON(wg_prev_queue_peek(&peer->tx_queue) || wg_prev_queue_peek(&peer->rx_queue));
202193
203194 /* The final zeroing takes care of clearing any remaining handshake key
204195 * material and other potentially sensitive information.
205196 */
206
- kzfree(peer);
197
+ memzero_explicit(peer, sizeof(*peer));
198
+ kmem_cache_free(peer_cache, peer);
207199 }
208200
209201 static void kref_release(struct kref *refcount)
....@@ -235,3 +227,14 @@
235227 return;
236228 kref_put(&peer->refcount, kref_release);
237229 }
230
+
231
+int __init wg_peer_init(void)
232
+{
233
+ peer_cache = KMEM_CACHE(wg_peer, 0);
234
+ return peer_cache ? 0 : -ENOMEM;
235
+}
236
+
237
+void wg_peer_uninit(void)
238
+{
239
+ kmem_cache_destroy(peer_cache);
240
+}