#!/usr/bin/env python3
"""lfit-server - Setup and launcher for the stable-diffusion.cpp sd-server (Vulkan).

This script helps you download, configure, and start the sd-server binary that LFIT
talks to. It is NOT the server itself (that's architecture-specific) — it's a launcher
that ensures the correct flags and model paths are set up.

Usage:
  lfit-server setup          Download sd-server binary and verify model files
  lfit-server start          Start the server (foreground)
  lfit-server start --daemon Start the server (background, daemonized)
  lfit-server status         Check if the server is running
  lfit-server stop           Stop a running server

Environment variables:
  LFIT_MODEL_DIR    - Directory containing SDXL model files (default: ~/.lfit/models/)
  LFIT_LORA_DIR     - Directory containing LoRA .safetensors files (default: ~/.lfit/models/loras/)
  LFIT_SERVER_URL   - URL of the sd-server instance (default: http://127.0.0.1:7860)
  LFIT_SD_BIN       - Path to sd-server binary (default: ~/.lfit/bin/sd-server)
"""
import argparse
import json
import os
import platform
import socket
import subprocess
import sys
import urllib.request
import urllib.error

MODEL_DIR = os.environ.get("LFIT_MODEL_DIR", os.path.expanduser("~/.lfit/models/"))
LORA_DIR = os.environ.get("LFIT_LORA_DIR", os.path.expanduser("~/.lfit/models/loras/"))
SERVER_URL = os.environ.get("LFIT_SERVER_URL", "http://127.0.0.1:7860")
SD_BIN = os.environ.get("LFIT_SD_BIN", os.path.expanduser("~/.lfit/bin/sd-server"))

SDXL_MODEL = "sd_xl_base_1.0.safetensors"
LORA_8STEP = "sdxl_lightning_8step_lora.safetensors"
LISTEN_IP = "127.0.0.1"
LISTEN_PORT = "7860"

RELEASES_URL = "https://api.github.com/repos/leejet/stable-diffusion.cpp/releases/latest"


def _detect_arch():
    s = platform.system().lower()
    m = platform.machine().lower()
    if s == "linux":
        if m in ("x86_64", "amd64"): return "linux", "x86_64"
        if m in ("aarch64", "arm64"): return "linux", "aarch64"
    if s == "darwin":
        if m == "arm64": return "macos", "arm64"
        return "macos", "x86_64"
    if s == "windows": return "windows", "x86_64"
    return s, m


def cmd_setup(args):
    if os.path.isfile(SD_BIN) and os.access(SD_BIN, os.X_OK):
        sys.stderr.write("sd-server binary found: %s\n" % SD_BIN)
    else:
        sys.stderr.write("sd-server binary not found at %s\n" % SD_BIN)
        sys.stderr.write("Attempting to download from github releases...\n")
        try:
            req = urllib.request.Request(RELEASES_URL, headers={"User-Agent": "lfit-server/1.0"})
            with urllib.request.urlopen(req, timeout=30) as resp:
                release = json.load(resp)
            assets = release.get("assets", [])
            tag = release.get("tag_name", "unknown")
            os_name, arch = _detect_arch()
            candidates = []
            for a in assets:
                name = a["name"].lower()
                if os_name in name and arch in name and "vulkan" in name:
                    candidates.append(a)
            if not candidates:
                for a in assets:
                    name = a["name"].lower()
                    if os_name in name and arch in name:
                        candidates.append(a)
            if candidates:
                asset = candidates[0]
                sys.stderr.write("Found release %s: %s\n" % (tag, asset["name"]))
                bin_dir = os.path.dirname(SD_BIN)
                os.makedirs(bin_dir, exist_ok=True)
                try:
                    urllib.request.urlretrieve(asset["browser_download_url"], SD_BIN)
                    os.chmod(SD_BIN, 0o755)
                    sys.stderr.write("Downloaded sd-server to %s\n" % SD_BIN)
                except Exception:
                    sys.stderr.write("Download failed. Get it manually from:\n"
                                     "  https://github.com/leejet/stable-diffusion.cpp/releases\n")
            else:
                sys.stderr.write("No matching binary for %s/%s. Build from source:\n"
                                 "  https://github.com/leejet/stable-diffusion.cpp\n" % (os_name, arch))
        except Exception as exc:
            sys.stderr.write("Could not check releases: %s\n" % exc)

    model_path = os.path.join(MODEL_DIR, SDXL_MODEL)
    if os.path.isfile(model_path):
        size_mb = os.path.getsize(model_path) / (1024 * 1024)
        sys.stderr.write("SDXL model found: %s (%.0f MB)\n" % (model_path, size_mb))
    else:
        sys.stderr.write("WARNING: SDXL model not found at %s\n" % model_path)
        sys.stderr.write("Download from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0\n")

    lora_path = os.path.join(LORA_DIR, LORA_8STEP)
    if os.path.isfile(lora_path):
        sys.stderr.write("8-step Lightning LoRA found: %s\n" % lora_path)
    else:
        sys.stderr.write("WARNING: 8-step Lightning LoRA not found at %s\n" % lora_path)
        sys.stderr.write("Download from: https://huggingface.co/ByteDance/SDXL-Lightning\n")

    sys.stderr.write("\nSetup check complete.\n")
    return 0


