/**
 ** simple C server example
 **
 ** Its strategy is to sit waiting in select() listening on
 ** all sockets; whenever input is available, control is passed to
 ** the appropriate handler.
 **
 ** This version uses the FIONRED ioctl to work out how many bytes
 ** to read; a safer alternative is to use non-blocking I/O.
 **/

#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h> /* for FIONREAD (sigh) */
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#include <string.h>
#include <strings.h>

#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#ifndef FIONREAD
# include <sys/filio.h> /* Solaris 2 puts it here */
#endif

/** BACKLOG is the number of pending connections to allow in accept().
 ** Many systems have a maxiumum of 5 for this value.
 **/
#define BACKLOG 5

/** MAXHOSTNAME is the longest host name we support, in bytes: */
#define MAXHOSTNAME 256

static char localhost[MAXHOSTNAME + 1];

/** PORT is the network port on which to listen. **/
#define PORT	7285

char *progname = "$Id$";

static int ClientCount = 0;
int PortFd;
static char *DefaultLocalHost = 0;

/** declare internal functions */
static int ReadMessage(int Socket, int Output);
static void startLog(int ClientSocket, int transactionCount);
static void endLog(int ClientSocket);
static int StartServerConnections(char *ServerHost, int ServerPort);
static int OpenConnectionToServer(char *ServerHost, int ServerPort);

int
main(argc, argv)
    int argc;
    char *argv[];
{
    char *DefaultHost = "localhost";
    char *RemoteHost = "chow.groveware.com";
    int DefaultPort = PORT;
    int Socket;
    int i;
    int width;
    fd_set readfds;
    fd_set ReadList;
    fd_set ClientList;
    struct timeval timeout;

    progname = argv[0];

    for (i = 1; i < argc; i++) {
	if (argv[i][0] != '-') {
	    break;
	}
	if (!strcmp(argv[i], "-display") || !strcmp(argv[i], "-dpy")) {
	    char *p = argv[i + 1];
	    RemoteHost = p;
	    while (*p) {
		if (*p == ':') {
		    *p = '\0';
		    p++;
		    if (*p) {
			int port;

			port = atoi(p);
			RemotePort += port;
		    }
		    break;
		}
		p++;
	    }
	    i++; /* skip the argument */
	} else if (!strcmp(argv[i], "-remotehost")) {
	    RemoteHost = argv[i + 1];
	    i++;
	} else if (!strcmp(argv[i], "-remoteport")) {
	    RemotePort = atoi(argv[i + 1]);
	    i++;
	} else if (!strcmp(argv[i], "-localport")) {
	    DefaultPort = atoi(argv[i + 1]);
	    i++;
	} else if (!strcmp(argv[i], "-localhost")) {
	    DefaultLocalHost = argv[i + 1];
	    i++;
	} else if (!strcmp(argv[i], "-log")) {
	    doLogging = 1;
	} else {
	    fprintf(stderr, "usage: %s [-remotehost name|IP] [-remoteport N] [-localport N] [-localhost name|IP] [-log]\n",
		progname
	    );
	    exit(2);
	}
    }

    if ((Socket = StartServerConnections(DefaultHost, DefaultPort)) < 0) {
	fprintf(stderr, "%s: could not listen on port, exiting\n", progname);
	exit(1);
    }

    PortFd = Socket;

    width = getdtablesize(); /* or 20 if you don't have getdtablesize() */
    OtherEnd = (int *) calloc(width * 8, sizeof(int));
    logFileDescriptorTable = (int *) calloc(width * 8, sizeof(int));
    if (!OtherEnd || !logFileDescriptorTable) {
	fprintf(stderr, "failed to allocate %d entries for file descriptors\n",
	    width
	);
	exit(1);
	/* FIXME: reduce width to 255 (say) and try again */
    }

    FD_ZERO(&ReadList);
    FD_ZERO(&ClientList);
    FD_SET(Socket, &ReadList);

    OtherEnd[Socket] = -1;

    for (;;) {
	int ClientSocket;
	struct sockaddr_in isa;
	int i;
	int nReady;

	FD_ZERO(&readfds);
	for (i = 0; i < width; i++) {
	    if (FD_ISSET(i, &ReadList)) {
		FD_SET(i, &readfds);
	    }
	}

	if (ClientCount == 0) {
	    nReady = select(width, &readfds, (fd_set *) 0, (fd_set *) 0,
						(struct timeval *) 0);
	} else {
	    timeout.tv_sec = 3600L; /* after an hour of solitde, die */
	    timeout.tv_usec = 0;

	    nReady = select(width, &readfds, (fd_set *) 0, (fd_set *) 0, &timeout);
	}

	if (nReady == 0) {
	    fprintf(stderr, "Got a timeout...");
	    exit(1);
	}

	/* Socket hre is the Socket associated with the tcp/ip port */
	if (FD_ISSET(Socket, &readfds)) {
	    int x;
	    int sz = (int) sizeof(isa);

	    FD_CLR(Socket, &readfds);
	    --nReady;

	    ClientSocket = accept(
		Socket, (struct sockaddr *) &isa, &sz
	    );

	    if (ClientSocket < 0) {
		perror("accept()");
		continue;
	    }
	    /* got a new client, file descriptor is ClientSocket */

	    /* try and open a socket to the remote server */
	    x = OpenConnectionToServer(RemoteHost, RemotePort);
	    if (x != -1) {

#ifdef DEBUG
		fprintf(stderr, "[%d, other end is %d]\n", ClientSocket, x);
#endif

		if (OtherEnd[x] != 0) {
		    fprintf(stderr,
			"%s: Internal error, other end of %d is %d\n",
			progname, x, OtherEnd[x]
		    );
		    exit(1);
		}

		if (OtherEnd[ClientSocket] != 0) {
		    fprintf(stderr,
			"%s: Internal error %d, other end of %d is %d\n",
			progname, ClientSocket, __LINE__, OtherEnd[ClientSocket]
		    );
		    exit(1);
		}

		OtherEnd[x] = ClientSocket;
		OtherEnd[ClientSocket] = x;

		/* add them to the list of things that are interesting */
		FD_SET(ClientSocket, &ReadList);
		FD_SET(x, &ReadList);

		/* and mark that it is done for this time round the loop: */
		FD_CLR(ClientSocket, &readfds);

		/* and we have a new client: */
		++ClientCount;

		++transactionCount;
		/* log if necessary */
		if (doLogging) {
		    startLog(ClientSocket, transactionCount);
		}
	    }
	}

	{
	    int Count = nReady;

	    for (i = 0; i < width; i++) {
		if (FD_ISSET(i, &readfds)) {
		    int Output;

		    Output = OtherEnd[i];
		    if (ReadMessage(i, Output) < 0) {
#ifdef DEBUG
	fprintf(stderr, "Client %d ReadMessage < 0\n", Socket);
#endif
			if (doLogging) {
			    endLog(ClientSocket);
			}
			(void) close(i);
			FD_CLR(i, &ReadList);
			(void) close(Output);
			OtherEnd[Output] = 0; /* this end! */
			FD_CLR(Output, &ReadList);
			OtherEnd[i] = 0;
			if (--ClientCount <= 0) {
			    ClientCount = 0;
#ifdef DIE_ON_LAST_EXIT
			    fprintf(stderr, "\n\r%s: last client died.\r\n",
				progname);
			    exit(0);
#else
			    break;
#endif
			}
		    }

		    --Count;
		}
		if (Count <= 0) break;
	    }
	}

    } /* endfor */
    /*NOTREACHED*/
    return 0;
}

