/* new_l2tp_k.c
 *	Copyright 2004 Benjamin LaHaise.  All Rights Reserved.
 *	Kernel side of a simple L2TP implementation for use with Babylon.
 *
 *	Portions copied from net/ipv4/raw.c -	Alan Cox, David S. Miller, 
 *						Ross Biro, Fred N. van Kempen
 *
 *	This program is free software; you can redistribute it and/or
 *	modify it under the terms of the GNU General Public License
 *	as published by the Free Software Foundation; either version
 *	2 of the License, or (at your option) any later version.
 */
#include "bab_module.h"

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/init.h>
#include <linux/net.h>
#include <linux/errno.h>
#include <linux/sched.h>
#include <linux/file.h>
#include <linux/vmalloc.h>
#include <linux/list.h>
#include <net/sock.h>
#include <net/ip.h>

#include "aps_if.h"
#include "l2tp_linux.h"
#include "./l2tp_build.h"
#include "vercomp.h"

struct l2tp_info;
struct l2tp_tunnel;
struct l2tp_session {
	channel_t		ch;
	u16			session_id;
	u16			peer_session_id;
	int			recursion;

	struct l2tp_tunnel	*tunnel;
};

struct l2tp_tunnel {
	struct list_head	list;

	atomic_t		count;
	u16			tunnel_id;
	u16			peer_tunnel_id;
	struct sock		*sk;

	struct socket		*tx_socket;
	struct l2tp_info	*l2tp;

#define L2TP_MAX_SESSIONS	65536
	struct l2tp_session	*sessions[L2TP_MAX_SESSIONS];
		/* sessions[0] is the tunnel control session */
};

#define L2TP_MAX_TUNNELS	65536
static struct l2tp_tunnel *l2tp_tunnels[L2TP_MAX_TUNNELS];
static unsigned next_l2tp_tunnel_id = 1;

struct l2tp_info {
	struct list_head	tunnel_list;
	struct sock		*sk;	/* used for l2tp packets to userspace */

	struct socket		*rx_socket;
	struct sockaddr_l2tp	l2tp_id;	/* session id */
};

static int num_tunnels, num_infos;	/* debug */

static kmem_cache_t *session_cachep;

int l2tp_alloc_tunnel_id(struct l2tp_tunnel *tunnel, int hint)
{
	u16 id = hint;
	int i;
	if (hint <= 0)
		id = next_l2tp_tunnel_id++;

	for (i=0; i<L2TP_MAX_TUNNELS; i++) {
		if (!id)
			next_l2tp_tunnel_id = 1;
		tunnel->tunnel_id = next_l2tp_tunnel_id++;
		next_l2tp_tunnel_id %= L2TP_MAX_TUNNELS;
		if (!l2tp_tunnels[tunnel->tunnel_id]) {
			l2tp_tunnels[tunnel->tunnel_id] = tunnel;
			return 0;
		}
		if (hint) {
			pr_debug("unabled to alloc requested l2tp tunnel id (0x%04x)\n", hint);
			return -EBUSY;
		}
	}

	pr_debug("unabled to alloc l2tp tunnel id\n");
	return -EBUSY;
}

static struct l2tp_tunnel *alloc_l2tp_tunnel(struct l2tp_info *l2tp, int hint)
{
	struct l2tp_tunnel *tunnel;

	tunnel = vmalloc(sizeof(*tunnel));
	if (!tunnel)
		return NULL;

	memset(tunnel, 0, sizeof(*tunnel));
	tunnel->l2tp = l2tp;
	atomic_set(&tunnel->count, 1);
	if (l2tp_alloc_tunnel_id(tunnel, hint)) {
		vfree(tunnel);
		return NULL;
	}

	list_add(&tunnel->list, &l2tp->tunnel_list);
	num_tunnels ++;
	pr_debug("alloc'd l2tp tunnel[0x%04x]\n", tunnel->tunnel_id);
	return tunnel;
}

#define get_l2tp_tunnel(x)	(atomic_inc(&(x)->count))

