Commit d9d410bf authored by Erik Senn's avatar Erik Senn
Browse files

Upload New File

parent a3253f70
Loading
Loading
Loading
Loading

notebooks/utils.py

0 → 100644
+102 −0
Original line number Diff line number Diff line
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch


# plot embedding vectors using PCA
def plot_embedding_pca(embedding_vectors, labels=None):
    # embedding_vectors: list of embedding vectors
    # labels: list of labels (e.g. token name) for each embedding vector

    if labels is None:
        labels = range(len(embedding_vectors))
    reduced_vectors = PCA(n_components=2, random_state=42).fit_transform(
        embedding_vectors
    )  # Apply dimensionality reduction

    # Create the plot
    plt.figure(figsize=(10, 6))
    origin = np.zeros((len(embedding_vectors), 2))
    for i, label in enumerate(labels):
        plt.quiver(
            origin[:, 0],
            origin[:, 1],
            reduced_vectors[i, 0],
            reduced_vectors[i, 1],
            angles="xy",
            scale_units="xy",
            scale=1,
            width=0.002,
        )
        plt.scatter(reduced_vectors[i, 0], reduced_vectors[i, 1])
        plt.annotate(label, (reduced_vectors[i, 0], reduced_vectors[i, 1]))

    plt.title("Embedding Visualization (using 2 Principal Components)")
    plt.show()


# plot torch model
# from https://github.com/Atcold/NYU-DLSP20/blob/master/res/plot_lib.py
def plot_model(X, y, model):
    model.cpu()
    mesh = np.arange(-1.1, 1.1, 0.01)
    xx, yy = np.meshgrid(mesh, mesh)
    with torch.no_grad():
        data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
        Z = model(data).detach()
    Z = np.argmax(Z, axis=1).reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
    plot_data(X, y)


##############################################################################
## UNUSED
##############################################################################


# # plot embeddings as linechart
# def plot_embeddings_linechart(embedding_vectors, labels=None):
#     """
#     Plots the embeddings for the given indices.

#     Args:
#     - embeddings: The embeddings tensor.
#     - indices: List of indices to plot.

#     Example usage:
#     plot_embeddings_linechart(position_embeddings, [1, 2])

#     """
#     colors = [
#         "blue",
#         "orange",
#         "green",
#         "red",
#         "purple",
#         "brown",
#     ]  # Define colors to use for each plot

#     if labels is None:
#         labels = range(len(embedding_vectors))

#     plt.figure(figsize=(10, 6))  # Set the figure size

#     # Plot each embedding based on the index
#     for i, vector in enumerate(embedding_vectors):
#         plt.plot(
#             vector,
#             label=f"Position {i}",
#             color=colors[i % len(colors)],
#         )

#     # Add legend
#     plt.legend()

#     # Add title and axis labels for clarity
#     plt.title("Position Embeddings")
#     plt.xlabel("Embedding Dimension")
#     plt.ylabel("Embedding Value")

#     # Display the plot
#     plt.show()