-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
29 lines (25 loc) · 763 Bytes
/
plot.py
File metadata and controls
29 lines (25 loc) · 763 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_clusters_2d(k, clusters):
sns.set_theme()
plt.figure(figsize=(6,6))
colors = sns.color_palette("viridis", k)
for i in range(k):
cluster_points = np.array(clusters[i])
plt.scatter(cluster_points[:,0], cluster_points[:,1], color=colors[i], label=f"Cluster {i}", alpha=0.7)
plt.xlabel("X1")
plt.ylabel("X2")
plt.title("2D K-Means Clusters")
plt.legend()
plt.grid(True)
plt.savefig("clusters.png")
def plot_loss(losses):
sns.set_theme()
plt.figure(figsize=(6,6))
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss function")
plt.grid(True)
plt.savefig("loss.png")