int
StartServerConnections(ServerHost, ServerPort)
    char *ServerHost;
    int ServerPort;
{
    extern char *progname;
    extern int errno;
    extern struct servent *getservbyname();

    struct sockaddr_in sa;
    struct hostent *hp;
    int Socket = -1;

    if (!ServerPort) ServerPort = PORT;

    (void) bzero((char *)&sa, sizeof sa);

    if (DefaultLocalHost) {
	(void) strcpy(localhost, DefaultLocalHost);
    } else {
	localhost[0] = '\0'; /* paranoia in case the next call fails: */
	(void) gethostname(localhost, MAXHOSTNAME);
    }

    if ((hp = gethostbyname(localhost)) == NULL) {
	(void) fprintf(stderr, "%s: can't get local host info (%s)\n",
				progname, localhost);
	exit(1);
    }

    sa.sin_port = htons(ServerPort);
    sa.sin_family = hp->h_addrtype;

    if ((Socket = socket(hp->h_addrtype, SOCK_STREAM, 0)) < 0) {
	int e = errno;
	(void) fprintf(stderr, "%s: socket() failed [%d]: ", progname, e);
	errno = e;
	perror("socket");
	return -1;
    }

#if 0 /* why commented out?? */
    (void) bcopy((char *) hp->h_addr, (char *) &sa.sin_addr, hp->h_length);
#endif


    /* allow us to use the address without waiting: */
    {
	int level = SOL_SOCKET;
	int True = 1;

	if (
	    setsockopt(Socket, level, SO_REUSEADDR, (void *) &True, sizeof(True))
	    <
	    0
	) {
	    perror("warning: setsockopt for SO_REUSEADDR");
	}
    }

    /* On 4.2BSD add an extra argument of 0 to bind() */
    if (bind(Socket, (struct sockaddr *) &sa, sizeof sa) < 0) {
	int e = errno;
	(void) fprintf(stderr, "%s: bind() failed [%d]: ", progname, e);
	errno = e;
	perror("");
	return -1;
    }

    if (listen(Socket, BACKLOG) < 0) {
	int e = errno;
	(void) fprintf(stderr, "%s: ", progname);
	errno = e;
	perror("listen");
	return -1;
    }

    fprintf(stderr, "%s: waiting on %s port %d\n",
				progname, localhost, ServerPort);

    return Socket;
}

