Transformerを進化させた14の現代的テクニック

Stephen Diehlの記事に基づき、オリジナルの「Attention」論文以降にTransformerを劇的に進化させたGQAやFlash Attentionなど14の重要な現代的テクニックを解説する
LLM
Python
Transformer
Author

Junichiro Iwasawa

Published

October 22, 2025

2017年の画期的な論文「Attention Is All You Need」は、Transformerアーキテクチャを提案し、現代のAI、特に大規模言語モデル(LLM)の基盤を築いた。ChatGPTやClaudeのような驚異的なモデルの中心には、このAttentionメカニズムがある。

しかし、オリジナルの論文で提案されたTransformerは、いわば「バージョン1.0」に過ぎない。今日の最先端モデルが驚異的な性能を発揮できるのは、オリジナルの論文以降に開発された、数多くの洗練された技術と最適化のおかげである。

この記事では、Stephen Diehl氏の優れたブログ記事「Attention Wasn’t All We Needed」に基づき、オリジナルのTransformerアーキテクチャを劇的に進化させた、いくつかの重要な現代的テクニックについて見ていく。

各テクニックの核となるアイデアを、できるだけ簡潔なPyTorchコード例と共に紹介する。ただし、これらの例の多くは中核的な概念をスケッチしたものであり、完全な実装については元の論文やPyTorch、Jaxなどのフレームワークにおける本番コードを参照されたい。

  1. Group Query Attention (GQA)
  2. Multi-head Latent Attention
  3. Flash Attention
  4. Ring Attention
  5. Pre-normalization
  6. RMSNorm
  7. SwiGLU
  8. Rotary Positional Embedding (RoPE)
  9. Mixture of Experts (MoE)
  10. Learning Rate Warmup
  11. Cosine Schedule
  12. AdamW Optimizer
  13. Multi-token Prediction
  14. Speculative Decoding

1. Group Query Attention (GQA)

Grouped Query Attention (GQA) は、推論時のKVキャッシュ(Key-Valueキャッシュ)のメモリ使用量を削減するための技術である。これは標準のMulti-head Attention (MHA) に対するアーキテクチャの最適化である。

GQAの基本的なアイデアは、MHAの計算上およびメモリ上のボトルネックが、K(キー)とV(バリュー)の射影とそのキャッシュサイズに大きく影響されるという観察に基づいている。GQAは、複数のQ(クエリ)ヘッドで単一のKとVの射影セットを共有することにより、このコストを削減する。

MHAのように\(N_h\)個のQ, K, Vヘッドを持つ代わりに、GQAは\(N_h\)個のQヘッドを使用するが、K/Vヘッドは\(N_{kv}\)個のみ(ここで \(N_{kv} < N_h\))使用する。\(N_h\)個のQヘッドは\(N_{kv}\)個のグループに分割され、各グループ(\(N_h / N_{kv}\)個のQヘッド)が同じKヘッドとVヘッドにアテンドする。この構造により、KとVの射影行列のパラメータ数が大幅に削減され、さらに重要なことに、自己回帰デコーディング中に必要なK/Vキャッシュのサイズが縮小する。

実装では、コード例のrepeat_interleaveステップで示されるように、\(N_{kv}\)個のK/Vヘッドを計算し、それらを\(g = N_h / N_{kv}\)(グループサイズ)回繰り返すかインターリーブして、\(N_h\)個のQヘッドと一致させてからアテンションスコアを計算する。

GQAは主に、モデルのパフォーマンスを大幅に損なうことなく、推論速度を高速化し、メモリ要件を削減する技術として使用される。K/Vヘッドの数を減らすことで、GQAは各デコーディングステップでK/Vキャッシュをロードするために必要なメモリ帯域幅を劇的に削減する。これは推論時の主要なボトルネックである。\(N_{kv}=1\)とする極端な形式は、Multi-query attention (MQA) と呼ばれる。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import random

class GroupQueryAttention(nn.Module):
    def __init__(self, dim, num_heads, num_kv_heads=None, head_dim=64, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads else num_heads
        self.head_dim = head_dim

        # num_heads が num_kv_heads で割り切れることを確認
        assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"

        # 1つのK/Vヘッドあたりのクエリ数
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # 射影
        self.q_proj = nn.Linear(dim, num_heads * head_dim)
        self.k_proj = nn.Linear(dim, self.num_kv_heads * head_dim)
        self.v_proj = nn.Linear(dim, self.num_kv_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Q, K, V への射影
        q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        # アテンション計算のために転置
        q = q.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k = k.transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]
        v = v.transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]

        # グループ内の各クエリヘッドのために k, v をリピート
        k = k.repeat_interleave(self.num_queries_per_kv, dim=1)
        v = v.repeat_interleave(self.num_queries_per_kv, dim=1)

        # スケール化ドット積アテンション
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(2, 3)) * scale

        # マスク適用(もしあれば)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        # Softmax と ドロップアウト
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # アテンションをバリューに適用
        out = torch.matmul(attn, v)  # [batch_size, num_heads, seq_len, head_dim]
        out = out.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim)

        # 出力射影
        out = self.o_proj(out)

        return out

2. Multi-head Latent Attention

Multi-head Latent Attention は、学習可能な「潜在(latent)」ベクトルを導入し、これが入力シーケンス要素間の中間的なボトルネックとして機能する。

中核となるアイデアは、標準的なSelf-Attentionに固有の、シーケンス長\(L\)に対する二乗の計算コスト \(O(L^2)\) を緩和することである。すべての入力要素が他のすべての要素に直接アテンドする代わりに、入力はまず固定数の潜在ユニット(\(N_{\text{latents}}\))にアテンドし、次にこれらの潜在ユニットが入力(またはその変種)にアテンドし返す。

これにより、長い入力シーケンス内の直接的な相互作用が、はるかに小さな潜在セットを介した2つのCross-Attentionステップに置き換えられる。このアプローチは、入力シーケンスからの本質的な情報が、これらの潜在表現に効果的に要約または圧縮できるという仮定に基づいている。\(N_{\text{latents}} \ll L\) の場合、計算量を大幅に削減しつつ、表現力を維持する。

