hc
2024-10-22 8ac6c7a54ed1b98d142dce24b11c6de6a1e239a5
kernel/net/vmw_vsock/virtio_transport_common.c
....@@ -1,20 +1,16 @@
1
+// SPDX-License-Identifier: GPL-2.0-only
12 /*
23 * common code for virtio vsock
34 *
45 * Copyright (C) 2013-2015 Red Hat, Inc.
56 * Author: Asias He <asias@redhat.com>
67 * Stefan Hajnoczi <stefanha@redhat.com>
7
- *
8
- * This work is licensed under the terms of the GNU GPL, version 2.
98 */
109 #include <linux/spinlock.h>
1110 #include <linux/module.h>
1211 #include <linux/sched/signal.h>
1312 #include <linux/ctype.h>
1413 #include <linux/list.h>
15
-#include <linux/virtio.h>
16
-#include <linux/virtio_ids.h>
17
-#include <linux/virtio_config.h>
1814 #include <linux/virtio_vsock.h>
1915 #include <uapi/linux/vsockmon.h>
2016
....@@ -27,13 +23,20 @@
2723 /* How long to wait for graceful shutdown of a connection */
2824 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
2925
26
+/* Threshold for detecting small packets to copy */
27
+#define GOOD_COPY_LEN 128
28
+
3029 uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024;
3130 module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444);
3231 EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size);
3332
34
-static const struct virtio_transport *virtio_transport_get_ops(void)
33
+static const struct virtio_transport *
34
+virtio_transport_get_ops(struct vsock_sock *vsk)
3535 {
36
- const struct vsock_transport *t = vsock_core_get_transport();
36
+ const struct vsock_transport *t = vsock_core_get_transport(vsk);
37
+
38
+ if (WARN_ON(!t))
39
+ return NULL;
3740
3841 return container_of(t, struct virtio_transport, transport);
3942 }
....@@ -69,6 +72,9 @@
6972 pkt->buf = kmalloc(len, GFP_KERNEL);
7073 if (!pkt->buf)
7174 goto out_pkt;
75
+
76
+ pkt->buf_len = len;
77
+
7278 err = memcpy_from_msg(pkt->buf, info->msg, len);
7379 if (err)
7480 goto out;
....@@ -155,19 +161,33 @@
155161
156162 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
157163 {
164
+ if (pkt->tap_delivered)
165
+ return;
166
+
158167 vsock_deliver_tap(virtio_transport_build_skb, pkt);
168
+ pkt->tap_delivered = true;
159169 }
160170 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
161171
172
+/* This function can only be used on connecting/connected sockets,
173
+ * since a socket assigned to a transport is required.
174
+ *
175
+ * Do not use on listener sockets!
176
+ */
162177 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
163178 struct virtio_vsock_pkt_info *info)
164179 {
165180 u32 src_cid, src_port, dst_cid, dst_port;
181
+ const struct virtio_transport *t_ops;
166182 struct virtio_vsock_sock *vvs;
167183 struct virtio_vsock_pkt *pkt;
168184 u32 pkt_len = info->pkt_len;
169185
170
- src_cid = vm_sockets_get_local_cid();
186
+ t_ops = virtio_transport_get_ops(vsk);
187
+ if (unlikely(!t_ops))
188
+ return -EFAULT;
189
+
190
+ src_cid = t_ops->transport.get_local_cid();
171191 src_port = vsk->local_addr.svm_port;
172192 if (!info->remote_cid) {
173193 dst_cid = vsk->remote_addr.svm_cid;
....@@ -180,8 +200,8 @@
180200 vvs = vsk->trans;
181201
182202 /* we can send less than pkt_len bytes */
183
- if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
184
- pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
203
+ if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
204
+ pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
185205
186206 /* virtio_transport_get_credit might return less than pkt_len credit */
187207 pkt_len = virtio_transport_get_credit(vvs, pkt_len);
....@@ -200,13 +220,17 @@
200220
201221 virtio_transport_inc_tx_pkt(vvs, pkt);
202222
203
- return virtio_transport_get_ops()->send_pkt(pkt);
223
+ return t_ops->send_pkt(pkt);
204224 }
205225
206
-static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
226
+static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
207227 struct virtio_vsock_pkt *pkt)
208228 {
229
+ if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
230
+ return false;
231
+
209232 vvs->rx_bytes += pkt->len;
233
+ return true;
210234 }
211235
212236 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
....@@ -218,10 +242,11 @@
218242
219243 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
220244 {
221
- spin_lock_bh(&vvs->tx_lock);
245
+ spin_lock_bh(&vvs->rx_lock);
246
+ vvs->last_fwd_cnt = vvs->fwd_cnt;
222247 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
223248 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
224
- spin_unlock_bh(&vvs->tx_lock);
249
+ spin_unlock_bh(&vvs->rx_lock);
225250 }
226251 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
227252
....@@ -262,6 +287,55 @@
262287 }
263288
264289 static ssize_t
290
+virtio_transport_stream_do_peek(struct vsock_sock *vsk,
291
+ struct msghdr *msg,
292
+ size_t len)
293
+{
294
+ struct virtio_vsock_sock *vvs = vsk->trans;
295
+ struct virtio_vsock_pkt *pkt;
296
+ size_t bytes, total = 0, off;
297
+ int err = -EFAULT;
298
+
299
+ spin_lock_bh(&vvs->rx_lock);
300
+
301
+ list_for_each_entry(pkt, &vvs->rx_queue, list) {
302
+ off = pkt->off;
303
+
304
+ if (total == len)
305
+ break;
306
+
307
+ while (total < len && off < pkt->len) {
308
+ bytes = len - total;
309
+ if (bytes > pkt->len - off)
310
+ bytes = pkt->len - off;
311
+
312
+ /* sk_lock is held by caller so no one else can dequeue.
313
+ * Unlock rx_lock since memcpy_to_msg() may sleep.
314
+ */
315
+ spin_unlock_bh(&vvs->rx_lock);
316
+
317
+ err = memcpy_to_msg(msg, pkt->buf + off, bytes);
318
+ if (err)
319
+ goto out;
320
+
321
+ spin_lock_bh(&vvs->rx_lock);
322
+
323
+ total += bytes;
324
+ off += bytes;
325
+ }
326
+ }
327
+
328
+ spin_unlock_bh(&vvs->rx_lock);
329
+
330
+ return total;
331
+
332
+out:
333
+ if (total)
334
+ err = total;
335
+ return err;
336
+}
337
+
338
+static ssize_t
265339 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
266340 struct msghdr *msg,
267341 size_t len)
....@@ -269,6 +343,7 @@
269343 struct virtio_vsock_sock *vvs = vsk->trans;
270344 struct virtio_vsock_pkt *pkt;
271345 size_t bytes, total = 0;
346
+ u32 free_space;
272347 int err = -EFAULT;
273348
274349 spin_lock_bh(&vvs->rx_lock);
....@@ -299,11 +374,24 @@
299374 virtio_transport_free_pkt(pkt);
300375 }
301376 }
377
+
378
+ free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
379
+
302380 spin_unlock_bh(&vvs->rx_lock);
303381
304
- /* Send a credit pkt to peer */
305
- virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
306
- NULL);
382
+ /* To reduce the number of credit update messages,
383
+ * don't update credits as long as lots of space is available.
384
+ * Note: the limit chosen here is arbitrary. Setting the limit
385
+ * too high causes extra messages. Too low causes transmitter
386
+ * stalls. As stalls are in theory more expensive than extra
387
+ * messages, we set the limit to a high value. TODO: experiment
388
+ * with different values.
389
+ */
390
+ if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
391
+ virtio_transport_send_credit_update(vsk,
392
+ VIRTIO_VSOCK_TYPE_STREAM,
393
+ NULL);
394
+ }
307395
308396 return total;
309397
....@@ -319,9 +407,9 @@
319407 size_t len, int flags)
320408 {
321409 if (flags & MSG_PEEK)
322
- return -EOPNOTSUPP;
323
-
324
- return virtio_transport_stream_do_dequeue(vsk, msg, len);
410
+ return virtio_transport_stream_do_peek(vsk, msg, len);
411
+ else
412
+ return virtio_transport_stream_do_dequeue(vsk, msg, len);
325413 }
326414 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
327415
....@@ -383,20 +471,16 @@
383471
384472 vsk->trans = vvs;
385473 vvs->vsk = vsk;
386
- if (psk) {
474
+ if (psk && psk->trans) {
387475 struct virtio_vsock_sock *ptrans = psk->trans;
388476
389
- vvs->buf_size = ptrans->buf_size;
390
- vvs->buf_size_min = ptrans->buf_size_min;
391
- vvs->buf_size_max = ptrans->buf_size_max;
392477 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
393
- } else {
394
- vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
395
- vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
396
- vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
397478 }
398479
399
- vvs->buf_alloc = vvs->buf_size;
480
+ if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
481
+ vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
482
+
483
+ vvs->buf_alloc = vsk->buffer_size;
400484
401485 spin_lock_init(&vvs->rx_lock);
402486 spin_lock_init(&vvs->tx_lock);
....@@ -406,68 +490,20 @@
406490 }
407491 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
408492
409
-u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
493
+/* sk_lock held by the caller */
494
+void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
410495 {
411496 struct virtio_vsock_sock *vvs = vsk->trans;
412497
413
- return vvs->buf_size;
498
+ if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
499
+ *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
500
+
501
+ vvs->buf_alloc = *val;
502
+
503
+ virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
504
+ NULL);
414505 }
415
-EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
416
-
417
-u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
418
-{
419
- struct virtio_vsock_sock *vvs = vsk->trans;
420
-
421
- return vvs->buf_size_min;
422
-}
423
-EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
424
-
425
-u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
426
-{
427
- struct virtio_vsock_sock *vvs = vsk->trans;
428
-
429
- return vvs->buf_size_max;
430
-}
431
-EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
432
-
433
-void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
434
-{
435
- struct virtio_vsock_sock *vvs = vsk->trans;
436
-
437
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
438
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
439
- if (val < vvs->buf_size_min)
440
- vvs->buf_size_min = val;
441
- if (val > vvs->buf_size_max)
442
- vvs->buf_size_max = val;
443
- vvs->buf_size = val;
444
- vvs->buf_alloc = val;
445
-}
446
-EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
447
-
448
-void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
449
-{
450
- struct virtio_vsock_sock *vvs = vsk->trans;
451
-
452
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
453
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
454
- if (val > vvs->buf_size)
455
- vvs->buf_size = val;
456
- vvs->buf_size_min = val;
457
-}
458
-EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
459
-
460
-void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
461
-{
462
- struct virtio_vsock_sock *vvs = vsk->trans;
463
-
464
- if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
465
- val = VIRTIO_VSOCK_MAX_BUF_SIZE;
466
- if (val < vvs->buf_size)
467
- vvs->buf_size = val;
468
- vvs->buf_size_max = val;
469
-}
470
-EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
506
+EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
471507
472508 int
473509 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
....@@ -559,9 +595,7 @@
559595
560596 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
561597 {
562
- struct virtio_vsock_sock *vvs = vsk->trans;
563
-
564
- return vvs->buf_size;
598
+ return vsk->buffer_size;
565599 }
566600 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
567601
....@@ -703,6 +737,23 @@
703737 return t->send_pkt(reply);
704738 }
705739
740
+/* This function should be called with sk_lock held and SOCK_DONE set */
741
+static void virtio_transport_remove_sock(struct vsock_sock *vsk)
742
+{
743
+ struct virtio_vsock_sock *vvs = vsk->trans;
744
+ struct virtio_vsock_pkt *pkt, *tmp;
745
+
746
+ /* We don't need to take rx_lock, as the socket is closing and we are
747
+ * removing it.
748
+ */
749
+ list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
750
+ list_del(&pkt->list);
751
+ virtio_transport_free_pkt(pkt);
752
+ }
753
+
754
+ vsock_remove_sock(vsk);
755
+}
756
+
706757 static void virtio_transport_wait_close(struct sock *sk, long timeout)
707758 {
708759 if (timeout) {
....@@ -735,7 +786,7 @@
735786 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
736787 vsk->close_work_scheduled = false;
737788
738
- vsock_remove_sock(vsk);
789
+ virtio_transport_remove_sock(vsk);
739790
740791 /* Release refcnt obtained when we scheduled the timeout */
741792 sock_put(sk);
....@@ -798,23 +849,16 @@
798849
799850 void virtio_transport_release(struct vsock_sock *vsk)
800851 {
801
- struct virtio_vsock_sock *vvs = vsk->trans;
802
- struct virtio_vsock_pkt *pkt, *tmp;
803852 struct sock *sk = &vsk->sk;
804853 bool remove_sock = true;
805854
806
- lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
807855 if (sk->sk_type == SOCK_STREAM)
808856 remove_sock = virtio_transport_close(vsk);
809857
810
- list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
811
- list_del(&pkt->list);
812
- virtio_transport_free_pkt(pkt);
858
+ if (remove_sock) {
859
+ sock_set_flag(sk, SOCK_DONE);
860
+ virtio_transport_remove_sock(vsk);
813861 }
814
- release_sock(sk);
815
-
816
- if (remove_sock)
817
- vsock_remove_sock(vsk);
818862 }
819863 EXPORT_SYMBOL_GPL(virtio_transport_release);
820864
....@@ -854,24 +898,64 @@
854898 return err;
855899 }
856900
901
+static void
902
+virtio_transport_recv_enqueue(struct vsock_sock *vsk,
903
+ struct virtio_vsock_pkt *pkt)
904
+{
905
+ struct virtio_vsock_sock *vvs = vsk->trans;
906
+ bool can_enqueue, free_pkt = false;
907
+
908
+ pkt->len = le32_to_cpu(pkt->hdr.len);
909
+ pkt->off = 0;
910
+
911
+ spin_lock_bh(&vvs->rx_lock);
912
+
913
+ can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
914
+ if (!can_enqueue) {
915
+ free_pkt = true;
916
+ goto out;
917
+ }
918
+
919
+ /* Try to copy small packets into the buffer of last packet queued,
920
+ * to avoid wasting memory queueing the entire buffer with a small
921
+ * payload.
922
+ */
923
+ if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
924
+ struct virtio_vsock_pkt *last_pkt;
925
+
926
+ last_pkt = list_last_entry(&vvs->rx_queue,
927
+ struct virtio_vsock_pkt, list);
928
+
929
+ /* If there is space in the last packet queued, we copy the
930
+ * new packet in its buffer.
931
+ */
932
+ if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
933
+ memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
934
+ pkt->len);
935
+ last_pkt->len += pkt->len;
936
+ free_pkt = true;
937
+ goto out;
938
+ }
939
+ }
940
+
941
+ list_add_tail(&pkt->list, &vvs->rx_queue);
942
+
943
+out:
944
+ spin_unlock_bh(&vvs->rx_lock);
945
+ if (free_pkt)
946
+ virtio_transport_free_pkt(pkt);
947
+}
948
+
857949 static int
858950 virtio_transport_recv_connected(struct sock *sk,
859951 struct virtio_vsock_pkt *pkt)
860952 {
861953 struct vsock_sock *vsk = vsock_sk(sk);
862
- struct virtio_vsock_sock *vvs = vsk->trans;
863954 int err = 0;
864955
865956 switch (le16_to_cpu(pkt->hdr.op)) {
866957 case VIRTIO_VSOCK_OP_RW:
867
- pkt->len = le32_to_cpu(pkt->hdr.len);
868
- pkt->off = 0;
869
-
870
- spin_lock_bh(&vvs->rx_lock);
871
- virtio_transport_inc_rx_pkt(vvs, pkt);
872
- list_add_tail(&pkt->list, &vvs->rx_queue);
873
- spin_unlock_bh(&vvs->rx_lock);
874
-
958
+ virtio_transport_recv_enqueue(vsk, pkt);
875959 sk->sk_data_ready(sk);
876960 return err;
877961 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
....@@ -930,32 +1014,57 @@
9301014 return virtio_transport_send_pkt_info(vsk, &info);
9311015 }
9321016
1017
+static bool virtio_transport_space_update(struct sock *sk,
1018
+ struct virtio_vsock_pkt *pkt)
1019
+{
1020
+ struct vsock_sock *vsk = vsock_sk(sk);
1021
+ struct virtio_vsock_sock *vvs = vsk->trans;
1022
+ bool space_available;
1023
+
1024
+ /* Listener sockets are not associated with any transport, so we are
1025
+ * not able to take the state to see if there is space available in the
1026
+ * remote peer, but since they are only used to receive requests, we
1027
+ * can assume that there is always space available in the other peer.
1028
+ */
1029
+ if (!vvs)
1030
+ return true;
1031
+
1032
+ /* buf_alloc and fwd_cnt is always included in the hdr */
1033
+ spin_lock_bh(&vvs->tx_lock);
1034
+ vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1035
+ vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1036
+ space_available = virtio_transport_has_space(vsk);
1037
+ spin_unlock_bh(&vvs->tx_lock);
1038
+ return space_available;
1039
+}
1040
+
9331041 /* Handle server socket */
9341042 static int
935
-virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
1043
+virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
1044
+ struct virtio_transport *t)
9361045 {
9371046 struct vsock_sock *vsk = vsock_sk(sk);
9381047 struct vsock_sock *vchild;
9391048 struct sock *child;
1049
+ int ret;
9401050
9411051 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
942
- virtio_transport_reset(vsk, pkt);
1052
+ virtio_transport_reset_no_sock(t, pkt);
9431053 return -EINVAL;
9441054 }
9451055
9461056 if (sk_acceptq_is_full(sk)) {
947
- virtio_transport_reset(vsk, pkt);
1057
+ virtio_transport_reset_no_sock(t, pkt);
9481058 return -ENOMEM;
9491059 }
9501060
951
- child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
952
- sk->sk_type, 0);
1061
+ child = vsock_create_connected(sk);
9531062 if (!child) {
954
- virtio_transport_reset(vsk, pkt);
1063
+ virtio_transport_reset_no_sock(t, pkt);
9551064 return -ENOMEM;
9561065 }
9571066
958
- sk->sk_ack_backlog++;
1067
+ sk_acceptq_added(sk);
9591068
9601069 lock_sock_nested(child, SINGLE_DEPTH_NESTING);
9611070
....@@ -967,6 +1076,20 @@
9671076 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
9681077 le32_to_cpu(pkt->hdr.src_port));
9691078
1079
+ ret = vsock_assign_transport(vchild, vsk);
1080
+ /* Transport assigned (looking at remote_addr) must be the same
1081
+ * where we received the request.
1082
+ */
1083
+ if (ret || vchild->transport != &t->transport) {
1084
+ release_sock(child);
1085
+ virtio_transport_reset_no_sock(t, pkt);
1086
+ sock_put(child);
1087
+ return ret;
1088
+ }
1089
+
1090
+ if (virtio_transport_space_update(child, pkt))
1091
+ child->sk_write_space(child);
1092
+
9701093 vsock_insert_connected(vchild);
9711094 vsock_enqueue_accept(sk, child);
9721095 virtio_transport_send_response(vchild, pkt);
....@@ -975,22 +1098,6 @@
9751098
9761099 sk->sk_data_ready(sk);
9771100 return 0;
978
-}
979
-
980
-static bool virtio_transport_space_update(struct sock *sk,
981
- struct virtio_vsock_pkt *pkt)
982
-{
983
- struct vsock_sock *vsk = vsock_sk(sk);
984
- struct virtio_vsock_sock *vvs = vsk->trans;
985
- bool space_available;
986
-
987
- /* buf_alloc and fwd_cnt is always included in the hdr */
988
- spin_lock_bh(&vvs->tx_lock);
989
- vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
990
- vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
991
- space_available = virtio_transport_has_space(vsk);
992
- spin_unlock_bh(&vvs->tx_lock);
993
- return space_available;
9941101 }
9951102
9961103 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
....@@ -1039,17 +1146,26 @@
10391146
10401147 lock_sock(sk);
10411148
1149
+ /* Check if sk has been closed before lock_sock */
1150
+ if (sock_flag(sk, SOCK_DONE)) {
1151
+ (void)virtio_transport_reset_no_sock(t, pkt);
1152
+ release_sock(sk);
1153
+ sock_put(sk);
1154
+ goto free_pkt;
1155
+ }
1156
+
10421157 space_available = virtio_transport_space_update(sk, pkt);
10431158
10441159 /* Update CID in case it has changed after a transport reset event */
1045
- vsk->local_addr.svm_cid = dst.svm_cid;
1160
+ if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1161
+ vsk->local_addr.svm_cid = dst.svm_cid;
10461162
10471163 if (space_available)
10481164 sk->sk_write_space(sk);
10491165
10501166 switch (sk->sk_state) {
10511167 case TCP_LISTEN:
1052
- virtio_transport_recv_listen(sk, pkt);
1168
+ virtio_transport_recv_listen(sk, pkt, t);
10531169 virtio_transport_free_pkt(pkt);
10541170 break;
10551171 case TCP_SYN_SENT:
....@@ -1068,6 +1184,7 @@
10681184 virtio_transport_free_pkt(pkt);
10691185 break;
10701186 }
1187
+
10711188 release_sock(sk);
10721189
10731190 /* Release refcnt obtained when we fetched this socket out of the
....@@ -1083,7 +1200,7 @@
10831200
10841201 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
10851202 {
1086
- kfree(pkt->buf);
1203
+ kvfree(pkt->buf);
10871204 kfree(pkt);
10881205 }
10891206 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);