#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# do/tune — SageMaker AI Managed Model Customization
# Wraps SageMaker managed fine-tuning for supported foundation models.
# Supports SFT, DPO, RLAIF, and RLVR techniques.

set -e
set -u
set -o pipefail

# ── Source project configuration ──────────────────────────────────────────────
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/config"

# ── Constants ─────────────────────────────────────────────────────────────────
CATALOG_FILE="${SCRIPT_DIR}/.tune_catalog.json"
HELPER_SCRIPT="${SCRIPT_DIR}/.tune_helper.py"
POLL_INTERVAL=60

# ── CLI Variables (set by _parse_args) ────────────────────────────────────────
ARG_TECHNIQUE=""
ARG_DATASET=""
ARG_TRAINING_TYPE="lora"
ARG_MODEL=""
ARG_EPOCHS=""
ARG_LEARNING_RATE=""
ARG_MAX_SEQ_LENGTH=""
ARG_LORA_RANK=""
ARG_LORA_ALPHA=""
ARG_BATCH_SIZE=""
ARG_REWARD_FUNCTION=""
ARG_REWARD_PROMPT=""
ARG_OUTPUT_BUCKET=""
ARG_ROLE=""
ARG_FORCE=false
ARG_NO_WAIT=false
ARG_STATUS=false
ARG_HELP=false
ARG_DRY_RUN=false
ARG_LIST_MODELS=false


# ── _parse_args() ─────────────────────────────────────────────────────────────
# Parse all CLI flags into variables.
_parse_args() {
    while [ $# -gt 0 ]; do
        case "$1" in
            --technique)
                if [ -z "${2:-}" ]; then
                    echo "❌ --technique requires a value (sft, dpo, rlaif, rlvr)"
                    exit 1
                fi
                ARG_TECHNIQUE="$2"; shift 2 ;;
            --dataset)
                if [ -z "${2:-}" ]; then
                    echo "❌ --dataset requires a value (s3://... or hf://...)"
                    exit 1
                fi
                ARG_DATASET="$2"; shift 2 ;;
            --training-type)
                if [ -z "${2:-}" ]; then
                    echo "❌ --training-type requires a value (lora, full-rank)"
                    exit 1
                fi
                ARG_TRAINING_TYPE="$2"; shift 2 ;;
            --model)
                if [ -z "${2:-}" ]; then
                    echo "❌ --model requires a model ID"
                    exit 1
                fi
                ARG_MODEL="$2"; shift 2 ;;
            --epochs)
                if [ -z "${2:-}" ]; then
                    echo "❌ --epochs requires an integer value"
                    exit 1
                fi
                ARG_EPOCHS="$2"; shift 2 ;;
            --learning-rate)
                if [ -z "${2:-}" ]; then
                    echo "❌ --learning-rate requires a float value"
                    exit 1
                fi
                ARG_LEARNING_RATE="$2"; shift 2 ;;
            --max-seq-length)
                if [ -z "${2:-}" ]; then
                    echo "❌ --max-seq-length requires an integer value"
                    exit 1
                fi
                ARG_MAX_SEQ_LENGTH="$2"; shift 2 ;;
            --lora-rank)
                if [ -z "${2:-}" ]; then
                    echo "❌ --lora-rank requires an integer value"
                    exit 1
                fi
                ARG_LORA_RANK="$2"; shift 2 ;;
            --lora-alpha)
                if [ -z "${2:-}" ]; then
                    echo "❌ --lora-alpha requires an integer value"
                    exit 1
                fi
                ARG_LORA_ALPHA="$2"; shift 2 ;;
            --batch-size)
                if [ -z "${2:-}" ]; then
                    echo "❌ --batch-size requires an integer value"
                    exit 1
                fi
                ARG_BATCH_SIZE="$2"; shift 2 ;;
            --reward-function)
                if [ -z "${2:-}" ]; then
                    echo "❌ --reward-function requires a Lambda ARN"
                    exit 1
                fi
                ARG_REWARD_FUNCTION="$2"; shift 2 ;;
            --reward-prompt)
                if [ -z "${2:-}" ]; then
                    echo "❌ --reward-prompt requires an S3 URI"
                    exit 1
                fi
                ARG_REWARD_PROMPT="$2"; shift 2 ;;
            --output-bucket)
                if [ -z "${2:-}" ]; then
                    echo "❌ --output-bucket requires a bucket name"
                    exit 1
                fi
                ARG_OUTPUT_BUCKET="$2"; shift 2 ;;
            --role)
                if [ -z "${2:-}" ]; then
                    echo "❌ --role requires an IAM role ARN"
                    exit 1
                fi
                ARG_ROLE="$2"; shift 2 ;;
            --force) ARG_FORCE=true; shift ;;
            --no-wait) ARG_NO_WAIT=true; shift ;;
            --status) ARG_STATUS=true; shift ;;
            --help|-h) ARG_HELP=true; shift ;;
            --dry-run) ARG_DRY_RUN=true; shift ;;
            --list-models) ARG_LIST_MODELS=true; shift ;;
            *)
                echo "❌ Unknown option: $1"
                echo "   Run ./do/tune --help for usage."
                exit 1
                ;;
        esac
    done
}


