From 95099d4622f8cb224d94e314c7a8e0df60b13f87 Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Sat, 09 Dec 2023 08:38:01 +0000
Subject: [PATCH] enable docker ppp

---
 kernel/drivers/virtio/virtio_balloon.c |  588 ++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 501 insertions(+), 87 deletions(-)

diff --git a/kernel/drivers/virtio/virtio_balloon.c b/kernel/drivers/virtio/virtio_balloon.c
index 1afcbef..8985fc2 100644
--- a/kernel/drivers/virtio/virtio_balloon.c
+++ b/kernel/drivers/virtio/virtio_balloon.c
@@ -1,22 +1,9 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  * Virtio balloon implementation, inspired by Dor Laor and Marcelo
  * Tosatti's implementations.
  *
  *  Copyright 2008 Rusty Russell IBM Corporation
- *
- *  This program is free software; you can redistribute it and/or modify
- *  it under the terms of the GNU General Public License as published by
- *  the Free Software Foundation; either version 2 of the License, or
- *  (at your option) any later version.
- *
- *  This program is distributed in the hope that it will be useful,
- *  but WITHOUT ANY WARRANTY; without even the implied warranty of
- *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- *  GNU General Public License for more details.
- *
- *  You should have received a copy of the GNU General Public License
- *  along with this program; if not, write to the Free Software
- *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
  */
 
 #include <linux/virtio.h>
@@ -27,10 +14,13 @@
 #include <linux/slab.h>
 #include <linux/module.h>
 #include <linux/balloon_compaction.h>
+#include <linux/oom.h>
 #include <linux/wait.h>
 #include <linux/mm.h>
 #include <linux/mount.h>
 #include <linux/magic.h>
+#include <linux/pseudo_fs.h>
+#include <linux/page_reporting.h>
 
 /*
  * Balloon device works in 4K page units.  So each page is pointed to by
@@ -39,15 +29,44 @@
  */
 #define VIRTIO_BALLOON_PAGES_PER_PAGE (unsigned)(PAGE_SIZE >> VIRTIO_BALLOON_PFN_SHIFT)
 #define VIRTIO_BALLOON_ARRAY_PFNS_MAX 256
-#define VIRTBALLOON_OOM_NOTIFY_PRIORITY 80
+/* Maximum number of (4k) pages to deflate on OOM notifications. */
+#define VIRTIO_BALLOON_OOM_NR_PAGES 256
+#define VIRTIO_BALLOON_OOM_NOTIFY_PRIORITY 80
+
+#define VIRTIO_BALLOON_FREE_PAGE_ALLOC_FLAG (__GFP_NORETRY | __GFP_NOWARN | \
+					     __GFP_NOMEMALLOC)
+/* The order of free page blocks to report to host */
+#define VIRTIO_BALLOON_HINT_BLOCK_ORDER (MAX_ORDER - 1)
+/* The size of a free page block in bytes */
+#define VIRTIO_BALLOON_HINT_BLOCK_BYTES \
+	(1 << (VIRTIO_BALLOON_HINT_BLOCK_ORDER + PAGE_SHIFT))
+#define VIRTIO_BALLOON_HINT_BLOCK_PAGES (1 << VIRTIO_BALLOON_HINT_BLOCK_ORDER)
 
 #ifdef CONFIG_BALLOON_COMPACTION
 static struct vfsmount *balloon_mnt;
 #endif
 