static struct l2tp_tunnel *find_get_l2tp_tunnel(int hint)
{
	struct l2tp_tunnel *tunnel;
	if (hint <= 0 || hint >= L2TP_MAX_TUNNELS)
		return NULL;

	tunnel = l2tp_tunnels[hint];
	if (tunnel)
		get_l2tp_tunnel(tunnel);
	return tunnel;
}

static void free_l2tp_session(struct l2tp_session *session);
static void put_l2tp_tunnel(struct l2tp_tunnel *tunnel)
{
	int i;
	if (!atomic_dec_and_test(&tunnel->count))
		return;

	pr_debug("freeing l2tp_tunnel: %p\n", tunnel);
	l2tp_tunnels[tunnel->tunnel_id] = NULL;
	list_del(&tunnel->list);

	if (tunnel->tx_socket) {
		fput(tunnel->tx_socket->file);	/* sockfd_put */
		tunnel->tx_socket = NULL;
	}

	for (i=0; i<65536; i++) {
		if (tunnel->sessions[i])
			free_l2tp_session(tunnel->sessions[i]);
	}

	vfree(tunnel);
	num_tunnels --;
}

struct l2tp_info *alloc_l2tp_info(struct sock *sk)
{
	struct l2tp_info *l2tp = kmalloc(sizeof(struct l2tp_info), GFP_KERNEL);

	if (!l2tp)
		return NULL;

	memset(l2tp, 0, sizeof(*l2tp));

	INIT_LIST_HEAD(&l2tp->tunnel_list);
	l2tp->sk = sk;

	num_infos ++;
	return l2tp;
}

void free_l2tp_info(struct l2tp_info *l2tp)
{
	struct list_head *pos, *next;
	pr_debug("free_l2tp_info(%p)\n", l2tp);
	list_for_each_safe(pos, next, &l2tp->tunnel_list) {
		struct l2tp_tunnel *tunnel = (struct l2tp_tunnel *)pos;
		pr_debug("free_l2tp_info: freeing tunnel %p\n", tunnel);
		put_l2tp_tunnel(tunnel);
	}
	kfree(l2tp);
	num_infos --;
}

static inline struct l2tp_info *sk_l2tp_info(struct sock *sk)
{
	return sk->sk_protinfo;
}

u32 skb_pull4(struct sk_buff *skb)
{
	u32 val;
	unsigned char *data = skb->data;

	if (!data)
		return ~0;
	val = data[0];
	val <<= 8;
	val |= data[1];
	val <<= 8;
	val |= data[2];
	val <<= 8;
	val |= data[3];

	if (!skb_pull(skb, 4))
		printk("skb_pull2: short\n");
	return val;
}

u16 skb_pull2(struct sk_buff *skb)
{
	u16 val;
	unsigned char *data = skb->data;

	val = data[0];
	val <<= 8;
	val |= data[1];

	if (!skb_pull(skb, 2))
		printk("skb_pull2: short\n");
	return val;
}

void l2tp_ReInput(channel_t *ch, struct sk_buff *skb)
{
	struct l2tp_session *session = (void *)ch;
	struct l2tp_tunnel *tunnel = session->tunnel;
	struct l2tp_info *l2tp = tunnel->l2tp;
	u8 *data;

	pr_debug("l2tp_ReInput(%p/%p)\n", skb, l2tp);

	data = skb_push(skb, 6);
	*data++ = 0x00;
	*data++ = 0x02;
	*data++ = tunnel->tunnel_id >> 8;
	*data++ = tunnel->tunnel_id;
	*data++ = session->session_id >> 8;
	*data++ = session->session_id;

	if (sock_queue_rcv_skb(l2tp->sk, skb)) {
		//printk("queue failed\n");
		kfree_skb(skb);
	}
}

/* l2tp_data_ready
 *	Called with the socket lock on the UDP sock lock held.
 */
