From d2ccde1c8e90d38cee87a1b0309ad2827f3fd30d Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Mon, 11 Dec 2023 02:45:28 +0000
Subject: [PATCH] add boot partition  size

---
 kernel/arch/x86/net/bpf_jit_comp.c | 1221 ++++++++++++++++++++++++++++++++++++++++++++++++++++------
 1 files changed, 1,093 insertions(+), 128 deletions(-)

diff --git a/kernel/arch/x86/net/bpf_jit_comp.c b/kernel/arch/x86/net/bpf_jit_comp.c
index 81c3d4b..547ec83 100644
--- a/kernel/arch/x86/net/bpf_jit_comp.c
+++ b/kernel/arch/x86/net/bpf_jit_comp.c
@@ -1,21 +1,20 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * bpf_jit_comp.c: BPF JIT compiler
  *
  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
  * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
- *
- * This program is free software; you can redistribute it and/or
- * modify it under the terms of the GNU General Public License
- * as published by the Free Software Foundation; version 2
- * of the License.
  */
 #include <linux/netdevice.h>
 #include <linux/filter.h>
 #include <linux/if_vlan.h>
 #include <linux/bpf.h>
-
+#include <linux/memory.h>
+#include <linux/sort.h>
+#include <asm/extable.h>
 #include <asm/set_memory.h>
 #include <asm/nospec-branch.h>
+#include <asm/text-patching.h>
 
 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
 {
@@ -100,6 +99,7 @@
 
 /* Pick a register outside of BPF range for JIT internal work */
 #define AUX_REG (MAX_BPF_JIT_REG + 1)
+#define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
 
 /*
  * The following table maps BPF registers to x86-64 registers.
@@ -108,8 +108,8 @@
  * register in load/store instructions, it always needs an
  * extra byte of encoding and is callee saved.
  *
- * Also x86-64 register R9 is unused. x86-64 register R10 is
- * used for blinding (if enabled).
+ * x86-64 register R9 is not used by BPF programs, but can be used by BPF
+ * trampoline. x86-64 register R10 is used for blinding (if enabled).
  */
 static const int reg2hex[] = {
 	[BPF_REG_0] = 0,  /* RAX */
@@ -125,6 +125,20 @@
 	[BPF_REG_FP] = 5, /* RBP readonly */
 	[BPF_REG_AX] = 2, /* R10 temp register */
 	[AUX_REG] = 3,    /* R11 temp register */
+	[X86_REG_R9] = 1, /* R9 register, 6th function argument */
+};
+
+static const int reg2pt_regs[] = {
+	[BPF_REG_0] = offsetof(struct pt_regs, ax),
+	[BPF_REG_1] = offsetof(struct pt_regs, di),
+	[BPF_REG_2] = offsetof(struct pt_regs, si),
+	[BPF_REG_3] = offsetof(struct pt_regs, dx),
+	[BPF_REG_4] = offsetof(struct pt_regs, cx),
+	[BPF_REG_5] = offsetof(struct pt_regs, r8),
+	[BPF_REG_6] = offsetof(struct pt_regs, bx),
+	[BPF_REG_7] = offsetof(struct pt_regs, r13),
+	[BPF_REG_8] = offsetof(struct pt_regs, r14),
+	[BPF_REG_9] = offsetof(struct pt_regs, r15),
 };
 
 /*
@@ -139,6 +153,7 @@
 			     BIT(BPF_REG_7) |
 			     BIT(BPF_REG_8) |
 			     BIT(BPF_REG_9) |
+			     BIT(X86_REG_R9) |
 			     BIT(BPF_REG_AX));
 }
 
@@ -197,36 +212,206 @@
 
 struct jit_context {
 	int cleanup_addr; /* Epilogue code offset */
+
+	/*
+	 * Program specific offsets of labels in the code; these rely on the
+	 * JIT doing at least 2 passes, recording the position on the first
+	 * pass, only to generate the correct offset on the second pass.
+	 */
+	int tail_call_direct_label;
+	int tail_call_indirect_label;
 };
 
 /* Maximum number of bytes emitted while JITing one eBPF insn */
 #define BPF_MAX_INSN_SIZE	128
 #define BPF_INSN_SAFETY		64
 
-#define PROLOGUE_SIZE		20
+/* Number of bytes emit_patch() needs to generate instructions */
+#define X86_PATCH_SIZE		5
+/* Number of bytes that will be skipped on tailcall */
+#define X86_TAIL_CALL_OFFSET	11
 
