From 9999e48639b3cecb08ffb37358bcba3b48161b29 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Fri, 10 May 2024 08:50:17 +0000
Subject: [PATCH] add ax88772_rst

---
 kernel/fs/afs/addr_list.c |  310 ++++++++++++++++++++++++++------------------------
 1 files changed, 161 insertions(+), 149 deletions(-)

diff --git a/kernel/fs/afs/addr_list.c b/kernel/fs/afs/addr_list.c
index 025a9a5..de1ae0b 100644
--- a/kernel/fs/afs/addr_list.c
+++ b/kernel/fs/afs/addr_list.c
@@ -1,12 +1,8 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /* Server address list management
  *
  * Copyright (C) 2017 Red Hat, Inc. All Rights Reserved.
  * Written by David Howells (dhowells@redhat.com)
- *
- * This program is free software; you can redistribute it and/or
- * modify it under the terms of the GNU General Public Licence
- * as published by the Free Software Foundation; either version
- * 2 of the Licence, or (at your option) any later version.
  */
 
 #include <linux/slab.h>
@@ -17,18 +13,13 @@
 #include "internal.h"
 #include "afs_fs.h"
 
-//#define AFS_MAX_ADDRESSES
-//	((unsigned int)((PAGE_SIZE - sizeof(struct afs_addr_list)) /
-//			sizeof(struct sockaddr_rxrpc)))
-#define AFS_MAX_ADDRESSES ((unsigned int)(sizeof(unsigned long) * 8))
-
 /*
  * Release an address list.
  */
 void afs_put_addrlist(struct afs_addr_list *alist)
 {
 	if (alist && refcount_dec_and_test(&alist->usage))
-		call_rcu(&alist->rcu, (rcu_callback_t)kfree);
+		kfree_rcu(alist, rcu);
 }
 
 /*
@@ -43,11 +34,15 @@
 
 	_enter("%u,%u,%u", nr, service, port);
 
+	if (nr > AFS_MAX_ADDRESSES)
+		nr = AFS_MAX_ADDRESSES;
+
 	alist = kzalloc(struct_size(alist, addrs, nr), GFP_KERNEL);
 	if (!alist)
 		return NULL;
 
 	refcount_set(&alist->usage, 1);
+	alist->max_addrs = nr;
 
 	for (i = 0; i < nr; i++) {
 		struct sockaddr_rxrpc *srx = &alist->addrs[i];
@@ -65,19 +60,25 @@
 /*
  * Parse a text string consisting of delimited addresses.
  */