static int __l2tp_data_ready(struct sock *sk)
{
	struct l2tp_info *l2tp = sk->sk_user_data;
	struct l2tp_tunnel *tunnel;
	struct l2tp_session *session;
	struct sk_buff *skb;
	u16 flags, len, tunnel_id, session_id, offset = 0, Ns = 0, Nr = 0;
	u16 *data;
	int pull_len;
	int err;

	pr_debug("__l2tp_data_ready(%p)\n", l2tp);

	skb = skb_recv_datagram(sk, 0, 1, &err);
	if (NULL == skb)
		return 0;

	skb_orphan(skb);
	skb->sk = NULL;
	skb->ip_summed = CHECKSUM_NONE;

	/* remove the UDP header from skb */
	skb_pull(skb, 8);

	len = skb->len;
	data = (u16 *)skb->data;

	flags = ntohs(*data++);
	if (L2TPF_Ver2 != (flags & L2TPF_Ver)) {
		printk(KERN_INFO "l2tp: packet flags (0x%04x)not ver 2?\n", flags);
		goto discard;
	}
	if (flags & L2TPF_L)
		len = ntohs(*data++);
	else
		len = skb->len;

	if (skb->len < len) {
		printk("discarding short packet (l2tp = %d, skb->len = %d)\n",
			len, skb->len);
		goto discard;
	}

	if (flags & L2TPF_T) {
		/* control packet -- pass on to userspace */
queue:
		err = sock_queue_rcv_skb(l2tp->sk, skb);
		if (err) {
			pr_debug("__l2tp_data_ready: queue_rcv_skb failed (%d)\n", err);
			goto discard;
		}
		return 1;
	}

	tunnel_id = ntohs(*data++);
	session_id = ntohs(*data++);

	pr_debug("l2tp data packet %d.%d\n", tunnel_id, session_id);

	/* lookup the tunnel -- does it exist? is it allowed on this socket? */
	tunnel = l2tp_tunnels[tunnel_id % L2TP_MAX_TUNNELS];
	if (!tunnel) {
		pr_debug("__l2tp_data_ready: no such tunnel 0x%04x\n", tunnel_id);
		goto queue;
	}

	if (!session_id) {
		pr_debug("__l2tp_data_ready: data packet for 0 session???\n");
		goto queue;
	}

	session = tunnel->sessions[session_id % L2TP_MAX_SESSIONS];
	if (!session) {
		pr_debug("__l2tp_data_ready: no such session 0x%04x.0x%04x\n",
			tunnel_id, session_id);
		goto queue;
	}
	session->ch.CH_rx_bytes += skb->len;

	if (flags & L2TPF_S) {
		Ns = ntohs(*data++);
		Nr = ntohs(*data++);
	}
	
	if (flags & L2TPF_O)
		offset = ntohs(*data++);
	else
		offset = 0;

	pull_len = (unsigned char *)data - skb->data;
	pull_len += offset;

	skb_pull(skb, pull_len);

	pr_debug("flags = 0x%04x, len = 0x%04x, tunnel = 0x%04x, session = 0x%04x\n",
		flags, len, tunnel_id, session_id);

	/* assuming we have a channel, receive the skb. */
	skb->dev = NULL;
	dst_release(skb->dst);
	skb->dst = NULL;
#ifdef CONFIG_NETFILTER
                nf_conntrack_put(skb->nfct);
                skb->nfct = NULL;
#ifdef CONFIG_NETFILTER_DEBUG
                skb->nf_debug = 0;
#endif
#endif

	ch_Input(&session->ch, skb);
	return 0;

discard:
	skb_free_datagram(sk, skb);
	return 1;
}

static void l2tp_data_ready(struct sock *sk, int count)
{
	/* hook in case we need to loop for multiple packets */
	__l2tp_data_ready(sk);
}

