#include "ldapdns.h"
#include "ip.h"
#include "env.h"
#include "config.h"
#include "sio.h"
#include "bin.h"

static int server_fd = -1;
static char tcpserver_ip[IP_LEN];
static int tcpserver_port;
static bin_t tcpserver_active;
#ifdef HAVE_IPV6
static int using_6 = 0;
#endif

void tp_close(dns_ctx *c)
{
	if (server_fd != -1) {
		/* shutdown this connection */
		pthread_mutex_lock(&c->lock);
		close(c->sock);
		c->sock = -1;
		pthread_mutex_unlock(&c->lock);
	} else {
		/* then exit; we're a uniprocess hosted by
		 * tcpserver or something */
		exit(0);
	}
}
void inline tp_housekeeping(long *now)
{
	dns_ctx *p;

	if (server_fd == -1)
		return;

	/* run a poll on all sockets */
reload_l:
	sio_flush(tcpserver_active);
	sio_add(tcpserver_active, server_fd, sio_read);

	for (p = handler; p; p = p->next) {
		if (p->sock == -1) continue;
		sio_add(tcpserver_active, p->sock, sio_read);
	}

	if (ldapdns.timeout_tcp > 0) {
		if (sio_block(tcpserver_active, ldapdns.timeout_tcp) == 0) {
			goto reload_l;
		}
	} else {
		sio_block(tcpserver_active, sio_infinity);
	}
}

