#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# do/stage — Pre-stage model weights from HuggingFace to S3
# Downloads the model using huggingface-cli and syncs to S3 so that
# vLLM can load directly from S3 at deploy time (fast cold-start).
#
# Idempotent: if the model is already staged (config.json exists at
# the target S3 path), the script exits early.
#
# Usage:
#   ./do/stage                       Stage model to S3
#   ./do/stage --force               Re-stage even if already present in S3
#   ./do/stage --update-config       Stage and update MODEL_NAME in do/config
#   ./do/stage --submit              Submit as SageMaker Processing Job (for models >500GB)
#   ./do/stage --submit --no-wait    Submit and exit without polling

set -e
set -u
set -o pipefail

# ── Source project configuration ──────────────────────────────────────────────
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/config"

# ── Parse flags ───────────────────────────────────────────────────────────────
FORCE=false
UPDATE_CONFIG=false
SUBMIT_MODE=false
NO_WAIT=false
while [ $# -gt 0 ]; do
    case "$1" in
        --force) FORCE=true; shift ;;
        --update-config) UPDATE_CONFIG=true; shift ;;
        --submit) SUBMIT_MODE=true; shift ;;
        --no-wait) NO_WAIT=true; shift ;;
        --help|-h)
            echo "Usage: ./do/stage [--force] [--update-config] [--submit] [--no-wait]"
            echo ""
            echo "Pre-stage model weights from HuggingFace to S3."
            echo ""
            echo "Modes:"
            echo "  (default)    Download locally then sync to S3"
            echo "  --submit     Submit as SageMaker Processing Job (for models >500GB)"
            echo ""
            echo "Options:"
            echo "  --force          Re-stage even if model already exists in S3"
            echo "  --update-config  Update MODEL_NAME in do/config to the staged S3 URI"
            echo "  --no-wait        (with --submit) Exit without polling for completion"
            echo ""
            echo "Environment:"
            echo "  HF_TOKEN   HuggingFace token (for gated models)"
            echo ""
            echo "The staged S3 URI will be printed on completion."
            echo "Pass --update-config to automatically update do/config for S3-backed deploys."
            echo ""
            echo "The --submit mode uses a SageMaker Processing Job with 2TB attached"
            echo "storage, suitable for very large models that exceed local disk capacity."
            exit 0
            ;;
        *) shift ;;
    esac
done

# ── Processing Job submission function ────────────────────────────────────────
# Submits a SageMaker Processing Job that downloads model weights from HuggingFace
# and syncs them to S3. Uses 2TB attached storage to handle any model size.
POLL_INTERVAL=30
PROCESSING_JOB_INSTANCE_TYPE="ml.m5.xlarge"
PROCESSING_JOB_VOLUME_GB=2048