int l2tp_connect(struct socket *sock, struct sockaddr *sa, int sockaddr_len, int flags)
{
	struct l2tp_info *l2tp;
	struct l2tp_tunnel *tunnel;
	struct sockaddr_l2tp *sl;

	sl = (struct sockaddr_l2tp *)sa;
	if (sockaddr_len < sizeof(struct sockaddr_l2tp))
		return -EINVAL;

	if (!sock->sk)
		return -EINVAL;

	l2tp = sk_l2tp_info(sock->sk);
	if (!l2tp)
		return -EINVAL;

	tunnel = find_get_l2tp_tunnel(sl->sl_tunnel);
	if (!tunnel)
		return -EINVAL;

	if (tunnel->l2tp != l2tp) {
		printk("l2tp_connect: l2tp mismatch\n");
		put_l2tp_tunnel(tunnel);
		return -EINVAL;
	}

	if (!sl->sl_tunnel) {
		if (!sl->sl_peer_tunnel) {
			// Free the tunnel
			pr_debug("freeing l2tp tunnel %p\n", tunnel);
			if (l2tp_tunnels[sl->sl_tunnel] != tunnel) {
				printk("l2tp_connect: tunnel changed.\n");
				goto out_inval;
			}
			put_l2tp_tunnel(tunnel);
			put_l2tp_tunnel(tunnel);
			return 0;
		}
		printk("l2tp: no tunnel in connect\n");
		goto out_inval;
	}

	tunnel->peer_tunnel_id = sl->sl_peer_tunnel;
	put_l2tp_tunnel(tunnel);

	return 0;

out_inval:
	put_l2tp_tunnel(tunnel);
	return -EINVAL;
}

static int l2tp_bind(struct socket *sock, struct sockaddr *_sa, int sa_len)
{
	struct l2tp_tunnel *tunnel = NULL;
	struct sockaddr_l2tp *sa = (void *)_sa;
	struct sock *sk = sock->sk;
	struct l2tp_info *l2tp = sk_l2tp_info(sk);
	int err;

	if (sa_len < 0 || sa_len < sizeof(*sa))
		return -EINVAL;

	if (sa->sl_family != AF_L2TP)
		return -EINVAL;

	if (sa->sl_rx_sfd != -1 && sa->sl_session) {
		printk("l2tp_bind: cannot bind session(0x%04x) to rx_sfd(%d)\n",
			sa->sl_session, sa->sl_rx_sfd);
		return -EINVAL;
	}

	if (sa->sl_tx_sfd != -1 && sa->sl_session) {
		printk("l2tp_bind: cannot bind session(0x%04x) to tx_sfd(%d)\n",
			sa->sl_session, sa->sl_tx_sfd);
		return -EINVAL;
	}

	/* subtle sematics: new tunnels will only be created for session 
	 * ids of 0, which is to say for a new control session.
	 */
	tunnel = find_get_l2tp_tunnel(ntohs(sa->sl_tunnel));
	if (!tunnel && !sa->sl_session)
		tunnel = alloc_l2tp_tunnel(l2tp, ntohs(sa->sl_tunnel));

	if (!tunnel) {
		printk("l2tp_bind: no tunnel\n");
		return -EBUSY;
	}
	sa->sl_tunnel = htons(tunnel->tunnel_id);

	if (sa->sl_rx_sfd != -1) {
		err = -EBUSY;
		if (SS_UNCONNECTED != sock->state)
			goto out_err;

		if (l2tp->rx_socket) {
			printk("attempt to bind busy l2tp to rx socket\n");
			goto out_err;
		}

		l2tp->rx_socket = sockfd_lookup(sa->sl_rx_sfd, &err);
		if (l2tp->rx_socket == NULL)
			goto out_err;

		pr_debug("cool: have a socket (%p)->sk (%p)->socket = %p.\n",
			l2tp->rx_socket, l2tp->rx_socket->sk,
			l2tp->rx_socket->sk ? l2tp->rx_socket->sk->socket : (void*)-1);

		l2tp->rx_socket->sk->sk_user_data = l2tp;
		l2tp->rx_socket->sk->sk_data_ready = l2tp_data_ready;
		l2tp->rx_socket->sk->sk_sndbuf = 262144;
		l2tp->rx_socket->sk->sk_rcvbuf = 262144;

		sock->state = SS_CONNECTED;
	}

	if (sa->sl_tx_sfd != -1) {
		err = -EBUSY;
		if (SS_CONNECTED != sock->state) {
			printk("attempt to bind tunnel tx without rx socket\n");
			goto out_err;
		}

		if (tunnel->tx_socket) {
			printk("attempt to bind busy tunnel to tx socket\n");
			goto out_err;
		}

		tunnel->tx_socket = sockfd_lookup(sa->sl_tx_sfd, &err);
		if (tunnel->tx_socket == NULL)
			goto out_err;

		pr_debug("cool: have a tx socket (%p)->sk (%p)->tx_socket = %p.\n",
			tunnel->tx_socket, tunnel->tx_socket->sk,
			tunnel->tx_socket->sk ? tunnel->tx_socket->sk->socket : (void*)-1);

		tunnel->tx_socket->sk->sk_sndbuf = 262144;
		tunnel->tx_socket->sk->sk_rcvbuf = 262144;
	}

	if (SS_CONNECTED != sock->state) {
		printk("l2tp: post-sockets, not connected!\n");
		goto out_err;
	}


	err = -EINVAL;
	if (!sa->sl_session && !l2tp->rx_socket) {
		printk("no session no rx socket\n");
		goto out_err;
	}

	err = -EINVAL;
	if (sa->sl_session && !tunnel->tx_socket) {
		printk("have session no tx socket\n");
		goto out_err;
	}

	err = -EINVAL;
	if (sa->sl_session && l2tp->rx_socket) {
		printk("have session have socket\n");
		goto out_err;
	}

	err = -EBUSY;
	if (tunnel->sessions[sa->sl_session]) {
		printk("session 0x%04x in use\n", sa->sl_session);
		goto out_err;
	}


	memcpy(&l2tp->l2tp_id, sa, sizeof(l2tp->l2tp_id));

	pr_debug("l2tp_bind: okay. sock->sk = %p\n", sock->sk);
	return 0;

out_err:
	if (tunnel)
		put_l2tp_tunnel(tunnel);
	return err;
}

