/*
  yubikey-server-c — a yubikey validation server written in C
  Copyright (C) 2009 Tollef Fog Heen <tfheen@err.no>
  
  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License version 2 as
  published by the Free Software Foundation.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License along
  with this program; if not, write to the Free Software Foundation, Inc.,
  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#define _GNU_SOURCE

#include <errno.h>
#include <assert.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <stdarg.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <sys/types.h>
#include <sys/select.h>
#include <arpa/inet.h>
#include <syslog.h>
#include <pwd.h>
#include <grp.h>
#include <microhttpd.h>
#include <libpq-fe.h>
#include <yubikey.h>
#include <gcrypt.h>

#include "util.h"
#include "config.h"

#ifdef UNUSED
#elif defined(__GNUC__)
# define UNUSED(x) UNUSED_ ## x __attribute__((unused))
#elif defined(__LCLINT__)
# define UNUSED(x) /*@unused@*/ x
#else
# define UNUSED(x) x
#endif

struct ykc_config {
	char *pidfile;
	char *user;
	int uid;
	char *group;
	int gid;
	char *dbdef;
	int port;
	int log_otp;
};

PGconn     *db_conn;

struct ykc_stats {
	int active;
	char *public_id;
	char *secret_uid;
	char *secret_key;
	int session_counter;
	int session_use;
};

char *get_timestamp(void) {
	size_t len = strlen("YYYY-mm-ddTHH:MM:SSZMSMS");
	char *ts = malloc(len + 1);

	ysc_strftime(ts, len + 1, "%FT%TZ%v", NULL);
	return ts;
}

int validate_signature(const char *key, size_t key_len, const char *h,
		       const char *id, const char *otp)
{
	char *line = NULL;
	char *our_sig = NULL;
	gcry_md_hd_t hd;
	int r = 0;

	asprintf(&line, "id=%s&otp=%s", id, otp);
	if (line == NULL) {
		r = -1;
		goto free_mem;
	}
	gcry_md_open(&hd, GCRY_MD_SHA1, GCRY_MD_FLAG_HMAC);
	if (hd == NULL) {
		r = -1;
		goto free_mem;
	}

	gcry_md_setkey(hd, key, key_len);
	gcry_md_write(hd, line, strlen(line));
	gcry_md_final(hd);
	our_sig = ysc_b64_encode((char *) gcry_md_read(hd, 0),
				 gcry_md_get_algo_dlen(GCRY_MD_SHA1));
	if (our_sig == NULL || strcmp(our_sig, h) != 0) {
		r = -1;
		goto free_mem;
	}

free_mem:
	gcry_md_close(hd);
	free(line);
	free(our_sig);

	return r;
}

char *sign_request(char *key, size_t key_len, char *info, char *status,
		   char *timestamp) {
	char *line;
	char *ret;
	gcry_md_hd_t hd;
	if (info != NULL) {
		asprintf(&line, "info=%s&status=%s&timestamp=%s", info,
			 status, timestamp);
	} else {
		asprintf(&line, "status=%s&timestamp=%s", status, timestamp);
	}
	/* XXX Check memory allocation */
	gcry_md_open(&hd, GCRY_MD_SHA1, GCRY_MD_FLAG_HMAC);
	gcry_md_setkey(hd, key, key_len);
	gcry_md_write(hd, line, strlen(line));
	gcry_md_final(hd);
	ret = ysc_b64_encode((char *) gcry_md_read(hd, 0),
			     gcry_md_get_algo_dlen(GCRY_MD_SHA1));
	gcry_md_close(hd);
	free(line);
	return ret;
}