# ── _show_help() ──────────────────────────────────────────────────────────────
_show_help() {
    echo "Usage: ./do/tune --technique <technique> --dataset <source> [options]"
    echo "       ./do/tune --status"
    echo "       ./do/tune --list-models"
    echo "       ./do/tune --help"
    echo ""
    echo "SageMaker AI Managed Model Customization — fine-tune supported foundation"
    echo "models using SFT, DPO, RLAIF, or RLVR without managing infrastructure."
    echo ""
    echo "Required:"
    echo "  --technique <t>       Customization technique: sft, dpo, rlaif, rlvr"
    echo "  --dataset <source>    Dataset: s3://bucket/path.jsonl or hf://org/name[/split]"
    echo ""
    echo "Training type:"
    echo "  --training-type <t>   lora (default) or full-rank"
    echo ""
    echo "Hyperparameter overrides (optional):"
    echo "  --epochs <n>          Number of training epochs"
    echo "  --learning-rate <f>   Learning rate (e.g., 2e-4)"
    echo "  --max-seq-length <n>  Maximum sequence length in tokens"
    echo "  --lora-rank <n>       LoRA rank (e.g., 16, 32, 64)"
    echo "  --lora-alpha <n>      LoRA alpha scaling factor"
    echo "  --batch-size <n>      Global batch size"
    echo ""
    echo "Evaluator (RLVR/RLAIF only):"
    echo "  --reward-function <arn>  Lambda ARN for reward function"
    echo "  --reward-prompt <uri>    S3 URI for reward prompt file"
    echo ""
    echo "Overrides:"
    echo "  --model <id>          Override model (defaults to MODEL_ID from do/config)"
    echo "  --output-bucket <b>   Override output bucket (defaults to TUNE_S3_BUCKET)"
    echo "  --role <arn>          Override execution role (defaults to ROLE_ARN)"
    echo ""
    echo "Job control:"
    echo "  --force               Force new job even if one exists for this technique"
    echo "  --no-wait             Submit and exit without polling for completion"
    echo "  --status              Show status of all tracked tune jobs"
    echo ""
    echo "Informational:"
    echo "  --help, -h            Show this help message"
    echo "  --dry-run             Validate inputs and show what would be submitted"
    echo "  --list-models         Print supported models, techniques, and training types"
    echo ""
    echo "Examples:"
    echo "  ./do/tune --technique sft --dataset s3://my-bucket/train.jsonl"
    echo "  ./do/tune --technique dpo --dataset hf://my-org/pref-data --learning-rate 1e-5"
    echo "  ./do/tune --technique sft --dataset s3://bucket/data.jsonl --training-type full-rank"
    echo "  ./do/tune --status"
    echo "  ./do/tune --technique sft --dataset s3://bucket/data.jsonl --dry-run"
    exit 0
}

# ── _show_status() ────────────────────────────────────────────────────────────
# Display status of all tracked tune jobs from do/config.
_show_status() {
    echo "📊 Tune Job Status"
    echo ""

    local found_any=false
    for technique in sft dpo rlaif rlvr; do
        local var_name="TUNE_JOB_NAME_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
        local job_name="${!var_name:-}"

        if [ -n "${job_name}" ]; then
            found_any=true
            echo "   ${technique^^}:"
            echo "     Job: ${job_name}"

            # Query status via Python helper
            local status_json
            status_json=$(python3 "${HELPER_SCRIPT}" status \
                --job-name "${job_name}" \
                --region "${AWS_REGION}" 2>/dev/null) || status_json='{"status":"Unknown","error":"Failed to query"}'

            local status
            status=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('status','Unknown'))" 2>/dev/null) || status="Unknown"

            local elapsed
            elapsed=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"

            echo "     Status: ${status}"
            if [ "${elapsed}" != "0" ]; then
                local mins=$((elapsed / 60))
                local secs=$((elapsed % 60))
                echo "     Elapsed: ${mins}m ${secs}s"
            fi

            # Show output path if completed
            local output_var="TUNE_ADAPTER_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
            local model_var="TUNE_MODEL_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
            if [ -n "${!output_var:-}" ]; then
                echo "     Output (adapter): ${!output_var}"
            elif [ -n "${!model_var:-}" ]; then
                echo "     Output (model): ${!model_var}"
            fi
            echo ""
        fi
    done

    if [ "${found_any}" = false ]; then
        echo "   No tune jobs tracked. Run ./do/tune --technique <t> --dataset <d> to start."
    fi

    exit 0
}

