import numpy
import torch
import torch.nn as nn
import csv
import visualtorch
import matplotlib.pyplot as plt
from IPython.display import display, Markdown, Math
import sympy as sp
from sympy import symbols, expand, simplify, latex


def normalize(X, y, X_mean=None, X_std=None, y_mean=None, y_std=None):
    if X_mean == None:
        X_mean = X.mean(dim=0, keepdim=True)
        X_std = X.std(dim=0, keepdim=True)
        y_mean = y.mean()
        y_std = y.std()

    X_normalized = (X - X_mean) / X_std
    y_normalized = (y - y_mean) / y_std

    return X_normalized, y_normalized, X_mean, X_std, y_mean, y_std

def prepare_data(X, y, X_test, y_test, use_normalization=True, feature_names=None):
    """
    Prepare data for training, optionally applying normalization.

    Args:
        X, y: Training data
        X_test, y_test: Test data
        use_normalization: Either:
            - True: normalize all features
            - False: don't normalize
            - list of feature names: only normalize those specific features
        feature_names: List of all feature names (required if use_normalization is a list)

    Returns:
        X_train, y_train, X_test_data, y_test_data, norm_stats
        where norm_stats is a dict with normalization parameters or None
    """
    if use_normalization is True:
        # Normalize all features
        X_train, y_train, X_mean, X_std, y_mean, y_std = normalize(X, y)
        X_test_data, y_test_data, _, _, _, _ = normalize(X_test, y_test, X_mean, X_std, y_mean, y_std)
        norm_stats = {'X_mean': X_mean, 'X_std': X_std, 'y_mean': y_mean, 'y_std': y_std}
    elif use_normalization is False:
        # Don't normalize
        X_train, y_train = X, y
        X_test_data, y_test_data = X_test, y_test
        norm_stats = None
    elif isinstance(use_normalization, list):
        # Normalize only specific features
        if feature_names is None:
            raise ValueError("feature_names must be provided when use_normalization is a list")

        # Find indices of features to normalize
        normalize_indices = [i for i, name in enumerate(feature_names) if name in use_normalization]

        # Clone the data
        X_train = X.clone()
        X_test_data = X_test.clone()

        # Compute normalization stats for selected features only
        X_mean = X[:, normalize_indices].mean(dim=0, keepdim=True)
        X_std = X[:, normalize_indices].std(dim=0, keepdim=True)

        # Normalize selected features
        X_train[:, normalize_indices] = (X[:, normalize_indices] - X_mean) / X_std
        X_test_data[:, normalize_indices] = (X_test[:, normalize_indices] - X_mean) / X_std

        # Normalize y
        y_mean = y.mean()
        y_std = y.std()
        y_train = (y - y_mean) / y_std
        y_test_data = (y_test - y_mean) / y_std

        # Store stats with full dimensionality for compatibility
        X_mean_full = torch.zeros((1, X.shape[1]))
        X_std_full = torch.ones((1, X.shape[1]))
        X_mean_full[:, normalize_indices] = X_mean
        X_std_full[:, normalize_indices] = X_std

        norm_stats = {
            'X_mean': X_mean_full,
            'X_std': X_std_full,
            'y_mean': y_mean,
            'y_std': y_std,
            'normalized_indices': normalize_indices
        }
    else:
        raise ValueError("use_normalization must be True, False, or a list of feature names")

    return X_train, y_train, X_test_data, y_test_data, norm_stats

def compute_loss_for_display(y_pred, y_true, norm_stats):
    """
    Compute MAE loss in original scale for display.
    Handles both normalized and unnormalized cases.
    """
    if norm_stats is not None:
        # Denormalize predictions and targets
        y_pred_denorm = y_pred * norm_stats['y_std'] + norm_stats['y_mean']
        y_true_denorm = y_true * norm_stats['y_std'] + norm_stats['y_mean']
        return torch.abs(y_pred_denorm - y_true_denorm).mean().item()
    else:
        # Already in original scale
        return torch.abs(y_pred - y_true).mean().item()

