
import streamlit as st
import pandas as pd
import json
import os
import numpy as np
from PIL import Image
from datetime import datetime
import base64
from io import BytesIO

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from transformers import pipeline

from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans

from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors

# ==================== STYLE ====================
st.set_page_config(page_title="SportShop IA – Championne 2025", layout="wide")
st.markdown("""
<style>
    .big-title {font-size: 62px !important; font-weight:bold; text-align:center; color:#FF3366; text-shadow: 0 0 15px gold;}
    .subtitle {text-align:center; color:#555; font-size:22px; margin-top:-15px;}
    .gold-border {border: 8px solid gold; border-radius: 30px; padding: 25px;
                  background: linear-gradient(45deg, #FFFDE7, #FFD700); box-shadow: 0 15px 40px rgba(255,215,0,0.5); margin: 30px 0;}
    .like-counter {font-size: 36px; color: #FF1493; font-weight: bold;}
    .champion-badge {background: gold; color: black; padding: 10px 20px; margin-bottom: 30px; border-radius: 30px; font-weight: bold; font-size:20px;}
    .dl-tag {background:#8A2BE2; color:white; padding:5px 12px; border-radius:12px; font-size:14px;}
</style>
""", unsafe_allow_html=True)

# ==================== CHARGEMENT DONNÉES ====================
products = pd.read_csv("data/products.csv", encoding="utf-8")

def load_json_safe(path, default={}):
    if not os.path.exists(path): return default
    try:
        with open(path, "r", encoding="utf-8") as f: return json.load(f)
    except:
        try:
            with open(path, "r", encoding="latin-1") as f:
                return json.loads(f.read().encode("latin-1").decode("utf-8", errors="ignore"))
        except: return default

clicks  = load_json_safe("data/clicks.json", {})
reviews = load_json_safe("data/reviews.json", {})

if "cart" not in st.session_state:
    st.session_state.cart = {}

# ==================== CHARGEMENT MODÈLES DEEP LEARNING ====================
with st.spinner("Chargement des modèles IA (ResNet50 + BERT + LSTM)… 8-15s max"):

    # 1. ResNet50
    @st.cache_resource
    def get_resnet():
        model = models.resnet50(pretrained=True)
        return nn.Sequential(*list(model.children())[:-1]).eval()
    resnet = get_resnet()

    # 2. BERT Sentiment
    @st.cache_resource
    def get_bert():
        return pipeline("sentiment-analysis",
                        model="nlptown/bert-base-multilingual-uncased-sentiment",
                        return_all_scores=False)
    bert = get_bert()

    # 3. LSTM Prédiction Likes – 100% CORRIGÉ
    @st.cache_resource
    def get_lstm():
        class LSTMModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lstm = nn.LSTM(1, 64, 2, batch_first=True)
                self.fc = nn.Linear(64, 1)
            def forward(self, x):
                out, _ = self.lstm(x)
                return self.fc(out[:, -1])

        history = {}
        np.random.seed(42)
        for pid in products["id"].astype(str):
            base = clicks.get(pid, 0)
            trend = base + np.cumsum(np.random.normal(0.4, 0.3, 30))
            history[pid] = np.clip(trend, a_min=0, a_max=None).tolist()

        scaler = MinMaxScaler()
        seqs, targets = [], []
        for lst in history.values():
            if len(lst) >= 10:
                scaled = scaler.fit_transform(np.array(lst).reshape(-1,1))
                for i in range(10, len(scaled)):
                    seqs.append(scaled[i-10:i])
                    targets.append(scaled[i])

        if not seqs:
            return None, None, history

        X = torch.tensor(np.array(seqs), dtype=torch.float32)
        y = torch.tensor(np.array(targets), dtype=torch.float32)

        model = LSTMModel()
        opt = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = nn.MSELoss()

        model.train()
        for _ in range(400):
            opt.zero_grad()
            loss = loss_fn(model(X), y)
            loss.backward()
            opt.step()

        model.eval()
        return model, scaler, history

    lstm_model, scaler, likes_history = get_lstm()