# ── _list_models() ────────────────────────────────────────────────────────────
# Print the Supported Model Catalog.
_list_models() {
    if [ ! -f "${CATALOG_FILE}" ]; then
        echo "❌ Catalog file not found: ${CATALOG_FILE}"
        exit 1
    fi

    echo "📋 Supported Models for Managed Customization"
    echo ""

    python3 -c "
import json, sys

with open('${CATALOG_FILE}') as f:
    catalog = json.load(f)

models = catalog.get('models', {})
# Group by family
families = {}
for model_id, entry in models.items():
    family = entry.get('family', 'unknown')
    if family not in families:
        families[family] = []
    families[family].append(entry)

for family in sorted(families.keys()):
    entries = families[family]
    provider = entries[0].get('provider', '')
    print(f'  {family} ({provider}):')
    for entry in entries:
        techniques = list(entry.get('techniques', {}).keys())
        print(f'    • {entry[\"displayName\"]}')
        print(f'      ID: {entry[\"modelId\"]}')
        for t in techniques:
            tc = entry['techniques'][t]
            types = ', '.join(tc.get('trainingTypes', []))
            print(f'      {t}: [{types}]')
    print()
" 2>/dev/null || {
        echo "❌ Failed to parse catalog. Ensure python3 is available."
        exit 1
    }

    exit 0
}


# ── _update_config_var() ──────────────────────────────────────────────────────
# Write or update a variable in do/config.
# Usage: _update_config_var VAR_NAME "value"
_update_config_var() {
    local var_name="$1"
    local var_value="$2"
    local config_file="${SCRIPT_DIR}/config"

    if grep -q "^export ${var_name}=" "${config_file}" 2>/dev/null; then
        sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
        rm -f "${config_file}.bak"
    else
        echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
    fi
}

# ── _validate_model() ─────────────────────────────────────────────────────────
# Read MODEL_ID from do/config (or --model override), check against catalog.
# Sets RESOLVED_MODEL_ID on success.
_validate_model() {
    # Resolve model ID: --model override, MODEL_ID from config, or MODEL_NAME fallback
    if [ -n "${ARG_MODEL}" ]; then
        RESOLVED_MODEL_ID="${ARG_MODEL}"
    elif [ -n "${MODEL_ID:-}" ]; then
        RESOLVED_MODEL_ID="${MODEL_ID}"
    elif [ -n "${MODEL_NAME:-}" ]; then
        RESOLVED_MODEL_ID="${MODEL_NAME}"
    else
        echo "❌ No model configured"
        echo "   Set MODEL_ID in do/config or use --model <id>"
        exit 1
    fi

    if [ ! -f "${CATALOG_FILE}" ]; then
        echo "❌ Catalog file not found: ${CATALOG_FILE}"
        echo "   The tune catalog is required for model validation."
        exit 1
    fi

    # Check if model is in catalog using python3 for JSON parsing
    local result
    result=$(python3 -c "
import json, sys

with open('${CATALOG_FILE}') as f:
    catalog = json.load(f)

model_id = '${RESOLVED_MODEL_ID}'
models = catalog.get('models', {})

if model_id in models:
    print('SUPPORTED')
else:
    # Collect unique families
    families = sorted(set(e.get('family', '') for e in models.values() if e.get('family')))
    print('UNSUPPORTED|' + '|'.join(families))
" 2>/dev/null) || {
        echo "❌ Failed to validate model against catalog"
        echo "   Ensure python3 is available."
        exit 1
    }

    if [ "${result}" = "SUPPORTED" ]; then
        return 0
    fi

    # Model not supported — extract families from result
    local families
    families=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')

    echo "❌ Model \"${RESOLVED_MODEL_ID}\" is not yet supported for managed serverless customization."
    echo "   Supported model families: ${families}"
    echo ""
    echo "   Additional model support and custom training workflows are expected in future releases."
    echo "   For custom training workflows, see \`do/train\`."
    exit 1
}

# ── _validate_technique() ─────────────────────────────────────────────────────
# Check that the technique is supported for the resolved model.
_validate_technique() {
    local technique="${ARG_TECHNIQUE}"

    # Validate technique value
    case "${technique}" in
        sft|dpo|rlaif|rlvr) ;;
        *)
            echo "❌ Invalid technique: ${technique}"
            echo "   Valid techniques: sft, dpo, rlaif, rlvr"
            exit 1
            ;;
    esac

    # Check catalog for model+technique support
    local result
    result=$(python3 -c "
import json, sys

with open('${CATALOG_FILE}') as f:
    catalog = json.load(f)

model_id = '${RESOLVED_MODEL_ID}'
technique = '${technique}'
entry = catalog['models'][model_id]
techniques = entry.get('techniques', {})

if technique in techniques:
    print('SUPPORTED')
else:
    supported = list(techniques.keys())
    print('UNSUPPORTED|' + '|'.join(supported))
" 2>/dev/null) || {
        echo "❌ Failed to validate technique against catalog"
        exit 1
    }

    if [ "${result}" = "SUPPORTED" ]; then
        return 0
    fi

    local supported
    supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')

    echo "❌ Technique \"${technique}\" is not supported for model \"${RESOLVED_MODEL_ID}\"."
    echo "   Supported techniques: ${supported}"
    exit 1
}

