#include <fcntl.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "l2tp_linux.h"
#include "l2tpd.h"
#include "l2tp_tunnel.h"

static void make_nonblock(int fd)
{
	int fl = fcntl(fd, F_GETFL);
	if (-1 == fl)
		perror("l2tp_tunnel_t: fcntl(F_GETFL)\n");
	if (-1 == fcntl(fd, F_SETFL, fl | O_NONBLOCK))
		perror("l2tp_tunnel_t: fcntl(F_SETFL)\n");
}

l2tp_peer_t::l2tp_peer_t(l2tpd_t *parent, struct sockaddr_in *local, struct sockaddr_in *remote, const char *secret)
{
	l2tpd = parent;
	l2tp_tunnels = parent->l2tp_tunnels;

	udp_fd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
	if (udp_fd < 0) {
		perror("socket(l2tp_tunnel_t)");
		exit(1);
	}

	if (secret) {
		m_secret = secret;
		m_secret_len = strlen(m_secret);
	} else {
		m_secret = NULL;
		m_secret_len = 0;
	}
	make_nonblock(udp_fd);
	local_sin = *local;
	remote_sin = *remote;

	int one = 1;
	if (setsockopt(udp_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
		perror("l2tpd_t:setsockopt(udp_fd, SO_REUSEADDR)");
		exit(1);
	}

	fprintf(stderr, "remote: %s:%d ",
		inet_ntoa(remote->sin_addr),
		ntohs(remote->sin_port));

	if (bind(udp_fd, (struct sockaddr *)local, sizeof(*local)) < 0) {
		perror("l2tp_peer_t::l2tp_peer_t: bind(udp)");
		exit(1);
	}

	if (connect(udp_fd, (struct sockaddr *)remote, sizeof(*remote)) < 0) {
		perror("connect(l2tp_tunnel_t)");
		exit(1);
	}

	struct sockaddr_in sin;
	socklen_t sin_len = sizeof(sin);
	if (getsockname(udp_fd, (struct sockaddr *)&sin, &sin_len) < 0) {
		perror("getsockbyname(l2tp_tunnel_t)");
		exit(1);
	}
	local = &sin;

	fprintf(stderr, "local: %s:%d\n",
		inet_ntoa(local->sin_addr),
		ntohs(local->sin_port));

	/* create and bind l2tp to this new udp socket */
	l2tp_fd = socket(AF_L2TP, SOCK_DGRAM, 0);
	if (l2tp_fd < 0) {
		perror("socket(l2tpd_t: l2tp)");
		exit(1);
	}

	make_nonblock(l2tp_fd);

	struct sockaddr_l2tp l2tp_sa;
	memset(&l2tp_sa, 0, sizeof(l2tp_sa));
	l2tp_sa.sl_family = AF_L2TP;
	l2tp_sa.sl_rx_sfd = udp_fd;
	l2tp_sa.sl_tx_sfd = -1;
	l2tp_sa.sl_tunnel = htons(0);
	l2tp_sa.sl_session = htons(0);
	if (bind(l2tp_fd, (struct sockaddr *)&l2tp_sa, sizeof(l2tp_sa))) {
		perror("l2tp_peer_t: bind(l2tp)");
		exit(1);
	}

	this->SelectSetEvents(l2tp_fd, SEL_READ);
}

l2tp_tunnel_t *l2tp_peer_t::make_tunnel(void)
{
#if 0
	/* read back the new tunnel's id */
	struct sockaddr_l2tp l2tp_sa;
	socklen_t sin_len = sizeof(l2tp_sa);
	if (getsockname(l2tp_fd, (struct sockaddr *)&l2tp_sa, &sin_len) < 0) {
		perror("getsockbyname(l2tp_tunnel_t:l2tp_fd)");
		exit(1);
	}

	u16 tunnel_id = ntohs(l2tp_sa.sl_tunnel);
	//peer_tunnel_id = 0;
	fprintf(stderr, "tunnel id: %d\n", tunnel_id);
#endif
	return new l2tp_tunnel_t(this, 0);
}

int l2tp_peer_t::alloc_tunnel_id(l2tp_tunnel_t *tunnel)
{
	struct sockaddr_l2tp l2tp_sa;
	socklen_t sin_len = sizeof(l2tp_sa);

	// Create the new tunnel in the kernel
	memset(&l2tp_sa, 0, sizeof(l2tp_sa));
	l2tp_sa.sl_family = AF_L2TP;
	l2tp_sa.sl_rx_sfd = -1;
	l2tp_sa.sl_tx_sfd = udp_fd;
	l2tp_sa.sl_tunnel = htons(0);
	l2tp_sa.sl_session = htons(0);
	l2tp_sa.sl_peer_tunnel = htons(0);
	l2tp_sa.sl_peer_session = htons(0);
	if (bind(l2tp_fd, (struct sockaddr *)&l2tp_sa, sizeof(l2tp_sa))) {
		perror("l2tp_peer_t::alloc_tunnel_id: bind(l2tp)");
		return -1;
	}

	if (getsockname(l2tp_fd, (struct sockaddr *)&l2tp_sa, &sin_len) < 0) {
		perror("getsockbyname(l2tp_tunnel_t:l2tp_fd)");
		exit(1);
	}

	u16 id = ntohs(l2tp_sa.sl_tunnel);

	if (l2tp_tunnels[id])
		fprintf(stderr, "eek!  tunnel %d already busy!\n", id);

	l2tp_tunnels[id] = tunnel;
	return id;
#if 0
	u16 id = free_tunnel_id++;

	for (int i=0; i<65536; i++) {
		if (!id)
			id = 1;
		if (!l2tp_tunnels[id]) {
			l2tp_tunnels[id] = tunnel;
			return id;
#if 0
			struct sockaddr_l2tp l2tp_sa;
			l2tp_sa.sl_family = AF_L2TP;
			l2tp_sa.sl_sfd = -1;
			l2tp_sa.sl_tunnel = htons(id);
			l2tp_sa.sl_session = htons(0);
			if (!connect(l2tp_fd, (struct sockaddr *)&l2tp_sa, sizeof(l2tp_sa))) {
				fprintf(stderr, "alloc'd tunnel %d!\n", id);
				l2tp_tunnels[id] = tunnel;
				return id;
			}
			perror("l2tp_tunnel_t: connect(l2tp)");
#endif
		}
		id++;
	}

	return -1;
#endif
}

void l2tp_peer_t::remove_tunnel_id(int id, l2tp_tunnel_t *tunnel)
{
	if (l2tp_tunnels[id] == tunnel) {
		struct sockaddr_l2tp l2tp_sa;

		l2tp_tunnels[id] = NULL;

		// Destroy the tunnel in the kernel
		memset(&l2tp_sa, 0, sizeof(l2tp_sa));
		l2tp_sa.sl_family = AF_L2TP;
		l2tp_sa.sl_rx_sfd = -1;
		l2tp_sa.sl_tx_sfd = -1;
		l2tp_sa.sl_tunnel = htons(tunnel->tunnel_id);
		l2tp_sa.sl_session = htons(0);
		l2tp_sa.sl_peer_tunnel = htons(0);
		l2tp_sa.sl_peer_session = htons(0);
		if (connect(l2tp_fd, (struct sockaddr *)&l2tp_sa, sizeof(l2tp_sa)))
			perror("l2tp_tunnel_t::remove_tunnel_id connect(l2tp)");
	} else
		fprintf(stderr, "l2tp_peer_t::alloc_tunnel_id -- BUG!\n");
}

void l2tp_peer_t::dump_sessions(ctrlfd_t *cfd)
{
	unsigned i;
	cfd->printf("peer(%s:%d) ",
		inet_ntoa(remote_sin.sin_addr),
		ntohs(remote_sin.sin_port));

	// we need a second printf here as inet_ntoa uses a static buffer.
	cfd->printf("local(%s:%d)\n",
		inet_ntoa(local_sin.sin_addr),
		ntohs(local_sin.sin_port));
	for (i=0; i<65536; i++) {
		l2tp_tunnel_t *tunnel = l2tp_tunnels[i];
		if (tunnel && tunnel->peer == this)
			tunnel->dump_sessions(cfd);
	}
}


//	l2tp_tunnel = new l2tp_tunnel_t(&our_sin, sin, tunnel);
//	l2tp_tunnels[l2tp_tunnel->tunnel_id] = l2tp_tunnel;

/* handle_packet:
 *	this function should only get called for the first incoming UDP 
 *	packet for a tunnel.  Its job is to setup a socket specifically 
 *	for the new tunnel and kick off the tunnel with the packet.  Just 
 *	in case, we also pass on any control packets to established tunnel 
 *	(this can happen if the packet is queued up in the kernel before 
 *	the new tunnel has time to setup the kernel side).
 */
void l2tp_peer_t::handle_packet(char *buf, unsigned size, struct sockaddr_in *sin)
{
	l2tp_tunnel_t	*l2tp_tunnel;
	u16 *data = (u16 *)buf;
	u16 flags, len, tunnel, session;

	if (debug)
		fprintf(stderr, "l2tp_peer_t::handle_packet\n");

	flags = ntohs(*data++);
	if (flags & L2TPF_L)
		len = ntohs(*data++);
	else
		len = size;
	tunnel = ntohs(*data++);
	session = ntohs(*data++);

	l2tp_tunnel = l2tp_tunnels[tunnel];
	if (!l2tp_tunnel) {
#if 0
		if (tunnel || session) {
			fprintf(stderr, "not creating tunnel for %d.%d\n",
				tunnel, session);
			return;
		}
#endif
		// This tunnel doesn't exist.  Maybe we're still opening.
		if (l2tp_tunnels[0])
			l2tp_tunnel = l2tp_tunnels[0];
		else if (!tunnel)
			l2tp_tunnel = make_tunnel();
		else
			l2tp_tunnel = new l2tp_tunnel_t(this, tunnel);
	}

	if (l2tp_tunnel)
		l2tp_tunnel->handle_packet(buf, size, sin);
}

l2tp_tunnel_t *l2tp_peer_t::find_next_tunnel(l2tp_tunnel_t *tunnel)
{
	u16 id = 0;
	if (tunnel) {
		id = tunnel->tunnel_id + 1;
		if (!id)
			return NULL;
	}

	do {
		if (l2tp_tunnels[id] && l2tp_tunnels[id]->peer == this)
			return l2tp_tunnels[id];
		id++;
	} while (id) ;
	return NULL;
}

l2tp_tunnel_t *l2tp_peer_t::find_avail_tunnel(void)
{
	l2tp_tunnel_t *tunnel;
	for (tunnel = this->find_next_tunnel(NULL); tunnel;
	     tunnel = this->find_next_tunnel(tunnel)) {
		if (!tunnel->is_idle())
			break;
	}

	if (!tunnel)
		tunnel = this->make_tunnel();
	return tunnel;
}