static int l2tp_release(struct socket *sock)
{
	struct l2tp_info *l2tp;
	/* FIXME: release any bound sockets, any children */
	pr_debug("l2tp_release\n");
	if (!sock->sk) {
		printk("l2tp_release: no sock?\n");
		return 0;
	}

	l2tp = sk_l2tp_info(sock->sk);
	if (l2tp && l2tp->rx_socket)	/* sock may not have been bound */
		sock_set_flag(l2tp->rx_socket->sk, SOCK_DEAD);	/* prevent data_ready() */

	sock->sk->sk_protinfo = NULL;

	if (l2tp) {
		struct socket *rx_socket = l2tp->rx_socket;
		l2tp->rx_socket = NULL;
		free_l2tp_info(l2tp);
		if (NULL != rx_socket) {
			pr_debug("releasing rx_socket\n");
			fput(rx_socket->file);	/* sockfd_put */
		}
	}

	sk_free(sock->sk);
	sock->sk = NULL;
	pr_debug("l2tp_release: done\n");
	return 0;
}

/*
 *	ripped from ipv4/raw.c
 *
 *	This should be easy, if there is something there
 *	we return it, otherwise we block.
 */
int __l2tp_recvmsg(struct sock *sk, struct msghdr *msg, int len,
		int noblock, int flags, int *addr_len)
{
	int copied = 0;
	int err = -EOPNOTSUPP;
	struct sockaddr_l2tp *saddr = (struct sockaddr_l2tp *)msg->msg_name;
	struct sk_buff *skb;

	if (flags & MSG_OOB)
		goto out;

	if (addr_len)
		*addr_len = sizeof(*saddr);

	if (flags & MSG_ERRQUEUE) {
		/*err = ip_recv_error(sk, msg, len);*/
		err = 0;	/* FIXME: pass on errors */
		goto out;
	}

	skb = skb_recv_datagram(sk, flags, noblock, &err);
	if (!skb)
		goto out;

	copied = skb->len;
	if (len < copied) {
		msg->msg_flags |= MSG_TRUNC;
		copied = len;
	}

	err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
	if (err)
		goto done;

	sock_recv_timestamp(msg, sk, skb);

	/* Copy the address. */
	if (saddr) {
		memset(saddr, 0, sizeof(*saddr));
		saddr->sl_family = AF_L2TP;
		/* FIXME: what's the address? */
	}
done:
	skb_free_datagram(sk, skb);
out:	return err ? : copied;
}


