t sne数据降维及可视化

it2024-05-12  50

import torch import torch.nn.functional as F import numpy as np from sklearn.manifold import TSNE import matplotlib.pyplot as plt # 19个类别,每个类别2048维特征长度 features = torch.rand(19, 2048) tsne = TSNE(n_components=3) tsne.fit_transform(features) print(tsne.embedding_.shape) # plot x = tsne.embedding_[:, 0] y = tsne.embedding_[:, 1] # print(x) # print(y) # print(z) plt.figure() if tsne.embedding_.shape[1] == 2: ax = plt.gca() ax.set_xlabel('x') ax.set_ylabel('y') ax.scatter(x, y, c='r', s=20, alpha=0.5) for i in range(len(x)): ax.text(x[i],y[i],i) elif tsne.embedding_.shape[1] == 3: z = tsne.embedding_[:, 2] ax = plt.gca(projection='3d') ax.set_ylabel('z') ax.scatter(x, y, z, c='r', s=20, alpha=0.5) for i in range(len(x)): ax.text(x[i],y[i],z[i],i) plt.show()

 

最新回复(0)