# ==================== FEATURES VISUELLES ====================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

@st.cache_resource
def get_features():
    embeddings = {}
    with torch.no_grad():
        for _, row in products.iterrows():
            pid = str(row["id"])
            path = os.path.join("images", row["image"])
            if not os.path.exists(path):
                embeddings[pid] = None
                continue
            img = Image.open(path).convert("RGB")
            tensor = transform(img).unsqueeze(0)
            feat = resnet(tensor).flatten().numpy()
            norm = np.linalg.norm(feat)
            embeddings[pid] = feat / norm if norm > 0 else None

    valid = [(k,v) for k,v in embeddings.items() if v is not None]
    knn, ids, clusters = None, [], {}
    if len(valid) >= 2:
        vecs = np.array([v for _,v in valid])
        ids = [k for k,_ in valid]
        knn = NearestNeighbors(n_neighbors=8, metric="cosine").fit(vecs)
        n = min(6, len(valid))
        labels = KMeans(n_clusters=n, random_state=42, n_init="auto").fit_predict(vecs)
        clusters = {ids[i]: int(labels[i]) for i in range(len(ids))}
    return embeddings, knn, ids, clusters

emb, knn_model, knn_ids, deep_clusters = get_features()

# ==================== FONCTIONS ====================
def recommend(pid: str, k=5):
    if knn_model is None or emb.get(pid) is None: return []
    vec = emb[pid].reshape(1, -1)
    _, idx = knn_model.kneighbors(vec)
    return [knn_ids[i] for i in idx[0] if knn_ids[i] != pid][:k]

def bert_sentiment(text: str):
    if not text.strip(): return "Neutre", 3
    res = bert(text[:512])[0]
    stars = int(res["label"].split()[0])
    if stars >= 4: return "Positif", stars
    if stars <= 2: return "Négatif", stars
    return "Neutre", stars

def predict_future_likes(pid: str, days=7):
    if lstm_model is None or pid not in likes_history:
        return clicks.get(pid, 0)
    hist = np.array(likes_history[pid][-10:]).reshape(-1, 1)
    if len(hist) < 10:
        return clicks.get(pid, 0)
    scaled = scaler.transform(hist).astype(np.float32)
    seq = torch.tensor(scaled).unsqueeze(0)  # shape: (1, 10, 1)

    preds = []
    with torch.no_grad():
        current_seq = seq
        for _ in range(days):
            pred = lstm_model(current_seq)
            preds.append(pred.item())
            # On décale et ajoute la prédiction
            new_val = torch.tensor([[[pred.item()]]])
            current_seq = torch.cat([current_seq[:, 1:, :], new_val], dim=1)
    final = scaler.inverse_transform(np.array([[preds[-1]]]))[0][0]
    return max(0, int(final))

def add_to_cart(pid): st.session_state.cart[pid] = st.session_state.cart.get(pid, 0) + 1
def get_avg_stars(pid):
    lst = reviews.get(pid, [])
    return round(sum(r["stars"] for r in lst)/len(lst), 2) if lst else 0.0

def get_champion():
    if not clicks: return None
    best = max(clicks, key=clicks.get)
    row = products[products["id"] == int(best)].iloc[0]
    return {"id":best, "name":row["name"], "cat":row["category"], "img":row["image"], "likes":clicks[best], "stars":get_avg_stars(best)}
champion = get_champion()

def stars(n): n=max(0,min(5,int(round(n)))); return "⭐"*n + "☆"*(5-n) + f" {n:.1f}/5"