/* ripped from af_inet.c */
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,0))
static int l2tp_recvmsg(struct kiocb *iocb, struct socket *sock,
			struct msghdr *msg, size_t size,
			int flags)
#else
static int l2tp_recvmsg(struct socket *sock, struct msghdr *msg, int size,
			int flags, struct scm_cookie *scm)
#endif
{
	struct sock *sk = sock->sk;
	int addr_len = 0;
	int err;

	err = __l2tp_recvmsg(sk, msg, size, flags&MSG_DONTWAIT,
				flags&~MSG_DONTWAIT, &addr_len);
	if (err >= 0)
		msg->msg_namelen = addr_len;
//printk("l2tp_recvmsg: %d\n", err);
	return err;
}

static int l2tp_getname(struct socket *sock, struct sockaddr *sa,
			int *lenp, int peer)
{
	struct sock *sk = sock->sk;
	struct l2tp_info *l2tp = sk_l2tp_info(sk);

	if (peer)
		return -EINVAL;

	memcpy(sa, &l2tp->l2tp_id, sizeof(l2tp->l2tp_id));
	*lenp = sizeof(l2tp->l2tp_id);
	return 0;
}

void l2tp_bab_use(channel_t *ch)
{
	printk("l2tp_bab_use\n");
}

void l2tp_bab_unuse(channel_t *ch)
{
	printk("l2tp_bab_unuse\n");
}

int l2tp_bab_output(channel_t *ch, struct sk_buff *skb)
{
	struct l2tp_session *session = (void *)ch;
	u16 *data;
	int err;

	/* recursion idea taken from ip_gre */
	if (unlikely(session->recursion++))
		goto tx_error;

	pr_debug("l2tp_bab_output(%p)\n", skb);

	pr_debug("l2tp_bab_output: headroom = %d\n", skb_headroom(skb));
	if (skb_headroom(skb) < 40 || skb_cloned(skb) || skb_shared(skb)) {
		struct sk_buff *newskb = skb_realloc_headroom(skb, 40);
	pr_debug("realloc'd\n");
		if (!newskb) {
			printk("l2tp_bab_output: skb_realloc_headroom failed\n");
			session->ch.stats.tx_dropped++;
			goto tx_error;
		}
		if (skb->sk)
			skb_set_owner_w(newskb, skb->sk);
		dev_kfree_skb(skb);
		skb = newskb;
	}

	skb_orphan(skb);
	pr_debug("memset'd\n");
	memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt));
	pr_debug("dst'd\n");
	dst_release(skb->dst);
	skb->dst = NULL;

	skb->ip_summed = CHECKSUM_NONE;
	pr_debug("l2tp_bab_output: len = %d\n", skb->len);
	data = (u16 *)skb_push(skb, 14);
	pr_debug("l2tp_bab_output: len = %d\n", skb->len);

	skb->h.raw = (u8*)data;
 {
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,0))
        struct inet_opt *inet = inet_sk(session->tunnel->tx_socket->sk);
        u16 sport = ntohs(inet->sport);
        u16 dport = ntohs(inet->dport);
#define my_ip_queue_xmit(x,y)	my_ip_queue_xmit(x,y)
#else
	u16 sport = session->tunnel->tx_socket->sk->sport;
	u16 dport = session->tunnel->tx_socket->sk->dport;