def visualize(model, feature_names, numeric=False):
    """
    Visualize neural network architecture with labeled weights.
    Works for both single layer and multi-layer networks.

    Args:
        model: PyTorch model (nn.Linear or nn.Sequential)
        feature_names: list of input feature names
        numeric: If True, show actual weight values. If False, show symbolic labels.
    """
    from graphviz import Digraph

    dot = Digraph(comment='Neural Network', format='png')
    dot.attr(rankdir='LR', size='20,3', dpi='300')
    dot.attr(ranksep='2.0', nodesep='0.15')  # Increase horizontal spacing between layers
    dot.attr('node', shape='circle', style='filled', fillcolor='lightblue', fontsize='9', width='0.4', height='0.4')

    # Handle single linear layer
    if isinstance(model, nn.Linear):
        weights = model.weight.data.squeeze()
        bias = model.bias.data.item()

        # Input nodes
        with dot.subgraph() as s:
            s.attr(rank='same')
            for i, name in enumerate(feature_names):
                s.node(f'x{i}', name, fillcolor='lightgreen')

        # Output node
        dot.node('output', 'Price', fillcolor='orange')

        # Edges with weight labels
        for i, name in enumerate(feature_names):
            if numeric:
                weight = weights[i].item()
                label = f'{weight:+.3f}'
            else:
                label = f'<<I>w</I><SUB>{i+1}</SUB>>'
            dot.edge(f'x{i}', 'output', label=label, fontsize='9')

        # Add bias as a special node
        if numeric:
            bias_label = f'b={bias:+.3f}'
        else:
            bias_label = '<<I>b</I>>'
        dot.node('bias', bias_label, shape='box', fillcolor='lightyellow')
        dot.edge('bias', 'output', style='dashed')

    # Handle sequential model
    elif isinstance(model, nn.Sequential):
        layers = [l for l in model.children() if isinstance(l, (nn.Linear, nn.ReLU))]

        # Input nodes
        with dot.subgraph() as s:
            s.attr(rank='same')
            for i, name in enumerate(feature_names):
                s.node(f'x{i}', name, fillcolor='lightgreen')

        layer_idx = 0
        prev_nodes = [f'x{i}' for i in range(len(feature_names))]

        for layer in layers:
            if isinstance(layer, nn.Linear):
                out_features = layer.out_features
                weights = layer.weight.data
                bias = layer.bias.data

                # Create nodes for this layer
                current_nodes = []
                with dot.subgraph() as s:
                    s.attr(rank='same')
                    for i in range(out_features):
                        node_name = f'h{layer_idx}_{i}'
                        if layer_idx == len([l for l in layers if isinstance(l, nn.Linear)]) - 1:
                            # Last layer - output
                            s.node(node_name, 'Price', fillcolor='orange')
                        else:
                            # Hidden layer
                            s.node(node_name, f'<<I>h</I><SUB>{i+1}</SUB>>', fillcolor='lightblue')
                        current_nodes.append(node_name)

                # Connect previous layer to this layer with weight labels
                for i, curr_node in enumerate(current_nodes):
                    for j, prev_node in enumerate(prev_nodes):
                        # Show labels for small to medium networks
                        # (up to ~50 connections for readability)
                        if len(prev_nodes) * len(current_nodes) <= 50:
                            if numeric:
                                weight = weights[i, j].item()
                                label = f'{weight:+.2f}'
                            else:
                                # Use same notation as get_formula with HTML formatting
                                label = f'<<I>w</I><SUB>{layer_idx+1},{i+1},{j+1}</SUB>>'
                            dot.edge(prev_node, curr_node, label=label, fontsize='8')
                        else:
                            dot.edge(prev_node, curr_node)

                # Add bias nodes for this layer
                for i, curr_node in enumerate(current_nodes):
                    bias_node_name = f'bias_{layer_idx}_{i}'
                    if numeric:
                        bias_label = f'b={bias[i].item():+.3f}'
                    else:
                        bias_label = f'<<I>b</I><SUB>{layer_idx+1},{i+1}</SUB>>'
                    dot.node(bias_node_name, bias_label, shape='box', fillcolor='lightyellow')
                    dot.edge(bias_node_name, curr_node, style='dashed')

                prev_nodes = current_nodes
                layer_idx += 1

            elif isinstance(layer, nn.ReLU):
                # Add ReLU annotation
                for node in prev_nodes:
                    dot.node(node, fillcolor='lightyellow')

    # Render and display
    try:
        from IPython.display import display, Image as IPImage
        display(IPImage(dot.pipe(format='png')))
    except:
        # Fallback: save to file
        dot.render('network_visualization', view=False, cleanup=True)
        img = plt.imread('network_visualization.png')
        plt.figure(figsize=(12, 8))
        plt.axis('off')
        plt.imshow(img)
        plt.tight_layout()
        plt.show()

