本文在python3.7,pytorch1.5.1环境下编译,可直接执行。
#coding:utf8
# @Time : 2020/10/21 下午3:27
# @Author : xxx
# @File : cbow.py
# @Software: PyCharm,python3.7,pytorch1.5.1
import torch
from torch import nn
# from torch.autograd import Variable
from torch.nn import functional as F
# import numpy as np
import torch.optim as opt
raw_text = """We are about to study the idea of a computational process .
Computational processes are abstract beings that inhabit computers .
As they evolve , processes manipulate other abstract things called data .
The evolution of a process is directed by a pattern of rules
called a program . People create programs to direct processes . In effect ,
we conjure the spirits of the computer with our spells .""".split()
import codecs
vocab = set(raw_text)
vocab_size = len(vocab)
word2idx = {word: i for i, word in enumerate(vocab)}
data = []
for i in range(2, len(raw_text)-2):
context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
target = raw_text[i]
data.append([context, [target]])
class CBOW(nn.Module):
def __init__(self, vocab_size, embed_size, window_size):
super(CBOW, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.l1 = nn.Linear(embed_size*window_size, 128)
self.l2 = nn.Linear(128,vocab_size)
def forward(self, x):
input = self.embed(x).view([1, -1])
out = F.relu(self.l1(input))
out = self.l2(out)
out = F.log_softmax(out, dim=-1)
return out
def make_context_vec(context, word2id):
idxs = [word2id[word] for word in context]
tensor = torch.tensor(idxs)
return tensor
model = CBOW(vocab_size, 32, 4)
# # print(model(torch.tensor([2, 3, 4, 5])))
op = opt.Adam(model.parameters(),lr=0.001)
#
loss_f = nn.NLLLoss()
epoch = 40
def train():
for _ in range(epoch):
total_loss = 0.0
for x, t in data:
v_src = make_context_vec(x, word2idx)
v_tgt = make_context_vec(t, word2idx)
model.zero_grad()
out = model(v_src)
loss = loss_f(out, v_tgt)
total_loss += loss.data
loss.backward()
op.step()
print(total_loss)
torch.save(model, "./models_y/cbow.pkl")
# import sys
import os
if os.path.exists("./models_y/cbow.pkl"):
model = torch.load("./models_y/cbow.pkl")
vocab_file = codecs.open("./models_y/cbow.vocab","r","utf8")
words = [word.strip() for word in vocab_file.readlines()]
#src ======>>>> process is directed by a
inputs = [words.index(word) for word in "process is by a".split(" ")]
out = model(torch.tensor(inputs))
word_id = torch.max(out,dim=1)
print(word_id[1].data[0])
print(words[word_id[1].data[0]])
else:
vocab_file = codecs.open("./models_y/cbow.vocab", "w", "utf8")
[vocab_file.write(word.strip()+"\n") for i, word in enumerate(vocab)]
train()