Linux Audio

Check our new training course

Loading...
Note: File does not exist in v4.6.
  1// SPDX-License-Identifier: GPL-2.0
  2/*
  3 * Amazon Nitro Secure Module driver.
  4 *
  5 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  6 *
  7 * The Nitro Secure Module implements commands via CBOR over virtio.
  8 * This driver exposes a raw message ioctls on /dev/nsm that user
  9 * space can use to issue these commands.
 10 */
 11
 12#include <linux/file.h>
 13#include <linux/fs.h>
 14#include <linux/interrupt.h>
 15#include <linux/hw_random.h>
 16#include <linux/miscdevice.h>
 17#include <linux/module.h>
 18#include <linux/mutex.h>
 19#include <linux/slab.h>
 20#include <linux/string.h>
 21#include <linux/uaccess.h>
 22#include <linux/uio.h>
 23#include <linux/virtio_config.h>
 24#include <linux/virtio_ids.h>
 25#include <linux/virtio.h>
 26#include <linux/wait.h>
 27#include <uapi/linux/nsm.h>
 28
 29/* Timeout for NSM virtqueue respose in milliseconds. */
 30#define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
 31
 32/* Maximum length input data */
 33struct nsm_data_req {
 34	u32 len;
 35	u8  data[NSM_REQUEST_MAX_SIZE];
 36};
 37
 38/* Maximum length output data */
 39struct nsm_data_resp {
 40	u32 len;
 41	u8  data[NSM_RESPONSE_MAX_SIZE];
 42};
 43
 44/* Full NSM request/response message */
 45struct nsm_msg {
 46	struct nsm_data_req req;
 47	struct nsm_data_resp resp;
 48};
 49
 50struct nsm {
 51	struct virtio_device *vdev;
 52	struct virtqueue     *vq;
 53	struct mutex          lock;
 54	struct completion     cmd_done;
 55	struct miscdevice     misc;
 56	struct hwrng          hwrng;
 57	struct work_struct    misc_init;
 58	struct nsm_msg        msg;
 59};
 60
 61/* NSM device ID */
 62static const struct virtio_device_id id_table[] = {
 63	{ VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID },
 64	{ 0 },
 65};
 66
 67static struct nsm *file_to_nsm(struct file *file)
 68{
 69	return container_of(file->private_data, struct nsm, misc);
 70}
 71
 72static struct nsm *hwrng_to_nsm(struct hwrng *rng)
 73{
 74	return container_of(rng, struct nsm, hwrng);
 75}
 76
 77#define CBOR_TYPE_MASK  0xE0
 78#define CBOR_TYPE_MAP 0xA0
 79#define CBOR_TYPE_TEXT 0x60
 80#define CBOR_TYPE_ARRAY 0x40
 81#define CBOR_HEADER_SIZE_SHORT 1
 82
 83#define CBOR_SHORT_SIZE_MAX_VALUE 23
 84#define CBOR_LONG_SIZE_U8  24
 85#define CBOR_LONG_SIZE_U16 25
 86#define CBOR_LONG_SIZE_U32 26
 87#define CBOR_LONG_SIZE_U64 27
 88
 89static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size)
 90{
 91	if (cbor_object_size == 0 || cbor_object == NULL)
 92		return false;
 93
 94	return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY;
 95}
 96
 97static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array)
 98{
 99	u8 cbor_short_size;
100	void *array_len_p;
101	u64 array_len;
102	u64 array_offset;
103
104	if (!cbor_object_is_array(cbor_object, cbor_object_size))
105		return -EFAULT;
106
107	cbor_short_size = (cbor_object[0] & 0x1F);
108
109	/* Decoding byte array length */
110	array_offset = CBOR_HEADER_SIZE_SHORT;
111	if (cbor_short_size >= CBOR_LONG_SIZE_U8)
112		array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8);
113
114	if (cbor_object_size < array_offset)
115		return -EFAULT;
116
117	array_len_p = &cbor_object[1];
118
119	switch (cbor_short_size) {
120	case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */
121		array_len = cbor_short_size;
122		break;
123	case CBOR_LONG_SIZE_U8:
124		array_len = *(u8 *)array_len_p;
125		break;
126	case CBOR_LONG_SIZE_U16:
127		array_len = be16_to_cpup((__be16 *)array_len_p);
128		break;
129	case CBOR_LONG_SIZE_U32:
130		array_len = be32_to_cpup((__be32 *)array_len_p);
131		break;
132	case CBOR_LONG_SIZE_U64:
133		array_len = be64_to_cpup((__be64 *)array_len_p);
134		break;
135	}
136
137	if (cbor_object_size < array_offset)
138		return -EFAULT;
139
140	if (cbor_object_size - array_offset < array_len)
141		return -EFAULT;
142
143	if (array_len > INT_MAX)
144		return -EFAULT;
145
146	*cbor_array = cbor_object + array_offset;
147	return array_len;
148}
149
150/* Copy the request of a raw message to kernel space */
151static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req,
152			struct nsm_raw *raw)
153{
154	/* Verify the user input size. */
155	if (raw->request.len > sizeof(req->data))
156		return -EMSGSIZE;
157
158	/* Copy the request payload */
159	if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr),
160			   raw->request.len))
161		return -EFAULT;
162
163	req->len = raw->request.len;
164
165	return 0;
166}
167
168/* Copy the response of a raw message back to user-space */
169static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp,
170			  struct nsm_raw *raw)
171{
172	/* Truncate any message that does not fit. */
173	raw->response.len = min_t(u64, raw->response.len, resp->len);
174
175	/* Copy the response content to user space */
176	if (copy_to_user(u64_to_user_ptr(raw->response.addr),
177			 resp->data, raw->response.len))
178		return -EFAULT;
179
180	return 0;
181}
182
183/* Virtqueue interrupt handler */
184static void nsm_vq_callback(struct virtqueue *vq)
185{
186	struct nsm *nsm = vq->vdev->priv;
187
188	complete(&nsm->cmd_done);
189}
190
191/* Forward a message to the NSM device and wait for the response from it */
192static int nsm_sendrecv_msg_locked(struct nsm *nsm)
193{
194	struct device *dev = &nsm->vdev->dev;
195	struct scatterlist sg_in, sg_out;
196	struct nsm_msg *msg = &nsm->msg;
197	struct virtqueue *vq = nsm->vq;
198	unsigned int len;
199	void *queue_buf;
200	bool kicked;
201	int rc;
202
203	/* Initialize scatter-gather lists with request and response buffers. */
204	sg_init_one(&sg_out, msg->req.data, msg->req.len);
205	sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data));
206
207	init_completion(&nsm->cmd_done);
208	/* Add the request buffer (read by the device). */
209	rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL);
210	if (rc)
211		return rc;
212
213	/* Add the response buffer (written by the device). */
214	rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL);
215	if (rc)
216		goto cleanup;
217
218	kicked = virtqueue_kick(vq);
219	if (!kicked) {
220		/* Cannot kick the virtqueue. */
221		rc = -EIO;
222		goto cleanup;
223	}
224
225	/* If the kick succeeded, wait for the device's response. */
226	if (!wait_for_completion_io_timeout(&nsm->cmd_done,
227		msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) {
228		rc = -ETIMEDOUT;
229		goto cleanup;
230	}
231
232	queue_buf = virtqueue_get_buf(vq, &len);
233	if (!queue_buf || (queue_buf != msg->req.data)) {
234		dev_err(dev, "wrong request buffer.");
235		rc = -ENODATA;
236		goto cleanup;
237	}
238
239	queue_buf = virtqueue_get_buf(vq, &len);
240	if (!queue_buf || (queue_buf != msg->resp.data)) {
241		dev_err(dev, "wrong response buffer.");
242		rc = -ENODATA;
243		goto cleanup;
244	}
245
246	msg->resp.len = len;
247
248	rc = 0;
249
250cleanup:
251	if (rc) {
252		/* Clean the virtqueue. */
253		while (virtqueue_get_buf(vq, &len) != NULL)
254			;
255	}
256
257	return rc;
258}
259
260static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req)
261{
262	/*
263	 * 69                          # text(9)
264	 *     47657452616E646F6D      # "GetRandom"
265	 */
266	const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"),
267			       'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
268
269	memcpy(req->data, request, sizeof(request));
270	req->len = sizeof(request);
271
272	return 0;
273}
274
275static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp,
276				 void *out, size_t max)
277{
278	/*
279	 * A1                          # map(1)
280	 *     69                      # text(9) - Name of field
281	 *         47657452616E646F6D  # "GetRandom"
282	 * A1                          # map(1) - The field itself
283	 *     66                      # text(6)
284	 *         72616E646F6D        # "random"
285	 *	# The rest of the response is random data
286	 */
287	const u8 response[] = { CBOR_TYPE_MAP + 1,
288				CBOR_TYPE_TEXT + strlen("GetRandom"),
289				'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
290				CBOR_TYPE_MAP + 1,
291				CBOR_TYPE_TEXT + strlen("random"),
292				'r', 'a', 'n', 'd', 'o', 'm' };
293	struct device *dev = &nsm->vdev->dev;
294	u8 *rand_data = NULL;
295	u8 *resp_ptr = resp->data;
296	u64 resp_len = resp->len;
297	int rc;
298
299	if ((resp->len < sizeof(response) + 1) ||
300	    (memcmp(resp_ptr, response, sizeof(response)) != 0)) {
301		dev_err(dev, "Invalid response for GetRandom");
302		return -EFAULT;
303	}
304
305	resp_ptr += sizeof(response);
306	resp_len -= sizeof(response);
307
308	rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data);
309	if (rc < 0) {
310		dev_err(dev, "GetRandom: Invalid CBOR encoding\n");
311		return rc;
312	}
313
314	rc = min_t(size_t, rc, max);
315	memcpy(out, rand_data, rc);
316
317	return rc;
318}
319
320/*
321 * HwRNG implementation
322 */
323static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait)
324{
325	struct nsm *nsm = hwrng_to_nsm(rng);
326	struct device *dev = &nsm->vdev->dev;
327	int rc = 0;
328
329	/* NSM always needs to wait for a response */
330	if (!wait)
331		return 0;
332
333	mutex_lock(&nsm->lock);
334
335	rc = fill_req_get_random(nsm, &nsm->msg.req);
336	if (rc != 0)
337		goto out;
338
339	rc = nsm_sendrecv_msg_locked(nsm);
340	if (rc != 0)
341		goto out;
342
343	rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max);
344	if (rc < 0)
345		goto out;
346
347	dev_dbg(dev, "RNG: returning rand bytes = %d", rc);
348out:
349	mutex_unlock(&nsm->lock);
350	return rc;
351}
352
353static long nsm_dev_ioctl(struct file *file, unsigned int cmd,
354	unsigned long arg)
355{
356	void __user *argp = u64_to_user_ptr((u64)arg);
357	struct nsm *nsm = file_to_nsm(file);
358	struct nsm_raw raw;
359	int r = 0;
360
361	if (cmd != NSM_IOCTL_RAW)
362		return -EINVAL;
363
364	if (_IOC_SIZE(cmd) != sizeof(raw))
365		return -EINVAL;
366
367	/* Copy user argument struct to kernel argument struct */
368	r = -EFAULT;
369	if (copy_from_user(&raw, argp, _IOC_SIZE(cmd)))
370		goto out;
371
372	mutex_lock(&nsm->lock);
373
374	/* Convert kernel argument struct to device request */
375	r = fill_req_raw(nsm, &nsm->msg.req, &raw);
376	if (r)
377		goto out;
378
379	/* Send message to NSM and read reply */
380	r = nsm_sendrecv_msg_locked(nsm);
381	if (r)
382		goto out;
383
384	/* Parse device response into kernel argument struct */
385	r = parse_resp_raw(nsm, &nsm->msg.resp, &raw);
386	if (r)
387		goto out;
388
389	/* Copy kernel argument struct back to user argument struct */
390	r = -EFAULT;
391	if (copy_to_user(argp, &raw, sizeof(raw)))
392		goto out;
393
394	r = 0;
395
396out:
397	mutex_unlock(&nsm->lock);
398	return r;
399}
400
401static int nsm_device_init_vq(struct virtio_device *vdev)
402{
403	struct virtqueue *vq = virtio_find_single_vq(vdev,
404		nsm_vq_callback, "nsm.vq.0");
405	struct nsm *nsm = vdev->priv;
406
407	if (IS_ERR(vq))
408		return PTR_ERR(vq);
409
410	nsm->vq = vq;
411
412	return 0;
413}
414
415static const struct file_operations nsm_dev_fops = {
416	.unlocked_ioctl = nsm_dev_ioctl,
417	.compat_ioctl = compat_ptr_ioctl,
418};
419
420/* Handler for probing the NSM device */
421static int nsm_device_probe(struct virtio_device *vdev)
422{
423	struct device *dev = &vdev->dev;
424	struct nsm *nsm;
425	int rc;
426
427	nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL);
428	if (!nsm)
429		return -ENOMEM;
430
431	vdev->priv = nsm;
432	nsm->vdev = vdev;
433
434	rc = nsm_device_init_vq(vdev);
435	if (rc) {
436		dev_err(dev, "queue failed to initialize: %d.\n", rc);
437		goto err_init_vq;
438	}
439
440	mutex_init(&nsm->lock);
441
442	/* Register as hwrng provider */
443	nsm->hwrng = (struct hwrng) {
444		.read = nsm_rng_read,
445		.name = "nsm-hwrng",
446		.quality = 1000,
447	};
448
449	rc = hwrng_register(&nsm->hwrng);
450	if (rc) {
451		dev_err(dev, "RNG initialization error: %d.\n", rc);
452		goto err_hwrng;
453	}
454
455	/* Register /dev/nsm device node */
456	nsm->misc = (struct miscdevice) {
457		.minor	= MISC_DYNAMIC_MINOR,
458		.name	= "nsm",
459		.fops	= &nsm_dev_fops,
460		.mode	= 0666,
461	};
462
463	rc = misc_register(&nsm->misc);
464	if (rc) {
465		dev_err(dev, "misc device registration error: %d.\n", rc);
466		goto err_misc;
467	}
468
469	return 0;
470
471err_misc:
472	hwrng_unregister(&nsm->hwrng);
473err_hwrng:
474	vdev->config->del_vqs(vdev);
475err_init_vq:
476	return rc;
477}
478
479/* Handler for removing the NSM device */
480static void nsm_device_remove(struct virtio_device *vdev)
481{
482	struct nsm *nsm = vdev->priv;
483
484	hwrng_unregister(&nsm->hwrng);
485
486	vdev->config->del_vqs(vdev);
487	misc_deregister(&nsm->misc);
488}
489
490/* NSM device configuration structure */
491static struct virtio_driver virtio_nsm_driver = {
492	.feature_table             = 0,
493	.feature_table_size        = 0,
494	.feature_table_legacy      = 0,
495	.feature_table_size_legacy = 0,
496	.driver.name               = KBUILD_MODNAME,
497	.driver.owner              = THIS_MODULE,
498	.id_table                  = id_table,
499	.probe                     = nsm_device_probe,
500	.remove                    = nsm_device_remove,
501};
502
503module_virtio_driver(virtio_nsm_driver);
504MODULE_DEVICE_TABLE(virtio, id_table);
505MODULE_DESCRIPTION("Virtio NSM driver");
506MODULE_LICENSE("GPL");