-/*
- * Emit x86-64 prologue code for BPF program and check its size.
- * bpf_tail_call helper will skip it while jumping into another program
- */
-static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
+static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
 {
 	u8 *prog = *pprog;
 	int cnt = 0;
 
+	if (callee_regs_used[0])
+		EMIT1(0x53);         /* push rbx */
+	if (callee_regs_used[1])
+		EMIT2(0x41, 0x55);   /* push r13 */
+	if (callee_regs_used[2])
+		EMIT2(0x41, 0x56);   /* push r14 */
+	if (callee_regs_used[3])
+		EMIT2(0x41, 0x57);   /* push r15 */
+	*pprog = prog;
+}
+
+static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	if (callee_regs_used[3])
+		EMIT2(0x41, 0x5F);   /* pop r15 */
+	if (callee_regs_used[2])
+		EMIT2(0x41, 0x5E);   /* pop r14 */
+	if (callee_regs_used[1])
+		EMIT2(0x41, 0x5D);   /* pop r13 */
+	if (callee_regs_used[0])
+		EMIT1(0x5B);         /* pop rbx */
+	*pprog = prog;
+}
+
+/*
+ * Emit x86-64 prologue code for BPF program.
+ * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
+ * while jumping to another program
+ */
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
+			  bool tail_call_reachable, bool is_subprog)
+{
+	u8 *prog = *pprog;
+	int cnt = X86_PATCH_SIZE;
+
+	/* BPF trampoline can be made to work without these nops,
+	 * but let's waste 5 bytes for now and optimize later
+	 */
+	memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
+	prog += cnt;
+	if (!ebpf_from_cbpf) {
+		if (tail_call_reachable && !is_subprog)
+			EMIT2(0x31, 0xC0); /* xor eax, eax */
+		else
+			EMIT2(0x66, 0x90); /* nop2 */
+	}
 	EMIT1(0x55);             /* push rbp */
 	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
 	/* sub rsp, rounded_stack_depth */
-	EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
-	EMIT1(0x53);             /* push rbx */
-	EMIT2(0x41, 0x55);       /* push r13 */
-	EMIT2(0x41, 0x56);       /* push r14 */
-	EMIT2(0x41, 0x57);       /* push r15 */
-	if (!ebpf_from_cbpf) {
-		/* zero init tail_call_cnt */
-		EMIT2(0x6a, 0x00);
-		BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
+	if (tail_call_reachable)
+		EMIT1(0x50);         /* push rax */
+	*pprog = prog;
+}
+
+static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+	s64 offset;
+
+	offset = func - (ip + X86_PATCH_SIZE);
+	if (!is_simm32(offset)) {
+		pr_err("Target call %p is out of range\n", func);
+		return -ERANGE;
 	}
+	EMIT1_off32(opcode, offset);
+	*pprog = prog;
+	return 0;
+}
+
+static int emit_call(u8 **pprog, void *func, void *ip)
+{
+	return emit_patch(pprog, func, ip, 0xE8);
+}
+
+static int emit_jump(u8 **pprog, void *func, void *ip)
+{
+	return emit_patch(pprog, func, ip, 0xE9);
+}
+
+static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
+				void *old_addr, void *new_addr,
+				const bool text_live)
+{
+	const u8 *nop_insn = ideal_nops[NOP_ATOMIC5];
+	u8 old_insn[X86_PATCH_SIZE];
+	u8 new_insn[X86_PATCH_SIZE];
+	u8 *prog;
+	int ret;
+
+	memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
+	if (old_addr) {
+		prog = old_insn;
+		ret = t == BPF_MOD_CALL ?
+		      emit_call(&prog, old_addr, ip) :
+		      emit_jump(&prog, old_addr, ip);
+		if (ret)
+			return ret;
+	}
+
+	memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
+	if (new_addr) {
+		prog = new_insn;
+		ret = t == BPF_MOD_CALL ?
+		      emit_call(&prog, new_addr, ip) :
+		      emit_jump(&prog, new_addr, ip);
+		if (ret)
+			return ret;
+	}
+
+	ret = -EBUSY;
+	mutex_lock(&text_mutex);
+	if (memcmp(ip, old_insn, X86_PATCH_SIZE))
+		goto out;
+	ret = 1;
+	if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
+		if (text_live)
+			text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
+		else
+			memcpy(ip, new_insn, X86_PATCH_SIZE);
+		ret = 0;
+	}
+out:
+	mutex_unlock(&text_mutex);
+	return ret;
+}
+
+int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
+		       void *old_addr, void *new_addr)
+{
+	if (!is_kernel_text((long)ip) &&
+	    !is_bpf_text_address((long)ip))
+		/* BPF poking in modules is not supported */
+		return -EINVAL;
+
+	return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
+}
+
+#define EMIT_LFENCE()	EMIT3(0x0F, 0xAE, 0xE8)
+
+static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+#ifdef CONFIG_RETPOLINE
+	if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) {
+		EMIT_LFENCE();
+		EMIT2(0xFF, 0xE0 + reg);
+	} else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) {
+		emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip);
+	} else
+#endif
+	EMIT2(0xFF, 0xE0 + reg);
+
+	*pprog = prog;
+}
+
+static void emit_return(u8 **pprog, u8 *ip)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
+		emit_jump(&prog, &__x86_return_thunk, ip);
+	} else {
+		EMIT1(0xC3);		/* ret */
+		if (IS_ENABLED(CONFIG_SLS))
+			EMIT1(0xCC);	/* int3 */
+	}
+
 	*pprog = prog;
 }
 
@@ -244,11 +429,13 @@
  *   goto *(prog->bpf_func + prologue_size);
  * out:
  */
-static void emit_bpf_tail_call(u8 **pprog)
+static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
+					u32 stack_depth, u8 *ip,
+					struct jit_context *ctx)
 {
-	u8 *prog = *pprog;
-	int label1, label2, label3;
-	int cnt = 0;
+	int tcc_off = -4 - round_up(stack_depth, 8);
+	u8 *prog = *pprog, *start = *pprog;
+	int cnt = 0, offset;
 
 	/*
 	 * rdi - pointer to ctx
@@ -263,52 +450,143 @@
 	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
 	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 	      offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
-	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
-	label1 = cnt;
+
+	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+	EMIT2(X86_JBE, offset);                   /* jbe out */
 
 	/*
 	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 	 *	goto out;
 	 */
-	EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+	EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
 	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
-	EMIT2(X86_JA, OFFSET2);                   /* ja out */
-	label2 = cnt;
+
+	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+	EMIT2(X86_JA, offset);                    /* ja out */
 	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
-	EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+	EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
 
 	/* prog = array->ptrs[index]; */
-	EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
+	EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
 		    offsetof(struct bpf_array, ptrs));
 
 	/*
 	 * if (prog == NULL)
 	 *	goto out;
 	 */
-	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
-#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
-	EMIT2(X86_JE, OFFSET3);                   /* je out */
-	label3 = cnt;
+	EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
 
-	/* goto *(prog->bpf_func + prologue_size); */
-	EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
+	offset = ctx->tail_call_indirect_label - (prog + 2 - start);
+	EMIT2(X86_JE, offset);                    /* je out */
+
+	pop_callee_regs(&prog, callee_regs_used);
+
+	EMIT1(0x58);                              /* pop rax */
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
+			    round_up(stack_depth, 8));
+
+	/* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
+	EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
 	      offsetof(struct bpf_prog, bpf_func));
-	EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
-
+	EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
+	      X86_TAIL_CALL_OFFSET);
 	/*
-	 * Wow we're ready to jump into next BPF program
+	 * Now we're ready to jump into next BPF program
 	 * rdi == ctx (1st arg)
-	 * rax == prog->bpf_func + prologue_size
+	 * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
 	 */
