#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# CloudWatch log forwarder — workaround for IC platform log routing gap
exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1

echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"

# CUDA compatibility setup (required for newer SageMaker inference AMIs)
source /usr/bin/cuda_compat.sh 2>/dev/null || true

echo "Starting <%= modelServer %> server"

<% if (modelServer === 'lmi' || modelServer === 'djl') { %>
<%- include('serve.d/lmi') %>
<% } else { %>

<% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
# ---------------------------------------------------------------------------
# download_model_from_s3 — Download model artifacts from S3 to a local path
# ---------------------------------------------------------------------------
download_model_from_s3() {
    local s3_uri="$1"
    local dest_path="$2"
    local start_time
    start_time=$(date +%s)

    if [ -z "$s3_uri" ] || [ -z "$dest_path" ]; then
        echo "Error: download_model_from_s3 requires S3 URI and destination path" >&2
        return 1
    fi

    echo "Downloading model from ${s3_uri} to ${dest_path}..." >&2
    mkdir -p "${dest_path}"

    if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
        if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
            echo "Error: Failed to download tarball from ${s3_uri}" >&2
            return 1
        fi
        if ! tar -xzf /tmp/model_archive.tar.gz -C "$dest_path"; then
            echo "Error: Failed to extract tarball from ${s3_uri}" >&2
            rm -f /tmp/model_archive.tar.gz
            return 1
        fi
        rm -f /tmp/model_archive.tar.gz
    elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
        if ! aws s3 sync "$s3_uri" "$dest_path"; then
            echo "Error: Failed to sync from ${s3_uri}" >&2
            return 1
        fi
    else
        if ! aws s3 cp "$s3_uri" "$dest_path/"; then
            echo "Error: Failed to copy ${s3_uri}" >&2
            return 1
        fi
    fi

    local duration
    duration=$(( $(date +%s) - start_time ))
    echo "Download complete: ${s3_uri} → ${dest_path} (${duration}s)" >&2
}
<% } %>

# ---------------------------------------------------------------------------
# Model Loading Adapter — resolve model based on MODEL_SOURCE env var
# ---------------------------------------------------------------------------
MODEL_SOURCE="${MODEL_SOURCE:-huggingface}"
MODEL_ARTIFACT_URI="${MODEL_ARTIFACT_URI:-}"
LOCAL_MODEL_PATH="/opt/ml/model"

<% if (modelServer === 'vllm') { %>
_MODEL_VAR="VLLM_MODEL"
<% } else if (modelServer === 'sglang') { %>
_MODEL_VAR="SGLANG_MODEL_PATH"
<% } else if (modelServer === 'tensorrt-llm') { %>
_MODEL_VAR="TRTLLM_MODEL"
<% } %>

resolve_model() {
    case "$MODEL_SOURCE" in
        huggingface)
            echo "${!_MODEL_VAR}"
            return
            ;;
        s3|registry)
            if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
                echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
                echo "$LOCAL_MODEL_PATH"
                return
            fi

            if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
                local model_uri="${!_MODEL_VAR}"
                local registry_prefix="registry://"
                if [[ "$model_uri" == "${registry_prefix}"* ]]; then
                    local registry_path="${model_uri#${registry_prefix}}"
                    local group_name="${registry_path%%/*}"
                    local version="${registry_path#*/}"
                    local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"

                    local account_id
                    account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
                        echo "Error: Failed to get AWS account ID for model package ARN" >&2
                        exit 1
                    }

                    local package_arn="arn:aws:sagemaker:${region}:${account_id}:model-package/${group_name}/${version}"
                    echo "Resolving ${model_uri} via SageMaker DescribeModelPackage..." >&2
                    echo "   ARN: ${package_arn}" >&2

                    local describe_output
                    describe_output=$(aws sagemaker describe-model-package \
                        --model-package-name "$package_arn" \
                        --region "$region" \
                        --output json 2>/dev/null) || {
                        echo "Error: Failed to describe model package: ${package_arn}" >&2
                        exit 1
                    }

                    MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
import sys, json, re
try:
    pkg = json.load(sys.stdin)
    uri = ''
    containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
    if containers:
        c = containers[0]
        uri = c.get('ModelDataUrl', '')
        if not uri:
            uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
    if not uri:
        desc = pkg.get('ModelPackageDescription', '')
        m = re.search(r's3://[^\s]+', desc)
        if m:
            uri = m.group(0)
    print(uri)
except:
    print('')
" 2>/dev/null)

                    if [ -n "$MODEL_ARTIFACT_URI" ] && [ "$MODEL_ARTIFACT_URI" != "None" ]; then
                        echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
                    else
                        echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
                        exit 1
                    fi
                fi
            fi

            if [ -z "$MODEL_ARTIFACT_URI" ]; then
                echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
                exit 1
            fi
            if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
                echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
                exit 1
            fi
            echo "$LOCAL_MODEL_PATH"
            ;;
        *)
            echo "${!_MODEL_VAR}"
            return
            ;;
    esac
}

_RESOLVED_MODEL=$(resolve_model) || exit 1
export "${_MODEL_VAR}=${_RESOLVED_MODEL}"
echo "Resolved ${_MODEL_VAR}=${_RESOLVED_MODEL} (source: ${MODEL_SOURCE})"
unset _MODEL_VAR _RESOLVED_MODEL

# Initialize server arguments
<% if (modelServer === 'tensorrt-llm') { %>
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
<% } else { %>
SERVER_ARGS=(--host 0.0.0.0 --port 8080)

# ---------------------------------------------------------------------------
# Adapter Sidecar — DISABLED
# vLLM runs on port 8080 and handles /ping, /invocations, /v1/* natively.
# The /adapters route will be injected directly into vLLM's FastAPI app.
# ---------------------------------------------------------------------------
<% } %>

# --- Server-specific arg conversion and exec ---
<% if (['vllm', 'sglang', 'tensorrt-llm'].includes(modelServer)) { %>
<%- include('serve.d/' + modelServer) %>
<% } %>
<% } %>
