阅读源码-理解pytorch

it2025-11-11  10

文章目录

tokenization.pyload_vocab(vocab_file)whitespace_tokenize(text)class BertTokenizer(object)WordpieceTokenizer(object)tokenize(self, text) 测试 在使用训练好的模型预测文本时,需要对文本进行tokenize,然后在将tokenize转为index序列,最后传入模型。

tokenization.py

load_vocab(vocab_file)

def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() index = 0 with open(vocab_file, "r", encoding="utf-8") as reader: while True: token = reader.readline() if not token: break token = token.strip() vocab[token] = index index += 1 return vocab

传统的dict()是按照hash存储的,只保存了键值对信息,故每次输出键值对的顺序都不同。 OrderedDict()保存了键值对插入时的顺序信息,根据插入顺序对字典进行排序,可以使字典有序。

理解:collections.OrderedDict() OrderedDict是记住键首次插入顺序的字典。如果新条目覆盖现有条目,则原始插入位置保持不变。 参考资料:https://zhuanlan.zhihu.com/p/110407087

whitespace_tokenize(text)

在一段文本中运行基本的空字符清洗和拆分

def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" # 去除开头、结尾的空字符 text = text.strip() if not text: return [] # 默认按空字符进行拆分 tokens = text.split() return tokens

class BertTokenizer(object)

class BertTokenizer(object): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) # 加载词汇文件vocab.txt,返回一个有序字典 self.vocab = load_vocab(vocab_file) # 根据索引得到对应的词汇 self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split) # 用于调用运行分词函数的类 self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.max_len = max_len if max_len is not None else int(1e12)

WordpieceTokenizer(object)

def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word

tokenize(self, text)

将一段文本标记成它的单词段。

使用贪婪的最大长度优先匹配算法

假设vocab.txt中的最长词条所含字符个数为n,则取被处理文本序列中的一段字符作为匹配字段,在vocab.txt中查找。若vocab.txt中存在这样一个字符个数为n的词,则匹配成功,匹配字段作为一个完整的词被切分出来;如果vocab.txt中找不到这样的一个字符个数为n的词,则匹配失败。匹配字段去掉最后一个字符,剩下的字符作为新的匹配字段,回到上述步骤,重新匹配,一直循环,直到切分成功为止。完成一轮匹配,并切分出一个词,之后再按上述步骤进行下去,直到切分出所有词为止。

使用给定的词汇执行标记化

def tokenize(self, text): """Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary. For example: input = "unaffable" output = ["un", "##aff", "##able"] Args: text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. Returns: A list of wordpiece tokens. """ output_tokens = [] # 见目录:whitespace_tokenize(text) for token in whitespace_tokenize(text): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens

测试

最新回复(0)