|

  為什麼Gemma採取Decoder-Only Transformer架構呢?

   

作者:黃佳琪、高煥堂 / 神櫻團隊成員

使用Gemma源碼解說

從典型的Transformer模型說起

最初在 2017 年的論文「Attention is All You Need」中介紹,Transformer 具有解碼器(Decoder)和編碼器(Encoder)部分。如下圖:

這種典型Encoder-Decoder架構的使用情境,通常是將一種類型的序列(Sequence)對應到另一種類型的序列,例如將法語翻譯成英語,或在聊天機器人獲取對話上下文並產生回應的情況下。在這些情況下,輸入和輸出之間存在質的差異,因此對它們使用不同的權重是有意義的。

例如,GPT-2是在維基百科文章等連續文本上進行訓練的,如果我們想使用上圖的典型Encoder-Decoder架構,就必須進行任意截斷來確定Encoder將處理哪一部分,以及哪一部分由Decoder負責。因此,由於GPT-2在文字生成方面的出色性能,趨勢是轉向純解碼器(Decoder-Only Transformer)模型。僅解碼器模型的強大之處在於它們不僅能夠模仿類似人類的文本,而且還能夠創造性地做出回應。他們可以寫故事、回答問題,甚至進行自然流暢的對話。這種功能使它們在廣泛的應用中非常有用,從聊天機器人和數位助理到內容創建、抽象總結和講故事。

因此,GPT-2和Gemma並不需要 Transformer 架構的編碼器部分,因為模型使用只能查看先前標記的屏蔽自注意力。不需要編碼器,因為模型不需要學習輸入序列的表示。此外,Decoder-Only Transformer 架構在 GPT-3、ChatGPT、GPT-4、PaLM、LaMDa 和 Falcon 等流行的大型語言模型中受到關注。

當今流行的Decoder-Only架構

在上述典型的Transformer架構裡,解碼器(Decoder)從最初位於編碼器(Encoder)端的所有輸入Token開始。這意味著Decoder現在具有編碼器的文字輸入,解碼器現在必須預測下一個標記。本質上,解碼器從一開始就執行編碼部分。Decoder-Only Transformer架構是由多個具有相同結構的區塊(Block)按順序堆疊組成。在每個區塊中,都有兩個主要組件:

  • Masked, multi-headed self-attention。
  • A feed-forward transformation。

此外,我們通常會用殘差連接(Residual connection)和歸一化(Normalization)層來圍繞這些元件。如下圖所示,就是一個區塊:

    就如上圖所示,一個完整的Decoder-Only Trnsformer區塊,通常必須使用到這4種組件:

  • Masked, multi-headed self-attention
  • Layer normalization
  • Feed-forward transformation
  • Residual Connections

這個模型背後的真正驅動力是 Masked Self-Attention 機制。這種機制允許模型在預測每個標記時專注於輸入序列的不同部分,從而促進上下文相關文字的生成。

歸一化(Normalization)層透過計算輸入最終維度的平均值和標準差,消除了批量歸一化對批量維度的依賴。這意味著我們計算嵌入維度上的歸一化統計資料。

而Feed-Forward Layer 是一種前饋層,由兩個適用於最後一個維度的密集層組成,這意味著序列中的每個位置項都使用相同的密集層,因此稱為位置方式。

殘差連接(Residual connections)是一個通用概念,可以應用於任何不改變輸入維度的神經網路層。透過增加殘差連接,我們可以緩解梯度消失和爆炸的問題,並提高訓練過程的整體易用性和穩定性。

基於上述的解碼器區塊(Decoder Block),就可以依據需求的不同,而有多個彼此堆疊的解碼器區塊。如下圖所示:

由上所述可以知道了,Gemma不需要原始Transformer架構的編碼器(Encoder)部分,因為它只是解碼器(Decoder),並且沒有編碼器注意塊,因此解碼器相當於編碼器,除了多頭注意塊中的MASKING ,解碼器只允許從句子中前面的單字收集資訊。它的工作原理就像傳統的語言模型一樣,因為它以詞向量作為輸入,並產生下一個詞的機率作為輸出的估計,但它是自回歸的,因為句子中的每個標記都有前一個詞的上下文。因此,Gemma一次只運行一個令牌;不需要編碼器,因為模型不需要學習輸入序列的表示。

Gemma 和 ChatGPT 使用僅解碼器變壓器。由於 Gemma 和 ChatGPT 僅適用於解碼器,因此它們適用於文字到文字的 LLM,但不適用於圖像和影片。而Google Gemini 同時使用解碼器和編碼器架構。該架構實現了Gemini 的多模式功能,使其能夠在用戶提示及其回應中支援語音和圖像以及文字。