+enum virtio_balloon_vq {
+	VIRTIO_BALLOON_VQ_INFLATE,
+	VIRTIO_BALLOON_VQ_DEFLATE,
+	VIRTIO_BALLOON_VQ_STATS,
+	VIRTIO_BALLOON_VQ_FREE_PAGE,
+	VIRTIO_BALLOON_VQ_REPORTING,
+	VIRTIO_BALLOON_VQ_MAX
+};
+
+enum virtio_balloon_config_read {
+	VIRTIO_BALLOON_CONFIG_READ_CMD_ID = 0,
+};
+
 struct virtio_balloon {
 	struct virtio_device *vdev;
-	struct virtqueue *inflate_vq, *deflate_vq, *stats_vq;
+	struct virtqueue *inflate_vq, *deflate_vq, *stats_vq, *free_page_vq;
+
+	/* Balloon's own wq for cpu-intensive work items */
+	struct workqueue_struct *balloon_wq;
+	/* The free page reporting work item submitted to the balloon wq */
+	struct work_struct report_free_page_work;
 
 	/* The balloon servicing is delegated to a freezable workqueue. */
 	struct work_struct update_balloon_stats_work;
@@ -56,6 +75,24 @@
 	/* Prevent updating balloon when it is being canceled. */
 	spinlock_t stop_update_lock;
 	bool stop_update;
+	/* Bitmap to indicate if reading the related config fields are needed */
+	unsigned long config_read_bitmap;
+
+	/* The list of allocated free pages, waiting to be given back to mm */
+	struct list_head free_page_list;
+	spinlock_t free_page_list_lock;
+	/* The number of free page blocks on the above list */
+	unsigned long num_free_page_blocks;
+	/*
+	 * The cmd id received from host.
+	 * Read it via virtio_balloon_cmd_id_received to get the latest value
+	 * sent from host.
+	 */
+	u32 cmd_id_received_cache;
+	/* The cmd id that is actively in use */
+	__virtio32 cmd_id_active;
+	/* Buffer to store the stop sign */
+	__virtio32 cmd_id_stop;
 
 	/* Waiting for host to ack the pages we released. */
 	wait_queue_head_t acked;
@@ -80,11 +117,18 @@
 	/* Memory statistics */
 	struct virtio_balloon_stat stats[VIRTIO_BALLOON_S_NR];
 
-	/* To register a shrinker to shrink memory upon memory pressure */
+	/* Shrinker to return free pages - VIRTIO_BALLOON_F_FREE_PAGE_HINT */
 	struct shrinker shrinker;
+
+	/* OOM notifier to deflate on OOM - VIRTIO_BALLOON_F_DEFLATE_ON_OOM */
+	struct notifier_block oom_nb;
+
+	/* Free page reporting device */
+	struct virtqueue *reporting_vq;
+	struct page_reporting_dev_info pr_dev_info;
 };
 
-static struct virtio_device_id id_table[] = {
+static const struct virtio_device_id id_table[] = {
 	{ VIRTIO_ID_BALLOON, VIRTIO_DEV_ANY_ID },
 	{ 0 },
 };
@@ -119,6 +163,33 @@
 	/* When host has read buffer, this completes via balloon_ack */
 	wait_event(vb->acked, virtqueue_get_buf(vq, &len));
 
+}
+
+static int virtballoon_free_page_report(struct page_reporting_dev_info *pr_dev_info,
+				   struct scatterlist *sg, unsigned int nents)
+{
+	struct virtio_balloon *vb =
+		container_of(pr_dev_info, struct virtio_balloon, pr_dev_info);
+	struct virtqueue *vq = vb->reporting_vq;
+	unsigned int unused, err;
+
+	/* We should always be able to add these buffers to an empty queue. */
+	err = virtqueue_add_inbuf(vq, sg, nents, vb, GFP_NOWAIT | __GFP_NOWARN);
+
+	/*
+	 * In the extremely unlikely case that something has occurred and we
+	 * are able to trigger an error we will simply display a warning
+	 * and exit without actually processing the pages.
+	 */
+	if (WARN_ON_ONCE(err))
+		return err;
+
+	virtqueue_kick(vq);
+
+	/* When host has read buffer, this completes via balloon_ack */
+	wait_event(vb->acked, virtqueue_get_buf(vq, &unused));
+
+	return 0;
 }
 
 static void set_page_pfns(struct virtio_balloon *vb,
@@ -322,31 +393,65 @@
 	virtqueue_kick(vq);
 }
 