-	RETPOLINE_RAX_BPF_JIT();
+	emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start));
 
 	/* out: */
-	BUILD_BUG_ON(cnt - label1 != OFFSET1);
-	BUILD_BUG_ON(cnt - label2 != OFFSET2);
-	BUILD_BUG_ON(cnt - label3 != OFFSET3);
+	ctx->tail_call_indirect_label = prog - start;
 	*pprog = prog;
+}
+
+static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
+				      u8 **pprog, u8 *ip,
+				      bool *callee_regs_used, u32 stack_depth,
+				      struct jit_context *ctx)
+{
+	int tcc_off = -4 - round_up(stack_depth, 8);
+	u8 *prog = *pprog, *start = *pprog;
+	int cnt = 0, offset;
+
+	/*
+	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
+	 *	goto out;
+	 */
+	EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
+	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
+
+	offset = ctx->tail_call_direct_label - (prog + 2 - start);
+	EMIT2(X86_JA, offset);                        /* ja out */
+	EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
+	EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
+
+	poke->tailcall_bypass = ip + (prog - start);
+	poke->adj_off = X86_TAIL_CALL_OFFSET;
+	poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
+	poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
+
+	emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
+		  poke->tailcall_bypass);
+
+	pop_callee_regs(&prog, callee_regs_used);
+	EMIT1(0x58);                                  /* pop rax */
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
+
+	memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
+	prog += X86_PATCH_SIZE;
+
+	/* out: */
+	ctx->tail_call_direct_label = prog - start;
+
+	*pprog = prog;
+}
+
+static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
+{
+	struct bpf_jit_poke_descriptor *poke;
+	struct bpf_array *array;
+	struct bpf_prog *target;
+	int i, ret;
+
+	for (i = 0; i < prog->aux->size_poke_tab; i++) {
+		poke = &prog->aux->poke_tab[i];
+		WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable));
+
+		if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
+			continue;
+
+		array = container_of(poke->tail_call.map, struct bpf_array, map);
+		mutex_lock(&array->aux->poke_mutex);
+		target = array->ptrs[poke->tail_call.key];
+		if (target) {
+			/* Plain memcpy is used when image is not live yet
+			 * and still not locked as read-only. Once poke
+			 * location is active (poke->tailcall_target_stable),
+			 * any parallel bpf_arch_text_poke() might occur
+			 * still on the read-write image until we finally
+			 * locked it as read-only. Both modifications on
+			 * the given image are under text_mutex to avoid
+			 * interference.
+			 */
+			ret = __bpf_arch_text_poke(poke->tailcall_target,
+						   BPF_MOD_JUMP, NULL,
+						   (u8 *)target->bpf_func +
+						   poke->adj_off, false);
+			BUG_ON(ret < 0);
+			ret = __bpf_arch_text_poke(poke->tailcall_bypass,
+						   BPF_MOD_JUMP,
+						   (u8 *)poke->tailcall_target +
+						   X86_PATCH_SIZE, NULL, false);
+			BUG_ON(ret < 0);
+		}
+		WRITE_ONCE(poke->tailcall_target_stable, true);
+		mutex_unlock(&array->aux->poke_mutex);
+	}
 }
 
 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
@@ -394,21 +672,141 @@
 	*pprog = prog;
 }
 
