From 8ac6c7a54ed1b98d142dce24b11c6de6a1e239a5 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Tue, 22 Oct 2024 10:36:11 +0000
Subject: [PATCH] 修改4g拨号为QMI,需要在系统里后台执行quectel-CM

---
 kernel/net/vmw_vsock/af_vsock.c |  493 ++++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 364 insertions(+), 129 deletions(-)

diff --git a/kernel/net/vmw_vsock/af_vsock.c b/kernel/net/vmw_vsock/af_vsock.c
index 22931a5..06dddb5 100644
--- a/kernel/net/vmw_vsock/af_vsock.c
+++ b/kernel/net/vmw_vsock/af_vsock.c
@@ -1,16 +1,8 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * VMware vSockets Driver
  *
  * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the Free
- * Software Foundation version 2 and no later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
- * more details.
  */
 
 /* Implementation notes:
@@ -134,18 +126,19 @@
  */
 #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
 
-static const struct vsock_transport *transport;
+#define VSOCK_DEFAULT_BUFFER_SIZE     (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
+
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
+/* Transport used for local communication */
+static const struct vsock_transport *transport_local;
 static DEFINE_MUTEX(vsock_register_mutex);
-
-/**** EXPORTS ****/
-
-/* Get the ID of the local context.  This is transport dependent. */
-
-int vm_sockets_get_local_cid(void)
-{
-	return transport->get_local_cid();
-}
-EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
 
 /**** UTILS ****/
 
@@ -196,7 +189,7 @@
 	return __vsock_bind(sk, &local_addr);
 }
 
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
 {
 	int i;
 
@@ -205,7 +198,6 @@
 
 	for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
 		INIT_LIST_HEAD(&vsock_connected_table[i]);
-	return 0;
 }
 
 static void __vsock_insert_bound(struct list_head *list,
@@ -238,9 +230,15 @@
 {
 	struct vsock_sock *vsk;
 
-	list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
-		if (addr->svm_port == vsk->local_addr.svm_port)
+	list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
+		if (vsock_addr_equals_addr(addr, &vsk->local_addr))
 			return sk_vsock(vsk);
+
+		if (addr->svm_port == vsk->local_addr.svm_port &&
+		    (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
+		     addr->svm_cid == VMADDR_CID_ANY))
+			return sk_vsock(vsk);
+	}
 
 	return NULL;
 }
@@ -390,6 +388,112 @@
 }
 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
 