+static inline s64 towards_target(struct virtio_balloon *vb)
+{
+	s64 target;
+	u32 num_pages;
+
+	/* Legacy balloon config space is LE, unlike all other devices. */
+	virtio_cread_le(vb->vdev, struct virtio_balloon_config, num_pages,
+			&num_pages);
+
+	target = num_pages;
+	return target - vb->num_pages;
+}
+
+/* Gives back @num_to_return blocks of free pages to mm. */
+static unsigned long return_free_pages_to_mm(struct virtio_balloon *vb,
+					     unsigned long num_to_return)
+{
+	struct page *page;
+	unsigned long num_returned;
+
+	spin_lock_irq(&vb->free_page_list_lock);
+	for (num_returned = 0; num_returned < num_to_return; num_returned++) {
+		page = balloon_page_pop(&vb->free_page_list);
+		if (!page)
+			break;
+		free_pages((unsigned long)page_address(page),
+			   VIRTIO_BALLOON_HINT_BLOCK_ORDER);
+	}
+	vb->num_free_page_blocks -= num_returned;
+	spin_unlock_irq(&vb->free_page_list_lock);
+
+	return num_returned;
+}
+
+static void virtio_balloon_queue_free_page_work(struct virtio_balloon *vb)
+{
+	if (!virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
+		return;
+
+	/* No need to queue the work if the bit was already set. */
+	if (test_and_set_bit(VIRTIO_BALLOON_CONFIG_READ_CMD_ID,
+			     &vb->config_read_bitmap))
+		return;
+
+	queue_work(vb->balloon_wq, &vb->report_free_page_work);
+}
+
 static void virtballoon_changed(struct virtio_device *vdev)
 {
 	struct virtio_balloon *vb = vdev->priv;
 	unsigned long flags;
 
 	spin_lock_irqsave(&vb->stop_update_lock, flags);
-	if (!vb->stop_update)
-		queue_work(system_freezable_wq, &vb->update_balloon_size_work);
+	if (!vb->stop_update) {
+		queue_work(system_freezable_wq,
+			   &vb->update_balloon_size_work);
+		virtio_balloon_queue_free_page_work(vb);
+	}
 	spin_unlock_irqrestore(&vb->stop_update_lock, flags);
-}
-
-static inline s64 towards_target(struct virtio_balloon *vb)
-{
-	s64 target;
-	u32 num_pages;
-
-	virtio_cread(vb->vdev, struct virtio_balloon_config, num_pages,
-		     &num_pages);
-
-	/* Legacy balloon config space is LE, unlike all other devices. */
-	if (!virtio_has_feature(vb->vdev, VIRTIO_F_VERSION_1))
-		num_pages = le32_to_cpu((__force __le32)num_pages);
-
-	target = num_pages;
-	return target - vb->num_pages;
 }
 
 static void update_balloon_size(struct virtio_balloon *vb)
@@ -354,11 +459,8 @@
 	u32 actual = vb->num_pages;
 
 	/* Legacy balloon config space is LE, unlike all other devices. */
-	if (!virtio_has_feature(vb->vdev, VIRTIO_F_VERSION_1))
-		actual = (__force u32)cpu_to_le32(actual);
-
-	virtio_cwrite(vb->vdev, struct virtio_balloon_config, actual,
-		      &actual);
+	virtio_cwrite_le(vb->vdev, struct virtio_balloon_config, actual,
+			 &actual);
 }
 
 static void update_balloon_stats_func(struct work_struct *work)
@@ -379,9 +481,12 @@
 			  update_balloon_size_work);
 	diff = towards_target(vb);
 
+	if (!diff)
+		return;
+
 	if (diff > 0)
 		diff -= fill_balloon(vb, diff);
-	else if (diff < 0)
+	else
 		diff += leak_balloon(vb, -diff);
 	update_balloon_size(vb);
 