+/* LDX: dst_reg = *(u8*)(src_reg + off) */
+static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	switch (size) {
+	case BPF_B:
+		/* Emit 'movzx rax, byte ptr [rax + off]' */
+		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
+		break;
+	case BPF_H:
+		/* Emit 'movzx rax, word ptr [rax + off]' */
+		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
+		break;
+	case BPF_W:
+		/* Emit 'mov eax, dword ptr [rax+0x14]' */
+		if (is_ereg(dst_reg) || is_ereg(src_reg))
+			EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
+		else
+			EMIT1(0x8B);
+		break;
+	case BPF_DW:
+		/* Emit 'mov rax, qword ptr [rax+0x14]' */
+		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
+		break;
+	}
+	/*
+	 * If insn->off == 0 we can save one extra byte, but
+	 * special case of x86 R13 which always needs an offset
+	 * is not worth the hassle
+	 */
+	if (is_imm8(off))
+		EMIT2(add_2reg(0x40, src_reg, dst_reg), off);
+	else
+		EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), off);
+	*pprog = prog;
+}
+
+/* STX: *(u8*)(dst_reg + off) = src_reg */
+static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	switch (size) {
+	case BPF_B:
+		/* Emit 'mov byte ptr [rax + off], al' */
+		if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
+			/* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
+			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
+		else
+			EMIT1(0x88);
+		break;
+	case BPF_H:
+		if (is_ereg(dst_reg) || is_ereg(src_reg))
+			EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
+		else
+			EMIT2(0x66, 0x89);
+		break;
+	case BPF_W:
+		if (is_ereg(dst_reg) || is_ereg(src_reg))
+			EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
+		else
+			EMIT1(0x89);
+		break;
+	case BPF_DW:
+		EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
+		break;
+	}
+	if (is_imm8(off))
+		EMIT2(add_2reg(0x40, dst_reg, src_reg), off);
+	else
+		EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), off);
+	*pprog = prog;
+}
+
+static bool ex_handler_bpf(const struct exception_table_entry *x,
+			   struct pt_regs *regs, int trapnr,
+			   unsigned long error_code, unsigned long fault_addr)
+{
+	u32 reg = x->fixup >> 8;
+
+	/* jump over faulting load and clear dest register */
+	*(unsigned long *)((void *)regs + reg) = 0;
+	regs->ip += x->fixup & 0xff;
+	return true;
+}
+
+static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
+			     bool *regs_used, bool *tail_call_seen)
+{
+	int i;
+
+	for (i = 1; i <= insn_cnt; i++, insn++) {
+		if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
+			*tail_call_seen = true;
+		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
+			regs_used[0] = true;
+		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
+			regs_used[1] = true;
+		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
+			regs_used[2] = true;
+		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
+			regs_used[3] = true;
+	}
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 		  int oldproglen, struct jit_context *ctx)
 {
+	bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
 	struct bpf_insn *insn = bpf_prog->insnsi;
+	bool callee_regs_used[4] = {};
 	int insn_cnt = bpf_prog->len;
+	bool tail_call_seen = false;
 	bool seen_exit = false;
 	u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
-	int i, cnt = 0;
+	int i, cnt = 0, excnt = 0;
 	int proglen = 0;
 	u8 *prog = temp;
 
-	emit_prologue(&prog, bpf_prog->aux->stack_depth,
-		      bpf_prog_was_classic(bpf_prog));
+	detect_reg_usage(insn, insn_cnt, callee_regs_used,
+			 &tail_call_seen);
 
-	for (i = 0; i < insn_cnt; i++, insn++) {
+	/* tail call's presence in current prog implies it is reachable */
+	tail_call_reachable |= tail_call_seen;
+
+	emit_prologue(&prog, bpf_prog->aux->stack_depth,
+		      bpf_prog_was_classic(bpf_prog), tail_call_reachable,
+		      bpf_prog->aux->func_idx != 0);
+	push_callee_regs(&prog, callee_regs_used);
+	addrs[0] = prog - temp;
+
+	for (i = 1; i <= insn_cnt; i++, insn++) {
 		const s32 imm32 = insn->imm;
 		u32 dst_reg = insn->dst_reg;
 		u32 src_reg = insn->src_reg;
@@ -734,8 +1132,7 @@
 			/* speculation barrier */
 		case BPF_ST | BPF_NOSPEC:
 			if (boot_cpu_has(X86_FEATURE_XMM2))
-				/* Emit 'lfence' */
-				EMIT3(0x0F, 0xAE, 0xE8);
+				EMIT_LFENCE();
 			break;
 
 			/* ST: *(u8*)(dst_reg + off) = imm */
@@ -770,63 +1167,64 @@
 
 			/* STX: *(u8*)(dst_reg + off) = src_reg */
 		case BPF_STX | BPF_MEM | BPF_B:
-			/* Emit 'mov byte ptr [rax + off], al' */
-			if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
-				/* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
-				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
-			else
-				EMIT1(0x88);
-			goto stx;
 		case BPF_STX | BPF_MEM | BPF_H:
-			if (is_ereg(dst_reg) || is_ereg(src_reg))
-				EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
-			else
-				EMIT2(0x66, 0x89);
-			goto stx;
 		case BPF_STX | BPF_MEM | BPF_W:
-			if (is_ereg(dst_reg) || is_ereg(src_reg))
-				EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
-			else
-				EMIT1(0x89);
-			goto stx;
 		case BPF_STX | BPF_MEM | BPF_DW:
-			EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
-stx:			if (is_imm8(insn->off))
-				EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
-			else
-				EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
-					    insn->off);
+			emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
 			break;
 
 			/* LDX: dst_reg = *(u8*)(src_reg + off) */
 		case BPF_LDX | BPF_MEM | BPF_B:
-			/* Emit 'movzx rax, byte ptr [rax + off]' */
-			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
-			goto ldx;
+		case BPF_LDX | BPF_PROBE_MEM | BPF_B:
 		case BPF_LDX | BPF_MEM | BPF_H:
-			/* Emit 'movzx rax, word ptr [rax + off]' */
-			EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
-			goto ldx;
+		case BPF_LDX | BPF_PROBE_MEM | BPF_H:
 		case BPF_LDX | BPF_MEM | BPF_W:
-			/* Emit 'mov eax, dword ptr [rax+0x14]' */
-			if (is_ereg(dst_reg) || is_ereg(src_reg))
-				EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
-			else
-				EMIT1(0x8B);
-			goto ldx;
+		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
 		case BPF_LDX | BPF_MEM | BPF_DW:
-			/* Emit 'mov rax, qword ptr [rax+0x14]' */
-			EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
-ldx:			/*
-			 * If insn->off == 0 we can save one extra byte, but
-			 * special case of x86 R13 which always needs an offset
-			 * is not worth the hassle
-			 */
-			if (is_imm8(insn->off))
-				EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
-			else
-				EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
-					    insn->off);
+		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
+			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
+				struct exception_table_entry *ex;
+				u8 *_insn = image + proglen;
+				s64 delta;
+
+				if (!bpf_prog->aux->extable)
+					break;
+
+				if (excnt >= bpf_prog->aux->num_exentries) {
+					pr_err("ex gen bug\n");
+					return -EFAULT;
+				}
+				ex = &bpf_prog->aux->extable[excnt++];
+
+				delta = _insn - (u8 *)&ex->insn;
+				if (!is_simm32(delta)) {
+					pr_err("extable->insn doesn't fit into 32-bit\n");
+					return -EFAULT;
+				}
+				ex->insn = delta;
+
+				delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler;
+				if (!is_simm32(delta)) {
+					pr_err("extable->handler doesn't fit into 32-bit\n");
+					return -EFAULT;
+				}
+				ex->handler = delta;
+
+				if (dst_reg > BPF_REG_9) {
+					pr_err("verifier error\n");
+					return -EFAULT;
+				}
+				/*
+				 * Compute size of x86 insn and its target dest x86 register.
+				 * ex_handler_bpf() will use lower 8 bits to adjust
+				 * pt_regs->ip to jump over this x86 instruction
+				 * and upper bits to figure out which pt_regs to zero out.
+				 * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
+				 * of 4 bytes will be ignored and rbx will be zero inited.
+				 */
+				ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8);
+			}
 			break;
 
 			/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
@@ -849,17 +1247,31 @@
 			/* call */
 		case BPF_JMP | BPF_CALL:
 			func = (u8 *) __bpf_call_base + imm32;
