/* radius.cc
 * Copyright (C) 1997-2000 SpellCaster Telecommunications Inc.
 * $Id: radius.cc,v 1.4 2004/10/18 02:17:38 bcrl Exp $
 * Released under the GNU Public License. See LICENSE file for details.
 */
#include <stdlib.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <fcntl.h>
#include <unistd.h>

#include "config.h"
#include "radius.h"
#include "md5.h"
#include "babd.h"

#define RETRANS_TIME	(5*100)		/* 5s */
#define RETRANS_MAX	10

static char radius_secret[256] = "testing123";
static int radius_secret_len = 10;

static RadiusClient *first_radius_client;
static RadiusClient *last_radius_client;

static RadiusClient *first_radius_acct_client;
static RadiusClient *last_radius_acct_client;

RadiusClient::RadiusClient(struct sockaddr_in sin)
{
	m_sockfd = socket(sin.sin_family, SOCK_DGRAM, IPPROTO_UDP);
	if (m_sockfd < 0) {
		perror("RadiusClient::RadiusClient(): socket");
		exit(1);
	}

	if (connect(m_sockfd, (struct sockaddr *)&sin, sizeof(sin)) < 0) {
		perror("RadiusClient::RadiusClient(): connect");
		exit(1);
	}

	this->SelectSetEvents(m_sockfd, SEL_READ);
	fprintf(stderr, "radius started\n");
}

/*
 * the write callback is only installed when we have a packet to transmit, but
 * it was delayed previously.  Hopefully select wakes us up before the request
 * times out.
 */
void RadiusClient::SelectEvent(int fd, SelectEventType event)
{
	if (event & SEL_READ) {
		unsigned char buf[RAD_MAX_PKT_SIZE];
		int l;

		l = read(fd, buf, RAD_MAX_PKT_SIZE);

		if (l > 0 && l < RAD_MAX_PKT_SIZE) {
			unsigned ident = buf[1];
			RadiusReq *req = m_reqs[ident];

			fprintf(stderr, "RadiusClient:SelectEvent %p\n", req);

			if (req) {
				CuscArray *pkt = new CuscArray(buf, (uint)l);
				req->RadiusRsp(pkt);
				delete pkt;
			}
		}
		else if (l < 0)	/* DEBUG */
			perror("read(radius)");
	}

	if (event & SEL_WRITE) {
		/* okay, no more requests, I hope */
		SelectRemoveEvent(fd, SEL_WRITE);
	}
}

void RadiusClient::QueueReq(RadiusReq *req)
{
	for (unsigned i=0; i<RADIUSCLIENT_N_REQS; i++) {
		if (!m_reqs[m_next_ident])
			goto got_ident;
		m_next_ident = (m_next_ident + 1) % RADIUSCLIENT_N_REQS;
	}

	m_queue.Append(req);
	return;

got_ident:
	m_reqs[m_next_ident] = req;
	req->SetIdent(m_next_ident);
	m_next_ident = (m_next_ident + 1) % RADIUSCLIENT_N_REQS;
	Retrans(req);
}

void RadiusClient::Retrans(RadiusReq *req)
{
	if (!req->HaveIdent())
		return;
	int err = write(m_sockfd, req->m_pkt.m_start, req->m_pkt.GetLength());
	if (err < 0)
		perror("RadiusClient::QueueReq(): write");
}

void RadiusReq::TimerExpired(void)
{
	if (m_retrans_count++ < RETRANS_MAX) {
		m_client->Retrans(this);
	}

	/* If we've already hit our callback, the request must have expired. */
	if (m_cfg_done) {
		if (m_client)
			m_client->RemoveReq(this, m_ident);
		m_client = NULL;
		return;
	}

	/* Damn.  Well, callback so the protocol can report that auth failed. */
	RadiusCallback();
}

void RadiusReq::RadiusCallback(void)
{
	if (!m_cfg_done) {
		m_cfg_done = 1;
		if (m_cfg.reserved1)
			m_cfg.reserved1(m_cfg.reserved2, &m_cfg.options);
	}
	delete this;
}