@@ -391,26 +496,52 @@
 
 static int init_vqs(struct virtio_balloon *vb)
 {
-	struct virtqueue *vqs[3];
-	vq_callback_t *callbacks[] = { balloon_ack, balloon_ack, stats_request };
-	static const char * const names[] = { "inflate", "deflate", "stats" };
-	int err, nvqs;
+	struct virtqueue *vqs[VIRTIO_BALLOON_VQ_MAX];
+	vq_callback_t *callbacks[VIRTIO_BALLOON_VQ_MAX];
+	const char *names[VIRTIO_BALLOON_VQ_MAX];
+	int err;
 
 	/*
-	 * We expect two virtqueues: inflate and deflate, and
-	 * optionally stat.
+	 * Inflateq and deflateq are used unconditionally. The names[]
+	 * will be NULL if the related feature is not enabled, which will
+	 * cause no allocation for the corresponding virtqueue in find_vqs.
 	 */
-	nvqs = virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ) ? 3 : 2;
-	err = virtio_find_vqs(vb->vdev, nvqs, vqs, callbacks, names, NULL);
+	callbacks[VIRTIO_BALLOON_VQ_INFLATE] = balloon_ack;
+	names[VIRTIO_BALLOON_VQ_INFLATE] = "inflate";
+	callbacks[VIRTIO_BALLOON_VQ_DEFLATE] = balloon_ack;
+	names[VIRTIO_BALLOON_VQ_DEFLATE] = "deflate";
+	callbacks[VIRTIO_BALLOON_VQ_STATS] = NULL;
+	names[VIRTIO_BALLOON_VQ_STATS] = NULL;
+	callbacks[VIRTIO_BALLOON_VQ_FREE_PAGE] = NULL;
+	names[VIRTIO_BALLOON_VQ_FREE_PAGE] = NULL;
+	names[VIRTIO_BALLOON_VQ_REPORTING] = NULL;
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ)) {
+		names[VIRTIO_BALLOON_VQ_STATS] = "stats";
+		callbacks[VIRTIO_BALLOON_VQ_STATS] = stats_request;
+	}
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT)) {
+		names[VIRTIO_BALLOON_VQ_FREE_PAGE] = "free_page_vq";
+		callbacks[VIRTIO_BALLOON_VQ_FREE_PAGE] = NULL;
+	}
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_REPORTING)) {
+		names[VIRTIO_BALLOON_VQ_REPORTING] = "reporting_vq";
+		callbacks[VIRTIO_BALLOON_VQ_REPORTING] = balloon_ack;
+	}
+
+	err = vb->vdev->config->find_vqs(vb->vdev, VIRTIO_BALLOON_VQ_MAX,
+					 vqs, callbacks, names, NULL, NULL);
 	if (err)
 		return err;
 