觀摩Gemma的源碼結構

請觀摩這個範例程式碼,是引自Gemma的開源程式碼,僅稍微調整其參數而已(如減改Attention heads的數量等)。這個範例程式可以在Python IDLE環境裡執行。其程式碼如下:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# gem_decoder_001.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from typing import Any, List, Optional, Sequence, Tuple, Union

class GemmaConfig:
    # The number of tokens in the vocabulary.
    vocab_size: int = 10
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 8192
    # The number of blocks in the model.
    num_hidden_layers: int = 2
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 3
    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1
    # The hidden size of the model.
    hidden_size: int = 10
    # The dimension of the MLP representations.
    intermediate_size: int = 16
    # The number of head dimensions.
    head_dim: int = 8
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6
    # The dtype of the weights.
    dtype: str = 'bfloat16'
    # Whether a quantized version of the model is used.
    quant: bool = False
    # The path to the model tokenizer.
    tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

#-----------------------------------------
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """Applies the rotary embedding to the query and key tensors."""
    x_ = torch.view_as_complex(
        torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
                    dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
                          -1).transpose(1, 2)
    return x_out

#--------------------------------------
class RMSNorm(torch.nn.Module):

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        add_unit_offset: bool = True,
    ):
        super().__init__()
        self.eps = eps
        self.add_unit_offset = add_unit_offset
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        x = self._norm(x.float()).type_as(x)
        if self.add_unit_offset:
            output = x * (1 + self.weight)
        else:
            output = x * self.weight
        return output
   
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, quant: bool):
        super().__init__()
        if quant:
            self.weight = nn.Parameter(
                torch.rand((out_features, in_features), dtype=torch.int8),
                requires_grad=True,
            )
            self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
        else:
            self.weight = nn.Parameter(
                torch.rand((out_features, in_features)),
                requires_grad=True,
            )
        self.quant = quant

    def forward(self, x):
        weight = self.weight
        if self.quant:
            weight = weight * self.weight_scaler.unsqueeze(-1)
        output = F.linear(x, weight)
        return output

#-----------------------------------------
class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant: bool,
    ):
        super().__init__()
        self.gate_proj = Linear(hidden_size, intermediate_size, quant)
        self.up_proj = Linear(hidden_size, intermediate_size, quant)
        self.down_proj = Linear(intermediate_size, hidden_size, quant)

    def forward(self, x):
        gate = self.gate_proj(x)
        gate = F.gelu(gate, approximate="tanh")
        up = self.up_proj(x)
        fuse = gate * up
        outputs = self.down_proj(fuse)
        return outputs
   
#---------------------------------------
# 定義模型
class GemmaAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        quant: bool,
        ):
        super().__init__()

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = hidden_size
        self.head_dim = head_dim

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.scaling = self.head_dim**-0.5
        #--------------------------------
 
        self.qkv_proj = Linear(self.hidden_size,
                   (self.num_heads + 2*self.num_kv_heads)*self.head_dim, quant=False )

        self.o_proj = Linear(
            self.num_heads * self.head_dim,
            self.hidden_size,
            quant=False )

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: torch.Tensor,
        kv_write_indices: torch.Tensor,
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
        mask: torch.Tensor,
      ) -> torch.Tensor:
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3

        batch_size, input_len, _ = hidden_states_shape

        #----------------------------
        qkv = self.qkv_proj(hidden_states)
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
                               dim=-1)
         
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Positional embedding.
        xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
        xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
 
        # Write new kv cache.
        # [batch_size, input_len, n_local_kv_heads, head_dim]
        k_cache, v_cache = kv_cache
        k_cache.index_copy_(1, kv_write_indices, xk)
        v_cache.index_copy_(1, kv_write_indices, xv)

        key = k_cache
        value = v_cache

        if self.num_kv_heads != self.num_heads:
            # [batch_size, max_seq_len, n_local_heads, head_dim]
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=2)

        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)
        # [batch_size, n_local_heads, max_seq_len, head_dim]
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)
       
        # [batch_size, n_local_heads, input_len, max_seq_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
        scores = scores + mask
        scores = F.softmax(scores.float(), dim=-1).type_as(q)

        # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.matmul(scores, v)

        # [batch_size, input_len, hidden_dim]
        output = (output.transpose(1, 2).contiguous().view(
            batch_size, input_len, -1))
        output = self.o_proj(output)
        return output

#-------------------------------------
class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config
    ):
        super().__init__()
        self.self_attn = GemmaAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            quant=config.quant,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: torch.Tensor,
        kv_write_indices: torch.Tensor,
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
        mask: torch.Tensor,
        ):
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
       
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            freqs_cis=freqs_cis,
            kv_write_indices=kv_write_indices,
            kv_cache=kv_cache,
            mask=mask,
        )
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
       
        return hidden_states