def generate_pdf():
    buffer = BytesIO()
    doc = SimpleDocTemplate(buffer, pagesize=A4)
    styles = getSampleStyleSheet()
    story = [Paragraph("SportShop IA – Avis Clients", styles["Title"]), Spacer(1,20)]
    data = [["Produit","Note","Commentaire","Sentiment","Date"]]
    for pid,lst in reviews.items():
        p = products[products["id"]==int(pid)].iloc[0]
        for r in lst:
            sent,_ = bert_sentiment(r["comment"])
            data.append([p["name"], f"{r['stars']}/5", r["comment"], sent, datetime.now().strftime("%d/%m/%Y")])
    table = Table(data, colWidths=[120,50,220,80,80])
    table.setStyle(TableStyle([("BACKGROUND",(0,0),(-1,0),colors.gold), ("GRID",(0,0),(-1,-1),0.5,colors.grey)]))
    story.append(table)
    doc.build(story)
    return buffer.getvalue()


# ==================== HEADER ====================
st.markdown("<h1 class='big-title'>SportShop IA</h1>", unsafe_allow_html=True)
st.markdown("<p class='subtitle'>La boutique la plus intelligente de l'univers</p>", unsafe_allow_html=True)

if champion:
    st.markdown("<div class='champion-badge'>PRODUIT LE PLUS AIMÉ</div>", unsafe_allow_html=True)
    
    c1, c2 = st.columns([1, 4])
    
    with c1:
        st.image(
            os.path.join("images", champion["img"]),
            width=None  # ← Image remplit toute la largeur de la colonne
        )
        st.markdown(
            f"<p class='like-counter'>❤️ {champion['likes']} likes</p>",
            unsafe_allow_html=True
        )
    
    with c2:
        st.markdown("<div class='gold-border'>", unsafe_allow_html=True)
        st.markdown(f"### {champion['name']}")
        st.markdown(f"**Catégorie :** {champion['cat']}")
        st.markdown(stars(champion['stars']))
        st.markdown("</div>", unsafe_allow_html=True)

# ==================== MENU ====================
page = st.sidebar.radio("Navigation", [
    "Catalogue","Panier","Avis","Reco CNN","Sentiment BERT","Clusters","Prédiction Likes (LSTM)","Exporter","Admin"
])

import streamlit as st
import os
import json
import pandas as pd
# ... (tes autres imports : torch, plt, etc.)

# ==================== HEADER ====================
st.markdown("<h1 class='big-title'>SportShop IA</h1>", unsafe_allow_html=True)
st.markdown("<p class='subtitle'>La boutique la plus intelligente de l'univers</p>", unsafe_allow_html=True)

if champion:
    st.markdown("<div class='champion-badge'>PRODUIT LE PLUS AIMÉ</div>", unsafe_allow_html=True)
    
    c1, c2 = st.columns([1, 4])
    
    with c1:
        st.image(
            os.path.join("images", champion["img"]),
            width=None  # Remplit la largeur de la colonne
        )
        st.markdown(
            f"<p class='like-counter'>❤️ {champion['likes']} likes</p>",
            unsafe_allow_html=True
        )
    
    with c2:
        st.markdown("<div class='gold-border'>", unsafe_allow_html=True)
        st.markdown(f"### {champion['name']}")
        st.markdown(f"**Catégorie :** {champion['cat']}")
        st.markdown(stars(champion['stars']))
        st.markdown("</div>", unsafe_allow_html=True)