-			jmp_offset = func - (image + addrs[i]);
-			if (!imm32 || !is_simm32(jmp_offset)) {
-				pr_err("unsupported BPF func %d addr %p image %p\n",
-				       imm32, func, image);
-				return -EINVAL;
+			if (tail_call_reachable) {
+				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
+				EMIT3_off32(0x48, 0x8B, 0x85,
+					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
+				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
+					return -EINVAL;
+			} else {
+				if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
+					return -EINVAL;
 			}
-			EMIT1_off32(0xE8, jmp_offset);
 			break;
 
 		case BPF_JMP | BPF_TAIL_CALL:
-			emit_bpf_tail_call(&prog);
+			if (imm32)
+				emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
+							  &prog, image + addrs[i - 1],
+							  callee_regs_used,
+							  bpf_prog->aux->stack_depth,
+							  ctx);
+			else
+				emit_bpf_tail_call_indirect(&prog,
+							    callee_regs_used,
+							    bpf_prog->aux->stack_depth,
+							    image + addrs[i - 1],
+							    ctx);
 			break;
 
 			/* cond jump */
@@ -873,20 +1285,41 @@
 		case BPF_JMP | BPF_JSLT | BPF_X:
 		case BPF_JMP | BPF_JSGE | BPF_X:
 		case BPF_JMP | BPF_JSLE | BPF_X:
+		case BPF_JMP32 | BPF_JEQ | BPF_X:
+		case BPF_JMP32 | BPF_JNE | BPF_X:
+		case BPF_JMP32 | BPF_JGT | BPF_X:
+		case BPF_JMP32 | BPF_JLT | BPF_X:
+		case BPF_JMP32 | BPF_JGE | BPF_X:
+		case BPF_JMP32 | BPF_JLE | BPF_X:
+		case BPF_JMP32 | BPF_JSGT | BPF_X:
+		case BPF_JMP32 | BPF_JSLT | BPF_X:
+		case BPF_JMP32 | BPF_JSGE | BPF_X:
+		case BPF_JMP32 | BPF_JSLE | BPF_X:
 			/* cmp dst_reg, src_reg */
-			EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
-			      add_2reg(0xC0, dst_reg, src_reg));
+			if (BPF_CLASS(insn->code) == BPF_JMP)
+				EMIT1(add_2mod(0x48, dst_reg, src_reg));
+			else if (is_ereg(dst_reg) || is_ereg(src_reg))
+				EMIT1(add_2mod(0x40, dst_reg, src_reg));
+			EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
 			goto emit_cond_jmp;
 
 		case BPF_JMP | BPF_JSET | BPF_X:
+		case BPF_JMP32 | BPF_JSET | BPF_X:
 			/* test dst_reg, src_reg */
-			EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
-			      add_2reg(0xC0, dst_reg, src_reg));
+			if (BPF_CLASS(insn->code) == BPF_JMP)
+				EMIT1(add_2mod(0x48, dst_reg, src_reg));
+			else if (is_ereg(dst_reg) || is_ereg(src_reg))
+				EMIT1(add_2mod(0x40, dst_reg, src_reg));
+			EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
 			goto emit_cond_jmp;
 
 		case BPF_JMP | BPF_JSET | BPF_K:
+		case BPF_JMP32 | BPF_JSET | BPF_K:
 			/* test dst_reg, imm32 */
-			EMIT1(add_1mod(0x48, dst_reg));
+			if (BPF_CLASS(insn->code) == BPF_JMP)
+				EMIT1(add_1mod(0x48, dst_reg));
+			else if (is_ereg(dst_reg))
+				EMIT1(add_1mod(0x40, dst_reg));
 			EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
 			goto emit_cond_jmp;
 
@@ -900,8 +1333,31 @@
 		case BPF_JMP | BPF_JSLT | BPF_K:
 		case BPF_JMP | BPF_JSGE | BPF_K:
 		case BPF_JMP | BPF_JSLE | BPF_K:
+		case BPF_JMP32 | BPF_JEQ | BPF_K:
+		case BPF_JMP32 | BPF_JNE | BPF_K:
+		case BPF_JMP32 | BPF_JGT | BPF_K:
+		case BPF_JMP32 | BPF_JLT | BPF_K:
+		case BPF_JMP32 | BPF_JGE | BPF_K:
+		case BPF_JMP32 | BPF_JLE | BPF_K:
+		case BPF_JMP32 | BPF_JSGT | BPF_K:
+		case BPF_JMP32 | BPF_JSLT | BPF_K:
+		case BPF_JMP32 | BPF_JSGE | BPF_K:
+		case BPF_JMP32 | BPF_JSLE | BPF_K:
+			/* test dst_reg, dst_reg to save one extra byte */
+			if (imm32 == 0) {
+				if (BPF_CLASS(insn->code) == BPF_JMP)
+					EMIT1(add_2mod(0x48, dst_reg, dst_reg));
+				else if (is_ereg(dst_reg))
+					EMIT1(add_2mod(0x40, dst_reg, dst_reg));
+				EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
+				goto emit_cond_jmp;
+			}
+
 			/* cmp dst_reg, imm8/32 */
-			EMIT1(add_1mod(0x48, dst_reg));
+			if (BPF_CLASS(insn->code) == BPF_JMP)
+				EMIT1(add_1mod(0x48, dst_reg));
+			else if (is_ereg(dst_reg))
+				EMIT1(add_1mod(0x40, dst_reg));
 
 			if (is_imm8(imm32))
 				EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
@@ -998,14 +1454,9 @@
 			seen_exit = true;
 			/* Update cleanup_addr */
 			ctx->cleanup_addr = proglen;
-			if (!bpf_prog_was_classic(bpf_prog))
-				EMIT1(0x5B); /* get rid of tail_call_cnt */
-			EMIT2(0x41, 0x5F);   /* pop r15 */
-			EMIT2(0x41, 0x5E);   /* pop r14 */
-			EMIT2(0x41, 0x5D);   /* pop r13 */
-			EMIT1(0x5B);         /* pop rbx */
+			pop_callee_regs(&prog, callee_regs_used);
 			EMIT1(0xC9);         /* leave */
