From 04dd17822334871b23ea2862f7798fb0e0007777 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Sat, 11 May 2024 08:53:19 +0000
Subject: [PATCH] change otg to host mode

---
 kernel/drivers/net/wireguard/netlink.c |   62 +++++++++++++-----------------
 1 files changed, 27 insertions(+), 35 deletions(-)

diff --git a/kernel/drivers/net/wireguard/netlink.c b/kernel/drivers/net/wireguard/netlink.c
index a4377ad..f5bc279 100644
--- a/kernel/drivers/net/wireguard/netlink.c
+++ b/kernel/drivers/net/wireguard/netlink.c
@@ -17,17 +17,13 @@
 #include <net/sock.h>
 #include <crypto/algapi.h>
 
-struct __uapi_kernel_timespec {
-	int64_t tv_sec, tv_nsec;
-};
-
 static struct genl_family genl_family;
 
 static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
 	[WGDEVICE_A_IFINDEX]		= { .type = NLA_U32 },
 	[WGDEVICE_A_IFNAME]		= { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
-	[WGDEVICE_A_PRIVATE_KEY]	= { .len = NOISE_PUBLIC_KEY_LEN },
-	[WGDEVICE_A_PUBLIC_KEY]		= { .len = NOISE_PUBLIC_KEY_LEN },
+	[WGDEVICE_A_PRIVATE_KEY]	= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
+	[WGDEVICE_A_PUBLIC_KEY]		= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
 	[WGDEVICE_A_FLAGS]		= { .type = NLA_U32 },
 	[WGDEVICE_A_LISTEN_PORT]	= { .type = NLA_U16 },
 	[WGDEVICE_A_FWMARK]		= { .type = NLA_U32 },
@@ -35,12 +31,12 @@
 };
 
 static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
-	[WGPEER_A_PUBLIC_KEY]				= { .len = NOISE_PUBLIC_KEY_LEN },
-	[WGPEER_A_PRESHARED_KEY]			= { .len = NOISE_SYMMETRIC_KEY_LEN },
+	[WGPEER_A_PUBLIC_KEY]				= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
+	[WGPEER_A_PRESHARED_KEY]			= NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
 	[WGPEER_A_FLAGS]				= { .type = NLA_U32 },
-	[WGPEER_A_ENDPOINT]				= { .len = sizeof(struct sockaddr) },
+	[WGPEER_A_ENDPOINT]				= NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
 	[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]	= { .type = NLA_U16 },
-	[WGPEER_A_LAST_HANDSHAKE_TIME]			= { .len = sizeof(struct __uapi_kernel_timespec) },
+	[WGPEER_A_LAST_HANDSHAKE_TIME]			= NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
 	[WGPEER_A_RX_BYTES]				= { .type = NLA_U64 },
 	[WGPEER_A_TX_BYTES]				= { .type = NLA_U64 },
 	[WGPEER_A_ALLOWEDIPS]				= { .type = NLA_NESTED },
@@ -49,7 +45,7 @@
 
 static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
 	[WGALLOWEDIP_A_FAMILY]		= { .type = NLA_U16 },
-	[WGALLOWEDIP_A_IPADDR]		= { .len = sizeof(struct in_addr) },
+	[WGALLOWEDIP_A_IPADDR]		= NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
 	[WGALLOWEDIP_A_CIDR_MASK]	= { .type = NLA_U8 }
 };
 
@@ -125,7 +121,7 @@
 		goto err;
 
 	if (!allowedips_node) {
-		const struct __uapi_kernel_timespec last_handshake = {
+		const struct __kernel_timespec last_handshake = {
 			.tv_sec = peer->walltime_last_handshake.tv_sec,
 			.tv_nsec = peer->walltime_last_handshake.tv_nsec
 		};
@@ -202,15 +198,9 @@
 
 static int wg_get_device_start(struct netlink_callback *cb)
 {
-	struct nlattr **attrs = genl_family_attrbuf(&genl_family);
 	struct wg_device *wg;
-	int ret;
 
-	ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs,
-			  genl_family.maxattr, device_policy, NULL);
-	if (ret < 0)
-		return ret;
-	wg = lookup_interface(attrs, cb->skb);
+	wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
 	if (IS_ERR(wg))
 		return PTR_ERR(wg);
 	DUMP_CTX(cb)->wg = wg;
@@ -446,14 +436,13 @@
 	if (attrs[WGPEER_A_ENDPOINT]) {
 		struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
 		size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
+		struct endpoint endpoint = { { { 0 } } };
 
-		if ((len == sizeof(struct sockaddr_in) &&
-		     addr->sa_family == AF_INET) ||
-		    (len == sizeof(struct sockaddr_in6) &&
-		     addr->sa_family == AF_INET6)) {
-			struct endpoint endpoint = { { { 0 } } };
-
-			memcpy(&endpoint.addr, addr, len);
+		if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) {
+			endpoint.addr4 = *(struct sockaddr_in *)addr;
+			wg_socket_set_peer_endpoint(peer, &endpoint);
+		} else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) {
+			endpoint.addr6 = *(struct sockaddr_in6 *)addr;
 			wg_socket_set_peer_endpoint(peer, &endpoint);
 		}
 	}
@@ -557,6 +546,7 @@
 		u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
 		u8 public_key[NOISE_PUBLIC_KEY_LEN];
 		struct wg_peer *peer, *temp;
+		bool send_staged_packets;
 
 		if (!crypto_memneq(wg->static_identity.static_private,
 				   private_key, NOISE_PUBLIC_KEY_LEN))
@@ -575,14 +565,17 @@
 		}
 
 		down_write(&wg->static_identity.lock);
-		wg_noise_set_static_identity_private_key(&wg->static_identity,
-							 private_key);
-		list_for_each_entry_safe(peer, temp, &wg->peer_list,
-					 peer_list) {
+		send_staged_packets = !wg->static_identity.has_identity && netif_running(wg->dev);
+		wg_noise_set_static_identity_private_key(&wg->static_identity, private_key);
+		send_staged_packets = send_staged_packets && wg->static_identity.has_identity;
+
+		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
+		list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) {
 			wg_noise_precompute_static_static(peer);
 			wg_noise_expire_current_peer_keypairs(peer);
+			if (send_staged_packets)
+				wg_packet_send_staged_packets(peer);
 		}
-		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
 		up_write(&wg->static_identity.lock);
 	}
 skip_set_private_key:
@@ -620,13 +613,11 @@
 		.start = wg_get_device_start,
 		.dumpit = wg_get_device_dump,
 		.done = wg_get_device_done,
-		.flags = GENL_UNS_ADMIN_PERM,
-		.policy = device_policy
+		.flags = GENL_UNS_ADMIN_PERM
 	}, {
 		.cmd = WG_CMD_SET_DEVICE,
 		.doit = wg_set_device,
-		.flags = GENL_UNS_ADMIN_PERM,
-		.policy = device_policy
+		.flags = GENL_UNS_ADMIN_PERM
 	}
 };
 
@@ -637,6 +628,7 @@
 	.version = WG_GENL_VERSION,
 	.maxattr = WGDEVICE_A_MAX,
 	.module = THIS_MODULE,
+	.policy = device_policy,
 	.netnsok = true
 };
 

--
Gitblit v1.6.2