# ── _validate_training_type() ─────────────────────────────────────────────────
# Check that the training type is supported for the model+technique.
_validate_training_type() {
    local technique="${ARG_TECHNIQUE}"
    local training_type="${ARG_TRAINING_TYPE}"

    # Validate training type value
    case "${training_type}" in
        lora|full-rank) ;;
        *)
            echo "❌ Invalid training type: ${training_type}"
            echo "   Valid training types: lora, full-rank"
            exit 1
            ;;
    esac

    # Check catalog for model+technique+training_type support
    local result
    result=$(python3 -c "
import json, sys

with open('${CATALOG_FILE}') as f:
    catalog = json.load(f)

model_id = '${RESOLVED_MODEL_ID}'
technique = '${technique}'
training_type = '${training_type}'
entry = catalog['models'][model_id]
technique_entry = entry['techniques'][technique]
training_types = technique_entry.get('trainingTypes', [])

if training_type in training_types:
    print('SUPPORTED')
else:
    print('UNSUPPORTED|' + '|'.join(training_types))
" 2>/dev/null) || {
        echo "❌ Failed to validate training type against catalog"
        exit 1
    }

    if [ "${result}" = "SUPPORTED" ]; then
        return 0
    fi

    local supported
    supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')

    echo "❌ Training type \"${training_type}\" is not supported for model \"${RESOLVED_MODEL_ID}\" with technique \"${technique}\"."
    echo "   Supported training types: ${supported}"
    exit 1
}


# ── _validate_dataset() ───────────────────────────────────────────────────────
# Check S3 existence or delegate HF staging to Python helper.
# Sets RESOLVED_DATASET_S3_URI on success.
_validate_dataset() {
    local dataset="${ARG_DATASET}"

    if [ -z "${dataset}" ]; then
        echo "❌ --dataset is required"
        echo "   Provide an S3 URI (s3://bucket/path.jsonl) or HF reference (hf://org/name)"
        exit 1
    fi

    # Determine dataset type
    if [[ "${dataset}" == s3://* ]]; then
        # S3 dataset — verify existence
        if ! aws s3 ls "${dataset}" --region "${AWS_REGION}" >/dev/null 2>&1; then
            echo "❌ Dataset not found or not accessible: ${dataset}"
            echo "   Verify the S3 URI is correct and you have read permissions."
            echo "   Check: aws s3 ls ${dataset} --region ${AWS_REGION}"
            exit 1
        fi
        RESOLVED_DATASET_S3_URI="${dataset}"

        # Validate format by downloading first 10 lines
        local schema_json
        schema_json=$(_get_dataset_schema)

        local sample_data
        sample_data=$(aws s3 cp "${dataset}" - --region "${AWS_REGION}" 2>/dev/null | head -10)

        if [ -n "${sample_data}" ]; then
            local validate_result
            validate_result=$(echo "${sample_data}" | python3 "${HELPER_SCRIPT}" validate \
                --schema "${schema_json}" 2>/dev/null) || validate_result='{"valid":false,"error":"Validation failed"}'

            local is_valid
            is_valid=$(echo "${validate_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('valid', False))" 2>/dev/null) || is_valid="False"

            if [ "${is_valid}" != "True" ]; then
                local error_msg
                error_msg=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('error','Unknown format error'))" 2>/dev/null) || error_msg="Unknown format error"
                local malformed_line
                malformed_line=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('malformed_line','') or '')" 2>/dev/null) || malformed_line=""
                local expected_format
                expected_format=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('expected_format','') or '')" 2>/dev/null) || expected_format=""

                echo "❌ Dataset format validation failed"
                echo "   ${error_msg}"
                if [ -n "${malformed_line}" ]; then
                    echo ""
                    echo "   Malformed line: ${malformed_line}"
                fi
                if [ -n "${expected_format}" ]; then
                    echo ""
                    echo "   ${expected_format}"
                fi
                exit 1
            fi
        fi

    elif [[ "${dataset}" == hf://* ]]; then
        # Hugging Face dataset — parse reference and stage to S3
        local hf_path="${dataset#hf://}"
        local hf_org hf_name hf_split

        # Parse org/name/split
        hf_org=$(echo "${hf_path}" | cut -d'/' -f1)
        hf_name=$(echo "${hf_path}" | cut -d'/' -f2)
        hf_split=$(echo "${hf_path}" | cut -d'/' -f3-)

        if [ -z "${hf_org}" ] || [ -z "${hf_name}" ]; then
            echo "❌ Invalid HF dataset reference: ${dataset}"
            echo "   Expected format: hf://org/name or hf://org/name/split"
            exit 1
        fi

        local output_bucket
        output_bucket=$(_resolve_output_bucket)

        echo "📦 Staging Hugging Face dataset: ${hf_org}/${hf_name}"
        if [ -n "${hf_split}" ]; then
            echo "   Split: ${hf_split}"
        else
            echo "   Split: train (default)"
        fi

        # Build stage-hf arguments
        local stage_args=(
            --hf-org "${hf_org}"
            --hf-name "${hf_name}"
            --output-bucket "${output_bucket}"
            --project-name "${PROJECT_NAME}"
            --region "${AWS_REGION}"
        )
        if [ -n "${hf_split}" ]; then
            stage_args+=(--hf-split "${hf_split}")
        fi
        if [ -n "${HF_TOKEN_ARN:-}" ]; then
            stage_args+=(--hf-secret-name "${HF_TOKEN_ARN}")
        fi

        local stage_result
        stage_result=$(python3 "${HELPER_SCRIPT}" stage-hf "${stage_args[@]}" 2>/dev/null) || {
            local error_msg
            error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Failed to stage dataset'))" 2>/dev/null) || error_msg="Failed to stage HF dataset"
            echo "❌ ${error_msg}"
            exit 1
        }

        # Check for error in response
        local has_error
        has_error=$(echo "${stage_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"

        if [ "${has_error}" = "yes" ]; then
            local error_msg
            error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
            echo "❌ ${error_msg}"
            exit 1
        fi

        RESOLVED_DATASET_S3_URI=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin)['s3_uri'])" 2>/dev/null)
        local num_records
        num_records=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('num_records',0))" 2>/dev/null) || num_records="0"

        echo "   ✅ Staged to: ${RESOLVED_DATASET_S3_URI}"
        echo "   Records: ${num_records}"
        echo ""

    else
        echo "❌ Invalid dataset format: ${dataset}"
        echo "   Expected: s3://bucket/path.jsonl or hf://org/name[/split]"
        exit 1
    fi
}