-	vb->inflate_vq = vqs[0];
-	vb->deflate_vq = vqs[1];
+	vb->inflate_vq = vqs[VIRTIO_BALLOON_VQ_INFLATE];
+	vb->deflate_vq = vqs[VIRTIO_BALLOON_VQ_DEFLATE];
 	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ)) {
 		struct scatterlist sg;
 		unsigned int num_stats;
-		vb->stats_vq = vqs[2];
+		vb->stats_vq = vqs[VIRTIO_BALLOON_VQ_STATS];
 
 		/*
 		 * Prime this virtqueue with one buffer so the hypervisor can
@@ -428,7 +559,176 @@
 		}
 		virtqueue_kick(vb->stats_vq);
 	}
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
+		vb->free_page_vq = vqs[VIRTIO_BALLOON_VQ_FREE_PAGE];
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_REPORTING))
+		vb->reporting_vq = vqs[VIRTIO_BALLOON_VQ_REPORTING];
+
 	return 0;
+}
+
+static u32 virtio_balloon_cmd_id_received(struct virtio_balloon *vb)
+{
+	if (test_and_clear_bit(VIRTIO_BALLOON_CONFIG_READ_CMD_ID,
+			       &vb->config_read_bitmap)) {
+		/* Legacy balloon config space is LE, unlike all other devices. */
+		virtio_cread_le(vb->vdev, struct virtio_balloon_config,
+				free_page_hint_cmd_id,
+				&vb->cmd_id_received_cache);
+	}
+
+	return vb->cmd_id_received_cache;
+}
+
+static int send_cmd_id_start(struct virtio_balloon *vb)
+{
+	struct scatterlist sg;
+	struct virtqueue *vq = vb->free_page_vq;
+	int err, unused;
+
+	/* Detach all the used buffers from the vq */
+	while (virtqueue_get_buf(vq, &unused))
+		;
+
+	vb->cmd_id_active = cpu_to_virtio32(vb->vdev,
+					virtio_balloon_cmd_id_received(vb));
+	sg_init_one(&sg, &vb->cmd_id_active, sizeof(vb->cmd_id_active));
+	err = virtqueue_add_outbuf(vq, &sg, 1, &vb->cmd_id_active, GFP_KERNEL);
+	if (!err)
+		virtqueue_kick(vq);
+	return err;
+}
+
+static int send_cmd_id_stop(struct virtio_balloon *vb)
+{
+	struct scatterlist sg;
+	struct virtqueue *vq = vb->free_page_vq;
+	int err, unused;
+
+	/* Detach all the used buffers from the vq */
+	while (virtqueue_get_buf(vq, &unused))
+		;
+
+	sg_init_one(&sg, &vb->cmd_id_stop, sizeof(vb->cmd_id_stop));
+	err = virtqueue_add_outbuf(vq, &sg, 1, &vb->cmd_id_stop, GFP_KERNEL);
+	if (!err)
+		virtqueue_kick(vq);
+	return err;
+}
+
+static int get_free_page_and_send(struct virtio_balloon *vb)
+{
+	struct virtqueue *vq = vb->free_page_vq;
+	struct page *page;
+	struct scatterlist sg;
+	int err, unused;
+	void *p;
+
+	/* Detach all the used buffers from the vq */
+	while (virtqueue_get_buf(vq, &unused))
+		;
+
+	page = alloc_pages(VIRTIO_BALLOON_FREE_PAGE_ALLOC_FLAG,
+			   VIRTIO_BALLOON_HINT_BLOCK_ORDER);
+	/*
+	 * When the allocation returns NULL, it indicates that we have got all
+	 * the possible free pages, so return -EINTR to stop.
+	 */
+	if (!page)
+		return -EINTR;
+
+	p = page_address(page);
+	sg_init_one(&sg, p, VIRTIO_BALLOON_HINT_BLOCK_BYTES);
+	/* There is always 1 entry reserved for the cmd id to use. */
+	if (vq->num_free > 1) {
+		err = virtqueue_add_inbuf(vq, &sg, 1, p, GFP_KERNEL);
+		if (unlikely(err)) {
+			free_pages((unsigned long)p,
+				   VIRTIO_BALLOON_HINT_BLOCK_ORDER);
+			return err;
+		}
+		virtqueue_kick(vq);
+		spin_lock_irq(&vb->free_page_list_lock);
+		balloon_page_push(&vb->free_page_list, page);
+		vb->num_free_page_blocks++;
+		spin_unlock_irq(&vb->free_page_list_lock);
+	} else {
+		/*
+		 * The vq has no available entry to add this page block, so
+		 * just free it.
+		 */
+		free_pages((unsigned long)p, VIRTIO_BALLOON_HINT_BLOCK_ORDER);
+	}
+
+	return 0;
+}
+
+static int send_free_pages(struct virtio_balloon *vb)
+{
+	int err;
+	u32 cmd_id_active;
+
+	while (1) {
+		/*
+		 * If a stop id or a new cmd id was just received from host,
+		 * stop the reporting.
+		 */
+		cmd_id_active = virtio32_to_cpu(vb->vdev, vb->cmd_id_active);
+		if (unlikely(cmd_id_active !=
+			     virtio_balloon_cmd_id_received(vb)))
+			break;
+
+		/*
+		 * The free page blocks are allocated and sent to host one by
+		 * one.
+		 */
+		err = get_free_page_and_send(vb);
+		if (err == -EINTR)
+			break;
+		else if (unlikely(err))
+			return err;
+	}
+
+	return 0;
+}
+
+static void virtio_balloon_report_free_page(struct virtio_balloon *vb)
+{
+	int err;
+	struct device *dev = &vb->vdev->dev;
+
+	/* Start by sending the received cmd id to host with an outbuf. */
+	err = send_cmd_id_start(vb);
+	if (unlikely(err))
+		dev_err(dev, "Failed to send a start id, err = %d\n", err);
+
+	err = send_free_pages(vb);
+	if (unlikely(err))
+		dev_err(dev, "Failed to send a free page, err = %d\n", err);
+
+	/* End by sending a stop id to host with an outbuf. */
+	err = send_cmd_id_stop(vb);
+	if (unlikely(err))
+		dev_err(dev, "Failed to send a stop id, err = %d\n", err);
+}
+
+static void report_free_page_func(struct work_struct *work)
+{
+	struct virtio_balloon *vb = container_of(work, struct virtio_balloon,
+						 report_free_page_work);
+	u32 cmd_id_received;
+
+	cmd_id_received = virtio_balloon_cmd_id_received(vb);
+	if (cmd_id_received == VIRTIO_BALLOON_CMD_ID_DONE) {
+		/* Pass ULONG_MAX to give back all the free pages */
+		return_free_pages_to_mm(vb, ULONG_MAX);
+	} else if (cmd_id_received != VIRTIO_BALLOON_CMD_ID_STOP &&
+		   cmd_id_received !=
+		   virtio32_to_cpu(vb->vdev, vb->cmd_id_active)) {
+		virtio_balloon_report_free_page(vb);
+	}
 }
 
 #ifdef CONFIG_BALLOON_COMPACTION
