#!/usr/bin/env python3
"""Manage GSD thinking policy overrides.

Usage:
  gsd-thinking show
  gsd-thinking reset
  gsd-thinking set <target> <level> [--force]
  gsd-thinking remove <target>
  gsd-thinking validate
  gsd-thinking lint

Examples:
  gsd-thinking set plan xhigh
  gsd-thinking set execute high
  gsd-thinking set plan-slice high
  gsd-thinking set research- high
  gsd-thinking remove execute-task

Targets:
  - Shortcuts: default, research, discuss, plan, execute, replan
  - Prefixes: any value ending with '-' (must match KNOWN_PREFIXES, override with --force)
  - Unit types: any exact unit type (must be in KNOWN_UNIT_TYPES, override with --force)

Validation rules (ported from quangdo126/gsd-2 fork):
  - levels must be one of: off, minimal, low, medium, high, xhigh
  - unit types must be in KNOWN_UNIT_TYPES (typos rejected)
  - prefixes should match KNOWN_PREFIXES (warning if not, --force to override)

Run `gsd-thinking validate` to check the current policy file for invalid keys
or values without modifying anything. Run `gsd-thinking lint` to also auto-prune
invalid entries from disk.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Any

# Allow `import _thinking_constants` regardless of CWD.
sys.path.insert(0, str(Path(__file__).resolve().parent))
from _thinking_constants import (  # noqa: E402
    KNOWN_LEVELS,
    KNOWN_PREFIXES,
    KNOWN_UNIT_TYPES,
)

POLICY_PATH = Path.home() / ".gsd" / "agent" / "thinking-policy.json"
DEFAULT_POLICY = {
    "default": "medium",
    "prefixes": {
        "research-": "medium",
        "discuss-": "high",
        "plan-": "xhigh",
    },
    "unitTypes": {
        "execute-task": "off",
        "execute-task-simple": "off",
        "reactive-execute": "off",
        "replan-slice": "high",
    },
}
SHORTCUTS = {
    "default": ("default", None),
    "research": ("prefixes", "research-"),
    "discuss": ("prefixes", "discuss-"),
    "plan": ("prefixes", "plan-"),
    "execute": ("unitTypes", "execute-task"),
    "replan": ("unitTypes", "replan-slice"),
}


def ensure_policy() -> dict[str, Any]:
    """Load the policy file or create the default shape if it does not exist."""
    POLICY_PATH.parent.mkdir(parents=True, exist_ok=True)
    if not POLICY_PATH.exists():
        write_policy(DEFAULT_POLICY)
        return json.loads(json.dumps(DEFAULT_POLICY))
    with POLICY_PATH.open("r", encoding="utf-8") as fh:
        data = json.load(fh)
    if not isinstance(data, dict):
        raise SystemExit(f"Invalid policy file: {POLICY_PATH}")
    data.setdefault("default", DEFAULT_POLICY["default"])
    data.setdefault("prefixes", {})
    data.setdefault("unitTypes", {})
    return data


def write_policy(policy: dict[str, Any]) -> None:
    """Persist the policy file with stable formatting."""
    POLICY_PATH.parent.mkdir(parents=True, exist_ok=True)
    with POLICY_PATH.open("w", encoding="utf-8") as fh:
        json.dump(policy, fh, indent=2)
        fh.write("\n")


def resolve_target(target: str) -> tuple[str, str | None]:
    """Resolve a user target into policy section + key."""
    if target in SHORTCUTS:
        return SHORTCUTS[target]
    if target.endswith("-"):
        return ("prefixes", target)
    if target == "default":
        return ("default", None)
    return ("unitTypes", target)


def validate_target(section: str, key: str | None, force: bool) -> list[str]:
    """Return a list of human-readable warnings/errors for the target.

    Errors raise SystemExit; warnings are printed when --force is used.
    Returns the warning messages (caller decides whether to print).
    """
    warnings: list[str] = []
    if section == "prefixes" and key is not None:
        if key not in KNOWN_PREFIXES and not force:
            samples = ", ".join(sorted(list(KNOWN_PREFIXES))[:8])
            raise SystemExit(
                f"Unknown prefix '{key}'. Known prefixes (subset): {samples}.\n"
                f"Use --force to add it anyway."
            )
        if key not in KNOWN_PREFIXES:
            warnings.append(f"prefix '{key}' is not in KNOWN_PREFIXES (forced)")
    if section == "unitTypes" and key is not None:
        if key not in KNOWN_UNIT_TYPES and not force:
            raise SystemExit(
                f"Unknown unit type '{key}'. Must be one of KNOWN_UNIT_TYPES "
                f"(see `gsd-thinking validate` for the full list).\n"
                f"Use --force to add it anyway."
            )
        if key not in KNOWN_UNIT_TYPES:
            warnings.append(f"unit type '{key}' is not in KNOWN_UNIT_TYPES (forced)")
    return warnings


def cmd_show() -> int:
    """Print the current policy JSON."""
    policy = ensure_policy()
    print(json.dumps(policy, indent=2))
    return 0


def cmd_reset() -> int:
    """Reset the policy file back to the default opinionated preset."""
    write_policy(json.loads(json.dumps(DEFAULT_POLICY)))
    print(f"Reset policy: {POLICY_PATH}")
    return 0


def cmd_set(target: str, level: str, force: bool = False) -> int:
    """Set a target override to a validated thinking level."""
    if level not in KNOWN_LEVELS:
        raise SystemExit(
            f"Invalid level: {level}. Expected one of: {', '.join(sorted(KNOWN_LEVELS))}"
        )
    section, key = resolve_target(target)
    warnings = validate_target(section, key, force)
    policy = ensure_policy()
    if section == "default":
        policy["default"] = level
    else:
        policy.setdefault(section, {})
        policy[section][key] = level
    write_policy(policy)
    for w in warnings:
        print(f"warning: {w}", file=sys.stderr)
    print(f"Set {target} -> {level}")
    return 0


def cmd_remove(target: str) -> int:
    """Remove a target override and fall back to built-in rules."""
    policy = ensure_policy()
    section, key = resolve_target(target)
    if section == "default":
        policy["default"] = DEFAULT_POLICY["default"]
    else:
        policy.setdefault(section, {})
        policy[section].pop(key, None)
    write_policy(policy)
    print(f"Removed override for {target}")
    return 0


def collect_invalid_entries(policy: dict[str, Any]) -> list[tuple[str, str, str]]:
    """Return list of (section, key, reason) for invalid entries in policy."""
    issues: list[tuple[str, str, str]] = []
    default_level = policy.get("default")
    if default_level is not None and default_level not in KNOWN_LEVELS:
        issues.append(("default", "", f"invalid level '{default_level}'"))
    for prefix, level in (policy.get("prefixes") or {}).items():
        if level not in KNOWN_LEVELS:
            issues.append(("prefixes", prefix, f"invalid level '{level}'"))
            continue
        if not prefix.endswith("-"):
            issues.append(("prefixes", prefix, "prefix must end with '-'"))
        elif prefix not in KNOWN_PREFIXES:
            issues.append(("prefixes", prefix, "prefix not in KNOWN_PREFIXES"))
    for unit_type, level in (policy.get("unitTypes") or {}).items():
        if level not in KNOWN_LEVELS:
            issues.append(("unitTypes", unit_type, f"invalid level '{level}'"))
            continue
        if unit_type not in KNOWN_UNIT_TYPES:
            issues.append(("unitTypes", unit_type, "not in KNOWN_UNIT_TYPES"))
    return issues


def cmd_validate() -> int:
    """Report invalid entries in the current policy file (non-mutating)."""
    policy = ensure_policy()
    issues = collect_invalid_entries(policy)
    if not issues:
        print(f"OK: {POLICY_PATH} has no invalid entries")
        return 0
    print(f"Found {len(issues)} issue(s) in {POLICY_PATH}:")
    for section, key, reason in issues:
        loc = f"{section}.{key}" if key else section
        print(f"  - {loc}: {reason}")
    print("\nRun `gsd-thinking lint` to auto-prune invalid entries.")
    return 1


def cmd_lint() -> int:
    """Auto-prune invalid entries, keeping only canonical keys + valid levels."""
    policy = ensure_policy()
    issues = collect_invalid_entries(policy)
    if not issues:
        print(f"OK: {POLICY_PATH} already clean")
        return 0
    for section, key, _reason in issues:
        if section == "default":
            policy["default"] = DEFAULT_POLICY["default"]
        else:
            policy.get(section, {}).pop(key, None)
    write_policy(policy)
    print(f"Pruned {len(issues)} invalid entry/entries from {POLICY_PATH}")
    for section, key, reason in issues:
        loc = f"{section}.{key}" if key else section
        print(f"  - removed {loc}: {reason}")
    return 0


def main(argv: list[str]) -> int:
    """Parse CLI arguments and dispatch the selected command."""
    if len(argv) < 2 or argv[1] in {"-h", "--help", "help"}:
        print(__doc__.strip())
        return 0

    cmd = argv[1]
    if cmd == "show":
        return cmd_show()
    if cmd == "reset":
        return cmd_reset()
    if cmd == "validate":
        return cmd_validate()
    if cmd == "lint":
        return cmd_lint()
    if cmd == "set":
        rest = argv[2:]
        force = "--force" in rest
        rest = [a for a in rest if a != "--force"]
        if len(rest) != 2:
            raise SystemExit("Usage: gsd-thinking set <target> <level> [--force]")
        return cmd_set(rest[0], rest[1], force=force)
    if cmd == "remove":
        if len(argv) != 3:
            raise SystemExit("Usage: gsd-thinking remove <target>")
        return cmd_remove(argv[2])

    raise SystemExit(f"Unknown command: {cmd}")


if __name__ == "__main__":
    raise SystemExit(main(sys.argv))