# ── _get_dataset_schema() ─────────────────────────────────────────────────────
# Get the dataset schema JSON for the current model+technique from the catalog.
_get_dataset_schema() {
    python3 -c "
import json, sys

with open('${CATALOG_FILE}') as f:
    catalog = json.load(f)

model_id = '${RESOLVED_MODEL_ID}'
technique = '${ARG_TECHNIQUE}'
entry = catalog['models'][model_id]
schema = entry['techniques'][technique].get('datasetSchema', {})
print(json.dumps(schema))
" 2>/dev/null
}

# ── _resolve_output_bucket() ──────────────────────────────────────────────────
# Resolve the S3 output bucket from --output-bucket, TUNE_S3_BUCKET, or fallback.
_resolve_output_bucket() {
    if [ -n "${ARG_OUTPUT_BUCKET}" ]; then
        echo "${ARG_OUTPUT_BUCKET}"
    elif [ -n "${TUNE_S3_BUCKET:-}" ]; then
        echo "${TUNE_S3_BUCKET}"
    elif [ -n "${ADAPTER_S3_BUCKET:-}" ]; then
        echo "${ADAPTER_S3_BUCKET}"
    else
        echo "mlcc-tune-$(aws sts get-caller-identity --query Account --output text 2>/dev/null || echo 'UNKNOWN')-${AWS_REGION}"
    fi
}


# ── _check_idempotency() ──────────────────────────────────────────────────────
# Check TUNE_JOB_NAME_<TECHNIQUE> in config, query status if exists.
# Returns 0 if a new job should be created, 1 if existing job was handled.
_check_idempotency() {
    local technique_upper
    technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
    local var_name="TUNE_JOB_NAME_${technique_upper}"
    local existing_job="${!var_name:-}"

    if [ -z "${existing_job}" ] || [ "${ARG_FORCE}" = true ]; then
        return 0  # No existing job or --force: proceed with new job
    fi

    echo "🔍 Found existing ${ARG_TECHNIQUE^^} job: ${existing_job}"

    # Query status via Python helper
    local status_json
    status_json=$(python3 "${HELPER_SCRIPT}" status \
        --job-name "${existing_job}" \
        --region "${AWS_REGION}" 2>/dev/null) || {
        echo "   ⚠️  Could not query job status — proceeding with new job"
        return 0
    }

    # Check for error response
    local has_error
    has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="no"

    if [ "${has_error}" = "yes" ]; then
        echo "   ⚠️  Could not query job status — proceeding with new job"
        return 0
    fi

    local status
    status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])" 2>/dev/null) || status="Unknown"

    case "${status}" in
        InProgress|Starting)
            echo "   Status: ${status}"
            echo "   (use --force to start a new job instead)"
            echo ""
            # Resume polling
            _poll_job "${existing_job}"
            _handle_completion "${existing_job}"
            return 1
            ;;
        Completed)
            echo "   Status: Completed"
            echo "   (use --force to start a new job)"
            echo ""
            _handle_completion "${existing_job}"
            return 1
            ;;
        Failed)
            local failure_reason
            failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
            echo "   Status: Failed"
            echo "   Reason: ${failure_reason}"
            echo ""
            echo "   Re-run with --force to start a new job."
            exit 2
            ;;
        Stopped)
            echo "   Status: Stopped"
            echo "   Re-run with --force to start a new job."
            exit 2
            ;;
        *)
            echo "   Status: ${status}"
            echo "   Proceeding with new job..."
            return 0
            ;;
    esac
}

