import streamlit as st
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import pipeline
from PIL import Image
import os

# ================= CONFIG =================
st.set_page_config(page_title="DeepShop Duo 2025", layout="wide")
st.title("DeepShop Duo 2025")
st.markdown("## Recommandation visuelle & analyse d’avis (Deep Learning)")

# ================= PATHS =================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CSV_PATH = os.path.join(BASE_DIR, "products.csv")
IMG_DIR = os.path.join(BASE_DIR, "images")

# ================= DATA =================
@st.cache_data
def load_products():
    if not os.path.exists(CSV_PATH):
        st.error(f"CSV introuvable : {CSV_PATH}")
        st.stop()
    return pd.read_csv(CSV_PATH, encoding="latin-1")

products = load_products()

# ================= RESNET =================
@st.cache_resource
def get_resnet():
    model = models.resnet50(pretrained=True)
    model.eval()
    return nn.Sequential(*list(model.children())[:-1])

resnet = get_resnet()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

@st.cache_resource
def build_embeddings(df):
    emb = {}
    count = 0
    for _, row in df.iterrows():
        img_path = os.path.join(IMG_DIR, str(row["image"]))
        if os.path.exists(img_path):
            try:
                img = Image.open(img_path).convert("RGB")
                x = transform(img).unsqueeze(0)
                with torch.no_grad():
                    feat = resnet(x).flatten()
                emb[str(row["id"])] = F.normalize(feat, dim=0)
                count += 1
            except:
                pass
    st.sidebar.success(f"Images indexées : {count}")
    return emb

embeddings = build_embeddings(products)

# ================= RECOMMANDATION =================
def recommend(pid, k=6):
    pid_i = int(pid)

    # catégorie
    cat = products.loc[products["id"] == pid_i, "category"].iloc[0]

    # produits de même catégorie
    same_cat = products[
        (products["category"] == cat) &
        (products["id"] != pid_i)
    ]

    valid_ids = [
        str(i) for i in same_cat["id"].tolist()
        if str(i) in embeddings
    ]

    if len(valid_ids) == 0:
        return []

    if len(valid_ids) == 1:
        return valid_ids

    q = embeddings[str(pid_i)].unsqueeze(0)
    vecs = torch.stack([embeddings[i] for i in valid_ids])
    sims = F.cosine_similarity(q, vecs)

    order = torch.argsort(sims, descending=True)
    return [valid_ids[i] for i in order[:k]]

# ================= BERT =================
@st.cache_resource
def get_bert():
    return pipeline(
        "sentiment-analysis",
        model="nlptown/bert-base-multilingual-uncased-sentiment"
    )

bert = get_bert()

def analyze_sentiment(text):
    res = bert(text[:512])[0]
    stars = int(res["label"].split()[0])
    if stars >= 4:
        return "Positif", stars
    if stars <= 2:
        return "Négatif", stars
    return "Neutre", stars

# ================= UI =================
tab1, tab2 = st.tabs(["Recommandations visuelles", "Analyse d’avis"])

with tab1:
    pid = int(st.selectbox("Sélectionne un produit", products["id"]))
    row = products[products["id"] == pid].iloc[0]

    st.markdown(f"### {row['name']}")
    main_img = os.path.join(IMG_DIR, row["image"])
    if os.path.exists(main_img):
        st.image(Image.open(main_img), width=320)
    else:
        st.image("https://via.placeholder.com/320x240", width=320)

    st.caption(f"Catégorie : {row['category']}")

    if st.button("Produits similaires"):
        recs = recommend(str(pid))
        if not recs:
            st.info("Aucun produit réellement similaire dans cette catégorie.")
        else:
            cols = st.columns(3)
            for i, rid in enumerate(recs):
                r = products[products["id"] == int(rid)].iloc[0]
                with cols[i % 3]:
                    p = os.path.join(IMG_DIR, r["image"])
                    if os.path.exists(p):
                        st.image(Image.open(p), width=200)
                    else:
                        st.image("https://via.placeholder.com/200x150", width=200)
                    st.markdown(
                        f"<p style='text-align:center; font-weight:bold'>{r['name']}</p>",
                        unsafe_allow_html=True
                    )

with tab2:
    txt = st.text_area("Écris un avis produit")
    if txt.strip():
        label, stars = analyze_sentiment(txt)
        st.success(f"{label} – {stars} étoiles")

st.caption("DeepShop Duo 2025 — APP FINALE COMPLÈTE")