def summarize_statistics(X, features, y):
    for i, feature in enumerate(features):
        mean = X[:, i].mean().item()
        std = X[:, i].std().item()
        print(f"{feature:10s}: Mean = {mean:10.2f},   Standard Deviation = {std:8.2f}")

    price_mean = y.mean().item()
    price_std = y.std().item()
    print(f"{'Target':10s}: Mean = {price_mean:10.2f},   Standard Deviation = {price_std:8.2f}")

def histogram(data, feature_names=None):
    """
    Plot histograms for features or target.
    Works differently based on tensor shape:
    - 2D with multiple columns: plots histogram for each feature
    - 2D with 1 column or 1D: plots single histogram

    Args:
        data: torch.Tensor (X for features, y for target)
        feature_names: list of strings (required for multi-feature X)
    """
    # Handle 1D tensors
    if len(data.shape) == 1:
        data = data.reshape(-1, 1)

    num_features = data.shape[1]

    if num_features == 1:
        # Single histogram for target (y)
        plt.figure(figsize=(6, 4))
        plt.hist(data.numpy().flatten(), bins=30, edgecolor='black', alpha=0.7, color='orange')
        plt.xlabel('Price ($)')
        plt.ylabel('Frequency')
        plt.title('Price Distribution')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
    else:
        # Multiple histograms for features (X)
        if feature_names is None:
            feature_names = [f'Feature {i+1}' for i in range(num_features)]

        # Calculate grid size
        cols = 3
        rows = (num_features + cols - 1) // cols

        fig, axes = plt.subplots(rows, cols, figsize=(15, 4 * rows))
        if rows == 1:
            axes = axes.reshape(1, -1)

        for i in range(num_features):
            row = i // cols
            col = i % cols
            axes[row, col].hist(data[:, i].numpy(), bins=30, edgecolor='black', alpha=0.7)
            axes[row, col].set_xlabel(feature_names[i])
            axes[row, col].set_ylabel('Frequency')
            axes[row, col].set_title(f'{feature_names[i]} Distribution')
            axes[row, col].grid(True, alpha=0.3)

        # Remove empty subplots
        for i in range(num_features, rows * cols):
            row = i // cols
            col = i % cols
            fig.delaxes(axes[row, col])

        plt.tight_layout()
        plt.show()