このメカニズムは、主に2段階のアテンション計算を含む。

  1. 潜在-入力アテンション: 潜在ベクトル(\(Q_L\))が入力(\(K_X, V_X\))にアテンドし、入力から情報を集約する。 \[H_L = \text{Attention}(Q_L, K_X, V_X)\]
  2. 入力-潜在アテンション: 入力(\(Q_X\))が潜在ベクトル(\(K_L\))にアテンドし、更新された潜在表現(\(H_L\))から情報を集約する。 \[O = \text{Attention}(Q_X, K_L, H_L)\]

この技術は、標準的なSelf-Attentionが計算的に実行不可能な、非常に長いシーケンスや高次元の入力を扱うアーキテクチャ(例:Perceiver)で主に使用される。計算量は\(O(L^2)\)から\(O(L \cdot N_{\text{latents}})\)に削減される。

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, dim, num_heads, num_latents=64, head_dim=64, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_latents = num_latents
        self.head_dim = head_dim

        # 射影
        self.q_proj = nn.Linear(dim, num_heads * head_dim)
        self.k_proj = nn.Linear(dim, num_heads * head_dim)
        self.v_proj = nn.Linear(dim, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, dim)

        # 潜在ベクトル(学習可能)
        self.latents = nn.Parameter(torch.randn(1, num_latents, dim))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # このバッチのための潜在ベクトルを取得
        latents = self.latents.expand(batch_size, -1, -1)

        # 入力を Q, K, V に射影
        q_x = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k_x = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v_x = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        # 潜在ベクトルを Q, K, V に射影
        q_latents = self.q_proj(latents).reshape(batch_size, self.num_latents, self.num_heads, self.head_dim)
        k_latents = self.k_proj(latents).reshape(batch_size, self.num_latents, self.num_heads, self.head_dim)
        v_latents = self.v_proj(latents).reshape(batch_size, self.num_latents, self.num_heads, self.head_dim)

        # アテンション計算のために転置
        q_x = q_x.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k_x = k_x.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        v_x = v_x.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]

        q_latents = q_latents.transpose(1, 2)  # [batch_size, num_heads, num_latents, head_dim]
        k_latents = k_latents.transpose(1, 2)  # [batch_size, num_heads, num_latents, head_dim]
        v_latents = v_latents.transpose(1, 2)  # [batch_size, num_heads, num_latents, head_dim]

        # アテンションのためのスケール係数
        scale = 1.0 / math.sqrt(self.head_dim)

        # 1. 潜在 -> 入力 アテンション
        attn_latent_to_input = torch.matmul(q_latents, k_x.transpose(2, 3)) * scale

        if mask is not None:
            latent_mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            attn_latent_to_input = attn_latent_to_input.masked_fill(latent_mask == 0, -1e9)

        attn_latent_to_input = torch.softmax(attn_latent_to_input, dim=-1)
        attn_latent_to_input = self.dropout(attn_latent_to_input)

        # アテンション重みを入力バリューに適用
        latent_output = torch.matmul(attn_latent_to_input, v_x)  # [batch_size, num_heads, num_latents, head_dim]

        # 2. 入力 -> 潜在 アテンション
        attn_input_to_latent = torch.matmul(q_x, k_latents.transpose(2, 3)) * scale
        attn_input_to_latent = torch.softmax(attn_input_to_latent, dim=-1)
        attn_input_to_latent = self.dropout(attn_input_to_latent)

        # 更新された潜在バリューをバリューとして使用
        output = torch.matmul(attn_input_to_latent, latent_output)  # [batch_size, num_heads, seq_len, head_dim]

        # リシェイプして出力射影
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.o_proj(output)

        return output

3. Flash Attention

Flash Attention は、特に長いシーケンスにおいて、標準的なSelf-Attentionメカニズムに固有の重大なメモリボトルネックに対処する技術である。

従来のアプローチでは、アテンションスコア行列 \(S = QK^T\)(ここで \(Q, K \in \mathbb{R}^{N \times d}\))全体を計算する。これには、\(N \times N\) 行列 \(S\) を格納する必要があり、シーケンス長 \(N\) に対して \(O(N^2)\) のメモリ複雑性を持つ。

Flash Attentionは、この巨大な \(N \times N\)\(S\) 行列をGPUの低速な高帯域幅メモリ(HBM)に実体化(materialize)して保存することを避ける。代わりに、**タイリング(tiling)再計算(recomputation)**の技術を活用し、アテンション計算をはるかに高速なオンチップSRAMに収まる小さなブロックで処理する。

Flash Attentionは、KとVの行列をブロックに分割し、SRAMに反復的にロードする。そして、Qの各ブロックに対して、SRAM上にある現在のKブロックとのアテンションスコアを計算する。重要なのは、**オンラインソフトマックス(online softmax)**アルゴリズムを採用している点である。これにより、完全な \(N \times N\) 行列を必要とせずに、ブロックごとに正しくスケーリングされたアテンション出力を計算できる。

中間結果を高速なSRAM内に保持し、HBMへのデータ転送を最小限に抑えることで、Flash Attentionはシーケンス長に関連するメモリフットプリントを \(O(N^2)\) から \(O(N)\)(Q, K, V自体の格納が支配的)に削減し、メモリアクセスパターンの改善により大幅なスピードアップを実現する。

実際には、FlashAttentionは高度に最適化されたCUDAカーネルのファミリーである。以下は、PyTorchでの最小限のトイ実装と、flash-attnライブラリの実際の使用例である。

