#!/usr/bin/env python3
"""
separate.py — AI stem separation for LinkinPark project

Uses Demucs htdemucs_6s to split a song into 6 stems:
  vocals, drums, bass, guitar, piano, other

Optionally uses MDX-Net (via audio-separator) for higher-quality vocal isolation.

Usage:
    python3 separate.py song.wav "Song Name"
    python3 separate.py song.wav "Song Name" --album "From Zero" --bpm 128
    python3 separate.py song.wav "Song Name" --model htdemucs_ft   # 4-stem variant
    python3 separate.py song.wav "Song Name" --mdx-vocals          # MDX-Net vocal boost
"""

from __future__ import annotations

import argparse
import os
import shutil
import subprocess
import sys
from pathlib import Path

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

PROJECT_DIR = Path(__file__).resolve().parent.parent / "LinkinPark Project"
SAMPLES_DIR = PROJECT_DIR / "Samples"

# Demucs model → output stem names
MODELS = {
    "htdemucs_6s": ["vocals", "drums", "bass", "guitar", "piano", "other"],
    "htdemucs_ft": ["vocals", "drums", "bass", "other"],
    "htdemucs":    ["vocals", "drums", "bass", "other"],
}

DEFAULT_MODEL = "htdemucs_6s"

# MDX-Net model for optional vocal re-extraction
MDX_VOCAL_MODEL = "UVR-MDX-NET-Voc_FT.onnx"

