96 lines
3.2 KiB
Plaintext
96 lines
3.2 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 在上一讲中,我们展示了几种实现简易平均加权的方式\n",
|
||
" - 在这里,做一些一些准备工作,从而实现后续搭建自注意力模块"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class BLM(nn.Module):\n",
|
||
" def __init__(self,vocab_size):\n",
|
||
" super().__init__()\n",
|
||
" self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)\n",
|
||
" \n",
|
||
" def forward(self,idx,targets = None):\n",
|
||
"\n",
|
||
" logits = self.token_embedding_table(idx) # (B,T) -> (B,T,C) # 这里我们通过Embedding操作直接得到预测分数\n",
|
||
" # 这里的预测分数过程与二分类或者多分类的分数是大致相同的\n",
|
||
"\n",
|
||
" \n",
|
||
" if targets is None:\n",
|
||
" loss = None\n",
|
||
" else: \n",
|
||
" B, T, C = logits.shape\n",
|
||
" logits = logits.view(B*T, C)\n",
|
||
" targets = targets.view(B*T) # 这里我们调整一下形状,以符合torch的交叉熵损失函数对于输入的变量的要求\n",
|
||
" loss = F.cross_entropy(logits, targets)\n",
|
||
"\n",
|
||
" return logits, loss\n",
|
||
"\n",
|
||
" def generate(self, idx, max_new_tokens):\n",
|
||
" '''\n",
|
||
" idx 是现在的输入的(B, T)序列 ,这是之前我们提取的batch的下标\n",
|
||
" max_new_tokens 是产生的最大的tokens数量\n",
|
||
" '''\n",
|
||
"\n",
|
||
" for _ in range(max_new_tokens):\n",
|
||
" \n",
|
||
" # 得到预测的结果\n",
|
||
" logits,_ = self(idx) # _ 表示省略,用于不获取相对应的函数返回值\n",
|
||
" \n",
|
||
" # 只关注最后一个的预测 (B,T,C)\n",
|
||
" logits = logits[:, -1, :] # becomes (B, C)\n",
|
||
" # 对概率值应用softmax\n",
|
||
" probs = F.softmax(logits, dim=-1) # (B, C)\n",
|
||
" # nn.argmax\n",
|
||
" # 对input的每一行做n_samples次取值,输出的张量是每一次取值时input张量对应行的下标,也即找到概率值输出最大的下标,也对应着最大的编码\n",
|
||
" idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
|
||
" # 将新产生的编码加入到之前的编码中,形成新的编码\n",
|
||
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
|
||
"\n",
|
||
" return idx "
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pytorch",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|