-			EMIT1(0xC3);         /* ret */
+			emit_return(&prog, image + addrs[i - 1] + (prog - temp));
 			break;
 
 		default:
@@ -1045,7 +1496,506 @@
 		addrs[i] = proglen;
 		prog = temp;
 	}
+
+	if (image && excnt != bpf_prog->aux->num_exentries) {
+		pr_err("extable is not populated\n");
+		return -EFAULT;
+	}
 	return proglen;
+}
+
+static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
+		      int stack_size)
+{
+	int i;
+	/* Store function arguments to stack.
+	 * For a function that accepts two pointers the sequence will be:
+	 * mov QWORD PTR [rbp-0x10],rdi
+	 * mov QWORD PTR [rbp-0x8],rsi
+	 */
+	for (i = 0; i < min(nr_args, 6); i++)
+		emit_stx(prog, bytes_to_bpf_size(m->arg_size[i]),
+			 BPF_REG_FP,
+			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
+			 -(stack_size - i * 8));
+}
+
+static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
+			 int stack_size)
+{
+	int i;
+
+	/* Restore function arguments from stack.
+	 * For a function that accepts two pointers the sequence will be:
+	 * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
+	 * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
+	 */
+	for (i = 0; i < min(nr_args, 6); i++)
+		emit_ldx(prog, bytes_to_bpf_size(m->arg_size[i]),
+			 i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
+			 BPF_REG_FP,
+			 -(stack_size - i * 8));
+}
+
+static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
+			   struct bpf_prog *p, int stack_size, bool save_ret)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	if (p->aux->sleepable) {
+		if (emit_call(&prog, __bpf_prog_enter_sleepable, prog))
+			return -EINVAL;
+	} else {
+		if (emit_call(&prog, __bpf_prog_enter, prog))
+			return -EINVAL;
+		/* remember prog start time returned by __bpf_prog_enter */
+		emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
+	}
+
+	/* arg1: lea rdi, [rbp - stack_size] */
+	EMIT4(0x48, 0x8D, 0x7D, -stack_size);
+	/* arg2: progs[i]->insnsi for interpreter */
+	if (!p->jited)
+		emit_mov_imm64(&prog, BPF_REG_2,
+			       (long) p->insnsi >> 32,
+			       (u32) (long) p->insnsi);
+	/* call JITed bpf program or interpreter */
+	if (emit_call(&prog, p->bpf_func, prog))
+		return -EINVAL;
+
+	/*
+	 * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
+	 * of the previous call which is then passed on the stack to
+	 * the next BPF program.
+	 *
+	 * BPF_TRAMP_FENTRY trampoline may need to return the return
+	 * value of BPF_PROG_TYPE_STRUCT_OPS prog.
+	 */
+	if (save_ret)
+		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+
+	if (p->aux->sleepable) {
+		if (emit_call(&prog, __bpf_prog_exit_sleepable, prog))
+			return -EINVAL;
+	} else {
+		/* arg1: mov rdi, progs[i] */
+		emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32,
+			       (u32) (long) p);
+		/* arg2: mov rsi, rbx <- start time in nsec */
+		emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
+		if (emit_call(&prog, __bpf_prog_exit, prog))
+			return -EINVAL;
+	}
+
+	*pprog = prog;
+	return 0;
+}
+
+static void emit_nops(u8 **pprog, unsigned int len)
+{
+	unsigned int i, noplen;
+	u8 *prog = *pprog;
+	int cnt = 0;
+
+	while (len > 0) {
+		noplen = len;
+
+		if (noplen > ASM_NOP_MAX)
+			noplen = ASM_NOP_MAX;
+
+		for (i = 0; i < noplen; i++)
+			EMIT1(ideal_nops[noplen][i]);
+		len -= noplen;
+	}
+
+	*pprog = prog;
+}
+
+static void emit_align(u8 **pprog, u32 align)
+{
+	u8 *target, *prog = *pprog;
+
+	target = PTR_ALIGN(prog, align);
+	if (target != prog)
+		emit_nops(&prog, target - prog);
+
+	*pprog = prog;
+}
+
+static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
+{
+	u8 *prog = *pprog;
+	int cnt = 0;
+	s64 offset;
+
+	offset = func - (ip + 2 + 4);
+	if (!is_simm32(offset)) {
+		pr_err("Target %p is out of range\n", func);
+		return -EINVAL;
+	}
+	EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
+	*pprog = prog;
+	return 0;
+}
+
+static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
+		      struct bpf_tramp_progs *tp, int stack_size,
+		      bool save_ret)
+{
+	int i;
+	u8 *prog = *pprog;
+
+	for (i = 0; i < tp->nr_progs; i++) {
+		if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size,
+				    save_ret))
+			return -EINVAL;
+	}
+	*pprog = prog;
+	return 0;
+}
+
+static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
+			      struct bpf_tramp_progs *tp, int stack_size,
+			      u8 **branches)
+{
+	u8 *prog = *pprog;
+	int i, cnt = 0;
+
+	/* The first fmod_ret program will receive a garbage return value.
+	 * Set this to 0 to avoid confusing the program.
+	 */
+	emit_mov_imm32(&prog, false, BPF_REG_0, 0);
+	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+	for (i = 0; i < tp->nr_progs; i++) {
+		if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, true))
+			return -EINVAL;
+
+		/* mod_ret prog stored return value into [rbp - 8]. Emit:
+		 * if (*(u64 *)(rbp - 8) !=  0)
+		 *	goto do_fexit;
+		 */
+		/* cmp QWORD PTR [rbp - 0x8], 0x0 */
+		EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
+
+		/* Save the location of the branch and Generate 6 nops
+		 * (4 bytes for an offset and 2 bytes for the jump) These nops
+		 * are replaced with a conditional jump once do_fexit (i.e. the
+		 * start of the fexit invocation) is finalized.
+		 */
+		branches[i] = prog;
+		emit_nops(&prog, 4 + 2);
+	}
+
+	*pprog = prog;
+	return 0;
+}
+
+static bool is_valid_bpf_tramp_flags(unsigned int flags)
+{
+	if ((flags & BPF_TRAMP_F_RESTORE_REGS) &&
+	    (flags & BPF_TRAMP_F_SKIP_FRAME))
+		return false;
+
+	/*
+	 * BPF_TRAMP_F_RET_FENTRY_RET is only used by bpf_struct_ops,
+	 * and it must be used alone.
+	 */
+	if ((flags & BPF_TRAMP_F_RET_FENTRY_RET) &&
+	    (flags & ~BPF_TRAMP_F_RET_FENTRY_RET))
+		return false;
+
+	return true;
+}
+
+/* Example:
+ * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
+ * its 'struct btf_func_model' will be nr_args=2
+ * The assembly code when eth_type_trans is executing after trampoline:
+ *
+ * push rbp
+ * mov rbp, rsp
+ * sub rsp, 16                     // space for skb and dev
+ * push rbx                        // temp regs to pass start time
+ * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
+ * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
+ * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
+ * mov rbx, rax                    // remember start time in bpf stats are enabled
+ * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
+ * call addr_of_jited_FENTRY_prog
+ * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
+ * mov rsi, rbx                    // prog start time
+ * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
+ * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
+ * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
+ * pop rbx
+ * leave
+ * ret
+ *
+ * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
+ * replaced with 'call generated_bpf_trampoline'. When it returns
+ * eth_type_trans will continue executing with original skb and dev pointers.
+ *
+ * The assembly code when eth_type_trans is called from trampoline:
+ *
+ * push rbp
+ * mov rbp, rsp
+ * sub rsp, 24                     // space for skb, dev, return value
+ * push rbx                        // temp regs to pass start time
+ * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
+ * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
+ * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
+ * mov rbx, rax                    // remember start time if bpf stats are enabled
+ * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
+ * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
+ * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
+ * mov rsi, rbx                    // prog start time
+ * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
+ * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
+ * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
+ * call eth_type_trans+5           // execute body of eth_type_trans
+ * mov qword ptr [rbp - 8], rax    // save return value
+ * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
+ * mov rbx, rax                    // remember start time in bpf stats are enabled
+ * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
+ * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
+ * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
+ * mov rsi, rbx                    // prog start time
+ * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
+ * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
+ * pop rbx
+ * leave
+ * add rsp, 8                      // skip eth_type_trans's frame
+ * ret                             // return to its caller
+ */
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
+				const struct btf_func_model *m, u32 flags,
+				struct bpf_tramp_progs *tprogs,
+				void *orig_call)
+{
+	int ret, i, cnt = 0, nr_args = m->nr_args;
+	int stack_size = nr_args * 8;
+	struct bpf_tramp_progs *fentry = &tprogs[BPF_TRAMP_FENTRY];
+	struct bpf_tramp_progs *fexit = &tprogs[BPF_TRAMP_FEXIT];
+	struct bpf_tramp_progs *fmod_ret = &tprogs[BPF_TRAMP_MODIFY_RETURN];
+	u8 **branches = NULL;
+	u8 *prog;
+	bool save_ret;
+
+	/* x86-64 supports up to 6 arguments. 7+ can be added in the future */
+	if (nr_args > 6)
+		return -ENOTSUPP;
+
+	if (!is_valid_bpf_tramp_flags(flags))
+		return -EINVAL;
+
+	/* room for return value of orig_call or fentry prog */
+	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
+	if (save_ret)
+		stack_size += 8;
+
+	if (flags & BPF_TRAMP_F_SKIP_FRAME)
+		/* skip patched call instruction and point orig_call to actual
+		 * body of the kernel function.
+		 */
+		orig_call += X86_PATCH_SIZE;
+
+	prog = image;
+
+	EMIT1(0x55);		 /* push rbp */
+	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
+	EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */
+	EMIT1(0x53);		 /* push rbx */
+
+	save_regs(m, &prog, nr_args, stack_size);
+
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		/* arg1: mov rdi, im */
+		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+		if (emit_call(&prog, __bpf_tramp_enter, prog)) {
+			ret = -EINVAL;
+			goto cleanup;
+		}
+	}
+
+	if (fentry->nr_progs)
+		if (invoke_bpf(m, &prog, fentry, stack_size,
+			       flags & BPF_TRAMP_F_RET_FENTRY_RET))
+			return -EINVAL;
+
+	if (fmod_ret->nr_progs) {
+		branches = kcalloc(fmod_ret->nr_progs, sizeof(u8 *),
+				   GFP_KERNEL);
+		if (!branches)
+			return -ENOMEM;
+
+		if (invoke_bpf_mod_ret(m, &prog, fmod_ret, stack_size,
+				       branches)) {
+			ret = -EINVAL;
+			goto cleanup;
+		}
+	}
+
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		restore_regs(m, &prog, nr_args, stack_size);
+
+		/* call original function */
+		if (emit_call(&prog, orig_call, prog)) {
+			ret = -EINVAL;
+			goto cleanup;
+		}
+		/* remember return value in a stack for bpf prog to access */
+		emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+		im->ip_after_call = prog;
+		memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
+		prog += X86_PATCH_SIZE;
+	}
+
+	if (fmod_ret->nr_progs) {
+		/* From Intel 64 and IA-32 Architectures Optimization
+		 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
+		 * Coding Rule 11: All branch targets should be 16-byte
+		 * aligned.
+		 */
+		emit_align(&prog, 16);
+		/* Update the branches saved in invoke_bpf_mod_ret with the
+		 * aligned address of do_fexit.
+		 */
+		for (i = 0; i < fmod_ret->nr_progs; i++)
+			emit_cond_near_jump(&branches[i], prog, branches[i],
+					    X86_JNE);
+	}
+
+	if (fexit->nr_progs)
+		if (invoke_bpf(m, &prog, fexit, stack_size, false)) {
+			ret = -EINVAL;
+			goto cleanup;
+		}
+
+	if (flags & BPF_TRAMP_F_RESTORE_REGS)
+		restore_regs(m, &prog, nr_args, stack_size);
+
+	/* This needs to be done regardless. If there were fmod_ret programs,
+	 * the return value is only updated on the stack and still needs to be
+	 * restored to R0.
+	 */
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		im->ip_epilogue = prog;
+		/* arg1: mov rdi, im */
+		emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+		if (emit_call(&prog, __bpf_tramp_exit, prog)) {
+			ret = -EINVAL;
+			goto cleanup;
+		}
+	}
+	/* restore return value of orig_call or fentry prog back into RAX */
+	if (save_ret)
+		emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
+
+	EMIT1(0x5B); /* pop rbx */
+	EMIT1(0xC9); /* leave */
+	if (flags & BPF_TRAMP_F_SKIP_FRAME)
+		/* skip our return address and return to parent */
+		EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
+	emit_return(&prog, prog);
+	/* Make sure the trampoline generation logic doesn't overflow */
+	if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) {
+		ret = -EFAULT;
+		goto cleanup;
+	}
+	ret = prog - (u8 *)image;
+
+cleanup:
+	kfree(branches);
+	return ret;
+}
+
+static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs)
+{
+	u8 *jg_reloc, *prog = *pprog;
+	int pivot, err, jg_bytes = 1, cnt = 0;
+	s64 jg_offset;
+
+	if (a == b) {
+		/* Leaf node of recursion, i.e. not a range of indices
+		 * anymore.
+		 */
+		EMIT1(add_1mod(0x48, BPF_REG_3));	/* cmp rdx,func */
+		if (!is_simm32(progs[a]))
+			return -1;
+		EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
+			    progs[a]);
+		err = emit_cond_near_jump(&prog,	/* je func */
+					  (void *)progs[a], prog,
+					  X86_JE);
+		if (err)
+			return err;
+
+		emit_indirect_jump(&prog, 2 /* rdx */, prog);
+
+		*pprog = prog;
+		return 0;
+	}
+
+	/* Not a leaf node, so we pivot, and recursively descend into
+	 * the lower and upper ranges.
+	 */
+	pivot = (b - a) / 2;
+	EMIT1(add_1mod(0x48, BPF_REG_3));		/* cmp rdx,func */
+	if (!is_simm32(progs[a + pivot]))
+		return -1;
+	EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
+
+	if (pivot > 2) {				/* jg upper_part */
+		/* Require near jump. */
+		jg_bytes = 4;
+		EMIT2_off32(0x0F, X86_JG + 0x10, 0);
+	} else {
+		EMIT2(X86_JG, 0);
+	}
+	jg_reloc = prog;
+
+	err = emit_bpf_dispatcher(&prog, a, a + pivot,	/* emit lower_part */
+				  progs);
+	if (err)
+		return err;
+
+	/* From Intel 64 and IA-32 Architectures Optimization
+	 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
+	 * Coding Rule 11: All branch targets should be 16-byte
+	 * aligned.
+	 */
+	emit_align(&prog, 16);
+	jg_offset = prog - jg_reloc;
+	emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
+
+	err = emit_bpf_dispatcher(&prog, a + pivot + 1,	/* emit upper_part */
+				  b, progs);
+	if (err)
+		return err;
+
+	*pprog = prog;
+	return 0;
+}
+
+static int cmp_ips(const void *a, const void *b)
+{
+	const s64 *ipa = a;
+	const s64 *ipb = b;
+
+	if (*ipa > *ipb)
+		return 1;
+	if (*ipa < *ipb)
+		return -1;
+	return 0;
+}
+
+int arch_prepare_bpf_dispatcher(void *image, s64 *funcs, int num_funcs)
+{
+	u8 *prog = image;
+
+	sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
+	return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs);
 }
 
 struct x64_jit_data {
@@ -1103,7 +2053,7 @@
 		extra_pass = true;
 		goto skip_init_addrs;
 	}