@@ -506,46 +806,39 @@
 	return MIGRATEPAGE_SUCCESS;
 }
 
-static struct dentry *balloon_mount(struct file_system_type *fs_type,
-		int flags, const char *dev_name, void *data)
+static int balloon_init_fs_context(struct fs_context *fc)
 {
-	static const struct dentry_operations ops = {
-		.d_dname = simple_dname,
-	};
-
-	return mount_pseudo(fs_type, "balloon-kvm:", NULL, &ops,
-				BALLOON_KVM_MAGIC);
+	return init_pseudo(fc, BALLOON_KVM_MAGIC) ? 0 : -ENOMEM;
 }
 
 static struct file_system_type balloon_fs = {
 	.name           = "balloon-kvm",
-	.mount          = balloon_mount,
+	.init_fs_context = balloon_init_fs_context,
 	.kill_sb        = kill_anon_super,
 };
 
 #endif /* CONFIG_BALLOON_COMPACTION */
 
+static unsigned long shrink_free_pages(struct virtio_balloon *vb,
+				       unsigned long pages_to_free)
+{
+	unsigned long blocks_to_free, blocks_freed;
+
+	pages_to_free = round_up(pages_to_free,
+				 VIRTIO_BALLOON_HINT_BLOCK_PAGES);
+	blocks_to_free = pages_to_free / VIRTIO_BALLOON_HINT_BLOCK_PAGES;
+	blocks_freed = return_free_pages_to_mm(vb, blocks_to_free);
+
+	return blocks_freed * VIRTIO_BALLOON_HINT_BLOCK_PAGES;
+}
+
 static unsigned long virtio_balloon_shrinker_scan(struct shrinker *shrinker,
 						  struct shrink_control *sc)
 {
-	unsigned long pages_to_free, pages_freed = 0;
 	struct virtio_balloon *vb = container_of(shrinker,
 					struct virtio_balloon, shrinker);
 
-	pages_to_free = sc->nr_to_scan * VIRTIO_BALLOON_PAGES_PER_PAGE;
-
-	/*
-	 * One invocation of leak_balloon can deflate at most
-	 * VIRTIO_BALLOON_ARRAY_PFNS_MAX balloon pages, so we call it
-	 * multiple times to deflate pages till reaching pages_to_free.
-	 */
-	while (vb->num_pages && pages_to_free) {
-		pages_to_free -= pages_freed;
-		pages_freed += leak_balloon(vb, pages_to_free);
-	}
-	update_balloon_size(vb);
-
-	return pages_freed / VIRTIO_BALLOON_PAGES_PER_PAGE;
+	return shrink_free_pages(vb, sc->nr_to_scan);
 }
 
 static unsigned long virtio_balloon_shrinker_count(struct shrinker *shrinker,
@@ -554,7 +847,21 @@
 	struct virtio_balloon *vb = container_of(shrinker,
 					struct virtio_balloon, shrinker);
 
-	return vb->num_pages / VIRTIO_BALLOON_PAGES_PER_PAGE;
+	return vb->num_free_page_blocks * VIRTIO_BALLOON_HINT_BLOCK_PAGES;
+}
+
+static int virtio_balloon_oom_notify(struct notifier_block *nb,
+				     unsigned long dummy, void *parm)
+{
+	struct virtio_balloon *vb = container_of(nb,
+						 struct virtio_balloon, oom_nb);
+	unsigned long *freed = parm;
+
+	*freed += leak_balloon(vb, VIRTIO_BALLOON_OOM_NR_PAGES) /
+		  VIRTIO_BALLOON_PAGES_PER_PAGE;
+	update_balloon_size(vb);
+
+	return NOTIFY_OK;
 }
 
 static void virtio_balloon_unregister_shrinker(struct virtio_balloon *vb)