# ── _submit_job() ─────────────────────────────────────────────────────────────
# Invoke Python helper with all args to submit the customization job.
# Sets JOB_NAME on success.
_submit_job() {
    local output_bucket
    output_bucket=$(_resolve_output_bucket)

    local role_arn
    if [ -n "${ARG_ROLE}" ]; then
        role_arn="${ARG_ROLE}"
    elif [ -n "${ROLE_ARN:-}" ]; then
        role_arn="${ROLE_ARN}"
    else
        # Try to resolve from SageMaker execution role
        role_arn="${SAGEMAKER_ROLE_ARN:-}"
        if [ -z "${role_arn}" ]; then
            echo "❌ No execution role configured"
            echo "   Set ROLE_ARN in do/config or use --role <arn>"
            exit 1
        fi
    fi

    # Generate unique job name
    local timestamp
    timestamp=$(date +%Y%m%d-%H%M%S)
    JOB_NAME="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"

    echo "🚀 Submitting ${ARG_TECHNIQUE^^} customization job"
    echo "   Job name: ${JOB_NAME}"
    echo "   Model: ${RESOLVED_MODEL_ID}"
    echo "   Technique: ${ARG_TECHNIQUE}"
    echo "   Training type: ${ARG_TRAINING_TYPE}"
    echo "   Dataset: ${RESOLVED_DATASET_S3_URI}"
    echo "   Output bucket: ${output_bucket}"
    echo ""

    # Build submit arguments
    local submit_args=(
        --model-id "${RESOLVED_MODEL_ID}"
        --technique "${ARG_TECHNIQUE}"
        --training-type "${ARG_TRAINING_TYPE}"
        --dataset-s3-uri "${RESOLVED_DATASET_S3_URI}"
        --output-bucket "${output_bucket}"
        --role-arn "${role_arn}"
        --job-name "${JOB_NAME}"
        --project-name "${PROJECT_NAME}"
    )

    # Add model package group
    submit_args+=(--model-package-group "${PROJECT_NAME}-tune-models")

    # Add optional hyperparameters
    if [ -n "${ARG_EPOCHS}" ]; then
        submit_args+=(--epochs "${ARG_EPOCHS}")
    fi
    if [ -n "${ARG_LEARNING_RATE}" ]; then
        submit_args+=(--learning-rate "${ARG_LEARNING_RATE}")
    fi
    if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
        submit_args+=(--max-seq-length "${ARG_MAX_SEQ_LENGTH}")
    fi
    if [ -n "${ARG_LORA_RANK}" ]; then
        submit_args+=(--lora-rank "${ARG_LORA_RANK}")
    fi
    if [ -n "${ARG_LORA_ALPHA}" ]; then
        submit_args+=(--lora-alpha "${ARG_LORA_ALPHA}")
    fi
    if [ -n "${ARG_BATCH_SIZE}" ]; then
        submit_args+=(--batch-size "${ARG_BATCH_SIZE}")
    fi
    if [ -n "${ARG_REWARD_FUNCTION}" ]; then
        submit_args+=(--reward-function "${ARG_REWARD_FUNCTION}")
    fi
    if [ -n "${ARG_REWARD_PROMPT}" ]; then
        submit_args+=(--reward-prompt "${ARG_REWARD_PROMPT}")
    fi

    # Invoke Python helper
    local submit_result
    submit_result=$(python3 "${HELPER_SCRIPT}" submit "${submit_args[@]}" 2>/dev/null) || {
        echo "❌ Failed to submit customization job"
        echo "   Ensure the SageMaker Python SDK is installed: pip install 'sagemaker>=2.232.0'"
        exit 1
    }

    # Check for error in response
    local has_error
    has_error=$(echo "${submit_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"

    if [ "${has_error}" = "yes" ]; then
        local error_msg
        error_msg=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
        echo "❌ ${error_msg}"
        exit 1
    fi

    # Extract job name from response (may differ from our generated name)
    local returned_job_name
    returned_job_name=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('job_name',''))" 2>/dev/null) || returned_job_name=""
    if [ -n "${returned_job_name}" ]; then
        JOB_NAME="${returned_job_name}"
    fi

    # Display MLflow URL if available
    local mlflow_url
    mlflow_url=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('mlflow_url','') or '')" 2>/dev/null) || mlflow_url=""
    if [ -n "${mlflow_url}" ]; then
        echo "   📈 MLflow tracking: ${mlflow_url}"
    fi

    echo "✅ Job submitted: ${JOB_NAME}"
    echo ""

    # Store state in do/config
    local technique_upper
    technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
    _update_config_var "TUNE_JOB_NAME_${technique_upper}" "${JOB_NAME}"
    _update_config_var "TUNE_TECHNIQUE" "${ARG_TECHNIQUE}"
    _update_config_var "TUNE_TRAINING_TYPE" "${ARG_TRAINING_TYPE}"
    _update_config_var "TUNE_DATASET_PATH" "${ARG_DATASET}"
}


