近年、Large Language Model (LLM) をはじめとする深層学習モデルは著しい進化を遂げているが、その根幹を支えるアーキテクチャには依然として根本的な課題が潜んでいる。本記事では、ネットワークの深さ方向における情報伝達の仕組みを根本から見直したアプローチであるMoonshotAIの論文「Attention Residuals (AttnRes)」について、その基本原理から大規模学習に向けたシステム実装、そして得られた知見までを詳細に解説する。
Residual Connectionsの光と影
現代の深層学習モデルにおいて、Residual connections(残差結合)は不可欠な構成要素である。各レイヤーの更新則は以下のように記述される。
\[h_l = h_{l-1} + f_{l-1}(h_{l-1})\]
この再帰的な構造を展開すると、任意のレイヤー \(l\) の隠れ状態 \(h_l\) は、入力エンベディング \(h_1\) とそれまでの全レイヤーの出力の単純な総和として表現される。
\[h_l = h_1 + \sum_{i=1}^{l-1} f_i(h_i)\]
この「恒等写像(Identity mapping)」の存在により、勾配がネットワークの深層から浅層へと消失することなく伝播し、非常に深いネットワークの学習が可能になった。さらに実用上は、PreNormと呼ばれる正規化手法との組み合わせが標準となっている。
しかし、この単純な「加算」によるアプローチには、モデルが深くなるにつれて顕在化する重大な欠点が存在する。
- Uniform Aggregation(一様な集約): 上記の式から明らかなように、すべての過去のレイヤーの出力は重み「1」で固定されて足し合わせられる。情報の重要度や文脈に依存した選択的な集約が行われない。
- Hidden-State GrowthとPreNorm Dilution: 出力が単純加算されるため、隠れ状態 \(h_l\) の大きさ(ノルム)はレイヤーの深さ \(O(L)\) に比例して線形に増大していく。この結果、PreNormを通した際に初期レイヤーの相対的な寄与が薄まる「PreNorm dilution」と呼ばれる現象が発生する。深いレイヤーが全体に影響を与えるためには、ますます大きな出力を学習せざるを得なくなり、学習の不安定化を招く。
- 情報の埋没: 過去の情報が一つのベクトル \(h_{l-1}\) に不可逆的に圧縮されてしまうため、深いレイヤーが必要とする特定の初期レイヤーの情報を正確に引き出すことが困難になる。
これは、初期のRecurrent Neural Network (RNN) が、長い系列情報を単一の隠れ状態ベクトルに圧縮しようとして文脈を喪失していた問題と極めてよく似ている。
Attention Residuals:時間と深さの双対性
系列モデリングにおけるRNNの限界は、Attention機構(そしてTransformer)の登場によって打破された。過去のすべてのトークンに対して動的な重みを計算し、必要な情報を選択的に引き出すアプローチである。研究チームは、この「系列方向の進化」と全く同じパラダイムシフトを「深さ方向」に適用するAttention Residuals (AttnRes)を提案した。
Full Attention Residuals
AttnResでは、固定された加算による集約を、学習可能で入力に依存するSoftmax Attentionメカニズムに置き換える。各レイヤーへの入力は次のように定義される。
\[h_l = \alpha_{0 \to l} \cdot h_1 + \sum_{i=1}^{l-1} \alpha_{i \to l} \cdot f_i(h_i)\]
ここで、\(\alpha_{i \to l}\) はレイヤー \(l\) がレイヤー \(i\) の出力に割り当てるAttentionの重みであり、\(\sum_{i=0}^{l-1} \alpha_{i \to l} = 1\) を満たす。この重みは以下のように計算される。
\[\alpha_{i \to l} = \frac{\phi(q_l, k_i)}{\sum_{j=0}^{l-1} \phi(q_l, k_j)}\] \[\phi(q, k) = \exp(q^\top \text{RMSNorm}(k))\]
ここで、クエリ \(q_l = w_l\) は各レイヤー固有の学習可能なパラメータベクトル(次元数 \(d\))であり、キーおよびバリューは過去のレイヤーの出力 \(k_i = v_i = f_i(h_i)\)(ただし \(i=0\) の場合は入力エンベディング \(h_1\))である。RMSNormの導入は、出力の大きな特定のレイヤーがAttentionの重みを支配してしまうのを防ぐために極めて重要である。
この手法により、各レイヤーは「過去のどのレイヤーの情報が今の処理に必要か」を動的に判断し、選択的に情報を抽出(Selective Aggregation)できるようになる。
Block Attention Residualsによるスケーラビリティの確保
Full AttnResは理論的に優れているが、実践上、特に大規模なLLMの学習においては課題が生じる。標準的な学習ではバックプロパゲーションのために各レイヤーの出力(アクティベーション)を保持する必要があるため、メモリのオーバーヘッドは問題にならない。しかし、Pipeline Parallelismのような分散学習環境下では、過去のすべてのレイヤー出力をパイプラインのステージ間で通信し続ける必要があり、メモリと通信のコストが \(O(Ld)\) (\(L\)はレイヤー数、\(d\)は次元数)で増大してしまう。
この課題を解決するため、計算および通信コストを劇的に削減するBlock Attention Residuals (Block AttnRes)が導入された。
- Layer Partitioning: 全 \(L\) レイヤーを \(N\) 個のブロックに分割する(各ブロックのサイズは \(S = L/N\))。
- Intra-Block Accumulation: ブロック内では、従来のResidual connectionsと同様にレイヤーの出力を単純加算し、ブロックを代表する1つの表現ベクトル \(b_n\) を作成する。
- Inter-Block Attention: ブロック間の情報伝達においてのみ、この \(N\) 個のブロックレベル表現に対してAttentionを適用する。
これにより、メモリと通信のオーバーヘッドは \(O(Ld)\) から \(O(Nd)\) へと大幅に削減される。経験的に、ブロック数 \(N \approx 8\) 程度でFull AttnResの恩恵の大部分を維持できることが確認されている。
大規模学習・推論のためのシステム最適化
AttnResを実用的なものにするためには、単なるアルゴリズムの提案にとどまらず、インフラストラクチャレベルでの最適化が不可欠である。
学習時の最適化:Cross-Stage Caching
Pipeline Parallelism下では、モデルが物理的なGPU(ステージ)に分割される。Block AttnResを素朴に実装すると、各マイクロバッチがパイプライン境界を越えるたびに、それまでに蓄積されたすべてのブロック表現を再送信することになり、冗長な通信が発生する。
これに対し、受信側のステージで過去に受け取ったブロック表現をローカルにキャッシュする「Cross-stage caching」が考案された。これにより、ステージ間の遷移時には「直近に更新された差分のブロック」のみを通信すればよくなり、通信コストのピークを大幅に押し下げることに成功している。結果として、Pipeline Parallelism環境下でも学習時のオーバーヘッドは4%未満に抑えられている。
推論時の最適化:Two-Phase Computation Strategy
推論時におけるBlock AttnResの計算は、自己回帰的なデコーディングと似た性質を持つ。各レイヤーで過去のブロック表現にアクセスすると、メモリアクセスの回数が膨大になる。これを回避するため、「Two-phase computation(2フェーズ計算)」戦略が採用された。
- Phase 1(並列ブロック間Attention): あるブロック内のすべてのレイヤーが持つ学習済みクエリ \(w_l\) をバッチ化し、キャッシュされた過去のブロック表現に対して1回の行列積でAttentionを計算する。これにより、HBM(High Bandwidth Memory)からの読み出しコストをレイヤーごとではなくブロック単位に償却する。
- Phase 2(シーケンシャルなブロック内Attention): ブロック内での逐次的な計算(パーシャルサム)を行い、Phase 1の結果とOnline Softmaxを用いてマージする。
これらの工夫により、推論レイテンシの増加は一般的なワークロードにおいてわずか2%未満に留まっている。
実験結果と学習ダイナミクスの分析
Kimi Teamは、AttnResを実際のLLMアーキテクチャ(48B total / 3B activated parameters の Mixture-of-Experts モデル)に組み込み、1.4Tトークンによる事前学習と詳細な評価を行った。
PreNorm Dilutionの克服と学習の安定化
学習ダイナミクスの分析により、AttnResがPreNorm dilutionを効果的に緩和していることが実証された。標準的なResidual connections(Baseline)では、ネットワークが深くなるにつれて隠れ状態の大きさが単調に増大し、勾配の大部分が浅いレイヤーに偏る現象が見られた。
対照的にBlock AttnResでは、ブロック境界での選択的な集約によって隠れ状態の増大がリセットされ、値が一定の範囲に収まる(Bounded periodic pattern)。これにより、ネットワーク全体にわたって勾配がより均一に分布し、安定した学習と最終的なLossの大幅な改善(1.25倍の計算量削減に相当)が達成された。
ベンチマーク性能の向上
ダウンストリームタスクにおける評価では、Block AttnResはすべてのベンチマークでBaselineを上回った。特に、GPQA-Diamond(+7.5)のような高度な推論タスクや、HumanEval(+3.1)のようなコード生成タスクにおいて顕著な改善が見られた。これは、深いレイヤーが特定の過去の表現を「選択的に引き出し、その上に論理を構築する」という、構成的なタスクにおけるAttnResの優位性を示唆している。
獲得されたAttentionパターンの可視化
学習後のAttention重み \(\alpha_{i \to l}\) を可視化すると、非常に興味深い振る舞いが確認できる。
- 局所性の維持とSkip Connectionsの創発: 基本的には直前のレイヤーへのAttentionが最も強いが、同時に「遠く離れた初期レイヤー」に対して強い重みが割り当てられる非対角成分(Skip connections)が自然発生的に学習されている。
- レイヤーの専門化: Multi-Layer Perceptron (MLP) の前段では局所的な情報(直近のレイヤー)を重視する一方、Attentionレイヤーの前段ではより広い視野で過去の情報を参照するといった、レイヤーの役割に応じた専門化が見られる。
まとめ
Attention Residuals (AttnRes) は、単なるアーキテクチャの微修正にとどまらない。RNNからTransformerへの進化が系列モデリングにおける情報流のボトルネックを解消したように、AttnResは深層学習モデルの「深さ方向」における固定的な情報伝達(Residual connections)を、動的で選択的なルーティングへと進化させるものである。
Block AttnResというスケーラブルなアプローチと、システムレベルでの緻密な最適化(Cross-stage cachingやTwo-phase computation)が組み合わさることで、この革新的なメカニズムは今日の巨大なLarge Language Model (LLM) に対しても実用的なドロップインリプレイスメントとして機能する。AttnResは、より深く、より表現豊かで、より解釈可能な次世代のAIアーキテクチャを設計するための強力なパラダイムとなるだろう。
出典・参考リンク: