#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import csv
import json
import os
import re
import subprocess
import sys
import uuid
from datetime import datetime

BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
CSV_PATH = os.path.join(BASE_DIR, "data", "food_log.csv")
PROMPT_PATH = os.path.join(BASE_DIR, "scripts", "agent_prompt.txt")

CSV_COLUMNS = [
    "timestamp",
    "event_id",
    "raw_text",
    "food_name",
    "quantity",
    "unit",
    "calories",
    "protein_g",
    "fat_g",
    "confidence",
    "status",
    "parser_source",
]


def now_iso():
    return datetime.now().astimezone().isoformat()


def ensure_csv():
    os.makedirs(os.path.dirname(CSV_PATH), exist_ok=True)
    if not os.path.exists(CSV_PATH) or os.path.getsize(CSV_PATH) == 0:
        with open(CSV_PATH, "w", encoding="utf-8", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS)
            writer.writeheader()


def load_prompt():
    if os.path.exists(PROMPT_PATH):
        with open(PROMPT_PATH, "r", encoding="utf-8") as f:
            return f.read().strip()
    return (
        "You are a nutrition estimator.\n"
        "Return JSON only. No markdown. No explanations.\n"
    )


def parse_json_maybe(text):
    if not text:
        return None

    text = text.strip()

    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?", "", text, flags=re.I).strip()
        text = re.sub(r"```$", "", text).strip()

    try:
        return json.loads(text)
    except Exception:
        pass

    decoder = json.JSONDecoder()
    for i, ch in enumerate(text):
        if ch not in "{[":
            continue
        try:
            obj, _ = decoder.raw_decode(text[i:])
            return obj
        except Exception:
            continue

    return None


def extract_openclaw_payload(stdout):
    wrapper = parse_json_maybe(stdout)
    if wrapper is None:
        return None

    if isinstance(wrapper, dict):
        payloads = (
            wrapper.get("result", {})
            .get("payloads", [])
        )
        if payloads and isinstance(payloads, list):
            text = payloads[0].get("text", "")
            inner = parse_json_maybe(text)
            if inner is not None:
                return inner

        if "status" in wrapper and "items" in wrapper:
            return wrapper

    return None


def validate_result(obj):
    if not isinstance(obj, dict):
        return False, "result_not_dict"

    status = obj.get("status")
    if status not in ("success", "error"):
        return False, "missing_or_invalid_status"

    if status == "error":
        if not obj.get("error_reason"):
            return False, "error_missing_reason"
        return True, None

    items = obj.get("items")
    totals = obj.get("totals")

    if not isinstance(items, list) or not items:
        return False, "success_missing_items"

    if not isinstance(totals, dict):
        return False, "success_missing_totals"

    for item in items:
        if not isinstance(item, dict):
            return False, "item_not_dict"

        for key in ("food_name", "quantity", "unit", "calories", "protein_g", "fat_g", "confidence"):
            if key not in item:
                return False, f"item_missing_{key}"

        for key in ("quantity", "calories", "protein_g", "fat_g"):
            try:
                float(item[key])
            except Exception:
                return False, f"item_invalid_number_{key}"

    return True, None


def call_openclaw(raw_text):
    agent = os.environ.get("OPENCLAW_AGENT", "health_food_parser_rev002")
    prompt = load_prompt()
    message = f"{prompt}\n\nInput:\n{raw_text}"

    cmd = [
        "timeout",
        "45s",
        "openclaw",
        "agent",
        "--agent",
        agent,
        "--message",
        message,
        "--json",
        "--timeout",
        "30",
    ]

    proc = subprocess.run(
        cmd,
        cwd=BASE_DIR,
        capture_output=True,
        text=True,
        timeout=60,
    )

    if proc.returncode != 0:
        return None, f"openclaw_rc_{proc.returncode}: {proc.stderr[:300]}"

    obj = extract_openclaw_payload(proc.stdout)
    if obj is None:
        return None, "openclaw_json_not_found"

    ok, reason = validate_result(obj)
    if not ok:
        return None, f"openclaw_invalid_json: {reason}"

    return obj, None