-	addrs = kmalloc_array(prog->len, sizeof(*addrs), GFP_KERNEL);
+	addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
 	if (!addrs) {
 		prog = orig_prog;
 		goto out_addrs;
@@ -1113,7 +2063,7 @@
 	 * Before first pass, make a rough estimation of addrs[]
 	 * each BPF instruction is translated to less than 64 bytes
 	 */
-	for (proglen = 0, i = 0; i < prog->len; i++) {
+	for (proglen = 0, i = 0; i <= prog->len; i++) {
 		proglen += 64;
 		addrs[i] = proglen;
 	}
@@ -1145,12 +2095,24 @@
 			break;
 		}
 		if (proglen == oldproglen) {
-			header = bpf_jit_binary_alloc(proglen, &image,
-						      1, jit_fill_hole);
+			/*
+			 * The number of entries in extable is the number of BPF_LDX
+			 * insns that access kernel memory via "pointer to BTF type".
+			 * The verifier changed their opcode from LDX|MEM|size
+			 * to LDX|PROBE_MEM|size to make JITing easier.
+			 */
+			u32 align = __alignof__(struct exception_table_entry);
+			u32 extable_size = prog->aux->num_exentries *
+				sizeof(struct exception_table_entry);
+
+			/* allocate module memory for x86 insns and extable */
+			header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size,
+						      &image, align, jit_fill_hole);
 			if (!header) {
 				prog = orig_prog;
 				goto out_addrs;
 			}
+			prog->aux->extable = (void *) image + roundup(proglen, align);
 		}
 		oldproglen = proglen;
 		cond_resched();
@@ -1161,6 +2123,7 @@
 
 	if (image) {
 		if (!prog->is_func || extra_pass) {
+			bpf_tail_call_direct_fixup(prog);
 			bpf_jit_binary_lock_ro(header);
 		} else {
 			jit_data->addrs = addrs;
@@ -1177,8 +2140,10 @@
 	}
 
 	if (!image || !prog->is_func || extra_pass) {
+		if (image)
+			bpf_prog_fill_jited_linfo(prog, addrs + 1);
 out_addrs:
-		kfree(addrs);
+		kvfree(addrs);
 		kfree(jit_data);
 		prog->aux->jit_data = NULL;
 	}

--
Gitblit v1.6.2