-struct afs_addr_list *afs_parse_text_addrs(const char *text, size_t len,
-					   char delim,
-					   unsigned short service,
-					   unsigned short port)
+struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
+					       const char *text, size_t len,
+					       char delim,
+					       unsigned short service,
+					       unsigned short port)
 {
+	struct afs_vlserver_list *vllist;
 	struct afs_addr_list *alist;
 	const char *p, *end = text + len;
+	const char *problem;
 	unsigned int nr = 0;
+	int ret = -ENOMEM;
 
 	_enter("%*.*s,%c", (int)len, (int)len, text, delim);
 
-	if (!len)
+	if (!len) {
+		_leave(" = -EDESTADDRREQ [empty]");
 		return ERR_PTR(-EDESTADDRREQ);
+	}
 
 	if (delim == ':' && (memchr(text, ',', len) || !memchr(text, '.', len)))
 		delim = ',';
@@ -85,18 +86,24 @@
 	/* Count the addresses */
 	p = text;
 	do {
-		if (!*p)
-			return ERR_PTR(-EINVAL);
+		if (!*p) {
+			problem = "nul";
+			goto inval;
+		}
 		if (*p == delim)
 			continue;
 		nr++;
 		if (*p == '[') {
 			p++;
-			if (p == end)
-				return ERR_PTR(-EINVAL);
+			if (p == end) {
+				problem = "brace1";
+				goto inval;
+			}
 			p = memchr(p, ']', end - p);
-			if (!p)
-				return ERR_PTR(-EINVAL);
+			if (!p) {
+				problem = "brace2";
+				goto inval;
+			}
 			p++;
 			if (p >= end)
 				break;
@@ -109,18 +116,27 @@
 	} while (p < end);
 
 	_debug("%u/%u addresses", nr, AFS_MAX_ADDRESSES);
-	if (nr > AFS_MAX_ADDRESSES)
-		nr = AFS_MAX_ADDRESSES;
 
-	alist = afs_alloc_addrlist(nr, service, port);
-	if (!alist)
+	vllist = afs_alloc_vlserver_list(1);
+	if (!vllist)
 		return ERR_PTR(-ENOMEM);
+
+	vllist->nr_servers = 1;
+	vllist->servers[0].server = afs_alloc_vlserver("<dummy>", 7, AFS_VL_PORT);
+	if (!vllist->servers[0].server)
+		goto error_vl;
+
+	alist = afs_alloc_addrlist(nr, service, AFS_VL_PORT);
+	if (!alist)
+		goto error;
 
 	/* Extract the addresses */
 	p = text;
 	do {
-		struct sockaddr_rxrpc *srx = &alist->addrs[alist->nr_addrs];
 		const char *q, *stop;
+		unsigned int xport = port;
+		__be32 x[4];
+		int family;
 
 		if (*p == delim) {
 			p++;
@@ -136,58 +152,74 @@
 					break;
 		}
 
-		if (in4_pton(p, q - p,
-			     (u8 *)&srx->transport.sin6.sin6_addr.s6_addr32[3],
-			     -1, &stop)) {
-			srx->transport.sin6.sin6_addr.s6_addr32[0] = 0;
-			srx->transport.sin6.sin6_addr.s6_addr32[1] = 0;
-			srx->transport.sin6.sin6_addr.s6_addr32[2] = htonl(0xffff);
-		} else if (in6_pton(p, q - p,
-				    srx->transport.sin6.sin6_addr.s6_addr,
-				    -1, &stop)) {
-			/* Nothing to do */
+		if (in4_pton(p, q - p, (u8 *)&x[0], -1, &stop)) {
+			family = AF_INET;
+		} else if (in6_pton(p, q - p, (u8 *)x, -1, &stop)) {
+			family = AF_INET6;
 		} else {
+			problem = "family";
 			goto bad_address;
 		}
 
-		if (stop != q)
-			goto bad_address;
-
 		p = q;
+		if (stop != p) {
+			problem = "nostop";
+			goto bad_address;
+		}
+
 		if (q < end && *q == ']')
 			p++;
 
 		if (p < end) {
 			if (*p == '+') {
 				/* Port number specification "+1234" */
-				unsigned int xport = 0;
+				xport = 0;
 				p++;
-				if (p >= end || !isdigit(*p))
+				if (p >= end || !isdigit(*p)) {
+					problem = "port";
 					goto bad_address;
+				}
 				do {
 					xport *= 10;
 					xport += *p - '0';
-					if (xport > 65535)
+					if (xport > 65535) {
+						problem = "pval";
 						goto bad_address;
+					}
 					p++;
 				} while (p < end && isdigit(*p));
-				srx->transport.sin6.sin6_port = htons(xport);
 			} else if (*p == delim) {
 				p++;
 			} else {
+				problem = "weird";
 				goto bad_address;
 			}
 		}
 
-		alist->nr_addrs++;
-	} while (p < end && alist->nr_addrs < AFS_MAX_ADDRESSES);
+		if (family == AF_INET)
+			afs_merge_fs_addr4(alist, x[0], xport);
+		else
+			afs_merge_fs_addr6(alist, x, xport);
 
+	} while (p < end);
+
+	rcu_assign_pointer(vllist->servers[0].server->addresses, alist);
 	_leave(" = [nr %u]", alist->nr_addrs);
-	return alist;
+	return vllist;
 
-bad_address:
-	kfree(alist);
+inval:
+	_leave(" = -EINVAL [%s %zu %*.*s]",
+	       problem, p - text, (int)len, (int)len, text);
 	return ERR_PTR(-EINVAL);
+bad_address:
+	_leave(" = -EINVAL [%s %zu %*.*s]",
+	       problem, p - text, (int)len, (int)len, text);
+	ret = -EINVAL;
+error:
+	afs_put_addrlist(alist);
+error_vl:
+	afs_put_vlserverlist(net, vllist);
+	return ERR_PTR(ret);
 }
 
 /*
@@ -206,30 +238,34 @@
 /*
  * Perform a DNS query for VL servers and build a up an address list.
  */
-struct afs_addr_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry)
+struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry)
 {
-	struct afs_addr_list *alist;
-	char *vllist = NULL;
+	struct afs_vlserver_list *vllist;
+	char *result = NULL;
 	int ret;
 
 	_enter("%s", cell->name);
 
-	ret = dns_query("afsdb", cell->name, cell->name_len,
-			"", &vllist, _expiry);
-	if (ret < 0)
+	ret = dns_query(cell->net->net, "afsdb", cell->name, cell->name_len,
+			"srv=1", &result, _expiry, true);
+	if (ret < 0) {
+		_leave(" = %d [dns]", ret);
 		return ERR_PTR(ret);
-
-	alist = afs_parse_text_addrs(vllist, strlen(vllist), ',',
-				     VL_SERVICE, AFS_VL_PORT);
-	if (IS_ERR(alist)) {
-		kfree(vllist);
-		if (alist != ERR_PTR(-ENOMEM))
-			pr_err("Failed to parse DNS data\n");
-		return alist;
 	}
 
-	kfree(vllist);
-	return alist;
+	if (*_expiry == 0)
+		*_expiry = ktime_get_real_seconds() + 60;
+
+	if (ret > 1 && result[0] == 0)
+		vllist = afs_extract_vlserver_list(cell, result, ret);
+	else
+		vllist = afs_parse_text_addrs(cell->net, result, ret, ',',
+					      VL_SERVICE, AFS_VL_PORT);
+	kfree(result);
+	if (IS_ERR(vllist) && vllist != ERR_PTR(-ENOMEM))
+		pr_err("Failed to parse DNS data %ld\n", PTR_ERR(vllist));
+
+	return vllist;
 }
 
 /*
@@ -237,19 +273,23 @@
  */
 void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port)
 {
-	struct sockaddr_in6 *a;
-	__be16 xport = htons(port);
+	struct sockaddr_rxrpc *srx;
+	u32 addr = ntohl(xdr);
 	int i;
 
+	if (alist->nr_addrs >= alist->max_addrs)
+		return;
+
 	for (i = 0; i < alist->nr_ipv4; i++) {
-		a = &alist->addrs[i].transport.sin6;
-		if (xdr == a->sin6_addr.s6_addr32[3] &&
-		    xport == a->sin6_port)
+		struct sockaddr_in *a = &alist->addrs[i].transport.sin;
+		u32 a_addr = ntohl(a->sin_addr.s_addr);
+		u16 a_port = ntohs(a->sin_port);
+
+		if (addr == a_addr && port == a_port)
 			return;
-		if (xdr == a->sin6_addr.s6_addr32[3] &&
-		    (u16 __force)xport < (u16 __force)a->sin6_port)
+		if (addr == a_addr && port < a_port)
 			break;
-		if ((u32 __force)xdr < (u32 __force)a->sin6_addr.s6_addr32[3])
+		if (addr < a_addr)
 			break;
 	}
 
@@ -258,12 +298,13 @@
 			alist->addrs + i,
 			sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
 
-	a = &alist->addrs[i].transport.sin6;
-	a->sin6_port		  = xport;
-	a->sin6_addr.s6_addr32[0] = 0;
-	a->sin6_addr.s6_addr32[1] = 0;
-	a->sin6_addr.s6_addr32[2] = htonl(0xffff);
-	a->sin6_addr.s6_addr32[3] = xdr;
+	srx = &alist->addrs[i];
+	srx->srx_family = AF_RXRPC;
+	srx->transport_type = SOCK_DGRAM;
+	srx->transport_len = sizeof(srx->transport.sin);
+	srx->transport.sin.sin_family = AF_INET;
+	srx->transport.sin.sin_port = htons(port);
+	srx->transport.sin.sin_addr.s_addr = xdr;
 	alist->nr_ipv4++;
 	alist->nr_addrs++;
 }
