nanoGPT-Tutorial-CN/Lecture/l4.ipynb

96 lines
3.2 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 在上一讲中,我们展示了几种实现简易平均加权的方式\n",
" - 在这里,做一些一些准备工作,从而实现后续搭建自注意力模块"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class BLM(nn.Module):\n",
" def __init__(self,vocab_size):\n",
" super().__init__()\n",
" self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)\n",
" \n",
" def forward(self,idx,targets = None):\n",
"\n",
" logits = self.token_embedding_table(idx) # (B,T) -> (B,T,C) # 这里我们通过Embedding操作直接得到预测分数\n",
" # 这里的预测分数过程与二分类或者多分类的分数是大致相同的\n",
"\n",
" \n",
" if targets is None:\n",
" loss = None\n",
" else: \n",
" B, T, C = logits.shape\n",
" logits = logits.view(B*T, C)\n",
" targets = targets.view(B*T) # 这里我们调整一下形状以符合torch的交叉熵损失函数对于输入的变量的要求\n",
" loss = F.cross_entropy(logits, targets)\n",
"\n",
" return logits, loss\n",
"\n",
" def generate(self, idx, max_new_tokens):\n",
" '''\n",
" idx 是现在的输入的(B, T)序列 这是之前我们提取的batch的下标\n",
" max_new_tokens 是产生的最大的tokens数量\n",
" '''\n",
"\n",
" for _ in range(max_new_tokens):\n",
" \n",
" # 得到预测的结果\n",
" logits,_ = self(idx) # _ 表示省略,用于不获取相对应的函数返回值\n",
" \n",
" # 只关注最后一个的预测 (B,T,C)\n",
" logits = logits[:, -1, :] # becomes (B, C)\n",
" # 对概率值应用softmax\n",
" probs = F.softmax(logits, dim=-1) # (B, C)\n",
" # nn.argmax\n",
" # 对input的每一行做n_samples次取值输出的张量是每一次取值时input张量对应行的下标也即找到概率值输出最大的下标也对应着最大的编码\n",
" idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
" # 将新产生的编码加入到之前的编码中,形成新的编码\n",
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
"\n",
" return idx "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}