#!/usr/bin/env bash
set -euo pipefail

env_file="${ENV_FILE:-/root/tipsharks/.env}"
from_date="${FROM_DATE:-2000-01-01}"
to_date="${TO_DATE:-2100-01-01}"
results_dir="${RESULTS_DIR:-/tmp/elo_auto_$(date +%Y%m%d_%H%M%S)}"
mkdir -p "$results_dir"

results_tsv="${results_dir}/results.tsv"
echo -e "step\tparam\tvalue\tlog_loss" >"$results_tsv"

log() {
  echo "[$(date +%H:%M:%S)] $*"
}

set_env() {
  local key="$1"
  local val="$2"
  perl -pi -e "s/^${key}=.*/${key}=${val}/" "$env_file"
}

recompute() {
  docker compose run --rm worker \
    python -m apps.backend.worker.cli recompute --from "$from_date" --to "$to_date" \
    >/tmp/elo_recompute.log
}

evaluate_json() {
  docker compose run --rm \
    -e DATABASE_URL=postgresql+psycopg://tipsharks:tipsharks@db:5432/tipsharks \
    -e PYTHONPATH=/app \
    worker python scripts/evaluate_accuracy.py --json
}

extract_log_loss() {
  python - <<'PY'
import json
import sys
data = json.load(sys.stdin)
ll = data.get("all", {}).get("log_loss")
if ll is None:
    raise SystemExit("missing log_loss")
print(ll)
PY
}

run_candidate() {
  local step="$1"
  local param="$2"
  local value="$3"
  local tag="$4"

  log "${step}: ${param}=${value}"
  set_env "$param" "$value"
  recompute
  local json_path="${results_dir}/${tag}.json"
  evaluate_json | tee "$json_path" >/tmp/elo_eval.json
  local log_loss
  log_loss="$(cat /tmp/elo_eval.json | extract_log_loss)"
  echo -e "${step}\t${param}\t${value}\t${log_loss}" | tee -a "$results_tsv"
  echo "$log_loss"
}

baseline() {
  log "baseline: recompute and evaluate current settings"
  recompute
  local json_path="${results_dir}/baseline.json"
  evaluate_json | tee "$json_path" >/tmp/elo_eval.json
  local log_loss
  log_loss="$(cat /tmp/elo_eval.json | extract_log_loss)"
  echo -e "baseline\t-\t-\t${log_loss}" | tee -a "$results_tsv"
}

sweep_param() {
  local step="$1"
  local param="$2"
  shift 2
  local best_val=""
  local best_loss=""
  local tag_index=0
  for val in "$@"; do
    local loss
    loss="$(run_candidate "$step" "$param" "$val" "${step}_${tag_index}")"
    tag_index=$((tag_index + 1))
    if [[ -z "$best_loss" ]] || awk "BEGIN {exit !($loss < $best_loss)}"; then
      best_loss="$loss"
      best_val="$val"
    fi
  done
  log "${step}: best ${param}=${best_val} (log_loss=${best_loss})"
  set_env "$param" "$best_val"
}

sweep_k_exponential() {
  local step="k_exponential"
  local k_start="${K_START:-100}"
  local k_max="${K_MAX:-1600}"
  local k_factor="${K_FACTOR:-2}"
  local refine_count="${K_REFINE_COUNT:-5}"

  local k="$k_start"
  local best_k=""
  local best_loss=""
  local k_drop=""
  local tag_index=0

  while awk "BEGIN {exit !($k <= $k_max)}"; do
    local loss
    loss="$(run_candidate "$step" "ELO_K_BASE" "$k" "${step}_${tag_index}")"
    tag_index=$((tag_index + 1))
    if [[ -z "$best_loss" ]] || awk "BEGIN {exit !($loss < $best_loss)}"; then
      best_loss="$loss"
      best_k="$k"
    elif [[ -n "$best_loss" ]]; then
      k_drop="$k"
      break
    fi
    k="$(awk "BEGIN {printf \"%.4f\", $k * $k_factor}")"
  done

  if [[ -n "$k_drop" && -n "$best_k" ]]; then
    log "${step}: drop at ${k_drop}, refine between ${best_k} and ${k_drop}"
    python - <<PY >"${results_dir}/k_refine_values.txt"
best = float("${best_k}")
drop = float("${k_drop}")
count = int("${refine_count}")
step = (drop - best) / (count + 1)
vals = [best + step * (i + 1) for i in range(count)]
print(" ".join(f"{v:.4f}" for v in vals))
PY
    read -r -a refine_vals <<<"$(cat "${results_dir}/k_refine_values.txt")"
    for val in "${refine_vals[@]}"; do
      local loss
      loss="$(run_candidate "k_refine" "ELO_K_BASE" "$val" "k_refine_${tag_index}")"
      tag_index=$((tag_index + 1))
      if awk "BEGIN {exit !($loss < $best_loss)}"; then
        best_loss="$loss"
        best_k="$val"
      fi
    done
  fi

  log "k_exponential: best ELO_K_BASE=${best_k} (log_loss=${best_loss})"
  set_env "ELO_K_BASE" "$best_k"
}

main() {
  baseline

  sweep_param "scale" "ELO_SCALE_C" ${SCALE_VALUES:-200 250 300 350 400 500}
  sweep_k_exponential
  sweep_param "rd_mode" "RD_SCALING_MODE" ${RD_MODES:-linear sqrt none}
  sweep_param "pairwise" "PAIRWISE_NORMALIZER" ${PAIRWISE_VALUES:-n_minus_1 n comparisons}

  if [[ "${ADJ_SWEEP:-0}" == "1" ]]; then
    log "adjustments: grid search"
    local lrs="${ADJ_LR_VALUES:-0.2 0.5 0.8}"
    local scales="${ADJ_SCALE_VALUES:-0.5 1.0 1.5}"
    local tag_index=0
    local best_loss=""
    local best_lr=""
    local best_scale=""
    for lr in $lrs; do
      for sc in $scales; do
        set_env "ADJ_LEARNING_RATE" "$lr"
        set_env "ADJ_UPDATE_SCALE" "$sc"
        local loss
        loss="$(run_candidate "adj_grid" "ADJ_LEARNING_RATE" "$lr" "adj_${tag_index}")"
        echo -e "adj_grid\tADJ_UPDATE_SCALE\t${sc}\t${loss}" | tee -a "$results_tsv"
        tag_index=$((tag_index + 1))
        if [[ -z "$best_loss" ]] || awk "BEGIN {exit !($loss < $best_loss)}"; then
          best_loss="$loss"
          best_lr="$lr"
          best_scale="$sc"
        fi
      done
    done
    log "adjustments: best ADJ_LEARNING_RATE=${best_lr} ADJ_UPDATE_SCALE=${best_scale} (log_loss=${best_loss})"
    set_env "ADJ_LEARNING_RATE" "$best_lr"
    set_env "ADJ_UPDATE_SCALE" "$best_scale"
  fi

  log "done. results in ${results_tsv}"
}

main