def get_formula(model, features, numeric=False):
    """
    Generate a LaTeX formula showing the model's computation as a function of inputs.

    Args:
        model: PyTorch model (nn.Linear or nn.Sequential)
        features: List of feature names
        numeric: If True, show actual weight values. If False, show symbolic weights.

    Returns:
        LaTeX string representing the formula
    """
    # Handle simple linear model
    if isinstance(model, nn.Linear):
        weights = model.weight.data.squeeze()
        bias = model.bias.data.item()

        # Build align* environment with one term per row
        lines = ["\\begin{align*}"]
        lines.append("\\text{output} &= ")

        for i, feature in enumerate(features):
            if numeric:
                w = weights[i].item()
                sign = "+" if w >= 0 else "-"
                if i == 0:
                    lines.append(f"  {w:+.4f} \\cdot \\text{{{feature}}} \\\\")
                else:
                    lines.append(f"  &\\quad {sign} {abs(w):.4f} \\cdot \\text{{{feature}}} \\\\")
            else:
                if i == 0:
                    lines.append(f"  w_{{{i+1}}} \\cdot \\text{{{feature}}} \\\\")
                else:
                    lines.append(f"  &\\quad + w_{{{i+1}}} \\cdot \\text{{{feature}}} \\\\")

        # Add bias term
        if numeric:
            sign = "+" if bias >= 0 else "-"
            lines.append(f"  &\\quad {sign} {abs(bias):.4f}")
        else:
            lines.append(f"  &\\quad + b")
        lines.append("\\end{align*}")

        return "\n".join(lines)

    # Handle sequential model
    elif isinstance(model, nn.Sequential):
        layers = list(model.children())

        # Check for one hidden layer: Linear -> ReLU -> Linear (with possible Dropout)
        # Remove dropout layers for structure detection
        non_dropout_layers = [l for l in layers if not isinstance(l, nn.Dropout)]

        # Case 1: Two linear layers with NO activation (collapses to linear)
        if (len(non_dropout_layers) == 2 and
            isinstance(non_dropout_layers[0], nn.Linear) and
            isinstance(non_dropout_layers[1], nn.Linear)):

            layer1 = non_dropout_layers[0]
            layer2 = non_dropout_layers[1]

            W1 = layer1.weight.data  # Shape: (hidden_size, n_features)
            b1 = layer1.bias.data     # Shape: (hidden_size,)
            W2 = layer2.weight.data.squeeze()  # Shape: (hidden_size,)
            b2 = layer2.bias.data.item()

            hidden_size = W1.shape[0]

            if numeric:
                # Compute the collapsed weights: W_combined = W2 @ W1
                W_combined = W2 @ W1  # Shape: (n_features,)
                b_combined = (W2 @ b1).item() + b2

                lines = ["\\begin{align*}"]
                lines.append("\\text{output} &= ")

                for i, feature in enumerate(features):
                    w = W_combined[i].item()
                    sign = "+" if w >= 0 else "-"
                    if i == 0:
                        lines.append(f"  {w:+.4f} \\cdot \\text{{{feature}}} \\\\")
                    else:
                        lines.append(f"  &\\quad {sign} {abs(w):.4f} \\cdot \\text{{{feature}}} \\\\")

                bias_sign = "+" if b_combined >= 0 else "-"
                lines.append(f"  &\\quad {bias_sign} {abs(b_combined):.4f}")
                lines.append("\\end{align*}")

                return "\n".join(lines)
            else:
                # Show symbolic expansion of matrix multiplication
                lines = ["\\begin{align*}"]
                lines.append("\\text{output} &= ")

                for i, feature in enumerate(features):
                    # Build the sum: (w2_1 * w1_1i + w2_2 * w1_2i + ...)
                    terms = []
                    for h in range(hidden_size):
                        terms.append(f"w^{{(2)}}_{{{h+1}}} w^{{(1)}}_{{{h+1},{i+1}}}")

                    term_str = " + ".join(terms)
                    if i == 0:
                        lines.append(f"  \\left({term_str}\\right) \\cdot \\text{{{feature}}} \\\\")
                    else:
                        lines.append(f"  &\\quad + \\left({term_str}\\right) \\cdot \\text{{{feature}}} \\\\")

                # Bias: (w2_1 * b1_1 + w2_2 * b1_2 + ...) + b2
                bias_terms = [f"w^{{(2)}}_{{{h+1}}} b^{{(1)}}_{{{h+1}}}" for h in range(hidden_size)]
                bias_str = " + ".join(bias_terms)
                lines.append(f"  &\\quad + \\left({bias_str}\\right) + b^{{(2)}}")
                lines.append("\\end{align*}")

                return "\n".join(lines)

        # Case 2: Linear -> ReLU -> Linear (true non-linear network)
        if (len(non_dropout_layers) == 3 and
            isinstance(non_dropout_layers[0], nn.Linear) and
            isinstance(non_dropout_layers[1], nn.ReLU) and
            isinstance(non_dropout_layers[2], nn.Linear)):

            layer1 = non_dropout_layers[0]
            layer2 = non_dropout_layers[2]

            # Count total weights
            total_params = sum(p.numel() for p in [layer1.weight, layer1.bias, layer2.weight, layer2.bias])

            # If small enough, expand the full algebra
            if total_params < 50:
                W1 = layer1.weight.data  # Shape: (hidden_size, n_features)
                b1 = layer1.bias.data     # Shape: (hidden_size,)
                W2 = layer2.weight.data.squeeze()  # Shape: (hidden_size,)
                b2 = layer2.bias.data.item()

                hidden_size = W1.shape[0]

                lines = ["\\begin{align*}"]

                # Define hidden neurons
                for h in range(hidden_size):
                    lines.append(f"h_{{{h+1}}} &= \\text{{ReLU}}\\Bigg(")
                    for i, feature in enumerate(features):
                        if numeric:
                            w = W1[h, i].item()
                            sign = "+" if w >= 0 else "-"
                            if i == 0:
                                lines.append(f"  {w:+.4f} \\cdot \\text{{{feature}}}")
                            else:
                                lines.append(f" {sign} {abs(w):.4f} \\cdot \\text{{{feature}}}")
                        else:
                            if i == 0:
                                lines.append(f"  w^{{(1)}}_{{{h+1},{i+1}}} \\cdot \\text{{{feature}}}")
                            else:
                                lines.append(f" + w^{{(1)}}_{{{h+1},{i+1}}} \\cdot \\text{{{feature}}}")

                    if numeric:
                        bias_sign = "+" if b1[h].item() >= 0 else "-"
                        lines.append(f" {bias_sign} {abs(b1[h].item()):.4f}\\Bigg) \\\\[0.5em]")
                    else:
                        lines.append(f" + b^{{(1)}}_{{{h+1}}}\\Bigg) \\\\[0.5em]")

                # Output layer
                lines.append("\\text{output} &= ")
                for h in range(hidden_size):
                    if numeric:
                        w = W2[h].item()
                        sign = "+" if w >= 0 else "-"
                        if h == 0:
                            lines.append(f"{w:+.4f} \\cdot h_{{{h+1}}}")
                        else:
                            lines.append(f" {sign} {abs(w):.4f} \\cdot h_{{{h+1}}}")
                    else:
                        if h == 0:
                            lines.append(f"w^{{(2)}}_{{{h+1}}} \\cdot h_{{{h+1}}}")
                        else:
                            lines.append(f" + w^{{(2)}}_{{{h+1}}} \\cdot h_{{{h+1}}}")

                if numeric:
                    output_bias_sign = "+" if b2 >= 0 else "-"
                    lines.append(f" {output_bias_sign} {abs(b2):.4f}")
                else:
                    lines.append(f" + b^{{(2)}}")

                lines.append("\\end{align*}")

                return "\n".join(lines)
            else:
                # Too many parameters, use symbolic notation
                return "\\text{output} = W^{(2)} \\cdot \\text{ReLU}(W^{(1)} \\mathbf{x} + \\mathbf{b}^{(1)}) + b^{(2)}"

        # More complex sequential or other architectures
        layer_strs = []
        for i, layer in enumerate(layers):
            if isinstance(layer, nn.Linear):
                in_f, out_f = layer.in_features, layer.out_features
                layer_strs.append(f"\\text{{Linear}}({in_f} \\to {out_f})")
            elif isinstance(layer, nn.ReLU):
                layer_strs.append("\\text{ReLU}")
            elif isinstance(layer, nn.Dropout):
                layer_strs.append(f"\\text{{Dropout}}({layer.p})")
            else:
                layer_strs.append(str(type(layer).__name__))

        return " \\to ".join(layer_strs)

    else:
        return "\\text{Complex model - see model.summary()}"