# トイ実装
class FlashAttention(nn.Module):
    def __init__(self, dim, num_heads, head_dim=64, dropout=0.0, block_size=1024):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_size = block_size # トイ実装でのシミュレーション用

        # 射影
        self.q_proj = nn.Linear(dim, num_heads * head_dim)
        self.k_proj = nn.Linear(dim, num_heads * head_dim)
        self.v_proj = nn.Linear(dim, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def _flash_attention_forward(self, q, k, v, mask=None):
        # これはFlash Attentionの簡略化された近似です
        # 実際にはカスタムCUDAカーネルを使用します

        batch_size, num_heads, seq_len, head_dim = q.shape
        scale = 1.0 / math.sqrt(head_dim)

        # 出力とアテンション統計を初期化
        output = torch.zeros_like(q)
        # オンラインソフトマックスのための正規化項
        normalizer = torch.zeros((batch_size, num_heads, seq_len, 1), device=q.device)

        # キーとバリューのブロックを処理
        for block_start in range(0, seq_len, self.block_size):
            block_end = min(block_start + self.block_size, seq_len)

            # キーとバリューのブロックを抽出(SRAMにロードするイメージ)
            k_block = k[:, :, block_start:block_end]
            v_block = v[:, :, block_start:block_end]

            # このブロックのアテンションスコアを計算
            attn_scores = torch.matmul(q, k_block.transpose(2, 3)) * scale

            if mask is not None:
                block_mask = mask[:, :, :, block_start:block_end]
                attn_scores = attn_scores.masked_fill(block_mask == 0, -1e9)

            # Softmaxを適用(ただし、これはまだ「オンライン」ではない)
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_probs = self.dropout(attn_probs)

            # このブロックのアテンション結果で出力を更新
            output += torch.matmul(attn_probs, v_block)
            normalizer += attn_probs.sum(dim=-1, keepdim=True) # 正規化項を蓄積

        # 出力を正規化(簡略化されたバージョン)
        output = output / (normalizer + 1e-6)

        return output

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        output = self._flash_attention_forward(q, k, v, mask)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.o_proj(output)
        return output
# `flash-attn` ライブラリの実際の使用例
# (pip install flash-attn が必要)

# import torch
# from flash_attn import flash_attn_qkvpacked_func

# # 最小限の構成
# BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM = 2, 64, 4, 32
# CAUSAL = False
# DTYPE = torch.float16
# DEVICE = "cuda" # Flash AttentionはCUDAを必要とします

# # ダミーのパックされたQKVテンソルを作成
# # Shape: (batch_size, seq_len, 3, num_heads, head_dim)
# qkv = torch.randn(
#     BATCH_SIZE,
#     SEQ_LEN,
#     3,
#     NUM_HEADS,
#     HEAD_DIM,
#     dtype=DTYPE,
#     device=DEVICE,
# )

# print(f"Input qkv shape: {qkv.shape}")

# # FlashAttentionのパックされたQKV関数を呼び出し
# output = flash_attn_qkvpacked_func(
#     qkv,
#     dropout_p=0.0,
#     causal=CAUSAL,
#     softmax_scale=None # デフォルトのスケーリングを使用
# )

# # Output shape: (batch_size, seq_len, num_heads, head_dim)
# print(f"Output shape: {output.shape}")
# print("FlashAttention call successful.")

4. Ring Attention

Ring Attention は、Self-Attentionのブロックワイズ計算を複数のGPUで使用し、単一のデバイスには収まらないような非常に長いシーケンスの学習と推論を可能にする。

中核となるアイデアは、複数のプロセッシングユニット(GPUなど)を概念的なリングトポロジに配置し、計算を分散させることである。このアプローチでは、単一のデバイスがKおよびVテンソル全体を保持する必要がない。代わりに、これらのテンソルはシーケンス長の次元に沿ってシャーディング(分割)され、デバイスごとのピークメモリ要件を劇的に削減する。

実用的な分散実装では、各デバイスはアテンション計算を同期ステップで展開する。各ステップで、デバイスはローカルのQシャードと現在所有しているKシャードを使用して、部分的なアテンションスコアを計算する。重要な要素はその後の通信である。KVのシャードがリング内の次のデバイスに渡される。このローテーションは、すべてのQシャードがすべてのK/Vシャードと相互作用するまで繰り返される。

以下のPython例は、実際のマルチGPUハードウェアを必要とせずに、単一デバイス上でRing Attentionのロジックをシミュレートしたものである。_simulate_ring_attention関数は、KVテンソルのスライスを選択することで、分散プロセスを模倣する。

class RingAttention(nn.Module):
    def __init__(self, dim, num_heads, head_dim=64, dropout=0.0, num_shards=4):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_shards = num_shards # GPUの数(シミュレーション用)

        # 射影
        self.q_proj = nn.Linear(dim, num_heads * head_dim)
        self.k_proj = nn.Linear(dim, num_heads * head_dim)
        self.v_proj = nn.Linear(dim, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def _simulate_ring_attention(self, q, k, v, mask=None):
        # 実際のマルチGPUサポートなしでリングアテンションをシミュレート
        batch_size, num_heads, seq_len, head_dim = q.shape
        scale = 1.0 / math.sqrt(head_dim)

        # シャードサイズを計算
        shard_size = (seq_len + self.num_shards - 1) // self.num_shards

        # 出力を初期化
        output = torch.zeros_like(q)
        normalizer = torch.zeros((batch_size, num_heads, seq_len, 1), device=q.device)

        # シャード処理をシミュレート
        for shard_idx in range(self.num_shards):
            start_idx = shard_idx * shard_size
            end_idx = min(start_idx + shard_size, seq_len)

            # このシャードの K と V を処理(実際にはこれがデバイス間で渡される)
            if start_idx < seq_len:
                k_shard = k[:, :, start_idx:end_idx]
                v_shard = v[:, :, start_idx:end_idx]

                # アテンションスコアを計算(Qは全デバイスで同じと仮定)
                attn_scores = torch.matmul(q, k_shard.transpose(2, 3)) * scale

                if mask is not None:
                    shard_mask = mask[:, :, :, start_idx:end_idx]
                    attn_scores = attn_scores.masked_fill(shard_mask == 0, -1e9)
                
                # Softmax(シャード全体で蓄積される)
                attn_probs = torch.softmax(attn_scores, dim=-1)
                attn_probs = self.dropout(attn_probs)

                # 出力と正規化項を更新
                output += torch.matmul(attn_probs, v_shard)
                normalizer += attn_probs.sum(dim=-1, keepdim=True)

        # 出力を正規化
        output = output / (normalizer + 1e-6)

        return output

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        output = self._simulate_ring_attention(q, k, v, mask)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.o_proj(output)
        return output

5. Pre-normalization (Pre-LN)

Pre-normalization(またはPre-LN)は、Transformerの残差ブロックの設計における重要な変更点である。

従来のPost-normalization(Post-LN)では、正規化レイヤー(LayerNorm)がメインの操作(Self-AttentionやFFN)のに適用されていた。Pre-normalizationは、それをに適用するように変更する。

この小さな変更が、学習のダイナミクスに大きな影響を与える。入力を計算量の多いサブレイヤー(AttentionやFFN)に通すに正規化することで、ネットワークを流れる活性化と勾配を安定させる。この安定化効果は、特に非常に深いネットワークにおいて顕著であり、勾配消失・爆発の問題を軽減し、より高い学習率の使用や、より速く信頼性の高い収束を可能にすることが多い。

典型的な実装は \(x + f(\text{norm}(x))\) という構造に従う。ここで、\(x\)はブロックへの入力、\(\text{norm}(\cdot)\)は正規化関数(LayerNormやRMSNorm)、\(f(\cdot)\)はメインの変換関数(MHAやFFN)である。正規化された出力が関数fnによって処理され、その出力が元の正規化されていない入力 \(x\) に残差接続を介して足し戻される。

# RMSNormクラスはセクション6で定義
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8, elementwise_affine=True):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter('weight', None)
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x_normalized = x / rms
        if self.elementwise_affine:
            x_normalized = x_normalized * self.weight
        return x_normalized

class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm_type='layer'):
        super().__init__()
        self.fn = fn

        if norm_type == 'layer':
            self.norm = nn.LayerNorm(dim)
        elif norm_type == 'rms':
            self.norm = RMSNorm(dim) # 次のセクションで定義
        else:
            raise ValueError(f"Unknown normalization type: {norm_type}")

    def forward(self, x, *args, **kwargs):
        # 最初に正規化を適用し、次に関数を適用
        # そして元の x を足し合わせる(残差接続)
        return self.fn(self.norm(x), *args, **kwargs) + x