# Stem name → Ableton track ID (for add_stems.py config generation)
STEM_TO_TRACK = {
    "drums":  105,   # Other Kit (full drum mix)
    "bass":   106,   # Bass
    "guitar": 107,   # Rhythm Guitar (full guitar mix)
    "vocals": 109,   # Vocals
    "piano":  112,   # Piano
    "other":  115,   # Other
}


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def run(cmd, desc=""):
    """Run a shell command with live output."""
    if desc:
        print(f"  → {desc}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"  ERROR: {' '.join(str(c) for c in cmd)}")
        if result.stderr:
            print(result.stderr[-500:])
        return False
    return True


def to_flac(src: Path, dst: Path):
    """Convert any audio file to FLAC."""
    return run(
        ["ffmpeg", "-y", "-i", str(src), "-c:a", "flac",
         "-compression_level", "8", str(dst)],
        f"→ {dst.name}")


def find_stem(directory: Path, name: str) -> Path | None:
    """Find a stem file by name (any audio extension)."""
    for ext in [".wav", ".flac", ".mp3", ".ogg"]:
        p = directory / f"{name}{ext}"
        if p.exists():
            return p
    return None


# ---------------------------------------------------------------------------
# Demucs separation
# ---------------------------------------------------------------------------

def run_demucs(input_file: Path, work_dir: Path, model: str) -> dict[str, Path]:
    """Run Demucs and return {stem_name: path} dict."""
    print(f"\n[Demucs {model}] Separating into {len(MODELS[model])} stems...")

    out_dir = work_dir / f"demucs_{model}"
    out_dir.mkdir(exist_ok=True)

    cmd = [
        "demucs",
        "--name", model,
        "--out", str(out_dir),
        "--filename", "{stem}.{ext}",
        str(input_file),
    ]
    if not run(cmd, f"Running demucs --name {model}..."):
        sys.exit(1)

    # Demucs output location depends on --filename template:
    # with "{stem}.{ext}" files land in out_dir/model/ (no subdir)
    # without it, files land in out_dir/model/<input_name>/
    stems_dir = out_dir / model / input_file.stem
    if not stems_dir.exists():
        stems_dir = out_dir / model  # fallback: no subdirectory

    results = {}
    for stem_name in MODELS[model]:
        f = find_stem(stems_dir, stem_name)
        if f:
            results[stem_name] = f
            size_mb = f.stat().st_size / 1024 / 1024
            print(f"  ✓ {stem_name} ({size_mb:.1f} MB)")
        else:
            print(f"  ✗ {stem_name} not found")

    return results


# ---------------------------------------------------------------------------
# Optional: MDX-Net vocal re-extraction
# ---------------------------------------------------------------------------

def run_mdx_vocals(input_file: Path, work_dir: Path) -> Path | None:
    """Run MDX-Net for cleaner vocal isolation. Returns path to vocals WAV."""
    print(f"\n[MDX-Net] Re-extracting vocals with {MDX_VOCAL_MODEL}...")

    out_dir = work_dir / "mdx_vocals"
    out_dir.mkdir(exist_ok=True)

    cmd = [
        "audio-separator",
        str(input_file),
        "--model_filename", MDX_VOCAL_MODEL,
        "--output_dir", str(out_dir),
        "--output_format", "WAV",
    ]
    if not run(cmd, f"Running audio-separator ({MDX_VOCAL_MODEL})..."):
        print("  WARNING: MDX vocal extraction failed")
        return None

    # audio-separator names: *(Vocals).wav and *(Instrumental).wav
    for f in out_dir.iterdir():
        if "vocal" in f.name.lower() and f.suffix == ".wav":
            print(f"  ✓ MDX vocals: {f.name}")
            return f

    # Fallback: primary output (first file)
    wav_files = sorted(out_dir.glob("*.wav"))
    if wav_files:
        print(f"  ✓ MDX vocals: {wav_files[0].name}")
        return wav_files[0]

    return None


# ---------------------------------------------------------------------------
# Assembly
# ---------------------------------------------------------------------------

def assemble(stems: dict[str, Path], output_dir: Path):
    """Convert all stems to FLAC in output directory."""
    print(f"\n[Assembly] Writing FLAC stems to {output_dir}")
    output_dir.mkdir(parents=True, exist_ok=True)

    for name, src in sorted(stems.items()):
        dst = output_dir / f"{name}.flac"
        if src.suffix == ".flac":
            shutil.copy2(src, dst)
            print(f"  ✓ {name}.flac (copied)")
        else:
            to_flac(src, dst)

    print(f"\n  Total: {len(stems)} stems in {output_dir}")


def print_config(song_name: str, bpm: int | None, album: str,
                 output_dir: Path, color: int):
    """Print SONG_CONFIG block for add_stems.py."""
    stems_rel = f"Samples/{album}/{song_name}"

    entries = []
    for stem_name, track_id in sorted(STEM_TO_TRACK.items(), key=lambda x: x[1]):
        if (output_dir / f"{stem_name}.flac").exists():
            entries.append(f'        {track_id}: ("{stem_name}", "{stem_name}"),')

    bpm_str = str(bpm) if bpm else "???  # TODO: look up BPM"

    print(f"\n{'='*60}")
    print("SONG_CONFIG for add_stems.py:")
    print(f"{'='*60}")
    print(f'''
SONG_CONFIG = {{
    "name": "{song_name}",
    "bpm": {bpm_str},
    "color": {color},
    "stems_dir": "{stems_rel}",
    "stems": {{
{chr(10).join(entries)}
    }}
}}''')


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="AI stem separation for LinkinPark project")
    parser.add_argument("input", nargs="?", help="Input audio file (WAV/MP3/FLAC)")
    parser.add_argument("song_name", nargs="?", help="Song name")
    parser.add_argument("--album", default="From Zero",
                        help="Album name (default: From Zero)")
    parser.add_argument("--bpm", type=int, help="Song BPM")
    parser.add_argument("--color", type=int, default=16,
                        help="Ableton color index (default: 16/purple)")
    parser.add_argument("--model", default=DEFAULT_MODEL,
                        choices=MODELS.keys(),
                        help=f"Demucs model (default: {DEFAULT_MODEL})")
    parser.add_argument("--mdx-vocals", action="store_true",
                        help="Also run MDX-Net for cleaner vocals")
    parser.add_argument("--work-dir", type=str,
                        help="Working dir for intermediates (default: /tmp/...)")

    args = parser.parse_args()

    if not args.input or not args.song_name:
        parser.print_help()
        print("\nModels:")
        for name, stems in MODELS.items():
            print(f"  {name}: {', '.join(stems)}")
        sys.exit(1)

    input_file = Path(args.input).resolve()
    if not input_file.exists():
        print(f"ERROR: {input_file} not found")
        sys.exit(1)

    song_name = args.song_name
    output_dir = SAMPLES_DIR / args.album / song_name

    work_dir = Path(args.work_dir) if args.work_dir else \
        Path(f"/tmp/stem_sep_{song_name.replace(' ', '_')}")
    work_dir.mkdir(parents=True, exist_ok=True)

    print(f"Input:  {input_file}")
    print(f"Output: {output_dir}")
    print(f"Model:  {args.model}")

    # Run Demucs
    stems = run_demucs(input_file, work_dir, args.model)

    # Optional: MDX-Net vocals
    if args.mdx_vocals:
        mdx_vox = run_mdx_vocals(input_file, work_dir)
        if mdx_vox:
            stems["vocals"] = mdx_vox  # override Demucs vocals

    # Assemble FLAC output
    assemble(stems, output_dir)

    # Print config for add_stems.py
    print_config(song_name, args.bpm, args.album, output_dir, args.color)

    print(f"\nNext steps:")
    print(f"  1. Review stems in {output_dir}")
    print(f"  2. Copy SONG_CONFIG into add_stems.py")
    print(f"  3. python3 add_stems.py && ./build.sh")


if __name__ == "__main__":
    main()
