所需导入的包
import collections
import math
import random
import sys
import time
import os
import torch
.utils
.data
as Data
import torch
import torch
.nn
as nn
1、读取并处理数据集
assert 'ptb.train.txt' in os
.listdir
("../data/ptb")
with open('../data/ptb/ptb.train.txt','r') as f
:
lines
= f
.readlines
()
raw_dataset
= [sentence
.split
() for sentence
in lines
]
1、1建立词语索引
counter
= collections
.Counter
([token
for sentence
in raw_dataset
for token
in sentence
])
counter
= dict(filter(lambda x
:x
[1] >= 5,counter
.items
()))
idx_to_token
= [token
for token
,_
in counter
.items
()]
token_to_idx
= {token
:idx
for idx
,token
in enumerate(idx_to_token
)}
dataset
= [[token_to_idx
[token
] for token
in sentence
if token
in token_to_idx
] for sentence
in raw_dataset
]
num_tokens
= sum([len(sentence
) for sentence
in dataset
])
1、2二次采样
def discard(idx
):
return random
.uniform
(0,1) < 1 - math
.sqrt
(1e-4 / counter
[idx_to_token
[idx
]] * num_tokens
)
subsampled_dataset
= [[token
for token
in sentence
if not discard
(token
)] for sentence
in dataset
]
num_tokens_2
= sum([len(sentence
) for sentence
in subsampled_dataset
])
print('二次采样后:',num_tokens_2
)
def compare_counts(token
):
return '# %s: before=%d, after=%d' % (token
, sum(
[st
.count
(token_to_idx
[token
]) for st
in dataset
]), sum(
[st
.count
(token_to_idx
[token
]) for st
in subsampled_dataset
]))
1、3 提取中心词和背景词
def get_centers_and_contexts(dataset
, max_window_size
):
centers
, contexts
= [], []
for st
in dataset
:
if len(st
) < 2:
continue
centers
+= st
for center_i
in range(len(st
)):
window_size
= random
.randint
(1, max_window_size
)
indices
= list(range(max(0, center_i
- window_size
),
min(len(st
), center_i
+ 1 + window_size
)))
indices
.remove
(center_i
)
contexts
.append
([st
[idx
] for idx
in indices
])
return centers
, contexts
all_centers
, all_contexts
= get_centers_and_contexts
(subsampled_dataset
, 5)
2、负采样
def get_negatives(all_contexts
, sampling_weights
, K
):
all_negatives
, neg_candidates
, i
= [], [], 0
population
= list(range(len(sampling_weights
)))
for contexts
in all_contexts
:
negatives
= []
while len(negatives
) < len(contexts
) * K
:
if i
== len(neg_candidates
):
i
, neg_candidates
= 0, random
.choices
(
population
, sampling_weights
, k
=int(1e5))
neg
, i
= neg_candidates
[i
], i
+ 1
if neg
not in set(contexts
):
negatives
.append
(neg
)
all_negatives
.append
(negatives
)
return all_negatives
sampling_weights
= [counter
[w
]**0.75 for w
in idx_to_token
]
all_negatives
= get_negatives
(all_contexts
, sampling_weights
, 5)
3、读取数据
class MyDataset(torch
.utils
.data
.Dataset
):
def __init__(self
, centers
, contexts
, negatives
):
assert len(centers
) == len(contexts
) == len(negatives
)
self
.centers
= centers
self
.contexts
= contexts
self
.negatives
= negatives
def __getitem__(self
, index
):
return (self
.centers
[index
], self
.contexts
[index
], self
.negatives
[index
])
def __len__(self
):
return len(self
.centers
)
def batchify(data
):
'''
用作DataLoader的参数collate_fn: 输入是一个长为batchsize的list,list中的每个元素都是Dataset类调用__getitem__得到的结果
'''
max_len
= max(len(c
) + len(n
) for _
, c
, n
in data
)
centers
, contexts_negatives
, masks
, labels
= [], [], [], []
for center
, context
, negative
in data
:
cur_len
= len(context
) + len(negative
)
centers
+= [center
]
contexts_negatives
+= [context
+ negative
+ [0] * (max_len
- cur_len
)]
masks
+= [[1] * cur_len
+ [0] * (max_len
- cur_len
)]
labels
+= [[1] * len(context
) + [0] * (max_len
- len(context
))]
return (torch
.tensor
(centers
).view
(-1, 1), torch
.tensor
(contexts_negatives
),
torch
.tensor
(masks
), torch
.tensor
(labels
))
batch_size
= 512
num_workers
= 0 if sys
.platform
.startswith
('win32') else 4
dataset
= MyDataset
(all_centers
,
all_contexts
,
all_negatives
)
data_iter
= Data
.DataLoader
(dataset
, batch_size
, shuffle
=True,
collate_fn
=batchify
,
num_workers
=num_workers
)
for batch
in data_iter
:
for name
, data
in zip(['centers', 'contexts_negatives', 'masks',
'labels'], batch
):
print(name
, 'shape:', data
.shape
)
break
4、跳字模型
embed
= nn
.Embedding
(num_embeddings
=20, embedding_dim
=4)
def skip_gram(center
, contexts_and_negatives
, embed_v
, embed_u
):
v
= embed_v
(center
)
u
= embed_u
(contexts_and_negatives
)
pred
= torch
.bmm
(v
, u
.permute
(0, 2, 1))
return pred
5、训练模型5、1定义模型的损失函数(二元交叉熵损失函数)
class SigmoidBinaryCrossEntropyLoss(nn
.Module
):
def __init__(self
):
super(SigmoidBinaryCrossEntropyLoss
, self
).__init__
()
def forward(self
, inputs
, targets
, mask
=None):
'''
:param inputs: Tensor shape: (batch_size, len)
:param targets: Tensor of the same shape as input
'''
inputs
, targets
, mask
= inputs
.float(), targets
.float(), mask
.float()
res
= nn
.functional
.binary_cross_entropy_with_logits
(
input=inputs
, target
=targets
, reduction
='none', weight
=mask
)
return res
.mean
(dim
=1)
loss
= SigmoidBinaryCrossEntropyLoss
()
5、2初始化模型参数
embed_size
= 100
net
= nn
.Sequential
(
nn
.Embedding
(num_embeddings
=len(idx_to_token
), embedding_dim
=embed_size
),
nn
.Embedding
(num_embeddings
=len(idx_to_token
), embedding_dim
=embed_size
)
)
5、3定义训练函数
def train(net
, lr
, num_epochs
):
device
= torch
.device
('cuda' if torch
.cuda
.is_available
() else 'cpu')
print("train on", device
)
net
= net
.to
(device
)
optimizer
= torch
.optim
.Adam
(net
.parameters
(), lr
=lr
)
for epoch
in range(num_epochs
):
start
, l_sum
, n
= time
.time
(), 0.0, 0
for batch
in data_iter
:
center
, context_negative
, mask
, label
= [d
.to
(device
) for d
in batch
]
pred
= skip_gram
(center
, context_negative
, net
[0], net
[1])
l
= (loss
(pred
.view
(label
.shape
), label
, mask
) *
mask
.shape
[1] / mask
.float().sum(dim
=1)).mean
()
optimizer
.zero_grad
()
l
.backward
()
optimizer
.step
()
l_sum
+= l
.cpu
().item
()
n
+= 1
print('epoch %d, loss %.2f, time %.2fs'
% (epoch
+ 1, l_sum
/ n
, time
.time
() - start
))
训练
train
(net
, 0.01, 10)
6、应用词嵌入模型
def get_similar_tokens(query_token
, k
, embed
):
W
= embed
.weight
.data
x
= W
[token_to_idx
[query_token
]]
cos
= torch
.matmul
(W
, x
) / (torch
.sum(W
* W
, dim
=1) * torch
.sum(x
* x
) + 1e-9).sqrt
()
_
, topk
= torch
.topk
(cos
, k
=k
+1)
topk
= topk
.cpu
().numpy
()
for i
in topk
[1:]:
print('cosine sim=%.3f: %s' % (cos
[i
], (idx_to_token
[i
])))
get_similar_tokens
('chip', 3, net
[0])