@@ -612,27 +919,107 @@
 	vb->vb_dev_info.inode = alloc_anon_inode(balloon_mnt->mnt_sb);
 	if (IS_ERR(vb->vb_dev_info.inode)) {
 		err = PTR_ERR(vb->vb_dev_info.inode);
-		kern_unmount(balloon_mnt);
-		goto out_del_vqs;
+		goto out_kern_unmount;
 	}
 	vb->vb_dev_info.inode->i_mapping->a_ops = &balloon_aops;
 #endif
-	/*
-	 * We continue to use VIRTIO_BALLOON_F_DEFLATE_ON_OOM to decide if a
-	 * shrinker needs to be registered to relieve memory pressure.
-	 */
-	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_DEFLATE_ON_OOM)) {
+	if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT)) {
+		/*
+		 * There is always one entry reserved for cmd id, so the ring
+		 * size needs to be at least two to report free page hints.
+		 */
+		if (virtqueue_get_vring_size(vb->free_page_vq) < 2) {
+			err = -ENOSPC;
+			goto out_iput;
+		}
+		vb->balloon_wq = alloc_workqueue("balloon-wq",
+					WQ_FREEZABLE | WQ_CPU_INTENSIVE, 0);
+		if (!vb->balloon_wq) {
+			err = -ENOMEM;
+			goto out_iput;
+		}
+		INIT_WORK(&vb->report_free_page_work, report_free_page_func);
+		vb->cmd_id_received_cache = VIRTIO_BALLOON_CMD_ID_STOP;
+		vb->cmd_id_active = cpu_to_virtio32(vb->vdev,
+						  VIRTIO_BALLOON_CMD_ID_STOP);
+		vb->cmd_id_stop = cpu_to_virtio32(vb->vdev,
+						  VIRTIO_BALLOON_CMD_ID_STOP);
+		spin_lock_init(&vb->free_page_list_lock);
+		INIT_LIST_HEAD(&vb->free_page_list);
+		/*
+		 * We're allowed to reuse any free pages, even if they are
+		 * still to be processed by the host.
+		 */
 		err = virtio_balloon_register_shrinker(vb);
 		if (err)
-			goto out_del_vqs;
+			goto out_del_balloon_wq;
 	}
+
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_DEFLATE_ON_OOM)) {
+		vb->oom_nb.notifier_call = virtio_balloon_oom_notify;
+		vb->oom_nb.priority = VIRTIO_BALLOON_OOM_NOTIFY_PRIORITY;
+		err = register_oom_notifier(&vb->oom_nb);
+		if (err < 0)
+			goto out_unregister_shrinker;
+	}
+
+	if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_PAGE_POISON)) {
+		/* Start with poison val of 0 representing general init */
+		__u32 poison_val = 0;
+
+		/*
+		 * Let the hypervisor know that we are expecting a
+		 * specific value to be written back in balloon pages.
+		 *
+		 * If the PAGE_POISON value was larger than a byte we would
+		 * need to byte swap poison_val here to guarantee it is
+		 * little-endian. However for now it is a single byte so we
+		 * can pass it as-is.
+		 */
+		if (!want_init_on_free())
+			memset(&poison_val, PAGE_POISON, sizeof(poison_val));
+
+		virtio_cwrite_le(vb->vdev, struct virtio_balloon_config,
+				 poison_val, &poison_val);
+	}
+
+	vb->pr_dev_info.report = virtballoon_free_page_report;
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_REPORTING)) {
+		unsigned int capacity;
+
+		capacity = virtqueue_get_vring_size(vb->reporting_vq);
+		if (capacity < PAGE_REPORTING_CAPACITY) {
+			err = -ENOSPC;
+			goto out_unregister_oom;
+		}
+
+		err = page_reporting_register(&vb->pr_dev_info);
+		if (err)
+			goto out_unregister_oom;
+	}
+
 	virtio_device_ready(vdev);
 
 	if (towards_target(vb))
 		virtballoon_changed(vdev);
 	return 0;
 