6. RMSNorm

RMSNorm (Root Mean Square Normalization) は、広く使われているLayerNormを簡略化したもので、計算オーバーヘッドを削減しつつ、同等のパフォーマンスを維持し、しばしば学習の安定性を向上させるように設計されている。

LayerNormが平均を減算して活性化を中央揃えにし、標準偏差でスケーリングするのとは異なり、RMSNormは平均の中央揃えステップを完全に省略する。この簡略化の背景には、LayerNormの再センタリング操作が計算コストの大部分を占めており、それを取り除いてもモデルのパフォーマンスに大きな害はない(時には利益さえある)という経験的な観察がある。

RMSNormは、入力の二乗平均平方根(RMS)の大きさに基づいて入力を再スケーリングするだけである。入力ベクトル \(x = (x_1, \dots, x_n)\) に対して、RMS値は \(\text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2}\) として計算される。正規化された出力 \(\bar{x}_i\)\(\bar{x}_i = \frac{x_i}{\text{RMS}(x) + \epsilon}\) となる。

LayerNormと同様に、RMSNormも学習可能なスケーリングパラメータ \(g\) を含み、最終的な出力は \(y_i = g_i \bar{x}_i\) となる(バイアス \(b\) は省略されることが多い)。平均計算を省くことで、RMSNormは計算量とメモリ使用量を削減し、特に大規模モデルにおいて魅力的な代替手段となっている。

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8, elementwise_affine=True):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if elementwise_affine:
            # 学習可能なスケーリングパラメータ(weight)
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter('weight', None)

    def forward(self, x):
        # 最後の次元(特徴量次元)に沿って二乗平均平方根を計算
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)

        # RMSで正規化
        x_normalized = x / rms

        # 学習可能なスケーリングを適用
        if self.elementwise_affine:
            x_normalized = x_normalized * self.weight

        return x_normalized

7. SwiGLU

SwiGLUは、Gated Linear Unit (GLU) ファミリーから派生した活性化関数で、ニューラルネットワークの性能を向上させるために特別に調整された。

GLUベースの活性化の基本的な概念は、ネットワークを流れる情報の流れを適応的に制御するゲーティング(gating)メカニズムを導入することである。SwiGLUは、ゲート部分に適用する特定の非線形関数としてSiLU (Sigmoid-weighted Linear Unit)、別名Swish\(\text{SiLU}(x) = x \cdot \sigma(x)\)\(\sigma\)はシグモイド関数)を採用している点で区別される。

SwiGLUの動作メカニズムは、通常、FFNブロック内で入力を2つの独立した線形変換(\(Wx+b\)\(Vx+c\))に射影する。SwiGLU活性化は \(\text{SwiGLU}(x) = \text{SiLU}(Wx + b) \odot (Vx + c)\) として計算される(\(\odot\)は要素ごとの乗算)。

効果として、一方のパスがSiLU活性化を経てゲート値となり、それがもう一方のパスの出力をスケーリングする。このゲーティングにより、ネットワークは入力コンテキストに基づいてどの特徴を前方に渡すかを動的に制御でき、ReLUのような単純な活性化関数よりも高い表現力とより良い勾配フローをもたらす。PaLMやLLaMAのようなモデルでの成功がその有効性を強調している。

class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_hidden=None, dim_out=None, bias=True):
        super().__init__()
        # 中間層の次元。指定がなければ入力の4倍(標準的なFFN設計)
        dim_hidden = dim_hidden or 4 * dim_in
        dim_out = dim_out or dim_in

        # GLUの2つの並列な線形層
        self.w1 = nn.Linear(dim_in, dim_hidden, bias=bias) # ゲート用
        self.w2 = nn.Linear(dim_in, dim_hidden, bias=bias) # 値用

        # 出力射影
        self.w3 = nn.Linear(dim_hidden, dim_out, bias=bias)

    def forward(self, x):
        # 2つの並列パス
        hidden1 = self.w1(x)
        hidden2 = self.w2(x)

        # SiLU (Swish) 活性化: x * sigmoid(x)
        hidden1_act = hidden1 * torch.sigmoid(hidden1)

        # ゲートを適用(要素ごとの乗算)
        hidden = hidden1_act * hidden2

        # 出力射影
        return self.w3(hidden)