# ==================== PAGES ====================
if page == "Catalogue":
    st.subheader("Catalogue")
    search = st.text_input("Rechercher")
    df = products.copy()
    if champion:
        df = df[df["id"] != int(champion["id"])]
    if search:
        df = df[df["name"].str.contains(search, case=False, na=False) |
              df["category"].str.contains(search, case=False, na=False)]

    cols = st.columns(4)

    for i, row in df.iterrows():
        with cols[i % 4]:
            pid = str(row["id"])
            likes = clicks.get(pid, 0)
            avg = get_avg_stars(pid)
            img = os.path.join("images", row["image"])
            
            st.image(
                img if os.path.exists(img) else "https://via.placeholder.com/300",
                width=None  # Image adapte à la largeur de la colonne
            )
            
            st.markdown(f"### {row['name']}")
            st.caption(f"ID: `{pid}` • {row['category']}")
            st.markdown(f"**{stars(avg)} ❤️ {likes} likes**")

            c1, c2 = st.columns(2)
            with c1:
                if st.button("Like", key=f"l{pid}"):
                    clicks[pid] = likes + 1
                    with open("data/clicks.json", "w", encoding="utf-8") as f:
                        json.dump(clicks, f, ensure_ascii=False, indent=2)
                    st.rerun()
            with c2:
                if st.button("Panier", key=f"c{pid}"):
                    add_to_cart(pid)
                    st.success("Ajouté !")

            with st.expander("Avis"):
                note = st.slider("Note", 1, 5, 5, key=f"n{pid}")
                com = st.text_area("Commentaire", key=f"t{pid}")
                if st.button("Envoyer", key=f"s{pid}"):
                    reviews.setdefault(pid, []).append({"stars": note, "comment": com})
                    with open("data/reviews.json", "w", encoding="utf-8") as f:
                        json.dump(reviews, f, ensure_ascii=False, indent=2)
                    st.success("Merci !")
                    st.rerun()

            with st.expander("Recommandations Deep"):
                st.markdown("<span class='dl-tag'>Partie 2 – CNN</span>", unsafe_allow_html=True)
                for rid in recommend(pid, 5):
                    r = products[products["id"] == int(rid)].iloc[0]
                    st.image(os.path.join("images", r["image"]), width=90)
                    st.caption(r["name"])

elif page == "Panier":
    st.subheader("Panier")
    if not st.session_state.cart:
        st.info("Panier vide")
    else:
        total = 0
        for pid, qty in st.session_state.cart.items():
            p = products[products["id"] == int(pid)].iloc[0]
            price = float(p.get("price", 59.90))
            subtotal = price * qty
            total += subtotal
            c1, c2, c3 = st.columns([1, 4, 2])
            with c1:
                st.image(os.path.join("images", p["image"]), width=80)
            with c2:
                st.write(f"**{p['name']}** × {qty}")
            with c3:
                st.write(f"{subtotal:.2f} €")
        st.markdown("---")
        st.markdown(f"### Total : **{total:.2f} €**")

elif page == "Avis":
    st.subheader("Tous les avis")
    for pid, lst in reviews.items():
        p = products[products["id"] == int(pid)].iloc[0]
        st.markdown(f"### {p['name']}")
        for r in lst:
            sent, _ = bert_sentiment(r["comment"])
            st.markdown(f"**{stars(r['stars'])}** {sent}")
            st.caption(r["comment"])
            st.markdown("---")

elif page == "Reco CNN":
    st.header("Partie 2 – Réseaux de convolution (ResNet50)")
    st.markdown("### Recommandations visuelles basées sur les features extraites par ResNet50")

    pid_select = st.selectbox("Choisis un produit pour voir les recommandations visuelles", products["id"])
    pid = str(pid_select)

    selected_row = products[products["id"] == int(pid)].iloc[0]
    st.image(os.path.join("images", selected_row["image"]), width=200, caption=f"Produit sélectionné : {selected_row['name']}")

    if st.button("Lancer les recommandations visuelles (Deep Learning)"):
        with st.spinner("Calcul des similarités visuelles..."):
            recommendations = recommend(pid, k=8)

        if not recommendations:
            st.warning("Pas assez d'images valides pour faire des recommandations.")
        else:
            st.success("Voici les 8 produits les plus similaires visuellement !")
            cols = st.columns(4)
            for i, rec_pid in enumerate(recommendations):
                with cols[i % 4]:
                    rec_row = products[products["id"] == int(rec_pid)].iloc[0]
                    img_path = os.path.join("images", rec_row["image"])
                    st.image(img_path, width=None)  # Image adapte à la colonne
                    st.caption(f"**{rec_row['name']}**")
                    st.caption(f"ID: {rec_pid}")