def cmd_status(args):
    try:
        resp = urllib.request.urlopen(SERVER_URL + "/sdapi/v1/loras", timeout=5)
        data = json.load(resp)
        loras = [l.get("path", l.get("name", "?")) for l in data] if isinstance(data, list) else []
        sys.stderr.write("Server is RUNNING on %s\n" % SERVER_URL)
        if loras:
            sys.stderr.write("Available LoRAs: %s\n" % ", ".join(loras))
        else:
            sys.stderr.write("No LoRAs loaded (check --lora-model-dir)\n")
        return 0
    except urllib.error.URLError:
        sys.stderr.write("Server is NOT responding on %s\n" % SERVER_URL)
        return 1


def cmd_start(args):
    # Check if already running
    try:
        urllib.request.urlopen(SERVER_URL + "/sdapi/v1/loras", timeout=3)
        sys.stderr.write("ERROR: server already responding on %s\nStop it first.\n" % SERVER_URL)
        return 1
    except urllib.error.URLError:
        pass

    # Check port conflicts
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(1)
        result = s.connect_ex((LISTEN_IP, int(LISTEN_PORT)))
        s.close()
        if result == 0:
            sys.stderr.write("ERROR: port %s is already in use.\n" % LISTEN_PORT)
            return 1
    except Exception:
        pass

    if not os.path.isfile(SD_BIN):
        sys.stderr.write("ERROR: sd-server binary not found at %s\nRun 'lfit-server setup' first.\n" % SD_BIN)
        return 1

    model_path = os.path.join(MODEL_DIR, SDXL_MODEL)
    if not os.path.isfile(model_path):
        sys.stderr.write("ERROR: SDXL model not found at %s\n" % model_path)
        return 1

    cmd = [
        SD_BIN,
        "--listen-ip", LISTEN_IP,
        "--listen-port", LISTEN_PORT,
        "--model", model_path,
        "--lora-model-dir", LORA_DIR,
        "--type", "f16",
        "--vae-tiling",
        "--lora-apply-mode", "at_runtime",
    ]

    sys.stderr.write("Starting sd-server:\n  %s\n" % " ".join(cmd))
    if args.daemon:
        log_dir = os.path.expanduser("~/.lfit/logs")
        os.makedirs(log_dir, exist_ok=True)
        log_path = os.path.join(log_dir, "sd-server.log")
        try:
            log_fh = open(log_path, "a")
            proc = subprocess.Popen(cmd, stdout=log_fh, stderr=log_fh, start_new_session=True)
            sys.stderr.write("Server started (PID %d), logging to %s\n" % (proc.pid, log_path))
            return 0
        except Exception as exc:
            sys.stderr.write("Failed to start: %s\n" % exc)
            return 1
    else:
        try:
            return subprocess.call(cmd)
        except KeyboardInterrupt:
            sys.stderr.write("\nServer stopped.\n")
            return 0


def cmd_stop(args):
    # Try to find the specific PID by hitting the server first, then killing by port
    import signal
    try:
        # Use lsof to find the process on our port specifically
        result = subprocess.run(
            ["lsof", "-ti", ":%s" % LISTEN_PORT],
            capture_output=True, text=True)
        pids = result.stdout.strip().split()
        if pids:
            for pid in pids:
                try:
                    os.kill(int(pid), signal.SIGTERM)
                    sys.stderr.write("Stopped process %s on port %s.\n" % (pid, LISTEN_PORT))
                except ProcessLookupError:
                    pass
            return 0
        else:
            sys.stderr.write("No process found on port %s.\n" % LISTEN_PORT)
            return 1
    except FileNotFoundError:
        # lsof not available — fall back to pkill with tighter pattern
        result = subprocess.run(
            ["pkill", "-f", "sd-server.*--listen-port.*%s" % LISTEN_PORT],
            capture_output=True, text=True)
        if result.returncode == 0:
            sys.stderr.write("Sent stop signal to sd-server on port %s.\n" % LISTEN_PORT)
            return 0
        else:
            sys.stderr.write("No matching sd-server process found.\n")
            return 1
    except Exception as exc:
        sys.stderr.write("Error stopping server: %s\n" % exc)
        return 1


def main():
    ap = argparse.ArgumentParser(description="LFIT server setup and launcher.")
    sub = ap.add_subparsers(dest="command")
    sub.add_parser("setup", help="Download sd-server binary and verify model files")
    sub.add_parser("status", help="Check if the server is running")
    start_p = sub.add_parser("start", help="Start the sd-server")
    start_p.add_argument("--daemon", action="store_true", help="Run in background")
    sub.add_parser("stop", help="Stop a running sd-server")
    args = ap.parse_args()

    if args.command == "setup": return cmd_setup(args)
    elif args.command == "status": return cmd_status(args)
    elif args.command == "start": return cmd_start(args)
    elif args.command == "stop": return cmd_stop(args)
    else: ap.print_help(); return 1


if __name__ == "__main__":
    sys.exit(main())