8. Rotary Positional Embedding (RoPE)

Rotary Positional Embedding (RoPE) は、TransformerのSelf-Attentionメカニズムに相対的な位置依存性を効果的に組み込むためのエレガントな手法である。

従来のアプローチ(絶対位置エンコーディングの加算など)とは異なり、RoPEは位置エンコーディングを、Query (Q) と Key (K) ベクトルのドット積が計算されるに適用される回転操作として捉える。

重要な洞察は、位置\(m\)のQベクトルと位置\(n\)のKベクトルを、それぞれ\(m\)\(n\)に比例する角度で回転させることにより、結果として得られるドット積が、ベクトルの大きさ(ノルム)を変えることなく、相対位置 \(m - n\) のみに依存する形でエンコードされるという点である。

この回転は、埋め込み次元をペアに分割し、三角関数(cos, sin)を用いて効率的に実装される。RotaryEmbeddingクラスは、シーケンス長と次元に基づいて必要なcossinの値を事前計算する。apply_rotary_pos_emb関数は、これらの値を使用してQとKベクトルを変換する。

RoPEは、QとKの射影が計算された、アテンションスコアが計算されるに適用される。この方法は、相対位置を自然にエンコードし、学習中に見たことのない長いシーケンスへの汎化性能が高いことを示しており、LLaMAなどの現代のLLMで広く採用されている。

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000, interleaved=False):
        super().__init__()
        self.dim = dim
        self.base = base
        self.interleaved = interleaved

        # 逆周波数バンドを生成
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, seq_len, device=None):
        if device is None:
            device = self.inv_freq.device

        # 位置インデックスを生成
        positions = torch.arange(seq_len, device=device).float()

        # 周波数パターンを計算 (seq_len, dim/2)
        freqs = torch.outer(positions, self.inv_freq)

        # サインとコサイン埋め込みを取得
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = torch.cos(emb)[:, :self.dim]
        sin = torch.sin(emb)[:, :self.dim]

        return cos, sin

def apply_rotary_pos_emb(q, k, cos, sin, interleaved=False):
    # Q と K にロータリー埋め込みを適用
    batch_size, num_heads, seq_len, head_dim = q.shape
    
    # cos/sinの形状をブロードキャスト可能にする
    # [1, 1, seq_len, dim]
    cos = cos.reshape(1, 1, seq_len, cos.shape[-1])
    sin = sin.reshape(1, 1, seq_len, sin.shape[-1])

    # QとKを回転のために半分に分割
    half_dim = head_dim // 2
    q1, q2 = q[..., :half_dim], q[..., half_dim:]
    k1, k2 = k[..., :half_dim], k[..., half_dim:]

    # 回転を適用(複素数乗算に相当)
    q_rotated = torch.cat([
        q1 * cos[..., :half_dim] - q2 * sin[..., :half_dim],
        q2 * cos[..., :half_dim] + q1 * sin[..., :half_dim]
    ], dim=-1)

    k_rotated = torch.cat([
        k1 * cos[..., :half_dim] - k2 * sin[..., :half_dim],
        k2 * cos[..., :half_dim] + k1 * sin[..., :half_dim]
    ], dim=-1)

    return q_rotated, k_rotated

9. Mixture of Experts (MoE)

Mixture of Experts (MoE) は、推論や学習中の計算コストを比例して大幅に増加させることなく、パラメータ数を(潜在的に数兆まで)大幅に増やすために設計されたモデルアーキテクチャである。

中核となるアイデアは、TransformerのFeed-Forward (FFN) ブロックのような計算集約的なコンポーネントを、複数の小さな「エキスパート(expert)」ネットワークに置き換えることである。重要なのは、すべてのエキスパートがすべての入力トークンを処理するわけではない点である。

代わりに、軽量な「ルーター(router)」または「ゲーティング(gating)」ネットワークが、各入力トークンを処理するのに最も適していると見なされるエキスパートの小さなサブセット(通常は1つまたは2つ、top-kルーティングと呼ばれる)を動的に選択する。この条件付き計算により、MoEモデルは膨大なパラメータを持ちながら、特定の入力に対してはそのごく一部のみをアクティブ化するため、同等のサイズの密な(dense)モデルと比較して管理可能なFLOPsを維持できる。

ルーターネットワークは、入力トークンの表現を受け取り、各エキスパートの適合性を示すスコア(ロジット)を生成する。これらのスコアはtop-k関数で処理され、選択されたエキスパートの重みが(通常はSoftmaxで)正規化される。

トークンは選択されたエキスパートにのみディスパッチされる。各エキスパート(通常は標準的なFFN)はトークンを独立して処理する。これらのアクティブなエキスパートによって生成された出力は、ルーターによって計算されたルーティング重みに基づいて重み付け和として結合される。

MoEの学習における課題は、すべてのエキスパートが効果的に利用されるようにすることである。そうでなければ、ルーターが特定のエキスパートに過負荷をかけ、他のエキスパートが未発達になる可能性がある。これに対抗するため、_compute_balance_lossメソッドで示されるように、**補助的な負荷分散損失(load balancing loss)**が通常、学習目的関数に組み込まれる。

