Linux Audio

Check our new training course

Linux BSP upgrade and security maintenance

Need help to get security updates for your Linux BSP?
Loading...
Note: File does not exist in v5.4.
   1// SPDX-License-Identifier: GPL-2.0
   2/* BPF JIT compiler for RV64G
   3 *
   4 * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
   5 *
   6 */
   7
   8#include <linux/bitfield.h>
   9#include <linux/bpf.h>
  10#include <linux/filter.h>
  11#include <linux/memory.h>
  12#include <linux/stop_machine.h>
  13#include <asm/text-patching.h>
  14#include <asm/cfi.h>
  15#include <asm/percpu.h>
  16#include "bpf_jit.h"
  17
  18#define RV_MAX_REG_ARGS 8
  19#define RV_FENTRY_NINSNS 2
  20#define RV_FENTRY_NBYTES (RV_FENTRY_NINSNS * 4)
  21#define RV_KCFI_NINSNS (IS_ENABLED(CONFIG_CFI_CLANG) ? 1 : 0)
  22/* imm that allows emit_imm to emit max count insns */
  23#define RV_MAX_COUNT_IMM 0x7FFF7FF7FF7FF7FF
  24
  25#define RV_REG_TCC RV_REG_A6
  26#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
  27#define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */
  28
  29static const int regmap[] = {
  30	[BPF_REG_0] =	RV_REG_A5,
  31	[BPF_REG_1] =	RV_REG_A0,
  32	[BPF_REG_2] =	RV_REG_A1,
  33	[BPF_REG_3] =	RV_REG_A2,
  34	[BPF_REG_4] =	RV_REG_A3,
  35	[BPF_REG_5] =	RV_REG_A4,
  36	[BPF_REG_6] =	RV_REG_S1,
  37	[BPF_REG_7] =	RV_REG_S2,
  38	[BPF_REG_8] =	RV_REG_S3,
  39	[BPF_REG_9] =	RV_REG_S4,
  40	[BPF_REG_FP] =	RV_REG_S5,
  41	[BPF_REG_AX] =	RV_REG_T0,
  42};
  43
  44static const int pt_regmap[] = {
  45	[RV_REG_A0] = offsetof(struct pt_regs, a0),
  46	[RV_REG_A1] = offsetof(struct pt_regs, a1),
  47	[RV_REG_A2] = offsetof(struct pt_regs, a2),
  48	[RV_REG_A3] = offsetof(struct pt_regs, a3),
  49	[RV_REG_A4] = offsetof(struct pt_regs, a4),
  50	[RV_REG_A5] = offsetof(struct pt_regs, a5),
  51	[RV_REG_S1] = offsetof(struct pt_regs, s1),
  52	[RV_REG_S2] = offsetof(struct pt_regs, s2),
  53	[RV_REG_S3] = offsetof(struct pt_regs, s3),
  54	[RV_REG_S4] = offsetof(struct pt_regs, s4),
  55	[RV_REG_S5] = offsetof(struct pt_regs, s5),
  56	[RV_REG_T0] = offsetof(struct pt_regs, t0),
  57};
  58
  59enum {
  60	RV_CTX_F_SEEN_TAIL_CALL =	0,
  61	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
  62	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
  63	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
  64	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
  65	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
  66	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
  67	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
  68};
  69
  70static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
  71{
  72	u8 reg = regmap[bpf_reg];
  73
  74	switch (reg) {
  75	case RV_CTX_F_SEEN_S1:
  76	case RV_CTX_F_SEEN_S2:
  77	case RV_CTX_F_SEEN_S3:
  78	case RV_CTX_F_SEEN_S4:
  79	case RV_CTX_F_SEEN_S5:
  80	case RV_CTX_F_SEEN_S6:
  81		__set_bit(reg, &ctx->flags);
  82	}
  83	return reg;
  84};
  85
  86static bool seen_reg(int reg, struct rv_jit_context *ctx)
  87{
  88	switch (reg) {
  89	case RV_CTX_F_SEEN_CALL:
  90	case RV_CTX_F_SEEN_S1:
  91	case RV_CTX_F_SEEN_S2:
  92	case RV_CTX_F_SEEN_S3:
  93	case RV_CTX_F_SEEN_S4:
  94	case RV_CTX_F_SEEN_S5:
  95	case RV_CTX_F_SEEN_S6:
  96		return test_bit(reg, &ctx->flags);
  97	}
  98	return false;
  99}
 100
 101static void mark_fp(struct rv_jit_context *ctx)
 102{
 103	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
 104}
 105
 106static void mark_call(struct rv_jit_context *ctx)
 107{
 108	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
 109}
 110
 111static bool seen_call(struct rv_jit_context *ctx)
 112{
 113	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
 114}
 115
 116static void mark_tail_call(struct rv_jit_context *ctx)
 117{
 118	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
 119}
 120
 121static bool seen_tail_call(struct rv_jit_context *ctx)
 122{
 123	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
 124}
 125
 126static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
 127{
 128	mark_tail_call(ctx);
 129
 130	if (seen_call(ctx)) {
 131		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
 132		return RV_REG_S6;
 133	}
 134	return RV_REG_A6;
 135}
 136
 137static bool is_32b_int(s64 val)
 138{
 139	return -(1L << 31) <= val && val < (1L << 31);
 140}
 141
 142static bool in_auipc_jalr_range(s64 val)
 143{
 144	/*
 145	 * auipc+jalr can reach any signed PC-relative offset in the range
 146	 * [-2^31 - 2^11, 2^31 - 2^11).
 147	 */
 148	return (-(1L << 31) - (1L << 11)) <= val &&
 149		val < ((1L << 31) - (1L << 11));
 150}
 151
 152/* Modify rd pointer to alternate reg to avoid corrupting original reg */
 153static void emit_sextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
 154{
 155	emit_sextw(ra, *rd, ctx);
 156	*rd = ra;
 157}
 158
 159static void emit_zextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
 160{
 161	emit_zextw(ra, *rd, ctx);
 162	*rd = ra;
 163}
 164
 165/* Emit fixed-length instructions for address */
 166static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
 167{
 168	/*
 169	 * Use the ro_insns(RX) to calculate the offset as the BPF program will
 170	 * finally run from this memory region.
 171	 */
 172	u64 ip = (u64)(ctx->ro_insns + ctx->ninsns);
 173	s64 off = addr - ip;
 174	s64 upper = (off + (1 << 11)) >> 12;
 175	s64 lower = off & 0xfff;
 176
 177	if (extra_pass && !in_auipc_jalr_range(off)) {
 178		pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
 179		return -ERANGE;
 180	}
 181
 182	emit(rv_auipc(rd, upper), ctx);
 183	emit(rv_addi(rd, rd, lower), ctx);
 184	return 0;
 185}
 186
 187/* Emit variable-length instructions for 32-bit and 64-bit imm */
 188static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
 189{
 190	/* Note that the immediate from the add is sign-extended,
 191	 * which means that we need to compensate this by adding 2^12,
 192	 * when the 12th bit is set. A simpler way of doing this, and
 193	 * getting rid of the check, is to just add 2**11 before the
 194	 * shift. The "Loading a 32-Bit constant" example from the
 195	 * "Computer Organization and Design, RISC-V edition" book by
 196	 * Patterson/Hennessy highlights this fact.
 197	 *
 198	 * This also means that we need to process LSB to MSB.
 199	 */
 200	s64 upper = (val + (1 << 11)) >> 12;
 201	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
 202	 * and addi are signed and RVC checks will perform signed comparisons.
 203	 */
 204	s64 lower = ((val & 0xfff) << 52) >> 52;
 205	int shift;
 206
 207	if (is_32b_int(val)) {
 208		if (upper)
 209			emit_lui(rd, upper, ctx);
 210
 211		if (!upper) {
 212			emit_li(rd, lower, ctx);
 213			return;
 214		}
 215
 216		emit_addiw(rd, rd, lower, ctx);
 217		return;
 218	}
 219
 220	shift = __ffs(upper);
 221	upper >>= shift;
 222	shift += 12;
 223
 224	emit_imm(rd, upper, ctx);
 225
 226	emit_slli(rd, rd, shift, ctx);
 227	if (lower)
 228		emit_addi(rd, rd, lower, ctx);
 229}
 230
 231static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 232{
 233	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
 234
 235	if (seen_reg(RV_REG_RA, ctx)) {
 236		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
 237		store_offset -= 8;
 238	}
 239	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
 240	store_offset -= 8;
 241	if (seen_reg(RV_REG_S1, ctx)) {
 242		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
 243		store_offset -= 8;
 244	}
 245	if (seen_reg(RV_REG_S2, ctx)) {
 246		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
 247		store_offset -= 8;
 248	}
 249	if (seen_reg(RV_REG_S3, ctx)) {
 250		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
 251		store_offset -= 8;
 252	}
 253	if (seen_reg(RV_REG_S4, ctx)) {
 254		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
 255		store_offset -= 8;
 256	}
 257	if (seen_reg(RV_REG_S5, ctx)) {
 258		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
 259		store_offset -= 8;
 260	}
 261	if (seen_reg(RV_REG_S6, ctx)) {
 262		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
 263		store_offset -= 8;
 264	}
 265	if (ctx->arena_vm_start) {
 266		emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
 267		store_offset -= 8;
 268	}
 269
 270	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
 271	/* Set return value. */
 272	if (!is_tail_call)
 273		emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
 274	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
 275		  /* kcfi, fentry and TCC init insns will be skipped on tailcall */
 276		  is_tail_call ? (RV_KCFI_NINSNS + RV_FENTRY_NINSNS + 1) * 4 : 0,
 277		  ctx);
 278}
 279
 280static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
 281		     struct rv_jit_context *ctx)
 282{
 283	switch (cond) {
 284	case BPF_JEQ:
 285		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
 286		return;
 287	case BPF_JGT:
 288		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
 289		return;
 290	case BPF_JLT:
 291		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
 292		return;
 293	case BPF_JGE:
 294		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
 295		return;
 296	case BPF_JLE:
 297		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
 298		return;
 299	case BPF_JNE:
 300		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
 301		return;
 302	case BPF_JSGT:
 303		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
 304		return;
 305	case BPF_JSLT:
 306		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
 307		return;
 308	case BPF_JSGE:
 309		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
 310		return;
 311	case BPF_JSLE:
 312		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
 313	}
 314}
 315
 316static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
 317			struct rv_jit_context *ctx)
 318{
 319	s64 upper, lower;
 320
 321	if (is_13b_int(rvoff)) {
 322		emit_bcc(cond, rd, rs, rvoff, ctx);
 323		return;
 324	}
 325
 326	/* Adjust for jal */
 327	rvoff -= 4;
 328
 329	/* Transform, e.g.:
 330	 *   bne rd,rs,foo
 331	 * to
 332	 *   beq rd,rs,<.L1>
 333	 *   (auipc foo)
 334	 *   jal(r) foo
 335	 * .L1
 336	 */
 337	cond = invert_bpf_cond(cond);
 338	if (is_21b_int(rvoff)) {
 339		emit_bcc(cond, rd, rs, 8, ctx);
 340		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
 341		return;
 342	}
 343
 344	/* 32b No need for an additional rvoff adjustment, since we
 345	 * get that from the auipc at PC', where PC = PC' + 4.
 346	 */
 347	upper = (rvoff + (1 << 11)) >> 12;
 348	lower = rvoff & 0xfff;
 349
 350	emit_bcc(cond, rd, rs, 12, ctx);
 351	emit(rv_auipc(RV_REG_T1, upper), ctx);
 352	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
 353}
 354
 355static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
 356{
 357	int tc_ninsn, off, start_insn = ctx->ninsns;
 358	u8 tcc = rv_tail_call_reg(ctx);
 359
 360	/* a0: &ctx
 361	 * a1: &array
 362	 * a2: index
 363	 *
 364	 * if (index >= array->map.max_entries)
 365	 *	goto out;
 366	 */
 367	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
 368		   ctx->offset[0];
 369	emit_zextw(RV_REG_A2, RV_REG_A2, ctx);
 370
 371	off = offsetof(struct bpf_array, map.max_entries);
 372	if (is_12b_check(off, insn))
 373		return -1;
 374	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
 375	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 376	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
 377
 378	/* if (--TCC < 0)
 379	 *     goto out;
 380	 */
 381	emit_addi(RV_REG_TCC, tcc, -1, ctx);
 382	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 383	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
 384
 385	/* prog = array->ptrs[index];
 386	 * if (!prog)
 387	 *     goto out;
 388	 */
 389	emit_sh3add(RV_REG_T2, RV_REG_A2, RV_REG_A1, ctx);
 390	off = offsetof(struct bpf_array, ptrs);
 391	if (is_12b_check(off, insn))
 392		return -1;
 393	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
 394	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 395	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
 396
 397	/* goto *(prog->bpf_func + 4); */
 398	off = offsetof(struct bpf_prog, bpf_func);
 399	if (is_12b_check(off, insn))
 400		return -1;
 401	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
 402	__build_epilogue(true, ctx);
 403	return 0;
 404}
 405
 406static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
 407		      struct rv_jit_context *ctx)
 408{
 409	u8 code = insn->code;
 410
 411	switch (code) {
 412	case BPF_JMP | BPF_JA:
 413	case BPF_JMP | BPF_CALL:
 414	case BPF_JMP | BPF_EXIT:
 415	case BPF_JMP | BPF_TAIL_CALL:
 416		break;
 417	default:
 418		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
 419	}
 420
 421	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
 422	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
 423	    code & BPF_LDX || code & BPF_STX)
 424		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
 425}
 426
 427static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
 428			      struct rv_jit_context *ctx)
 429{
 430	s64 upper, lower;
 431
 432	if (rvoff && fixed_addr && is_21b_int(rvoff)) {
 433		emit(rv_jal(rd, rvoff >> 1), ctx);
 434		return 0;
 435	} else if (in_auipc_jalr_range(rvoff)) {
 436		upper = (rvoff + (1 << 11)) >> 12;
 437		lower = rvoff & 0xfff;
 438		emit(rv_auipc(RV_REG_T1, upper), ctx);
 439		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
 440		return 0;
 441	}
 442
 443	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
 444	return -ERANGE;
 445}
 446
 447static bool is_signed_bpf_cond(u8 cond)
 448{
 449	return cond == BPF_JSGT || cond == BPF_JSLT ||
 450		cond == BPF_JSGE || cond == BPF_JSLE;
 451}
 452
 453static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
 454{
 455	s64 off = 0;
 456	u64 ip;
 457
 458	if (addr && ctx->insns && ctx->ro_insns) {
 459		/*
 460		 * Use the ro_insns(RX) to calculate the offset as the BPF
 461		 * program will finally run from this memory region.
 462		 */
 463		ip = (u64)(long)(ctx->ro_insns + ctx->ninsns);
 464		off = addr - ip;
 465	}
 466
 467	return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
 468}
 469
 470static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx)
 471{
 472	if (IS_ENABLED(CONFIG_CFI_CLANG))
 473		emit(hash, ctx);
 474}
 475
 476static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
 477			struct rv_jit_context *ctx)
 478{
 479	u8 r0;
 480	int jmp_offset;
 481
 482	if (off) {
 483		if (is_12b_int(off)) {
 484			emit_addi(RV_REG_T1, rd, off, ctx);
 485		} else {
 486			emit_imm(RV_REG_T1, off, ctx);
 487			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
 488		}
 489		rd = RV_REG_T1;
 490	}
 491
 492	switch (imm) {
 493	/* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
 494	case BPF_ADD:
 495		emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
 496		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
 497		break;
 498	case BPF_AND:
 499		emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
 500		     rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
 501		break;
 502	case BPF_OR:
 503		emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
 504		     rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
 505		break;
 506	case BPF_XOR:
 507		emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
 508		     rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
 509		break;
 510	/* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
 511	case BPF_ADD | BPF_FETCH:
 512		emit(is64 ? rv_amoadd_d(rs, rs, rd, 1, 1) :
 513		     rv_amoadd_w(rs, rs, rd, 1, 1), ctx);
 514		if (!is64)
 515			emit_zextw(rs, rs, ctx);
 516		break;
 517	case BPF_AND | BPF_FETCH:
 518		emit(is64 ? rv_amoand_d(rs, rs, rd, 1, 1) :
 519		     rv_amoand_w(rs, rs, rd, 1, 1), ctx);
 520		if (!is64)
 521			emit_zextw(rs, rs, ctx);
 522		break;
 523	case BPF_OR | BPF_FETCH:
 524		emit(is64 ? rv_amoor_d(rs, rs, rd, 1, 1) :
 525		     rv_amoor_w(rs, rs, rd, 1, 1), ctx);
 526		if (!is64)
 527			emit_zextw(rs, rs, ctx);
 528		break;
 529	case BPF_XOR | BPF_FETCH:
 530		emit(is64 ? rv_amoxor_d(rs, rs, rd, 1, 1) :
 531		     rv_amoxor_w(rs, rs, rd, 1, 1), ctx);
 532		if (!is64)
 533			emit_zextw(rs, rs, ctx);
 534		break;
 535	/* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
 536	case BPF_XCHG:
 537		emit(is64 ? rv_amoswap_d(rs, rs, rd, 1, 1) :
 538		     rv_amoswap_w(rs, rs, rd, 1, 1), ctx);
 539		if (!is64)
 540			emit_zextw(rs, rs, ctx);
 541		break;
 542	/* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
 543	case BPF_CMPXCHG:
 544		r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
 545		if (is64)
 546			emit_mv(RV_REG_T2, r0, ctx);
 547		else
 548			emit_addiw(RV_REG_T2, r0, 0, ctx);
 549		emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
 550		     rv_lr_w(r0, 0, rd, 0, 0), ctx);
 551		jmp_offset = ninsns_rvoff(8);
 552		emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
 553		emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 1) :
 554		     rv_sc_w(RV_REG_T3, rs, rd, 0, 1), ctx);
 555		jmp_offset = ninsns_rvoff(-6);
 556		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
 557		emit(rv_fence(0x3, 0x3), ctx);
 558		break;
 559	}
 560}
 561
 562#define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
 563#define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
 564#define REG_DONT_CLEAR_MARKER	0	/* RV_REG_ZERO unused in pt_regmap */
 565
 566bool ex_handler_bpf(const struct exception_table_entry *ex,
 567		    struct pt_regs *regs)
 568{
 569	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
 570	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
 571
 572	if (regs_offset != REG_DONT_CLEAR_MARKER)
 573		*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
 574	regs->epc = (unsigned long)&ex->fixup - offset;
 575
 576	return true;
 577}
 578
 579/* For accesses to BTF pointers, add an entry to the exception table */
 580static int add_exception_handler(const struct bpf_insn *insn,
 581				 struct rv_jit_context *ctx,
 582				 int dst_reg, int insn_len)
 583{
 584	struct exception_table_entry *ex;
 585	unsigned long pc;
 586	off_t ins_offset;
 587	off_t fixup_offset;
 588
 589	if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
 590	    (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
 591	     BPF_MODE(insn->code) != BPF_PROBE_MEM32))
 592		return 0;
 593
 594	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
 595		return -EINVAL;
 596
 597	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
 598		return -EINVAL;
 599
 600	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
 601		return -EINVAL;
 602
 603	ex = &ctx->prog->aux->extable[ctx->nexentries];
 604	pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len];
 605
 606	/*
 607	 * This is the relative offset of the instruction that may fault from
 608	 * the exception table itself. This will be written to the exception
 609	 * table and if this instruction faults, the destination register will
 610	 * be set to '0' and the execution will jump to the next instruction.
 611	 */
 612	ins_offset = pc - (long)&ex->insn;
 613	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
 614		return -ERANGE;
 615
 616	/*
 617	 * Since the extable follows the program, the fixup offset is always
 618	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
 619	 * to keep things simple, and put the destination register in the upper
 620	 * bits. We don't need to worry about buildtime or runtime sort
 621	 * modifying the upper bits because the table is already sorted, and
 622	 * isn't part of the main exception table.
 623	 *
 624	 * The fixup_offset is set to the next instruction from the instruction
 625	 * that may fault. The execution will jump to this after handling the
 626	 * fault.
 627	 */
 628	fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
 629	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
 630		return -ERANGE;
 631
 632	/*
 633	 * The offsets above have been calculated using the RO buffer but we
 634	 * need to use the R/W buffer for writes.
 635	 * switch ex to rw buffer for writing.
 636	 */
 637	ex = (void *)ctx->insns + ((void *)ex - (void *)ctx->ro_insns);
 638
 639	ex->insn = ins_offset;
 640
 641	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
 642		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
 643	ex->type = EX_TYPE_BPF;
 644
 645	ctx->nexentries++;
 646	return 0;
 647}
 648
 649static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
 650{
 651	s64 rvoff;
 652	struct rv_jit_context ctx;
 653
 654	ctx.ninsns = 0;
 655	ctx.insns = (u16 *)insns;
 656
 657	if (!target) {
 658		emit(rv_nop(), &ctx);
 659		emit(rv_nop(), &ctx);
 660		return 0;
 661	}
 662
 663	rvoff = (s64)(target - ip);
 664	return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
 665}
 666
 667int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
 668		       void *old_addr, void *new_addr)
 669{
 670	u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
 671	bool is_call = poke_type == BPF_MOD_CALL;
 672	int ret;
 673
 674	if (!is_kernel_text((unsigned long)ip) &&
 675	    !is_bpf_text_address((unsigned long)ip))
 676		return -ENOTSUPP;
 677
 678	ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
 679	if (ret)
 680		return ret;
 681
 682	if (memcmp(ip, old_insns, RV_FENTRY_NBYTES))
 683		return -EFAULT;
 684
 685	ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
 686	if (ret)
 687		return ret;
 688
 689	cpus_read_lock();
 690	mutex_lock(&text_mutex);
 691	if (memcmp(ip, new_insns, RV_FENTRY_NBYTES))
 692		ret = patch_text(ip, new_insns, RV_FENTRY_NBYTES);
 693	mutex_unlock(&text_mutex);
 694	cpus_read_unlock();
 695
 696	return ret;
 697}
 698
 699static void store_args(int nr_arg_slots, int args_off, struct rv_jit_context *ctx)
 700{
 701	int i;
 702
 703	for (i = 0; i < nr_arg_slots; i++) {
 704		if (i < RV_MAX_REG_ARGS) {
 705			emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
 706		} else {
 707			/* skip slots for T0 and FP of traced function */
 708			emit_ld(RV_REG_T1, 16 + (i - RV_MAX_REG_ARGS) * 8, RV_REG_FP, ctx);
 709			emit_sd(RV_REG_FP, -args_off, RV_REG_T1, ctx);
 710		}
 711		args_off -= 8;
 712	}
 713}
 714
 715static void restore_args(int nr_reg_args, int args_off, struct rv_jit_context *ctx)
 716{
 717	int i;
 718
 719	for (i = 0; i < nr_reg_args; i++) {
 720		emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
 721		args_off -= 8;
 722	}
 723}
 724
 725static void restore_stack_args(int nr_stack_args, int args_off, int stk_arg_off,
 726			       struct rv_jit_context *ctx)
 727{
 728	int i;
 729
 730	for (i = 0; i < nr_stack_args; i++) {
 731		emit_ld(RV_REG_T1, -(args_off - RV_MAX_REG_ARGS * 8), RV_REG_FP, ctx);
 732		emit_sd(RV_REG_FP, -stk_arg_off, RV_REG_T1, ctx);
 733		args_off -= 8;
 734		stk_arg_off -= 8;
 735	}
 736}
 737
 738static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
 739			   int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
 740{
 741	int ret, branch_off;
 742	struct bpf_prog *p = l->link.prog;
 743	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
 744
 745	if (l->cookie) {
 746		emit_imm(RV_REG_T1, l->cookie, ctx);
 747		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
 748	} else {
 749		emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
 750	}
 751
 752	/* arg1: prog */
 753	emit_imm(RV_REG_A0, (const s64)p, ctx);
 754	/* arg2: &run_ctx */
 755	emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
 756	ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
 757	if (ret)
 758		return ret;
 759
 760	/* store prog start time */
 761	emit_mv(RV_REG_S1, RV_REG_A0, ctx);
 762
 763	/* if (__bpf_prog_enter(prog) == 0)
 764	 *	goto skip_exec_of_prog;
 765	 */
 766	branch_off = ctx->ninsns;
 767	/* nop reserved for conditional jump */
 768	emit(rv_nop(), ctx);
 769
 770	/* arg1: &args_off */
 771	emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
 772	if (!p->jited)
 773		/* arg2: progs[i]->insnsi for interpreter */
 774		emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
 775	ret = emit_call((const u64)p->bpf_func, true, ctx);
 776	if (ret)
 777		return ret;
 778
 779	if (save_ret) {
 780		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
 781		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
 782	}
 783
 784	/* update branch with beqz */
 785	if (ctx->insns) {
 786		int offset = ninsns_rvoff(ctx->ninsns - branch_off);
 787		u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
 788		*(u32 *)(ctx->insns + branch_off) = insn;
 789	}
 790
 791	/* arg1: prog */
 792	emit_imm(RV_REG_A0, (const s64)p, ctx);
 793	/* arg2: prog start time */
 794	emit_mv(RV_REG_A1, RV_REG_S1, ctx);
 795	/* arg3: &run_ctx */
 796	emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
 797	ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
 798
 799	return ret;
 800}
 801
 802static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
 803					 const struct btf_func_model *m,
 804					 struct bpf_tramp_links *tlinks,
 805					 void *func_addr, u32 flags,
 806					 struct rv_jit_context *ctx)
 807{
 808	int i, ret, offset;
 809	int *branches_off = NULL;
 810	int stack_size = 0, nr_arg_slots = 0;
 811	int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off, stk_arg_off;
 812	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
 813	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
 814	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
 815	bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
 816	void *orig_call = func_addr;
 817	bool save_ret;
 818	u32 insn;
 819
 820	/* Two types of generated trampoline stack layout:
 821	 *
 822	 * 1. trampoline called from function entry
 823	 * --------------------------------------
 824	 * FP + 8	    [ RA to parent func	] return address to parent
 825	 *					  function
 826	 * FP + 0	    [ FP of parent func ] frame pointer of parent
 827	 *					  function
 828	 * FP - 8           [ T0 to traced func ] return address of traced
 829	 *					  function
 830	 * FP - 16	    [ FP of traced func ] frame pointer of traced
 831	 *					  function
 832	 * --------------------------------------
 833	 *
 834	 * 2. trampoline called directly
 835	 * --------------------------------------
 836	 * FP - 8	    [ RA to caller func ] return address to caller
 837	 *					  function
 838	 * FP - 16	    [ FP of caller func	] frame pointer of caller
 839	 *					  function
 840	 * --------------------------------------
 841	 *
 842	 * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
 843	 *					  BPF_TRAMP_F_RET_FENTRY_RET
 844	 *                  [ argN              ]
 845	 *                  [ ...               ]
 846	 * FP - args_off    [ arg1              ]
 847	 *
 848	 * FP - nregs_off   [ regs count        ]
 849	 *
 850	 * FP - ip_off      [ traced func	] BPF_TRAMP_F_IP_ARG
 851	 *
 852	 * FP - run_ctx_off [ bpf_tramp_run_ctx ]
 853	 *
 854	 * FP - sreg_off    [ callee saved reg	]
 855	 *
 856	 *		    [ pads              ] pads for 16 bytes alignment
 857	 *
 858	 *		    [ stack_argN        ]
 859	 *		    [ ...               ]
 860	 * FP - stk_arg_off [ stack_arg1        ] BPF_TRAMP_F_CALL_ORIG
 861	 */
 862
 863	if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
 864		return -ENOTSUPP;
 865
 866	if (m->nr_args > MAX_BPF_FUNC_ARGS)
 867		return -ENOTSUPP;
 868
 869	for (i = 0; i < m->nr_args; i++)
 870		nr_arg_slots += round_up(m->arg_size[i], 8) / 8;
 871
 872	/* room of trampoline frame to store return address and frame pointer */
 873	stack_size += 16;
 874
 875	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
 876	if (save_ret) {
 877		stack_size += 16; /* Save both A5 (BPF R0) and A0 */
 878		retval_off = stack_size;
 879	}
 880
 881	stack_size += nr_arg_slots * 8;
 882	args_off = stack_size;
 883
 884	stack_size += 8;
 885	nregs_off = stack_size;
 886
 887	if (flags & BPF_TRAMP_F_IP_ARG) {
 888		stack_size += 8;
 889		ip_off = stack_size;
 890	}
 891
 892	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
 893	run_ctx_off = stack_size;
 894
 895	stack_size += 8;
 896	sreg_off = stack_size;
 897
 898	if ((flags & BPF_TRAMP_F_CALL_ORIG) && (nr_arg_slots - RV_MAX_REG_ARGS > 0))
 899		stack_size += (nr_arg_slots - RV_MAX_REG_ARGS) * 8;
 900
 901	stack_size = round_up(stack_size, STACK_ALIGN);
 902
 903	/* room for args on stack must be at the top of stack */
 904	stk_arg_off = stack_size;
 905
 906	if (!is_struct_ops) {
 907		/* For the trampoline called from function entry,
 908		 * the frame of traced function and the frame of
 909		 * trampoline need to be considered.
 910		 */
 911		emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
 912		emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
 913		emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
 914		emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
 915
 916		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
 917		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
 918		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
 919		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
 920	} else {
 921		/* emit kcfi hash */
 922		emit_kcfi(cfi_get_func_hash(func_addr), ctx);
 923		/* For the trampoline called directly, just handle
 924		 * the frame of trampoline.
 925		 */
 926		emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
 927		emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
 928		emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
 929		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
 930	}
 931
 932	/* callee saved register S1 to pass start time */
 933	emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
 934
 935	/* store ip address of the traced function */
 936	if (flags & BPF_TRAMP_F_IP_ARG) {
 937		emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
 938		emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
 939	}
 940
 941	emit_li(RV_REG_T1, nr_arg_slots, ctx);
 942	emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
 943
 944	store_args(nr_arg_slots, args_off, ctx);
 945
 946	/* skip to actual body of traced function */
 947	if (flags & BPF_TRAMP_F_SKIP_FRAME)
 948		orig_call += RV_FENTRY_NINSNS * 4;
 949
 950	if (flags & BPF_TRAMP_F_CALL_ORIG) {
 951		emit_imm(RV_REG_A0, ctx->insns ? (const s64)im : RV_MAX_COUNT_IMM, ctx);
 952		ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
 953		if (ret)
 954			return ret;
 955	}
 956
 957	for (i = 0; i < fentry->nr_links; i++) {
 958		ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
 959				      flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
 960		if (ret)
 961			return ret;
 962	}
 963
 964	if (fmod_ret->nr_links) {
 965		branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
 966		if (!branches_off)
 967			return -ENOMEM;
 968
 969		/* cleanup to avoid garbage return value confusion */
 970		emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
 971		for (i = 0; i < fmod_ret->nr_links; i++) {
 972			ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
 973					      run_ctx_off, true, ctx);
 974			if (ret)
 975				goto out;
 976			emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
 977			branches_off[i] = ctx->ninsns;
 978			/* nop reserved for conditional jump */
 979			emit(rv_nop(), ctx);
 980		}
 981	}
 982
 983	if (flags & BPF_TRAMP_F_CALL_ORIG) {
 984		restore_args(min_t(int, nr_arg_slots, RV_MAX_REG_ARGS), args_off, ctx);
 985		restore_stack_args(nr_arg_slots - RV_MAX_REG_ARGS, args_off, stk_arg_off, ctx);
 986		ret = emit_call((const u64)orig_call, true, ctx);
 987		if (ret)
 988			goto out;
 989		emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
 990		emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
 991		im->ip_after_call = ctx->ro_insns + ctx->ninsns;
 992		/* 2 nops reserved for auipc+jalr pair */
 993		emit(rv_nop(), ctx);
 994		emit(rv_nop(), ctx);
 995	}
 996
 997	/* update branches saved in invoke_bpf_mod_ret with bnez */
 998	for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
 999		offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