# ── _poll_job() ───────────────────────────────────────────────────────────────
# Poll every 60s, display status/elapsed/step, handle Ctrl+C gracefully.
# Exits cleanly on interrupt without stopping the remote job.
_poll_job() {
    local job_name="${1:-${JOB_NAME}}"

    echo "⏳ Polling job status every ${POLL_INTERVAL}s..."
    echo "   (Ctrl+C to exit — job continues in background)"
    echo ""

    # Trap SIGINT for graceful exit
    trap '_handle_interrupt "${job_name}"' INT

    while true; do
        local status_json
        status_json=$(python3 "${HELPER_SCRIPT}" status \
            --job-name "${job_name}" \
            --region "${AWS_REGION}" 2>/dev/null) || {
            echo "   ⚠️  Failed to query status (will retry)"
            sleep "${POLL_INTERVAL}"
            continue
        }

        local status
        status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || status="Unknown"

        local elapsed
        elapsed=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"

        local mins=$((elapsed / 60))
        local secs=$((elapsed % 60))

        case "${status}" in
            Completed)
                echo "   ✅ $(date +%H:%M:%S) Status: Completed (${mins}m ${secs}s)"
                # Restore default signal handling
                trap - INT
                return 0
                ;;
            Failed)
                local failure_reason
                failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
                echo "   ❌ $(date +%H:%M:%S) Status: Failed (${mins}m ${secs}s)"
                echo "      Reason: ${failure_reason}"
                trap - INT
                exit 2
                ;;
            Stopped)
                echo "   ⚠️  $(date +%H:%M:%S) Status: Stopped (${mins}m ${secs}s)"
                trap - INT
                exit 2
                ;;
            *)
                echo "   $(date +%H:%M:%S) Status: ${status} (${mins}m ${secs}s)"
                sleep "${POLL_INTERVAL}"
                ;;
        esac
    done
}

# ── _handle_interrupt() ───────────────────────────────────────────────────────
# Handle Ctrl+C during polling — exit cleanly without stopping the remote job.
_handle_interrupt() {
    local job_name="${1:-}"
    echo ""
    echo ""
    echo "⚠️  Interrupted — job continues running in background"
    echo "   Job: ${job_name}"
    echo ""
    echo "   Resume monitoring: ./do/tune --technique ${ARG_TECHNIQUE} --dataset ${ARG_DATASET}"
    echo "   Check status:      ./do/tune --status"
    exit 130
}

# ── _handle_completion() ──────────────────────────────────────────────────────
# Store output paths, detect output type, print next-step commands.
_handle_completion() {
    local job_name="${1:-${JOB_NAME}}"
    local technique_upper
    technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')

    # Resolve artifact path via Python helper
    local resolve_args=(
        --job-name "${job_name}"
        --region "${AWS_REGION}"
        --training-type "${ARG_TRAINING_TYPE}"
    )
    resolve_args+=(--model-package-group "${PROJECT_NAME}-tune-models")

    local resolve_result
    resolve_result=$(python3 "${HELPER_SCRIPT}" resolve "${resolve_args[@]}" 2>/dev/null) || {
        echo "⚠️  Could not resolve output artifacts"
        echo "   Check job output manually:"
        echo "   python3 ${HELPER_SCRIPT} resolve --job-name ${job_name} --region ${AWS_REGION} --training-type ${ARG_TRAINING_TYPE}"
        return 0
    }

    # Check for error
    local has_error
    has_error=$(echo "${resolve_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"

    if [ "${has_error}" = "yes" ]; then
        local error_msg
        error_msg=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error',''))" 2>/dev/null) || error_msg=""
        if [ -n "${error_msg}" ]; then
            echo "   ⚠️  ${error_msg}"
        fi
        return 0
    fi

    # Extract results
    local artifact_path
    artifact_path=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('artifact_path',''))" 2>/dev/null) || artifact_path=""

    local output_type
    output_type=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('output_type',''))" 2>/dev/null) || output_type=""

    local model_package_arn
    model_package_arn=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('model_package_arn','') or '')" 2>/dev/null) || model_package_arn=""

    # Display results
    echo "🎉 Customization complete!"
    echo ""
    echo "   Output type: ${output_type}"
    echo "   Artifact: ${artifact_path}"
    if [ -n "${model_package_arn}" ]; then
        echo "   Model Package: ${model_package_arn}"
    fi
    echo ""

    # Store output paths in config
    if [ "${output_type}" = "adapter" ]; then
        _update_config_var "TUNE_ADAPTER_PATH_${technique_upper}" "${artifact_path}"
    else
        _update_config_var "TUNE_MODEL_PATH_${technique_upper}" "${artifact_path}"
    fi
    _update_config_var "TUNE_OUTPUT_PATH_LATEST" "${artifact_path}"
    _update_config_var "TUNE_OUTPUT_TYPE_LATEST" "${output_type}"

    # Print next-step commands
    echo "📋 Next steps:"
    echo ""
    if [ "${output_type}" = "adapter" ]; then
        echo "   Deploy as LoRA adapter:"
        echo "     ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune"
        echo "     ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune ${ARG_TECHNIQUE}"
        echo "     ./do/adapter add tuned-${ARG_TECHNIQUE} --weights ${artifact_path}"
    else
        echo "   Deploy as new inference component:"
        echo "     ./do/add-ic tuned-v1 --from-tune"
        echo "     ./do/add-ic tuned-v1 --model-data ${artifact_path}"
        echo "     ./do/deploy --force-ic --model-data ${artifact_path}"
    fi
    echo ""
}