class MixtureOfExperts(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts=4, top_k=2, noise_std=1.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_experts = num_experts
        self.top_k = min(top_k, num_experts)
        self.noise_std = noise_std

        # エキスパート(FFN)を作成
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])

        # ルーターネットワーク
        self.router = nn.Linear(input_dim, num_experts)

    def _compute_routing_weights(self, x):
        # ルーティングロジットを計算
        routing_logits = self.router(x)  # [batch_size, seq_len, num_experts]

        # 学習中にノイズを加えて探索を促進
        if self.training and self.noise_std > 0:
            noise = torch.randn_like(routing_logits) * self.noise_std
            routing_logits = routing_logits + noise

        # 各トークンの top-k エキスパートを取得
        routing_weights, selected_experts = torch.topk(routing_logits, self.top_k, dim=-1)

        # ルーティング重みを softmax で正規化
        routing_weights = F.softmax(routing_weights, dim=-1)

        return routing_weights, selected_experts

    def _compute_balance_loss(self, selected_experts, routing_weights):
        # 補助的な負荷分散損失を計算
        batch_size, seq_len, _ = selected_experts.shape

        expert_mask = torch.zeros(batch_size, seq_len, self.num_experts, device=selected_experts.device)

        # 選択されたエキスパートの位置に重みを配置
        for k in range(self.top_k):
            expert_mask.scatter_(-1, selected_experts[..., k:k+1], routing_weights[..., k:k+1])

        # エキスパートごとの平均ルーティング確率
        expert_routing_probs = expert_mask.mean(dim=[0, 1])

        # 均等な確率をターゲットとするMSE損失
        target_probs = torch.ones_like(expert_routing_probs) / self.num_experts
        balance_loss = F.mse_loss(expert_routing_probs, target_probs) * self.num_experts

        return balance_loss

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # ルーティング重みと選択されたエキスパートを計算
        routing_weights, selected_experts = self._compute_routing_weights(x)

        # 出力を準備
        output = torch.zeros(batch_size, seq_len, self.output_dim, device=x.device)

        # 選択されたエキスパートにディスパッチ
        for k in range(self.top_k):
            expert_indices = selected_experts[..., k]  # [batch_size, seq_len]
            expert_weights = routing_weights[..., k].unsqueeze(-1)  # [batch_size, seq_len, 1]

            # 各エキスパートごとに処理
            for expert_idx in range(self.num_experts):
                # このエキスパートに割り当てられたトークンを見つける
                mask = (expert_indices == expert_idx)

                if mask.any():
                    # 該当する入力トークンを収集
                    expert_inputs = x[mask]

                    # エキスパートで処理
                    expert_outputs = self.experts[expert_idx](expert_inputs)

                    # 適切な重みで出力をスキャッター(書き戻し)
                    output[mask] += expert_outputs * expert_weights[mask]

        # 補助的な負荷分散損失を計算
        balance_loss = self._compute_balance_loss(selected_experts, routing_weights)

        # 出力と補助損失を返す
        return output, balance_loss

10. Learning Rate Warmup

Learning Rate Warmup(学習率ウォームアップ)は、ニューラルネットワークの学習初期段階で採用されるヒューリスティックで、安定性を高め、発散を防ぐ。

学習の開始時、モデルのパラメータはランダムに初期化されており、最適とはほど遠い状態にある。ここでいきなり大きな学習率(Learning Rate: LR)を使用すると、初期の勾配(これも大きく不安定な場合がある)が急激なパラメータ更新を引き起こし、モデルを損失ランドスケープの悪い領域に押しやったり、数値的不安定性(損失の発散)を引き起こす可能性がある。

ウォームアップは、ごく小さなLRで学習プロセスを開始し、事前に定義された初期の学習ステップ数(「ウォームアップステップ」)にわたってLRを徐々に増加させ、ターゲットとなるベース値に到達させることで、このリスクを軽減する。

一般的な戦略は線形ウォームアップである。ステップ \(t\) での学習率 \(\eta_t\) は、 \(t < T_{\text{warmup}}\) の間、 \(\eta_t = \eta_{\text{base}} \times \frac{t}{T_{\text{warmup}}}\) として計算される。get_lrメソッドで示されるように、スケーリングファクタscaleは、warmup_stepsにわたって0から1まで線形に増加する。この穏やかな立ち上がりにより、モデルは不安定になりがちな初期段階で徐々に適応でき、スムーズな収束につながる。

# PyTorchの _LRScheduler を継承
class LinearWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
        self.warmup_steps = warmup_steps
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # ウォームアップ中: 0からベースLRまで線形に増加
            scale = float(self.last_epoch + 1) / float(max(1, self.warmup_steps))
            return [base_lr * scale for base_lr in self.base_lrs]
        else:
            # ウォームアップ後: ベースLRを使用
            return self.base_lrs

11. Cosine Schedule

Cosine Scheduling(コサインスケジューリング、またはコサインアニーリング)は、学習率のスケジュール手法である。その中核原理は、コサインカーブの形状に従って、学習の過程で学習率を徐々に減少させることである。

特定のステップでLRを急激に下げるステップディケイ(Step Decay)とは異なり、コサインアニーリングは滑らかで連続的な減少を提供する。通常、LRは初期の高い値から始まり、コサイン関数の最初の半サイクルに従って減少し、最終的な学習ステップまでに事前に定義された最小値(多くの場合ゼロに近い)に達する。

この滑らかな減衰は、学習初期には損失ランドスケープの広範な探索のために大きなステップを許可し、後半にはファインチューニングと良い最小値への収束のためにステップを徐々に小さくすることで、最適化プロセスを助けることが経験的に示されている。

以下のコード例のように、コサインスケジューリングはしばしば「ウォームアップ」フェーズ(セクション10)と組み合わされる。ウォームアップ後、コサイン減衰フェーズが始まり、LRをピーク値からターゲットの最小値(base_lr * min_lr_ratio)まで、残りのステップにわたって滑らかに減少させる。

class CosineAnnealingWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr_ratio=1e-4, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr_ratio = min_lr_ratio
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # ウォームアップ中: 線形増加
            scale = float(self.last_epoch + 1) / float(max(1, self.warmup_steps))
            return [base_lr * scale for base_lr in self.base_lrs]
        else:
            # ウォームアップ後: コサイン減衰
            progress = float(self.last_epoch - self.warmup_steps) / float(
                max(1, self.total_steps - self.warmup_steps)
            )
            # コサイン減衰の式
            scale = self.min_lr_ratio + 0.5 * (1.0 - self.min_lr_ratio) * (
                1.0 + math.cos(math.pi * progress)
            )
            return [base_lr * scale for base_lr in self.base_lrs]

12. AdamW Optimizer

**AdamW(Adam with Decoupled Weight Decay)**は、Adamのような適応的オプティマイザにおける重み減衰(L2正則化)の標準的な実装における微妙な問題に対処する。

従来のAdamでは、L2正則化は、移動平均(\(m_t\)\(v_t\))を計算するに、勾配に減衰項(\(\lambda \cdot \text{weight}\))を直接加えることで実装されることがよくあった。しかし、これにより、重み減衰の効果が適応的学習率と結びついてしまう。