static int
ReadMessage(Socket, Output)
    int Socket;
    int Output;
{
    int BytesAvailable = -1; /* so we can tell if it was changed */
    int i;
    char buf[BUFSIZ];
    int BytesToRead;

    if (ioctl(Socket, FIONREAD, &BytesAvailable) == -1 || BytesAvailable <= 0) {
	/* The end of the sock...
	 * our client has died...
	 */
#ifdef DEBUG
	fprintf(stderr, "Client %d has died.\n", Socket);
#endif
	return -1;
    }

    while (BytesAvailable > 0) {
	BytesToRead = BytesAvailable;
	if (BytesToRead >= BUFSIZ) {
	    BytesToRead = BUFSIZ;
	}

	if ((i = read(Socket, buf, BytesToRead)) <= 0) {
	    if (i < 0) {
		perror("socket read");
		return -1;
	    }
	    return 0;
	}
	BytesAvailable -= i;

	/* now send the data on */
	{
	    int nWritten;
	    int nTries = i;

	    while (i > 0) {
	    	nWritten = write(Output, buf, i);
		if (nWritten < 0) {
		    perror("write");
		    return -1;
		}
		i -= nWritten;
		if (--nTries <= 0) {
		    break;
		}
	    }
	}

	if (doLogging) {
	    (void) write(logFileDescriptorTable[Output], buf, i);
	}
    }
    return 0;
}


int
OpenConnectionToServer(ServerHost, ServerPort)
    char *ServerHost;
    int ServerPort;
{
    int Socket;
    struct sockaddr_in sa;
    struct hostent *hp;
    extern int errno;

    if ((hp = gethostbyname(ServerHost)) == NULL) {
	(void) fprintf(stderr, "%s: lookup for host \"%s\" failed\n",
						progname, ServerHost);
	exit(1);
    }

    (void) bzero(&sa, sizeof sa);
    (void) bcopy((char *) hp->h_addr, (char *) &sa.sin_addr, hp->h_length);

    sa.sin_family = hp->h_addrtype;

    sa.sin_port = htons(ServerPort);

    if ((Socket = socket(hp->h_addrtype, SOCK_STREAM, 0)) < 0) {
	perror("socket");
	exit(1);
    }

#ifdef DEBUG
    (void) fprintf(stderr,
	"%s: contacting %s:%s, socket %d port %d\n",
	progname, inet_ntoa(sa.sin_addr), ServerHost, Socket, ServerPort
    );
#endif

    errno = 0;
    if (connect(Socket, (struct sockaddr *)&sa, sizeof sa) < 0) {
	extern int errno;
	int e = errno;

	(void) fprintf(stderr, "%s: connect failed -- error %d: ", progname, e);
	errno = e;
	perror("connect");
	exit(1);
    }

    return Socket;
}

static void
startLog(ClientSocket, transactionCount)
    int ClientSocket;
    int transactionCount;
{
    int fd;
    char newFileName[1000];

    (void) sprintf(newFileName, logFileTemplate, transactionCount);
    fd = open(newFileName, O_CREAT|O_RDWR, 0644);

    if (!fd) {
	perror("newFileName");
	return;
    }

    /* remember the log file for both ends of the socket */
    logFileDescriptorTable[ClientSocket] = fd;
    logFileDescriptorTable[OtherEnd[ClientSocket]] = fd;

    {
	char firstLogLine[3000];
	int i;
	int len;

	(void) sprintf(firstLogLine,
	    "Log for fds %d and %d in %s\n",
	    ClientSocket, OtherEnd[ClientSocket], newFileName
	);
	len = strlen(firstLogLine);
	i = write(fd, firstLogLine, len);

	if (i != len) {
	    extern int errno;
	    int e = errno;

	    if (errno == ENOSPC) {
		fprintf(stderr,
		    "warning: no disk space left; disabling logging\n");
		doLogging = 0;

		return;

	    } else if (i == -1) {
		fprintf(stderr, "warning: can't log to \"%s\": ", newFileName);
	    } else {
		fprintf(stderr,
		    "warning: wrote only %d initial log bytes to \"%s\", not %d:",
		    i, newFileName, len
		);
	    }
	    errno = e;
	    perror("write");
	}
    }
}

static void
endLog(ClientSocket)
    int ClientSocket;
{
    int logfd = logFileDescriptorTable[ClientSocket];

    (void) close(logfd);
    logFileDescriptorTable[ClientSocket] = -1;
    logFileDescriptorTable[OtherEnd[ClientSocket]] = -1;
}