# ── _dry_run() ────────────────────────────────────────────────────────────────
# Validate all inputs and show what would be submitted without creating a job.
_dry_run() {
    local output_bucket
    output_bucket=$(_resolve_output_bucket)

    local role_arn
    if [ -n "${ARG_ROLE}" ]; then
        role_arn="${ARG_ROLE}"
    elif [ -n "${ROLE_ARN:-}" ]; then
        role_arn="${ROLE_ARN}"
    else
        role_arn="${SAGEMAKER_ROLE_ARN:-<not configured>}"
    fi

    local timestamp
    timestamp=$(date +%Y%m%d-%H%M%S)
    local job_name="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"

    echo "🔍 Dry run — validation passed, would submit:"
    echo ""
    echo "   Job name:      ${job_name}"
    echo "   Model:         ${RESOLVED_MODEL_ID}"
    echo "   Technique:     ${ARG_TECHNIQUE}"
    echo "   Training type: ${ARG_TRAINING_TYPE}"
    echo "   Dataset:       ${RESOLVED_DATASET_S3_URI}"
    echo "   Output bucket: ${output_bucket}"
    echo "   Role:          ${role_arn}"
    echo "   Package group: ${PROJECT_NAME}-tune-models"

    if [ -n "${ARG_EPOCHS}" ]; then
        echo "   Epochs:        ${ARG_EPOCHS}"
    fi
    if [ -n "${ARG_LEARNING_RATE}" ]; then
        echo "   Learning rate: ${ARG_LEARNING_RATE}"
    fi
    if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
        echo "   Max seq len:   ${ARG_MAX_SEQ_LENGTH}"
    fi
    if [ -n "${ARG_LORA_RANK}" ]; then
        echo "   LoRA rank:     ${ARG_LORA_RANK}"
    fi
    if [ -n "${ARG_LORA_ALPHA}" ]; then
        echo "   LoRA alpha:    ${ARG_LORA_ALPHA}"
    fi
    if [ -n "${ARG_BATCH_SIZE}" ]; then
        echo "   Batch size:    ${ARG_BATCH_SIZE}"
    fi
    if [ -n "${ARG_REWARD_FUNCTION}" ]; then
        echo "   Reward fn:     ${ARG_REWARD_FUNCTION}"
    fi
    if [ -n "${ARG_REWARD_PROMPT}" ]; then
        echo "   Reward prompt: ${ARG_REWARD_PROMPT}"
    fi

    echo ""
    echo "   ✅ All validations passed. Remove --dry-run to submit."
    exit 0
}

# ══════════════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════════════

_parse_args "$@"

# Handle informational flags first
if [ "${ARG_HELP}" = true ]; then
    _show_help
fi

if [ "${ARG_LIST_MODELS}" = true ]; then
    _list_models
fi

if [ "${ARG_STATUS}" = true ]; then
    _show_status
fi

# Validate required arguments for job submission
if [ -z "${ARG_TECHNIQUE}" ]; then
    echo "❌ --technique is required"
    echo "   Usage: ./do/tune --technique <sft|dpo|rlaif|rlvr> --dataset <source>"
    echo "   Run ./do/tune --help for full usage."
    exit 1
fi

if [ -z "${ARG_DATASET}" ]; then
    echo "❌ --dataset is required"
    echo "   Usage: ./do/tune --technique ${ARG_TECHNIQUE} --dataset <s3://... or hf://...>"
    echo "   Run ./do/tune --help for full usage."
    exit 1
fi

# Check runtime support
if [ "${TUNE_SUPPORTED:-}" = "false" ]; then
    echo "⚠️  Managed customization is not supported for the configured model."
    echo "   Checking catalog for current support..."
    echo ""
fi

# Validate Python availability
if ! command -v python3 &>/dev/null; then
    echo "❌ python3 is required but not found"
    echo "   Install Python 3 to use managed model customization."
    exit 1
fi

# Run validations
echo "🔧 SageMaker AI Managed Model Customization"
echo ""

_validate_model
_validate_technique
_validate_training_type
_validate_dataset

# Check idempotency (may exit early if existing job is handled)
if _check_idempotency; then
    # No existing job or --force: proceed with submission

    if [ "${ARG_DRY_RUN}" = true ]; then
        _dry_run
    fi

    _submit_job

    if [ "${ARG_NO_WAIT}" = true ]; then
        echo "   --no-wait specified. Job running in background."
        echo "   Check status: ./do/tune --status"
        exit 0
    fi

    _poll_job
    _handle_completion
fi