AdamWはこれらのプロセスを分離(decouple)する。標準的なAdamの更新を勾配のみに基づいて行い、それとは別に、重み減衰ステップを重みに直接適用する。これにより、重みがその勾配履歴に関係なく、その大きさに比例して減衰するという、L2正則化の本来の振る舞いが回復される。

AdamWの更新メカニズムは、重み減衰の適用方法が異なる。

  1. 重み減衰を重みに直接適用する: \(\theta_{t-1}' = \theta_{t-1} \cdot (1 - \text{lr} \cdot \lambda)\)(コード内のp.data.mul_(...)
  2. 次に、標準的なAdamの更新(モーメントに基づく)を、この減衰後の重みに適用する: \(\theta_t = \theta_{t-1}' - \text{lr} \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)\)(コード内のp.data.addcdiv_(...)

このアプローチは、特にTransformerのような大規模モデルの学習において、正則化が重要な場合に、より良い汎化性能をもたらすことが示されている。PyTorchには最適化されたAdamWの実装が含まれているが、以下はその簡略化されたバージョンである。

class AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=1e-2, amsgrad=False):
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super(AdamW, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')
                amsgrad = group['amsgrad']
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data) # m_t
                    state['exp_avg_sq'] = torch.zeros_like(p.data) # v_t
                    if amsgrad:
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1

                # Adamのモーメント更新
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                if amsgrad:
                    torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                # ★★★ Decoupled Weight Decay ★★★
                # 最適化ステップの「前」に重み減衰を適用
                if group['weight_decay'] != 0:
                    p.data.mul_(1 - group['lr'] * group['weight_decay'])

                # パラメータ更新
                p.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

13. Multi-token Prediction

Multi-token Prediction(複数トークン予測)は、自己回帰型言語モデルの推論速度を向上させるために開発された技術である。

通常の自己回帰生成では、トークンを1つずつ予測する。モデルはシーケンスを受け取り、次に最も可能性の高い単一のトークンを予測し、それをシーケンスに追加してプロセスを繰り返す。この逐次的な性質は、レイテンシが重要なアプリケーションにとって大きなボトルネックとなる。

Multi-token Predictionは、モデルの予測ヘッドを変更し、現在の隠れ状態に基づいて複数の未来のトークン(例:\(t+1\), \(t+2\), …, \(t+N\))の確率を同時に出力することで、この問題を克服しようとする。

実装には、コード例のように、異なる未来のオフセット(\(t+1\)用、\(t+2\)用など)のトークンを予測するように学習された、複数の個別の予測ヘッド(lm_heads)を持たせるアプローチがある。

学習中、compute_lossメソッドで示されるように、モデルは入力シーケンスを受け取り、次の\(N\)トークンの予測が、訓練データの実際の\(N\)個のターゲットトークンと比較される。損失(通常はクロスエントロピー)が予測された各位置で計算され、集約されて逆伝播に使用される。

この方法は速度向上を示すことができるが、いくつかの欠点がある。遠い未来のトークンを予測する精度は低下する傾向があり、選択された\(N\)トークンのシーケンスは、単一トークン生成が取ったであろう最適パスから逸脱する可能性がある。したがって、これは多くの場合、生成速度と品質のトレードオフとなる。

class MultiTokenPredictor(nn.Module):
    def __init__(self, hidden_dim, vocab_size, num_predicted_tokens=2, shared_prediction_head=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.num_predicted_tokens = num_predicted_tokens
        self.shared_prediction_head = shared_prediction_head

        if shared_prediction_head:
            # すべての位置で同じ予測ヘッドを共有
            self.lm_head = nn.Linear(hidden_dim, vocab_size)
        else:
            # 位置ごとに個別の予測ヘッドを使用
            self.lm_heads = nn.ModuleList([
                nn.Linear(hidden_dim, vocab_size) 
                for _ in range(num_predicted_tokens)
            ])

    def forward(self, hidden_states):
        batch_size, seq_len, _ = hidden_states.shape

        # 最後のトークンの隠れ状態を取得
        last_hidden = hidden_states[:, -1]

        if self.shared_prediction_head:
            # (共有ヘッドのロジックはデモ用に簡略化)
            multi_token_logits = []
            for i in range(self.num_predicted_tokens):
                projected_hidden = last_hidden # 実際にはより複雑な変換が必要
                multi_token_logits.append(self.lm_head(projected_hidden))
            multi_token_logits = torch.stack(multi_token_logits, dim=1)
            next_token_logits = multi_token_logits[:, 0:1]
        else:
            # 位置ごとに個別のヘッドを使用
            multi_token_logits = torch.stack([
                head(last_hidden) for head in self.lm_heads
            ], dim=1)
            next_token_logits = multi_token_logits[:, 0:1]

        # next_token_logits は標準的な推論用, multi_token_logits は学習用
        return next_token_logits, multi_token_logits

    def compute_loss(self, hidden_states, labels, ignore_index=-100):
        # 予測を取得
        _, multi_token_logits = self.forward(hidden_states)

        # ターゲットを準備: ラベルを予測と一致するようにシフト
        # (この損失計算は簡略化されたデモです)
        targets = []
        for i in range(self.num_predicted_tokens):
            # 実際には、(seq_len - num_predicted_tokens) の長さにわたって計算する必要がある
            targets.append(labels[:, 1+i:labels.shape[1]-self.num_predicted_tokens+1+i])
        
        # ... 損失計算ロジック (ここでは簡略化のため省略) ...
        # loss = F.cross_entropy(...)
        loss = 0.0 # ダミー
        
        #
        # 以下の損失計算は、元のブログのロジックに基づき、
        # `last_hidden` のみから予測された `multi_token_logits` (B, N, V) と
        # `labels` (B, L) の最後の N トークンを比較するように修正します。
        #
        
        # ターゲットは、入力シーケンスに続く N トークン
        # labels の形状が (B, seq_len + num_predicted_tokens - 1) と仮定
        
        # 簡略化:入力の最後の N トークンをターゲットと仮定(実際にはシフトが必要)
        if hidden_states.shape[1] > self.num_predicted_tokens:
            stacked_targets = labels[:, -self.num_predicted_tokens:] # (B, N)
        else:
            # シーケンスが短い場合の処理(デモ)
            stacked_targets = labels[:, 1:1+self.num_predicted_tokens]
            if stacked_targets.shape[1] < self.num_predicted_tokens:
                # パディング(ダミー)
                pad_size = self.num_predicted_tokens - stacked_targets.shape[1]
                stacked_targets = F.pad(stacked_targets, (0, pad_size), value=ignore_index)


        loss = 0
        for i in range(self.num_predicted_tokens):
            loss += F.cross_entropy(
                multi_token_logits[:, i].view(-1, self.vocab_size),
                stacked_targets[:, i].reshape(-1),
                ignore_index=ignore_index
            )
        
        return loss / self.num_predicted_tokens

14. Speculative Decoding

Speculative Decoding(投機的デコーディング)は、大規模言語モデルの推論プロセスを高速化するために設計された巧妙なテクニックである。

標準的な生成は、計算コストの高い大規模モデル(「ターゲット(target)」モデル)が、一度に1トークンだけを予測するために完全なフォワードパスを実行する必要があるため、ボトルネックとなっている。

Speculative Decodingは、はるかに小型で高速な「ドラフト(draft)」モデルを導入する。このドラフトモデルは、候補となる未来のトークンシーケンス(「ドラフト」)を迅速に生成する。中核となるアイデアは、大規模なターゲットモデルを使用して、このドラフトシーケンス全体を単一の並列フォワードパスで検証し、一度に複数のトークンを受け入れる可能性があるというものである。

メカニズムは、ドラフトモデルの予測とターゲットモデルの予測を比較することにかかっている。

  1. ドラフトモデルが \(k\) 個のトークン \(d_1, \dots, d_k\) を提案する。
  2. ターゲットモデルは、元の入力+ドラフトシーケンス全体に対して1回実行される。これにより、ドラフトシーケンス内の各位置におけるターゲットモデルの確率分布が得られる。
  3. 各ドラフトトークン \(d_i\) が検証される。ターゲットモデルがドラフトトークンに強く同意する場合(特定の採択ルールに基づく)、トークンは採択される。
  4. この検証は、ドラフトトークン \(d_j\) が棄却されるまで、またはすべてのドラフトが採択されるまで逐次的に進められる。
  5. 位置 \(j\) で棄却が発生した場合、\(d_1, \dots, d_{j-1}\) は保持される。重要なことに、位置 \(j\) で計算されたターゲットモデルの確率分布を使用して、修正されたトークンをサンプリングできる。

ターゲットモデルの推論ステップごとに平均して複数のトークンを採択することにより、Speculative Decodingは、生成されるテキストの品質に最小限の影響で、大幅なスピードアップ(例:2〜3倍)を達成できる。

# この例は、モデルとトークナイザが既にロードされていることを前提としています
class SimpleSpeculativeDecoding:
    def __init__(self, target_model, draft_model, tokenizer, max_draft_tokens=5):
        self.target_model = target_model
        self.draft_model = draft_model
        self.tokenizer = tokenizer
        self.max_draft_tokens = max_draft_tokens

    def generate(self, prompt, max_length=100):
        # プロンプトのトークンIDから開始
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.target_model.device)

        while input_ids.shape[1] < max_length:
            # ステップ1: 複数のドラフトトークンを生成
            draft_input_ids = input_ids.clone()
            draft_tokens = []

            with torch.no_grad():
                for _ in range(self.max_draft_tokens):
                    outputs = self.draft_model(draft_input_ids)
                    next_token_logits = outputs.logits[:, -1, :]
                    next_token = torch.argmax(next_token_logits, dim=-1)
                    
                    draft_tokens.append(next_token.item())
                    draft_input_ids = torch.cat([draft_input_ids, next_token.unsqueeze(0)], dim=1)
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break

            # ステップ2: ターゲットモデルで検証
            with torch.no_grad():
                verification_ids = torch.cat([
                    input_ids, 
                    torch.tensor([draft_tokens]).to(input_ids.device)
                ], dim=1)
                
                target_outputs = self.target_model(verification_ids)
                # input_idsの最後からドラフトトークン分のロジットを取得
                target_logits = target_outputs.logits[:, input_ids.shape[1]-1:-1] # (B, K, V)
                
                target_probs = F.softmax(target_logits, dim=-1) # (B, K, V)

                # トークンを採択
                accepted_tokens = []
                for i, token_id in enumerate(draft_tokens):
                    target_prob_dist = target_probs[0, i] # ターゲットのi番目の予測
                    
                    # 簡略化された採択ルール(確率の比較)
                    # 実際には、ドラフトモデルの確率も考慮する必要がある
                    
                    target_token_id = torch.argmax(target_prob_dist).item()
                    
                    if token_id == target_token_id:
                        accepted_tokens.append(token_id)
                    else:
                        # 棄却: ターゲットモデルから新しいトークンを取得
                        accepted_tokens.append(target_token_id)
                        break
            
            # 採択されたトークンを input_ids に追加
            input_ids = torch.cat([
                input_ids, 
                torch.tensor([accepted_tokens]).to(input_ids.device)
            ], dim=1)

            if input_ids[0, -1].item() == self.tokenizer.eos_token_id:
                break

        # 生成されたトークンをデコード
        return self.tokenizer.decode(input_ids[0])

まとめ

「Attention Is All You Need」は、間違いなくAIの歴史における転換点であった。しかし、それは壮大な物語の序章に過ぎなかった。

今回紹介した14のテクニック(GQA、Flash Attention、RoPE、MoE、AdamWなど)は、オリジナルのTransformerアーキテクチャが抱えていた計算量、メモリ、安定性、効率といった多くの課題を解決するために考案された、無数のイノベーションのほんの一部である。

GPT-5、Claude 4のような今日の最先端モデルは、これらの洗練された技術を多く組み込むことで、その驚異的な能力を実現している。この分野のイノベーションの速さは驚異的であり、次にどのようなブレークスルーが登場するのか、目が離せない。