Linux Audio

Check our new training course

Loading...
Note: File does not exist in v5.9.
  1// SPDX-License-Identifier: GPL-2.0-only
  2/*
  3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
  4 * Author: Joerg Roedel <jroedel@suse.de>
  5 */
  6
  7#define pr_fmt(fmt)     "AMD-Vi: " fmt
  8
  9#include <linux/mmu_notifier.h>
 10#include <linux/amd-iommu.h>
 11#include <linux/mm_types.h>
 12#include <linux/profile.h>
 13#include <linux/module.h>
 14#include <linux/sched.h>
 15#include <linux/sched/mm.h>
 16#include <linux/iommu.h>
 17#include <linux/wait.h>
 18#include <linux/pci.h>
 19#include <linux/gfp.h>
 20
 21#include "amd_iommu_types.h"
 22#include "amd_iommu_proto.h"
 23
 24MODULE_LICENSE("GPL v2");
 25MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
 26
 27#define MAX_DEVICES		0x10000
 28#define PRI_QUEUE_SIZE		512
 29
 30struct pri_queue {
 31	atomic_t inflight;
 32	bool finish;
 33	int status;
 34};
 35
 36struct pasid_state {
 37	struct list_head list;			/* For global state-list */
 38	atomic_t count;				/* Reference count */
 39	unsigned mmu_notifier_count;		/* Counting nested mmu_notifier
 40						   calls */
 41	struct mm_struct *mm;			/* mm_struct for the faults */
 42	struct mmu_notifier mn;                 /* mmu_notifier handle */
 43	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
 44	struct device_state *device_state;	/* Link to our device_state */
 45	int pasid;				/* PASID index */
 46	bool invalid;				/* Used during setup and
 47						   teardown of the pasid */
 48	spinlock_t lock;			/* Protect pri_queues and
 49						   mmu_notifer_count */
 50	wait_queue_head_t wq;			/* To wait for count == 0 */
 51};
 52
 53struct device_state {
 54	struct list_head list;
 55	u16 devid;
 56	atomic_t count;
 57	struct pci_dev *pdev;
 58	struct pasid_state **states;
 59	struct iommu_domain *domain;
 60	int pasid_levels;
 61	int max_pasids;
 62	amd_iommu_invalid_ppr_cb inv_ppr_cb;
 63	amd_iommu_invalidate_ctx inv_ctx_cb;
 64	spinlock_t lock;
 65	wait_queue_head_t wq;
 66};
 67
 68struct fault {
 69	struct work_struct work;
 70	struct device_state *dev_state;
 71	struct pasid_state *state;
 72	struct mm_struct *mm;
 73	u64 address;
 74	u16 devid;
 75	u16 pasid;
 76	u16 tag;
 77	u16 finish;
 78	u16 flags;
 79};
 80
 81static LIST_HEAD(state_list);
 82static spinlock_t state_lock;
 83
 84static struct workqueue_struct *iommu_wq;
 85
 86static void free_pasid_states(struct device_state *dev_state);
 87
 88static u16 device_id(struct pci_dev *pdev)
 89{
 90	u16 devid;
 91
 92	devid = pdev->bus->number;
 93	devid = (devid << 8) | pdev->devfn;
 94
 95	return devid;
 96}
 97
 98static struct device_state *__get_device_state(u16 devid)
 99{
100	struct device_state *dev_state;
101
102	list_for_each_entry(dev_state, &state_list, list) {
103		if (dev_state->devid == devid)
104			return dev_state;
105	}
106
107	return NULL;
108}
109
110static struct device_state *get_device_state(u16 devid)
111{
112	struct device_state *dev_state;
113	unsigned long flags;
114
115	spin_lock_irqsave(&state_lock, flags);
116	dev_state = __get_device_state(devid);
117	if (dev_state != NULL)
118		atomic_inc(&dev_state->count);
119	spin_unlock_irqrestore(&state_lock, flags);
120
121	return dev_state;
122}
123
124static void free_device_state(struct device_state *dev_state)
125{
126	struct iommu_group *group;
127
128	/*
129	 * First detach device from domain - No more PRI requests will arrive
130	 * from that device after it is unbound from the IOMMUv2 domain.
131	 */
132	group = iommu_group_get(&dev_state->pdev->dev);
133	if (WARN_ON(!group))
134		return;
135
136	iommu_detach_group(dev_state->domain, group);
137
138	iommu_group_put(group);
139
140	/* Everything is down now, free the IOMMUv2 domain */
141	iommu_domain_free(dev_state->domain);
142
143	/* Finally get rid of the device-state */
144	kfree(dev_state);
145}
146
147static void put_device_state(struct device_state *dev_state)
148{
149	if (atomic_dec_and_test(&dev_state->count))
150		wake_up(&dev_state->wq);
151}
152
153/* Must be called under dev_state->lock */
154static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
155						  int pasid, bool alloc)
156{
157	struct pasid_state **root, **ptr;
158	int level, index;
159
160	level = dev_state->pasid_levels;
161	root  = dev_state->states;
162
163	while (true) {
164
165		index = (pasid >> (9 * level)) & 0x1ff;
166		ptr   = &root[index];
167
168		if (level == 0)
169			break;
170
171		if (*ptr == NULL) {
172			if (!alloc)
173				return NULL;
174
175			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
176			if (*ptr == NULL)
177				return NULL;
178		}
179
180		root   = (struct pasid_state **)*ptr;
181		level -= 1;
182	}
183
184	return ptr;
185}
186
187static int set_pasid_state(struct device_state *dev_state,
188			   struct pasid_state *pasid_state,
189			   int pasid)
190{
191	struct pasid_state **ptr;
192	unsigned long flags;
193	int ret;
194
195	spin_lock_irqsave(&dev_state->lock, flags);
196	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
197
198	ret = -ENOMEM;
199	if (ptr == NULL)
200		goto out_unlock;
201
202	ret = -ENOMEM;
203	if (*ptr != NULL)
204		goto out_unlock;
205
206	*ptr = pasid_state;
207
208	ret = 0;
209
210out_unlock:
211	spin_unlock_irqrestore(&dev_state->lock, flags);
212
213	return ret;
214}
215
216static void clear_pasid_state(struct device_state *dev_state, int pasid)
217{
218	struct pasid_state **ptr;
219	unsigned long flags;
220
221	spin_lock_irqsave(&dev_state->lock, flags);
222	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
223
224	if (ptr == NULL)
225		goto out_unlock;
226
227	*ptr = NULL;
228
229out_unlock:
230	spin_unlock_irqrestore(&dev_state->lock, flags);
231}
232
233static struct pasid_state *get_pasid_state(struct device_state *dev_state,
234					   int pasid)
235{
236	struct pasid_state **ptr, *ret = NULL;
237	unsigned long flags;
238
239	spin_lock_irqsave(&dev_state->lock, flags);
240	ptr = __get_pasid_state_ptr(dev_state, pasid, false);
241
242	if (ptr == NULL)
243		goto out_unlock;
244
245	ret = *ptr;
246	if (ret)
247		atomic_inc(&ret->count);
248
249out_unlock:
250	spin_unlock_irqrestore(&dev_state->lock, flags);
251
252	return ret;
253}
254
255static void free_pasid_state(struct pasid_state *pasid_state)
256{
257	kfree(pasid_state);
258}
259
260static void put_pasid_state(struct pasid_state *pasid_state)
261{
262	if (atomic_dec_and_test(&pasid_state->count))
263		wake_up(&pasid_state->wq);
264}
265
266static void put_pasid_state_wait(struct pasid_state *pasid_state)
267{
268	atomic_dec(&pasid_state->count);
269	wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
270	free_pasid_state(pasid_state);
271}
272
273static void unbind_pasid(struct pasid_state *pasid_state)
274{
275	struct iommu_domain *domain;
276
277	domain = pasid_state->device_state->domain;
278
279	/*
280	 * Mark pasid_state as invalid, no more faults will we added to the
281	 * work queue after this is visible everywhere.
282	 */
283	pasid_state->invalid = true;
284
285	/* Make sure this is visible */
286	smp_wmb();
287
288	/* After this the device/pasid can't access the mm anymore */
289	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
290
291	/* Make sure no more pending faults are in the queue */
292	flush_workqueue(iommu_wq);
293}
294
295static void free_pasid_states_level1(struct pasid_state **tbl)
296{
297	int i;
298
299	for (i = 0; i < 512; ++i) {
300		if (tbl[i] == NULL)
301			continue;
302
303		free_page((unsigned long)tbl[i]);
304	}
305}
306
307static void free_pasid_states_level2(struct pasid_state **tbl)
308{
309	struct pasid_state **ptr;
310	int i;
311
312	for (i = 0; i < 512; ++i) {
313		if (tbl[i] == NULL)
314			continue;
315
316		ptr = (struct pasid_state **)tbl[i];
317		free_pasid_states_level1(ptr);
318	}
319}
320
321static void free_pasid_states(struct device_state *dev_state)
322{
323	struct pasid_state *pasid_state;
324	int i;
325
326	for (i = 0; i < dev_state->max_pasids; ++i) {
327		pasid_state = get_pasid_state(dev_state, i);
328		if (pasid_state == NULL)
329			continue;
330
331		put_pasid_state(pasid_state);
332
333		/*
334		 * This will call the mn_release function and
335		 * unbind the PASID
336		 */
337		mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
338
339		put_pasid_state_wait(pasid_state); /* Reference taken in
340						      amd_iommu_bind_pasid */
341
342		/* Drop reference taken in amd_iommu_bind_pasid */
343		put_device_state(dev_state);
344	}
345
346	if (dev_state->pasid_levels == 2)
347		free_pasid_states_level2(dev_state->states);
348	else if (dev_state->pasid_levels == 1)
349		free_pasid_states_level1(dev_state->states);
350	else
351		BUG_ON(dev_state->pasid_levels != 0);
352
353	free_page((unsigned long)dev_state->states);
354}
355
356static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
357{
358	return container_of(mn, struct pasid_state, mn);
359}
360
361static void mn_invalidate_range(struct mmu_notifier *mn,
362				struct mm_struct *mm,
363				unsigned long start, unsigned long end)
364{
365	struct pasid_state *pasid_state;
366	struct device_state *dev_state;
367
368	pasid_state = mn_to_state(mn);
369	dev_state   = pasid_state->device_state;
370
371	if ((start ^ (end - 1)) < PAGE_SIZE)
372		amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
373				     start);
374	else
375		amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
376}
377
378static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
379{
380	struct pasid_state *pasid_state;
381	struct device_state *dev_state;
382	bool run_inv_ctx_cb;
383
384	might_sleep();
385
386	pasid_state    = mn_to_state(mn);
387	dev_state      = pasid_state->device_state;
388	run_inv_ctx_cb = !pasid_state->invalid;
389
390	if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
391		dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
392
393	unbind_pasid(pasid_state);
394}
395
396static const struct mmu_notifier_ops iommu_mn = {
397	.release		= mn_release,
398	.invalidate_range       = mn_invalidate_range,
399};
400
401static void set_pri_tag_status(struct pasid_state *pasid_state,
402			       u16 tag, int status)
403{
404	unsigned long flags;
405
406	spin_lock_irqsave(&pasid_state->lock, flags);
407	pasid_state->pri[tag].status = status;
408	spin_unlock_irqrestore(&pasid_state->lock, flags);
409}
410
411static void finish_pri_tag(struct device_state *dev_state,
412			   struct pasid_state *pasid_state,
413			   u16 tag)
414{
415	unsigned long flags;
416
417	spin_lock_irqsave(&pasid_state->lock, flags);
418	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
419	    pasid_state->pri[tag].finish) {
420		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
421				       pasid_state->pri[tag].status, tag);
422		pasid_state->pri[tag].finish = false;
423		pasid_state->pri[tag].status = PPR_SUCCESS;
424	}
425	spin_unlock_irqrestore(&pasid_state->lock, flags);
426}
427
428static void handle_fault_error(struct fault *fault)
429{
430	int status;
431
432	if (!fault->dev_state->inv_ppr_cb) {
433		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
434		return;
435	}
436
437	status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
438					      fault->pasid,
439					      fault->address,
440					      fault->flags);
441	switch (status) {
442	case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
443		set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
444		break;
445	case AMD_IOMMU_INV_PRI_RSP_INVALID:
446		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
447		break;
448	case AMD_IOMMU_INV_PRI_RSP_FAIL:
449		set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
450		break;
451	default:
452		BUG();
453	}
454}
455
456static bool access_error(struct vm_area_struct *vma, struct fault *fault)
457{
458	unsigned long requested = 0;
459
460	if (fault->flags & PPR_FAULT_EXEC)
461		requested |= VM_EXEC;
462
463	if (fault->flags & PPR_FAULT_READ)
464		requested |= VM_READ;
465
466	if (fault->flags & PPR_FAULT_WRITE)
467		requested |= VM_WRITE;
468
469	return (requested & ~vma->vm_flags) != 0;
470}
471
472static void do_fault(struct work_struct *work)
473{
474	struct fault *fault = container_of(work, struct fault, work);
475	struct vm_area_struct *vma;
476	vm_fault_t ret = VM_FAULT_ERROR;
477	unsigned int flags = 0;
478	struct mm_struct *mm;
479	u64 address;
480
481	mm = fault->state->mm;
482	address = fault->address;
483
484	if (fault->flags & PPR_FAULT_USER)
485		flags |= FAULT_FLAG_USER;
486	if (fault->flags & PPR_FAULT_WRITE)
487		flags |= FAULT_FLAG_WRITE;
488	flags |= FAULT_FLAG_REMOTE;
489
490	down_read(&mm->mmap_sem);
491	vma = find_extend_vma(mm, address);
492	if (!vma || address < vma->vm_start)
493		/* failed to get a vma in the right range */
494		goto out;
495
496	/* Check if we have the right permissions on the vma */
497	if (access_error(vma, fault))
498		goto out;
499
500	ret = handle_mm_fault(vma, address, flags);
501out:
502	up_read(&mm->mmap_sem);
503
504	if (ret & VM_FAULT_ERROR)
505		/* failed to service fault */
506		handle_fault_error(fault);
507
508	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
509
510	put_pasid_state(fault->state);
511
512	kfree(fault);
513}
514
515static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
516{
517	struct amd_iommu_fault *iommu_fault;
518	struct pasid_state *pasid_state;
519	struct device_state *dev_state;
520	unsigned long flags;
521	struct fault *fault;
522	bool finish;
523	u16 tag, devid;
524	int ret;
525	struct iommu_dev_data *dev_data;
526	struct pci_dev *pdev = NULL;
527
528	iommu_fault = data;
529	tag         = iommu_fault->tag & 0x1ff;
530	finish      = (iommu_fault->tag >> 9) & 1;
531
532	devid = iommu_fault->device_id;
533	pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
534					   devid & 0xff);
535	if (!pdev)
536		return -ENODEV;
537	dev_data = get_dev_data(&pdev->dev);
538
539	/* In kdump kernel pci dev is not initialized yet -> send INVALID */
540	ret = NOTIFY_DONE;
541	if (translation_pre_enabled(amd_iommu_rlookup_table[devid])
542		&& dev_data->defer_attach) {
543		amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
544				       PPR_INVALID, tag);
545		goto out;
546	}
547
548	dev_state = get_device_state(iommu_fault->device_id);
549	if (dev_state == NULL)
550		goto out;
551
552	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
553	if (pasid_state == NULL || pasid_state->invalid) {
554		/* We know the device but not the PASID -> send INVALID */
555		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
556				       PPR_INVALID, tag);
557		goto out_drop_state;
558	}
559
560	spin_lock_irqsave(&pasid_state->lock, flags);
561	atomic_inc(&pasid_state->pri[tag].inflight);
562	if (finish)
563		pasid_state->pri[tag].finish = true;
564	spin_unlock_irqrestore(&pasid_state->lock, flags);
565
566	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
567	if (fault == NULL) {
568		/* We are OOM - send success and let the device re-fault */
569		finish_pri_tag(dev_state, pasid_state, tag);
570		goto out_drop_state;
571	}
572
573	fault->dev_state = dev_state;
574	fault->address   = iommu_fault->address;
575	fault->state     = pasid_state;
576	fault->tag       = tag;
577	fault->finish    = finish;
578	fault->pasid     = iommu_fault->pasid;
579	fault->flags     = iommu_fault->flags;
580	INIT_WORK(&fault->work, do_fault);
581
582	queue_work(iommu_wq, &fault->work);
583
584	ret = NOTIFY_OK;
585
586out_drop_state:
587
588	if (ret != NOTIFY_OK && pasid_state)
589		put_pasid_state(pasid_state);
590
591	put_device_state(dev_state);
592
593out:
594	return ret;
595}
596
597static struct notifier_block ppr_nb = {
598	.notifier_call = ppr_notifier,
599};
600
601int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
602			 struct task_struct *task)
603{
604	struct pasid_state *pasid_state;
605	struct device_state *dev_state;
606	struct mm_struct *mm;
607	u16 devid;
608	int ret;
609
610	might_sleep();
611
612	if (!amd_iommu_v2_supported())
613		return -ENODEV;
614
615	devid     = device_id(pdev);
616	dev_state = get_device_state(devid);
617
618	if (dev_state == NULL)
619		return -EINVAL;
620
621	ret = -EINVAL;
622	if (pasid < 0 || pasid >= dev_state->max_pasids)
623		goto out;
624
625	ret = -ENOMEM;
626	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
627	if (pasid_state == NULL)
628		goto out;
629
630
631	atomic_set(&pasid_state->count, 1);
632	init_waitqueue_head(&pasid_state->wq);
633	spin_lock_init(&pasid_state->lock);
634
635	mm                        = get_task_mm(task);
636	pasid_state->mm           = mm;
637	pasid_state->device_state = dev_state;
638	pasid_state->pasid        = pasid;
639	pasid_state->invalid      = true; /* Mark as valid only if we are
640					     done with setting up the pasid */
641	pasid_state->mn.ops       = &iommu_mn;
642
643	if (pasid_state->mm == NULL)
644		goto out_free;
645
646	mmu_notifier_register(&pasid_state->mn, mm);
647
648	ret = set_pasid_state(dev_state, pasid_state, pasid);
649	if (ret)
650		goto out_unregister;
651
652	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
653					__pa(pasid_state->mm->pgd));
654	if (ret)
655		goto out_clear_state;
656
657	/* Now we are ready to handle faults */
658	pasid_state->invalid = false;
659
660	/*
661	 * Drop the reference to the mm_struct here. We rely on the
662	 * mmu_notifier release call-back to inform us when the mm
663	 * is going away.
664	 */
665	mmput(mm);
666
667	return 0;
668
669out_clear_state:
670	clear_pasid_state(dev_state, pasid);
671
672out_unregister:
673	mmu_notifier_unregister(&pasid_state->mn, mm);
674	mmput(mm);
675
676out_free:
677	free_pasid_state(pasid_state);
678
679out:
680	put_device_state(dev_state);
681
682	return ret;
683}
684EXPORT_SYMBOL(amd_iommu_bind_pasid);
685
686void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
687{
688	struct pasid_state *pasid_state;
689	struct device_state *dev_state;
690	u16 devid;
691
692	might_sleep();
693
694	if (!amd_iommu_v2_supported())
695		return;
696
697	devid = device_id(pdev);
698	dev_state = get_device_state(devid);
699	if (dev_state == NULL)
700		return;
701
702	if (pasid < 0 || pasid >= dev_state->max_pasids)
703		goto out;
704
705	pasid_state = get_pasid_state(dev_state, pasid);
706	if (pasid_state == NULL)
707		goto out;
708	/*
709	 * Drop reference taken here. We are safe because we still hold
710	 * the reference taken in the amd_iommu_bind_pasid function.
711	 */
712	put_pasid_state(pasid_state);
713
714	/* Clear the pasid state so that the pasid can be re-used */
715	clear_pasid_state(dev_state, pasid_state->pasid);
716
717	/*
718	 * Call mmu_notifier_unregister to drop our reference
719	 * to pasid_state->mm
720	 */
721	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
722
723	put_pasid_state_wait(pasid_state); /* Reference taken in
724					      amd_iommu_bind_pasid */
725out:
726	/* Drop reference taken in this function */
727	put_device_state(dev_state);
728
729	/* Drop reference taken in amd_iommu_bind_pasid */
730	put_device_state(dev_state);
731}
732EXPORT_SYMBOL(amd_iommu_unbind_pasid);
733
734int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
735{
736	struct device_state *dev_state;
737	struct iommu_group *group;
738	unsigned long flags;
739	int ret, tmp;
740	u16 devid;
741
742	might_sleep();
743
744	if (!amd_iommu_v2_supported())
745		return -ENODEV;
746
747	if (pasids <= 0 || pasids > (PASID_MASK + 1))
748		return -EINVAL;
749
750	devid = device_id(pdev);
751
752	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
753	if (dev_state == NULL)
754		return -ENOMEM;
755
756	spin_lock_init(&dev_state->lock);
757	init_waitqueue_head(&dev_state->wq);
758	dev_state->pdev  = pdev;
759	dev_state->devid = devid;
760
761	tmp = pasids;
762	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
763		dev_state->pasid_levels += 1;
764
765	atomic_set(&dev_state->count, 1);
766	dev_state->max_pasids = pasids;
767
768	ret = -ENOMEM;
769	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
770	if (dev_state->states == NULL)
771		goto out_free_dev_state;
772
773	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
774	if (dev_state->domain == NULL)
775		goto out_free_states;
776
777	amd_iommu_domain_direct_map(dev_state->domain);
778
779	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
780	if (ret)
781		goto out_free_domain;
782
783	group = iommu_group_get(&pdev->dev);
784	if (!group) {
785		ret = -EINVAL;
786		goto out_free_domain;
787	}
788
789	ret = iommu_attach_group(dev_state->domain, group);
790	if (ret != 0)
791		goto out_drop_group;
792
793	iommu_group_put(group);
794
795	spin_lock_irqsave(&state_lock, flags);
796
797	if (__get_device_state(devid) != NULL) {
798		spin_unlock_irqrestore(&state_lock, flags);
799		ret = -EBUSY;
800		goto out_free_domain;
801	}
802
803	list_add_tail(&dev_state->list, &state_list);
804
805	spin_unlock_irqrestore(&state_lock, flags);
806
807	return 0;
808
809out_drop_group:
810	iommu_group_put(group);
811
812out_free_domain:
813	iommu_domain_free(dev_state->domain);
814
815out_free_states:
816	free_page((unsigned long)dev_state->states);
817
818out_free_dev_state:
819	kfree(dev_state);
820
821	return ret;
822}
823EXPORT_SYMBOL(amd_iommu_init_device);
824
825void amd_iommu_free_device(struct pci_dev *pdev)
826{
827	struct device_state *dev_state;
828	unsigned long flags;
829	u16 devid;
830
831	if (!amd_iommu_v2_supported())
832		return;
833
834	devid = device_id(pdev);
835
836	spin_lock_irqsave(&state_lock, flags);
837
838	dev_state = __get_device_state(devid);
839	if (dev_state == NULL) {
840		spin_unlock_irqrestore(&state_lock, flags);
841		return;
842	}
843
844	list_del(&dev_state->list);
845
846	spin_unlock_irqrestore(&state_lock, flags);
847
848	/* Get rid of any remaining pasid states */
849	free_pasid_states(dev_state);
850
851	put_device_state(dev_state);
852	/*
853	 * Wait until the last reference is dropped before freeing
854	 * the device state.
855	 */
856	wait_event(dev_state->wq, !atomic_read(&dev_state->count));
857	free_device_state(dev_state);
858}
859EXPORT_SYMBOL(amd_iommu_free_device);
860
861int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
862				 amd_iommu_invalid_ppr_cb cb)
863{
864	struct device_state *dev_state;
865	unsigned long flags;
866	u16 devid;
867	int ret;
868
869	if (!amd_iommu_v2_supported())
870		return -ENODEV;
871
872	devid = device_id(pdev);
873
874	spin_lock_irqsave(&state_lock, flags);
875
876	ret = -EINVAL;
877	dev_state = __get_device_state(devid);
878	if (dev_state == NULL)
879		goto out_unlock;
880
881	dev_state->inv_ppr_cb = cb;
882
883	ret = 0;
884
885out_unlock:
886	spin_unlock_irqrestore(&state_lock, flags);
887
888	return ret;
889}
890EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
891
892int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
893				    amd_iommu_invalidate_ctx cb)
894{
895	struct device_state *dev_state;
896	unsigned long flags;
897	u16 devid;
898	int ret;
899
900	if (!amd_iommu_v2_supported())
901		return -ENODEV;
902
903	devid = device_id(pdev);
904
905	spin_lock_irqsave(&state_lock, flags);
906
907	ret = -EINVAL;
908	dev_state = __get_device_state(devid);
909	if (dev_state == NULL)
910		goto out_unlock;
911
912	dev_state->inv_ctx_cb = cb;
913
914	ret = 0;
915
916out_unlock:
917	spin_unlock_irqrestore(&state_lock, flags);
918
919	return ret;
920}
921EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
922
923static int __init amd_iommu_v2_init(void)
924{
925	int ret;
926
927	pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
928
929	if (!amd_iommu_v2_supported()) {
930		pr_info("AMD IOMMUv2 functionality not available on this system\n");
931		/*
932		 * Load anyway to provide the symbols to other modules
933		 * which may use AMD IOMMUv2 optionally.
934		 */
935		return 0;
936	}
937
938	spin_lock_init(&state_lock);
939
940	ret = -ENOMEM;
941	iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
942	if (iommu_wq == NULL)
943		goto out;
944
945	amd_iommu_register_ppr_notifier(&ppr_nb);
946
947	return 0;
948
949out:
950	return ret;
951}
952
953static void __exit amd_iommu_v2_exit(void)
954{
955	struct device_state *dev_state;
956	int i;
957
958	if (!amd_iommu_v2_supported())
959		return;
960
961	amd_iommu_unregister_ppr_notifier(&ppr_nb);
962
963	flush_workqueue(iommu_wq);
964
965	/*
966	 * The loop below might call flush_workqueue(), so call
967	 * destroy_workqueue() after it
968	 */
969	for (i = 0; i < MAX_DEVICES; ++i) {
970		dev_state = get_device_state(i);
971
972		if (dev_state == NULL)
973			continue;
974
975		WARN_ON_ONCE(1);
976
977		put_device_state(dev_state);
978		amd_iommu_free_device(dev_state->pdev);
979	}
980
981	destroy_workqueue(iommu_wq);
982}
983
984module_init(amd_iommu_v2_init);
985module_exit(amd_iommu_v2_exit);