+static bool vsock_use_local_transport(unsigned int remote_cid)
+{
+	if (!transport_local)
+		return false;
+
+	if (remote_cid == VMADDR_CID_LOCAL)
+		return true;
+
+	if (transport_g2h) {
+		return remote_cid == transport_g2h->get_local_cid();
+	} else {
+		return remote_cid == VMADDR_CID_HOST;
+	}
+}
+
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+	if (!vsk->transport)
+		return;
+
+	vsk->transport->destruct(vsk);
+	module_put(vsk->transport->module);
+	vsk->transport = NULL;
+}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
+ *    g2h is not loaded, will use local transport;
+ *  - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+	const struct vsock_transport *new_transport;
+	struct sock *sk = sk_vsock(vsk);
+	unsigned int remote_cid = vsk->remote_addr.svm_cid;
+	int ret;
+
+	switch (sk->sk_type) {
+	case SOCK_DGRAM:
+		new_transport = transport_dgram;
+		break;
+	case SOCK_STREAM:
+		if (vsock_use_local_transport(remote_cid))
+			new_transport = transport_local;
+		else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g)
+			new_transport = transport_g2h;
+		else
+			new_transport = transport_h2g;
+		break;
+	default:
+		return -ESOCKTNOSUPPORT;
+	}
+
+	if (vsk->transport) {
+		if (vsk->transport == new_transport)
+			return 0;
+
+		/* transport->release() must be called with sock lock acquired.
+		 * This path can only be taken during vsock_stream_connect(),
+		 * where we have already held the sock lock.
+		 * In the other cases, this function is called on a new socket
+		 * which is not assigned to any transport.
+		 */
+		vsk->transport->release(vsk);
+		vsock_deassign_transport(vsk);
+	}
+
+	/* We increase the module refcnt to prevent the transport unloading
+	 * while there are open sockets assigned to it.
+	 */
+	if (!new_transport || !try_module_get(new_transport->module))
+		return -ENODEV;
+
+	ret = new_transport->init(vsk, psk);
+	if (ret) {
+		module_put(new_transport->module);
+		return ret;
+	}
+
+	vsk->transport = new_transport;
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+	if (transport_g2h && cid == transport_g2h->get_local_cid())
+		return true;
+
+	if (transport_h2g && cid == VMADDR_CID_HOST)
+		return true;
+
+	if (transport_local && cid == VMADDR_CID_LOCAL)
+		return true;
+
+	return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
 static struct sock *vsock_dequeue_accept(struct sock *listener)
 {
 	struct vsock_sock *vlistener;
@@ -426,7 +530,12 @@
 
 static int vsock_send_shutdown(struct sock *sk, int mode)
 {
-	return transport->shutdown(vsock_sk(sk), mode);
+	struct vsock_sock *vsk = vsock_sk(sk);
+
+	if (!vsk->transport)
+		return -ENODEV;
+
+	return vsk->transport->shutdown(vsk, mode);
 }
 
 static void vsock_pending_work(struct work_struct *work)
@@ -447,7 +556,7 @@
 	if (vsock_is_pending(sk)) {
 		vsock_remove_pending(listener, sk);
 
-		listener->sk_ack_backlog--;
+		sk_acceptq_removed(listener);
 	} else if (!vsk->rejected) {
 		/* We are not on the pending list and accept() did not reject
 		 * us, so we must have been accepted by our user process.  We
@@ -481,7 +590,7 @@
 static int __vsock_bind_stream(struct vsock_sock *vsk,
 			       struct sockaddr_vm *addr)
 {
-	static u32 port = 0;
+	static u32 port;
 	struct sockaddr_vm new_addr;
 
 	if (!port)
@@ -536,13 +645,12 @@
 static int __vsock_bind_dgram(struct vsock_sock *vsk,
 			      struct sockaddr_vm *addr)
 {
-	return transport->dgram_bind(vsk, addr);
+	return vsk->transport->dgram_bind(vsk, addr);
 }
 
 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 {
 	struct vsock_sock *vsk = vsock_sk(sk);
-	u32 cid;
 	int retval;
 
 	/* First ensure this socket isn't already bound. */
@@ -552,10 +660,9 @@
 	/* Now bind to the provided address or select appropriate values if
 	 * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
 	 * like AF_INET prevents binding to a non-local IP address (in most
-	 * cases), we only allow binding to the local CID.
+	 * cases), we only allow binding to a local CID.
 	 */
-	cid = transport->get_local_cid();
-	if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+	if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
 		return -EADDRNOTAVAIL;
 
 	switch (sk->sk_socket->type) {
@@ -579,12 +686,12 @@
 
 static void vsock_connect_timeout(struct work_struct *work);
 
-struct sock *__vsock_create(struct net *net,
-			    struct socket *sock,
-			    struct sock *parent,
-			    gfp_t priority,
-			    unsigned short type,
-			    int kern)
+static struct sock *__vsock_create(struct net *net,
+				   struct socket *sock,
+				   struct sock *parent,
+				   gfp_t priority,
+				   unsigned short type,
+				   int kern)
 {
 	struct sock *sk;
 	struct vsock_sock *psk;
@@ -628,39 +735,30 @@
 		vsk->trusted = psk->trusted;
 		vsk->owner = get_cred(psk->owner);
 		vsk->connect_timeout = psk->connect_timeout;
+		vsk->buffer_size = psk->buffer_size;
+		vsk->buffer_min_size = psk->buffer_min_size;
+		vsk->buffer_max_size = psk->buffer_max_size;
 		security_sk_clone(parent, sk);
 	} else {
 		vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN);
 		vsk->owner = get_current_cred();
 		vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
+		vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
+		vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
+		vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
 	}
-
-	if (transport->init(vsk, psk) < 0) {
-		sk_free(sk);
-		return NULL;
-	}
-
-	if (sock)
-		vsock_insert_unbound(vsk);
 
 	return sk;
 }
-EXPORT_SYMBOL_GPL(__vsock_create);
 
 static void __vsock_release(struct sock *sk, int level)
 {
 	if (sk) {
-		struct sk_buff *skb;
 		struct sock *pending;
 		struct vsock_sock *vsk;
 
 		vsk = vsock_sk(sk);
 		pending = NULL;	/* Compiler warning. */
-
-		/* The release call is supposed to use lock_sock_nested()
-		 * rather than lock_sock(), if a sock lock should be acquired.
-		 */
-		transport->release(vsk);
 
 		/* When "level" is SINGLE_DEPTH_NESTING, use the nested
 		 * version to avoid the warning "possible recursive locking
@@ -668,11 +766,16 @@
 		 * is the same as lock_sock(sk).
 		 */
 		lock_sock_nested(sk, level);
+
+		if (vsk->transport)
+			vsk->transport->release(vsk);
+		else if (sk->sk_type == SOCK_STREAM)
+			vsock_remove_sock(vsk);
+
 		sock_orphan(sk);
 		sk->sk_shutdown = SHUTDOWN_MASK;
 
-		while ((skb = skb_dequeue(&sk->sk_receive_queue)))
-			kfree_skb(skb);
+		skb_queue_purge(&sk->sk_receive_queue);
 
 		/* Clean up any sockets that never were accepted. */
 		while ((pending = vsock_dequeue_accept(sk)) != NULL) {
@@ -689,7 +792,7 @@
 {
 	struct vsock_sock *vsk = vsock_sk(sk);
 
-	transport->destruct(vsk);
+	vsock_deassign_transport(vsk);
 
 	/* When clearing these addresses, there's no need to set the family and
 	 * possibly register the address family with the kernel.
@@ -711,15 +814,22 @@
 	return err;
 }
 
+struct sock *vsock_create_connected(struct sock *parent)
+{
+	return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
+			      parent->sk_type, 0);
+}
+EXPORT_SYMBOL_GPL(vsock_create_connected);
+
 s64 vsock_stream_has_data(struct vsock_sock *vsk)
 {
-	return transport->stream_has_data(vsk);
+	return vsk->transport->stream_has_data(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
 
 s64 vsock_stream_has_space(struct vsock_sock *vsk)
 {
-	return transport->stream_has_space(vsk);
+	return vsk->transport->stream_has_space(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
 
@@ -890,7 +1000,11 @@
 			mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
 
 	} else if (sock->type == SOCK_STREAM) {
+		const struct vsock_transport *transport;
+
 		lock_sock(sk);
+
+		transport = vsk->transport;
 
 		/* Listening sockets that have connections in their accept
 		 * queue can be read.
@@ -900,7 +1014,7 @@
 			mask |= EPOLLIN | EPOLLRDNORM;
 
 		/* If there is something in the queue then we can read. */
-		if (transport->stream_is_active(vsk) &&
+		if (transport && transport->stream_is_active(vsk) &&
 		    !(sk->sk_shutdown & RCV_SHUTDOWN)) {
 			bool data_ready_now = false;
 			int ret = transport->notify_poll_in(
@@ -924,7 +1038,7 @@
 		}
 
 		/* Connected sockets that can produce data can be written. */
-		if (sk->sk_state == TCP_ESTABLISHED) {
+		if (transport && sk->sk_state == TCP_ESTABLISHED) {
 			if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
 				bool space_avail_now = false;
 				int ret = transport->notify_poll_out(
@@ -965,6 +1079,7 @@
 	struct sock *sk;
 	struct vsock_sock *vsk;
 	struct sockaddr_vm *remote_addr;
+	const struct vsock_transport *transport;
 
 	if (msg->msg_flags & MSG_OOB)
 		return -EOPNOTSUPP;
@@ -975,6 +1090,8 @@
 	vsk = vsock_sk(sk);
 
 	lock_sock(sk);
+
+	transport = vsk->transport;
 
 	err = vsock_auto_bind(vsk);
 	if (err)
@@ -1057,8 +1174,8 @@
 	if (err)
 		goto out;
 
-	if (!transport->dgram_allow(remote_addr->svm_cid,
-				    remote_addr->svm_port)) {
+	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+					 remote_addr->svm_port)) {
 		err = -EINVAL;
 		goto out;
 	}
@@ -1074,7 +1191,9 @@
 static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
 			       size_t len, int flags)
 {
-	return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags);
+	struct vsock_sock *vsk = vsock_sk(sock->sk);
+
+	return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
 }
 
 static const struct proto_ops vsock_dgram_ops = {
@@ -1090,8 +1209,6 @@
 	.ioctl = sock_no_ioctl,
 	.listen = sock_no_listen,
 	.shutdown = vsock_shutdown,
-	.setsockopt = sock_no_setsockopt,
-	.getsockopt = sock_no_getsockopt,
 	.sendmsg = vsock_dgram_sendmsg,
 	.recvmsg = vsock_dgram_recvmsg,
 	.mmap = sock_no_mmap,
@@ -1100,7 +1217,9 @@
 
 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
 {
-	if (!transport->cancel_pkt)
+	const struct vsock_transport *transport = vsk->transport;
+
+	if (!transport || !transport->cancel_pkt)
 		return -EOPNOTSUPP;
 
 	return transport->cancel_pkt(vsk);
@@ -1118,6 +1237,7 @@
 	if (sk->sk_state == TCP_SYN_SENT &&
 	    (sk->sk_shutdown != SHUTDOWN_MASK)) {
 		sk->sk_state = TCP_CLOSE;
+		sk->sk_socket->state = SS_UNCONNECTED;
 		sk->sk_err = ETIMEDOUT;
 		sk->sk_error_report(sk);
 		vsock_transport_cancel_pkt(vsk);
@@ -1133,6 +1253,7 @@
 	int err;
 	struct sock *sk;
 	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
 	struct sockaddr_vm *remote_addr;
 	long timeout;
 	DEFINE_WAIT(wait);
@@ -1169,18 +1290,25 @@
 			goto out;
 		}
 
+		/* Set the remote address that we are connecting to. */
+		memcpy(&vsk->remote_addr, remote_addr,
+		       sizeof(vsk->remote_addr));
+
+		err = vsock_assign_transport(vsk, NULL);
+		if (err)
+			goto out;
+
+		transport = vsk->transport;
+
 		/* The hypervisor and well-known contexts do not have socket
 		 * endpoints.
 		 */
-		if (!transport->stream_allow(remote_addr->svm_cid,
+		if (!transport ||
+		    !transport->stream_allow(remote_addr->svm_cid,
 					     remote_addr->svm_port)) {
 			err = -ENETUNREACH;
 			goto out;
 		}
-
-		/* Set the remote address that we are connecting to. */
-		memcpy(&vsk->remote_addr, remote_addr,
-		       sizeof(vsk->remote_addr));
 
 		err = vsock_auto_bind(vsk);
 		if (err)
@@ -1215,7 +1343,14 @@
 			 * timeout fires.
 			 */
 			sock_hold(sk);
-			schedule_delayed_work(&vsk->connect_work, timeout);
+
+			/* If the timeout function is already scheduled,
+			 * reschedule it, then ungrab the socket refcount to
+			 * keep it balanced.
+			 */
+			if (mod_delayed_work(system_wq, &vsk->connect_work,
+					     timeout))
+				sock_put(sk);
 
 			/* Skip ahead to preserve error code set above. */
 			goto out_wait;
@@ -1232,7 +1367,7 @@
 			vsock_transport_cancel_pkt(vsk);
 			vsock_remove_connected(vsk);
 			goto out_wait;
-		} else if (timeout == 0) {
+		} else if ((sk->sk_state != TCP_ESTABLISHED) && (timeout == 0)) {
 			err = -ETIMEDOUT;
 			sk->sk_state = TCP_CLOSE;
 			sock->state = SS_UNCONNECTED;
@@ -1312,7 +1447,7 @@
 		err = -listener->sk_err;
 
 	if (connected) {
-		listener->sk_ack_backlog--;
+		sk_acceptq_removed(listener);
 
 		lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
 		vconnected = vsock_sk(connected);
@@ -1377,15 +1512,33 @@
 	return err;
 }
 
+static void vsock_update_buffer_size(struct vsock_sock *vsk,
+				     const struct vsock_transport *transport,
+				     u64 val)
+{
+	if (val > vsk->buffer_max_size)
+		val = vsk->buffer_max_size;
+
+	if (val < vsk->buffer_min_size)
+		val = vsk->buffer_min_size;
+
+	if (val != vsk->buffer_size &&
+	    transport && transport->notify_buffer_size)
+		transport->notify_buffer_size(vsk, &val);
+
+	vsk->buffer_size = val;
+}
+
 static int vsock_stream_setsockopt(struct socket *sock,
 				   int level,
 				   int optname,
-				   char __user *optval,
+				   sockptr_t optval,
 				   unsigned int optlen)
 {
 	int err;
 	struct sock *sk;
 	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
 	u64 val;
 
 	if (level != AF_VSOCK)
@@ -1397,7 +1550,7 @@
 			err = -EINVAL;			  \
 			goto exit;			  \
 		}					  \
-		if (copy_from_user(&_v, optval, sizeof(_v)) != 0) {	\
+		if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) {	\
 			err = -EFAULT;					\
 			goto exit;					\
 		}							\
@@ -1409,24 +1562,28 @@
 
 	lock_sock(sk);
 
+	transport = vsk->transport;
+
 	switch (optname) {
 	case SO_VM_SOCKETS_BUFFER_SIZE:
 		COPY_IN(val);
-		transport->set_buffer_size(vsk, val);
+		vsock_update_buffer_size(vsk, transport, val);
 		break;
 
 	case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
 		COPY_IN(val);
-		transport->set_max_buffer_size(vsk, val);
+		vsk->buffer_max_size = val;
+		vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
 		break;
 
 	case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
 		COPY_IN(val);
-		transport->set_min_buffer_size(vsk, val);
+		vsk->buffer_min_size = val;
+		vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
 		break;
 
 	case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
-		struct timeval tv;
+		struct __kernel_old_timeval tv;
 		COPY_IN(tv);
 		if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC &&
 		    tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) {
@@ -1489,22 +1646,22 @@
 
 	switch (optname) {
 	case SO_VM_SOCKETS_BUFFER_SIZE:
-		val = transport->get_buffer_size(vsk);
+		val = vsk->buffer_size;
 		COPY_OUT(val);
 		break;
 
 	case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
-		val = transport->get_max_buffer_size(vsk);
+		val = vsk->buffer_max_size;
 		COPY_OUT(val);
 		break;
 
 	case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
-		val = transport->get_min_buffer_size(vsk);
+		val = vsk->buffer_min_size;
 		COPY_OUT(val);
 		break;
 
 	case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
-		struct timeval tv;
+		struct __kernel_old_timeval tv;
 		tv.tv_sec = vsk->connect_timeout / HZ;
 		tv.tv_usec =
 		    (vsk->connect_timeout -
@@ -1530,6 +1687,7 @@
 {
 	struct sock *sk;
 	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
 	ssize_t total_written;
 	long timeout;
 	int err;
@@ -1546,6 +1704,8 @@
 
 	lock_sock(sk);
 
+	transport = vsk->transport;
+
 	/* Callers should not provide a destination with stream sockets. */
 	if (msg->msg_namelen) {
 		err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
@@ -1559,7 +1719,7 @@
 		goto out;
 	}
 
-	if (sk->sk_state != TCP_ESTABLISHED ||
+	if (!transport || sk->sk_state != TCP_ESTABLISHED ||
 	    !vsock_addr_bound(&vsk->local_addr)) {
 		err = -ENOTCONN;
 		goto out;
@@ -1669,6 +1829,7 @@
 {
 	struct sock *sk;
 	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
 	int err;
 	size_t target;
 	ssize_t copied;
@@ -1683,7 +1844,9 @@
 
 	lock_sock(sk);
 
-	if (sk->sk_state != TCP_ESTABLISHED) {
+	transport = vsk->transport;
+
+	if (!transport || sk->sk_state != TCP_ESTABLISHED) {
 		/* Recvmsg is supposed to return 0 if a peer performs an
 		 * orderly shutdown. Differentiate between that case and when a
 		 * peer has not connected or a local shutdown occured with the
@@ -1857,6 +2020,10 @@
 static int vsock_create(struct net *net, struct socket *sock,
 			int protocol, int kern)
 {
+	struct vsock_sock *vsk;
+	struct sock *sk;
+	int ret;
+
 	if (!sock)
 		return -EINVAL;
 
@@ -1876,7 +2043,23 @@
 
 	sock->state = SS_UNCONNECTED;
 
-	return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM;
+	sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
+	if (!sk)
+		return -ENOMEM;
+
+	vsk = vsock_sk(sk);
+
+	if (sock->type == SOCK_DGRAM) {
+		ret = vsock_assign_transport(vsk, NULL);
+		if (ret < 0) {
+			sock_put(sk);
+			return ret;
+		}
+	}
+
+	vsock_insert_unbound(vsk);
+
+	return 0;
 }
 
 static const struct net_proto_family vsock_family_ops = {
@@ -1889,11 +2072,20 @@
 			       unsigned int cmd, void __user *ptr)
 {
 	u32 __user *p = ptr;
+	u32 cid = VMADDR_CID_ANY;
 	int retval = 0;
 
 	switch (cmd) {
 	case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
-		if (put_user(transport->get_local_cid(), p) != 0)
+		/* To be compatible with the VMCI behavior, we prioritize the
+		 * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+		 */
+		if (transport_g2h)
+			cid = transport_g2h->get_local_cid();
+		else if (transport_h2g)
+			cid = transport_h2g->get_local_cid();
+
+		if (put_user(cid, p) != 0)
 			retval = -EFAULT;
 		break;
 
@@ -1933,24 +2125,13 @@
 	.fops		= &vsock_device_ops,
 };
 
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
 {
-	int err = mutex_lock_interruptible(&vsock_register_mutex);
+	int err = 0;
 
-	if (err)
-		return err;
+	vsock_init_tables();
 
-	if (transport) {
-		err = -EBUSY;
-		goto err_busy;
-	}
-
-	/* Transport must be the owner of the protocol so that it can't
-	 * unload while there are open sockets.
-	 */
-	vsock_proto.owner = owner;
-	transport = t;
-
+	vsock_proto.owner = THIS_MODULE;
 	vsock_device.minor = MISC_DYNAMIC_MINOR;
 	err = misc_register(&vsock_device);
 	if (err) {
@@ -1971,7 +2152,6 @@
 		goto err_unregister_proto;
 	}
 
-	mutex_unlock(&vsock_register_mutex);
 	return 0;
 
 err_unregister_proto:
@@ -1979,44 +2159,99 @@
 err_deregister_misc:
 	misc_deregister(&vsock_device);
 err_reset_transport:
-	transport = NULL;
+	return err;
+}
+
+static void __exit vsock_exit(void)
+{
+	misc_deregister(&vsock_device);
+	sock_unregister(AF_VSOCK);
+	proto_unregister(&vsock_proto);
+}
+
+const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
+{
+	return vsk->transport;
+}
+EXPORT_SYMBOL_GPL(vsock_core_get_transport);
+
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+	const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local;
+	int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+	if (err)
+		return err;
+
+	t_h2g = transport_h2g;
+	t_g2h = transport_g2h;
+	t_dgram = transport_dgram;
+	t_local = transport_local;
+
+	if (features & VSOCK_TRANSPORT_F_H2G) {
+		if (t_h2g) {
+			err = -EBUSY;
+			goto err_busy;
+		}
+		t_h2g = t;
+	}
+
+	if (features & VSOCK_TRANSPORT_F_G2H) {
+		if (t_g2h) {
+			err = -EBUSY;
+			goto err_busy;
+		}
+		t_g2h = t;
+	}
+
+	if (features & VSOCK_TRANSPORT_F_DGRAM) {
+		if (t_dgram) {
+			err = -EBUSY;
+			goto err_busy;
+		}
+		t_dgram = t;
+	}
+
+	if (features & VSOCK_TRANSPORT_F_LOCAL) {
+		if (t_local) {
+			err = -EBUSY;
+			goto err_busy;
+		}
+		t_local = t;
+	}
+
+	transport_h2g = t_h2g;
+	transport_g2h = t_g2h;
+	transport_dgram = t_dgram;
+	transport_local = t_local;
+
 err_busy:
 	mutex_unlock(&vsock_register_mutex);
 	return err;
 }
-EXPORT_SYMBOL_GPL(__vsock_core_init);
+EXPORT_SYMBOL_GPL(vsock_core_register);
 
-void vsock_core_exit(void)
+void vsock_core_unregister(const struct vsock_transport *t)
 {
 	mutex_lock(&vsock_register_mutex);
 
-	misc_deregister(&vsock_device);
-	sock_unregister(AF_VSOCK);
-	proto_unregister(&vsock_proto);
+	if (transport_h2g == t)
+		transport_h2g = NULL;
 
-	/* We do not want the assignment below re-ordered. */
-	mb();
-	transport = NULL;
+	if (transport_g2h == t)
+		transport_g2h = NULL;
+
+	if (transport_dgram == t)
+		transport_dgram = NULL;
+
+	if (transport_local == t)
+		transport_local = NULL;
 
 	mutex_unlock(&vsock_register_mutex);
 }
-EXPORT_SYMBOL_GPL(vsock_core_exit);
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
 
-const struct vsock_transport *vsock_core_get_transport(void)
-{
-	/* vsock_register_mutex not taken since only the transport uses this
-	 * function and only while registered.
-	 */
-	return transport;
-}
-EXPORT_SYMBOL_GPL(vsock_core_get_transport);
-
-static void __exit vsock_exit(void)
-{
-	/* Do nothing.  This function makes this module removable. */
-}
-
-module_init(vsock_init_tables);
+module_init(vsock_init);
 module_exit(vsock_exit);
 
 MODULE_AUTHOR("VMware, Inc.");

--
Gitblit v1.6.2