From 6778948f9de86c3cfaf36725a7c87dcff9ba247f Mon Sep 17 00:00:00 2001
From: hc <hc@nodka.com>
Date: Mon, 11 Dec 2023 08:20:59 +0000
Subject: [PATCH] kernel_5.10 no rt

---
 kernel/arch/arm/net/bpf_jit_32.c |  144 ++++++++++++++++++++++++++++++++---------------
 1 files changed, 98 insertions(+), 46 deletions(-)

diff --git a/kernel/arch/arm/net/bpf_jit_32.c b/kernel/arch/arm/net/bpf_jit_32.c
index dade3a3..1214e39 100644
--- a/kernel/arch/arm/net/bpf_jit_32.c
+++ b/kernel/arch/arm/net/bpf_jit_32.c
@@ -1,12 +1,9 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * Just-In-Time compiler for eBPF filters on 32bit ARM
  *
  * Copyright (c) 2017 Shubham Bansal <illusionist.neo@gmail.com>
  * Copyright (c) 2011 Mircea Gherzan <mgherzan@gmail.com>
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; version 2 of the License.
  */
 
 #include <linux/bpf.h>
@@ -755,7 +752,8 @@
 
 		/* ALU operation */
 		emit_alu_r(rd[1], rs, true, false, op, ctx);
-		emit_a32_mov_i(rd[0], 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(rd[0], 0, ctx);
 	}
 
 	arm_bpf_put_reg64(dst, rd, ctx);
@@ -777,8 +775,9 @@
 				  struct jit_ctx *ctx) {
 	if (!is64) {
 		emit_a32_mov_r(dst_lo, src_lo, ctx);
-		/* Zero out high 4 bytes */
-		emit_a32_mov_i(dst_hi, 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			/* Zero out high 4 bytes */
+			emit_a32_mov_i(dst_hi, 0, ctx);
 	} else if (__LINUX_ARM_ARCH__ < 6 &&
 		   ctx->cpu_architecture < CPU_ARCH_ARMv5TE) {
 		/* complete 8 byte move */
@@ -814,6 +813,9 @@
 		break;
 	case BPF_RSH:
 		emit(ARM_LSR_I(rd, rd, val), ctx);
+		break;
+	case BPF_ARSH:
+		emit(ARM_ASR_I(rd, rd, val), ctx);
 		break;
 	case BPF_NEG:
 		emit(ARM_RSB_I(rd, rd, val), ctx);
@@ -880,8 +882,8 @@
 	emit(ARM_SUBS_I(tmp2[0], rt, 32), ctx);
 	emit(ARM_MOV_SR(ARM_LR, rd[1], SRTYPE_LSR, rt), ctx);
 	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rd[0], SRTYPE_ASL, ARM_IP), ctx);
-	_emit(ARM_COND_MI, ARM_B(0), ctx);
-	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rd[0], SRTYPE_ASR, tmp2[0]), ctx);
+	_emit(ARM_COND_PL,
+	      ARM_ORR_SR(ARM_LR, ARM_LR, rd[0], SRTYPE_ASR, tmp2[0]), ctx);
 	emit(ARM_MOV_SR(ARM_IP, rd[0], SRTYPE_ASR, rt), ctx);
 
 	arm_bpf_put_reg32(dst_lo, ARM_LR, ctx);
@@ -1095,17 +1097,20 @@
 	case BPF_B:
 		/* Load a Byte */
 		emit(ARM_LDRB_I(rd[1], rm, off), ctx);
