# auto-generated by @connexum/ai-governance — do not edit; regenerate via npx ai-governance init --agent-dir
# agent-id: {{AGENT_ID}}
# source:   {{SOURCE_FILE}}
"""
Governance shim for {{MODULE_NAME}}.
Wraps HuggingFace InferenceClient + transformers.pipeline with Connexum
governance enforcement at request time.

BAA status: HuggingFace does NOT offer a standard BAA. Enterprise Hub
plans may negotiate per-deployment terms; treat as no-BAA for HIPAA-
scoped data unless explicitly agreed (see HuggingFaceGovernance docs).

Change pack assignments in governance.json; regenerate this file to apply.
"""
from __future__ import annotations

import importlib
import importlib.util
import sys
import os
from pathlib import Path
from typing import Any, Callable

# ---------------------------------------------------------------------------
# Load original module
# ---------------------------------------------------------------------------
_src_dir = Path(__file__).parent
_orig_spec = importlib.util.spec_from_file_location(
    "{{MODULE_NAME}}._original",
    _src_dir / "{{SOURCE_FILE}}",
)
_orig_mod = importlib.util.module_from_spec(_orig_spec)  # type: ignore[arg-type]
_orig_spec.loader.exec_module(_orig_mod)  # type: ignore[union-attr]

# ---------------------------------------------------------------------------
# Set up HuggingFace governance hook
# ---------------------------------------------------------------------------
from connexum_governance import GovernanceClient
from connexum_governance.integrations.huggingface import HuggingFaceGovernance
from connexum_governance.models import GovernanceViolation

_gov_client = GovernanceClient(
    api_url=os.environ.get("CONNEXUM_API_URL", "http://localhost:4201"),
    license_key=os.environ.get("CONNEXUM_LICENSE_KEY", ""),
)
_gov = HuggingFaceGovernance(
    client=_gov_client,
    agent_name="{{AGENT_ID}}",
)

# ---------------------------------------------------------------------------
# Patch InferenceClient method-shape on the original module's namespace
#
# Two HuggingFace patterns are wrapped:
#   1. huggingface_hub.InferenceClient — text_generation / chat_completion
#      methods on a class instance. We wrap the methods on each instance
#      instantiated from the original module's import.
#   2. transformers.pipeline() — callable. We wrap calls to .__call__().
#
# Both paths route through HuggingFaceGovernance.intercept_request which
# applies pack policy + emits an audit chain entry. A DENY decision raises
# GovernanceViolation; an ALLOW lets the original call proceed.
# ---------------------------------------------------------------------------


def _governed_call(method: Callable[..., Any], shape: str) -> Callable[..., Any]:
    """Wrap a HuggingFace method (text_generation / chat_completion / pipeline)
    so each invocation is governance-checked before it fires."""

    def _wrapped(*args: Any, **kwargs: Any) -> Any:
        # Best-effort request shape extraction. Falls back to {"inputs": str(args)}
        # if the call doesn't match a known HuggingFace signature.
        request: dict[str, Any] = {"model": kwargs.get("model", "unknown")}
        if shape == "chat":
            request["messages"] = kwargs.get("messages") or (args[0] if args else [])
            request["max_tokens"] = kwargs.get("max_tokens")
        else:
            request["inputs"] = kwargs.get("prompt") or kwargs.get("inputs") or (args[0] if args else "")
            request["max_new_tokens"] = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")

        decision = _gov.intercept_request(request)
        if decision.denied:
            raise GovernanceViolation(decision)
        return method(*args, **kwargs)

    return _wrapped


# Patch any huggingface_hub.InferenceClient references in the original module
if hasattr(_orig_mod, "huggingface_hub"):
    _hh = _orig_mod.huggingface_hub
    if hasattr(_hh, "InferenceClient"):
        _orig_inference_client_cls = _hh.InferenceClient
        _orig_init = _orig_inference_client_cls.__init__

        def _governed_inference_init(self: Any, *args: Any, **kwargs: Any) -> None:
            _orig_init(self, *args, **kwargs)
            # Wrap the two main inference methods on this instance
            if hasattr(self, "text_generation"):
                self.text_generation = _governed_call(self.text_generation, "text")
            if hasattr(self, "chat_completion"):
                self.chat_completion = _governed_call(self.chat_completion, "chat")

        _orig_inference_client_cls.__init__ = _governed_inference_init  # type: ignore[method-assign]

# Patch any standalone InferenceClient references
if hasattr(_orig_mod, "InferenceClient"):
    _ic = _orig_mod.InferenceClient
    if hasattr(_ic, "__init__"):
        _orig_ic_init = _ic.__init__

        def _governed_ic_init(self: Any, *args: Any, **kwargs: Any) -> None:
            _orig_ic_init(self, *args, **kwargs)
            if hasattr(self, "text_generation"):
                self.text_generation = _governed_call(self.text_generation, "text")
            if hasattr(self, "chat_completion"):
                self.chat_completion = _governed_call(self.chat_completion, "chat")

        _ic.__init__ = _governed_ic_init  # type: ignore[method-assign]

# Patch any transformers.pipeline references — pipeline() returns a Pipeline
# instance whose __call__ is what consumers invoke.
if hasattr(_orig_mod, "pipeline"):
    _orig_pipeline = _orig_mod.pipeline

    def _governed_pipeline(*args: Any, **kwargs: Any) -> Any:
        _pipe = _orig_pipeline(*args, **kwargs)
        # Wrap the callable so each invocation goes through governance
        _orig_call = _pipe.__call__
        _pipe.__call__ = _governed_call(_orig_call, "text")  # type: ignore[method-assign]
        return _pipe

    _orig_mod.pipeline = _governed_pipeline  # type: ignore[attr-defined]

# ---------------------------------------------------------------------------
# Re-export original module's public API
# ---------------------------------------------------------------------------
for _name in dir(_orig_mod):
    if not _name.startswith("_"):
        globals()[_name] = getattr(_orig_mod, _name)

__all__ = [n for n in dir(_orig_mod) if not n.startswith("_")]
