#!/usr/bin/env python3
"""Ensure devices exist for vendor inventory and backfill compatibility links."""

from __future__ import annotations

import csv
import sys
from pathlib import Path
from typing import Dict, Tuple

import psycopg2

REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(REPO_ROOT))

from imports.normalize_mobile_parts import detect_manufacturer, infer_part_template

CSV_PATH = Path(__file__).resolve().parent / "mobile_medic_parts_pricing_full.csv"


def load_manufacturers(cur) -> Dict[str, int]:
    cur.execute("SELECT id, name FROM manufacturers")
    return {name.lower(): mid for mid, name in cur.fetchall()}


def ensure_manufacturer(cur, manufacturers: Dict[str, int], name: str) -> int:
    key = name.lower()
    if key in manufacturers:
        return manufacturers[key]
    cur.execute(
        "INSERT INTO manufacturers (name) VALUES (%s) ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name RETURNING id",
        (name,),
    )
    new_id = cur.fetchone()[0]
    manufacturers[key] = new_id
    return new_id


def load_device_types(cur) -> Dict[str, int]:
    cur.execute("SELECT id, name FROM device_types")
    return {name.lower(): did for did, name in cur.fetchall()}


def ensure_device_type(cur, device_types: Dict[str, int], name: str) -> int:
    key = name.lower()
    if key in device_types:
        return device_types[key]
    cur.execute(
        "INSERT INTO device_types (name) VALUES (%s) RETURNING id",
        (name.capitalize(),),
    )
    new_id = cur.fetchone()[0]
    device_types[key] = new_id
    return new_id


def load_families(cur) -> Dict[Tuple[int, str], int]:
    cur.execute("SELECT id, manufacturer_id, name FROM device_families")
    return {(mid, name.lower()): fid for fid, mid, name in cur.fetchall()}


def ensure_family(cur, families: Dict[Tuple[int, str], int], manufacturer_id: int, device_type_id: int, name: str) -> int:
    key = (manufacturer_id, name.lower())
    if key in families:
        return families[key]
    cur.execute(
        """
        INSERT INTO device_families (manufacturer_id, device_type_id, name)
        VALUES (%s, %s, %s)
        ON CONFLICT (manufacturer_id, name) DO UPDATE SET name = EXCLUDED.name
        RETURNING id
        """,
        (manufacturer_id, device_type_id, name),
    )
    family_id = cur.fetchone()[0]
    families[key] = family_id
    return family_id


def load_devices(cur) -> Dict[Tuple[int, str], int]:
    cur.execute("SELECT id, family_id, model_name FROM devices")
    return {(family_id, model.lower()): did for did, family_id, model in cur.fetchall()}


def ensure_device(cur, devices: Dict[Tuple[int, str], int], family_id: int, model_name: str) -> int:
    key = (family_id, model_name.lower())
    if key in devices:
        return devices[key]
    cur.execute(
        """
        INSERT INTO devices (family_id, model_name)
        VALUES (%s, %s)
        ON CONFLICT (family_id, model_name) DO UPDATE SET model_name = EXCLUDED.model_name
        RETURNING id
        """,
        (family_id, model_name),
    )
    device_id = cur.fetchone()[0]
    devices[key] = device_id
    return device_id


def load_parts(cur) -> Dict[Tuple[str, str], int]:
    cur.execute("SELECT id, service_type, name FROM parts")
    return {(stype.lower(), name.lower()): pid for pid, stype, name in cur.fetchall()}


def main() -> None:
    if not CSV_PATH.exists():
        raise SystemExit(f"{CSV_PATH} not found")

    conn = psycopg2.connect(dbname="medicore", user="postgres")
    conn.autocommit = False
    cur = conn.cursor()

    manufacturers = load_manufacturers(cur)
    device_types = load_device_types(cur)
    families = load_families(cur)
    devices = load_devices(cur)
    parts = load_parts(cur)

    compat_inserted = 0
    missing_parts = 0
    with CSV_PATH.open() as handle:
        reader = csv.DictReader(handle)
        for row in reader:
            model = (row.get("model_name") or "").strip()
            family_name = (row.get("family_name") or model or "").strip()
            service_type = (row.get("service_type") or "").strip().lower()
            part_label = (row.get("name") or "").strip()
            if not model or not service_type or not part_label:
                continue

            manufacturer_name = detect_manufacturer(family_name, model)
            manufacturer_id = ensure_manufacturer(cur, manufacturers, manufacturer_name)

            device_type_name = (row.get("device_type_name") or "phone").strip().lower()
            device_type_id = ensure_device_type(cur, device_types, device_type_name)

            family_id = ensure_family(cur, families, manufacturer_id, device_type_id, family_name)
            device_id = ensure_device(cur, devices, family_id, model)

            part_type_key, suffix = infer_part_template(service_type, part_label)
            part_name = f"{model} {suffix}".strip()
            part_key = (service_type, part_name.lower())
            part_id = parts.get(part_key)
            if not part_id:
                missing_parts += 1
                continue

            cur.execute(
                "SELECT 1 FROM part_compatibility WHERE part_id = %s AND device_id = %s",
                (part_id, device_id),
            )
            if cur.fetchone():
                continue
            cur.execute(
                "INSERT INTO part_compatibility (part_id, device_id) VALUES (%s, %s)",
                (part_id, device_id),
            )
            compat_inserted += 1

    conn.commit()
    cur.close()
    conn.close()

    print(f"Compatibility rows inserted: {compat_inserted}")
    print(f"Rows skipped (part missing): {missing_parts}")


if __name__ == "__main__":
    main()