@@ -273,18 +314,20 @@
  */
 void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port)
 {
-	struct sockaddr_in6 *a;
-	__be16 xport = htons(port);
+	struct sockaddr_rxrpc *srx;
 	int i, diff;
 
+	if (alist->nr_addrs >= alist->max_addrs)
+		return;
+
 	for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
-		a = &alist->addrs[i].transport.sin6;
+		struct sockaddr_in6 *a = &alist->addrs[i].transport.sin6;
+		u16 a_port = ntohs(a->sin6_port);
+
 		diff = memcmp(xdr, &a->sin6_addr, 16);
-		if (diff == 0 &&
-		    xport == a->sin6_port)
+		if (diff == 0 && port == a_port)
 			return;
-		if (diff == 0 &&
-		    (u16 __force)xport < (u16 __force)a->sin6_port)
+		if (diff == 0 && port < a_port)
 			break;
 		if (diff < 0)
 			break;
@@ -295,12 +338,13 @@
 			alist->addrs + i,
 			sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
 
-	a = &alist->addrs[i].transport.sin6;
-	a->sin6_port		  = xport;
-	a->sin6_addr.s6_addr32[0] = xdr[0];
-	a->sin6_addr.s6_addr32[1] = xdr[1];
-	a->sin6_addr.s6_addr32[2] = xdr[2];
-	a->sin6_addr.s6_addr32[3] = xdr[3];
+	srx = &alist->addrs[i];
+	srx->srx_family = AF_RXRPC;
+	srx->transport_type = SOCK_DGRAM;
+	srx->transport_len = sizeof(srx->transport.sin6);
+	srx->transport.sin6.sin6_family = AF_INET6;
+	srx->transport.sin6.sin6_port = htons(port);
+	memcpy(&srx->transport.sin6.sin6_addr, xdr, 16);
 	alist->nr_addrs++;
 }
 
