#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# CUDA compatibility setup (required for newer SageMaker inference AMIs)
source /usr/bin/cuda_compat.sh 2>/dev/null || true

echo "Starting vLLM-Omni server (diffusion model serving)"

# Resolve model URI prefixes that engines cannot handle natively.
# The generator's model-picker may store provider-specific URIs
# (e.g. registry://my-model-group/1) as the model identifier.
# vLLM expects a HuggingFace repo ID or local path.
_RAW_MODEL="${VLLM_MODEL:-}"
if [[ "$_RAW_MODEL" == registry://* ]]; then
    if [ -d /opt/ml/model ] && [ "$(ls -A /opt/ml/model 2>/dev/null)" ]; then
        echo "Resolved VLLM_MODEL='${_RAW_MODEL}' → /opt/ml/model (local artifacts found)"
        export VLLM_MODEL="/opt/ml/model"
    else
        _BARE_ID="${_RAW_MODEL#*://}"
        echo "Warning: VLLM_MODEL='${_RAW_MODEL}' has a provider prefix but /opt/ml/model is empty."
        echo "Stripping prefix → '${_BARE_ID}' (engine will attempt to fetch from model hub)"
        export VLLM_MODEL="${_BARE_ID}"
    fi
fi
unset _RAW_MODEL _BARE_ID

# Validate that the model name is set
if [ -z "$VLLM_MODEL" ]; then
    echo "Error: VLLM_MODEL environment variable is not set"
    exit 1
fi

# Initialize server arguments with --omni flag
# --omni activates vLLM-Omni diffusion/multi-stage support
# port 8081 is the internal port; nginx on 8080 handles SageMaker routing
#   /invocations -> /v1/images/generations
#   /ping -> /health
SERVER_ARGS=(--omni --host 0.0.0.0 --port 8081)

# Define the prefix for environment variables to look for
# Uses VLLM_OMNI_ prefix to avoid conflicts with base vLLM env vars
PREFIX="VLLM_OMNI_"
ARG_PREFIX="--"

# Define environment variables to exclude from CLI flag conversion
# VLLM_MODEL is used as the positional model argument, not a --flag
EXCLUDE_VARS=("VLLM_MODEL")

# Declare and populate array of matching environment variables
mapfile -t env_vars < <(env | grep "^${PREFIX}")

# Loop through the array and convert to command-line arguments
for var in "${env_vars[@]}"; do
    IFS='=' read -r key value <<< "$var"

    # Skip excluded variables
    skip=false
    for exclude in "${EXCLUDE_VARS[@]}"; do
        if [ "$key" = "$exclude" ]; then
            skip=true
            break
        fi
    done

    if [ "$skip" = true ]; then
        continue
    fi

    # Remove prefix, convert to lowercase, and replace underscores with dashes
    arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
    SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
    if [ -n "$value" ]; then
        SERVER_ARGS+=("$value")
    fi
done

echo "-------------------------------------------------------------------"
echo "vLLM-Omni engine args: vllm serve $VLLM_MODEL [${SERVER_ARGS[@]}]"
echo "-------------------------------------------------------------------"

# Launch vLLM-Omni on internal port (8081), then nginx on SageMaker port (8080)
vllm serve "$VLLM_MODEL" "${SERVER_ARGS[@]}" &
VLLM_PID=$!

# Wait for vLLM-Omni to be ready before starting nginx
echo "Waiting for vLLM-Omni server to start..."
for i in {1..300}; do
    if curl -s http://localhost:8081/health > /dev/null 2>&1; then
        echo "vLLM-Omni server is ready!"
        break
    fi
    if ! kill -0 $VLLM_PID 2>/dev/null; then
        echo "Error: vLLM-Omni process exited unexpectedly"
        exit 1
    fi
    if [ $i -eq 300 ]; then
        echo "Error: vLLM-Omni server failed to start within 300 seconds"
        exit 1
    fi
    sleep 1
done

echo "Starting nginx reverse proxy on port 8080..."
nginx -c /etc/nginx/nginx.conf &
NGINX_PID=$!

# Wait for either process to exit (this keeps the container running)
wait -n $VLLM_PID $NGINX_PID

# If we get here, one process exited - this is an error condition
EXIT_CODE=$?
echo "Error: Process exited with code $EXIT_CODE"

# Kill any remaining processes
kill $VLLM_PID $NGINX_PID 2>/dev/null || true

exit $EXIT_CODE
