Linux Audio

Check our new training course

Loading...
   1/*
   2 * linux/net/sunrpc/auth_gss/auth_gss.c
   3 *
   4 * RPCSEC_GSS client authentication.
   5 *
   6 *  Copyright (c) 2000 The Regents of the University of Michigan.
   7 *  All rights reserved.
   8 *
   9 *  Dug Song       <dugsong@monkey.org>
  10 *  Andy Adamson   <andros@umich.edu>
  11 *
  12 *  Redistribution and use in source and binary forms, with or without
  13 *  modification, are permitted provided that the following conditions
  14 *  are met:
  15 *
  16 *  1. Redistributions of source code must retain the above copyright
  17 *     notice, this list of conditions and the following disclaimer.
  18 *  2. Redistributions in binary form must reproduce the above copyright
  19 *     notice, this list of conditions and the following disclaimer in the
  20 *     documentation and/or other materials provided with the distribution.
  21 *  3. Neither the name of the University nor the names of its
  22 *     contributors may be used to endorse or promote products derived
  23 *     from this software without specific prior written permission.
  24 *
  25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
  26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
  29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
  32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  36 */
  37
  38
  39#include <linux/module.h>
  40#include <linux/init.h>
  41#include <linux/types.h>
  42#include <linux/slab.h>
  43#include <linux/sched.h>
  44#include <linux/pagemap.h>
  45#include <linux/sunrpc/clnt.h>
  46#include <linux/sunrpc/auth.h>
  47#include <linux/sunrpc/auth_gss.h>
  48#include <linux/sunrpc/svcauth_gss.h>
  49#include <linux/sunrpc/gss_err.h>
  50#include <linux/workqueue.h>
  51#include <linux/sunrpc/rpc_pipe_fs.h>
  52#include <linux/sunrpc/gss_api.h>
  53#include <asm/uaccess.h>
  54
  55static const struct rpc_authops authgss_ops;
  56
  57static const struct rpc_credops gss_credops;
  58static const struct rpc_credops gss_nullops;
  59
  60#define GSS_RETRY_EXPIRED 5
  61static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
  62
  63#ifdef RPC_DEBUG
  64# define RPCDBG_FACILITY	RPCDBG_AUTH
  65#endif
  66
  67#define GSS_CRED_SLACK		(RPC_MAX_AUTH_SIZE * 2)
  68/* length of a krb5 verifier (48), plus data added before arguments when
  69 * using integrity (two 4-byte integers): */
  70#define GSS_VERF_SLACK		100
  71
  72struct gss_auth {
  73	struct kref kref;
  74	struct rpc_auth rpc_auth;
  75	struct gss_api_mech *mech;
  76	enum rpc_gss_svc service;
  77	struct rpc_clnt *client;
  78	/*
  79	 * There are two upcall pipes; dentry[1], named "gssd", is used
  80	 * for the new text-based upcall; dentry[0] is named after the
  81	 * mechanism (for example, "krb5") and exists for
  82	 * backwards-compatibility with older gssd's.
  83	 */
  84	struct rpc_pipe *pipe[2];
  85};
  86
  87/* pipe_version >= 0 if and only if someone has a pipe open. */
  88static int pipe_version = -1;
  89static atomic_t pipe_users = ATOMIC_INIT(0);
  90static DEFINE_SPINLOCK(pipe_version_lock);
  91static struct rpc_wait_queue pipe_version_rpc_waitqueue;
  92static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
  93
  94static void gss_free_ctx(struct gss_cl_ctx *);
  95static const struct rpc_pipe_ops gss_upcall_ops_v0;
  96static const struct rpc_pipe_ops gss_upcall_ops_v1;
  97
  98static inline struct gss_cl_ctx *
  99gss_get_ctx(struct gss_cl_ctx *ctx)
 100{
 101	atomic_inc(&ctx->count);
 102	return ctx;
 103}
 104
 105static inline void
 106gss_put_ctx(struct gss_cl_ctx *ctx)
 107{
 108	if (atomic_dec_and_test(&ctx->count))
 109		gss_free_ctx(ctx);
 110}
 111
 112/* gss_cred_set_ctx:
 113 * called by gss_upcall_callback and gss_create_upcall in order
 114 * to set the gss context. The actual exchange of an old context
 115 * and a new one is protected by the pipe->lock.
 116 */
 117static void
 118gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
 119{
 120	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 121
 122	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
 123		return;
 124	gss_get_ctx(ctx);
 125	rcu_assign_pointer(gss_cred->gc_ctx, ctx);
 126	set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 127	smp_mb__before_clear_bit();
 128	clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
 129}
 130
 131static const void *
 132simple_get_bytes(const void *p, const void *end, void *res, size_t len)
 133{
 134	const void *q = (const void *)((const char *)p + len);
 135	if (unlikely(q > end || q < p))
 136		return ERR_PTR(-EFAULT);
 137	memcpy(res, p, len);
 138	return q;
 139}
 140
 141static inline const void *
 142simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
 143{
 144	const void *q;
 145	unsigned int len;
 146
 147	p = simple_get_bytes(p, end, &len, sizeof(len));
 148	if (IS_ERR(p))
 149		return p;
 150	q = (const void *)((const char *)p + len);
 151	if (unlikely(q > end || q < p))
 152		return ERR_PTR(-EFAULT);
 153	dest->data = kmemdup(p, len, GFP_NOFS);
 154	if (unlikely(dest->data == NULL))
 155		return ERR_PTR(-ENOMEM);
 156	dest->len = len;
 157	return q;
 158}
 159
 160static struct gss_cl_ctx *
 161gss_cred_get_ctx(struct rpc_cred *cred)
 162{
 163	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 164	struct gss_cl_ctx *ctx = NULL;
 165
 166	rcu_read_lock();
 167	if (gss_cred->gc_ctx)
 168		ctx = gss_get_ctx(gss_cred->gc_ctx);
 169	rcu_read_unlock();
 170	return ctx;
 171}
 172
 173static struct gss_cl_ctx *
 174gss_alloc_context(void)
 175{
 176	struct gss_cl_ctx *ctx;
 177
 178	ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
 179	if (ctx != NULL) {
 180		ctx->gc_proc = RPC_GSS_PROC_DATA;
 181		ctx->gc_seq = 1;	/* NetApp 6.4R1 doesn't accept seq. no. 0 */
 182		spin_lock_init(&ctx->gc_seq_lock);
 183		atomic_set(&ctx->count,1);
 184	}
 185	return ctx;
 186}
 187
 188#define GSSD_MIN_TIMEOUT (60 * 60)
 189static const void *
 190gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
 191{
 192	const void *q;
 193	unsigned int seclen;
 194	unsigned int timeout;
 195	u32 window_size;
 196	int ret;
 197
 198	/* First unsigned int gives the lifetime (in seconds) of the cred */
 199	p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
 200	if (IS_ERR(p))
 201		goto err;
 202	if (timeout == 0)
 203		timeout = GSSD_MIN_TIMEOUT;
 204	ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
 205	/* Sequence number window. Determines the maximum number of simultaneous requests */
 206	p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
 207	if (IS_ERR(p))
 208		goto err;
 209	ctx->gc_win = window_size;
 210	/* gssd signals an error by passing ctx->gc_win = 0: */
 211	if (ctx->gc_win == 0) {
 212		/*
 213		 * in which case, p points to an error code. Anything other
 214		 * than -EKEYEXPIRED gets converted to -EACCES.
 215		 */
 216		p = simple_get_bytes(p, end, &ret, sizeof(ret));
 217		if (!IS_ERR(p))
 218			p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
 219						    ERR_PTR(-EACCES);
 220		goto err;
 221	}
 222	/* copy the opaque wire context */
 223	p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
 224	if (IS_ERR(p))
 225		goto err;
 226	/* import the opaque security context */
 227	p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
 228	if (IS_ERR(p))
 229		goto err;
 230	q = (const void *)((const char *)p + seclen);
 231	if (unlikely(q > end || q < p)) {
 232		p = ERR_PTR(-EFAULT);
 233		goto err;
 234	}
 235	ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, GFP_NOFS);
 236	if (ret < 0) {
 237		p = ERR_PTR(ret);
 238		goto err;
 239	}
 240	return q;
 241err:
 242	dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
 243	return p;
 244}
 245
 246#define UPCALL_BUF_LEN 128
 247
 248struct gss_upcall_msg {
 249	atomic_t count;
 250	uid_t	uid;
 251	struct rpc_pipe_msg msg;
 252	struct list_head list;
 253	struct gss_auth *auth;
 254	struct rpc_pipe *pipe;
 255	struct rpc_wait_queue rpc_waitqueue;
 256	wait_queue_head_t waitqueue;
 257	struct gss_cl_ctx *ctx;
 258	char databuf[UPCALL_BUF_LEN];
 259};
 260
 261static int get_pipe_version(void)
 262{
 263	int ret;
 264
 265	spin_lock(&pipe_version_lock);
 266	if (pipe_version >= 0) {
 267		atomic_inc(&pipe_users);
 268		ret = pipe_version;
 269	} else
 270		ret = -EAGAIN;
 271	spin_unlock(&pipe_version_lock);
 272	return ret;
 273}
 274
 275static void put_pipe_version(void)
 276{
 277	if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
 278		pipe_version = -1;
 279		spin_unlock(&pipe_version_lock);
 280	}
 281}
 282
 283static void
 284gss_release_msg(struct gss_upcall_msg *gss_msg)
 285{
 286	if (!atomic_dec_and_test(&gss_msg->count))
 287		return;
 288	put_pipe_version();
 289	BUG_ON(!list_empty(&gss_msg->list));
 290	if (gss_msg->ctx != NULL)
 291		gss_put_ctx(gss_msg->ctx);
 292	rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
 293	kfree(gss_msg);
 294}
 295
 296static struct gss_upcall_msg *
 297__gss_find_upcall(struct rpc_pipe *pipe, uid_t uid)
 298{
 299	struct gss_upcall_msg *pos;
 300	list_for_each_entry(pos, &pipe->in_downcall, list) {
 301		if (pos->uid != uid)
 302			continue;
 303		atomic_inc(&pos->count);
 304		dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
 305		return pos;
 306	}
 307	dprintk("RPC:       gss_find_upcall found nothing\n");
 308	return NULL;
 309}
 310
 311/* Try to add an upcall to the pipefs queue.
 312 * If an upcall owned by our uid already exists, then we return a reference
 313 * to that upcall instead of adding the new upcall.
 314 */
 315static inline struct gss_upcall_msg *
 316gss_add_msg(struct gss_upcall_msg *gss_msg)
 317{
 318	struct rpc_pipe *pipe = gss_msg->pipe;
 319	struct gss_upcall_msg *old;
 320
 321	spin_lock(&pipe->lock);
 322	old = __gss_find_upcall(pipe, gss_msg->uid);
 323	if (old == NULL) {
 324		atomic_inc(&gss_msg->count);
 325		list_add(&gss_msg->list, &pipe->in_downcall);
 326	} else
 327		gss_msg = old;
 328	spin_unlock(&pipe->lock);
 329	return gss_msg;
 330}
 331
 332static void
 333__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 334{
 335	list_del_init(&gss_msg->list);
 336	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 337	wake_up_all(&gss_msg->waitqueue);
 338	atomic_dec(&gss_msg->count);
 339}
 340
 341static void
 342gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 343{
 344	struct rpc_pipe *pipe = gss_msg->pipe;
 345
 346	if (list_empty(&gss_msg->list))
 347		return;
 348	spin_lock(&pipe->lock);
 349	if (!list_empty(&gss_msg->list))
 350		__gss_unhash_msg(gss_msg);
 351	spin_unlock(&pipe->lock);
 352}
 353
 354static void
 355gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
 356{
 357	switch (gss_msg->msg.errno) {
 358	case 0:
 359		if (gss_msg->ctx == NULL)
 360			break;
 361		clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 362		gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
 363		break;
 364	case -EKEYEXPIRED:
 365		set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 366	}
 367	gss_cred->gc_upcall_timestamp = jiffies;
 368	gss_cred->gc_upcall = NULL;
 369	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 370}
 371
 372static void
 373gss_upcall_callback(struct rpc_task *task)
 374{
 375	struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
 376			struct gss_cred, gc_base);
 377	struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
 378	struct rpc_pipe *pipe = gss_msg->pipe;
 379
 380	spin_lock(&pipe->lock);
 381	gss_handle_downcall_result(gss_cred, gss_msg);
 382	spin_unlock(&pipe->lock);
 383	task->tk_status = gss_msg->msg.errno;
 384	gss_release_msg(gss_msg);
 385}
 386
 387static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
 388{
 389	gss_msg->msg.data = &gss_msg->uid;
 390	gss_msg->msg.len = sizeof(gss_msg->uid);
 391}
 392
 393static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
 394				struct rpc_clnt *clnt,
 395				const char *service_name)
 396{
 397	struct gss_api_mech *mech = gss_msg->auth->mech;
 398	char *p = gss_msg->databuf;
 399	int len = 0;
 400
 401	gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
 402				   mech->gm_name,
 403				   gss_msg->uid);
 404	p += gss_msg->msg.len;
 405	if (clnt->cl_principal) {
 406		len = sprintf(p, "target=%s ", clnt->cl_principal);
 407		p += len;
 408		gss_msg->msg.len += len;
 409	}
 410	if (service_name != NULL) {
 411		len = sprintf(p, "service=%s ", service_name);
 412		p += len;
 413		gss_msg->msg.len += len;
 414	}
 415	if (mech->gm_upcall_enctypes) {
 416		len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
 417		p += len;
 418		gss_msg->msg.len += len;
 419	}
 420	len = sprintf(p, "\n");
 421	gss_msg->msg.len += len;
 422
 423	gss_msg->msg.data = gss_msg->databuf;
 424	BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
 425}
 426
 427static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
 428				struct rpc_clnt *clnt,
 429				const char *service_name)
 430{
 431	if (pipe_version == 0)
 432		gss_encode_v0_msg(gss_msg);
 433	else /* pipe_version == 1 */
 434		gss_encode_v1_msg(gss_msg, clnt, service_name);
 435}
 436
 437static struct gss_upcall_msg *
 438gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
 439		uid_t uid, const char *service_name)
 440{
 441	struct gss_upcall_msg *gss_msg;
 442	int vers;
 443
 444	gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
 445	if (gss_msg == NULL)
 446		return ERR_PTR(-ENOMEM);
 447	vers = get_pipe_version();
 448	if (vers < 0) {
 449		kfree(gss_msg);
 450		return ERR_PTR(vers);
 451	}
 452	gss_msg->pipe = gss_auth->pipe[vers];
 453	INIT_LIST_HEAD(&gss_msg->list);
 454	rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
 455	init_waitqueue_head(&gss_msg->waitqueue);
 456	atomic_set(&gss_msg->count, 1);
 457	gss_msg->uid = uid;
 458	gss_msg->auth = gss_auth;
 459	gss_encode_msg(gss_msg, clnt, service_name);
 460	return gss_msg;
 461}
 462
 463static struct gss_upcall_msg *
 464gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
 465{
 466	struct gss_cred *gss_cred = container_of(cred,
 467			struct gss_cred, gc_base);
 468	struct gss_upcall_msg *gss_new, *gss_msg;
 469	uid_t uid = cred->cr_uid;
 470
 471	gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal);
 472	if (IS_ERR(gss_new))
 473		return gss_new;
 474	gss_msg = gss_add_msg(gss_new);
 475	if (gss_msg == gss_new) {
 476		int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg);
 477		if (res) {
 478			gss_unhash_msg(gss_new);
 479			gss_msg = ERR_PTR(res);
 480		}
 481	} else
 482		gss_release_msg(gss_new);
 483	return gss_msg;
 484}
 485
 486static void warn_gssd(void)
 487{
 488	static unsigned long ratelimit;
 489	unsigned long now = jiffies;
 490
 491	if (time_after(now, ratelimit)) {
 492		printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
 493				"Please check user daemon is running.\n");
 494		ratelimit = now + 15*HZ;
 495	}
 496}
 497
 498static inline int
 499gss_refresh_upcall(struct rpc_task *task)
 500{
 501	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 502	struct gss_auth *gss_auth = container_of(cred->cr_auth,
 503			struct gss_auth, rpc_auth);
 504	struct gss_cred *gss_cred = container_of(cred,
 505			struct gss_cred, gc_base);
 506	struct gss_upcall_msg *gss_msg;
 507	struct rpc_pipe *pipe;
 508	int err = 0;
 509
 510	dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
 511								cred->cr_uid);
 512	gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
 513	if (PTR_ERR(gss_msg) == -EAGAIN) {
 514		/* XXX: warning on the first, under the assumption we
 515		 * shouldn't normally hit this case on a refresh. */
 516		warn_gssd();
 517		task->tk_timeout = 15*HZ;
 518		rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
 519		return -EAGAIN;
 520	}
 521	if (IS_ERR(gss_msg)) {
 522		err = PTR_ERR(gss_msg);
 523		goto out;
 524	}
 525	pipe = gss_msg->pipe;
 526	spin_lock(&pipe->lock);
 527	if (gss_cred->gc_upcall != NULL)
 528		rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
 529	else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
 530		task->tk_timeout = 0;
 531		gss_cred->gc_upcall = gss_msg;
 532		/* gss_upcall_callback will release the reference to gss_upcall_msg */
 533		atomic_inc(&gss_msg->count);
 534		rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
 535	} else {
 536		gss_handle_downcall_result(gss_cred, gss_msg);
 537		err = gss_msg->msg.errno;
 538	}
 539	spin_unlock(&pipe->lock);
 540	gss_release_msg(gss_msg);
 541out:
 542	dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
 543			task->tk_pid, cred->cr_uid, err);
 544	return err;
 545}
 546
 547static inline int
 548gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
 549{
 550	struct rpc_pipe *pipe;
 551	struct rpc_cred *cred = &gss_cred->gc_base;
 552	struct gss_upcall_msg *gss_msg;
 553	DEFINE_WAIT(wait);
 554	int err = 0;
 555
 556	dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
 557retry:
 558	gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
 559	if (PTR_ERR(gss_msg) == -EAGAIN) {
 560		err = wait_event_interruptible_timeout(pipe_version_waitqueue,
 561				pipe_version >= 0, 15*HZ);
 562		if (pipe_version < 0) {
 563			warn_gssd();
 564			err = -EACCES;
 565		}
 566		if (err)
 567			goto out;
 568		goto retry;
 569	}
 570	if (IS_ERR(gss_msg)) {
 571		err = PTR_ERR(gss_msg);
 572		goto out;
 573	}
 574	pipe = gss_msg->pipe;
 575	for (;;) {
 576		prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE);
 577		spin_lock(&pipe->lock);
 578		if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
 579			break;
 580		}
 581		spin_unlock(&pipe->lock);
 582		if (fatal_signal_pending(current)) {
 583			err = -ERESTARTSYS;
 584			goto out_intr;
 585		}
 586		schedule();
 587	}
 588	if (gss_msg->ctx)
 589		gss_cred_set_ctx(cred, gss_msg->ctx);
 590	else
 591		err = gss_msg->msg.errno;
 592	spin_unlock(&pipe->lock);
 593out_intr:
 594	finish_wait(&gss_msg->waitqueue, &wait);
 595	gss_release_msg(gss_msg);
 596out:
 597	dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
 598			cred->cr_uid, err);
 599	return err;
 600}
 601
 602#define MSG_BUF_MAXSIZE 1024
 603
 604static ssize_t
 605gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
 606{
 607	const void *p, *end;
 608	void *buf;
 609	struct gss_upcall_msg *gss_msg;
 610	struct rpc_pipe *pipe = RPC_I(filp->f_dentry->d_inode)->pipe;
 611	struct gss_cl_ctx *ctx;
 612	uid_t uid;
 613	ssize_t err = -EFBIG;
 614
 615	if (mlen > MSG_BUF_MAXSIZE)
 616		goto out;
 617	err = -ENOMEM;
 618	buf = kmalloc(mlen, GFP_NOFS);
 619	if (!buf)
 620		goto out;
 621
 622	err = -EFAULT;
 623	if (copy_from_user(buf, src, mlen))
 624		goto err;
 625
 626	end = (const void *)((char *)buf + mlen);
 627	p = simple_get_bytes(buf, end, &uid, sizeof(uid));
 628	if (IS_ERR(p)) {
 629		err = PTR_ERR(p);
 630		goto err;
 631	}
 632
 633	err = -ENOMEM;
 634	ctx = gss_alloc_context();
 635	if (ctx == NULL)
 636		goto err;
 637
 638	err = -ENOENT;
 639	/* Find a matching upcall */
 640	spin_lock(&pipe->lock);
 641	gss_msg = __gss_find_upcall(pipe, uid);
 642	if (gss_msg == NULL) {
 643		spin_unlock(&pipe->lock);
 644		goto err_put_ctx;
 645	}
 646	list_del_init(&gss_msg->list);
 647	spin_unlock(&pipe->lock);
 648
 649	p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
 650	if (IS_ERR(p)) {
 651		err = PTR_ERR(p);
 652		switch (err) {
 653		case -EACCES:
 654		case -EKEYEXPIRED:
 655			gss_msg->msg.errno = err;
 656			err = mlen;
 657			break;
 658		case -EFAULT:
 659		case -ENOMEM:
 660		case -EINVAL:
 661		case -ENOSYS:
 662			gss_msg->msg.errno = -EAGAIN;
 663			break;
 664		default:
 665			printk(KERN_CRIT "%s: bad return from "
 666				"gss_fill_context: %zd\n", __func__, err);
 667			BUG();
 668		}
 669		goto err_release_msg;
 670	}
 671	gss_msg->ctx = gss_get_ctx(ctx);
 672	err = mlen;
 673
 674err_release_msg:
 675	spin_lock(&pipe->lock);
 676	__gss_unhash_msg(gss_msg);
 677	spin_unlock(&pipe->lock);
 678	gss_release_msg(gss_msg);
 679err_put_ctx:
 680	gss_put_ctx(ctx);
 681err:
 682	kfree(buf);
 683out:
 684	dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
 685	return err;
 686}
 687
 688static int gss_pipe_open(struct inode *inode, int new_version)
 689{
 690	int ret = 0;
 691
 692	spin_lock(&pipe_version_lock);
 693	if (pipe_version < 0) {
 694		/* First open of any gss pipe determines the version: */
 695		pipe_version = new_version;
 696		rpc_wake_up(&pipe_version_rpc_waitqueue);
 697		wake_up(&pipe_version_waitqueue);
 698	} else if (pipe_version != new_version) {
 699		/* Trying to open a pipe of a different version */
 700		ret = -EBUSY;
 701		goto out;
 702	}
 703	atomic_inc(&pipe_users);
 704out:
 705	spin_unlock(&pipe_version_lock);
 706	return ret;
 707
 708}
 709
 710static int gss_pipe_open_v0(struct inode *inode)
 711{
 712	return gss_pipe_open(inode, 0);
 713}
 714
 715static int gss_pipe_open_v1(struct inode *inode)
 716{
 717	return gss_pipe_open(inode, 1);
 718}
 719
 720static void
 721gss_pipe_release(struct inode *inode)
 722{
 723	struct rpc_pipe *pipe = RPC_I(inode)->pipe;
 724	struct gss_upcall_msg *gss_msg;
 725
 726restart:
 727	spin_lock(&pipe->lock);
 728	list_for_each_entry(gss_msg, &pipe->in_downcall, list) {
 729
 730		if (!list_empty(&gss_msg->msg.list))
 731			continue;
 732		gss_msg->msg.errno = -EPIPE;
 733		atomic_inc(&gss_msg->count);
 734		__gss_unhash_msg(gss_msg);
 735		spin_unlock(&pipe->lock);
 736		gss_release_msg(gss_msg);
 737		goto restart;
 738	}
 739	spin_unlock(&pipe->lock);
 740
 741	put_pipe_version();
 742}
 743
 744static void
 745gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
 746{
 747	struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
 748
 749	if (msg->errno < 0) {
 750		dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
 751				gss_msg);
 752		atomic_inc(&gss_msg->count);
 753		gss_unhash_msg(gss_msg);
 754		if (msg->errno == -ETIMEDOUT)
 755			warn_gssd();
 756		gss_release_msg(gss_msg);
 757	}
 758}
 759
 760static void gss_pipes_dentries_destroy(struct rpc_auth *auth)
 761{
 762	struct gss_auth *gss_auth;
 763
 764	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 765	if (gss_auth->pipe[0]->dentry)
 766		rpc_unlink(gss_auth->pipe[0]->dentry);
 767	if (gss_auth->pipe[1]->dentry)
 768		rpc_unlink(gss_auth->pipe[1]->dentry);
 769}
 770
 771static int gss_pipes_dentries_create(struct rpc_auth *auth)
 772{
 773	int err;
 774	struct gss_auth *gss_auth;
 775	struct rpc_clnt *clnt;
 776
 777	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 778	clnt = gss_auth->client;
 779
 780	gss_auth->pipe[1]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
 781						      "gssd",
 782						      clnt, gss_auth->pipe[1]);
 783	if (IS_ERR(gss_auth->pipe[1]->dentry))
 784		return PTR_ERR(gss_auth->pipe[1]->dentry);
 785	gss_auth->pipe[0]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
 786						      gss_auth->mech->gm_name,
 787						      clnt, gss_auth->pipe[0]);
 788	if (IS_ERR(gss_auth->pipe[0]->dentry)) {
 789		err = PTR_ERR(gss_auth->pipe[0]->dentry);
 790		goto err_unlink_pipe_1;
 791	}
 792	return 0;
 793
 794err_unlink_pipe_1:
 795	rpc_unlink(gss_auth->pipe[1]->dentry);
 796	return err;
 797}
 798
 799static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt,
 800					   struct rpc_auth *auth)
 801{
 802	struct net *net = rpc_net_ns(clnt);
 803	struct super_block *sb;
 804
 805	sb = rpc_get_sb_net(net);
 806	if (sb) {
 807		if (clnt->cl_dentry)
 808			gss_pipes_dentries_destroy(auth);
 809		rpc_put_sb_net(net);
 810	}
 811}
 812
 813static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt,
 814					 struct rpc_auth *auth)
 815{
 816	struct net *net = rpc_net_ns(clnt);
 817	struct super_block *sb;
 818	int err = 0;
 819
 820	sb = rpc_get_sb_net(net);
 821	if (sb) {
 822		if (clnt->cl_dentry)
 823			err = gss_pipes_dentries_create(auth);
 824		rpc_put_sb_net(net);
 825	}
 826	return err;
 827}
 828
 829/*
 830 * NOTE: we have the opportunity to use different
 831 * parameters based on the input flavor (which must be a pseudoflavor)
 832 */
 833static struct rpc_auth *
 834gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
 835{
 836	struct gss_auth *gss_auth;
 837	struct rpc_auth * auth;
 838	int err = -ENOMEM; /* XXX? */
 839
 840	dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
 841
 842	if (!try_module_get(THIS_MODULE))
 843		return ERR_PTR(err);
 844	if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
 845		goto out_dec;
 846	gss_auth->client = clnt;
 847	err = -EINVAL;
 848	gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
 849	if (!gss_auth->mech) {
 850		printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
 851				__func__, flavor);
 852		goto err_free;
 853	}
 854	gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
 855	if (gss_auth->service == 0)
 856		goto err_put_mech;
 857	auth = &gss_auth->rpc_auth;
 858	auth->au_cslack = GSS_CRED_SLACK >> 2;
 859	auth->au_rslack = GSS_VERF_SLACK >> 2;
 860	auth->au_ops = &authgss_ops;
 861	auth->au_flavor = flavor;
 862	atomic_set(&auth->au_count, 1);
 863	kref_init(&gss_auth->kref);
 864
 865	/*
 866	 * Note: if we created the old pipe first, then someone who
 867	 * examined the directory at the right moment might conclude
 868	 * that we supported only the old pipe.  So we instead create
 869	 * the new pipe first.
 870	 */
 871	gss_auth->pipe[1] = rpc_mkpipe_data(&gss_upcall_ops_v1,
 872					    RPC_PIPE_WAIT_FOR_OPEN);
 873	if (IS_ERR(gss_auth->pipe[1])) {
 874		err = PTR_ERR(gss_auth->pipe[1]);
 875		goto err_put_mech;
 876	}
 877
 878	gss_auth->pipe[0] = rpc_mkpipe_data(&gss_upcall_ops_v0,
 879					    RPC_PIPE_WAIT_FOR_OPEN);
 880	if (IS_ERR(gss_auth->pipe[0])) {
 881		err = PTR_ERR(gss_auth->pipe[0]);
 882		goto err_destroy_pipe_1;
 883	}
 884	err = gss_pipes_dentries_create_net(clnt, auth);
 885	if (err)
 886		goto err_destroy_pipe_0;
 887	err = rpcauth_init_credcache(auth);
 888	if (err)
 889		goto err_unlink_pipes;
 890
 891	return auth;
 892err_unlink_pipes:
 893	gss_pipes_dentries_destroy_net(clnt, auth);
 894err_destroy_pipe_0:
 895	rpc_destroy_pipe_data(gss_auth->pipe[0]);
 896err_destroy_pipe_1:
 897	rpc_destroy_pipe_data(gss_auth->pipe[1]);
 898err_put_mech:
 899	gss_mech_put(gss_auth->mech);
 900err_free:
 901	kfree(gss_auth);
 902out_dec:
 903	module_put(THIS_MODULE);
 904	return ERR_PTR(err);
 905}
 906
 907static void
 908gss_free(struct gss_auth *gss_auth)
 909{
 910	gss_pipes_dentries_destroy_net(gss_auth->client, &gss_auth->rpc_auth);
 911	rpc_destroy_pipe_data(gss_auth->pipe[0]);
 912	rpc_destroy_pipe_data(gss_auth->pipe[1]);
 913	gss_mech_put(gss_auth->mech);
 914
 915	kfree(gss_auth);
 916	module_put(THIS_MODULE);
 917}
 918
 919static void
 920gss_free_callback(struct kref *kref)
 921{
 922	struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
 923
 924	gss_free(gss_auth);
 925}
 926
 927static void
 928gss_destroy(struct rpc_auth *auth)
 929{
 930	struct gss_auth *gss_auth;
 931
 932	dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
 933			auth, auth->au_flavor);
 934
 935	rpcauth_destroy_credcache(auth);
 936
 937	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 938	kref_put(&gss_auth->kref, gss_free_callback);
 939}
 940
 941/*
 942 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
 943 * to the server with the GSS control procedure field set to
 944 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
 945 * all RPCSEC_GSS state associated with that context.
 946 */
 947static int
 948gss_destroying_context(struct rpc_cred *cred)
 949{
 950	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 951	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
 952	struct rpc_task *task;
 953
 954	if (gss_cred->gc_ctx == NULL ||
 955	    test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
 956		return 0;
 957
 958	gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
 959	cred->cr_ops = &gss_nullops;
 960
 961	/* Take a reference to ensure the cred will be destroyed either
 962	 * by the RPC call or by the put_rpccred() below */
 963	get_rpccred(cred);
 964
 965	task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
 966	if (!IS_ERR(task))
 967		rpc_put_task(task);
 968
 969	put_rpccred(cred);
 970	return 1;
 971}
 972
 973/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
 974 * to create a new cred or context, so they check that things have been
 975 * allocated before freeing them. */
 976static void
 977gss_do_free_ctx(struct gss_cl_ctx *ctx)
 978{
 979	dprintk("RPC:       gss_free_ctx\n");
 980
 981	gss_delete_sec_context(&ctx->gc_gss_ctx);
 982	kfree(ctx->gc_wire_ctx.data);
 983	kfree(ctx);
 984}
 985
 986static void
 987gss_free_ctx_callback(struct rcu_head *head)
 988{
 989	struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
 990	gss_do_free_ctx(ctx);
 991}
 992
 993static void
 994gss_free_ctx(struct gss_cl_ctx *ctx)
 995{
 996	call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
 997}
 998
 999static void
1000gss_free_cred(struct gss_cred *gss_cred)
1001{
1002	dprintk("RPC:       gss_free_cred %p\n", gss_cred);
1003	kfree(gss_cred);
1004}
1005
1006static void
1007gss_free_cred_callback(struct rcu_head *head)
1008{
1009	struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
1010	gss_free_cred(gss_cred);
1011}
1012
1013static void
1014gss_destroy_nullcred(struct rpc_cred *cred)
1015{
1016	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
1017	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
1018	struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
1019
1020	RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
1021	call_rcu(&cred->cr_rcu, gss_free_cred_callback);
1022	if (ctx)
1023		gss_put_ctx(ctx);
1024	kref_put(&gss_auth->kref, gss_free_callback);
1025}
1026
1027static void
1028gss_destroy_cred(struct rpc_cred *cred)
1029{
1030
1031	if (gss_destroying_context(cred))
1032		return;
1033	gss_destroy_nullcred(cred);
1034}
1035
1036/*
1037 * Lookup RPCSEC_GSS cred for the current process
1038 */
1039static struct rpc_cred *
1040gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1041{
1042	return rpcauth_lookup_credcache(auth, acred, flags);
1043}
1044
1045static struct rpc_cred *
1046gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1047{
1048	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1049	struct gss_cred	*cred = NULL;
1050	int err = -ENOMEM;
1051
1052	dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
1053		acred->uid, auth->au_flavor);
1054
1055	if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
1056		goto out_err;
1057
1058	rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
1059	/*
1060	 * Note: in order to force a call to call_refresh(), we deliberately
1061	 * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
1062	 */
1063	cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
1064	cred->gc_service = gss_auth->service;
1065	cred->gc_principal = NULL;
1066	if (acred->machine_cred)
1067		cred->gc_principal = acred->principal;
1068	kref_get(&gss_auth->kref);
1069	return &cred->gc_base;
1070
1071out_err:
1072	dprintk("RPC:       gss_create_cred failed with error %d\n", err);
1073	return ERR_PTR(err);
1074}
1075
1076static int
1077gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1078{
1079	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1080	struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1081	int err;
1082
1083	do {
1084		err = gss_create_upcall(gss_auth, gss_cred);
1085	} while (err == -EAGAIN);
1086	return err;
1087}
1088
1089static int
1090gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1091{
1092	struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1093
1094	if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1095		goto out;
1096	/* Don't match with creds that have expired. */
1097	if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1098		return 0;
1099	if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1100		return 0;
1101out:
1102	if (acred->principal != NULL) {
1103		if (gss_cred->gc_principal == NULL)
1104			return 0;
1105		return strcmp(acred->principal, gss_cred->gc_principal) == 0;
1106	}
1107	if (gss_cred->gc_principal != NULL)
1108		return 0;
1109	return rc->cr_uid == acred->uid;
1110}
1111
1112/*
1113* Marshal credentials.
1114* Maybe we should keep a cached credential for performance reasons.
1115*/
1116static __be32 *
1117gss_marshal(struct rpc_task *task, __be32 *p)
1118{
1119	struct rpc_rqst *req = task->tk_rqstp;
1120	struct rpc_cred *cred = req->rq_cred;
1121	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1122						 gc_base);
1123	struct gss_cl_ctx	*ctx = gss_cred_get_ctx(cred);
1124	__be32		*cred_len;
1125	u32             maj_stat = 0;
1126	struct xdr_netobj mic;
1127	struct kvec	iov;
1128	struct xdr_buf	verf_buf;
1129
1130	dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1131
1132	*p++ = htonl(RPC_AUTH_GSS);
1133	cred_len = p++;
1134
1135	spin_lock(&ctx->gc_seq_lock);
1136	req->rq_seqno = ctx->gc_seq++;
1137	spin_unlock(&ctx->gc_seq_lock);
1138
1139	*p++ = htonl((u32) RPC_GSS_VERSION);
1140	*p++ = htonl((u32) ctx->gc_proc);
1141	*p++ = htonl((u32) req->rq_seqno);
1142	*p++ = htonl((u32) gss_cred->gc_service);
1143	p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1144	*cred_len = htonl((p - (cred_len + 1)) << 2);
1145
1146	/* We compute the checksum for the verifier over the xdr-encoded bytes
1147	 * starting with the xid and ending at the end of the credential: */
1148	iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1149					req->rq_snd_buf.head[0].iov_base);
1150	iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1151	xdr_buf_from_iov(&iov, &verf_buf);
1152
1153	/* set verifier flavor*/
1154	*p++ = htonl(RPC_AUTH_GSS);
1155
1156	mic.data = (u8 *)(p + 1);
1157	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1158	if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1159		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1160	} else if (maj_stat != 0) {
1161		printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1162		goto out_put_ctx;
1163	}
1164	p = xdr_encode_opaque(p, NULL, mic.len);
1165	gss_put_ctx(ctx);
1166	return p;
1167out_put_ctx:
1168	gss_put_ctx(ctx);
1169	return NULL;
1170}
1171
1172static int gss_renew_cred(struct rpc_task *task)
1173{
1174	struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
1175	struct gss_cred *gss_cred = container_of(oldcred,
1176						 struct gss_cred,
1177						 gc_base);
1178	struct rpc_auth *auth = oldcred->cr_auth;
1179	struct auth_cred acred = {
1180		.uid = oldcred->cr_uid,
1181		.principal = gss_cred->gc_principal,
1182		.machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0),
1183	};
1184	struct rpc_cred *new;
1185
1186	new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1187	if (IS_ERR(new))
1188		return PTR_ERR(new);
1189	task->tk_rqstp->rq_cred = new;
1190	put_rpccred(oldcred);
1191	return 0;
1192}
1193
1194static int gss_cred_is_negative_entry(struct rpc_cred *cred)
1195{
1196	if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
1197		unsigned long now = jiffies;
1198		unsigned long begin, expire;
1199		struct gss_cred *gss_cred; 
1200
1201		gss_cred = container_of(cred, struct gss_cred, gc_base);
1202		begin = gss_cred->gc_upcall_timestamp;
1203		expire = begin + gss_expired_cred_retry_delay * HZ;
1204
1205		if (time_in_range_open(now, begin, expire))
1206			return 1;
1207	}
1208	return 0;
1209}
1210
1211/*
1212* Refresh credentials. XXX - finish
1213*/
1214static int
1215gss_refresh(struct rpc_task *task)
1216{
1217	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1218	int ret = 0;
1219
1220	if (gss_cred_is_negative_entry(cred))
1221		return -EKEYEXPIRED;
1222
1223	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1224			!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1225		ret = gss_renew_cred(task);
1226		if (ret < 0)
1227			goto out;
1228		cred = task->tk_rqstp->rq_cred;
1229	}
1230
1231	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1232		ret = gss_refresh_upcall(task);
1233out:
1234	return ret;
1235}
1236
1237/* Dummy refresh routine: used only when destroying the context */
1238static int
1239gss_refresh_null(struct rpc_task *task)
1240{
1241	return -EACCES;
1242}
1243
1244static __be32 *
1245gss_validate(struct rpc_task *task, __be32 *p)
1246{
1247	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1248	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1249	__be32		seq;
1250	struct kvec	iov;
1251	struct xdr_buf	verf_buf;
1252	struct xdr_netobj mic;
1253	u32		flav,len;
1254	u32		maj_stat;
1255
1256	dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1257
1258	flav = ntohl(*p++);
1259	if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1260		goto out_bad;
1261	if (flav != RPC_AUTH_GSS)
1262		goto out_bad;
1263	seq = htonl(task->tk_rqstp->rq_seqno);
1264	iov.iov_base = &seq;
1265	iov.iov_len = sizeof(seq);
1266	xdr_buf_from_iov(&iov, &verf_buf);
1267	mic.data = (u8 *)p;
1268	mic.len = len;
1269
1270	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1271	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1272		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1273	if (maj_stat) {
1274		dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1275				"error 0x%08x\n", task->tk_pid, maj_stat);
1276		goto out_bad;
1277	}
1278	/* We leave it to unwrap to calculate au_rslack. For now we just
1279	 * calculate the length of the verifier: */
1280	cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1281	gss_put_ctx(ctx);
1282	dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1283			task->tk_pid);
1284	return p + XDR_QUADLEN(len);
1285out_bad:
1286	gss_put_ctx(ctx);
1287	dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1288	return NULL;
1289}
1290
1291static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
1292				__be32 *p, void *obj)
1293{
1294	struct xdr_stream xdr;
1295
1296	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
1297	encode(rqstp, &xdr, obj);
1298}
1299
1300static inline int
1301gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1302		   kxdreproc_t encode, struct rpc_rqst *rqstp,
1303		   __be32 *p, void *obj)
1304{
1305	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1306	struct xdr_buf	integ_buf;
1307	__be32          *integ_len = NULL;
1308	struct xdr_netobj mic;
1309	u32		offset;
1310	__be32		*q;
1311	struct kvec	*iov;
1312	u32             maj_stat = 0;
1313	int		status = -EIO;
1314
1315	integ_len = p++;
1316	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1317	*p++ = htonl(rqstp->rq_seqno);
1318
1319	gss_wrap_req_encode(encode, rqstp, p, obj);
1320
1321	if (xdr_buf_subsegment(snd_buf, &integ_buf,
1322				offset, snd_buf->len - offset))
1323		return status;
1324	*integ_len = htonl(integ_buf.len);
1325
1326	/* guess whether we're in the head or the tail: */
1327	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1328		iov = snd_buf->tail;
1329	else
1330		iov = snd_buf->head;
1331	p = iov->iov_base + iov->iov_len;
1332	mic.data = (u8 *)(p + 1);
1333
1334	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1335	status = -EIO; /* XXX? */
1336	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1337		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1338	else if (maj_stat)
1339		return status;
1340	q = xdr_encode_opaque(p, NULL, mic.len);
1341
1342	offset = (u8 *)q - (u8 *)p;
1343	iov->iov_len += offset;
1344	snd_buf->len += offset;
1345	return 0;
1346}
1347
1348static void
1349priv_release_snd_buf(struct rpc_rqst *rqstp)
1350{
1351	int i;
1352
1353	for (i=0; i < rqstp->rq_enc_pages_num; i++)
1354		__free_page(rqstp->rq_enc_pages[i]);
1355	kfree(rqstp->rq_enc_pages);
1356}
1357
1358static int
1359alloc_enc_pages(struct rpc_rqst *rqstp)
1360{
1361	struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1362	int first, last, i;
1363
1364	if (snd_buf->page_len == 0) {
1365		rqstp->rq_enc_pages_num = 0;
1366		return 0;
1367	}
1368
1369	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1370	last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1371	rqstp->rq_enc_pages_num = last - first + 1 + 1;
1372	rqstp->rq_enc_pages
1373		= kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1374				GFP_NOFS);
1375	if (!rqstp->rq_enc_pages)
1376		goto out;
1377	for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1378		rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1379		if (rqstp->rq_enc_pages[i] == NULL)
1380			goto out_free;
1381	}
1382	rqstp->rq_release_snd_buf = priv_release_snd_buf;
1383	return 0;
1384out_free:
1385	rqstp->rq_enc_pages_num = i;
1386	priv_release_snd_buf(rqstp);
1387out:
1388	return -EAGAIN;
1389}
1390
1391static inline int
1392gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1393		  kxdreproc_t encode, struct rpc_rqst *rqstp,
1394		  __be32 *p, void *obj)
1395{
1396	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1397	u32		offset;
1398	u32             maj_stat;
1399	int		status;
1400	__be32		*opaque_len;
1401	struct page	**inpages;
1402	int		first;
1403	int		pad;
1404	struct kvec	*iov;
1405	char		*tmp;
1406
1407	opaque_len = p++;
1408	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1409	*p++ = htonl(rqstp->rq_seqno);
1410
1411	gss_wrap_req_encode(encode, rqstp, p, obj);
1412
1413	status = alloc_enc_pages(rqstp);
1414	if (status)
1415		return status;
1416	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1417	inpages = snd_buf->pages + first;
1418	snd_buf->pages = rqstp->rq_enc_pages;
1419	snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1420	/*
1421	 * Give the tail its own page, in case we need extra space in the
1422	 * head when wrapping:
1423	 *
1424	 * call_allocate() allocates twice the slack space required
1425	 * by the authentication flavor to rq_callsize.
1426	 * For GSS, slack is GSS_CRED_SLACK.
1427	 */
1428	if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1429		tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1430		memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1431		snd_buf->tail[0].iov_base = tmp;
1432	}
1433	maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1434	/* slack space should prevent this ever happening: */
1435	BUG_ON(snd_buf->len > snd_buf->buflen);
1436	status = -EIO;
1437	/* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1438	 * done anyway, so it's safe to put the request on the wire: */
1439	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1440		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1441	else if (maj_stat)
1442		return status;
1443
1444	*opaque_len = htonl(snd_buf->len - offset);
1445	/* guess whether we're in the head or the tail: */
1446	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1447		iov = snd_buf->tail;
1448	else
1449		iov = snd_buf->head;
1450	p = iov->iov_base + iov->iov_len;
1451	pad = 3 - ((snd_buf->len - offset - 1) & 3);
1452	memset(p, 0, pad);
1453	iov->iov_len += pad;
1454	snd_buf->len += pad;
1455
1456	return 0;
1457}
1458
1459static int
1460gss_wrap_req(struct rpc_task *task,
1461	     kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
1462{
1463	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1464	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1465			gc_base);
1466	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1467	int             status = -EIO;
1468
1469	dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1470	if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1471		/* The spec seems a little ambiguous here, but I think that not
1472		 * wrapping context destruction requests makes the most sense.
1473		 */
1474		gss_wrap_req_encode(encode, rqstp, p, obj);
1475		status = 0;
1476		goto out;
1477	}
1478	switch (gss_cred->gc_service) {
1479	case RPC_GSS_SVC_NONE:
1480		gss_wrap_req_encode(encode, rqstp, p, obj);
1481		status = 0;
1482		break;
1483	case RPC_GSS_SVC_INTEGRITY:
1484		status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj);
1485		break;
1486	case RPC_GSS_SVC_PRIVACY:
1487		status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj);
1488		break;
1489	}
1490out:
1491	gss_put_ctx(ctx);
1492	dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1493	return status;
1494}
1495
1496static inline int
1497gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1498		struct rpc_rqst *rqstp, __be32 **p)
1499{
1500	struct xdr_buf	*rcv_buf = &rqstp->rq_rcv_buf;
1501	struct xdr_buf integ_buf;
1502	struct xdr_netobj mic;
1503	u32 data_offset, mic_offset;
1504	u32 integ_len;
1505	u32 maj_stat;
1506	int status = -EIO;
1507
1508	integ_len = ntohl(*(*p)++);
1509	if (integ_len & 3)
1510		return status;
1511	data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1512	mic_offset = integ_len + data_offset;
1513	if (mic_offset > rcv_buf->len)
1514		return status;
1515	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1516		return status;
1517
1518	if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1519				mic_offset - data_offset))
1520		return status;
1521
1522	if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1523		return status;
1524
1525	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1526	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1527		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1528	if (maj_stat != GSS_S_COMPLETE)
1529		return status;
1530	return 0;
1531}
1532
1533static inline int
1534gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1535		struct rpc_rqst *rqstp, __be32 **p)
1536{
1537	struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1538	u32 offset;
1539	u32 opaque_len;
1540	u32 maj_stat;
1541	int status = -EIO;
1542
1543	opaque_len = ntohl(*(*p)++);
1544	offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1545	if (offset + opaque_len > rcv_buf->len)
1546		return status;
1547	/* remove padding: */
1548	rcv_buf->len = offset + opaque_len;
1549
1550	maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1551	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1552		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1553	if (maj_stat != GSS_S_COMPLETE)
1554		return status;
1555	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1556		return status;
1557
1558	return 0;
1559}
1560
1561static int
1562gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
1563		      __be32 *p, void *obj)
1564{
1565	struct xdr_stream xdr;
1566
1567	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
1568	return decode(rqstp, &xdr, obj);
1569}
1570
1571static int
1572gss_unwrap_resp(struct rpc_task *task,
1573		kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
1574{
1575	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1576	struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1577			gc_base);
1578	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1579	__be32		*savedp = p;
1580	struct kvec	*head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1581	int		savedlen = head->iov_len;
1582	int             status = -EIO;
1583
1584	if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1585		goto out_decode;
1586	switch (gss_cred->gc_service) {
1587	case RPC_GSS_SVC_NONE:
1588		break;
1589	case RPC_GSS_SVC_INTEGRITY:
1590		status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1591		if (status)
1592			goto out;
1593		break;
1594	case RPC_GSS_SVC_PRIVACY:
1595		status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1596		if (status)
1597			goto out;
1598		break;
1599	}
1600	/* take into account extra slack for integrity and privacy cases: */
1601	cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1602						+ (savedlen - head->iov_len);
1603out_decode:
1604	status = gss_unwrap_req_decode(decode, rqstp, p, obj);
1605out:
1606	gss_put_ctx(ctx);
1607	dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1608			status);
1609	return status;
1610}
1611
1612static const struct rpc_authops authgss_ops = {
1613	.owner		= THIS_MODULE,
1614	.au_flavor	= RPC_AUTH_GSS,
1615	.au_name	= "RPCSEC_GSS",
1616	.create		= gss_create,
1617	.destroy	= gss_destroy,
1618	.lookup_cred	= gss_lookup_cred,
1619	.crcreate	= gss_create_cred,
1620	.pipes_create	= gss_pipes_dentries_create,
1621	.pipes_destroy	= gss_pipes_dentries_destroy,
1622};
1623
1624static const struct rpc_credops gss_credops = {
1625	.cr_name	= "AUTH_GSS",
1626	.crdestroy	= gss_destroy_cred,
1627	.cr_init	= gss_cred_init,
1628	.crbind		= rpcauth_generic_bind_cred,
1629	.crmatch	= gss_match,
1630	.crmarshal	= gss_marshal,
1631	.crrefresh	= gss_refresh,
1632	.crvalidate	= gss_validate,
1633	.crwrap_req	= gss_wrap_req,
1634	.crunwrap_resp	= gss_unwrap_resp,
1635};
1636
1637static const struct rpc_credops gss_nullops = {
1638	.cr_name	= "AUTH_GSS",
1639	.crdestroy	= gss_destroy_nullcred,
1640	.crbind		= rpcauth_generic_bind_cred,
1641	.crmatch	= gss_match,
1642	.crmarshal	= gss_marshal,
1643	.crrefresh	= gss_refresh_null,
1644	.crvalidate	= gss_validate,
1645	.crwrap_req	= gss_wrap_req,
1646	.crunwrap_resp	= gss_unwrap_resp,
1647};
1648
1649static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1650	.upcall		= rpc_pipe_generic_upcall,
1651	.downcall	= gss_pipe_downcall,
1652	.destroy_msg	= gss_pipe_destroy_msg,
1653	.open_pipe	= gss_pipe_open_v0,
1654	.release_pipe	= gss_pipe_release,
1655};
1656
1657static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1658	.upcall		= rpc_pipe_generic_upcall,
1659	.downcall	= gss_pipe_downcall,
1660	.destroy_msg	= gss_pipe_destroy_msg,
1661	.open_pipe	= gss_pipe_open_v1,
1662	.release_pipe	= gss_pipe_release,
1663};
1664
1665static __net_init int rpcsec_gss_init_net(struct net *net)
1666{
1667	return gss_svc_init_net(net);
1668}
1669
1670static __net_exit void rpcsec_gss_exit_net(struct net *net)
1671{
1672	gss_svc_shutdown_net(net);
1673}
1674
1675static struct pernet_operations rpcsec_gss_net_ops = {
1676	.init = rpcsec_gss_init_net,
1677	.exit = rpcsec_gss_exit_net,
1678};
1679
1680/*
1681 * Initialize RPCSEC_GSS module
1682 */
1683static int __init init_rpcsec_gss(void)
1684{
1685	int err = 0;
1686
1687	err = rpcauth_register(&authgss_ops);
1688	if (err)
1689		goto out;
1690	err = gss_svc_init();
1691	if (err)
1692		goto out_unregister;
1693	err = register_pernet_subsys(&rpcsec_gss_net_ops);
1694	if (err)
1695		goto out_svc_exit;
1696	rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1697	return 0;
1698out_svc_exit:
1699	gss_svc_shutdown();
1700out_unregister:
1701	rpcauth_unregister(&authgss_ops);
1702out:
1703	return err;
1704}
1705
1706static void __exit exit_rpcsec_gss(void)
1707{
1708	unregister_pernet_subsys(&rpcsec_gss_net_ops);
1709	gss_svc_shutdown();
1710	rpcauth_unregister(&authgss_ops);
1711	rcu_barrier(); /* Wait for completion of call_rcu()'s */
1712}
1713
1714MODULE_LICENSE("GPL");
1715module_param_named(expired_cred_retry_delay,
1716		   gss_expired_cred_retry_delay,
1717		   uint, 0644);
1718MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
1719		"the RPC engine retries an expired credential");
1720
1721module_init(init_rpcsec_gss)
1722module_exit(exit_rpcsec_gss)