Linux Audio

Check our new training course

Loading...
Note: File does not exist in v5.14.15.
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
   3 *
   4 * The iopt_pages is the center of the storage and motion of PFNs. Each
   5 * iopt_pages represents a logical linear array of full PFNs. The array is 0
   6 * based and has npages in it. Accessors use 'index' to refer to the entry in
   7 * this logical array, regardless of its storage location.
   8 *
   9 * PFNs are stored in a tiered scheme:
  10 *  1) iopt_pages::pinned_pfns xarray
  11 *  2) An iommu_domain
  12 *  3) The origin of the PFNs, i.e. the userspace pointer
  13 *
  14 * PFN have to be copied between all combinations of tiers, depending on the
  15 * configuration.
  16 *
  17 * When a PFN is taken out of the userspace pointer it is pinned exactly once.
  18 * The storage locations of the PFN's index are tracked in the two interval
  19 * trees. If no interval includes the index then it is not pinned.
  20 *
  21 * If access_itree includes the PFN's index then an in-kernel access has
  22 * requested the page. The PFN is stored in the xarray so other requestors can
  23 * continue to find it.
  24 *
  25 * If the domains_itree includes the PFN's index then an iommu_domain is storing
  26 * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
  27 * duplicating storage the xarray is not used if only iommu_domains are using
  28 * the PFN's index.
  29 *
  30 * As a general principle this is designed so that destroy never fails. This
  31 * means removing an iommu_domain or releasing a in-kernel access will not fail
  32 * due to insufficient memory. In practice this means some cases have to hold
  33 * PFNs in the xarray even though they are also being stored in an iommu_domain.
  34 *
  35 * While the iopt_pages can use an iommu_domain as storage, it does not have an
  36 * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
  37 * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
  38 * and reference their own slice of the PFN array, with sub page granularity.
  39 *
  40 * In this file the term 'last' indicates an inclusive and closed interval, eg
  41 * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
  42 * no PFNs.
  43 *
  44 * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
  45 * last_iova + 1 can overflow. An iopt_pages index will always be much less than
  46 * ULONG_MAX so last_index + 1 cannot overflow.
  47 */
  48#include <linux/overflow.h>
  49#include <linux/slab.h>
  50#include <linux/iommu.h>
  51#include <linux/sched/mm.h>
  52#include <linux/highmem.h>
  53#include <linux/kthread.h>
  54#include <linux/iommufd.h>
  55
  56#include "io_pagetable.h"
  57#include "double_span.h"
  58
  59#ifndef CONFIG_IOMMUFD_TEST
  60#define TEMP_MEMORY_LIMIT 65536
  61#else
  62#define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
  63#endif
  64#define BATCH_BACKUP_SIZE 32
  65
  66/*
  67 * More memory makes pin_user_pages() and the batching more efficient, but as
  68 * this is only a performance optimization don't try too hard to get it. A 64k
  69 * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
  70 * pfn_batch. Various destroy paths cannot fail and provide a small amount of
  71 * stack memory as a backup contingency. If backup_len is given this cannot
  72 * fail.
  73 */
  74static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
  75{
  76	void *res;
  77
  78	if (WARN_ON(*size == 0))
  79		return NULL;
  80
  81	if (*size < backup_len)
  82		return backup;
  83
  84	if (!backup && iommufd_should_fail())
  85		return NULL;
  86
  87	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
  88	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
  89	if (res)
  90		return res;
  91	*size = PAGE_SIZE;
  92	if (backup_len) {
  93		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
  94		if (res)
  95			return res;
  96		*size = backup_len;
  97		return backup;
  98	}
  99	return kmalloc(*size, GFP_KERNEL);
 100}
 101
 102void interval_tree_double_span_iter_update(
 103	struct interval_tree_double_span_iter *iter)
 104{
 105	unsigned long last_hole = ULONG_MAX;
 106	unsigned int i;
 107
 108	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
 109		if (interval_tree_span_iter_done(&iter->spans[i])) {
 110			iter->is_used = -1;
 111			return;
 112		}
 113
 114		if (iter->spans[i].is_hole) {
 115			last_hole = min(last_hole, iter->spans[i].last_hole);
 116			continue;
 117		}
 118
 119		iter->is_used = i + 1;
 120		iter->start_used = iter->spans[i].start_used;
 121		iter->last_used = min(iter->spans[i].last_used, last_hole);
 122		return;
 123	}
 124
 125	iter->is_used = 0;
 126	iter->start_hole = iter->spans[0].start_hole;
 127	iter->last_hole =
 128		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
 129}
 130
 131void interval_tree_double_span_iter_first(
 132	struct interval_tree_double_span_iter *iter,
 133	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
 134	unsigned long first_index, unsigned long last_index)
 135{
 136	unsigned int i;
 137
 138	iter->itrees[0] = itree1;
 139	iter->itrees[1] = itree2;
 140	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
 141		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
 142					      first_index, last_index);
 143	interval_tree_double_span_iter_update(iter);
 144}
 145
 146void interval_tree_double_span_iter_next(
 147	struct interval_tree_double_span_iter *iter)
 148{
 149	unsigned int i;
 150
 151	if (iter->is_used == -1 ||
 152	    iter->last_hole == iter->spans[0].last_index) {
 153		iter->is_used = -1;
 154		return;
 155	}
 156
 157	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
 158		interval_tree_span_iter_advance(
 159			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
 160	interval_tree_double_span_iter_update(iter);
 161}
 162
 163static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
 164{
 165	int rc;
 166
 167	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
 168	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 169		WARN_ON(rc || pages->npinned > pages->npages);
 170}
 171
 172static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
 173{
 174	int rc;
 175
 176	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
 177	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 178		WARN_ON(rc || pages->npinned > pages->npages);
 179}
 180
 181static void iopt_pages_err_unpin(struct iopt_pages *pages,
 182				 unsigned long start_index,
 183				 unsigned long last_index,
 184				 struct page **page_list)
 185{
 186	unsigned long npages = last_index - start_index + 1;
 187
 188	unpin_user_pages(page_list, npages);
 189	iopt_pages_sub_npinned(pages, npages);
 190}
 191
 192/*
 193 * index is the number of PAGE_SIZE units from the start of the area's
 194 * iopt_pages. If the iova is sub page-size then the area has an iova that
 195 * covers a portion of the first and last pages in the range.
 196 */
 197static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
 198					     unsigned long index)
 199{
 200	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 201		WARN_ON(index < iopt_area_index(area) ||
 202			index > iopt_area_last_index(area));
 203	index -= iopt_area_index(area);
 204	if (index == 0)
 205		return iopt_area_iova(area);
 206	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
 207}
 208
 209static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
 210						  unsigned long index)
 211{
 212	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 213		WARN_ON(index < iopt_area_index(area) ||
 214			index > iopt_area_last_index(area));
 215	if (index == iopt_area_last_index(area))
 216		return iopt_area_last_iova(area);
 217	return iopt_area_iova(area) - area->page_offset +
 218	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
 219}
 220
 221static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
 222			       size_t size)
 223{
 224	size_t ret;
 225
 226	ret = iommu_unmap(domain, iova, size);
 227	/*
 228	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
 229	 * something other than exactly as requested. This implies that the
 230	 * iommu driver may not fail unmap for reasons beyond bad agruments.
 231	 * Particularly, the iommu driver may not do a memory allocation on the
 232	 * unmap path.
 233	 */
 234	WARN_ON(ret != size);
 235}
 236
 237static void iopt_area_unmap_domain_range(struct iopt_area *area,
 238					 struct iommu_domain *domain,
 239					 unsigned long start_index,
 240					 unsigned long last_index)
 241{
 242	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
 243
 244	iommu_unmap_nofail(domain, start_iova,
 245			   iopt_area_index_to_iova_last(area, last_index) -
 246				   start_iova + 1);
 247}
 248
 249static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
 250						     unsigned long index)
 251{
 252	struct interval_tree_node *node;
 253
 254	node = interval_tree_iter_first(&pages->domains_itree, index, index);
 255	if (!node)
 256		return NULL;
 257	return container_of(node, struct iopt_area, pages_node);
 258}
 259
 260/*
 261 * A simple datastructure to hold a vector of PFNs, optimized for contiguous
 262 * PFNs. This is used as a temporary holding memory for shuttling pfns from one
 263 * place to another. Generally everything is made more efficient if operations
 264 * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
 265 * better cache locality, etc
 266 */
 267struct pfn_batch {
 268	unsigned long *pfns;
 269	u32 *npfns;
 270	unsigned int array_size;
 271	unsigned int end;
 272	unsigned int total_pfns;
 273};
 274
 275static void batch_clear(struct pfn_batch *batch)
 276{
 277	batch->total_pfns = 0;
 278	batch->end = 0;
 279	batch->pfns[0] = 0;
 280	batch->npfns[0] = 0;
 281}
 282
 283/*
 284 * Carry means we carry a portion of the final hugepage over to the front of the
 285 * batch
 286 */
 287static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
 288{
 289	if (!keep_pfns)
 290		return batch_clear(batch);
 291
 292	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 293		WARN_ON(!batch->end ||
 294			batch->npfns[batch->end - 1] < keep_pfns);
 295
 296	batch->total_pfns = keep_pfns;
 297	batch->npfns[0] = keep_pfns;
 298	batch->pfns[0] = batch->pfns[batch->end - 1] +
 299			 (batch->npfns[batch->end - 1] - keep_pfns);
 300	batch->end = 0;
 301}
 302
 303static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
 304{
 305	if (!batch->total_pfns)
 306		return;
 307	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 308		WARN_ON(batch->total_pfns != batch->npfns[0]);
 309	skip_pfns = min(batch->total_pfns, skip_pfns);
 310	batch->pfns[0] += skip_pfns;
 311	batch->npfns[0] -= skip_pfns;
 312	batch->total_pfns -= skip_pfns;
 313}
 314
 315static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
 316			size_t backup_len)
 317{
 318	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
 319	size_t size = max_pages * elmsz;
 320
 321	batch->pfns = temp_kmalloc(&size, backup, backup_len);
 322	if (!batch->pfns)
 323		return -ENOMEM;
 324	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
 325		return -EINVAL;
 326	batch->array_size = size / elmsz;
 327	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
 328	batch_clear(batch);
 329	return 0;
 330}
 331
 332static int batch_init(struct pfn_batch *batch, size_t max_pages)
 333{
 334	return __batch_init(batch, max_pages, NULL, 0);
 335}
 336
 337static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
 338			      void *backup, size_t backup_len)
 339{
 340	__batch_init(batch, max_pages, backup, backup_len);
 341}
 342
 343static void batch_destroy(struct pfn_batch *batch, void *backup)
 344{
 345	if (batch->pfns != backup)
 346		kfree(batch->pfns);
 347}
 348
 349/* true if the pfn was added, false otherwise */
 350static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
 351{
 352	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
 353
 354	if (batch->end &&
 355	    pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
 356	    batch->npfns[batch->end - 1] != MAX_NPFNS) {
 357		batch->npfns[batch->end - 1]++;
 358		batch->total_pfns++;
 359		return true;
 360	}
 361	if (batch->end == batch->array_size)
 362		return false;
 363	batch->total_pfns++;
 364	batch->pfns[batch->end] = pfn;
 365	batch->npfns[batch->end] = 1;
 366	batch->end++;
 367	return true;
 368}
 369
 370/*
 371 * Fill the batch with pfns from the domain. When the batch is full, or it
 372 * reaches last_index, the function will return. The caller should use
 373 * batch->total_pfns to determine the starting point for the next iteration.
 374 */
 375static void batch_from_domain(struct pfn_batch *batch,
 376			      struct iommu_domain *domain,
 377			      struct iopt_area *area, unsigned long start_index,
 378			      unsigned long last_index)
 379{
 380	unsigned int page_offset = 0;
 381	unsigned long iova;
 382	phys_addr_t phys;
 383
 384	iova = iopt_area_index_to_iova(area, start_index);
 385	if (start_index == iopt_area_index(area))
 386		page_offset = area->page_offset;
 387	while (start_index <= last_index) {
 388		/*
 389		 * This is pretty slow, it would be nice to get the page size
 390		 * back from the driver, or have the driver directly fill the
 391		 * batch.
 392		 */
 393		phys = iommu_iova_to_phys(domain, iova) - page_offset;
 394		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
 395			return;
 396		iova += PAGE_SIZE - page_offset;
 397		page_offset = 0;
 398		start_index++;
 399	}
 400}
 401
 402static struct page **raw_pages_from_domain(struct iommu_domain *domain,
 403					   struct iopt_area *area,
 404					   unsigned long start_index,
 405					   unsigned long last_index,
 406					   struct page **out_pages)
 407{
 408	unsigned int page_offset = 0;
 409	unsigned long iova;
 410	phys_addr_t phys;
 411
 412	iova = iopt_area_index_to_iova(area, start_index);
 413	if (start_index == iopt_area_index(area))
 414		page_offset = area->page_offset;
 415	while (start_index <= last_index) {
 416		phys = iommu_iova_to_phys(domain, iova) - page_offset;
 417		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
 418		iova += PAGE_SIZE - page_offset;
 419		page_offset = 0;
 420		start_index++;
 421	}
 422	return out_pages;
 423}
 424
 425/* Continues reading a domain until we reach a discontinuity in the pfns. */
 426static void batch_from_domain_continue(struct pfn_batch *batch,
 427				       struct iommu_domain *domain,
 428				       struct iopt_area *area,
 429				       unsigned long start_index,
 430				       unsigned long last_index)
 431{
 432	unsigned int array_size = batch->array_size;
 433
 434	batch->array_size = batch->end;
 435	batch_from_domain(batch, domain, area, start_index, last_index);
 436	batch->array_size = array_size;
 437}
 438
 439/*
 440 * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
 441 * mode permits splitting a mapped area up, and then one of the splits is
 442 * unmapped. Doing this normally would cause us to violate our invariant of
 443 * pairing map/unmap. Thus, to support old VFIO compatibility disable support
 444 * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
 445 * PAGE_SIZE units, not larger or smaller.
 446 */
 447static int batch_iommu_map_small(struct iommu_domain *domain,
 448				 unsigned long iova, phys_addr_t paddr,
 449				 size_t size, int prot)
 450{
 451	unsigned long start_iova = iova;
 452	int rc;
 453
 454	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
 455		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
 456			size % PAGE_SIZE);
 457
 458	while (size) {
 459		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot);
 460		if (rc)
 461			goto err_unmap;
 462		iova += PAGE_SIZE;
 463		paddr += PAGE_SIZE;
 464		size -= PAGE_SIZE;
 465	}
 466	return 0;
 467
 468err_unmap:
 469	if (start_iova != iova)
 470		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
 471	return rc;
 472}
 473
 474static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
 475			   struct iopt_area *area, unsigned long start_index)
 476{
 477	bool disable_large_pages = area->iopt->disable_large_pages;
 478	unsigned long last_iova = iopt_area_last_iova(area);
 479	unsigned int page_offset = 0;
 480	unsigned long start_iova;
 481	unsigned long next_iova;
 482	unsigned int cur = 0;
 483	unsigned long iova;
 484	int rc;
 485
 486	/* The first index might be a partial page */
 487	if (start_index == iopt_area_index(area))
 488		page_offset = area->page_offset;
 489	next_iova = iova = start_iova =
 490		iopt_area_index_to_iova(area, start_index);
 491	while (cur < batch->end) {
 492		next_iova = min(last_iova + 1,
 493				next_iova + batch->npfns[cur] * PAGE_SIZE -
 494					page_offset);
 495		if (disable_large_pages)
 496			rc = batch_iommu_map_small(
 497				domain, iova,
 498				PFN_PHYS(batch->pfns[cur]) + page_offset,
 499				next_iova - iova, area->iommu_prot);
 500		else
 501			rc = iommu_map(domain, iova,
 502				       PFN_PHYS(batch->pfns[cur]) + page_offset,
 503				       next_iova - iova, area->iommu_prot);
 504		if (rc)
 505			goto err_unmap;
 506		iova = next_iova;
 507		page_offset = 0;
 508		cur++;
 509	}
 510	return 0;
 511err_unmap:
 512	if (start_iova != iova)
 513		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
 514	return rc;
 515}
 516
 517static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
 518			      unsigned long start_index,
 519			      unsigned long last_index)
 520{
 521	XA_STATE(xas, xa, start_index);
 522	void *entry;
 523
 524	rcu_read_lock();
 525	while (true) {
 526		entry = xas_next(&xas);
 527		if (xas_retry(&xas, entry))
 528			continue;
 529		WARN_ON(!xa_is_value(entry));
 530		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
 531		    start_index == last_index)
 532			break;
 533		start_index++;
 534	}
 535	rcu_read_unlock();
 536}
 537
 538static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
 539				    unsigned long start_index,
 540				    unsigned long last_index)
 541{
 542	XA_STATE(xas, xa, start_index);
 543	void *entry;
 544
 545	xas_lock(&xas);
 546	while (true) {
 547		entry = xas_next(&xas);
 548		if (xas_retry(&xas, entry))
 549			continue;
 550		WARN_ON(!xa_is_value(entry));
 551		if (!batch_add_pfn(batch, xa_to_value(entry)))
 552			break;
 553		xas_store(&xas, NULL);
 554		if (start_index == last_index)
 555			break;
 556		start_index++;
 557	}
 558	xas_unlock(&xas);
 559}
 560
 561static void clear_xarray(struct xarray *xa, unsigned long start_index,
 562			 unsigned long last_index)
 563{
 564	XA_STATE(xas, xa, start_index);
 565	void *entry;
 566
 567	xas_lock(&xas);
 568	xas_for_each(&xas, entry, last_index)
 569		xas_store(&xas, NULL);
 570	xas_unlock(&xas);
 571}
 572
 573static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
 574			   unsigned long last_index, struct page **pages)
 575{
 576	struct page **end_pages = pages + (last_index - start_index) + 1;
 577	struct page **half_pages = pages + (end_pages - pages) / 2;
 578	XA_STATE(xas, xa, start_index);
 579
 580	do {
 581		void *old;
 582
 583		xas_lock(&xas);
 584		while (pages != end_pages) {
 585			/* xarray does not participate in fault injection */
 586			if (pages == half_pages && iommufd_should_fail()) {
 587				xas_set_err(&xas, -EINVAL);
 588				xas_unlock(&xas);
 589				/* aka xas_destroy() */
 590				xas_nomem(&xas, GFP_KERNEL);
 591				goto err_clear;
 592			}
 593
 594			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
 595			if (xas_error(&xas))
 596				break;
 597			WARN_ON(old);
 598			pages++;
 599			xas_next(&xas);
 600		}
 601		xas_unlock(&xas);
 602	} while (xas_nomem(&xas, GFP_KERNEL));
 603
 604err_clear:
 605	if (xas_error(&xas)) {
 606		if (xas.xa_index != start_index)
 607			clear_xarray(xa, start_index, xas.xa_index - 1);
 608		return xas_error(&xas);
 609	}
 610	return 0;
 611}
 612
 613static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
 614			     size_t npages)
 615{
 616	struct page **end = pages + npages;
 617
 618	for (; pages != end; pages++)
 619		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
 620			break;
 621}
 622
 623static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
 624			unsigned int first_page_off, size_t npages)
 625{
 626	unsigned int cur = 0;
 627
 628	while (first_page_off) {
 629		if (batch->npfns[cur] > first_page_off)
 630			break;
 631		first_page_off -= batch->npfns[cur];
 632		cur++;
 633	}
 634
 635	while (npages) {
 636		size_t to_unpin = min_t(size_t, npages,
 637					batch->npfns[cur] - first_page_off);
 638
 639		unpin_user_page_range_dirty_lock(
 640			pfn_to_page(batch->pfns[cur] + first_page_off),
 641			to_unpin, pages->writable);
 642		iopt_pages_sub_npinned(pages, to_unpin);
 643		cur++;
 644		first_page_off = 0;
 645		npages -= to_unpin;
 646	}
 647}
 648
 649static void copy_data_page(struct page *page, void *data, unsigned long offset,
 650			   size_t length, unsigned int flags)
 651{
 652	void *mem;
 653
 654	mem = kmap_local_page(page);
 655	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
 656		memcpy(mem + offset, data, length);
 657		set_page_dirty_lock(page);
 658	} else {
 659		memcpy(data, mem + offset, length);
 660	}
 661	kunmap_local(mem);
 662}
 663
 664static unsigned long batch_rw(struct pfn_batch *batch, void *data,
 665			      unsigned long offset, unsigned long length,
 666			      unsigned int flags)
 667{
 668	unsigned long copied = 0;
 669	unsigned int npage = 0;
 670	unsigned int cur = 0;
 671
 672	while (cur < batch->end) {
 673		unsigned long bytes = min(length, PAGE_SIZE - offset);
 674
 675		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
 676			       offset, bytes, flags);
 677		offset = 0;
 678		length -= bytes;
 679		data += bytes;
 680		copied += bytes;
 681		npage++;
 682		if (npage == batch->npfns[cur]) {
 683			npage = 0;
 684			cur++;
 685		}
 686		if (!length)
 687			break;
 688	}
 689	return copied;
 690}
 691
 692/* pfn_reader_user is just the pin_user_pages() path */
 693struct pfn_reader_user {
 694	struct page **upages;
 695	size_t upages_len;
 696	unsigned long upages_start;
 697	unsigned long upages_end;
 698	unsigned int gup_flags;
 699	/*
 700	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
 701	 * neither
 702	 */
 703	int locked;
 704};
 705
 706static void pfn_reader_user_init(struct pfn_reader_user *user,
 707				 struct iopt_pages *pages)
 708{
 709	user->upages = NULL;
 710	user->upages_start = 0;
 711	user->upages_end = 0;
 712	user->locked = -1;
 713
 714	user->gup_flags = FOLL_LONGTERM;
 715	if (pages->writable)
 716		user->gup_flags |= FOLL_WRITE;
 717}
 718
 719static void pfn_reader_user_destroy(struct pfn_reader_user *user,
 720				    struct iopt_pages *pages)
 721{
 722	if (user->locked != -1) {
 723		if (user->locked)
 724			mmap_read_unlock(pages->source_mm);
 725		if (pages->source_mm != current->mm)
 726			mmput(pages->source_mm);
 727		user->locked = -1;
 728	}
 729
 730	kfree(user->upages);
 731	user->upages = NULL;
 732}
 733
 734static int pfn_reader_user_pin(struct pfn_reader_user *user,
 735			       struct iopt_pages *pages,
 736			       unsigned long start_index,
 737			       unsigned long last_index)
 738{
 739	bool remote_mm = pages->source_mm != current->mm;
 740	unsigned long npages;
 741	uintptr_t uptr;
 742	long rc;
 743
 744	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
 745	    WARN_ON(last_index < start_index))
 746		return -EINVAL;
 747
 748	if (!user->upages) {
 749		/* All undone in pfn_reader_destroy() */
 750		user->upages_len =
 751			(last_index - start_index + 1) * sizeof(*user->upages);
 752		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
 753		if (!user->upages)
 754			return -ENOMEM;
 755	}
 756
 757	if (user->locked == -1) {
 758		/*
 759		 * The majority of usages will run the map task within the mm
 760		 * providing the pages, so we can optimize into
 761		 * get_user_pages_fast()
 762		 */
 763		if (remote_mm) {
 764			if (!mmget_not_zero(pages->source_mm))
 765				return -EFAULT;
 766		}
 767		user->locked = 0;
 768	}
 769
 770	npages = min_t(unsigned long, last_index - start_index + 1,
 771		       user->upages_len / sizeof(*user->upages));
 772
 773
 774	if (iommufd_should_fail())
 775		return -EFAULT;
 776
 777	uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
 778	if (!remote_mm)
 779		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
 780					 user->upages);
 781	else {
 782		if (!user->locked) {
 783			mmap_read_lock(pages->source_mm);
 784			user->locked = 1;
 785		}
 786		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
 787					   user->gup_flags, user->upages, NULL,
 788					   &user->locked);
 789	}
 790	if (rc <= 0) {
 791		if (WARN_ON(!rc))
 792			return -EFAULT;
 793		return rc;
 794	}
 795	iopt_pages_add_npinned(pages, rc);
 796	user->upages_start = start_index;
 797	user->upages_end = start_index + rc;
 798	return 0;
 799}
 800
 801/* This is the "modern" and faster accounting method used by io_uring */
 802static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
 803{
 804	unsigned long lock_limit;
 805	unsigned long cur_pages;
 806	unsigned long new_pages;
 807
 808	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
 809		     PAGE_SHIFT;
 810	do {
 811		cur_pages = atomic_long_read(&pages->source_user->locked_vm);
 812		new_pages = cur_pages + npages;
 813		if (new_pages > lock_limit)
 814			return -ENOMEM;
 815	} while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
 816				     new_pages) != cur_pages);
 817	return 0;
 818}
 819
 820static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
 821{
 822	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
 823		return;
 824	atomic_long_sub(npages, &pages->source_user->locked_vm);
 825}
 826
 827/* This is the accounting method used for compatibility with VFIO */
 828static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
 829			       bool inc, struct pfn_reader_user *user)
 830{
 831	bool do_put = false;
 832	int rc;
 833
 834	if (user && user->locked) {
 835		mmap_read_unlock(pages->source_mm);
 836		user->locked = 0;
 837		/* If we had the lock then we also have a get */
 838	} else if ((!user || !user->upages) &&
 839		   pages->source_mm != current->mm) {
 840		if (!mmget_not_zero(pages->source_mm))
 841			return -EINVAL;
 842		do_put = true;
 843	}
 844
 845	mmap_write_lock(pages->source_mm);
 846	rc = __account_locked_vm(pages->source_mm, npages, inc,
 847				 pages->source_task, false);
 848	mmap_write_unlock(pages->source_mm);
 849
 850	if (do_put)
 851		mmput(pages->source_mm);
 852	return rc;
 853}
 854
 855static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
 856			    bool inc, struct pfn_reader_user *user)
 857{
 858	int rc = 0;
 859
 860	switch (pages->account_mode) {
 861	case IOPT_PAGES_ACCOUNT_NONE:
 862		break;
 863	case IOPT_PAGES_ACCOUNT_USER:
 864		if (inc)
 865			rc = incr_user_locked_vm(pages, npages);
 866		else
 867			decr_user_locked_vm(pages, npages);
 868		break;
 869	case IOPT_PAGES_ACCOUNT_MM:
 870		rc = update_mm_locked_vm(pages, npages, inc, user);
 871		break;
 872	}
 873	if (rc)
 874		return rc;
 875
 876	pages->last_npinned = pages->npinned;
 877	if (inc)
 878		atomic64_add(npages, &pages->source_mm->pinned_vm);
 879	else
 880		atomic64_sub(npages, &pages->source_mm->pinned_vm);
 881	return 0;
 882}
 883
 884static void update_unpinned(struct iopt_pages *pages)
 885{
 886	if (WARN_ON(pages->npinned > pages->last_npinned))
 887		return;
 888	if (pages->npinned == pages->last_npinned)
 889		return;
 890	do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
 891			 NULL);
 892}
 893
 894/*
 895 * Changes in the number of pages pinned is done after the pages have been read
 896 * and processed. If the user lacked the limit then the error unwind will unpin
 897 * everything that was just pinned. This is because it is expensive to calculate
 898 * how many pages we have already pinned within a range to generate an accurate
 899 * prediction in advance of doing the work to actually pin them.
 900 */
 901static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
 902					 struct iopt_pages *pages)
 903{
 904	unsigned long npages;
 905	bool inc;
 906
 907	lockdep_assert_held(&pages->mutex);
 908
 909	if (pages->npinned == pages->last_npinned)
 910		return 0;
 911
 912	if (pages->npinned < pages->last_npinned) {
 913		npages = pages->last_npinned - pages->npinned;
 914		inc = false;
 915	} else {
 916		if (iommufd_should_fail())
 917			return -ENOMEM;
 918		npages = pages->npinned - pages->last_npinned;
 919		inc = true;
 920	}
 921	return do_update_pinned(pages, npages, inc, user);
 922}
 923
 924/*
 925 * PFNs are stored in three places, in order of preference:
 926 * - The iopt_pages xarray. This is only populated if there is a
 927 *   iopt_pages_access
 928 * - The iommu_domain under an area
 929 * - The original PFN source, ie pages->source_mm
 930 *
 931 * This iterator reads the pfns optimizing to load according to the
 932 * above order.
 933 */
 934struct pfn_reader {
 935	struct iopt_pages *pages;
 936	struct interval_tree_double_span_iter span;
 937	struct pfn_batch batch;
 938	unsigned long batch_start_index;
 939	unsigned long batch_end_index;
 940	unsigned long last_index;
 941
 942	struct pfn_reader_user user;
 943};
 944
 945static int pfn_reader_update_pinned(struct pfn_reader *pfns)
 946{
 947	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
 948}
 949
 950/*
 951 * The batch can contain a mixture of pages that are still in use and pages that
 952 * need to be unpinned. Unpin only pages that are not held anywhere else.
 953 */
 954static void pfn_reader_unpin(struct pfn_reader *pfns)
 955{
 956	unsigned long last = pfns->batch_end_index - 1;
 957	unsigned long start = pfns->batch_start_index;
 958	struct interval_tree_double_span_iter span;
 959	struct iopt_pages *pages = pfns->pages;
 960
 961	lockdep_assert_held(&pages->mutex);
 962
 963	interval_tree_for_each_double_span(&span, &pages->access_itree,
 964					   &pages->domains_itree, start, last) {
 965		if (span.is_used)
 966			continue;
 967
 968		batch_unpin(&pfns->batch, pages, span.start_hole - start,
 969			    span.last_hole - span.start_hole + 1);
 970	}
 971}
 972
 973/* Process a single span to load it from the proper storage */
 974static int pfn_reader_fill_span(struct pfn_reader *pfns)
 975{
 976	struct interval_tree_double_span_iter *span = &pfns->span;
 977	unsigned long start_index = pfns->batch_end_index;
 978	struct iopt_area *area;
 979	int rc;
 980
 981	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
 982	    WARN_ON(span->last_used < start_index))
 983		return -EINVAL;
 984
 985	if (span->is_used == 1) {
 986		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
 987				  start_index, span->last_used);
 988		return 0;
 989	}
 990
 991	if (span->is_used == 2) {
 992		/*
 993		 * Pull as many pages from the first domain we find in the
 994		 * target span. If it is too small then we will be called again
 995		 * and we'll find another area.
 996		 */
 997		area = iopt_pages_find_domain_area(pfns->pages, start_index);
 998		if (WARN_ON(!area))
 999			return -EINVAL;
1000
1001		/* The storage_domain cannot change without the pages mutex */
1002		batch_from_domain(
1003			&pfns->batch, area->storage_domain, area, start_index,
1004			min(iopt_area_last_index(area), span->last_used));
1005		return 0;
1006	}
1007
1008	if (start_index >= pfns->user.upages_end) {
1009		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1010					 span->last_hole);
1011		if (rc)
1012			return rc;
1013	}
1014
1015	batch_from_pages(&pfns->batch,
1016			 pfns->user.upages +
1017				 (start_index - pfns->user.upages_start),
1018			 pfns->user.upages_end - start_index);
1019	return 0;
1020}
1021
1022static bool pfn_reader_done(struct pfn_reader *pfns)
1023{
1024	return pfns->batch_start_index == pfns->last_index + 1;
1025}
1026
1027static int pfn_reader_next(struct pfn_reader *pfns)
1028{
1029	int rc;
1030
1031	batch_clear(&pfns->batch);
1032	pfns->batch_start_index = pfns->batch_end_index;
1033
1034	while (pfns->batch_end_index != pfns->last_index + 1) {
1035		unsigned int npfns = pfns->batch.total_pfns;
1036
1037		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1038		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1039			return -EINVAL;
1040
1041		rc = pfn_reader_fill_span(pfns);
1042		if (rc)
1043			return rc;
1044
1045		if (WARN_ON(!pfns->batch.total_pfns))
1046			return -EINVAL;
1047
1048		pfns->batch_end_index =
1049			pfns->batch_start_index + pfns->batch.total_pfns;
1050		if (pfns->batch_end_index == pfns->span.last_used + 1)
1051			interval_tree_double_span_iter_next(&pfns->span);
1052
1053		/* Batch is full */
1054		if (npfns == pfns->batch.total_pfns)
1055			return 0;
1056	}
1057	return 0;
1058}
1059
1060static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1061			   unsigned long start_index, unsigned long last_index)
1062{
1063	int rc;
1064
1065	lockdep_assert_held(&pages->mutex);
1066
1067	pfns->pages = pages;
1068	pfns->batch_start_index = start_index;
1069	pfns->batch_end_index = start_index;
1070	pfns->last_index = last_index;
1071	pfn_reader_user_init(&pfns->user, pages);
1072	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1073	if (rc)
1074		return rc;
1075	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1076					     &pages->domains_itree, start_index,
1077					     last_index);
1078	return 0;
1079}
1080
1081/*
1082 * There are many assertions regarding the state of pages->npinned vs
1083 * pages->last_pinned, for instance something like unmapping a domain must only
1084 * decrement the npinned, and pfn_reader_destroy() must be called only after all
1085 * the pins are updated. This is fine for success flows, but error flows
1086 * sometimes need to release the pins held inside the pfn_reader before going on
1087 * to complete unmapping and releasing pins held in domains.
1088 */
1089static void pfn_reader_release_pins(struct pfn_reader *pfns)
1090{
1091	struct iopt_pages *pages = pfns->pages;
1092
1093	if (pfns->user.upages_end > pfns->batch_end_index) {
1094		size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1095
1096		/* Any pages not transferred to the batch are just unpinned */
1097		unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1098						      pfns->user.upages_start),
1099				 npages);
1100		iopt_pages_sub_npinned(pages, npages);
1101		pfns->user.upages_end = pfns->batch_end_index;
1102	}
1103	if (pfns->batch_start_index != pfns->batch_end_index) {
1104		pfn_reader_unpin(pfns);
1105		pfns->batch_start_index = pfns->batch_end_index;
1106	}
1107}
1108
1109static void pfn_reader_destroy(struct pfn_reader *pfns)
1110{
1111	struct iopt_pages *pages = pfns->pages;
1112
1113	pfn_reader_release_pins(pfns);
1114	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1115	batch_destroy(&pfns->batch, NULL);
1116	WARN_ON(pages->last_npinned != pages->npinned);
1117}
1118
1119static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1120			    unsigned long start_index, unsigned long last_index)
1121{
1122	int rc;
1123
1124	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1125	    WARN_ON(last_index < start_index))
1126		return -EINVAL;
1127
1128	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1129	if (rc)
1130		return rc;
1131	rc = pfn_reader_next(pfns);
1132	if (rc) {
1133		pfn_reader_destroy(pfns);
1134		return rc;
1135	}
1136	return 0;
1137}
1138
1139struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1140				    bool writable)
1141{
1142	struct iopt_pages *pages;
1143
1144	/*
1145	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1146	 * below from overflow
1147	 */
1148	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1149		return ERR_PTR(-EINVAL);
1150
1151	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1152	if (!pages)
1153		return ERR_PTR(-ENOMEM);
1154
1155	kref_init(&pages->kref);
1156	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1157	mutex_init(&pages->mutex);
1158	pages->source_mm = current->mm;
1159	mmgrab(pages->source_mm);
1160	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1161	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1162	pages->access_itree = RB_ROOT_CACHED;
1163	pages->domains_itree = RB_ROOT_CACHED;
1164	pages->writable = writable;
1165	if (capable(CAP_IPC_LOCK))
1166		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1167	else
1168		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1169	pages->source_task = current->group_leader;
1170	get_task_struct(current->group_leader);
1171	pages->source_user = get_uid(current_user());
1172	return pages;
1173}
1174
1175void iopt_release_pages(struct kref *kref)
1176{
1177	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1178
1179	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1180	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1181	WARN_ON(pages->npinned);
1182	WARN_ON(!xa_empty(&pages->pinned_pfns));
1183	mmdrop(pages->source_mm);
1184	mutex_destroy(&pages->mutex);
1185	put_task_struct(pages->source_task);
1186	free_uid(pages->source_user);
1187	kfree(pages);
1188}
1189
1190static void
1191iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1192		       struct iopt_pages *pages, struct iommu_domain *domain,
1193		       unsigned long start_index, unsigned long last_index,
1194		       unsigned long *unmapped_end_index,
1195		       unsigned long real_last_index)
1196{
1197	while (start_index <= last_index) {
1198		unsigned long batch_last_index;
1199
1200		if (*unmapped_end_index <= last_index) {
1201			unsigned long start =
1202				max(start_index, *unmapped_end_index);
1203
1204			batch_from_domain(batch, domain, area, start,
1205					  last_index);
1206			batch_last_index = start + batch->total_pfns - 1;
1207		} else {
1208			batch_last_index = last_index;
1209		}
1210
1211		/*
1212		 * unmaps must always 'cut' at a place where the pfns are not
1213		 * contiguous to pair with the maps that always install
1214		 * contiguous pages. Thus, if we have to stop unpinning in the
1215		 * middle of the domains we need to keep reading pfns until we
1216		 * find a cut point to do the unmap. The pfns we read are
1217		 * carried over and either skipped or integrated into the next
1218		 * batch.
1219		 */
1220		if (batch_last_index == last_index &&
1221		    last_index != real_last_index)
1222			batch_from_domain_continue(batch, domain, area,
1223						   last_index + 1,
1224						   real_last_index);
1225
1226		if (*unmapped_end_index <= batch_last_index) {
1227			iopt_area_unmap_domain_range(
1228				area, domain, *unmapped_end_index,
1229				start_index + batch->total_pfns - 1);
1230			*unmapped_end_index = start_index + batch->total_pfns;
1231		}
1232
1233		/* unpin must follow unmap */
1234		batch_unpin(batch, pages, 0,
1235			    batch_last_index - start_index + 1);
1236		start_index = batch_last_index + 1;
1237
1238		batch_clear_carry(batch,
1239				  *unmapped_end_index - batch_last_index - 1);
1240	}
1241}
1242
1243static void __iopt_area_unfill_domain(struct iopt_area *area,
1244				      struct iopt_pages *pages,
1245				      struct iommu_domain *domain,
1246				      unsigned long last_index)
1247{
1248	struct interval_tree_double_span_iter span;
1249	unsigned long start_index = iopt_area_index(area);
1250	unsigned long unmapped_end_index = start_index;
1251	u64 backup[BATCH_BACKUP_SIZE];
1252	struct pfn_batch batch;
1253
1254	lockdep_assert_held(&pages->mutex);
1255
1256	/*
1257	 * For security we must not unpin something that is still DMA mapped,
1258	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1259	 * This creates a complexity where we need to skip over unpinning pages
1260	 * held in the xarray, but continue to unmap from the domain.
1261	 *
1262	 * The domain unmap cannot stop in the middle of a contiguous range of
1263	 * PFNs. To solve this problem the unpinning step will read ahead to the
1264	 * end of any contiguous span, unmap that whole span, and then only
1265	 * unpin the leading part that does not have any accesses. The residual
1266	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1267	 * batch as they are moved to the front of the PFN list and continue on
1268	 * to the next iteration(s).
1269	 */
1270	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1271	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1272					   &pages->access_itree, start_index,
1273					   last_index) {
1274		if (span.is_used) {
1275			batch_skip_carry(&batch,
1276					 span.last_used - span.start_used + 1);
1277			continue;
1278		}
1279		iopt_area_unpin_domain(&batch, area, pages, domain,
1280				       span.start_hole, span.last_hole,
1281				       &unmapped_end_index, last_index);
1282	}
1283	/*
1284	 * If the range ends in a access then we do the residual unmap without
1285	 * any unpins.
1286	 */
1287	if (unmapped_end_index != last_index + 1)
1288		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1289					     last_index);
1290	WARN_ON(batch.total_pfns);
1291	batch_destroy(&batch, backup);
1292	update_unpinned(pages);
1293}
1294
1295static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1296					    struct iopt_pages *pages,
1297					    struct iommu_domain *domain,
1298					    unsigned long end_index)
1299{
1300	if (end_index != iopt_area_index(area))
1301		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1302}
1303
1304/**
1305 * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1306 * @area: The IOVA range to unmap
1307 * @domain: The domain to unmap
1308 *
1309 * The caller must know that unpinning is not required, usually because there
1310 * are other domains in the iopt.
1311 */
1312void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1313{
1314	iommu_unmap_nofail(domain, iopt_area_iova(area),
1315			   iopt_area_length(area));
1316}
1317
1318/**
1319 * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1320 * @area: IOVA area to use
1321 * @pages: page supplier for the area (area->pages is NULL)
1322 * @domain: Domain to unmap from
1323 *
1324 * The domain should be removed from the domains_itree before calling. The
1325 * domain will always be unmapped, but the PFNs may not be unpinned if there are
1326 * still accesses.
1327 */
1328void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1329			     struct iommu_domain *domain)
1330{
1331	__iopt_area_unfill_domain(area, pages, domain,
1332				  iopt_area_last_index(area));
1333}
1334
1335/**
1336 * iopt_area_fill_domain() - Map PFNs from the area into a domain
1337 * @area: IOVA area to use
1338 * @domain: Domain to load PFNs into
1339 *
1340 * Read the pfns from the area's underlying iopt_pages and map them into the
1341 * given domain. Called when attaching a new domain to an io_pagetable.
1342 */
1343int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1344{
1345	unsigned long done_end_index;
1346	struct pfn_reader pfns;
1347	int rc;
1348
1349	lockdep_assert_held(&area->pages->mutex);
1350
1351	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1352			      iopt_area_last_index(area));
1353	if (rc)
1354		return rc;
1355
1356	while (!pfn_reader_done(&pfns)) {
1357		done_end_index = pfns.batch_start_index;
1358		rc = batch_to_domain(&pfns.batch, domain, area,
1359				     pfns.batch_start_index);
1360		if (rc)
1361			goto out_unmap;
1362		done_end_index = pfns.batch_end_index;
1363
1364		rc = pfn_reader_next(&pfns);
1365		if (rc)
1366			goto out_unmap;
1367	}
1368
1369	rc = pfn_reader_update_pinned(&pfns);
1370	if (rc)
1371		goto out_unmap;
1372	goto out_destroy;
1373
1374out_unmap:
1375	pfn_reader_release_pins(&pfns);
1376	iopt_area_unfill_partial_domain(area, area->pages, domain,
1377					done_end_index);
1378out_destroy:
1379	pfn_reader_destroy(&pfns);
1380	return rc;
1381}
1382
1383/**
1384 * iopt_area_fill_domains() - Install PFNs into the area's domains
1385 * @area: The area to act on
1386 * @pages: The pages associated with the area (area->pages is NULL)
1387 *
1388 * Called during area creation. The area is freshly created and not inserted in
1389 * the domains_itree yet. PFNs are read and loaded into every domain held in the
1390 * area's io_pagetable and the area is installed in the domains_itree.
1391 *
1392 * On failure all domains are left unchanged.
1393 */
1394int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1395{
1396	unsigned long done_first_end_index;
1397	unsigned long done_all_end_index;
1398	struct iommu_domain *domain;
1399	unsigned long unmap_index;
1400	struct pfn_reader pfns;
1401	unsigned long index;
1402	int rc;
1403
1404	lockdep_assert_held(&area->iopt->domains_rwsem);
1405
1406	if (xa_empty(&area->iopt->domains))
1407		return 0;
1408
1409	mutex_lock(&pages->mutex);
1410	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1411			      iopt_area_last_index(area));
1412	if (rc)
1413		goto out_unlock;
1414
1415	while (!pfn_reader_done(&pfns)) {
1416		done_first_end_index = pfns.batch_end_index;
1417		done_all_end_index = pfns.batch_start_index;
1418		xa_for_each(&area->iopt->domains, index, domain) {
1419			rc = batch_to_domain(&pfns.batch, domain, area,
1420					     pfns.batch_start_index);
1421			if (rc)
1422				goto out_unmap;
1423		}
1424		done_all_end_index = done_first_end_index;
1425
1426		rc = pfn_reader_next(&pfns);
1427		if (rc)
1428			goto out_unmap;
1429	}
1430	rc = pfn_reader_update_pinned(&pfns);
1431	if (rc)
1432		goto out_unmap;
1433
1434	area->storage_domain = xa_load(&area->iopt->domains, 0);
1435	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1436	goto out_destroy;
1437
1438out_unmap:
1439	pfn_reader_release_pins(&pfns);
1440	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1441		unsigned long end_index;
1442
1443		if (unmap_index < index)
1444			end_index = done_first_end_index;
1445		else
1446			end_index = done_all_end_index;
1447
1448		/*
1449		 * The area is not yet part of the domains_itree so we have to
1450		 * manage the unpinning specially. The last domain does the
1451		 * unpin, every other domain is just unmapped.
1452		 */
1453		if (unmap_index != area->iopt->next_domain_id - 1) {
1454			if (end_index != iopt_area_index(area))
1455				iopt_area_unmap_domain_range(
1456					area, domain, iopt_area_index(area),
1457					end_index - 1);
1458		} else {
1459			iopt_area_unfill_partial_domain(area, pages, domain,
1460							end_index);
1461		}
1462	}
1463out_destroy:
1464	pfn_reader_destroy(&pfns);
1465out_unlock:
1466	mutex_unlock(&pages->mutex);
1467	return rc;
1468}
1469
1470/**
1471 * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1472 * @area: The area to act on
1473 * @pages: The pages associated with the area (area->pages is NULL)
1474 *
1475 * Called during area destruction. This unmaps the iova's covered by all the
1476 * area's domains and releases the PFNs.
1477 */
1478void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1479{
1480	struct io_pagetable *iopt = area->iopt;
1481	struct iommu_domain *domain;
1482	unsigned long index;
1483
1484	lockdep_assert_held(&iopt->domains_rwsem);
1485
1486	mutex_lock(&pages->mutex);
1487	if (!area->storage_domain)
1488		goto out_unlock;
1489
1490	xa_for_each(&iopt->domains, index, domain)
1491		if (domain != area->storage_domain)
1492			iopt_area_unmap_domain_range(
1493				area, domain, iopt_area_index(area),
1494				iopt_area_last_index(area));
1495
1496	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1497	iopt_area_unfill_domain(area, pages, area->storage_domain);
1498	area->storage_domain = NULL;
1499out_unlock:
1500	mutex_unlock(&pages->mutex);
1501}
1502
1503static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1504				    struct iopt_pages *pages,
1505				    unsigned long start_index,
1506				    unsigned long end_index)
1507{
1508	while (start_index <= end_index) {
1509		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1510					end_index);
1511		batch_unpin(batch, pages, 0, batch->total_pfns);
1512		start_index += batch->total_pfns;
1513		batch_clear(batch);
1514	}
1515}
1516
1517/**
1518 * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1519 * @pages: The pages to act on
1520 * @start_index: Starting PFN index
1521 * @last_index: Last PFN index
1522 *
1523 * Called when an iopt_pages_access is removed, removes pages from the itree.
1524 * The access should already be removed from the access_itree.
1525 */
1526void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1527			      unsigned long start_index,
1528			      unsigned long last_index)
1529{
1530	struct interval_tree_double_span_iter span;
1531	u64 backup[BATCH_BACKUP_SIZE];
1532	struct pfn_batch batch;
1533	bool batch_inited = false;
1534
1535	lockdep_assert_held(&pages->mutex);
1536
1537	interval_tree_for_each_double_span(&span, &pages->access_itree,
1538					   &pages->domains_itree, start_index,
1539					   last_index) {
1540		if (!span.is_used) {
1541			if (!batch_inited) {
1542				batch_init_backup(&batch,
1543						  last_index - start_index + 1,
1544						  backup, sizeof(backup));
1545				batch_inited = true;
1546			}
1547			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1548						span.last_hole);
1549		} else if (span.is_used == 2) {
1550			/* Covered by a domain */
1551			clear_xarray(&pages->pinned_pfns, span.start_used,
1552				     span.last_used);
1553		}
1554		/* Otherwise covered by an existing access */
1555	}
1556	if (batch_inited)
1557		batch_destroy(&batch, backup);
1558	update_unpinned(pages);
1559}
1560
1561/**
1562 * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1563 * @pages: The pages to act on
1564 * @start_index: The first page index in the range
1565 * @last_index: The last page index in the range
1566 * @out_pages: The output array to return the pages
1567 *
1568 * This can be called if the caller is holding a refcount on an
1569 * iopt_pages_access that is known to have already been filled. It quickly reads
1570 * the pages directly from the xarray.
1571 *
1572 * This is part of the SW iommu interface to read pages for in-kernel use.
1573 */
1574void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1575				 unsigned long start_index,
1576				 unsigned long last_index,
1577				 struct page **out_pages)
1578{
1579	XA_STATE(xas, &pages->pinned_pfns, start_index);
1580	void *entry;
1581
1582	rcu_read_lock();
1583	while (start_index <= last_index) {
1584		entry = xas_next(&xas);
1585		if (xas_retry(&xas, entry))
1586			continue;
1587		WARN_ON(!xa_is_value(entry));
1588		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1589		start_index++;
1590	}
1591	rcu_read_unlock();
1592}
1593
1594static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1595				       unsigned long start_index,
1596				       unsigned long last_index,
1597				       struct page **out_pages)
1598{
1599	while (start_index != last_index + 1) {
1600		unsigned long domain_last;
1601		struct iopt_area *area;
1602
1603		area = iopt_pages_find_domain_area(pages, start_index);
1604		if (WARN_ON(!area))
1605			return -EINVAL;
1606
1607		domain_last = min(iopt_area_last_index(area), last_index);
1608		out_pages = raw_pages_from_domain(area->storage_domain, area,
1609						  start_index, domain_last,
1610						  out_pages);
1611		start_index = domain_last + 1;
1612	}
1613	return 0;
1614}
1615
1616static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1617				   struct pfn_reader_user *user,
1618				   unsigned long start_index,
1619				   unsigned long last_index,
1620				   struct page **out_pages)
1621{
1622	unsigned long cur_index = start_index;
1623	int rc;
1624
1625	while (cur_index != last_index + 1) {
1626		user->upages = out_pages + (cur_index - start_index);
1627		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1628		if (rc)
1629			goto out_unpin;
1630		cur_index = user->upages_end;
1631	}
1632	return 0;
1633
1634out_unpin:
1635	if (start_index != cur_index)
1636		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1637				     out_pages);
1638	return rc;
1639}
1640
1641/**
1642 * iopt_pages_fill_xarray() - Read PFNs
1643 * @pages: The pages to act on
1644 * @start_index: The first page index in the range
1645 * @last_index: The last page index in the range
1646 * @out_pages: The output array to return the pages, may be NULL
1647 *
1648 * This populates the xarray and returns the pages in out_pages. As the slow
1649 * path this is able to copy pages from other storage tiers into the xarray.
1650 *
1651 * On failure the xarray is left unchanged.
1652 *
1653 * This is part of the SW iommu interface to read pages for in-kernel use.
1654 */
1655int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1656			   unsigned long last_index, struct page **out_pages)
1657{
1658	struct interval_tree_double_span_iter span;
1659	unsigned long xa_end = start_index;
1660	struct pfn_reader_user user;
1661	int rc;
1662
1663	lockdep_assert_held(&pages->mutex);
1664
1665	pfn_reader_user_init(&user, pages);
1666	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1667	interval_tree_for_each_double_span(&span, &pages->access_itree,
1668					   &pages->domains_itree, start_index,
1669					   last_index) {
1670		struct page **cur_pages;
1671
1672		if (span.is_used == 1) {
1673			cur_pages = out_pages + (span.start_used - start_index);
1674			iopt_pages_fill_from_xarray(pages, span.start_used,
1675						    span.last_used, cur_pages);
1676			continue;
1677		}
1678
1679		if (span.is_used == 2) {
1680			cur_pages = out_pages + (span.start_used - start_index);
1681			iopt_pages_fill_from_domain(pages, span.start_used,
1682						    span.last_used, cur_pages);
1683			rc = pages_to_xarray(&pages->pinned_pfns,
1684					     span.start_used, span.last_used,
1685					     cur_pages);
1686			if (rc)
1687				goto out_clean_xa;
1688			xa_end = span.last_used + 1;
1689			continue;
1690		}
1691
1692		/* hole */
1693		cur_pages = out_pages + (span.start_hole - start_index);
1694		rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1695					     span.last_hole, cur_pages);
1696		if (rc)
1697			goto out_clean_xa;
1698		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1699				     span.last_hole, cur_pages);
1700		if (rc) {
1701			iopt_pages_err_unpin(pages, span.start_hole,
1702					     span.last_hole, cur_pages);
1703			goto out_clean_xa;
1704		}
1705		xa_end = span.last_hole + 1;
1706	}
1707	rc = pfn_reader_user_update_pinned(&user, pages);
1708	if (rc)
1709		goto out_clean_xa;
1710	user.upages = NULL;
1711	pfn_reader_user_destroy(&user, pages);
1712	return 0;
1713
1714out_clean_xa:
1715	if (start_index != xa_end)
1716		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1717	user.upages = NULL;
1718	pfn_reader_user_destroy(&user, pages);
1719	return rc;
1720}
1721
1722/*
1723 * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1724 * do every scenario and is fully consistent with what an iommu_domain would
1725 * see.
1726 */
1727static int iopt_pages_rw_slow(struct iopt_pages *pages,
1728			      unsigned long start_index,
1729			      unsigned long last_index, unsigned long offset,
1730			      void *data, unsigned long length,
1731			      unsigned int flags)
1732{
1733	struct pfn_reader pfns;
1734	int rc;
1735
1736	mutex_lock(&pages->mutex);
1737
1738	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1739	if (rc)
1740		goto out_unlock;
1741
1742	while (!pfn_reader_done(&pfns)) {
1743		unsigned long done;
1744
1745		done = batch_rw(&pfns.batch, data, offset, length, flags);
1746		data += done;
1747		length -= done;
1748		offset = 0;
1749		pfn_reader_unpin(&pfns);
1750
1751		rc = pfn_reader_next(&pfns);
1752		if (rc)
1753			goto out_destroy;
1754	}
1755	if (WARN_ON(length != 0))
1756		rc = -EINVAL;
1757out_destroy:
1758	pfn_reader_destroy(&pfns);
1759out_unlock:
1760	mutex_unlock(&pages->mutex);
1761	return rc;
1762}
1763
1764/*
1765 * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1766 * memory allocations or interval tree searches.
1767 */
1768static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1769			      unsigned long offset, void *data,
1770			      unsigned long length, unsigned int flags)
1771{
1772	struct page *page = NULL;
1773	int rc;
1774
1775	if (!mmget_not_zero(pages->source_mm))
1776		return iopt_pages_rw_slow(pages, index, index, offset, data,
1777					  length, flags);
1778
1779	if (iommufd_should_fail()) {
1780		rc = -EINVAL;
1781		goto out_mmput;
1782	}
1783
1784	mmap_read_lock(pages->source_mm);
1785	rc = pin_user_pages_remote(
1786		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1787		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1788		NULL, NULL);
1789	mmap_read_unlock(pages->source_mm);
1790	if (rc != 1) {
1791		if (WARN_ON(rc >= 0))
1792			rc = -EINVAL;
1793		goto out_mmput;
1794	}
1795	copy_data_page(page, data, offset, length, flags);
1796	unpin_user_page(page);
1797	rc = 0;
1798
1799out_mmput:
1800	mmput(pages->source_mm);
1801	return rc;
1802}
1803
1804/**
1805 * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1806 * @pages: pages to act on
1807 * @start_byte: First byte of pages to copy to/from
1808 * @data: Kernel buffer to get/put the data
1809 * @length: Number of bytes to copy
1810 * @flags: IOMMUFD_ACCESS_RW_* flags
1811 *
1812 * This will find each page in the range, kmap it and then memcpy to/from
1813 * the given kernel buffer.
1814 */
1815int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1816			 void *data, unsigned long length, unsigned int flags)
1817{
1818	unsigned long start_index = start_byte / PAGE_SIZE;
1819	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1820	bool change_mm = current->mm != pages->source_mm;
1821	int rc = 0;
1822
1823	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1824	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1825		change_mm = true;
1826
1827	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1828		return -EPERM;
1829
1830	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1831		if (start_index == last_index)
1832			return iopt_pages_rw_page(pages, start_index,
1833						  start_byte % PAGE_SIZE, data,
1834						  length, flags);
1835		return iopt_pages_rw_slow(pages, start_index, last_index,
1836					  start_byte % PAGE_SIZE, data, length,
1837					  flags);
1838	}
1839
1840	/*
1841	 * Try to copy using copy_to_user(). We do this as a fast path and
1842	 * ignore any pinning inconsistencies, unlike a real DMA path.
1843	 */
1844	if (change_mm) {
1845		if (!mmget_not_zero(pages->source_mm))
1846			return iopt_pages_rw_slow(pages, start_index,
1847						  last_index,
1848						  start_byte % PAGE_SIZE, data,
1849						  length, flags);
1850		kthread_use_mm(pages->source_mm);
1851	}
1852
1853	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1854		if (copy_to_user(pages->uptr + start_byte, data, length))
1855			rc = -EFAULT;
1856	} else {
1857		if (copy_from_user(data, pages->uptr + start_byte, length))
1858			rc = -EFAULT;
1859	}
1860
1861	if (change_mm) {
1862		kthread_unuse_mm(pages->source_mm);
1863		mmput(pages->source_mm);
1864	}
1865
1866	return rc;
1867}
1868
1869static struct iopt_pages_access *
1870iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1871			    unsigned long last)
1872{
1873	struct interval_tree_node *node;
1874
1875	lockdep_assert_held(&pages->mutex);
1876
1877	/* There can be overlapping ranges in this interval tree */
1878	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1879	     node; node = interval_tree_iter_next(node, index, last))
1880		if (node->start == index && node->last == last)
1881			return container_of(node, struct iopt_pages_access,
1882					    node);
1883	return NULL;
1884}
1885
1886/**
1887 * iopt_area_add_access() - Record an in-knerel access for PFNs
1888 * @area: The source of PFNs
1889 * @start_index: First page index
1890 * @last_index: Inclusive last page index
1891 * @out_pages: Output list of struct page's representing the PFNs
1892 * @flags: IOMMUFD_ACCESS_RW_* flags
1893 *
1894 * Record that an in-kernel access will be accessing the pages, ensure they are
1895 * pinned, and return the PFNs as a simple list of 'struct page *'.
1896 *
1897 * This should be undone through a matching call to iopt_area_remove_access()
1898 */
1899int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1900			  unsigned long last_index, struct page **out_pages,
1901			  unsigned int flags)
1902{
1903	struct iopt_pages *pages = area->pages;
1904	struct iopt_pages_access *access;
1905	int rc;
1906
1907	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1908		return -EPERM;
1909
1910	mutex_lock(&pages->mutex);
1911	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1912	if (access) {
1913		area->num_accesses++;
1914		access->users++;
1915		iopt_pages_fill_from_xarray(pages, start_index, last_index,
1916					    out_pages);
1917		mutex_unlock(&pages->mutex);
1918		return 0;
1919	}
1920
1921	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1922	if (!access) {
1923		rc = -ENOMEM;
1924		goto err_unlock;
1925	}
1926
1927	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1928	if (rc)
1929		goto err_free;
1930
1931	access->node.start = start_index;
1932	access->node.last = last_index;
1933	access->users = 1;
1934	area->num_accesses++;
1935	interval_tree_insert(&access->node, &pages->access_itree);
1936	mutex_unlock(&pages->mutex);
1937	return 0;
1938
1939err_free:
1940	kfree(access);
1941err_unlock:
1942	mutex_unlock(&pages->mutex);
1943	return rc;
1944}
1945
1946/**
1947 * iopt_area_remove_access() - Release an in-kernel access for PFNs
1948 * @area: The source of PFNs
1949 * @start_index: First page index
1950 * @last_index: Inclusive last page index
1951 *
1952 * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1953 * must stop using the PFNs before calling this.
1954 */
1955void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1956			     unsigned long last_index)
1957{
1958	struct iopt_pages *pages = area->pages;
1959	struct iopt_pages_access *access;
1960
1961	mutex_lock(&pages->mutex);
1962	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1963	if (WARN_ON(!access))
1964		goto out_unlock;
1965
1966	WARN_ON(area->num_accesses == 0 || access->users == 0);
1967	area->num_accesses--;
1968	access->users--;
1969	if (access->users)
1970		goto out_unlock;
1971
1972	interval_tree_remove(&access->node, &pages->access_itree);
1973	iopt_pages_unfill_xarray(pages, start_index, last_index);
1974	kfree(access);
1975out_unlock:
1976	mutex_unlock(&pages->mutex);
1977}