void RadiusReq::RadiusRsp(CuscArray *rsppkt)
{
	u8 code;
	u8 ident;
	u16 len;
	MD5_CTX md5;
	int trailing = rsppkt->GetLength();

	trailing -= 4 + 16;
	if (trailing < 0) {
		fprintf(stderr, "bad radius response\n");
		return;
	}

	MD5Init(&md5);
	MD5Update(&md5, rsppkt->m_start, 4);		/* code+id+length */
	MD5Update(&md5, m_authenticator, 16);
	if (trailing)
		MD5Update(&md5, rsppkt->m_start+4+16, trailing);/* code+id+length */
	MD5Update(&md5, (u8*)radius_secret, radius_secret_len);
	MD5Final(&md5);

	code = rsppkt->Pull8();
	ident = rsppkt->Pull8();
	len = rsppkt->Pull16();

	fprintf(stderr, "RadiusRspr[%d, %d, %d]\n", code, ident, len);

	/* pull off authenticator */

	u8 authenticator[16];
	rsppkt->Pull(authenticator, 16);

	if (memcmp(authenticator, md5.digest, 16)) {
		fprintf(stderr, "radius: bad authenticator\n");
		return;
	}
	

	/* okay, process the attribute-value pairs */
	while (rsppkt->GetLength() >= 2) {
		u8 type = rsppkt->Pull8();
		u8 length = rsppkt->Pull8();

		if (length < 2)
			goto out_badpkt;
		length -= 2;

		if (rsppkt->GetLength() < length || length > 64)
			goto out_badpkt;

		u8 data[256];
		rsppkt->Pull(data, length);

		switch (type) {
		case RT_USER_PASSWORD:
			break;

		case RT_CHAP_PASSWORD:
			break;

		case RT_SERVICE_TYPE:
			if (length != 4)
				goto out_badpkt;
			switch (ntohl(*(u32 *)data)) {
			case 2:	/* Framed==PPP */
				break;
			default:
				goto denied;
			}
			break;

		case RT_FRAMED_IP_ADDRESS:
			if (length != 4)
				goto out_badpkt;
			m_cfg.options.rem_ip = *(u32 *)data;
			break;

		case RT_FRAMED_IP_NETMASK:
			if (length != 4)
				goto out_badpkt;
			m_cfg.options.netmask = *(u32 *)data;
			m_cfg.options.netroute = 1;
			break;

		case RT_FRAMED_COMPRESSION:
			if (length != 4)
				goto out_badpkt;
			switch (ntohl(*(u32 *)data)) {
			case 0:		m_cfg.options.vjc = 0;	break;
			case 1:		m_cfg.options.vjc = 0;	break;
			default:	break;
			}
			break;

		case RT_IDLE_TIMEOUT:
			if (length != 4)
				goto out_badpkt;
			m_cfg.options.droptime = ntohl(*(u32 *)data);
			break;

		case RT_PORT_LIMIT:
			if (length != 4)
				goto out_badpkt;
			m_cfg.options.max_links = ntohl(*(u32 *)data);
			break;

		default:
			break;
		}
	}

	if (code == RC_ACCESS_ACCEPT)
		m_cfg.options.is_valid = 1;
	else
		m_cfg.options.is_valid = 0;
	Stop();
	RadiusCallback();
	return;

denied:
out_badpkt:
	return;
}

static int random_fd = -1;

static void get_random(u8 *buf, int len)
{
	if (random_fd < 0) {
		random_fd = open("/dev/urandom", O_RDONLY);
		if (random_fd < 0) {
			perror("open(/dev/urandom, O_RDONLY)");
			return;
		}
	}

	do {
		int c = read(random_fd, buf, len);
		if (c <= 0) {
			perror("get_random: read");
			break;
		}
		buf += c;
		len -= c;
	} while (len > 0);
}

RadiusReq::RadiusReq(void)
{
	//u8	buf[256];

	m_cfg_done = 0;
	m_have_ident = 0;
	m_retrans_count = 0;
	memset(&m_cfg, 0, sizeof m_cfg);

	/* build the Radius packet */
	get_random(m_authenticator, sizeof(m_authenticator));

	MD5Init(&m_md5);

	MD5Update(&m_md5, (u8*)radius_secret, radius_secret_len);
	MD5Update(&m_md5, m_authenticator, 16);	/* code+id+length */

	MD5Final(&m_md5);
}

