#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# CloudWatch log forwarder — workaround for IC platform log routing gap
exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1

echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"

# CUDA compatibility setup (required for newer SageMaker inference AMIs)
source /usr/bin/cuda_compat.sh 2>/dev/null || true

<% if (modelServer === 'vllm') { %>
echo "Starting vLLM server"
<% } else if (modelServer === 'sglang') { %>
echo "Starting SGLang server"
<% } else if (modelServer === 'tensorrt-llm') { %>
echo "Starting TensorRT-LLM server"
<% } else if (modelServer === 'lmi') { %>
echo "Starting LMI (Large Model Inference) server"
<% } else if (modelServer === 'djl') { %>
echo "Starting DJL Serving server"
<% } %>

<% if (modelServer === 'lmi' || modelServer === 'djl') { %>
# LMI/DJL containers use serving.properties for configuration
# The configuration file should be at /opt/ml/model/serving.properties
# DJL Serving will automatically start with this configuration

if [ ! -f /opt/ml/model/serving.properties ]; then
    echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
    exit 1
fi

echo "Using configuration from /opt/ml/model/serving.properties"
cat /opt/ml/model/serving.properties

# DJL Serving is already configured in the base image
# This script is not typically needed for LMI/DJL as they have their own entrypoint
# But we provide it for consistency with other model servers
exit 0
<% } else { %>

<% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
# ---------------------------------------------------------------------------
# download_model_from_s3 — Download model artifacts from S3 to a local path
# ---------------------------------------------------------------------------
download_model_from_s3() {
    local s3_uri="$1"
    local dest_path="$2"
    local start_time
    start_time=$(date +%s)

    if [ -z "$s3_uri" ] || [ -z "$dest_path" ]; then
        echo "Error: download_model_from_s3 requires S3 URI and destination path" >&2
        return 1
    fi

    echo "Downloading model from ${s3_uri} to ${dest_path}..." >&2
    mkdir -p "${dest_path}"

    if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
        # Tarball: download and extract
        if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
            echo "Error: Failed to download tarball from ${s3_uri}" >&2
            return 1
        fi
        if ! tar -xzf /tmp/model_archive.tar.gz -C "$dest_path"; then
            echo "Error: Failed to extract tarball from ${s3_uri}" >&2
            rm -f /tmp/model_archive.tar.gz
            return 1
        fi
        rm -f /tmp/model_archive.tar.gz
    elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
        # Directory prefix: sync
        if ! aws s3 sync "$s3_uri" "$dest_path"; then
            echo "Error: Failed to sync from ${s3_uri}" >&2
            return 1
        fi
    else
        # Single file: copy
        if ! aws s3 cp "$s3_uri" "$dest_path/"; then
            echo "Error: Failed to copy ${s3_uri}" >&2
            return 1
        fi
    fi

    local duration
    duration=$(( $(date +%s) - start_time ))
    echo "Download complete: ${s3_uri} → ${dest_path} (${duration}s)" >&2
}
<% } %>

# ---------------------------------------------------------------------------
# Model Loading Adapter — resolve model based on MODEL_SOURCE env var
# ---------------------------------------------------------------------------
MODEL_SOURCE="${MODEL_SOURCE:-huggingface}"
MODEL_ARTIFACT_URI="${MODEL_ARTIFACT_URI:-}"
LOCAL_MODEL_PATH="/opt/ml/model"

<% if (modelServer === 'vllm') { %>
_MODEL_VAR="VLLM_MODEL"
<% } else if (modelServer === 'sglang') { %>
_MODEL_VAR="SGLANG_MODEL_PATH"
<% } else if (modelServer === 'tensorrt-llm') { %>
_MODEL_VAR="TRTLLM_MODEL"
<% } %>

resolve_model() {
    case "$MODEL_SOURCE" in
        huggingface)
            # Pass model name directly — server fetches from HF Hub
            echo "${!_MODEL_VAR}"
            return
            ;;
        s3|registry)
            # Check for pre-mounted artifacts first
            if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
                echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
                echo "$LOCAL_MODEL_PATH"
                return
            fi

            # For registry:// models, resolve artifact URI at runtime via SageMaker API
            if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
                local model_uri="${!_MODEL_VAR}"
                local registry_prefix="registry://"
                if [[ "$model_uri" == "${registry_prefix}"* ]]; then
                    local registry_path="${model_uri#${registry_prefix}}"
                    local group_name="${registry_path%%/*}"
                    local version="${registry_path#*/}"
                    local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"

                    # Get account ID for ARN construction
                    local account_id
                    account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
                        echo "Error: Failed to get AWS account ID for model package ARN" >&2
                        exit 1
                    }

                    local package_arn="arn:aws:sagemaker:${region}:${account_id}:model-package/${group_name}/${version}"
                    echo "Resolving ${model_uri} via SageMaker DescribeModelPackage..." >&2
                    echo "   ARN: ${package_arn}" >&2

                    local describe_output
                    describe_output=$(aws sagemaker describe-model-package \
                        --model-package-name "$package_arn" \
                        --region "$region" \
                        --output json 2>/dev/null) || {
                        echo "Error: Failed to describe model package: ${package_arn}" >&2
                        exit 1
                    }

                    # Try ModelDataUrl first, then S3DataSource.S3Uri, then description
                    MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