_submit_processing_job() {
    echo "🚀 Submitting SageMaker Processing Job for model staging"
    echo "   Model: ${MODEL_NAME}"
    echo "   Target: ${MODEL_S3_URI}"
    echo "   Instance: ${PROCESSING_JOB_INSTANCE_TYPE}"
    echo "   Storage: ${PROCESSING_JOB_VOLUME_GB} GB"
    echo ""

    # Validate AWS credentials
    if ! aws sts get-caller-identity &>/dev/null; then
        echo "❌ AWS credentials not configured or expired."
        echo "   Run: aws configure"
        exit 4
    fi

    # Resolve execution role from profile
    local execution_role
    execution_role=$(echo "${_PROFILE_JSON}" | python3 -c "
import sys, json
p = json.load(sys.stdin)
print(p.get('executionRoleArn', ''))
" 2>/dev/null) || execution_role=""

    if [ -z "${execution_role}" ]; then
        echo "❌ No execution role configured."
        echo "   Run 'ml-container-creator bootstrap' to set up your profile."
        echo "   The role needs: SageMaker, S3, and Secrets Manager permissions."
        exit 1
    fi

    # Resolve HF token ARN for the processing job (optional — for gated models)
    local hf_token_secret_arn="${HF_TOKEN_ARN:-}"

    # Generate job name with timestamp
    local timestamp
    timestamp=$(date +%Y%m%d-%H%M%S)
    local job_name="mlcc-stage-${PROJECT_NAME}-${timestamp}"
    # SageMaker job names max 63 chars, must match [a-zA-Z0-9](-*[a-zA-Z0-9])*
    job_name=$(echo "${job_name}" | cut -c1-63 | sed 's/[^a-zA-Z0-9-]/-/g' | sed 's/-*$//')

    echo "   Job name: ${job_name}"
    echo ""

    # Build the entrypoint script that runs inside the processing container
    local entrypoint_script
    entrypoint_script=$(cat <<'ENTRYPOINT_EOF'
#!/bin/bash
set -e
set -o pipefail

echo "=== MCC Model Staging Processing Job ==="
echo "Model: ${MODEL_ID}"
echo "Target: ${S3_OUTPUT_URI}"
echo ""

# Install dependencies
echo "📦 Installing huggingface-cli and hf_transfer..."
pip install -q huggingface_hub[cli] hf_transfer

# Enable fast parallel downloads
export HF_HUB_ENABLE_HF_TRANSFER=1

# Set HF token if provided
if [ -n "${HF_TOKEN:-}" ]; then
    echo "🔐 Using provided HuggingFace token"
fi

# Download model from HuggingFace
echo ""
echo "⬇️  Downloading model: ${MODEL_ID}"
DOWNLOAD_ARGS="${MODEL_ID}"
if [ -n "${HF_TOKEN:-}" ]; then
    DOWNLOAD_ARGS="${DOWNLOAD_ARGS} --token ${HF_TOKEN}"
fi
huggingface-cli download ${DOWNLOAD_ARGS}

echo ""
echo "✅ Download complete"

# Locate downloaded files
CACHE_PATH=$(python3 -c "
from huggingface_hub import snapshot_download
path = snapshot_download('${MODEL_ID}', local_files_only=True)
print(path)
")

echo "📁 Cache path: ${CACHE_PATH}"

# Sync to S3
echo ""
echo "☁️  Syncing to S3: ${S3_OUTPUT_URI}"
aws s3 sync "${CACHE_PATH}" "${S3_OUTPUT_URI}" \
    --no-progress \
    --exclude "*.lock" \
    --exclude ".gitattributes"

echo ""
echo "✅ Model staged successfully to: ${S3_OUTPUT_URI}"
ENTRYPOINT_EOF
)

    # Build environment variables for the container
    local env_vars="MODEL_ID=${MODEL_NAME},S3_OUTPUT_URI=${MODEL_S3_URI}"
    if [ -n "${hf_token_secret_arn}" ]; then
        # Resolve token and pass as env var to the job
        local hf_token_value=""
        hf_token_value=$(aws secretsmanager get-secret-value \
            --secret-id "${hf_token_secret_arn}" \
            --query SecretString --output text 2>/dev/null) || hf_token_value=""
        if [ -n "${hf_token_value}" ]; then
            env_vars="${env_vars},HF_TOKEN=${hf_token_value}"
        fi
    elif [ -n "${HF_TOKEN:-}" ]; then
        env_vars="${env_vars},HF_TOKEN=${HF_TOKEN}"
    fi

    # Write entrypoint to a temp file for the processing job input
    local entrypoint_s3_key="staging-jobs/${job_name}/entrypoint.sh"
    local entrypoint_s3_uri="s3://${STAGE_S3_BUCKET}/${entrypoint_s3_key}"

    echo "📤 Uploading entrypoint script..."
    echo "${entrypoint_script}" | aws s3 cp - "${entrypoint_s3_uri}" --region "${AWS_REGION}"

    # Create the processing job
    # Uses a lightweight Python image with AWS CLI pre-installed
    local container_image="763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-cpu-py310-ubuntu20.04-sagemaker"

    local processing_request
    processing_request=$(python3 -c "
import json, sys

job = {
    'ProcessingJobName': '${job_name}',
    'ProcessingResources': {
        'ClusterConfig': {
            'InstanceCount': 1,
            'InstanceType': '${PROCESSING_JOB_INSTANCE_TYPE}',
            'VolumeSizeInGB': ${PROCESSING_JOB_VOLUME_GB}
        }
    },
    'AppSpecification': {
        'ImageUri': '${container_image}',
        'ContainerEntrypoint': ['bash', '-c'],
        'ContainerArguments': ['aws s3 cp ${entrypoint_s3_uri} /tmp/entrypoint.sh && chmod +x /tmp/entrypoint.sh && /tmp/entrypoint.sh']
    },
    'Environment': dict(item.split('=', 1) for item in '${env_vars}'.split(',')),
    'RoleArn': '${execution_role}',
    'StoppingCondition': {
        'MaxRuntimeInSeconds': 86400
    }
}

print(json.dumps(job, indent=2))
")

    # Write request JSON to temp file
    local request_file="/tmp/mlcc-stage-request-${timestamp}.json"
    echo "${processing_request}" > "${request_file}"

    echo "🚀 Creating Processing Job: ${job_name}"
    echo ""

    local create_output
    local create_exit_code
    create_output=$(aws sagemaker create-processing-job \
        --cli-input-json "file://${request_file}" \
        --region "${AWS_REGION}" 2>&1) || create_exit_code=$?
    create_exit_code=${create_exit_code:-0}

    rm -f "${request_file}"

    if [ ${create_exit_code} -ne 0 ]; then
        echo "❌ Failed to create Processing Job"
        echo "   ${create_output}"
        echo ""
        if echo "${create_output}" | grep -q "AccessDeniedException"; then
            echo "   Remediation: ensure the execution role has sagemaker:CreateProcessingJob permission"
        fi
        exit 1
    fi

    echo "   ✅ Processing Job submitted: ${job_name}"
    echo ""

    # Handle --no-wait
    if [ "${NO_WAIT}" = true ]; then
        echo "   --no-wait specified. Job submitted, exiting without polling."
        echo ""
        echo "   Check status:"
        echo "     aws sagemaker describe-processing-job --processing-job-name ${job_name} --region ${AWS_REGION}"
        echo ""
        echo "   On completion, the staged model will be at:"
        echo "     ${MODEL_S3_URI}"
        return 0
    fi

    # Poll for completion
    _poll_processing_job "${job_name}"
}

# ── Poll Processing Job status ────────────────────────────────────────────────
_poll_processing_job() {
    local job_name="$1"

    echo "⏳ Polling Processing Job status (every ${POLL_INTERVAL}s)..."
    echo "   (Ctrl+C to stop polling — job continues in background)"
    echo ""

    while true; do
        local describe_output
        local describe_exit_code
        describe_output=$(aws sagemaker describe-processing-job \
            --processing-job-name "${job_name}" \
            --region "${AWS_REGION}" 2>&1) || describe_exit_code=$?
        describe_exit_code=${describe_exit_code:-0}

        if [ ${describe_exit_code} -ne 0 ]; then
            echo "   ⚠️  Failed to describe job (will retry): ${describe_output}"
            sleep "${POLL_INTERVAL}"
            continue
        fi

        # Parse status from response
        local job_status
        local failure_reason
        job_status=$(echo "${describe_output}" | python3 -c "
import sys, json
d = json.load(sys.stdin)
print(d.get('ProcessingJobStatus', 'Unknown'))
" 2>/dev/null) || job_status="Unknown"

        failure_reason=$(echo "${describe_output}" | python3 -c "
import sys, json
d = json.load(sys.stdin)
print(d.get('FailureReason', ''))
" 2>/dev/null) || failure_reason=""

        # Print status
        local now
        now=$(date +%H:%M:%S)
        echo "   [${now}] Status: ${job_status}"

        # Handle terminal states
        case "${job_status}" in
            Completed)
                echo ""
                echo "✅ Processing Job completed: ${job_name}"
                echo ""
                echo "   S3 URI: ${MODEL_S3_URI}"
                echo ""
                if [ "${UPDATE_CONFIG}" = true ]; then
                    CONFIG_FILE="${SCRIPT_DIR}/config"
                    sed -i.bak "s|^export MODEL_NAME=.*|export MODEL_NAME=\"${MODEL_S3_URI}\"|" "${CONFIG_FILE}"
                    rm -f "${CONFIG_FILE}.bak"
                    echo "   ✅ Updated MODEL_NAME in do/config → ${MODEL_S3_URI}"
                    echo ""
                    echo "   Re-deploy with S3-backed model: ./do/deploy"
                else
                    echo "   To use this staged model, update do/config:"
                    echo "   export MODEL_NAME=\"${MODEL_S3_URI}\""
                    echo ""
                    echo "   Or re-run with --update-config:"
                    echo "   ./do/stage --submit --update-config"
                fi
                return 0
                ;;
            Failed)
                echo ""
                echo "❌ Processing Job failed: ${job_name}"
                if [ -n "${failure_reason}" ]; then
                    echo "   Reason: ${failure_reason}"
                fi
                echo ""
                echo "   Check CloudWatch logs:"
                echo "     /aws/sagemaker/ProcessingJobs/${job_name}"
                echo ""
                echo "   To retry: ./do/stage --submit --force"
                return 1
                ;;
            Stopped)
                echo ""
                echo "⏹️  Processing Job was stopped: ${job_name}"
                echo ""
                echo "   To retry: ./do/stage --submit --force"
                return 2
                ;;
        esac

        sleep "${POLL_INTERVAL}"
    done
}

# ── Check if model is already an S3 URI ──────────────────────────────────────
if [[ "${MODEL_NAME}" == s3://* ]]; then
    echo "✅ Model is already an S3 URI: ${MODEL_NAME}"
    echo "   Nothing to stage."
    exit 0
fi

echo "📦 Staging model: ${MODEL_NAME}"
echo "   Project: ${PROJECT_NAME}"
echo ""

# ── Resolve profile for S3 bucket ────────────────────────────────────────────
_PROFILE_JSON=""
if command -v python3 &>/dev/null; then
    _PROFILE_JSON=$(python3 -c "
import json, os
config_path = os.path.expanduser('~/.ml-container-creator/config.json')
try:
    with open(config_path) as f:
        config = json.load(f)
    profile = config['profiles'][config['activeProfile']]
    print(json.dumps(profile))
except:
    print('{}')
" 2>/dev/null) || _PROFILE_JSON="{}"
fi

# Extract the benchmark S3 bucket from profile (used for model staging)
STAGE_S3_BUCKET=$(echo "${_PROFILE_JSON}" | python3 -c "
import sys, json
p = json.load(sys.stdin)
bucket = p.get('benchmarkS3Bucket', '')
if not bucket:
    acct = p.get('accountId', 'unknown')
    region = p.get('awsRegion', 'us-east-1')
    bucket = f'ml-container-creator-benchmark-{region}-{acct}'
print(bucket)
" 2>/dev/null) || STAGE_S3_BUCKET=""

if [ -z "${STAGE_S3_BUCKET}" ]; then
    echo "❌ Could not determine S3 bucket for staging."
    echo "   Run 'ml-container-creator bootstrap' to set up your profile."
    exit 1
fi

# Target S3 path for staged model
MODEL_S3_URI="s3://${STAGE_S3_BUCKET}/models/${PROJECT_NAME}/"

echo "   Target: ${MODEL_S3_URI}"
echo ""

# ── Submit mode: SageMaker Processing Job ─────────────────────────────────────
# For very large models (>500GB) that exceed local disk, submit a Processing Job
# with 2TB attached storage. The job downloads from HuggingFace and syncs to S3.
if [ "${SUBMIT_MODE}" = true ]; then
    _submit_processing_job
    exit $?
fi

# ── Idempotency: check if model is already staged ────────────────────────────
if [ "${FORCE}" = false ]; then
    if aws s3 ls "${MODEL_S3_URI}config.json" --region "${AWS_REGION}" &>/dev/null; then
        echo "✅ Model already staged at: ${MODEL_S3_URI}"
        echo "   Use --force to re-stage."
        echo ""
        if [ "${UPDATE_CONFIG}" = true ]; then
            CONFIG_FILE="${SCRIPT_DIR}/config"
            sed -i.bak "s|^export MODEL_NAME=.*|export MODEL_NAME=\"${MODEL_S3_URI}\"|" "${CONFIG_FILE}"
            rm -f "${CONFIG_FILE}.bak"
            echo "   ✅ Updated MODEL_NAME in do/config → ${MODEL_S3_URI}"
        else
            echo "   To use this staged model, set in do/config:"
            echo "   export MODEL_NAME=\"${MODEL_S3_URI}\""
        fi
        exit 0
    fi
fi

# ── Validate prerequisites ───────────────────────────────────────────────────
if ! command -v huggingface-cli &>/dev/null; then
    echo "❌ huggingface-cli is not installed"
    echo "   Install: pip install huggingface_hub[cli] hf_transfer"
    exit 2
fi

if ! command -v aws &>/dev/null; then
    echo "❌ AWS CLI is not installed"
    echo "   Install: https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html"
    exit 2
fi

# Validate AWS credentials
if ! aws sts get-caller-identity &>/dev/null; then
    echo "❌ AWS credentials not configured or expired."
    echo "   Run: aws configure"
    exit 4
fi

# ── Resolve HuggingFace token (for gated models) ─────────────────────────────
if [ -n "${HF_TOKEN_ARN:-}" ] && [ -z "${HF_TOKEN:-}" ]; then
    echo "🔐 Resolving HuggingFace token from Secrets Manager..."
    HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "${HF_TOKEN_ARN}" --query SecretString --output text) || {
        echo "⚠️  Failed to resolve HF token from Secrets Manager (continuing without token)"
        HF_TOKEN=""
    }
    export HF_TOKEN
fi

# ── Download model from HuggingFace ──────────────────────────────────────────
echo "⬇️  Downloading model from HuggingFace: ${MODEL_NAME}"
echo "   Using hf_transfer for fast parallel downloads..."
echo ""

# Enable fast parallel downloads via hf_transfer
export HF_HUB_ENABLE_HF_TRANSFER=1

# Download to HF cache (huggingface-cli manages cache location)
DOWNLOAD_ARGS=("${MODEL_NAME}")
if [ -n "${HF_TOKEN:-}" ]; then
    DOWNLOAD_ARGS+=("--token" "${HF_TOKEN}")
fi

if ! huggingface-cli download "${DOWNLOAD_ARGS[@]}"; then
    echo "❌ Failed to download model from HuggingFace: ${MODEL_NAME}"
    echo ""
    echo "Possible causes:"
    echo "  • Model name is incorrect"
    echo "  • Model is gated and requires HF_TOKEN"
    echo "  • Network connectivity issues"
    exit 3
fi

echo ""
echo "✅ Download complete"

# ── Locate downloaded files in HF cache ───────────────────────────────────────
# huggingface-cli downloads to ~/.cache/huggingface/hub/models--<org>--<name>/snapshots/<rev>/
HF_CACHE_DIR=$(python3 -c "
from huggingface_hub import snapshot_download
import os
path = snapshot_download('${MODEL_NAME}', local_files_only=True)
print(path)
" 2>/dev/null) || HF_CACHE_DIR=""

if [ -z "${HF_CACHE_DIR}" ] || [ ! -d "${HF_CACHE_DIR}" ]; then
    # Fallback: construct the path manually
    MODEL_DIR_NAME=$(echo "${MODEL_NAME}" | tr '/' '--')
    HF_CACHE_DIR="${HOME}/.cache/huggingface/hub/models--${MODEL_DIR_NAME}/snapshots"
    # Use the latest snapshot
    if [ -d "${HF_CACHE_DIR}" ]; then
        HF_CACHE_DIR=$(ls -td "${HF_CACHE_DIR}"/*/ 2>/dev/null | head -1)
    fi
fi

if [ -z "${HF_CACHE_DIR}" ] || [ ! -d "${HF_CACHE_DIR}" ]; then
    echo "❌ Could not locate downloaded model files in HuggingFace cache"
    echo "   Expected location: ~/.cache/huggingface/hub/models--${MODEL_NAME//\//-}/snapshots/"
    exit 3
fi

echo "📁 Model cache: ${HF_CACHE_DIR}"

# ── Sync to S3 ───────────────────────────────────────────────────────────────
echo ""
echo "☁️  Syncing model to S3: ${MODEL_S3_URI}"
echo "   This may take a while for large models..."
echo ""

if ! aws s3 sync "${HF_CACHE_DIR}" "${MODEL_S3_URI}" \
    --region "${AWS_REGION}" \
    --no-progress \
    --exclude "*.lock" \
    --exclude ".gitattributes"; then
    echo "❌ Failed to sync model to S3"
    echo ""
    echo "Possible causes:"
    echo "  • Missing S3 write permissions (s3:PutObject)"
    echo "  • Bucket does not exist (run 'ml-container-creator bootstrap')"
    echo "  • Network connectivity issues"
    exit 4
fi

echo ""
echo "✅ Model staged successfully!"
echo ""
echo "   S3 URI: ${MODEL_S3_URI}"
echo ""
if [ "${UPDATE_CONFIG}" = true ]; then
    CONFIG_FILE="${SCRIPT_DIR}/config"
    sed -i.bak "s|^export MODEL_NAME=.*|export MODEL_NAME=\"${MODEL_S3_URI}\"|" "${CONFIG_FILE}"
    rm -f "${CONFIG_FILE}.bak"
    echo "   ✅ Updated MODEL_NAME in do/config → ${MODEL_S3_URI}"
    echo ""
    echo "   Re-deploy with S3-backed model: ./do/deploy"
else
    echo "   To use this staged model, update do/config:"
    echo "   export MODEL_NAME=\"${MODEL_S3_URI}\""
    echo ""
    echo "   Or re-run with --update-config to do it automatically:"
    echo "   ./do/stage --update-config"
    echo ""
    echo "   Then re-deploy: ./do/deploy"
fi