#---------------------------------------
class GemmaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vocab_size = 10
        self.num_hidden_layers = 2
        self.rms_norm_eps = 1e-6

        self.layers = nn.ModuleList()
        for _ in range(config.num_hidden_layers):
            self.layers.append(GemmaDecoderLayer(config))
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states,
        freqs_cis,
        kv_write_indices,
        kv_caches,
        mask,
    ) -> torch.Tensor:
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                freqs_cis=freqs_cis,
                kv_write_indices=kv_write_indices,
                kv_cache=kv_caches[i],
                mask=mask,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states

#--------- 建立模型 --------------------
model = GemmaModel(GemmaConfig)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.004)

# 設定Target
T = torch.FloatTensor( [[0,0,0,0,0, 1,0,0,0,0],
                        [0,0,0,0,0, 0,1,0,0,0],
                        [0,0,0,0,0, 0,0,1,0,0],
                        [0,0,0,0,0, 0,0,0,1,0]] )

# Input資料
X = torch.FloatTensor( [[1,0,0,0,0, 0,0,0,0,0],
                        [0,1,0,0,0, 0,0,0,0,0],
                        [0,0,1,0,0, 0,0,0,0,0],
                        [0,0,0,1,0, 0,0,0,0,0]] )

X = X.unsqueeze(0)
T = T.unsqueeze(0)

batch_size = 1
seq_len = 4
num_kv_heads = 1
head_dim = 8

kc_cache = torch.randn(batch_size, seq_len, num_kv_heads, head_dim)
vc_cache = torch.randn(batch_size, seq_len, num_kv_heads, head_dim)
kv_cache = (kc_cache, vc_cache)
kv_caches = (kv_cache, kv_cache)

mask = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
freqs_cis =  torch.tensor([[1, 1, 1, 1],
                           [1, 1, 1, 1],
                           [1, 1, 1, 1],
                           [1, 1, 1, 1]
                           ], dtype=torch.float32)

kv_write_indices = torch.LongTensor([0, 1, 2, 3])
 
#----------- 訓練1000回合 -----------------
print('展開訓練...')
for epoch in range(1000):
    # 正向傳播
    Z = model(X, freqs_cis, kv_write_indices, kv_caches, mask)
   
    # 計算損失
    loss = criterion(Z, T)
    # 反向傳播和優化
    optimizer.zero_grad()
    loss.backward() # retain_graph=True )
    optimizer.step()

    # 每1個迴圈印出損失
    if(epoch%100 == 0):
        print('ep=', epoch, 'loss=', loss.item())

#------------------------------
Z = model(X, freqs_cis, kv_write_indices, kv_caches, mask)

print('\n----- Z -----')
print(np.round(Z.detach().numpy()))
#------------------
#END

此程式執行時,就展開訓練1000回合,並且輸出如下:

在這範例程式裡,各類別的組成一個區塊(Block),如下圖:

這個模型背後的真正驅動力是 GemmaAttention機制。這種機制允許模型在預測每個標記時專注於輸入序列的不同部分,從而促進上下文相關文字的生成。    RMSNorm層透過計算輸入最終維度的平均值和標準差,消除了批量歸一化對批量維度的依賴。而GemmaMLP是一種前饋層,由兩個適用於最後一個維度的密集層組成。還有殘差連接(Residual connections)可以緩解梯度消失和爆炸的問題,並提高訓練過程的整體易用性和穩定性。

基於上述的解碼器區塊(GemmaDecoderLayer),就可以依據需求的不同,而有多個彼此堆疊的解碼器區塊,組成了GemmaModel模型。如下圖所示:

結語

對於 Gemma模型,輸出是提示後的下一個標記/單字的機率分配。它輸出完整輸入的一個預測。然而,儘管取得了這些快速的進步,LLM 的一個組成部分仍然保持不變——僅解碼器的變壓器(Decoder-Only Transformer)架構。

(責任編輯:謝嘉洵。)

高煥堂

訂閱MakerPRO知識充電報

與40000位開發者一同掌握科技創新的技術資訊!

Author: 高煥堂

擁有40多年軟硬體整合設計經驗,專精Android終端平台、AI、Docker容器、VR、AI(人工智慧)、IC軟硬整合技術與大數據應用。 近5年來,從事於AI普及化教育工作,目前擔任銘傳大學AI課程、長庚智慧醫療研究所AI課程授課老師。也擔任永春國小、東園國小、立志中學、君毅中學、永春高中等學校的AI師資培育工作。

Share This Post On

Submit a Comment

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *