#!/usr/bin/env python3
import argparse
import serial
import os
import shutil
import subprocess
import sys
import time
from pexpect import fdpexpect

parser = argparse.ArgumentParser(
    description="Connect to serial console to execute stuff",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
    "--port",
    required=True,
    help="serial console device to connect " + "to (e.g. /dev/pts/X)",
)
parser.add_argument(
    "--hostname", default="bookworm", help="hostname of the system for login process"
)
parser.add_argument("--user", default="root", help="user name to use for login")
parser.add_argument("--password", default="grml", help="password for login")
parser.add_argument(
    "--timeout",
    default="180",
    type=int,
    help="Maximum time for finding the login prompt, in seconds",
)
parser.add_argument(
    "--screenshot",
    default="screenshot.jpg",
    help="file name for screenshot captured via VNC on error",
)
parser.add_argument(
    "--tries",
    default="12",
    type=int,
    help="Number of retries for finding the login prompt",
)
parser.add_argument(
    "--poweroff",
    action="store_true",
    default=False,
    help='send "poweroff" command after all other commands',
)
parser.add_argument("command", nargs="+", help="command to execute after logging in")


def print_ser_lines(ser):
    for line in ser.readlines():
        print("<<", line)  # line will be a binary string


def write_ser_line(ser, text):
    print(">>", text)
    ser.write(("%s\n" % text).encode())
    ser.flush()


def login(ser, hostname, user, password, timeout=5):

    child = fdpexpect.fdspawn(ser.fileno())
    child.sendline("\n")

    try:
        child.expect("%s@%s" % (user, hostname), timeout=timeout)
        return
    except:
        pass

    print("Waiting for login prompt...")
    child.expect("%s login:" % hostname, timeout=timeout)
    print("Logging in...")
    write_ser_line(ser, user)
    time.sleep(1)
    write_ser_line(ser, password)
    time.sleep(1)

    print("Waiting for shell prompt...")
    child.expect("%s@%s" % (user, hostname), timeout=timeout)


def capture_vnc_screenshot(screenshot_file):
    if not shutil.which("vncsnapshot"):
        print("WARN: vncsnapshot not available, skipping vnc snapshot capturing.")
        return

    print("Trying to capture screenshot via vncsnapshot to", screenshot_file)

    proc = subprocess.Popen(["vncsnapshot", "localhost", screenshot_file])
    proc.wait()
    if proc.returncode != 0:
        print("WARN: failed to capture vnc snapshot :(")
    else:
        print("Screenshot file '%s' available" % os.path.abspath(screenshot_file))


def main():
    args = parser.parse_args()
    hostname = args.hostname
    password = args.password
    port = args.port
    user = args.user
    commands = args.command
    screenshot_file = args.screenshot

    ser = serial.Serial(port, 115200)
    ser.flushInput()
    ser.flushOutput()

    success = False
    ts_start = time.time()
    for i in range(args.tries):
        try:
            print("Logging into %s via serial console [try %s]" % (port, i))
            login(ser, hostname, user, password)
            success = True
            break
        except Exception as except_inst:
            print("Login failure (try %s):" % (i,), except_inst, file=sys.stderr)
            time.sleep(5)
        if time.time() - ts_start > args.timeout:
            print("Timeout reached waiting for login prompt", file=sys.stderr)
            break

    if success:
        write_ser_line(ser, "")
        ser.timeout = 5
        print_ser_lines(ser)
        print("Running commands...")
        for command in commands:
            write_ser_line(ser, command)
            print_ser_lines(ser)
        if args.poweroff:
            print("Sending final poweroff command...")
            write_ser_line(ser, "poweroff")
            ser.flush()
            # after poweroff, the serial device will probably vanish. do not attempt reading from it anymore.

    if not success:
        capture_vnc_screenshot(screenshot_file)
        sys.exit(1)


if __name__ == "__main__":
    main()