+out_unregister_oom:
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_DEFLATE_ON_OOM))
+		unregister_oom_notifier(&vb->oom_nb);
+out_unregister_shrinker:
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
+		virtio_balloon_unregister_shrinker(vb);
+out_del_balloon_wq:
+	if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
+		destroy_workqueue(vb->balloon_wq);
+out_iput:
+#ifdef CONFIG_BALLOON_COMPACTION
+	iput(vb->vb_dev_info.inode);
+out_kern_unmount:
+	kern_unmount(balloon_mnt);
 out_del_vqs:
+#endif
 	vdev->config->del_vqs(vdev);
 out_free_vb:
 	kfree(vb);
@@ -647,6 +1034,10 @@
 		leak_balloon(vb, vb->num_pages);
 	update_balloon_size(vb);
 
+	/* There might be free pages that are being reported: release them. */
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
+		return_free_pages_to_mm(vb, ULONG_MAX);
+
 	/* Now we reset the device so we can clean up the queues. */
 	vb->vdev->config->reset(vb->vdev);
 
@@ -657,13 +1048,22 @@
 {
 	struct virtio_balloon *vb = vdev->priv;
 
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_REPORTING))
+		page_reporting_unregister(&vb->pr_dev_info);
 	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_DEFLATE_ON_OOM))
+		unregister_oom_notifier(&vb->oom_nb);
+	if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
 		virtio_balloon_unregister_shrinker(vb);
 	spin_lock_irq(&vb->stop_update_lock);
 	vb->stop_update = true;
 	spin_unlock_irq(&vb->stop_update_lock);
 	cancel_work_sync(&vb->update_balloon_size_work);
 	cancel_work_sync(&vb->update_balloon_stats_work);
+
+	if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT)) {
+		cancel_work_sync(&vb->report_free_page_work);
+		destroy_workqueue(vb->balloon_wq);
+	}
 
 	remove_common(vb);
 #ifdef CONFIG_BALLOON_COMPACTION
@@ -708,7 +1108,18 @@
 
 static int virtballoon_validate(struct virtio_device *vdev)
 {
-	__virtio_clear_bit(vdev, VIRTIO_F_IOMMU_PLATFORM);
+	/*
+	 * Inform the hypervisor that our pages are poisoned or
+	 * initialized. If we cannot do that then we should disable
+	 * page reporting as it could potentially change the contents
+	 * of our free pages.
+	 */
+	if (!want_init_on_free() && !page_poisoning_enabled_static())
+		__virtio_clear_bit(vdev, VIRTIO_BALLOON_F_PAGE_POISON);
+	else if (!virtio_has_feature(vdev, VIRTIO_BALLOON_F_PAGE_POISON))
+		__virtio_clear_bit(vdev, VIRTIO_BALLOON_F_REPORTING);
+
+	__virtio_clear_bit(vdev, VIRTIO_F_ACCESS_PLATFORM);
 	return 0;
 }
 
@@ -716,6 +1127,9 @@
 	VIRTIO_BALLOON_F_MUST_TELL_HOST,
 	VIRTIO_BALLOON_F_STATS_VQ,
 	VIRTIO_BALLOON_F_DEFLATE_ON_OOM,
+	VIRTIO_BALLOON_F_FREE_PAGE_HINT,
+	VIRTIO_BALLOON_F_PAGE_POISON,
+	VIRTIO_BALLOON_F_REPORTING,
 };
 
 static struct virtio_driver virtio_balloon_driver = {

--
Gitblit v1.6.2