def deterministic_fallback(raw_text):
    text = raw_text.strip()

    foods_100g = [
        ("חזה עוף", 165, 31.0, 3.6),
        ("בננה", 89, 1.1, 0.3),
        ("תפוח", 52, 0.3, 0.2),
        ("אורז", 130, 2.4, 0.3),
        ("קוטג", 95, 11.0, 5.0),
        ("יוגורט", 63, 5.0, 1.5),
        ("טונה במים", 116, 26.0, 1.0),
    ]

    items = []

    for name, cal100, prot100, fat100 in foods_100g:
        pattern = rf"{re.escape(name)}\s+(\d+(?:\.\d+)?)\s*גרם"
        m = re.search(pattern, text)
        if m:
            grams = float(m.group(1))
            factor = grams / 100.0
            items.append({
                "food_name": name,
                "quantity": grams,
                "unit": "g",
                "calories": round(cal100 * factor, 1),
                "protein_g": round(prot100 * factor, 1),
                "fat_g": round(fat100 * factor, 1),
                "confidence": "high",
            })

    m = re.search(r"(\d+)\s*ביצ", text)
    if m:
        qty = int(m.group(1))
        items.append({
            "food_name": "ביצה",
            "quantity": qty,
            "unit": "unit",
            "calories": round(72 * qty, 1),
            "protein_g": round(6.3 * qty, 1),
            "fat_g": round(4.8 * qty, 1),
            "confidence": "high",
        })

    if not items:
        return {
            "status": "error",
            "items": [],
            "totals": {"calories": 0, "protein_g": 0, "fat_g": 0},
            "error_reason": "no_parse",
        }

    return {
        "status": "success",
        "items": items,
        "totals": {
            "calories": round(sum(float(i["calories"]) for i in items), 1),
            "protein_g": round(sum(float(i["protein_g"]) for i in items), 1),
            "fat_g": round(sum(float(i["fat_g"]) for i in items), 1),
        },
        "error_reason": None,
    }


def append_csv(raw_text, result, parser_source):
    ensure_csv()

    if result.get("status") != "success":
        return 0

    event_id = "rev2-" + uuid.uuid4().hex
    timestamp = now_iso()

    rows_written = 0
    with open(CSV_PATH, "a", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS)
        for item in result["items"]:
            writer.writerow({
                "timestamp": timestamp,
                "event_id": event_id,
                "raw_text": raw_text,
                "food_name": item.get("food_name", ""),
                "quantity": item.get("quantity", 0),
                "unit": item.get("unit", ""),
                "calories": item.get("calories", 0),
                "protein_g": item.get("protein_g", 0),
                "fat_g": item.get("fat_g", 0),
                "confidence": item.get("confidence", ""),
                "status": "success",
                "parser_source": parser_source,
            })
            rows_written += 1

    return rows_written


def main():
    if len(sys.argv) < 2:
        print("FAIL: missing raw_text")
        return 2

    raw_text = " ".join(sys.argv[1:]).strip()
    if not raw_text:
        print("FAIL: empty raw_text")
        return 2

    result = None
    parser_source = None

    if os.environ.get("USE_OPENCLAW") == "1":
        result, err = call_openclaw(raw_text)
        if result is not None:
            parser_source = "openclaw_agent"
        else:
            print(f"OPENCLAW_FAIL: {err}")
            result = deterministic_fallback(raw_text)
            parser_source = "deterministic_fallback"
    else:
        result = deterministic_fallback(raw_text)
        parser_source = "deterministic_fallback"

    ok, reason = validate_result(result)
    if not ok:
        print(f"FAIL: invalid_result {reason}")
        return 1

    if result.get("status") != "success":
        print(f"FAIL: agent_error {result.get('error_reason')}")
        return 1

    rows = append_csv(raw_text, result, parser_source)
    print(f"PASS: wrote_rows={rows} parser_source={parser_source}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