static int send_response(struct MHD_Connection *conn,
			 const char *signature,
			 const char *status,
			 const char *info,
			 const char *timestamp)
{
	char *resp_text, *t;
	int r;
	size_t r_l;
	struct MHD_Response *response;

	r_l = strlen("h=\nstatus=\ntimestamp=\ninfo="); /* This is a maximum
						* of static strings */
	r_l += (signature != NULL ? strlen(signature) : 0);
	r_l += (status != NULL ? strlen(status) : 0);
	r_l += (info != NULL ? strlen(info) : 0);
	r_l += (timestamp != NULL ? strlen(timestamp) : 0);

	resp_text = malloc(r_l + 1);
	t = resp_text;
	if (signature) {
		t += sprintf(t, "h=%s\n", signature);
	}
	if (info) {
		t += sprintf(t, "info=%s\n", info);
	}
	if (timestamp) {
		t += sprintf(t, "timestamp=%s\n", timestamp);
	}
	if (status) {
		t += sprintf(t, "status=%s\n", status);
	}
	/* XXX error checking above */

	response = MHD_create_response_from_data(strlen(resp_text), resp_text,
						 MHD_YES, MHD_YES);
	r = MHD_queue_response(conn, MHD_HTTP_OK, response);
	MHD_destroy_response(response);
	free(resp_text);
	if (r == MHD_YES)
		return 0;
	return -1;
}

static int get_shared_secret(const char *id, char **shared_secret,
			     size_t *shared_secret_len)
{
	const char *paramValues[1];
	PGresult   *res;
	int r = 0;

	/* Do query to grab shared secret, we need this later anyway */
	paramValues[0] = id;
	res = PQexecParams(db_conn,
			   "SELECT secret FROM shared_secret WHERE secret_id = $1",
			   1,       /* one param */
			   NULL,    /* let the backend deduce param type */
			   paramValues,
			   NULL,    /* don't need param lengths since text */
			   NULL,    /* default to all text params */
			   1);      /* ask for binary results */

	if (PQresultStatus(res) != PGRES_TUPLES_OK) {
		syslog(LOG_ERR, "Failed to get shared secret for id=%s: %s",
		       id, PQerrorMessage(db_conn));
		r = -1;
		goto free_mem;
		/* XXX Return error object */
	}
	if (PQntuples(res) == 0) {
		syslog(LOG_INFO, "No such id: %s", id);
		/* XXX Better handling */
		r = -1;
		goto free_mem;
	}
	*shared_secret_len = PQgetlength(res, 0, 0);
	*shared_secret = malloc(*shared_secret_len);
	if (*shared_secret == NULL) {
		*shared_secret_len = 0;
		goto free_mem;
	}
	memcpy(*shared_secret, PQgetvalue(res, 0, 0), *shared_secret_len);

free_mem:
	PQclear(res);
	return r;
}

static int split_otp(const char *otp, char **user, char **s_otp)
{
	size_t otp_len, i, j;
	/* Modhex doubles the length of the string, so grab the last
	 * YUBIKEY_BLOCK_SIZE * 2 octets to get the actual OTP */
	otp_len = strlen(otp);

	assert(otp_len > YUBIKEY_BLOCK_SIZE * 2);
	assert(yubikey_modhex_p(otp));

	i = otp_len - YUBIKEY_BLOCK_SIZE * 2;
	*user = malloc(i+1);
	if (*user == NULL) {
		return -1;
	}
	memcpy(*user, otp, i);
	(*user)[i] = '\0';

	j = YUBIKEY_BLOCK_SIZE * 2;
	*s_otp = malloc(j+1);
	if (*s_otp == NULL) {
		free(*user);
		return -1;
	}
	memcpy(*s_otp, otp + i, j);
	(*s_otp)[j] = '\0';

	return 0;
}

static int set_data_for_uid(char *uid, struct ykc_stats *stats)
{
	PGresult *res;
	const char *paramValues[3];
	char ctr[10], use[10];
	int r = 0;

	paramValues[0] = uid;

	snprintf(ctr, sizeof (ctr), "%d", stats->session_counter);
	snprintf(use, sizeof (use), "%d", stats->session_use);

	paramValues[1] = ctr;
	paramValues[2] = use;
	res = PQexecParams(db_conn,
			   "UPDATE yubikey SET session_counter = $2, "
			   "session_use = $3 WHERE public_id = $1",
			   3,       /* number of params */
			   NULL,    /* let the backend deduce param type */
			   paramValues,
			   NULL,
			   NULL,
			   1);      /* ask for binary results */

	if (PQresultStatus(res) != PGRES_COMMAND_OK) {
		syslog(LOG_ERR, "UPDATE for %s failed: %s", uid,
		       PQerrorMessage(db_conn));
		r = -1;
		goto free_mem;
	}
	assert(PQntuples(res) == 0);

free_mem:
	PQclear(res);
	return r;
}