elif page == "Sentiment BERT":
    st.header("Partie 3 – NLP Deep Learning")
    txt = st.text_area("Commentaire", height=120)
    if txt.strip():
        label, s = bert_sentiment(txt)
        if label == "Positif":
            st.success(f"{label} – {s} étoiles")
        elif label == "Négatif":
            st.error(f"{label} – {s} étoiles")
        else:
            st.info(f"{label} – {s} étoiles")

elif page == "Clusters":
    st.header("Partie 5 – Deep non supervisé")
    for c in sorted(set(deep_clusters.values())):
        st.subheader(f"Cluster {c+1}")
        pids = [p for p, l in deep_clusters.items() if l == c]
        cols = st.columns(6)
        for i, pid in enumerate(pids[:18]):
            with cols[i % 6]:
                r = products[products["id"] == int(pid)].iloc[0]
                st.image(os.path.join("images", r["image"]), width=None)  # Image adapte à la colonne
                st.caption(r["name"])

elif page == "Prédiction Likes (LSTM)":
    st.header("Partie 4 – Réseaux récurrents (LSTM)")
    st.markdown("### Prédiction des likes dans 7 jours")
    pid = st.selectbox("Produit", products["id"])
    current = clicks.get(str(pid), 0)
    pred = predict_future_likes(str(pid), 7)
    c1, c2 = st.columns(2)
    with c1:
        st.metric("Likes actuels", current)
    with c2:
        st.metric("Prédiction J+7", pred, delta=pred - current)

    if str(pid) in likes_history and len(likes_history[str(pid)]) >= 10:
        past = likes_history[str(pid)]
        future = []
        seq = torch.tensor(scaler.transform(np.array(past[-10:]).reshape(-1, 1)), dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            temp_seq = seq.clone()
            for _ in range(7):
                p = lstm_model(temp_seq)
                future.append(p.item())
                new_val = torch.tensor([[[p.item()]]])
                temp_seq = torch.cat([temp_seq[:, 1:, :], new_val], dim=1)
        future = scaler.inverse_transform(np.array(future).reshape(-1, 1)).flatten()
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.plot(range(-len(past), 0), past, 'o-', label="Historique")
        ax.plot(range(1, 8), future, 'o-', color='red', label="Prédiction LSTM")
        ax.set_title(f"Produit {pid}")
        ax.legend()
        st.pyplot(fig)

elif page == "Exporter":
    c1, c2 = st.columns(2)
    with c1:
        if st.button("PDF"):
            b64 = base64.b64encode(generate_pdf()).decode()
            href = f'<a href="data:application/pdf;base64,{b64}" download="avis.pdf">Télécharger PDF</a>'
            st.markdown(href, unsafe_allow_html=True)
    with c2:
        if st.button("CSV"):
            data = [{"Produit": products[products["id"] == int(p)].iloc[0]["name"],
                     "Note": r["stars"], "Commentaire": r["comment"]} for p, l in reviews.items() for r in l]
            df = pd.DataFrame(data)
            csv = df.to_csv(index=False).encode()
            b64 = base64.b64encode(csv).decode()
            href = f'<a href="data:text/csv;base64,{b64}" download="avis.csv">Télécharger CSV</a>'
            st.markdown(href, unsafe_allow_html=True)

elif page == "Admin":
    st.title("Administration SportShop IA")
    st.markdown("Gérez ici les likes, avis et données internes de l’application.")

    # ==============================
    # 1️⃣ Gestion des LIKES
    # ==============================
    st.header("❤️ Gestion des Likes")

    st.write("### Total des likes :", sum(clicks.values()))

    # Liste des likes
    df_likes = pd.DataFrame([
        {"ID": pid, "Produit": products[products["id"] == int(pid)].iloc[0]["name"], "Likes": likes}
        for pid, likes in clicks.items()
    ])
    st.dataframe(df_likes)

    st.subheader("Modifier les likes d’un produit")
    pid_like = st.selectbox("Choisir un produit", products["id"])
    new_like = st.number_input("Nouveau nombre de likes", 0, 10000, clicks.get(str(pid_like), 0))

    if st.button(" Sauvegarder les likes"):
        clicks[str(pid_like)] = int(new_like)
        json.dump(clicks, open("data/clicks.json", "w"), indent=2, ensure_ascii=False)
        st.success("Likes mis à jour.")
        st.rerun()

    if st.button(" Réinitialiser TOUS les likes"):
        clicks.clear()
        json.dump(clicks, open("data/clicks.json", "w"), indent=2, ensure_ascii=False)
        st.success("Tous les likes ont été remis à zéro.")
        st.rerun()

    st.markdown("---")

    # ==============================
    # 2️⃣ Gestion des AVIS
    # ==============================
    st.header(" Gestion des Avis")

    total_reviews = sum(len(lst) for lst in reviews.values())
    st.write(f"### Total des avis : {total_reviews}")

    pid_review = st.selectbox("Choisir un produit pour gérer les avis", products["id"])
    pid_str = str(pid_review)
    lst = reviews.get(pid_str, [])

    if not lst:
        st.info("Aucun avis pour ce produit.")
    else:
        for idx, r in enumerate(lst):
            st.markdown(f"#### Avis #{idx + 1}")
            st.write(f"⭐ Note : {r['stars']}")
            st.write(f"💬 Commentaire : {r['comment']}")

            new_stars = st.slider(f"Modifier étoile #{idx}", 1, 5, r["stars"], key=f"mod_s_{idx}")
            new_comment = st.text_area(f"Modifier commentaire #{idx}", r["comment"], key=f"mod_c_{idx}")

            col1, col2 = st.columns(2)

            with col1:
                if st.button(f" Sauver avis #{idx}", key=f"save_rev_{idx}"):
                    r["stars"] = new_stars
                    r["comment"] = new_comment
                    json.dump(reviews, open("data/reviews.json", "w"), indent=2, ensure_ascii=False)
                    st.success("Avis modifié.")
                    st.rerun()

            with col2:
                if st.button(f"🗑 Supprimer avis #{idx}", key=f"del_rev_{idx}"):
                    lst.pop(idx)
                    reviews[pid_str] = lst
                    json.dump(reviews, open("data/reviews.json", "w"), indent=2, ensure_ascii=False)
                    st.success("Avis supprimé.")
                    st.rerun()

    if st.button(" Supprimer TOUS les avis de ce produit"):
        reviews[pid_str] = []
        json.dump(reviews, open("data/reviews.json", "w"), indent=2, ensure_ascii=False)
        st.success("Tous les avis de ce produit ont été supprimés.")
        st.rerun()

    if st.button(" Réinitialiser TOUTES les reviews"):
        reviews.clear()
        json.dump(reviews, open("data/reviews.json", "w"), indent=2, ensure_ascii=False)
        st.success("Toutes les reviews ont été supprimées.")
        st.rerun()

    st.markdown("---")

    # ==============================
    # 3️⃣ EXPORT DATA
    # ==============================
    st.header(" Export des données")

    # Export likes
    if st.button(" Exporter les likes en CSV"):
        df = pd.DataFrame([
            {"product_id": pid, "likes": like}
            for pid, like in clicks.items()
        ])
        st.download_button("Télécharger likes.csv", df.to_csv(index=False).encode(), "likes.csv")

    # Export reviews
    if st.button(" Exporter les avis en CSV"):
        rows = []
        for pid, lst in reviews.items():
            for r in lst:
                rows.append({"product_id": pid, "stars": r["stars"], "comment": r["comment"]})

        df = pd.DataFrame(rows)
        st.download_button("Télécharger reviews.csv", df.to_csv(index=False).encode(), "reviews.csv")


st.sidebar.success("Projet Deep Learning")