def show_formula(model, features, numeric=False):
    display(Math(get_formula(model, features, numeric=numeric)))

def show_linear_collapse(model, features):
    """
    Use SymPy to algebraically show that stacked linear layers without activation
    collapse to a single linear layer.

    Args:
        model: PyTorch Sequential model with Linear layers (no activation)
        features: List of feature names
    """
    if not isinstance(model, nn.Sequential):
        print("Model must be Sequential")
        return

    layers = [l for l in model.children() if not isinstance(l, nn.Dropout)]

    # Check if it's stacked linear layers
    if not all(isinstance(l, nn.Linear) for l in layers):
        print("The network does not collapse to a single layer because it contains activation functions (this is good)")
        return

    # Get dimensions
    n_features = len(features)
    layer_sizes = [n_features] + [l.out_features for l in layers]

    # Create symbolic input variables
    x_syms = [symbols(f'x_{i+1}') for i in range(n_features)]

    # Start with inputs
    current = x_syms

    lines = ["\\begin{align*}"]

    # Forward pass through each layer symbolically
    for layer_idx, layer in enumerate(layers):
        in_size = layer_sizes[layer_idx]
        out_size = layer_sizes[layer_idx + 1]

        # Create weight symbols for this layer
        W_syms = []
        for i in range(out_size):
            row = []
            for j in range(in_size):
                # Use underscore instead of comma to avoid SymPy parsing issues
                row.append(symbols(f'w_{layer_idx+1}_{i+1}_{j+1}'))
            W_syms.append(row)

        # Create bias symbols
        b_syms = [symbols(f'b_{layer_idx+1}_{i+1}') for i in range(out_size)]

        # Compute output of this layer
        next_layer = []
        for i in range(out_size):
            # Compute weighted sum
            output_i = sum(W_syms[i][j] * current[j] for j in range(in_size)) + b_syms[i]
            next_layer.append(output_i)

        # Show intermediate layer (except for final output)
        if layer_idx < len(layers) - 1:
            for i in range(out_size):
                lines.append(f"h_{{{layer_idx+1}}}^{{({i+1})}} &= {latex(next_layer[i])} \\\\[0.3em]")

        current = next_layer

    # Show the final output (expanded form)
    lines.append("\\\\")
    lines.append("\\text{output} &= " + latex(current[0]) + " \\\\[0.5em]")

    lines.append("\\end{align*}")

    display(Markdown("### Simplification"))
    display(Markdown("Expanding the formula:"))
    display(Math("\n".join(lines)))

    # Now simplify to show it's just a linear function
    simplified = expand(current[0])

    # Extract coefficients for each input variable
    coeffs = []
    for x_var in x_syms:
        coeff = simplified.coeff(x_var)
        coeffs.append(coeff)

    # Get the constant term (bias)
    constant = simplified.as_coeff_add(*x_syms)[0]

    # Create effective weight symbols
    W_eff_syms = [symbols(f'W_{{{i+1}}}') for i in range(n_features)]
    b_eff = symbols('B')

    # Build the simplified form
    simplified_lines = ["\\begin{align*}"]
    simplified_lines.append("\\text{Simplifies to:} \\\\[0.5em]")
    simplified_lines.append("\\text{output} &= ")

    for i in range(n_features):
        if i == 0:
            simplified_lines.append(f"W_{{{i+1}}} \\cdot x_{{{i+1}}}")
        else:
            simplified_lines.append(f" + W_{{{i+1}}} \\cdot x_{{{i+1}}}")
    simplified_lines.append(" + B")

    simplified_lines.append("\\end{align*}")

    display(Markdown("### Formula simplifies to be equivalent to a single layer"))
    display(Math("\n".join(simplified_lines)))

    # Show what the effective weights are
    display(Markdown("**Where the ``effective parameters'' are:**"))
    coeff_lines = ["\\begin{align*}"]
    for i in range(n_features):
        coeff_lines.append(f"W_{{{i+1}}} &= {latex(coeffs[i])} \\\\")
    coeff_lines.append(f"B &= {latex(constant)}")
    coeff_lines.append("\\end{align*}")
    display(Math("\n".join(coeff_lines)))

    display(Markdown(f"**Notice:** Despite having {sum(l.weight.numel() + l.bias.numel() for l in layers)} parameters across {len(layers)} layers, " +
                    f"the effective function has only **{n_features + 1} parameters** (same as a single linear layer)! " +
                    "Without activation functions, multiple linear layers don't increase expressiveness."))