static int get_data_for_uid(char *uid, struct ykc_stats *stats)
{
	PGresult *res;
	const char *paramValues[1];
	char *tmp;
	int r = 0;

	paramValues[0] = uid;
	res = PQexecParams(db_conn,
			   "SELECT active, secret_uid, secret_key, "
			   "session_counter, session_use FROM yubikey "
			   "WHERE public_id = $1",
			   1,       /* one param */
			   NULL,    /* let the backend deduce param type */
			   paramValues,
			   NULL,
			   NULL,    /* default to all text params */
			   1);      /* ask for binary results */

	if (PQresultStatus(res) != PGRES_TUPLES_OK)
	{
		syslog(LOG_ERR, "Failed to get shared secret for uid=%s: %s",
		       uid, PQerrorMessage(db_conn));
		r = -1;
		goto free_mem;
	}
	if (PQntuples(res) == 0) {
		syslog(LOG_INFO, "uid %s not found in database", uid);
		/* XXX Better handling */
		r = -1;
		goto free_mem;
	}
	assert(PQgetlength(res, 0, PQfnumber(res, "secret_key")) == YUBIKEY_KEY_SIZE);
	stats->secret_key = ysc_memdup(
		PQgetvalue(res, 0, PQfnumber(res, "secret_key")),
		PQgetlength(res, 0, PQfnumber(res, "secret_key")));
	if (stats->secret_key == NULL) {
		r = -1;
		goto free_mem;
	}

	stats->secret_uid = ysc_memdup(
		PQgetvalue(res, 0, PQfnumber(res, "secret_uid")),
		PQgetlength(res, 0, PQfnumber(res, "secret_uid")));
	if (stats->secret_uid == NULL) {
		r = -1;
		goto free_mem;
	}

	tmp = PQgetvalue(res, 0, PQfnumber(res, "session_counter"));
	assert(tmp != NULL); /* DB schema should enforce this */
	stats->session_counter = ntohl(*((uint32_t *) tmp));
	tmp = PQgetvalue(res, 0, PQfnumber(res, "session_use"));
	assert(tmp != NULL); /* DB schema should enforce this */
	stats->session_use = ntohl(*((uint32_t *) tmp));
	stats->public_id = NULL;

free_mem:
	PQclear(res);
	return r;
}