@@ -309,25 +353,33 @@
  */
 bool afs_iterate_addresses(struct afs_addr_cursor *ac)
 {
-	_enter("%hu+%hd", ac->start, (short)ac->index);
+	unsigned long set, failed;
+	int index;
 
 	if (!ac->alist)
 		return false;
 
-	if (ac->begun) {
-		ac->index++;
-		if (ac->index == ac->alist->nr_addrs)
-			ac->index = 0;
+	set = ac->alist->responded;
+	failed = ac->alist->failed;
+	_enter("%lx-%lx-%lx,%d", set, failed, ac->tried, ac->index);
 
-		if (ac->index == ac->start) {
-			ac->error = -EDESTADDRREQ;
-			return false;
-		}
-	}
+	ac->nr_iterations++;
 
-	ac->begun = true;
+	set &= ~(failed | ac->tried);
+
+	if (!set)
+		return false;
+
+	index = READ_ONCE(ac->alist->preferred);
+	if (test_bit(index, &set))
+		goto selected;
+
+	index = __ffs(set);
+
+selected:
+	ac->index = index;
+	set_bit(index, &ac->tried);
 	ac->responded = false;
-	ac->addr = &ac->alist->addrs[ac->index];
 	return true;
 }
 
@@ -340,53 +392,13 @@
 
 	alist = ac->alist;
 	if (alist) {
-		if (ac->responded && ac->index != ac->start)
-			WRITE_ONCE(alist->index, ac->index);
+		if (ac->responded &&
+		    ac->index != alist->preferred &&
+		    test_bit(ac->alist->preferred, &ac->tried))
+			WRITE_ONCE(alist->preferred, ac->index);
 		afs_put_addrlist(alist);
+		ac->alist = NULL;
 	}
 
-	ac->addr = NULL;
-	ac->alist = NULL;
-	ac->begun = false;
 	return ac->error;
-}
-
-/*
- * Set the address cursor for iterating over VL servers.
- */
-int afs_set_vl_cursor(struct afs_addr_cursor *ac, struct afs_cell *cell)
-{
-	struct afs_addr_list *alist;
-	int ret;
-
-	if (!rcu_access_pointer(cell->vl_addrs)) {
-		ret = wait_on_bit(&cell->flags, AFS_CELL_FL_NO_LOOKUP_YET,
-				  TASK_INTERRUPTIBLE);
-		if (ret < 0)
-			return ret;
-
-		if (!rcu_access_pointer(cell->vl_addrs) &&
-		    ktime_get_real_seconds() < cell->dns_expiry)
-			return cell->error;
-	}
-
-	read_lock(&cell->vl_addrs_lock);
-	alist = rcu_dereference_protected(cell->vl_addrs,
-					  lockdep_is_held(&cell->vl_addrs_lock));
-	if (alist->nr_addrs > 0)
-		afs_get_addrlist(alist);
-	else
-		alist = NULL;
-	read_unlock(&cell->vl_addrs_lock);
-
-	if (!alist)
-		return -EDESTADDRREQ;
-
-	ac->alist = alist;
-	ac->addr = NULL;
-	ac->start = READ_ONCE(alist->index);
-	ac->index = ac->start;
-	ac->error = 0;
-	ac->begun = false;
-	return 0;
 }

--
Gitblit v1.6.2