Commit f96452a9 authored by Erik Senn's avatar Erik Senn
Browse files

Replace utils.py

parent f375fce5
Loading
Loading
Loading
Loading
+0 −67
Original line number Diff line number Diff line
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch


# plot embedding vectors using PCA
@@ -34,69 +33,3 @@ def plot_embedding_pca(embedding_vectors, labels=None):

    plt.title("Embedding Visualization (using 2 Principal Components)")
    plt.show()


# plot torch model
# from https://github.com/Atcold/NYU-DLSP20/blob/master/res/plot_lib.py
def plot_model(X, y, model):
    model.cpu()
    mesh = np.arange(-1.1, 1.1, 0.01)
    xx, yy = np.meshgrid(mesh, mesh)
    with torch.no_grad():
        data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
        Z = model(data).detach()
    Z = np.argmax(Z, axis=1).reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
    plot_data(X, y)


##############################################################################
## UNUSED
##############################################################################


# # plot embeddings as linechart
# def plot_embeddings_linechart(embedding_vectors, labels=None):
#     """
#     Plots the embeddings for the given indices.

#     Args:
#     - embeddings: The embeddings tensor.
#     - indices: List of indices to plot.

#     Example usage:
#     plot_embeddings_linechart(position_embeddings, [1, 2])

#     """
#     colors = [
#         "blue",
#         "orange",
#         "green",
#         "red",
#         "purple",
#         "brown",
#     ]  # Define colors to use for each plot

#     if labels is None:
#         labels = range(len(embedding_vectors))

#     plt.figure(figsize=(10, 6))  # Set the figure size

#     # Plot each embedding based on the index
#     for i, vector in enumerate(embedding_vectors):
#         plt.plot(
#             vector,
#             label=f"Position {i}",
#             color=colors[i % len(colors)],
#         )

#     # Add legend
#     plt.legend()

#     # Add title and axis labels for clarity
#     plt.title("Position Embeddings")
#     plt.xlabel("Embedding Dimension")
#     plt.ylabel("Embedding Value")

#     # Display the plot
#     plt.show()