static int handle_request(void * priv,
			  struct MHD_Connection *conn,
			  const char * url,
			  const char *UNUSED(method),
			  const char *UNUSED(version),
			  const char *UNUSED(upload_data),
			  size_t *UNUSED(upload_data_size),
			  void **UNUSED(con_cls))
{
	const char *id = NULL, *otp = NULL, *h = NULL;
	char *uid = NULL, *otp_token = NULL;
	char *signature = NULL, *status = NULL, *info = NULL, *timestamp = NULL;
	char *shared_secret = NULL;
	size_t shared_secret_len;
	yubikey_token_st token;
	struct ykc_stats stats;
	struct ykc_config *conf = (struct ykc_config*) priv;
	memset(&token, '\0', sizeof(token));
	memset(&stats, '\0', sizeof(stats));

	timestamp = get_timestamp();
	assert(timestamp != NULL);
	/* Parse query string, grab id, otp and h (optional) */

	id = MHD_lookup_connection_value(conn, MHD_GET_ARGUMENT_KIND, "id");
	otp = MHD_lookup_connection_value(conn, MHD_GET_ARGUMENT_KIND, "otp");
	h = MHD_lookup_connection_value(conn, MHD_GET_ARGUMENT_KIND, "h");
	syslog(LOG_DEBUG, "Got new connection with parameters: "
	       "url=%s id=%s otp=%s, h=%s\n", url, id,
	       (conf->log_otp ? otp : "<hidden>"),  h);

	/* Do query to grab shared secret, we need this later anyway */
	if (get_shared_secret(id, &shared_secret, &shared_secret_len) < 0) {
		/* XXX: Something blew up, assume no such ID */
		status = "NO_SUCH_CLIENT";
		send_response(conn, NULL, status, NULL, timestamp);
		goto free_mem;
	}

	if (otp == NULL) {
		info = "otp";
		status = "MISSING_PARAMETER";
		signature = sign_request(shared_secret, shared_secret_len,
					 info, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);
		goto free_mem;
	}

	if (! yubikey_modhex_p(otp) ||
	    strlen(otp) < (YUBIKEY_BLOCK_SIZE * 2 + 1)) {
		status = "BAD_OTP";
		signature = sign_request(shared_secret, shared_secret_len,
					 NULL, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);
		goto free_mem;
	}

	if (h != NULL) {
		if (validate_signature(shared_secret, shared_secret_len, h, 
				       id, otp) < 0) {
		status = "BAD_SIGNATURE";
		signature = sign_request(shared_secret, shared_secret_len,
					 NULL, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);
		goto free_mem;
		}
	}

	/* Validate OTP */
	/* Find public uid, if possible */
	split_otp(otp, &uid, &otp_token);
	if (get_data_for_uid(uid, &stats) < 0) {
		status = "BAD_OTP";
		signature = sign_request(shared_secret, shared_secret_len,
					 NULL, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);
		goto free_mem;
	}
	/* Argh, yubikey_parse takes in one modhex-ed token (but
	 * requires us to strip the public id first, and an unencoded aes key*/
	yubikey_parse((uint8_t*)(otp_token), (const uint8_t *)stats.secret_key, &token);
	if (!yubikey_crc_ok_p((void*)&token) ||
	    memcmp(token.uid, stats.secret_uid, YUBIKEY_UID_SIZE) != 0) {
		status = "BAD_OTP";
		signature = sign_request(shared_secret, shared_secret_len,
					 NULL, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);

		goto free_mem;
		return MHD_YES;
	}
	if (yubikey_counter(token.ctr) < stats.session_counter ||
	    (yubikey_counter(token.ctr) == stats.session_counter &&
	     token.use <= stats.session_use)) {
		/* Replay */
		status = "REPLAYED_OTP";
		signature = sign_request(shared_secret, shared_secret_len,
					 NULL, status, timestamp);
		send_response(conn, signature, status, NULL, timestamp);
		syslog(LOG_NOTICE, "Replay attempt for otp=%s, id=%s, uid=%s",
		       otp, id, uid);
		goto free_mem;
	}

	/* Update status, if appropriate */
	free(stats.public_id);
	free(stats.secret_uid);
	free(stats.secret_key);
	memset(&stats, 0, sizeof(struct ykc_stats));
	stats.session_counter = yubikey_counter(token.ctr);
	stats.session_use = token.use;
	set_data_for_uid(uid, &stats);
	/* Generate response, sign it */
	syslog(LOG_INFO, "OK request for otp=%s, id=%s, uid=%s",
	       otp, id, uid);
	status = "OK";
	signature = sign_request(shared_secret, shared_secret_len,
				 NULL, status, timestamp);
	send_response(conn, signature, status, NULL, timestamp);
free_mem:
	free(timestamp);
	free(shared_secret);
	free(signature);
	free(stats.public_id);
	free(stats.secret_uid);
	free(stats.secret_key);
	free(uid);
	free(otp_token);
	return MHD_YES;
}

void print_usage()
{
	printf("yubikeyd [-c conffile] [-V] [-f] [-h]\n");
	printf("\n");
	printf("-c conffile - use configuration file\n");
	printf("-V - print version\n");
	printf("-f - keep in foreground\n");
	printf("-h - this help\n");
	exit(0);
}

void print_version(void)
{
	printf("yubikeyd %s\n", VERSION);
	exit(0);
}