import sys, json, re
try:
    pkg = json.load(sys.stdin)
    uri = ''
    # Check InferenceSpecification.Containers[0]
    containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
    if containers:
        c = containers[0]
        uri = c.get('ModelDataUrl', '')
        if not uri:
            uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
    # Fallback: extract S3 URI from ModelPackageDescription
    if not uri:
        desc = pkg.get('ModelPackageDescription', '')
        m = re.search(r's3://[^\s]+', desc)
        if m:
            uri = m.group(0)
    # Fallback: check ModelCard hyperparameters for model_artifacts_s3
    if not uri:
        try:
            card = pkg.get('ModelCard', {})
            content = card.get('ModelCardContent', '{}')
            card_data = json.loads(content) if isinstance(content, str) else content
            params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
            for p in params:
                if p.get('name') == 'model_artifacts_s3':
                    uri = p.get('value', '')
                    break
        except:
            pass
    print(uri)
except:
    print('')
" 2>/dev/null)

                    if [ -n "$MODEL_ARTIFACT_URI" ] && [ "$MODEL_ARTIFACT_URI" != "None" ]; then
                        echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
                    else
                        echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
                        echo "   Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
                        echo "   Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
                        exit 1
                    fi
                fi
            fi

            # Need artifact URI for download
            if [ -z "$MODEL_ARTIFACT_URI" ]; then
                echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
                exit 1
            fi
            # Download from S3
            if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
                echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
                exit 1
            fi
            echo "$LOCAL_MODEL_PATH"
            ;;
        *)
            # Unrecognized source — treat as huggingface
            echo "${!_MODEL_VAR}"
            return
            ;;
    esac
}

_RESOLVED_MODEL=$(resolve_model) || exit 1
export "${_MODEL_VAR}=${_RESOLVED_MODEL}"
echo "Resolved ${_MODEL_VAR}=${_RESOLVED_MODEL} (source: ${MODEL_SOURCE})"
unset _MODEL_VAR _RESOLVED_MODEL

# Initialize server arguments
<% if (modelServer === 'tensorrt-llm') { %>
# port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
<% } else { %>
# port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
SERVER_ARGS=(--host 0.0.0.0 --port 8080)
<% } %>

# Define the prefix for environment variables to look for
<% if (modelServer === 'vllm') { %>
PREFIX="VLLM_"
<% } else if (modelServer === 'sglang') { %>
PREFIX="SGLANG_"
<% } else if (modelServer === 'tensorrt-llm') { %>
PREFIX="TRTLLM_"
<% } %>
ARG_PREFIX="--"

# Define environment variables to exclude (internal variables set by base images)
<% if (modelServer === 'vllm') { %>
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
<% } else if (modelServer === 'sglang') { %>
EXCLUDE_VARS=()
<% } else if (modelServer === 'tensorrt-llm') { %>
# Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
EXCLUDE_VARS=("TRTLLM_MODEL")
<% } %>

# Declare and populate array of matching environment variables
mapfile -t env_vars < <(env | grep "^${PREFIX}")

# Loop through the array and convert to command-line arguments
for var in "${env_vars[@]}"; do
    IFS='=' read -r key value <<< "$var"
    
    # Skip excluded variables
    skip=false
    for exclude in "${EXCLUDE_VARS[@]}"; do
        if [ "$key" = "$exclude" ]; then
            skip=true
            break
        fi
    done
    
    if [ "$skip" = true ]; then
        continue
    fi
    
    # Remove prefix, convert to lowercase, and replace underscores with dashes
    arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')

    # Boolean handling: true = flag only, false = skip entirely
    if [ "$value" = "false" ]; then
        continue
    fi

    SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
    if [ -n "$value" ] && [ "$value" != "true" ]; then
        SERVER_ARGS+=("$value")
    fi
done

echo "-------------------------------------------------------------------"
<% if (modelServer === 'vllm') { %>
echo "vLLM engine args: [${SERVER_ARGS[@]}]"
<% } else if (modelServer === 'sglang') { %>
echo "SGLang engine args: [${SERVER_ARGS[@]}]"
<% } else if (modelServer === 'tensorrt-llm') { %>
echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
<% } %>
echo "-------------------------------------------------------------------"

# Pass the collected arguments to the main entrypoint
<% if (modelServer === 'vllm') { %>
exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
<% } else if (modelServer === 'sglang') { %>
exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
<% } else if (modelServer === 'tensorrt-llm') { %>
# TensorRT-LLM requires the model as a positional argument
# Syntax: trtllm-serve serve MODEL [OPTIONS]
if [ -z "$TRTLLM_MODEL" ]; then
    echo "Error: TRTLLM_MODEL environment variable is not set"
    exit 1
fi
exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
<% } %>
<% } %>