#define my_ip_queue_xmit(x,y)	ip_queue_xmit(x)
#endif

	*data++ = sport;	// udp source
	*data++ = dport;	// udp dest
 }
	*data++ = htons(skb->len);	// udp len
	*data++ = htons(0x0000);	// udp csum
	*data++ = htons(0x0002);
	*data++ = htons(session->tunnel->peer_tunnel_id);
	*data++ = htons(session->peer_session_id);

	skb_set_owner_w(skb, session->tunnel->tx_socket->sk);
	err = my_ip_queue_xmit(skb, 1);
	if (err)
		printk("l2tp_bab_output: err = %d\n", err);
	else
		session->ch.CH_tx_bytes += skb->len;

	session->recursion--;
	session->ch.OutputComplete(&session->ch);
	return 0;

tx_error:
	pr_debug("tx_error\n");
	session->ch.stats.collisions++;
	dev_kfree_skb(skb);
	session->recursion--;
	pr_debug("tx_error done\n");
	session->ch.OutputComplete(&session->ch);
	return 0;
}

int l2tp_bab_connect(channel_t *ch, const char *num, u32 flags)
{
	printk("l2tp_bab_connect\n");
	return -EINVAL;
}

int l2tp_bab_hangup(channel_t *ch)
{
	printk("l2tp_bab_hangup\n");
	return -EINVAL;
}

int l2tp_bab_ioctl(channel_t *ch, unsigned int cmd, unsigned long arg)
{
	printk("l2tp_bab_ioctl\n");
	return -EINVAL;
}

static struct l2tp_session *setup_session(struct l2tp_tunnel *tunnel, u16 session_id, u16 peer_session_id)
{
	struct l2tp_session *session;
	int err;

	pr_debug("setup_session\n");
	session = kmem_cache_alloc(session_cachep, GFP_KERNEL);
	if (!session)
		return NULL;

	memset(session, 0, sizeof(*session));

	strcpy(session->ch.device_name, "l2tp0");
	strcpy(session->ch.dev_class, "l2tp");

	session->tunnel = tunnel;
	session->session_id = session_id;
	session->peer_session_id = peer_session_id;

	session->ch.mru = 1442;
	session->ch.use = l2tp_bab_use;
	session->ch.unuse = l2tp_bab_unuse;
	session->ch.Output = l2tp_bab_output;
	session->ch.Connect = l2tp_bab_connect;
	session->ch.Hangup = l2tp_bab_hangup;
	session->ch.ioctl = l2tp_bab_ioctl;
	session->ch.ReInput = l2tp_ReInput;

	set_busy(&session->ch);

	err = RegisterChannel(&session->ch);
	if (err) {
		printk("RegisterChannel: %d\n", err);
		kmem_cache_free(session_cachep, session);
		return NULL;
	}
	pr_debug("session setup!\n");

	clear_busy(&session->ch);
	session->ch.state = CS_CONNECTED;
	session->ch.Open(&session->ch);
	session->ch.Up(&session->ch);
	ch_ioctl(NULL, NULL, session->ch.link, BIOC_SETLCFL, BF_PPP);
	ch_ioctl(NULL, NULL, session->ch.link, BIOC_SETRCFL, BF_PPP);

	tunnel->sessions[session_id] = session;
	return session;
}

static void free_l2tp_session(struct l2tp_session *session)
{
	UnregisterChannel(&session->ch);
	kmem_cache_free(session_cachep, session);
}