void RadiusReq::Go(CfgMessage_t cfgreq)
{
	m_pkt.Put(m_authenticator, 16);
	m_cfg = cfgreq;

	/* User Name attribute */
	m_pkt.Put8(RT_USER_NAME);
	int len = strlen(m_cfg.options.user);
	m_pkt.Put8(len+2);
	m_pkt.Put((u8*)m_cfg.options.user, len);

	/* */
	if (m_cfg.options.proto_id == PAUTH_CHAP) {
		int len = strlen(m_cfg.options.passwd);

		m_pkt.Put8(RT_CHAP_PASSWORD);
		m_pkt.Put8(len+2);
		m_pkt.Put((u8*)m_cfg.options.passwd, len);
	} else {	/* PAP */
		int i, len = strlen(m_cfg.options.passwd);

		m_pkt.Put8(RT_USER_PASSWORD);
		m_pkt.Put8(((len+15)& ~15)+2);
fprintf(stderr, "pap[%s]\n", m_cfg.options.passwd);
		for (i=0; i<len; i++)
			m_pkt.Put8(m_cfg.options.passwd[i] ^ m_md5.digest[i%16]);
		while (i & 0xf)
			m_pkt.Put8(m_md5.digest[i++%16]);
	}

	/* we only do framed ppp */
	m_pkt.Put8(RT_SERVICE_TYPE);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32(2);	/* Framed */

	m_pkt.Put8(RT_FRAMED_PROTOCOL);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32(1);	/* PPP */

	/* we push these onto the front last as we need to know the length */
	m_pkt.Push16(m_pkt.GetLength()+4);
	m_pkt.Push8(0);			/* placeholder for identifier */
	m_pkt.Push8(RC_ACCESS_REQUEST);	/* code */

	Start(RETRANS_TIME);

	m_client = first_radius_client;
	m_client->QueueReq(this);
}

void RadiusReq::Go(AcctMessage_t *acct)
{
	m_pkt.Put(m_authenticator, 16);
	/* User Name attribute */
	m_pkt.Put8(RT_USER_NAME);
	int len = strlen(acct->user);
	m_pkt.Put8(len+2);
	m_pkt.Put((u8*)acct->user, len);

	/* */
	m_pkt.Put8(RT_ACCT_STATUS_TYPE);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32(acct->type);

	char str[256];
	len = sprintf(str, "%s.%s", acct->port, acct->ifname);
	m_pkt.Put8(RT_ACCT_SESSION_ID);
	m_pkt.Put8(2+len);
	m_pkt.Put((u8*)str, len);

	/* byte counters*/
#if 0
	m_pkt.Put8(RT_ACCT_INPUT_GIGAWORDS);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32((u32)(acct->in_octets >> 32));	
#endif
	m_pkt.Put8(RT_ACCT_INPUT_OCTETS);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32((u32)acct->in_octets);	
#if 0
	m_pkt.Put8(RT_ACCT_OUTPUT_GIGAWORDS);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32((u32)(acct->out_octets >> 32));	
#endif
	m_pkt.Put8(RT_ACCT_OUTPUT_OCTETS);
	m_pkt.Put8(6);	/* Length */
	m_pkt.Put32((u32)acct->out_octets);	

	/* we push these onto the front last as we need to know the length */
	m_pkt.Push16(m_pkt.GetLength()+4);
	m_pkt.Push8(0);				/* placeholder for identifier */
	m_pkt.Push8(RC_ACCOUNTING_REQUEST);	/* code */

	Start(RETRANS_TIME);

	m_client = first_radius_acct_client;
	m_client->QueueReq(this);
}

void RadiusReq::SetIdent(u8 ident)
{
	fprintf(stderr, "SetIdent(0x%02x)\n", ident);
	if (m_pkt.GetLength() >= 2) {
		m_pkt.m_start[1] = ident;
		m_have_ident = 1;
	}
}

void radius_acct(AcctMessage_t *acct)
{
	RadiusReq *req = new RadiusReq();
	req->Go(acct);
}

/*=========================================*/
void setup_radius(const char *str, int acct)
{
	struct sockaddr_in sin;

	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = htonl(strtoip(str));
	sin.sin_port = acct ? htons(1813) : htons(1812);

	RadiusClient *client = new RadiusClient(sin);

	RadiusClient **first = acct ? &first_radius_acct_client
				    : &first_radius_client;
	RadiusClient **last = acct ? &last_radius_acct_client
				   : &last_radius_client;

	if (!*first)
		*first = client;
	if (*last)
		(*last)->m_next = client;
	*last = client;

	if (acct)
		am_radius_acct_client = 1;
	else
		am_radius_client = 1;
}