int parse_config(const char *file, struct ykc_config *c)
{
	FILE *f;
	char line[4096];
	char *key, *value;

	f = fopen(file, "r");
	if (f == NULL) {
		fprintf(stderr, "Can't open configuration file %s: %s; exiting",
			file, strerror(errno));
		exit(1);
	}

	while (fgets(line, sizeof(line), f) != NULL) {
		if (line[0] == '#' || line[0] == '\n')
			continue;
		if ((value = index(line, '\n')) != NULL) {
			*value = '\0';
		}
		if ((value = index(line, '=')) == NULL) {
			/* XXX complain */
			continue;
		}
		key = line;
		*value = '\0';
		value++;

		if (strcmp(key, "pidfile") == 0) {
			c->pidfile = strdup(value);
			continue;
		}
		if (strcmp(key, "user") == 0) {
			struct passwd *p;
			p = getpwnam(value);
			c->user = strdup(value);
			c->uid = p->pw_uid;
			continue;
		}
		if (strcmp(key, "group") == 0) {
			struct group *g;
			g = getgrnam(value);
			c->group = strdup(value);
			c->gid = g->gr_gid;
			continue;
		}
		if (strcmp(key, "dbdef") == 0) {
			c->dbdef = strdup(value);
			continue;
		}
		if (strcmp(key, "port") == 0) {
			c->port = strtol(value, NULL, 0);
			continue;
		}
		if (strcmp(key, "log_otp") == 0) {
			c->log_otp = strtol(value, NULL, 0);
			continue;
		}
	}
	return 0;
}

int main(int argc, char ** argv)
{
	struct MHD_Daemon *d;
	int opt;
	const char *config;
	int foreground = 0;
	struct ykc_config conf;
	FILE *pidfd;

	while ((opt = getopt(argc, argv, "c:p:vfh")) != -1) {
	  switch (opt) {
	  case 'c':
	    config = optarg;
	    break;
	  case 'f':
	    foreground = 1;
	    break;
	  case 'V':
	    print_version();
	    break;
	  case 'h':
	  default:
	    print_usage();
	    break;
	  }
	}

	openlog("yubikeyd", LOG_PID, LOG_AUTHPRIV);
	syslog(LOG_NOTICE, "yubikeyd version %s starting up", VERSION);

	/* XXX return value */
	parse_config(config, &conf);

	d = MHD_start_daemon(MHD_USE_DEBUG,
			     conf.port,
			     NULL, /* Access policy handler */
			     NULL, /* Data to access policy handler */
			     handle_request, /* default handler for all URIs */
			     &conf, /* Data for default handler */
			     MHD_OPTION_END);
	if (d == NULL) {
		syslog(LOG_ERR, "could not start daemon, unsure why\n");
		perror("Error starting yubikeyd");
		exit(1);
	}

	/*XXX return value */
	if (setregid(conf.gid, conf.gid) < 0) {
		perror("changing group id");
		exit(1);
	}
	if (setreuid(conf.uid, conf.uid) < 0) {
		perror("changing group id");
		exit(1);
	}

	/* XXX check errors */
	gcry_check_version("0");
	gcry_control(GCRYCTL_INIT_SECMEM, 16384);

	db_conn = PQconnectdb(conf.dbdef);
	if (PQstatus(db_conn) != CONNECTION_OK) {
		syslog(LOG_ERR, "connection to database failed: %s",
		       PQerrorMessage(db_conn));
		exit(1);
	}

	if (!foreground)
		daemon(0, 0);

	unlink(conf.pidfile); /* XXX ignore errors? */
	pidfd = fopen(conf.pidfile, "wx");
	if (! pidfd) {
		perror("Error opening pid file");
		syslog(LOG_ERR, "could not open pid file: %s\n",
		       strerror(errno));
		exit(1);
	}
	fprintf(pidfd, "%d\n", getpid());
	fclose(pidfd);

	while (1) {
		fd_set rs, ws, es;
		int max_fd = 0;
		unsigned long long timeout;
		struct timeval tv;

		FD_ZERO(&rs);
		FD_ZERO(&ws);
		FD_ZERO(&es);

		if (MHD_get_fdset(d, &rs, &ws, &es, &max_fd) == MHD_NO) {
			MHD_stop_daemon(d);
			exit(1);
		}
		if (MHD_get_timeout(d, &timeout) == MHD_NO) {
			timeout = 0;
		}
		tv.tv_usec = (timeout % 1000) * 1000;
		tv.tv_sec = timeout / 1000;
		if (timeout == 0) {
			select(max_fd+1, &rs, &ws, &es, NULL);
		} else {
			select(max_fd+1, &rs, &ws, &es, &tv);
		}
		if (MHD_run(d) == MHD_NO) {
			MHD_stop_daemon(d);
			exit(1);
		}
	}
	MHD_stop_daemon (d);
	PQfinish(db_conn);
	return 0;
}
