hc
2024-02-20 102a0743326a03cd1a1202ceda21e175b7d3575c
kernel/net/l2tp/l2tp_core.c
....@@ -104,9 +104,9 @@
104104 /* per-net private data for this module */
105105 static unsigned int l2tp_net_id;
106106 struct l2tp_net {
107
- struct list_head l2tp_tunnel_list;
108
- /* Lock for write access to l2tp_tunnel_list */
109
- spinlock_t l2tp_tunnel_list_lock;
107
+ /* Lock for write access to l2tp_tunnel_idr */
108
+ spinlock_t l2tp_tunnel_idr_lock;
109
+ struct idr l2tp_tunnel_idr;
110110 struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
111111 /* Lock for write access to l2tp_session_hlist */
112112 spinlock_t l2tp_session_hlist_lock;
....@@ -208,13 +208,10 @@
208208 struct l2tp_tunnel *tunnel;
209209
210210 rcu_read_lock_bh();
211
- list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
212
- if (tunnel->tunnel_id == tunnel_id &&
213
- refcount_inc_not_zero(&tunnel->ref_count)) {
214
- rcu_read_unlock_bh();
215
-
216
- return tunnel;
217
- }
211
+ tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id);
212
+ if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) {
213
+ rcu_read_unlock_bh();
214
+ return tunnel;
218215 }
219216 rcu_read_unlock_bh();
220217
....@@ -224,13 +221,14 @@
224221
225222 struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
226223 {
227
- const struct l2tp_net *pn = l2tp_pernet(net);
224
+ struct l2tp_net *pn = l2tp_pernet(net);
225
+ unsigned long tunnel_id, tmp;
228226 struct l2tp_tunnel *tunnel;
229227 int count = 0;
230228
231229 rcu_read_lock_bh();
232
- list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
233
- if (++count > nth &&
230
+ idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
231
+ if (tunnel && ++count > nth &&
234232 refcount_inc_not_zero(&tunnel->ref_count)) {
235233 rcu_read_unlock_bh();
236234 return tunnel;
....@@ -1043,7 +1041,7 @@
10431041 IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED);
10441042 nf_reset_ct(skb);
10451043
1046
- bh_lock_sock(sk);
1044
+ bh_lock_sock_nested(sk);
10471045 if (sock_owned_by_user(sk)) {
10481046 kfree_skb(skb);
10491047 ret = NET_XMIT_DROP;
....@@ -1150,8 +1148,10 @@
11501148 }
11511149
11521150 /* Remove hooks into tunnel socket */
1151
+ write_lock_bh(&sk->sk_callback_lock);
11531152 sk->sk_destruct = tunnel->old_sk_destruct;
11541153 sk->sk_user_data = NULL;
1154
+ write_unlock_bh(&sk->sk_callback_lock);
11551155
11561156 /* Call the original destructor */
11571157 if (sk->sk_destruct)
....@@ -1227,6 +1227,15 @@
12271227 l2tp_tunnel_delete(tunnel);
12281228 }
12291229
1230
+static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel)
1231
+{
1232
+ struct l2tp_net *pn = l2tp_pernet(net);
1233
+
1234
+ spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
1235
+ idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id);
1236
+ spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
1237
+}
1238
+
12301239 /* Workqueue tunnel deletion function */
12311240 static void l2tp_tunnel_del_work(struct work_struct *work)
12321241 {
....@@ -1234,7 +1243,6 @@
12341243 del_work);
12351244 struct sock *sk = tunnel->sock;
12361245 struct socket *sock = sk->sk_socket;
1237
- struct l2tp_net *pn;
12381246
12391247 l2tp_tunnel_closeall(tunnel);
12401248
....@@ -1248,12 +1256,7 @@
12481256 }
12491257 }
12501258
1251
- /* Remove the tunnel struct from the tunnel list */
1252
- pn = l2tp_pernet(tunnel->l2tp_net);
1253
- spin_lock_bh(&pn->l2tp_tunnel_list_lock);
1254
- list_del_rcu(&tunnel->list);
1255
- spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
1256
-
1259
+ l2tp_tunnel_remove(tunnel->l2tp_net, tunnel);
12571260 /* drop initial ref */
12581261 l2tp_tunnel_dec_refcount(tunnel);
12591262
....@@ -1384,8 +1387,6 @@
13841387 return err;
13851388 }
13861389
1387
-static struct lock_class_key l2tp_socket_class;
1388
-
13891390 int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id,
13901391 struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp)
13911392 {
....@@ -1455,11 +1456,18 @@
14551456 int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
14561457 struct l2tp_tunnel_cfg *cfg)
14571458 {
1458
- struct l2tp_tunnel *tunnel_walk;
1459
- struct l2tp_net *pn;
1459
+ struct l2tp_net *pn = l2tp_pernet(net);
1460
+ u32 tunnel_id = tunnel->tunnel_id;
14601461 struct socket *sock;
14611462 struct sock *sk;
14621463 int ret;
1464
+
1465
+ spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
1466
+ ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id,
1467
+ GFP_ATOMIC);
1468
+ spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
1469
+ if (ret)
1470
+ return ret == -ENOSPC ? -EEXIST : ret;
14631471
14641472 if (tunnel->fd < 0) {
14651473 ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
....@@ -1471,30 +1479,16 @@
14711479 sock = sockfd_lookup(tunnel->fd, &ret);
14721480 if (!sock)
14731481 goto err;
1474
-
1475
- ret = l2tp_validate_socket(sock->sk, net, tunnel->encap);
1476
- if (ret < 0)
1477
- goto err_sock;
14781482 }
1479
-
1480
- tunnel->l2tp_net = net;
1481
- pn = l2tp_pernet(net);
14821483
14831484 sk = sock->sk;
1484
- sock_hold(sk);
1485
- tunnel->sock = sk;
1486
-
1487
- spin_lock_bh(&pn->l2tp_tunnel_list_lock);
1488
- list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) {
1489
- if (tunnel_walk->tunnel_id == tunnel->tunnel_id) {
1490
- spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
1491
- sock_put(sk);
1492
- ret = -EEXIST;
1493
- goto err_sock;
1494
- }
1495
- }
1496
- list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
1497
- spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
1485
+ lock_sock(sk);
1486
+ write_lock_bh(&sk->sk_callback_lock);
1487
+ ret = l2tp_validate_socket(sk, net, tunnel->encap);
1488
+ if (ret < 0)
1489
+ goto err_inval_sock;
1490
+ rcu_assign_sk_user_data(sk, tunnel);
1491
+ write_unlock_bh(&sk->sk_callback_lock);
14981492
14991493 if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
15001494 struct udp_tunnel_sock_cfg udp_cfg = {
....@@ -1505,15 +1499,20 @@
15051499 };
15061500
15071501 setup_udp_tunnel_sock(net, sock, &udp_cfg);
1508
- } else {
1509
- sk->sk_user_data = tunnel;
15101502 }
15111503
15121504 tunnel->old_sk_destruct = sk->sk_destruct;
15131505 sk->sk_destruct = &l2tp_tunnel_destruct;
1514
- lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class,
1515
- "l2tp_sock");
15161506 sk->sk_allocation = GFP_ATOMIC;
1507
+ release_sock(sk);
1508
+
1509
+ sock_hold(sk);
1510
+ tunnel->sock = sk;
1511
+ tunnel->l2tp_net = net;
1512
+
1513
+ spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
1514
+ idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id);
1515
+ spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
15171516
15181517 trace_register_tunnel(tunnel);
15191518
....@@ -1522,12 +1521,16 @@
15221521
15231522 return 0;
15241523
1525
-err_sock:
1524
+err_inval_sock:
1525
+ write_unlock_bh(&sk->sk_callback_lock);
1526
+ release_sock(sk);
1527
+
15261528 if (tunnel->fd < 0)
15271529 sock_release(sock);
15281530 else
15291531 sockfd_put(sock);
15301532 err:
1533
+ l2tp_tunnel_remove(net, tunnel);
15311534 return ret;
15321535 }
15331536 EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
....@@ -1641,8 +1644,8 @@
16411644 struct l2tp_net *pn = net_generic(net, l2tp_net_id);
16421645 int hash;
16431646
1644
- INIT_LIST_HEAD(&pn->l2tp_tunnel_list);
1645
- spin_lock_init(&pn->l2tp_tunnel_list_lock);
1647
+ idr_init(&pn->l2tp_tunnel_idr);
1648
+ spin_lock_init(&pn->l2tp_tunnel_idr_lock);
16461649
16471650 for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
16481651 INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
....@@ -1656,11 +1659,13 @@
16561659 {
16571660 struct l2tp_net *pn = l2tp_pernet(net);
16581661 struct l2tp_tunnel *tunnel = NULL;
1662
+ unsigned long tunnel_id, tmp;
16591663 int hash;
16601664
16611665 rcu_read_lock_bh();
1662
- list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
1663
- l2tp_tunnel_delete(tunnel);
1666
+ idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
1667
+ if (tunnel)
1668
+ l2tp_tunnel_delete(tunnel);
16641669 }
16651670 rcu_read_unlock_bh();
16661671
....@@ -1670,6 +1675,7 @@
16701675
16711676 for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
16721677 WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
1678
+ idr_destroy(&pn->l2tp_tunnel_idr);
16731679 }
16741680
16751681 static struct pernet_operations l2tp_net_ops = {