static int l2tp_join_bundle(struct l2tp_info *l2tp, unsigned int cmd, struct l2tp_join_bundle *j)
{
	struct l2tp_tunnel *tunnel;
	struct l2tp_session *session;
	int ret;

	pr_debug("l2tp_join_bundle\n");
	tunnel = find_get_l2tp_tunnel(j->tunnel);
#if 0
	if (!tunnel)
		tunnel = alloc_l2tp_tunnel(l2tp, j->tunnel);
#endif

	if (!tunnel || tunnel->l2tp != l2tp) {
		printk("where's the tunnel(%d)?\n", j->tunnel);
		return -EINVAL;
	}

	if (!tunnel->peer_tunnel_id)
		tunnel->peer_tunnel_id = j->peer_tunnel;

	if (!tunnel->peer_tunnel_id) {
		printk("what's the tunnel(%d)'s peer???\n", j->tunnel);
		put_l2tp_tunnel(tunnel);
		return -EINVAL;
	}

	if (!j->session || !j->peer_session) {
		printk("session not set %d/%d\n", j->session, j->peer_session);
		put_l2tp_tunnel(tunnel);
		return -EINVAL;
	}

	session = tunnel->sessions[j->session];
	if (!session)
		session = setup_session(tunnel, j->session, j->peer_session);

	if (!session) {
		printk("where's the session(%d)?\n", j->session);
		put_l2tp_tunnel(tunnel);
		return -ENOMEM;
	}

	ret = ch_ioctl(NULL, NULL, session->ch.link, cmd, j->arg);
	put_l2tp_tunnel(tunnel);
	return ret;
}

static int l2tp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
{
	struct sock *sk = sock->sk;
	struct l2tp_info *l2tp = sk_l2tp_info(sk);
	struct l2tp_join_bundle tmp;

	copy_from_user(&tmp, (void *)arg, sizeof(tmp));
	return l2tp_join_bundle(l2tp, cmd, &tmp);
}

static struct proto_ops l2tp_proto_ops = {
	family:		PF_L2TP,

	mmap:		sock_no_mmap,
	sendpage:	sock_no_sendpage,

	poll:		datagram_poll,

	getname:	l2tp_getname,
	bind:		l2tp_bind,
	connect:	l2tp_connect,
	release:	l2tp_release,
	recvmsg:	l2tp_recvmsg,
	ioctl:		l2tp_ioctl,
};

/* l2tp_create
 *	prepare a new l2tp socket.
 */
static int l2tp_create(struct socket *sock, int protocol)
{
	struct sock *sk;
	if (protocol)
		return -EINVAL;

	sk = my_sk_alloc(AF_L2TP, GFP_KERNEL, 1, NULL);
	if (unlikely(!sk))
		return -ENOBUFS;

	sk->sk_sndbuf = 262144;
	sk->sk_rcvbuf = 262144;

	sk->sk_protinfo = alloc_l2tp_info(sk);
	if (!sk->sk_protinfo) {
		sk_free(sk);
		return -ENOBUFS;
	}

	sock_init_data(sock, sk);
	sock->ops = &l2tp_proto_ops;

	pr_debug("l2tp_create: good\n");
	return 0;
}

static struct net_proto_family l2tp_family_ops = {
	family:	PF_L2TP,
	create:	l2tp_create,
};

/* module setup is easy -- just register our protocol family and let er rip.
 */
static int __init l2tp_init(void)
{
	int ret;

	session_cachep = kmem_cache_create("l2tp_session", sizeof(struct l2tp_session), 0, 0, NULL, NULL);
	if (!session_cachep) {
		printk(KERN_ERR "l2tp: can't create session slab cache\n");
		return -ENOMEM;
	}

	ret = sock_register(&l2tp_family_ops);
	if (ret) {
		printk(KERN_ERR "l2tp: can't register socket family");
		kmem_cache_destroy(session_cachep);
		return ret;
	}

	printk(KERN_NOTICE "l2tp is loaded (build %d)\n", (int)L2TP_BUILD);
	return 0;
}

/* cleanup is easy too: unregister the protocol family and make sure all 
 * our data structures are freed.
 */
static void l2tp_cleanup(void)
{
	sock_unregister(AF_L2TP);
	kmem_cache_destroy(session_cachep);
	/* FIXME: assert that all structures are gone */

	printk(KERN_NOTICE "no more l2tp (build %d) (num_tunnels %d  num_infos %d)\n",
		(int)L2TP_BUILD, num_tunnels, num_infos);
}

module_init(l2tp_init);
module_exit(l2tp_cleanup);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Benjamin LaHaise <bcrl@kvack.org>");
MODULE_DESCRIPTION("L2TP/Babylon");
