/*
 * Potamus: an audio player
 * Copyright (C) 2004, 2005, 2006, 2007, 2013 Adam Sampson <ats@offog.org>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 2 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see
 * <http://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <string.h>
#include <glib.h>
#include <mad.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <stdlib.h>
#include "buffer.h"
#include "format.h"
#include "input.h"
#include "input-mad.h"

// This is largely inspired by the madlld example.
// The dithering code is based on that from madplay.

typedef struct {
	mad_fixed_t error[3];
	mad_fixed_t random;
} aomad_dither;

#define MAX_CHANNELS 2
#define INPUT_BUF_SIZE (5 * 8192)
typedef struct {
	struct mad_stream stream;
	struct mad_frame frame;
	struct mad_synth synth;
	mad_timer_t pos;
	mad_timer_t length;
	mad_timer_t seek_dest;
	gboolean seeking;
	unsigned char buf[INPUT_BUF_SIZE + MAD_BUFFER_GUARD];
	FILE *f;
	off_t file_size;
	unsigned char *guard_pos;
	aomad_dither dither[MAX_CHANNELS];
} aomad;

static void init_stream(aomad *a) {
	mad_stream_init(&a->stream);
	a->stream.error = 0;
	a->stream.buffer = a->stream.bufend = a->buf;
	a->guard_pos = NULL;
	a->seeking = FALSE;
}

// This is the same PRNG that madplay uses (and VLC, and various other open
// source media programs that need to do audio dithering).
static inline unsigned long mad_prng(unsigned long state) {
	return (state * 0x0019660dL + 0x3c6ef35fL) & 0xffffffffL;
}

static inline unsigned long conv_sample(mad_fixed_t in, aomad_dither *d,
                                        int bits) {
	in += d->error[0] - d->error[1] + d->error[2];
	d->error[2] = d->error[1];
	d->error[1] = d->error[0] / 2;

	mad_fixed_t out = in + (1L << (MAD_F_FRACBITS + 1 - bits - 1));
	unsigned int scalebits = MAD_F_FRACBITS + 1 - bits;
	mad_fixed_t mask = (1L << scalebits) - 1;

	mad_fixed_t random = mad_prng(d->random);
	out += (random & mask) - (d->random & mask);
	d->random = random;

	if (out <= -MAD_F_ONE)
		out = -MAD_F_ONE;
	else if (out >= MAD_F_ONE - 1)
		out = MAD_F_ONE - 1;

	out &= ~mask;
	d->error[0] = in - out;
	
	return out >> scalebits;
}

static int aomad_fetch_data(aomad *a) {
	if (a->stream.next_frame != NULL) {
		memmove(a->buf, a->stream.next_frame,
			a->stream.bufend - a->stream.next_frame);
		a->stream.bufend -=
			a->stream.next_frame - a->stream.buffer;
	}

	int used = a->stream.bufend - a->stream.buffer;
	int n = fread(a->buf + used, 1, INPUT_BUF_SIZE - used, a->f);
	if (ferror(a->f))
		g_error("error reading file");
	if (lseek(fileno(a->f), 0, SEEK_CUR) == a->file_size
	    && a->guard_pos == NULL) {
		// At EOF; need to add guard bytes to get the last frame
		// decoded by mad.
		a->guard_pos = a->buf + n;
		memset(a->guard_pos, 0, MAD_BUFFER_GUARD);
		n += MAD_BUFFER_GUARD;
	}
	if (n == 0 && feof(a->f))
		return 0;
	a->stream.bufend += n;

	// This looks silly, but looking at the libmad code it
	// does the right thing (resets some other internal
	// state as well as setting the buffer pointers).
	mad_stream_buffer(&a->stream, a->stream.buffer,
			  a->stream.bufend - a->stream.buffer);
	a->stream.error = 0;

	return 1;
}

static int handle_decode_error(aomad *a) {
	if (a->stream.error == MAD_ERROR_LOSTSYNC
	    && a->guard_pos != NULL
	    && a->stream.this_frame >= a->guard_pos) {
		// Hit the guard data -- end of stream.
		return 0;
	} else if (a->stream.error == MAD_ERROR_BUFLEN) {
		// Need more data.
		return 1;
	} else if (MAD_RECOVERABLE(a->stream.error)) {
		// FIXME check for an ID3 block
		printf("recoverable error: %s\n",
		       mad_stream_errorstr(&a->stream));
		return 1;
	} else {
		g_error("unrecoverable error in stream: %s",
		        mad_stream_errorstr(&a->stream));
		return -1;
	}
}

static int scan_file(aomad *a) {
	struct mad_header header;

	mad_header_init(&header);

	while (1) {
		if (a->stream.buffer == a->stream.bufend
		    || a->stream.error == MAD_ERROR_BUFLEN) {
			if (aomad_fetch_data(a) == 0)
				break;
		}
		if (mad_header_decode(&header, &a->stream) == 0) {
			mad_timer_add(&a->length, header.duration);
		} else {
			int rc = handle_decode_error(a);
			// If we find an unrecoverable error, just pretend
			// that's the end of the file.
			if (rc <= 0)
				break;
		}
	}

	mad_header_finish(&header);
	mad_stream_finish(&a->stream);
	// Reinitialise the stream for decoding.
	init_stream(a);
	if (fseek(a->f, 0, SEEK_SET) < 0)
		return -1;

	return 0;
}

static int aomad_open(input *p, const char *fn) {
	aomad *a = malloc(sizeof *a);
	if (a == NULL)
		g_error("out of memory");
	p->data = a;

	init_stream(a);
	mad_frame_init(&a->frame);
	mad_synth_init(&a->synth);

	a->f = fopen(fn, "r");
	if (a->f == NULL)
		return -1;

	struct stat st;
	if (fstat(fileno(a->f), &st) < 0)
		return -1;
	a->file_size = st.st_size;

	mad_timer_reset(&a->pos);
	mad_timer_reset(&a->length);

	for (int i = 0; i < MAX_CHANNELS; i++) {
		for (int j = 0; j < 3; j++)
			a->dither[i].error[j] = 0;
		a->dither[i].random = 0;
	}

	if (scan_file(a) < 0)
		return -1;

	return 0;
}

static int aomad_get_audio(input *p, buffer *buf) {
	aomad *a = (aomad *) p->data;

	int discard = 0;
	while (1) {
		if (a->stream.buffer == a->stream.bufend
		    || a->stream.error == MAD_ERROR_BUFLEN) {
			if (aomad_fetch_data(a) == 0) {
				// This shouldn't happen because we'll
				// hit the guard data first, but just
				// in case...
				return 0;
			}
		}
		if (a->seeking
		    && mad_header_decode(&a->frame.header,
					 &a->stream) == 0) {
			// Seeking -- ignore frame until we've reached
			// where we want to go.
			mad_timer_add(&a->pos, a->frame.header.duration);
			if (mad_timer_compare(a->pos,
					      a->seek_dest) == -1)
				continue;

			a->seeking = FALSE;
			discard = 2;
		} else if ((!a->seeking)
			   && mad_frame_decode(&a->frame,
					       &a->stream) == 0) {
			if (discard > 0) {
				// We've just finished seeking -- throw
				// away the frame to get back in sync.
				mad_timer_add(&a->pos,
					      a->frame.header.duration);
				--discard;
				if (discard == 0) {
					// Last discarded frame -- get
					// the synth back in sync.
					mad_synth_frame(&a->synth,
							&a->frame);
				}
			} else {
				// Got a frame.
				break;
			}
		} else if (discard > 0) {
			// Error while discarding frames; ignore it.
		} else {
			int rc = handle_decode_error(a);
			if (rc <= 0)
				return rc;
		}
	}

	mad_timer_add(&a->pos, a->frame.header.duration);
	mad_synth_frame(&a->synth, &a->frame);

	int nbits = 24;
	int nbytes = bytes_per_sample(nbits);
	int nchannels = MAD_NCHANNELS(&a->frame.header);

	mad_fixed_t *in[MAX_CHANNELS];
	if (nchannels > MAX_CHANNELS)
		g_error("too many channels");

	size_t out_size = a->synth.pcm.length * nchannels * nbytes;
	unsigned char *out = buffer_reserve(buf, out_size);
	if (out == NULL)
		g_error("out of memory");

	p->fmt.bits = nbits;
	p->fmt.rate = a->frame.header.samplerate;
	p->fmt.channels = nchannels;
	p->fmt.byte_format = END_LITTLE;
	p->bitrate = a->frame.header.bitrate / 1000.0;

	for (int j = 0; j < nchannels; j++)
		in[j] = a->synth.pcm.samples[j];
	for (int i = 0; i < a->synth.pcm.length; i++) {
		for (int j = 0; j < nchannels; j++) {
			unsigned long sample = conv_sample(*in[j]++,
			                                   &a->dither[j],
			                                   nbits);
			if (nbits == 24)
				sample <<= 8;
			int n = nbytes;
			while (n-- > 0) {
				*out++ = sample & 0xFF;
				sample >>= 8;
			}
		}
	}
	buf->used += out_size;

	return out_size;
}

static int aomad_get_pos(input *p, double *pos) {
	aomad *a = (aomad *) p->data;

	*pos = mad_timer_count(a->pos, MAD_UNITS_MILLISECONDS) / 1000.0L;

	return 0;
}

static int aomad_get_len(input *p, double *len) {
	aomad *a = (aomad *) p->data;

	*len = mad_timer_count(a->length, MAD_UNITS_MILLISECONDS) / 1000.0L;

	return 0;
}

static int aomad_get_seekable(input *p) {
	return 1;
}

static int aomad_set_pos(input *p, double pos) {
	aomad *a = (aomad *) p->data;

	if (fseek(a->f, 0, SEEK_SET) < 0)
		return -1;

	// Discard data left in the buffer.
	mad_stream_finish(&a->stream);
	init_stream(a);

	mad_timer_reset(&a->pos);
	mad_timer_set(&a->seek_dest, 0, pos * 1000, 1000);
	a->seeking = TRUE;

	mad_frame_mute(&a->frame);
	mad_synth_mute(&a->synth);

	return 0;
}

static int aomad_close(input *p) {
	aomad *a = (aomad *) p->data;

	mad_synth_finish(&a->synth);
	mad_frame_finish(&a->frame);
	mad_stream_finish(&a->stream);

	if (a->f != NULL)
		fclose(a->f);

	free(a);
	free(p);

	return 0;
}

input *input_new_mad(void) {
	input *p = input_alloc();

	p->open = aomad_open;
	p->get_audio = aomad_get_audio;
	p->get_pos = aomad_get_pos;
	p->get_len = aomad_get_len;
	p->get_seekable = aomad_get_seekable;
	p->set_pos = aomad_set_pos;
	p->close = aomad_close;

	return p;
}