-		emit_a32_mov_i(rd[0], 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(rd[0], 0, ctx);
 		break;
 	case BPF_H:
 		/* Load a HalfWord */
 		emit(ARM_LDRH_I(rd[1], rm, off), ctx);
-		emit_a32_mov_i(rd[0], 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(rd[0], 0, ctx);
 		break;
 	case BPF_W:
 		/* Load a Word */
 		emit(ARM_LDR_I(rd[1], rm, off), ctx);
-		emit_a32_mov_i(rd[0], 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(rd[0], 0, ctx);
 		break;
 	case BPF_DW:
 		/* Load a Double Word */
@@ -1118,12 +1123,17 @@
 
 /* Arithmatic Operation */
 static inline void emit_ar_r(const u8 rd, const u8 rt, const u8 rm,
-			     const u8 rn, struct jit_ctx *ctx, u8 op) {
+			     const u8 rn, struct jit_ctx *ctx, u8 op,
+			     bool is_jmp64) {
 	switch (op) {
 	case BPF_JSET:
-		emit(ARM_AND_R(ARM_IP, rt, rn), ctx);
-		emit(ARM_AND_R(ARM_LR, rd, rm), ctx);
-		emit(ARM_ORRS_R(ARM_IP, ARM_LR, ARM_IP), ctx);
+		if (is_jmp64) {
+			emit(ARM_AND_R(ARM_IP, rt, rn), ctx);
+			emit(ARM_AND_R(ARM_LR, rd, rm), ctx);
+			emit(ARM_ORRS_R(ARM_IP, ARM_LR, ARM_IP), ctx);
+		} else {
+			emit(ARM_ANDS_R(ARM_IP, rt, rn), ctx);
+		}
 		break;
 	case BPF_JEQ:
 	case BPF_JNE:
@@ -1131,18 +1141,25 @@
 	case BPF_JGE:
 	case BPF_JLE:
 	case BPF_JLT:
-		emit(ARM_CMP_R(rd, rm), ctx);
-		_emit(ARM_COND_EQ, ARM_CMP_R(rt, rn), ctx);
+		if (is_jmp64) {
+			emit(ARM_CMP_R(rd, rm), ctx);
+			/* Only compare low halve if high halve are equal. */
+			_emit(ARM_COND_EQ, ARM_CMP_R(rt, rn), ctx);
+		} else {
+			emit(ARM_CMP_R(rt, rn), ctx);
+		}
 		break;
 	case BPF_JSLE:
 	case BPF_JSGT:
 		emit(ARM_CMP_R(rn, rt), ctx);
-		emit(ARM_SBCS_R(ARM_IP, rm, rd), ctx);
+		if (is_jmp64)
+			emit(ARM_SBCS_R(ARM_IP, rm, rd), ctx);
 		break;
 	case BPF_JSLT:
 	case BPF_JSGE:
 		emit(ARM_CMP_R(rt, rn), ctx);
-		emit(ARM_SBCS_R(ARM_IP, rd, rm), ctx);
+		if (is_jmp64)
+			emit(ARM_SBCS_R(ARM_IP, rd, rm), ctx);
 		break;
 	}
 }
@@ -1281,12 +1298,9 @@
 
 static void build_prologue(struct jit_ctx *ctx)
 {
-	const s8 r0 = bpf2a32[BPF_REG_0][1];
-	const s8 r2 = bpf2a32[BPF_REG_1][1];
-	const s8 r3 = bpf2a32[BPF_REG_1][0];
-	const s8 r4 = bpf2a32[BPF_REG_6][1];
-	const s8 fplo = bpf2a32[BPF_REG_FP][1];
-	const s8 fphi = bpf2a32[BPF_REG_FP][0];
+	const s8 arm_r0 = bpf2a32[BPF_REG_0][1];
+	const s8 *bpf_r1 = bpf2a32[BPF_REG_1];
+	const s8 *bpf_fp = bpf2a32[BPF_REG_FP];
 	const s8 *tcc = bpf2a32[TCALL_CNT];
 
 	/* Save callee saved registers. */
@@ -1299,8 +1313,10 @@
 	emit(ARM_PUSH(CALLEE_PUSH_MASK), ctx);
 	emit(ARM_MOV_R(ARM_FP, ARM_SP), ctx);
 #endif
-	/* Save frame pointer for later */
-	emit(ARM_SUB_I(ARM_IP, ARM_SP, SCRATCH_SIZE), ctx);
+	/* mov r3, #0 */
+	/* sub r2, sp, #SCRATCH_SIZE */
+	emit(ARM_MOV_I(bpf_r1[0], 0), ctx);
+	emit(ARM_SUB_I(bpf_r1[1], ARM_SP, SCRATCH_SIZE), ctx);
 
 	ctx->stack_size = imm8m(STACK_SIZE);
 
@@ -1308,18 +1324,15 @@
 	emit(ARM_SUB_I(ARM_SP, ARM_SP, ctx->stack_size), ctx);
 
 	/* Set up BPF prog stack base register */
-	emit_a32_mov_r(fplo, ARM_IP, ctx);
-	emit_a32_mov_i(fphi, 0, ctx);
+	emit_a32_mov_r64(true, bpf_fp, bpf_r1, ctx);
 
-	/* mov r4, 0 */
-	emit(ARM_MOV_I(r4, 0), ctx);
+	/* Initialize Tail Count */
+	emit(ARM_MOV_I(bpf_r1[1], 0), ctx);
+	emit_a32_mov_r64(true, tcc, bpf_r1, ctx);
 
 	/* Move BPF_CTX to BPF_R1 */
-	emit(ARM_MOV_R(r3, r4), ctx);
-	emit(ARM_MOV_R(r2, r0), ctx);
-	/* Initialize Tail Count */
-	emit(ARM_STR_I(r4, ARM_FP, EBPF_SCRATCH_TO_ARM_FP(tcc[0])), ctx);
-	emit(ARM_STR_I(r4, ARM_FP, EBPF_SCRATCH_TO_ARM_FP(tcc[1])), ctx);
+	emit(ARM_MOV_R(bpf_r1[1], arm_r0), ctx);
+
 	/* end of prologue */
 }
 
@@ -1382,6 +1395,11 @@
 	case BPF_ALU64 | BPF_MOV | BPF_X:
 		switch (BPF_SRC(code)) {
 		case BPF_X:
+			if (imm == 1) {
+				/* Special mov32 for zext */
+				emit_a32_mov_i(dst_hi, 0, ctx);
+				break;
+			}
 			emit_a32_mov_r64(is64, dst, src, ctx);
 			break;
 		case BPF_K:
@@ -1412,7 +1430,6 @@
 	case BPF_ALU | BPF_MUL | BPF_X:
 	case BPF_ALU | BPF_LSH | BPF_X:
 	case BPF_ALU | BPF_RSH | BPF_X:
-	case BPF_ALU | BPF_ARSH | BPF_K:
 	case BPF_ALU | BPF_ARSH | BPF_X:
 	case BPF_ALU64 | BPF_ADD | BPF_K:
 	case BPF_ALU64 | BPF_ADD | BPF_X:
@@ -1461,22 +1478,26 @@
 		}
 		emit_udivmod(rd_lo, rd_lo, rt, ctx, BPF_OP(code));
 		arm_bpf_put_reg32(dst_lo, rd_lo, ctx);
-		emit_a32_mov_i(dst_hi, 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(dst_hi, 0, ctx);
 		break;
 	case BPF_ALU64 | BPF_DIV | BPF_K:
 	case BPF_ALU64 | BPF_DIV | BPF_X:
 	case BPF_ALU64 | BPF_MOD | BPF_K:
 	case BPF_ALU64 | BPF_MOD | BPF_X:
 		goto notyet;
-	/* dst = dst >> imm */
 	/* dst = dst << imm */
-	case BPF_ALU | BPF_RSH | BPF_K:
+	/* dst = dst >> imm */
+	/* dst = dst >> imm (signed) */
 	case BPF_ALU | BPF_LSH | BPF_K:
+	case BPF_ALU | BPF_RSH | BPF_K:
+	case BPF_ALU | BPF_ARSH | BPF_K:
 		if (unlikely(imm > 31))
 			return -EINVAL;
 		if (imm)
 			emit_a32_alu_i(dst_lo, imm, ctx, BPF_OP(code));
-		emit_a32_mov_i(dst_hi, 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(dst_hi, 0, ctx);
 		break;
 	/* dst = dst << imm */
 	case BPF_ALU64 | BPF_LSH | BPF_K:
@@ -1511,7 +1532,8 @@
 	/* dst = ~dst */
 	case BPF_ALU | BPF_NEG:
 		emit_a32_alu_i(dst_lo, 0, ctx, BPF_OP(code));
-		emit_a32_mov_i(dst_hi, 0, ctx);
+		if (!ctx->prog->aux->verifier_zext)
+			emit_a32_mov_i(dst_hi, 0, ctx);
 		break;
 	/* dst = ~dst (64 bit) */
 	case BPF_ALU64 | BPF_NEG:
@@ -1567,11 +1589,13 @@
 #else /* ARMv6+ */
 			emit(ARM_UXTH(rd[1], rd[1]), ctx);
 #endif
-			emit(ARM_EOR_R(rd[0], rd[0], rd[0]), ctx);
+			if (!ctx->prog->aux->verifier_zext)
+				emit(ARM_EOR_R(rd[0], rd[0], rd[0]), ctx);
 			break;
 		case 32:
 			/* zero-extend 32 bits into 64 bits */
-			emit(ARM_EOR_R(rd[0], rd[0], rd[0]), ctx);
+			if (!ctx->prog->aux->verifier_zext)
+				emit(ARM_EOR_R(rd[0], rd[0], rd[0]), ctx);
 			break;
 		case 64:
 			/* nop */
@@ -1653,6 +1677,17 @@
 	case BPF_JMP | BPF_JLT | BPF_X:
 	case BPF_JMP | BPF_JSLT | BPF_X:
 	case BPF_JMP | BPF_JSLE | BPF_X:
+	case BPF_JMP32 | BPF_JEQ | BPF_X:
+	case BPF_JMP32 | BPF_JGT | BPF_X:
+	case BPF_JMP32 | BPF_JGE | BPF_X:
+	case BPF_JMP32 | BPF_JNE | BPF_X:
+	case BPF_JMP32 | BPF_JSGT | BPF_X:
+	case BPF_JMP32 | BPF_JSGE | BPF_X:
+	case BPF_JMP32 | BPF_JSET | BPF_X:
+	case BPF_JMP32 | BPF_JLE | BPF_X:
+	case BPF_JMP32 | BPF_JLT | BPF_X:
+	case BPF_JMP32 | BPF_JSLT | BPF_X:
+	case BPF_JMP32 | BPF_JSLE | BPF_X:
 		/* Setup source registers */
 		rm = arm_bpf_get_reg32(src_hi, tmp2[0], ctx);
 		rn = arm_bpf_get_reg32(src_lo, tmp2[1], ctx);
@@ -1679,6 +1714,17 @@
 	case BPF_JMP | BPF_JLE | BPF_K:
 	case BPF_JMP | BPF_JSLT | BPF_K:
 	case BPF_JMP | BPF_JSLE | BPF_K:
+	case BPF_JMP32 | BPF_JEQ | BPF_K:
+	case BPF_JMP32 | BPF_JGT | BPF_K:
+	case BPF_JMP32 | BPF_JGE | BPF_K:
+	case BPF_JMP32 | BPF_JNE | BPF_K:
+	case BPF_JMP32 | BPF_JSGT | BPF_K:
+	case BPF_JMP32 | BPF_JSGE | BPF_K:
+	case BPF_JMP32 | BPF_JSET | BPF_K:
+	case BPF_JMP32 | BPF_JLT | BPF_K:
+	case BPF_JMP32 | BPF_JLE | BPF_K:
+	case BPF_JMP32 | BPF_JSLT | BPF_K:
+	case BPF_JMP32 | BPF_JSLE | BPF_K:
 		if (off == 0)
 			break;
 		rm = tmp2[0];
@@ -1690,7 +1736,8 @@
 		rd = arm_bpf_get_reg64(dst, tmp, ctx);
 
 		/* Check for the condition */
-		emit_ar_r(rd[0], rd[1], rm, rn, ctx, BPF_OP(code));
+		emit_ar_r(rd[0], rd[1], rm, rn, ctx, BPF_OP(code),
+			  BPF_CLASS(code) == BPF_JMP);
 
 		/* Setup JUMP instruction */
 		jmp_offset = bpf2a32_offset(i+off, i, ctx);
@@ -1841,6 +1888,11 @@
 	/* Nothing to do here. We support Internal BPF. */
 }
 
+bool bpf_jit_needs_zext(void)
+{
+	return true;
+}
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
 	struct bpf_prog *tmp, *orig_prog = prog;

--
Gitblit v1.6.2