void tp_initialize(void)
{
	char *x;
	int port;

	server_fd = -1;
	ldapdns.dns_threads = 1;	/* always 1; we use select */

	x = env_get("TCPREMOTEIP");
	if (!x || !ipv4_scan(x, tcpserver_ip) ) {
		tcpserver_ip[0] = 0;
		tcpserver_ip[1] = 0;
		tcpserver_ip[2] = 0;
		tcpserver_ip[3] = 0;
	} else {
		/* tcpserver or clone */
		ldapdns.ldap_threads = 1;
		ldapdns.handlers = 1;

		x = env_get("TCPREMOTEPORT");
		if (!x) {
			tcpserver_port = 0;
		} else if ((tcpserver_port = atoi(x)) < 1) {
			tcpserver_port = 0;
		}
		return;
	}

	if (socket_peer4(0, tcpserver_ip, &tcpserver_port)) {
		/* okay, we're running xinetd */
		ldapdns.ldap_threads = 1;
		ldapdns.dns_threads = 1;
		ldapdns.handlers = 1;
		return;
	}

	/* okay, we're not attached to a socket; let's change that */
	x = env_get("IP");
	if (!x)
		x = "0.0.0.0";
	if (!ipv4_scan(x, tcpserver_ip)) {
#ifdef HAVE_IPV6
		if (ipv6_scan(x, tcpserver_ip))
			using_6 = 1;
		else
#endif
		fatal("cannot parse IP: %s", x);
	}

	x = env_get("PORT");
	if (!x)
		port = 53;
	else {
		port = atoi(x);
		if (port < 1)
			fatal("cannot parse PORT: %s", x);
		if (port != 53)
			warning("running on non-standard port: %d", port);
	}

#ifdef HAVE_IPV6
	if (using_6)
		server_fd = socket_tcp6();
	else
#endif
	server_fd = socket_tcp4();
	if (server_fd == -1)
		cfatal("socket_tcp: %s");
#ifdef HAVE_IPV6
	if (using_6) {
		if (socket_bind6_reuse(server_fd, tcpserver_ip, port) == -1)
			cfatal("socket_bind4_reuse: %s");
	} else
#endif
	if (socket_bind4_reuse(server_fd, tcpserver_ip, port) == -1)
		cfatal("socket_bind4_reuse: %s");
	socket_listen(server_fd);
	ndelay_on(server_fd);
	bin_init(tcpserver_active);
}
int inline tp_write(dns_ctx *c)
{
	int x, r;
	unsigned short ntcplen;

	/* we lock here to make certain our socket isn't pulled out from
	 * under us
	 */
	pthread_mutex_lock(&c->lock);
	if (c->sock == -1) {
		pthread_mutex_unlock(&c->lock);
		return 0;
	}

	/* Clib */
	ntcplen = ntohs(clen(c->response));

	/* this COULD hang. however, i don't find it too likely to cause
	 * any problems: the client will hangup eventually and we'll get EPIPE
	 */

	do {
		r = write(c->sock, &ntcplen, 2);
	} while (r == -1 && errno == EINTR);
	pthread_mutex_unlock(&c->lock);

	for (x = 0; x < clen(c->response);) {
		/* we're spinning this lock because this process cannot
		 * tell when hangup occurs
		 */
		pthread_mutex_lock(&c->lock);
		if (c->sock == -1) {
			/* stolen out from under us */
			pthread_mutex_unlock(&c->lock);
			return 0;
		}
		/* still valid */
		r = write(c->sock,
			caddr(c->response) + x,
			clen(c->response) - x);
		pthread_mutex_unlock(&c->lock);

		if (r == -1) {
			if (errno == EPIPE || errno == EBADF || errno == EINVAL || errno == EFAULT) {
				/* failed output */
				return 0;
			}
			continue;
		}

		if (!r) {
			/* sleep for a moment;
			 * just in case it was kernel related */
			usleep(100);
		}

		x += r;
	}
	return 1;
}
static int inline trash_message (dns_ctx *c)
{
	register int i, j;

	if (c->tcplen == 0) return 0;
	if (c->tcppos < c->tcplen) return 0;

	/* move memory in request_buf */
	j = c->tcplen+2;
	for (i = 0; i < j;) {
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
		c->request_buf[i] = c->request_buf[j]; i++; if (i == j) break;
	}
	c->tcppos -= j;
	c->tcplen = 0;
	c->request_len = 0;
	c->request_pos = 0;
	if (c->tcppos == 0)
		return 0;
	return 1;
}
int inline tp_read(dns_ctx *c)
{
	int fd, len = 0;
	unsigned short ntcplen;
	list_t ax;
	str_t retbuf;

	if (server_fd == -1) {
		c->ip[0] = tcpserver_ip[0];
		c->ip[1] = tcpserver_ip[1];
		c->ip[2] = tcpserver_ip[2];
		c->ip[3] = tcpserver_ip[3];
		c->port = tcpserver_port;
		c->sock = 1;
		fd = 0;
	} else {
		/* could be hung at this point */
		pthread_mutex_lock(&c->lock);
		if (c->sock == -1) {
			pthread_mutex_unlock(&c->lock);

			/* see if we can accept server_fd */
			if (!sio_test(tcpserver_active, server_fd)) {
				/* we'll come back later */
				return -1;
			}
			sio_remove(tcpserver_active, server_fd);

			/* we'll have to do a round before this guy can have
			 * activity performed
			 */
#ifdef HAVE_IPV6
			if (using_6) {
				c->sock = socket_accept6(server_fd, c->ip, &c->port);
			} else
				c->sock = socket_accept4(server_fd, c->ip, &c->port);
#endif
			c->tcppos = 0;
			c->tcplen = 0;
			c->request_pos = 0;
			return 0;
		} else {
			/* both sides are socket */
			if (!sio_test(tcpserver_active, c->sock)) {
				pthread_mutex_unlock(&c->lock);
				if (trash_message(c))
					goto past_read_shortcut_l;
				return 0;
			}
			sio_remove(tcpserver_active, c->sock);

			/* fall through */
			fd = c->sock;
			pthread_mutex_unlock(&c->lock);
		}
	}

	trash_message(c);
reread_shortcut_l:
	do {
		/* fd cannot be stolen out from under us;
		 * because the thread that does this is the only
		 * one that closes the socket
		 */
		len = read(fd, /* stdin */
				c->request_buf + c->tcppos,
				512 - c->tcppos);
		/* infinite if 0 hangs up */
		if (len == -1) {
			/* we'll be back */
			if (errno == EAGAIN || errno == EINTR) {
				return 0;
			}

			if (errno == EBADF || errno == EINVAL
					|| errno == EFAULT)
				goto FATAL;
		} else if (len == 0) {
			/* hung up */
			goto FATAL;
		}
	} while (len <= 0);
	if (len >= 512) {
		warning("read too long");
		goto FATAL;
	}
past_read_shortcut_l:
	if (c->tcplen == 0) {
		/* the BEGINNING of a tcp message is packet length */
		c->request_len += len; /* fudge for now */
		if (!dns_packet_copy(c, (char *)&ntcplen, 2)) {
			/* not enough room yet... */
			if (server_fd == -1)
				goto reread_shortcut_l;
			return 0;
		}
		/* Clib */
		c->tcplen = ntohs(ntcplen);
		if (c->tcplen > 512) {
			warning("read too long");
			goto FATAL;
		}
	}

	c->tcppos += len;
	if (c->tcppos < c->tcplen) {
		if (server_fd == -1)
			goto reread_shortcut_l;
		/* incomplete message */
		return 0;
	}
	if (ldapdns.update)
		c->update = str_dup(ldapdns.update);
	else
		c->update = 0;
	if (ldapdns.axfr_base)
		c->axfr_base = str_dup(ldapdns.axfr_base);
	else
		c->axfr_base = 0;
	c->request_len = c->tcplen + 2;

	/* not in server mode? we don't need to do anything else */
	if (server_fd == -1)
		return 1;

	/* calculated AXFR support */
	for (ax = ldapdns.swaxfr; ax; ax = ax->next) {
		if (!ax->str) continue;
		if (ax->str[0] == 0x04) {
			if (ipv4_in_subnet(ax->str+1, c->ip)) {
				if (!ax->str[9]) {
					if (c->axfr_base) mem_free(c->axfr_base);
					c->axfr_base = 0;
					if (!ipv4_null(ax->str+1))
						break;
					continue;
				}
				name_to_dns(retbuf, ax->str + 9);
				if (c->axfr_base) mem_free(c->axfr_base);
				c->axfr_base = str(retbuf);
				if (!ipv4_null(ax->str+1))
					break;
			}
#ifdef HAVE_IPV6
		} else if (ax->str[0] == 0x06) {
			if (ipv6_in_subnet(ax->str+1, c->ip)) {
				if (!ax->str[33]) {
					if (c->axfr_base) mem_free(c->axfr_base);
					c->axfr_base = 0;
					if (!ipv4_null(ax->str+1))
						break;
					continue;
				}
				name_to_dns(retbuf, ax->str + 33);
				if (c->axfr_base) mem_free(c->axfr_base);
				c->axfr_base = str(retbuf);
				if (!ipv6_null(ax->str+1))
					break;
			}
#endif
		}
	}

	return 1;
FATAL:
	if (server_fd == -1)
		exit(0);

	/* lock to remove */
	pthread_mutex_lock(&c->lock);
	close(c->sock);
	c->sock = -1;
	pthread_mutex_unlock(&c->lock);
	return 0;
}