1000		insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
1001		*(u32 *)(ctx->insns + branches_off[i]) = insn;
1002	}
1003
1004	for (i = 0; i < fexit->nr_links; i++) {
1005		ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
1006				      run_ctx_off, false, ctx);
1007		if (ret)
1008			goto out;
1009	}
1010
1011	if (flags & BPF_TRAMP_F_CALL_ORIG) {
1012		im->ip_epilogue = ctx->ro_insns + ctx->ninsns;
1013		emit_imm(RV_REG_A0, ctx->insns ? (const s64)im : RV_MAX_COUNT_IMM, ctx);
1014		ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
1015		if (ret)
1016			goto out;
1017	}
1018
1019	if (flags & BPF_TRAMP_F_RESTORE_REGS)
1020		restore_args(min_t(int, nr_arg_slots, RV_MAX_REG_ARGS), args_off, ctx);
1021
1022	if (save_ret) {
1023		emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
1024		emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
1025	}
1026
1027	emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
1028
1029	if (!is_struct_ops) {
1030		/* trampoline called from function entry */
1031		emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
1032		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1033		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1034
1035		emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
1036		emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
1037		emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
1038
1039		if (flags & BPF_TRAMP_F_SKIP_FRAME)
1040			/* return to parent function */
1041			emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1042		else
1043			/* return to traced function */
1044			emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
1045	} else {
1046		/* trampoline called directly */
1047		emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
1048		emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1049		emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1050
1051		emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1052	}
1053
1054	ret = ctx->ninsns;
1055out:
1056	kfree(branches_off);
1057	return ret;
1058}
1059
1060int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
1061			     struct bpf_tramp_links *tlinks, void *func_addr)
1062{
1063	struct bpf_tramp_image im;
1064	struct rv_jit_context ctx;
1065	int ret;
1066
1067	ctx.ninsns = 0;
1068	ctx.insns = NULL;
1069	ctx.ro_insns = NULL;
1070	ret = __arch_prepare_bpf_trampoline(&im, m, tlinks, func_addr, flags, &ctx);
1071
1072	return ret < 0 ? ret : ninsns_rvoff(ctx.ninsns);
1073}
1074
1075void *arch_alloc_bpf_trampoline(unsigned int size)
1076{
1077	return bpf_prog_pack_alloc(size, bpf_fill_ill_insns);
1078}
1079
1080void arch_free_bpf_trampoline(void *image, unsigned int size)
1081{
1082	bpf_prog_pack_free(image, size);
1083}
1084
1085int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image,
1086				void *ro_image_end, const struct btf_func_model *m,
1087				u32 flags, struct bpf_tramp_links *tlinks,
1088				void *func_addr)
1089{
1090	int ret;
1091	void *image, *res;
1092	struct rv_jit_context ctx;
1093	u32 size = ro_image_end - ro_image;
1094
1095	image = kvmalloc(size, GFP_KERNEL);
1096	if (!image)
1097		return -ENOMEM;
1098
1099	ctx.ninsns = 0;
1100	ctx.insns = image;
1101	ctx.ro_insns = ro_image;
1102	ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1103	if (ret < 0)
1104		goto out;
1105
1106	if (WARN_ON(size < ninsns_rvoff(ctx.ninsns))) {
1107		ret = -E2BIG;
1108		goto out;
1109	}
1110
1111	res = bpf_arch_text_copy(ro_image, image, size);
1112	if (IS_ERR(res)) {
1113		ret = PTR_ERR(res);
1114		goto out;
1115	}
1116
1117	bpf_flush_icache(ro_image, ro_image_end);
1118out:
1119	kvfree(image);
1120	return ret < 0 ? ret : size;
1121}
1122
1123int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1124		      bool extra_pass)
1125{
1126	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1127		    BPF_CLASS(insn->code) == BPF_JMP;
1128	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1129	struct bpf_prog_aux *aux = ctx->prog->aux;
1130	u8 rd = -1, rs = -1, code = insn->code;
1131	s16 off = insn->off;
1132	s32 imm = insn->imm;
1133
1134	init_regs(&rd, &rs, insn, ctx);
1135
1136	switch (code) {
1137	/* dst = src */
1138	case BPF_ALU | BPF_MOV | BPF_X:
1139	case BPF_ALU64 | BPF_MOV | BPF_X:
1140		if (insn_is_cast_user(insn)) {
1141			emit_mv(RV_REG_T1, rs, ctx);
1142			emit_zextw(RV_REG_T1, RV_REG_T1, ctx);
1143			emit_imm(rd, (ctx->user_vm_start >> 32) << 32, ctx);
1144			emit(rv_beq(RV_REG_T1, RV_REG_ZERO, 4), ctx);
1145			emit_or(RV_REG_T1, rd, RV_REG_T1, ctx);
1146			emit_mv(rd, RV_REG_T1, ctx);
1147			break;
1148		} else if (insn_is_mov_percpu_addr(insn)) {
1149			if (rd != rs)
1150				emit_mv(rd, rs, ctx);
1151#ifdef CONFIG_SMP
1152			/* Load current CPU number in T1 */
1153			emit_ld(RV_REG_T1, offsetof(struct thread_info, cpu),
1154				RV_REG_TP, ctx);
1155			/* Load address of __per_cpu_offset array in T2 */
1156			emit_addr(RV_REG_T2, (u64)&__per_cpu_offset, extra_pass, ctx);
1157			/* Get address of __per_cpu_offset[cpu] in T1 */
1158			emit_sh3add(RV_REG_T1, RV_REG_T1, RV_REG_T2, ctx);
1159			/* Load __per_cpu_offset[cpu] in T1 */
1160			emit_ld(RV_REG_T1, 0, RV_REG_T1, ctx);
1161			/* Add the offset to Rd */
1162			emit_add(rd, rd, RV_REG_T1, ctx);
1163#endif
1164		}
1165		if (imm == 1) {
1166			/* Special mov32 for zext */
1167			emit_zextw(rd, rd, ctx);
1168			break;
1169		}
1170		switch (insn->off) {
1171		case 0:
1172			emit_mv(rd, rs, ctx);
1173			break;
1174		case 8:
1175			emit_sextb(rd, rs, ctx);
1176			break;
1177		case 16:
1178			emit_sexth(rd, rs, ctx);
1179			break;
1180		case 32:
1181			emit_sextw(rd, rs, ctx);
1182			break;
1183		}
1184		if (!is64 && !aux->verifier_zext)
1185			emit_zextw(rd, rd, ctx);
1186		break;
1187
1188	/* dst = dst OP src */
1189	case BPF_ALU | BPF_ADD | BPF_X:
1190	case BPF_ALU64 | BPF_ADD | BPF_X:
1191		emit_add(rd, rd, rs, ctx);
1192		if (!is64 && !aux->verifier_zext)
1193			emit_zextw(rd, rd, ctx);
1194		break;
1195	case BPF_ALU | BPF_SUB | BPF_X:
1196	case BPF_ALU64 | BPF_SUB | BPF_X:
1197		if (is64)
1198			emit_sub(rd, rd, rs, ctx);
1199		else
1200			emit_subw(rd, rd, rs, ctx);
1201
1202		if (!is64 && !aux->verifier_zext)
1203			emit_zextw(rd, rd, ctx);
1204		break;
1205	case BPF_ALU | BPF_AND | BPF_X:
1206	case BPF_ALU64 | BPF_AND | BPF_X:
1207		emit_and(rd, rd, rs, ctx);
1208		if (!is64 && !aux->verifier_zext)
1209			emit_zextw(rd, rd, ctx);
1210		break;
1211	case BPF_ALU | BPF_OR | BPF_X:
1212	case BPF_ALU64 | BPF_OR | BPF_X:
1213		emit_or(rd, rd, rs, ctx);
1214		if (!is64 && !aux->verifier_zext)
1215			emit_zextw(rd, rd, ctx);
1216		break;
1217	case BPF_ALU | BPF_XOR | BPF_X:
1218	case BPF_ALU64 | BPF_XOR | BPF_X:
1219		emit_xor(rd, rd, rs, ctx);
1220		if (!is64 && !aux->verifier_zext)
1221			emit_zextw(rd, rd, ctx);
1222		break;
1223	case BPF_ALU | BPF_MUL | BPF_X:
1224	case BPF_ALU64 | BPF_MUL | BPF_X:
1225		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1226		if (!is64 && !aux->verifier_zext)
1227			emit_zextw(rd, rd, ctx);
1228		break;
1229	case BPF_ALU | BPF_DIV | BPF_X:
1230	case BPF_ALU64 | BPF_DIV | BPF_X:
1231		if (off)
1232			emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1233		else
1234			emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1235		if (!is64 && !aux->verifier_zext)
1236			emit_zextw(rd, rd, ctx);
1237		break;
1238	case BPF_ALU | BPF_MOD | BPF_X:
1239	case BPF_ALU64 | BPF_MOD | BPF_X:
1240		if (off)
1241			emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1242		else
1243			emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1244		if (!is64 && !aux->verifier_zext)
1245			emit_zextw(rd, rd, ctx);
1246		break;
1247	case BPF_ALU | BPF_LSH | BPF_X:
1248	case BPF_ALU64 | BPF_LSH | BPF_X:
1249		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1250		if (!is64 && !aux->verifier_zext)
1251			emit_zextw(rd, rd, ctx);
1252		break;
1253	case BPF_ALU | BPF_RSH | BPF_X:
1254	case BPF_ALU64 | BPF_RSH | BPF_X:
1255		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1256		if (!is64 && !aux->verifier_zext)
1257			emit_zextw(rd, rd, ctx);
1258		break;
1259	case BPF_ALU | BPF_ARSH | BPF_X:
1260	case BPF_ALU64 | BPF_ARSH | BPF_X:
1261		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1262		if (!is64 && !aux->verifier_zext)
1263			emit_zextw(rd, rd, ctx);
1264		break;
1265
1266	/* dst = -dst */
1267	case BPF_ALU | BPF_NEG:
1268	case BPF_ALU64 | BPF_NEG:
1269		emit_sub(rd, RV_REG_ZERO, rd, ctx);
1270		if (!is64 && !aux->verifier_zext)
1271			emit_zextw(rd, rd, ctx);
1272		break;
1273
1274	/* dst = BSWAP##imm(dst) */
1275	case BPF_ALU | BPF_END | BPF_FROM_LE:
1276		switch (imm) {
1277		case 16:
1278			emit_zexth(rd, rd, ctx);
1279			break;
1280		case 32:
1281			if (!aux->verifier_zext)
1282				emit_zextw(rd, rd, ctx);
1283			break;
1284		case 64:
1285			/* Do nothing */
1286			break;
1287		}
1288		break;
1289	case BPF_ALU | BPF_END | BPF_FROM_BE:
1290	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1291		emit_bswap(rd, imm, ctx);
1292		break;
1293
1294	/* dst = imm */
1295	case BPF_ALU | BPF_MOV | BPF_K:
1296	case BPF_ALU64 | BPF_MOV | BPF_K:
1297		emit_imm(rd, imm, ctx);
1298		if (!is64 && !aux->verifier_zext)
1299			emit_zextw(rd, rd, ctx);
1300		break;
1301
1302	/* dst = dst OP imm */
1303	case BPF_ALU | BPF_ADD | BPF_K:
1304	case BPF_ALU64 | BPF_ADD | BPF_K:
1305		if (is_12b_int(imm)) {
1306			emit_addi(rd, rd, imm, ctx);
1307		} else {
1308			emit_imm(RV_REG_T1, imm, ctx);
1309			emit_add(rd, rd, RV_REG_T1, ctx);
1310		}
1311		if (!is64 && !aux->verifier_zext)
1312			emit_zextw(rd, rd, ctx);
1313		break;
1314	case BPF_ALU | BPF_SUB | BPF_K:
1315	case BPF_ALU64 | BPF_SUB | BPF_K:
1316		if (is_12b_int(-imm)) {
1317			emit_addi(rd, rd, -imm, ctx);
1318		} else {
1319			emit_imm(RV_REG_T1, imm, ctx);
1320			emit_sub(rd, rd, RV_REG_T1, ctx);
1321		}
1322		if (!is64 && !aux->verifier_zext)
1323			emit_zextw(rd, rd, ctx);
1324		break;
1325	case BPF_ALU | BPF_AND | BPF_K:
1326	case BPF_ALU64 | BPF_AND | BPF_K:
1327		if (is_12b_int(imm)) {
1328			emit_andi(rd, rd, imm, ctx);
1329		} else {
1330			emit_imm(RV_REG_T1, imm, ctx);
1331			emit_and(rd, rd, RV_REG_T1, ctx);
1332		}
1333		if (!is64 && !aux->verifier_zext)
1334			emit_zextw(rd, rd, ctx);
1335		break;
1336	case BPF_ALU | BPF_OR | BPF_K:
1337	case BPF_ALU64 | BPF_OR | BPF_K:
1338		if (is_12b_int(imm)) {
1339			emit(rv_ori(rd, rd, imm), ctx);
1340		} else {
1341			emit_imm(RV_REG_T1, imm, ctx);
1342			emit_or(rd, rd, RV_REG_T1, ctx);
1343		}
1344		if (!is64 && !aux->verifier_zext)
1345			emit_zextw(rd, rd, ctx);
1346		break;
1347	case BPF_ALU | BPF_XOR | BPF_K:
1348	case BPF_ALU64 | BPF_XOR | BPF_K:
1349		if (is_12b_int(imm)) {
1350			emit(rv_xori(rd, rd, imm), ctx);
1351		} else {
1352			emit_imm(RV_REG_T1, imm, ctx);
1353			emit_xor(rd, rd, RV_REG_T1, ctx);
1354		}
1355		if (!is64 && !aux->verifier_zext)
1356			emit_zextw(rd, rd, ctx);
1357		break;
1358	case BPF_ALU | BPF_MUL | BPF_K:
1359	case BPF_ALU64 | BPF_MUL | BPF_K:
1360		emit_imm(RV_REG_T1, imm, ctx);
1361		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1362		     rv_mulw(rd, rd, RV_REG_T1), ctx);
1363		if (!is64 && !aux->verifier_zext)
1364			emit_zextw(rd, rd, ctx);
1365		break;
1366	case BPF_ALU | BPF_DIV | BPF_K:
1367	case BPF_ALU64 | BPF_DIV | BPF_K:
1368		emit_imm(RV_REG_T1, imm, ctx);
1369		if (off)
1370			emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1371			     rv_divw(rd, rd, RV_REG_T1), ctx);
1372		else
1373			emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1374			     rv_divuw(rd, rd, RV_REG_T1), ctx);
1375		if (!is64 && !aux->verifier_zext)
1376			emit_zextw(rd, rd, ctx);
1377		break;
1378	case BPF_ALU | BPF_MOD | BPF_K:
1379	case BPF_ALU64 | BPF_MOD | BPF_K:
1380		emit_imm(RV_REG_T1, imm, ctx);
1381		if (off)
1382			emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1383			     rv_remw(rd, rd, RV_REG_T1), ctx);
1384		else
1385			emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1386			     rv_remuw(rd, rd, RV_REG_T1), ctx);
1387		if (!is64 && !aux->verifier_zext)
1388			emit_zextw(rd, rd, ctx);
1389		break;
1390	case BPF_ALU | BPF_LSH | BPF_K:
1391	case BPF_ALU64 | BPF_LSH | BPF_K:
1392		emit_slli(rd, rd, imm, ctx);
1393
1394		if (!is64 && !aux->verifier_zext)
1395			emit_zextw(rd, rd, ctx);
1396		break;
1397	case BPF_ALU | BPF_RSH | BPF_K:
1398	case BPF_ALU64 | BPF_RSH | BPF_K:
1399		if (is64)
1400			emit_srli(rd, rd, imm, ctx);
1401		else
1402			emit(rv_srliw(rd, rd, imm), ctx);
1403
1404		if (!is64 && !aux->verifier_zext)
1405			emit_zextw(rd, rd, ctx);
1406		break;
1407	case BPF_ALU | BPF_ARSH | BPF_K:
1408	case BPF_ALU64 | BPF_ARSH | BPF_K:
1409		if (is64)
1410			emit_srai(rd, rd, imm, ctx);
1411		else
1412			emit(rv_sraiw(rd, rd, imm), ctx);
1413
1414		if (!is64 && !aux->verifier_zext)
1415			emit_zextw(rd, rd, ctx);
1416		break;
1417
1418	/* JUMP off */
1419	case BPF_JMP | BPF_JA:
1420	case BPF_JMP32 | BPF_JA:
1421		if (BPF_CLASS(code) == BPF_JMP)
1422			rvoff = rv_offset(i, off, ctx);
1423		else
1424			rvoff = rv_offset(i, imm, ctx);
1425		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1426		if (ret)
1427			return ret;
1428		break;
1429
1430	/* IF (dst COND src) JUMP off */
1431	case BPF_JMP | BPF_JEQ | BPF_X:
1432	case BPF_JMP32 | BPF_JEQ | BPF_X:
1433	case BPF_JMP | BPF_JGT | BPF_X:
1434	case BPF_JMP32 | BPF_JGT | BPF_X:
1435	case BPF_JMP | BPF_JLT | BPF_X:
1436	case BPF_JMP32 | BPF_JLT | BPF_X:
1437	case BPF_JMP | BPF_JGE | BPF_X:
1438	case BPF_JMP32 | BPF_JGE | BPF_X:
1439	case BPF_JMP | BPF_JLE | BPF_X:
1440	case BPF_JMP32 | BPF_JLE | BPF_X:
1441	case BPF_JMP | BPF_JNE | BPF_X:
1442	case BPF_JMP32 | BPF_JNE | BPF_X:
1443	case BPF_JMP | BPF_JSGT | BPF_X:
1444	case BPF_JMP32 | BPF_JSGT | BPF_X:
1445	case BPF_JMP | BPF_JSLT | BPF_X:
1446	case BPF_JMP32 | BPF_JSLT | BPF_X:
1447	case BPF_JMP | BPF_JSGE | BPF_X:
1448	case BPF_JMP32 | BPF_JSGE | BPF_X:
1449	case BPF_JMP | BPF_JSLE | BPF_X:
1450	case BPF_JMP32 | BPF_JSLE | BPF_X:
1451	case BPF_JMP | BPF_JSET | BPF_X:
1452	case BPF_JMP32 | BPF_JSET | BPF_X:
1453		rvoff = rv_offset(i, off, ctx);
1454		if (!is64) {
1455			s = ctx->ninsns;
1456			if (is_signed_bpf_cond(BPF_OP(code))) {
1457				emit_sextw_alt(&rs, RV_REG_T1, ctx);
1458				emit_sextw_alt(&rd, RV_REG_T2, ctx);
1459			} else {
1460				emit_zextw_alt(&rs, RV_REG_T1, ctx);
1461				emit_zextw_alt(&rd, RV_REG_T2, ctx);
1462			}
1463			e = ctx->ninsns;
1464
1465			/* Adjust for extra insns */
1466			rvoff -= ninsns_rvoff(e - s);
1467		}
1468
1469		if (BPF_OP(code) == BPF_JSET) {
1470			/* Adjust for and */
1471			rvoff -= 4;
1472			emit_and(RV_REG_T1, rd, rs, ctx);
1473			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1474		} else {
1475			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1476		}
1477		break;
1478
1479	/* IF (dst COND imm) JUMP off */
1480	case BPF_JMP | BPF_JEQ | BPF_K:
1481	case BPF_JMP32 | BPF_JEQ | BPF_K:
1482	case BPF_JMP | BPF_JGT | BPF_K:
1483	case BPF_JMP32 | BPF_JGT | BPF_K:
1484	case BPF_JMP | BPF_JLT | BPF_K:
1485	case BPF_JMP32 | BPF_JLT | BPF_K:
1486	case BPF_JMP | BPF_JGE | BPF_K:
1487	case BPF_JMP32 | BPF_JGE | BPF_K:
1488	case BPF_JMP | BPF_JLE | BPF_K:
1489	case BPF_JMP32 | BPF_JLE | BPF_K:
1490	case BPF_JMP | BPF_JNE | BPF_K:
1491	case BPF_JMP32 | BPF_JNE | BPF_K:
1492	case BPF_JMP | BPF_JSGT | BPF_K:
1493	case BPF_JMP32 | BPF_JSGT | BPF_K:
1494	case BPF_JMP | BPF_JSLT | BPF_K:
1495	case BPF_JMP32 | BPF_JSLT | BPF_K:
1496	case BPF_JMP | BPF_JSGE | BPF_K:
1497	case BPF_JMP32 | BPF_JSGE | BPF_K:
1498	case BPF_JMP | BPF_JSLE | BPF_K:
1499	case BPF_JMP32 | BPF_JSLE | BPF_K:
1500		rvoff = rv_offset(i, off, ctx);
1501		s = ctx->ninsns;
1502		if (imm)
1503			emit_imm(RV_REG_T1, imm, ctx);
1504		rs = imm ? RV_REG_T1 : RV_REG_ZERO;
1505		if (!is64) {
1506			if (is_signed_bpf_cond(BPF_OP(code))) {
1507				emit_sextw_alt(&rd, RV_REG_T2, ctx);
1508				/* rs has been sign extended */
1509			} else {
1510				emit_zextw_alt(&rd, RV_REG_T2, ctx);
1511				if (imm)
1512					emit_zextw(rs, rs, ctx);
1513			}
1514		}
1515		e = ctx->ninsns;
1516
1517		/* Adjust for extra insns */
1518		rvoff -= ninsns_rvoff(e - s);
1519		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1520		break;
1521
1522	case BPF_JMP | BPF_JSET | BPF_K:
1523	case BPF_JMP32 | BPF_JSET | BPF_K:
1524		rvoff = rv_offset(i, off, ctx);
1525		s = ctx->ninsns;
1526		if (is_12b_int(imm)) {
1527			emit_andi(RV_REG_T1, rd, imm, ctx);
1528		} else {
1529			emit_imm(RV_REG_T1, imm, ctx);
1530			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1531		}
1532		/* For jset32, we should clear the upper 32 bits of t1, but
1533		 * sign-extension is sufficient here and saves one instruction,
1534		 * as t1 is used only in comparison against zero.
1535		 */
1536		if (!is64 && imm < 0)
1537			emit_sextw(RV_REG_T1, RV_REG_T1, ctx);
1538		e = ctx->ninsns;
1539		rvoff -= ninsns_rvoff(e - s);
1540		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1541		break;
1542
1543	/* function call */
1544	case BPF_JMP | BPF_CALL:
1545	{
1546		bool fixed_addr;
1547		u64 addr;
1548
1549		/* Inline calls to bpf_get_smp_processor_id()
1550		 *
1551		 * RV_REG_TP holds the address of the current CPU's task_struct and thread_info is
1552		 * at offset 0 in task_struct.
1553		 * Load cpu from thread_info:
1554		 *     Set R0 to ((struct thread_info *)(RV_REG_TP))->cpu
1555		 *
1556		 * This replicates the implementation of raw_smp_processor_id() on RISCV
1557		 */
1558		if (insn->src_reg == 0 && insn->imm == BPF_FUNC_get_smp_processor_id) {
1559			/* Load current CPU number in R0 */
1560			emit_ld(bpf_to_rv_reg(BPF_REG_0, ctx), offsetof(struct thread_info, cpu),
1561				RV_REG_TP, ctx);
1562			break;
1563		}
1564
1565		mark_call(ctx);
1566		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1567					    &addr, &fixed_addr);
1568		if (ret < 0)
1569			return ret;
1570
1571		if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
1572			const struct btf_func_model *fm;
1573			int idx;
1574
1575			fm = bpf_jit_find_kfunc_model(ctx->prog, insn);
1576			if (!fm)
1577				return -EINVAL;
1578
1579			for (idx = 0; idx < fm->nr_args; idx++) {
1580				u8 reg = bpf_to_rv_reg(BPF_REG_1 + idx, ctx);
1581
1582				if (fm->arg_size[idx] == sizeof(int))
1583					emit_sextw(reg, reg, ctx);
1584			}
1585		}
1586
1587		ret = emit_call(addr, fixed_addr, ctx);
1588		if (ret)
1589			return ret;
1590
1591		if (insn->src_reg != BPF_PSEUDO_CALL)
1592			emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1593		break;
1594	}
1595	/* tail call */
1596	case BPF_JMP | BPF_TAIL_CALL:
1597		if (emit_bpf_tail_call(i, ctx))
1598			return -1;
1599		break;
1600
1601	/* function return */
1602	case BPF_JMP | BPF_EXIT:
1603		if (i == ctx->prog->len - 1)
1604			break;
1605
1606		rvoff = epilogue_offset(ctx);
1607		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1608		if (ret)
1609			return ret;
1610		break;
1611
1612	/* dst = imm64 */
1613	case BPF_LD | BPF_IMM | BPF_DW:
1614	{
1615		struct bpf_insn insn1 = insn[1];
1616		u64 imm64;
1617
1618		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1619		if (bpf_pseudo_func(insn)) {
1620			/* fixed-length insns for extra jit pass */
1621			ret = emit_addr(rd, imm64, extra_pass, ctx);
1622			if (ret)
1623				return ret;
1624		} else {
1625			emit_imm(rd, imm64, ctx);
1626		}
1627
1628		return 1;
1629	}
1630
1631	/* LDX: dst = *(unsigned size *)(src + off) */
1632	case BPF_LDX | BPF_MEM | BPF_B:
1633	case BPF_LDX | BPF_MEM | BPF_H:
1634	case BPF_LDX | BPF_MEM | BPF_W:
1635	case BPF_LDX | BPF_MEM | BPF_DW:
1636	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1637	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1638	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1639	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1640	/* LDSX: dst = *(signed size *)(src + off) */
1641	case BPF_LDX | BPF_MEMSX | BPF_B:
1642	case BPF_LDX | BPF_MEMSX | BPF_H:
1643	case BPF_LDX | BPF_MEMSX | BPF_W:
1644	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1645	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1646	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1647	/* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
1648	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
1649	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
1650	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
1651	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
1652	{
1653		int insn_len, insns_start;
1654		bool sign_ext;
1655
1656		sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1657			   BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1658
1659		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1660			emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
1661			rs = RV_REG_T2;
1662		}
1663
1664		switch (BPF_SIZE(code)) {
1665		case BPF_B:
1666			if (is_12b_int(off)) {
1667				insns_start = ctx->ninsns;
1668				if (sign_ext)
1669					emit(rv_lb(rd, off, rs), ctx);
1670				else
1671					emit(rv_lbu(rd, off, rs), ctx);
1672				insn_len = ctx->ninsns - insns_start;
1673				break;
1674			}
1675
1676			emit_imm(RV_REG_T1, off, ctx);
1677			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1678			insns_start = ctx->ninsns;
1679			if (sign_ext)
1680				emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1681			else
1682				emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1683			insn_len = ctx->ninsns - insns_start;
1684			break;
1685		case BPF_H:
1686			if (is_12b_int(off)) {
1687				insns_start = ctx->ninsns;
1688				if (sign_ext)
1689					emit(rv_lh(rd, off, rs), ctx);
1690				else
1691					emit(rv_lhu(rd, off, rs), ctx);
1692				insn_len = ctx->ninsns - insns_start;
1693				break;
1694			}
1695
1696			emit_imm(RV_REG_T1, off, ctx);
1697			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1698			insns_start = ctx->ninsns;
1699			if (sign_ext)
1700				emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1701			else
1702				emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1703			insn_len = ctx->ninsns - insns_start;
1704			break;
1705		case BPF_W:
1706			if (is_12b_int(off)) {
1707				insns_start = ctx->ninsns;
1708				if (sign_ext)
1709					emit(rv_lw(rd, off, rs), ctx);
1710				else
1711					emit(rv_lwu(rd, off, rs), ctx);
1712				insn_len = ctx->ninsns - insns_start;
1713				break;
1714			}
1715
1716			emit_imm(RV_REG_T1, off, ctx);
1717			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1718			insns_start = ctx->ninsns;
1719			if (sign_ext)
1720				emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1721			else
1722				emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1723			insn_len = ctx->ninsns - insns_start;
1724			break;
1725		case BPF_DW:
1726			if (is_12b_int(off)) {
1727				insns_start = ctx->ninsns;
1728				emit_ld(rd, off, rs, ctx);
1729				insn_len = ctx->ninsns - insns_start;
1730				break;
1731			}
1732
1733			emit_imm(RV_REG_T1, off, ctx);
1734			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1735			insns_start = ctx->ninsns;
1736			emit_ld(rd, 0, RV_REG_T1, ctx);
1737			insn_len = ctx->ninsns - insns_start;
1738			break;
1739		}
1740
1741		ret = add_exception_handler(insn, ctx, rd, insn_len);
1742		if (ret)
1743			return ret;
1744
1745		if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1746			return 1;
1747		break;
1748	}
1749	/* speculation barrier */
1750	case BPF_ST | BPF_NOSPEC:
1751		break;
1752
1753	/* ST: *(size *)(dst + off) = imm */
1754	case BPF_ST | BPF_MEM | BPF_B:
1755		emit_imm(RV_REG_T1, imm, ctx);
1756		if (is_12b_int(off)) {
1757			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1758			break;
1759		}
1760
1761		emit_imm(RV_REG_T2, off, ctx);
1762		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1763		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1764		break;
1765
1766	case BPF_ST | BPF_MEM | BPF_H:
1767		emit_imm(RV_REG_T1, imm, ctx);
1768		if (is_12b_int(off)) {
1769			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1770			break;
1771		}
1772
1773		emit_imm(RV_REG_T2, off, ctx);
1774		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1775		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1776		break;
1777	case BPF_ST | BPF_MEM | BPF_W:
1778		emit_imm(RV_REG_T1, imm, ctx);
1779		if (is_12b_int(off)) {
1780			emit_sw(rd, off, RV_REG_T1, ctx);
1781			break;
1782		}
1783
1784		emit_imm(RV_REG_T2, off, ctx);
1785		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1786		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1787		break;
1788	case BPF_ST | BPF_MEM | BPF_DW:
1789		emit_imm(RV_REG_T1, imm, ctx);
1790		if (is_12b_int(off)) {
1791			emit_sd(rd, off, RV_REG_T1, ctx);
1792			break;
1793		}
1794
1795		emit_imm(RV_REG_T2, off, ctx);
1796		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1797		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1798		break;
1799
1800	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
1801	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
1802	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
1803	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
1804	{
1805		int insn_len, insns_start;
1806
1807		emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
1808		rd = RV_REG_T3;
1809
1810		/* Load imm to a register then store it */
1811		emit_imm(RV_REG_T1, imm, ctx);
1812
1813		switch (BPF_SIZE(code)) {
1814		case BPF_B:
1815			if (is_12b_int(off)) {
1816				insns_start = ctx->ninsns;
1817				emit(rv_sb(rd, off, RV_REG_T1), ctx);
1818				insn_len = ctx->ninsns - insns_start;
1819				break;
1820			}
1821
1822			emit_imm(RV_REG_T2, off, ctx);
1823			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1824			insns_start = ctx->ninsns;
1825			emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1826			insn_len = ctx->ninsns - insns_start;
1827			break;
1828		case BPF_H:
1829			if (is_12b_int(off)) {
1830				insns_start = ctx->ninsns;
1831				emit(rv_sh(rd, off, RV_REG_T1), ctx);
1832				insn_len = ctx->ninsns - insns_start;
1833				break;
1834			}
1835
1836			emit_imm(RV_REG_T2, off, ctx);
1837			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1838			insns_start = ctx->ninsns;
1839			emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1840			insn_len = ctx->ninsns - insns_start;
1841			break;
1842		case BPF_W:
1843			if (is_12b_int(off)) {
1844				insns_start = ctx->ninsns;
1845				emit_sw(rd, off, RV_REG_T1, ctx);
1846				insn_len = ctx->ninsns - insns_start;
1847				break;
1848			}
1849
1850			emit_imm(RV_REG_T2, off, ctx);
1851			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1852			insns_start = ctx->ninsns;
1853			emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1854			insn_len = ctx->ninsns - insns_start;
1855			break;
1856		case BPF_DW:
1857			if (is_12b_int(off)) {
1858				insns_start = ctx->ninsns;
1859				emit_sd(rd, off, RV_REG_T1, ctx);
1860				insn_len = ctx->ninsns - insns_start;
1861				break;
1862			}
1863
1864			emit_imm(RV_REG_T2, off, ctx);
1865			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1866			insns_start = ctx->ninsns;
1867			emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1868			insn_len = ctx->ninsns - insns_start;
1869			break;
1870		}
1871
1872		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
1873					    insn_len);
1874		if (ret)
1875			return ret;
1876
1877		break;
1878	}
1879
1880	/* STX: *(size *)(dst + off) = src */
1881	case BPF_STX | BPF_MEM | BPF_B:
1882		if (is_12b_int(off)) {
1883			emit(rv_sb(rd, off, rs), ctx);
1884			break;
1885		}
1886
1887		emit_imm(RV_REG_T1, off, ctx);
1888		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1889		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1890		break;
1891	case BPF_STX | BPF_MEM | BPF_H:
1892		if (is_12b_int(off)) {
1893			emit(rv_sh(rd, off, rs), ctx);
1894			break;
1895		}
1896
1897		emit_imm(RV_REG_T1, off, ctx);
1898		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1899		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1900		break;
1901	case BPF_STX | BPF_MEM | BPF_W:
1902		if (is_12b_int(off)) {
1903			emit_sw(rd, off, rs, ctx);
1904			break;
1905		}
1906
1907		emit_imm(RV_REG_T1, off, ctx);
1908		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1909		emit_sw(RV_REG_T1, 0, rs, ctx);
1910		break;
1911	case BPF_STX | BPF_MEM | BPF_DW:
1912		if (is_12b_int(off)) {
1913			emit_sd(rd, off, rs, ctx);
1914			break;
1915		}
1916
1917		emit_imm(RV_REG_T1, off, ctx);
1918		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1919		emit_sd(RV_REG_T1, 0, rs, ctx);
1920		break;
1921	case BPF_STX | BPF_ATOMIC | BPF_W:
1922	case BPF_STX | BPF_ATOMIC | BPF_DW:
1923		emit_atomic(rd, rs, off, imm,
1924			    BPF_SIZE(code) == BPF_DW, ctx);
1925		break;
1926
1927	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
1928	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
1929	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
1930	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
1931	{
1932		int insn_len, insns_start;
1933
1934		emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
1935		rd = RV_REG_T2;
1936
1937		switch (BPF_SIZE(code)) {
1938		case BPF_B:
1939			if (is_12b_int(off)) {
1940				insns_start = ctx->ninsns;
1941				emit(rv_sb(rd, off, rs), ctx);
1942				insn_len = ctx->ninsns - insns_start;
1943				break;
1944			}
1945
1946			emit_imm(RV_REG_T1, off, ctx);
1947			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1948			insns_start = ctx->ninsns;
1949			emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1950			insn_len = ctx->ninsns - insns_start;
1951			break;
1952		case BPF_H:
1953			if (is_12b_int(off)) {
1954				insns_start = ctx->ninsns;
1955				emit(rv_sh(rd, off, rs), ctx);
1956				insn_len = ctx->ninsns - insns_start;
1957				break;
1958			}
1959
1960			emit_imm(RV_REG_T1, off, ctx);
1961			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1962			insns_start = ctx->ninsns;
1963			emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1964			insn_len = ctx->ninsns - insns_start;
1965			break;
1966		case BPF_W:
1967			if (is_12b_int(off)) {
1968				insns_start = ctx->ninsns;
1969				emit_sw(rd, off, rs, ctx);
1970				insn_len = ctx->ninsns - insns_start;
1971				break;
1972			}
1973
1974			emit_imm(RV_REG_T1, off, ctx);
1975			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1976			insns_start = ctx->ninsns;
1977			emit_sw(RV_REG_T1, 0, rs, ctx);
1978			insn_len = ctx->ninsns - insns_start;
1979			break;
1980		case BPF_DW:
1981			if (is_12b_int(off)) {
1982				insns_start = ctx->ninsns;
1983				emit_sd(rd, off, rs, ctx);
1984				insn_len = ctx->ninsns - insns_start;
1985				break;
1986			}
1987
1988			emit_imm(RV_REG_T1, off, ctx);
1989			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1990			insns_start = ctx->ninsns;
1991			emit_sd(RV_REG_T1, 0, rs, ctx);
1992			insn_len = ctx->ninsns - insns_start;
1993			break;
1994		}
1995
1996		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
1997					    insn_len);
1998		if (ret)
1999			return ret;
2000
2001		break;
2002	}
2003
2004	default:
2005		pr_err("bpf-jit: unknown opcode %02x\n", code);
2006		return -EINVAL;
2007	}
2008
2009	return 0;
2010}
2011
2012void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
2013{
2014	int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
2015
2016	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
2017	if (bpf_stack_adjust)
2018		mark_fp(ctx);
2019
2020	if (seen_reg(RV_REG_RA, ctx))
2021		stack_adjust += 8;
2022	stack_adjust += 8; /* RV_REG_FP */
2023	if (seen_reg(RV_REG_S1, ctx))
2024		stack_adjust += 8;
2025	if (seen_reg(RV_REG_S2, ctx))
2026		stack_adjust += 8;
2027	if (seen_reg(RV_REG_S3, ctx))
2028		stack_adjust += 8;
2029	if (seen_reg(RV_REG_S4, ctx))
2030		stack_adjust += 8;
2031	if (seen_reg(RV_REG_S5, ctx))
2032		stack_adjust += 8;
2033	if (seen_reg(RV_REG_S6, ctx))
2034		stack_adjust += 8;
2035	if (ctx->arena_vm_start)
2036		stack_adjust += 8;
2037
2038	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
2039	stack_adjust += bpf_stack_adjust;
2040
2041	store_offset = stack_adjust - 8;
2042
2043	/* emit kcfi type preamble immediately before the  first insn */
2044	emit_kcfi(is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash, ctx);
2045
2046	/* nops reserved for auipc+jalr pair */
2047	for (i = 0; i < RV_FENTRY_NINSNS; i++)
2048		emit(rv_nop(), ctx);
2049
2050	/* First instruction is always setting the tail-call-counter
2051	 * (TCC) register. This instruction is skipped for tail calls.
2052	 * Force using a 4-byte (non-compressed) instruction.
2053	 */
2054	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
2055
2056	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
2057
2058	if (seen_reg(RV_REG_RA, ctx)) {
2059		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
2060		store_offset -= 8;
2061	}
2062	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
2063	store_offset -= 8;
2064	if (seen_reg(RV_REG_S1, ctx)) {
2065		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
2066		store_offset -= 8;
2067	}
2068	if (seen_reg(RV_REG_S2, ctx)) {
2069		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
2070		store_offset -= 8;
2071	}
2072	if (seen_reg(RV_REG_S3, ctx)) {
2073		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
2074		store_offset -= 8;
2075	}
2076	if (seen_reg(RV_REG_S4, ctx)) {
2077		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
2078		store_offset -= 8;
2079	}
2080	if (seen_reg(RV_REG_S5, ctx)) {
2081		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
2082		store_offset -= 8;
2083	}
2084	if (seen_reg(RV_REG_S6, ctx)) {
2085		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
2086		store_offset -= 8;
2087	}
2088	if (ctx->arena_vm_start) {
2089		emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
2090		store_offset -= 8;
2091	}
2092
2093	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
2094
2095	if (bpf_stack_adjust)
2096		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
2097
2098	/* Program contains calls and tail calls, so RV_REG_TCC need
2099	 * to be saved across calls.
2100	 */
2101	if (seen_tail_call(ctx) && seen_call(ctx))
2102		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
2103
2104	ctx->stack_size = stack_adjust;
2105
2106	if (ctx->arena_vm_start)
2107		emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
2108}
2109
2110void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
2111{
2112	__build_epilogue(false, ctx);
2113}
2114
2115bool bpf_jit_supports_kfunc_call(void)
2116{
2117	return true;
2118}
2119
2120bool bpf_jit_supports_ptr_xchg(void)
2121{
2122	return true;
2123}
2124
2125bool bpf_jit_supports_arena(void)
2126{
2127	return true;
2128}
2129
2130bool bpf_jit_supports_percpu_insn(void)
2131{
2132	return true;
2133}
2134
2135bool bpf_jit_inlines_helper_call(s32 imm)
2136{
2137	switch (imm) {
2138	case BPF_FUNC_get_smp_processor_id:
2139		return true;
2140	default:
2141		return false;
2142	}
2143}