Blog
MN-Core Compiler Core チームの諸戸です。
情報処理推進機構(IPA)が主催するセキュリティ・キャンプ2025ネクストにて『低レベル MN-Core プログラミング』という講座を担当いたしました。
今回は本講座で扱った『グラフコンパイラ自作入門』について、セキュリティ・キャンプに参加されていない一般の方も体験できるよう、活動報告と解説を行っていきたいと思います。
まず、「グラフコンパイラ」と言われても、何をするものかイメージしづらいかもしれません。
また、「MN-Core」と聞いて「PFN が独自開発している AI チップ」ということは知っていたり、「MN-Core Challenge を通じてどんな命令セットを備えているかは知っている」という方は一定数いらっしゃるかもしれませんが、実際に MN-Core がどのように AI モデルの計算を行っているのか、実感されている方は少ないのではないでしょうか。
この記事を通じて、以下の2点を目指します。
- MN-Core がどのように AI で使われる計算を行うのかを知る
- Python で記述された AI モデルを、グラフコンパイラがどのように処理して AI チップ向けのアセンブリを生成するのかを知る
そのための題材として、オープンアクセスで公開されている MN-Core エミュレータを用いつつ、MNIST データセットの分類器の訓練を体験していただきます。
以下のコンテンツを用意しています。この記事と合わせてご参照下さい。
・MN-Core Simple Graph Compiler for Education (GitHub)
・MN-Core Challenge「MNIST Operator 実装」問題セット
・講義スライド
この記事では、この GitHub に公開された学習用のグラフコンパイラを用い、MN-Core 用のアセンブリ (vsm) を出力できるように改造してもらいます。
改造と言っても、MNIST のトレーニングに必要な演算 (Operator) の実装部分を埋めるだけで、オープンアクセスのエミュレータを使用して実際のトレーニングまで行えるようになっています。
「演算 (Operator) の実装部分」に関しては、MN-Core Challenge に特設問題セットを用意しました。オンラインジャッジ形式で、入出力アドレスなどが固定の状態で実装してもらうことで、ML 用途に特化された MN-Core の命令セットでどのようにトレーニングを行うのかを体験することができます。
特設問題セットの問題を 1 問解いたら、学習用のグラフコンパイラに実装を移すことで、コンパイラの対応演算が増えていき、最終的に MNIST のトレーニングが行えるようになる、という流れになります。
ブログを読みつつ、ぜひ自作グラフコンパイラを改造し、MN-Core 上での MNIST の訓練にトライしてください!
グラフコンパイラ
近年、AI技術の発展は目覚ましく、それを支えるハードウェアもCPU、GPU、さらには MN-Core をはじめとする専用のAIアクセラレーターと多様化しています。AI開発者がこれらの多様なハードウェアで最高の性能を引き出すための鍵となるのが、「グラフコンパイラ」です。
コンパイラと言うと、C++ コンパイラに限定しても GCC や Clang、Microsoft Visual C++ や Intel C++ Compiler など様々なものがありますが、グラフコンパイラにも数多くの種類があります。例えば PyTorch の標準コンパイラである TorchInductor や、Google の OpenXLA、Apache の Apache TVM や Microsoft の ONNX Runtime、LLVM の MLIRなどがあります。また PFN でも PFVM と呼ばれるグラフコンパイラを自社開発しており、MN-Core や GPU 向けのコード生成に使用されています。
グラフコンパイラとは、プログラムの計算手順を「計算グラフ」という形式で一度表現し、そのグラフ全体を分析して、特定のハードウェア向けに最適化するソフトウェアです。
計算グラフとは、数式やプログラムの処理を、ノード(演算)とフロー(データの流れ)で表現したものです。値(テンソル)と演算そのものを中心に、計算の流れを表現します。
一般的なコンパイラ(GCCやLLVMなど)は、命令を load, mul, add, store など、細かい低レベルな命令の集合として処理を捉えるため、最適化は命令レベルにとどまります。一方、グラフコンパイラは、畳み込み(Conv)や活性化関数(ReLU)といったAI特有の高レベルなノードとして計算を理解できます。これにより、「Conv → ReLU」といった典型的なパターンを一つのGPUカーネルにまとめる(Fusion)など、より高度な最適化が可能になります。
また、近年の深層学習では、マルチノード環境やノード内の高速なメモリを有効活用し、DRAMとのデータ往復を減らすことが性能向上の鍵となります。そのためには、グラフ全体で値の流れを最適化し、データ転送と計算の効率を最大化する必要もあります。
自作Cコンパイラなどを作ったことがある方なら、ソースコードを解析してAST(抽象構文木)を作る工程をご存知でしょう。グラフコンパイラでは、PyTorchなどのフレームワークがその役割を担います。本演習で作成するのは、その中間表現(計算グラフ)を受け取り、ターゲットマシンのメモリ配置を決め、アセンブリを生成する「コンパイラ・バックエンド」の部分に相当します。
AI分野でグラフコンパイラが不可欠とされる主な理由は以下の3点です。
- ハードウェアの多様化への対応: ハードウェアごとに得意な計算が異なるため、開発者が全てのHWで最高性能のコードを書くのは困難です。グラフコンパイラがその差を埋めます。
- 開発者の生産性向上: 開発者はハードウェアの低レベルな詳細を気にすることなく、モデルの設計という本質的な作業に集中できます。
- AI計算の性質: AI計算のほとんどは行列積や畳み込みなどの密な線形代数演算で、非常に規則的です。静的で明示的な計算グラフを用いることで、機械的な並列化や最適化がしやすくなります。
グラフコンパイラは、主に以下の3つのステップを経て実行可能コードを生成します。
- 計算グラフのトレース: PyTorchなどのAIフレームワークから、AIモデルの計算グラフ(例:ONNX形式、torch.fx 形式など)を抽出します。
- グラフ最適化: グラフ全体を分析し、より高速に実行できるグラフへと変換します。主な最適化手法には、複数の演算を一つにまとめる「演算子融合(Operator Fusion)」や、テンソルのメモリ上の並び順を効率的な形式に変換する「レイアウト変換(Layout Transformation)」、計算精度を調整する「自動混合精度(Automatic Mixed Precision)」などがあります。
- コード生成: 最適化されたグラフをもとに、ターゲットハードウェア(GPUなど)に特化した高速な機械語コード(カーネルコード)を生成します。
MNIST と多層パーセプトロン(MLP)
今回の MN-Core 向けグラフコンパイラ自作では、以下の多層パーセプトロン(MLP)モデルを使用し、MNIST の分類を行います。
class SimpleNN1024(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1024, 16) # 入力 32x32 = 1024、隠れ層 16 次元
self.fc2 = nn.Linear(16, 16) # 出力 16 次元 (10クラス問題だけど、16次元にしておく)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor):
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
return self.fc2(x)
MLPは、最も基本的なニューラルネットワーク(NN)の一つで、複数の層に配置された「ニューロン」が情報を伝達し学習を行います。簡単には、以下のような構成をしています。
- 入力層 (Input Layer): データを受け取ります。本来 MNIST データセットは 28×28ピクセル画像なのですが、簡略化のため32×32ピクセル(1024次元)にパディングします。
- 隠れ層 (Hidden Layer): 入力からの情報を受け取り、複雑なパターンを学習・抽出します。今回のモデルでは16次元の隠れ層を持ちます。
- 出力層 (Output Layer): 最終的な予測や分類結果を出力します。MNISTは0~9の10クラス分類ですが、実装の都合上16次元にパディングしています。本来であればコンパイラが自動でパディングを挿入し、余った部分を計算結果に影響しない無害な値(加算なら0、Maxなら-infなど)で埋めて処理を行うべきですが、今回はその工程を省略する。
各層のニューロンは、前の層からの入力に重みを掛け合わせ、バイアスを加算し、活性化関数(今回はReLU)を通して次の層へ出力を伝えるという、線形代数演算(\(y = Wx+b\))を基本としています。
MLP の学習は、大きく分けて「順伝播」「逆伝播」「パラメータ更新」のステップを繰り返し行うことで実現されます。
1. 順伝播 (Forward Propagation)
入力層(1024次元)に画像データを入力し、隠れ層(16次元)、そして出力層(16次元)へと計算を進め、最終的な予測結果(各クラスのスコア)を算出します。
基本的な計算: 各層のニューロンでは、前の層の出力に対して「重み行列(\(W\))」を乗算し、「バイアス(\(b\))」を加算する線形代数演算 (\(y = Wx+b\)) が行われます。
活性化関数: 線形演算の結果は、非線形性を導入するための活性化関数(今回は ReLU )を通過します。
2. 逆伝播 (Backward Propagation) と勾配計算
順伝播で得られた予測結果と、正解ラベルとの差を「損失(Loss)」として計算します。この損失を最小化するために、各層の重みとバイアスをどのように調整すべきかを示す「勾配(grad)」を計算するのが逆伝播です。
PyTorchなどのフレームワークでは、順伝播を定義するだけで、自動的に勾配計算のための逆伝播グラフが内部で構築されます。
計算の手順としては、モデルの出力と正解ラベルの差分を元に計算された勾配を起点に、微分演算に従って入力側(前の層)へと逆向きに伝播させていくことで各パラメータの勾配を求めます。例えば、ReLU の微分は Step 関数、行列積(Gemm)の微分は転置を使った行列積として計算されます。
3. パラメータ更新 (Optimizer Step)
計算された勾配に「学習率(Learning Rate)」を乗算し、現在の重みとバイアスから差し引くことで、パラメータを更新します。これにより、損失が小さくなる方向へモデルのパラメータが調整されます。
今回の自作グラフコンパイラでは、独自アーキテクチャ用の計算コードを出力することに焦点を当てているため、モデルの自動微分パートは PyTorch の Autograd(自動微分)機能を使い、計算グラフを構築しています。
逆伝播グラフの構築方法も様々であり、PyTorch の Autograd では順伝播の Torch コードを実行しつつ演算のたびに裏で逆伝播のグラフを構築します。他にも、このような方法で作成された逆伝播付きのグラフを実行しトレースを行うことで、順伝播と逆伝播のグラフを同時に作成する方法もあり、順伝播と逆伝播の垣根を超えた演算融合(Fusion)や、メモリ使用量を抑えるための再計算の最適化が可能になるなどの特徴もあります。
演習用自作グラフコンパイラの説明
今回のグラフコンパイラ自作入門では、あらかじめ用意された C++ コードを出力するグラフコンパイラを拡張し、多層パーセプトロン(MLP)モデルを MN-Core2 向けのコードとして出力・実行し、MNIST データセットの学習を行っていただきます。
今回 MN-Core で動かすモデルは、fx_export/train.py に以下のように定義されています。
class SimpleNN1024(nn.Module):
# PADDED_MNIST用: 入力1024次元(32x32)、出力16次元のモデル"""
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(1024, 16) # 32x32 = 1024
self.fc2 = nn.Linear(16, 16) # 出力を16次元に拡張
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
return self.fc2(x)
学習用自作グラフコンパイラには、既に以下の機能が実装されています。
- 定義された NN をトレースし、自動微分による逆伝播付き計算グラフを作成する
- その計算グラフから C++ コードを出力する
- 出力された C++ コードをコンパイル・実行して MLP のトレーニングを行う
この「C++ コード出力機能」を参考に、MN-Core2 向けのアセンブリを出力する機能を作成していただきます。
演習用コンパイラは、モデルの定義から最終的な学習実行までを、以下のようなステップで実行します。
1. C++コード生成と検証、実行
演習用コンパイラは、主に以下の3つのコマンドを通じて、PyTorchで定義されたモデルをターゲット向けのコードに変換します。
| コマンド | 目的 | 内容 |
|---|---|---|
export (Step 1) |
計算グラフの抽出 (ONNX) | PyTorchの自動微分機能を利用して、順伝播、逆伝播、パラメータ更新を含む学習ステップ全体の計算グラフを、中間表現であるONNX形式のファイルとしてエクスポートします。 |
test (Step 2) |
C++コードの検証 | エクスポートされたONNXグラフからC++コードを生成し、そのコードを実際に実行します。参照実装(例:PyTorch)との計算結果の誤差(Error)を比較し、生成されたC++コードが期待通りに動作するかを検証します。 |
train_cpp (Step 3) |
実際の学習実行 | 生成・コンパイルしたC++コードを用いて、MNISTの学習を最初から最後まで実行します。エポックごとの損失(Loss)と精度(Accuracy)が表示され、C++コードが正しく学習を進められるかを確認します。 |
この一連のステップを通じて、AIフレームワークでのモデル定義から、ターゲットハードウェアで実行可能な低レベルコードへの変換、そして実データによる学習実行という、グラフコンパイラが担う役割全体を体験できます。
export では、計算が定義された関数のほかに、ダミー入力を一緒に渡します。ダミー入力を使用して関数を実行し、「どのような Tensor Shape で演算が行われるのか」を含めて計算のトレースを行い、ONNX を出力します。
動的型付け言語であるPythonのコードから、コンパイル時にメモリ配置を決定するために、ダミー入力を流して型推論(Shape推論)を行っていると捉えると分かりやすいでしょう。
実際に `./haribote_graph_compiler.py export` を実行してみると、以下のような結果が得られるはずです。
=== ONNXエクスポート === モデル付き関数を検出しました。 自動微分ベースのグラフエクスポートを使用します。 ONNXモデルを /tmp/train_step/model.onnx に保存しました エクスポート完了: /tmp/train_step - model.onnx - input_*.npy (6 files) - output_*.npy (10 files)
model.onnx は、トレースした関数を計算グラフとして ONNX で表現されたものになります。
input_*.npy は、テスト用のサンプル入力です。ダミー入力のほか、重みパラメーターの初期値なども保存されています。
output_*.npy はテスト用のサンプル出力で、input_*.npy` を入力したときの関数の出力値や勾配 (grad)、逆伝播で更新されたあとのパラメーター値が保存されています。計算の関数としては、勾配は出力に含める必要性はないのですが、グラフコンパイラ開発のデバッグ用に出力しています。
test では、ONNX を入力にコンパイルとテストを行います。C++ 生成モードでは C++ コードを生成したあと、Python から呼べる形でコンパイルを行い、Python で計算した結果と同様の結果が得られるかをテストします。
実際に `./haribote_graph_compiler.py test /tmp/train_step` を実行してみると、以下のような結果が得られるはずです。
(省略) === C++ check検証 === Var | Error | Ref Min | Ref Max | Ref Avg | Shape | v -------------------------------------------------------------------------------- output | 5.44e-07 | -1.11e+00 | 9.06e-01 | -1.38e-03 |(256, 16) | ✓ loss | 9.54e-07 | 2.82e+00 | 2.82e+00 | 2.82e+00 |() | ✓ grad_fc1_weight | 9.31e-09 | -2.75e-02 | 2.93e-02 | 4.64e-05 |(16, 1024)| ✓ grad_fc1_bias | 1.12e-08 | -2.80e-02 | 3.06e-02 | 7.29e-03 |(16,) | ✓ grad_fc2_weight | 1.49e-08 | -2.74e-02 | 2.34e-02 | -1.46e-10 |(16, 16) | ✓ grad_fc2_bias | 3.73e-08 | -6.61e-02 | 8.17e-02 | -4.66e-10 |(16,) | ✓ updated_fc1_weight | 3.73e-09 | -3.14e-02 | 3.13e-02 | -1.65e-05 |(16, 1024)| ✓ updated_fc1_bias | 1.86e-09 | -3.01e-02 | 2.15e-02 | -3.55e-03 |(16,) | ✓ updated_fc2_weight | 1.49e-08 | -2.49e-01 | 2.49e-01 | 8.15e-03 |(16, 16) | ✓ updated_fc2_bias | 7.45e-09 | -2.33e-01 | 2.38e-01 | -3.15e-02 |(16,) | ✓ ✓ テスト成功: /tmp/train_step
各出力の誤差や shape、正誤判定(verify) の表示が行われています。
fx_export/operators/ 以下に、各種Operatorの実装が書かれています。
例えば fx_export/operators/add.py の `generate_cpp` には、Add Operatorが来たときに出力すべき C++ コードを返す定義が書かれています。例えば以下のような形で、文字列の配列として C++ コードを返しています。
lines.append(f" const Matrix<{shape0[0]}, {shape0[1]}> {out_var} = add_colvec<{shape0[0]}, {shape0[1]}>({in0}, {in1});")
今回は、Tensor の計算は全て matrix_operations.hpp に書かれたライブラリを呼び出す形で実装されています。
生成される C++ コードは、`/tmp/test__tmp_train_step/forward_backward.cpp` に保存され、以下のようなコードになります。
#include "matrix_operations.hpp"
...
void forward_backward(
const float* input_ptr,
...
float* output_ptr,
...
) {
const Matrix<256, 1024> input = load<256, 1024>(input_ptr);
...
const Matrix<256, 16> fc1_pre_matmul = matmul<256, 1024, 16>(input, trans<16, 1024>(fc1_weight));
const Matrix<256, 16> fc1_pre = add_colvec<256, 16>(fc1_pre_matmul, fc1_bias);
...
save(fc2_pre, output_ptr);
}
生成される C++ コードを意図的に壊してみると、テストに失敗する事が確認できます。
例えば fx_export/operators/div.py で `div_rowvec` のところを `sub_rowvec` に置き換えると、以下のような結果になります。
=== C++ check検証 === Var | Error | Ref Min | Ref Max | Ref Avg | Shape | v -------------------------------------------------------------------------------- output | 5.44e-07 | -1.11e+00 | 9.06e-01 | -1.38e-03 |(256, 16) | ✓ loss | nan | 2.82e+00 | 2.82e+00 | 2.82e+00 |() | ✗ grad_fc1_weight | 1.56e+00 | -2.75e-02 | 2.93e-02 | 4.64e-05 |(16, 1024)| ✗ grad_fc1_bias | 5.02e+00 | -2.80e-02 | 3.06e-02 | 7.29e-03 |(16,) | ✗ grad_fc2_weight | 2.89e+00 | -2.74e-02 | 2.34e-02 | -1.46e-10 |(16, 16) | ✗ grad_fc2_bias | 1.05e+01 | -6.61e-02 | 8.17e-02 | -4.66e-10 |(16,) | ✗ updated_fc1_weight | 1.56e-02 | -3.14e-02 | 3.13e-02 | -1.65e-05 |(16, 1024)| ✗ updated_fc1_bias | 5.02e-02 | -3.01e-02 | 2.15e-02 | -3.55e-03 |(16,) | ✗ updated_fc2_weight | 2.89e-02 | -2.49e-01 | 2.49e-01 | 8.15e-03 |(16, 16) | ✗ updated_fc2_bias | 1.05e-01 | -2.33e-01 | 2.38e-01 | -3.15e-02 |(16,) | ✗ ERROR: C++ check検証失敗 - いくつかの値が誤差しきい値(1e-05)を超えています ✗ テスト失敗: /tmp/train_step
このように、loss 以降の計算が壊れる事が確認できます。Softmax 関数の中で div が使われているため、output までは正しく計算できて、それ以降の計算がズレてしまった事が分かります。
train_cpp では、コンパイルされたコードに実際の MNIST データセットを用いてトレーニングを行います。
実際に `./haribote_graph_compiler.py train_cpp` を実行してみると、以下のような結果が得られるはずです。
(省略) Epoch 1: Loss: 1.1040, Accuracy: 86.43% Epoch 2: Loss: 0.4928, Accuracy: 89.00% Epoch 3: Loss: 0.4104, Accuracy: 89.50% Epoch 4: Loss: 0.3746, Accuracy: 90.18% (省略) Epoch 20: Loss: 0.2453, Accuracy: 93.00%
20 epoch ほど実行すると、精度 93% 程度で収束します。MNIST の分類問題は比較的簡単な部類なので本来はもっと精度を上げられるものなのですが、今回は実装の手間を軽減するために隠れ層が 16 次元の 1 層だけとしているので、この程度の値になります。MNIST は 10クラス分類問題であり、ランダムに答えを出力すると精度 10% になることを考えると、90% を超えればそこそこ分類できていると実感できると思います。
2. ユニットテスト
コンパイラ開発において、生成されたコードが正しいことを確認することは非常に重要です。本演習では、`test` コマンドとは別に、計算グラフの各ノード(Operator)単体での動作を確認できるユニットテストの仕組みが用意されています。
- ユニットテストの生成: `make build_mn_unittest` コマンドにより、学習グラフに含まれる `Gemm` や `Relu` など、個々のOperatorごとにONNXファイルとテストケースが分離されて生成されます。
- ユニットテストの実行: `test unit_tests/Operator名/` のようにディレクトリを指定して実行することで、特定の Operator だけを対象にC++コード生成と検証を行うことができます。ディレクトリを複数指定したり、`*` でワイルドカード指定することで、テスト範囲を指定することができます。
これにより、MN-Core向けの実装を進める際、他の部分の実装に影響されることなく、担当するOperatorの実装が正しいかを独立して検証することが可能になります。
`./haribote_graph_compiler.py build_unit_tests /tmp/train_step` と行うことで、C++ 向けのユニットテストが生成できます。
テストケースを ./unit_tests/train_step に生成... 生成中: Gemm_256x1024_16x1024_256x16_transB_a 生成中: Add_256x16_16_256x16_a ... 生成中: Mul_16_16_b 生成中: Sub_16_16_16_b ✓ 32 ノード中、32 個のテストケースを生成しました: ./unit_tests/train_step
生成したテストが `unit_tests/train_step/` 以下に保存されるので、`./haribote_graph_compiler.py test unit_tests/train_step/*` でテストを実行できます。
=== C++ check検証 === Var | Error | Ref Min | Ref Max | Ref Avg | Shape | v -------------------------------------------------------------------- fc1_pre | 0.00e+00 | -2.07e+00 | 2.38e+00 | -2.19e-03 |(256, 16)| ✓ ✓ テスト成功: unit_tests/train_step/Add_256x16_16_256x16_a ... === テストサマリー === ✓ PASS: unit_tests/train_step/Add_256x16_16_256x16_a ✓ PASS: unit_tests/train_step/Add_256x16_16_256x16_b ... 合計: 32/32 テスト成功
先程のように、`div_rowvec` のところを `sub_rowvec` に変えてみるなどで実装を壊してみると、以下のようにエラーで失敗します。
=== テストサマリー === ✗ FAIL: unit_tests/train_step/Div_256x16_256_256x16_a ... 合計: 31/32 テスト成功 失敗: 1 テスト
また、ユニットテスト作成時に、`./haribote_graph_compiler.py build_unit_tests /tmp/train_step –extra` のように `–extra` を追加すると、先頭の計算ノードから、各種途中の計算ノードまでの部分グラフを作成し、途中までの通しテストを作成できます。テストの実行は同じコマンドで行えます。
3. 演習用グラフコンパイラの方針と工夫
演習用グラフコンパイラは、MN-Coreアーキテクチャの特性に合わせるため、いくつかの追加処理を行っています。
メモリ転送Operatorの挿入
MN-Coreでは、PE(Processing Element)で計算するためには、PEに付随する容量の小さい LM(Local Memory)などの SRAM にデータがなければなりません。このため、DRAM(メインメモリ)から LM へのデータ転送が非常に重要になり、演習用コンパイラは、演算の前後に DRAM から LM への転送を行う「DL (Download)」Operatorと、LM から DRAM への転送を行う「UL (Upload)」Operatorを使用します。
これらのOperatorは、MN-Core 向けグラフエクスポート時に自動的に各演算に追加されます。具体的には、fx_export/operators の各種定義の `get_memory_layout_tag` を呼び出し、各種演算ごとにどのようなメモリ配置で来て欲しいのかをタグで返しています。
コンパイラは、Shape とこのタグに応じて、DRAM から LM に値を分配する命令を生成します。fx_export/operators/dl.py と ul.py に実際の分岐があり、そこから fx_export/operators/dlul_impls/dl_256x16.py などに分割された実装を呼び出しています。
Operatorの分解
例えば`softmax` などの複雑な演算は、`exp`、`sub`、`div`、`sum`などのより基本的な演算に分解することにします。これにより、個々の Operator の実装に集中しやすくなっています。
fx_export/operators の各種定義に `decompose` メソッドがあると、この分解処理を行います。fx_export/operators/softmax.py を見ると、Softmax を複数の計算ノードに分解している事が分かります。
4. 実装の進め方
まず、MN-Core 用ユニットテストを作成しましょう。以下のコマンドで生成できます。
./haribote_graph_compiler.py export --ignore=loss ./haribote_graph_compiler.py build_unit_tests /tmp/train_step --mntest
もしくは、make コマンドで `make build_mn_unittest` でも作成できます。
これで、`./unit_tests/train_step` に、各種 MN-Core 用 Operator のユニットテストが作成されます。
次に、MN-Core Challenge「MNIST Op 実装」問題セットを解き、各種 Operator の具体的な命令列がどのようになるのかを確認します。
順位表ページから各提出の行数部をクリックすると、その提出の VSM (MN-Core アセンブリ) を確認することもできます。
例えば問題「A-B」では、以下のようなコードで Accept を得られます。
lpassa $lm16v $nowrite fvadd $lm0v -$aluf $ln0v lpassa $lm24v $nowrite fvadd $lm8v -$aluf $ln8v
これをもとに、コンパイラの `generate_vsm` を埋めていきます。この問題に対応する Operator は fx_export/operators/sub.py です。
`generate_vsm` が定義されていますが、この問題は `shape0 == shape1` の場合に相当しますが、実装が無く、 `raise NotImplementedError` となっています。
`./haribote_graph_compiler.py test unit_tests/train_step/Sub_256x16_256x16_256x16_a`
でこのユニットテストが実行できますが、ちょうど NotImplementedError と言われるはずです。
`raise NotImplementedError` を消して、以下の実装が使われるように書き換えます。
for i in range(self.memory_len_in_div_ceil(8, 0)):
lines.append(f"ipassa $l{in1_prefix}{in1_offset + i * 8}v $nowrite")
lines.append(f"fvadd $lm{8 * i}v -$aluf $ln{8 * i}v")
詳しくはソースコード中のコメントをご覧ください。
この実装を有効にしてもう一度ユニットテストを実行すると、以下のように成功するはずです。
=== テスト実行: unit_tests/train_step/Sub_256x16_256x16_256x16_a === === MN-Core用テスト: unit_tests/train_step/Sub_256x16_256x16_256x16_a === VSMコード生成中: unit_tests/train_step/Sub_256x16_256x16_256x16_a/model.onnx 生成されたVSM: /tmp/mncore_vsm_test/Sub_256x16_256x16_256x16_a/generated.vsm 生成されたVSM(全4行、コメントと空行を除く): ipassa $lm16v $nowrite fvadd $lm0v -$aluf $ln0v ipassa $lm24v $nowrite fvadd $lm8v -$aluf $ln8v ACCEPTED!! score=4 j=4 m=0 bytes=95 ✓ offseted.vsm: テスト成功 ============================================================ === テストサマリー === ✓ PASS: unit_tests/train_step/Sub_256x16_256x16_256x16_a 合計: 1/1 テスト成功
無事に VSM が生成され、テストが成功します。
実装を間違えると、テストに失敗します。例えば `fvadd` の行をコメントアウトすると、以下のように表示されます。
生成されたVSM(全2行、コメントと空行を除く): ipassa $lm16v $nowrite ipassa $lm24v $nowrite 失敗: 再現コマンド: python3 ~/src/mn_mnist/judge/judge-py/judge.py unit_tests/train_step/Sub_256x16_256x16_256x16_a/offseted.vsm /tmp/mncore_vsm_test/Sub_256x16_256x16_256x16_a/generated.vsm テストケースファイル: unit_tests/train_step/Sub_256x16_256x16_256x16_a/offseted.vsm 生成された VSM: /tmp/mncore_vsm_test/Sub_256x16_256x16_256x16_a/generated.vsm エラー: RESULT MISMATCH: pos=(np.int64(0), np.int64(0)) actual=0.0 expected=369.5182800292969 error=-369.5182800292969 エラー: RESULT MISMATCH: pos=(np.int64(0), np.int64(1)) actual=0.0 expected=889.974365234375 error=-889.974365234375 エラー: RESULT MISMATCH: pos=(np.int64(0), np.int64(2)) actual=0.0 expected=484.7640075683594 error=-484.7640075683594 エラー: RESULT MISMATCH: pos=(np.int64(0), np.int64(3)) actual=0.0 expected=584.505615234375 error=-584.505615234375 エラー: RESULT MISMATCH: pos=(np.int64(0), np.int64(4)) actual=0.0 expected=-486.7652587890625 error=486.7652587890625 ✗ offseted.vsm: テストが失敗 ============================================================ === テストサマリー === ✗ FAIL: unit_tests/train_step/Sub_256x16_256x16_256x16_a 合計: 0/1 テスト成功 失敗: 1 テスト
生成された VSM や再現コマンドを使用しつつ、目的の VSM が生成されるように実装しましょう。
注意点として、Gemm(行列積問題) と、DL/UL (DRAM-LM メモリ転送問題) は種類が多いため、fx_export/operators/gemm_impls、fx_export/operators/dlul_impls 以下に実装を分けて収録しています。
実装については、問題集の先頭から順に実装を埋めていく形でも良いですし、`./haribote_graph_compiler.py test ./unit_tests/train_step/z_nodes_002` など、`z_` から始まる先頭からの順番に連結されたテストが順番に通るように実装していくのも良いでしょう。なお、連結テストを行うには LM 起点ではなく DRAM 起点でテストを行うため、対応する DL/UL も実装する必要があります。
全ての実装が終わったら、`./haribote_graph_compiler.py test ./unit_tests/train_step/z_nodes_093` で計算グラフ全体のテストを行ってみましょう。以下のようになれば成功です。
=== テスト実行: ./unit_tests/train_step/z_nodes_093 === === MN-Core用テスト: ./unit_tests/train_step/z_nodes_093 === VSMコード生成中: ./unit_tests/train_step/z_nodes_093/model.onnx 生成されたVSM: /tmp/mncore_vsm_test/z_nodes_093/generated.vsm 生成されたVSM(先頭4行と末尾4行、全2483行中): mvp/n16384i9f $d0 $lc0@.0 nop; wait i9f # Auto-tagged mvp/n16384i9f $d16384 $lc0@.1 nop; wait i9f # Auto-tagged ... mvp/n64i9f $lc0@0.0 $d56000@2 nop; wait i9f # Auto-tagged mvp/n64i9f $lc0@0.0 $d56000@3 nop; wait i9f # Auto-tagged ACCEPTED!! score=2700 j=2540 m=160 bytes=50558 ✓ offseted.vsm: テスト成功 ============================================================ === テストサマリー === ✓ PASS: ./unit_tests/train_step/z_nodes_093 合計: 1/1 テスト成功
テストに成功しました。`score=2700 j=2540 m=160` の表示は実際の VSM の行数で、MN-Core2 での実行時間を高精度に予測することができます。j というのが PE 命令数、m というのが mv 命令数です。
PE 命令に注目すると、このとき 2540 命令となり、1 命令で 4 cycle ぶんの命令が含まれるので 10160 cycle の時間がかかります。MN-Core2 の動作周波数は 750MHz なので、10160/750_000_000 [Hz] = 1.354667e-05 [sec] = 13.54667 [us] 程度の時間で実行できることが分かります。
成功することを確認したら、`./haribote_graph_compiler.py train_emu` で、MN-Core2 エミュレータを使用しつつ、MNIST のトレーニングが行えます。
MN-Core用ONNXに変換中... MN-Coreモデルを保存しました: /tmp/train_step_emu_batch256/mn_model.onnx 元のノード数: 30 新しいノード数: 103 DLノード: 43 ULノード: 23 計算ノード: 37 VSMコード生成中... VSMをアセンブル中... アドレス情報を抽出中... === エミュレータ訓練開始 === Epoch 1: 10/234 - Accuracy: 25.23% 20/234 - Accuracy: 45.03% 25/234 [3.8s/iter, avg:3.9s, ETA:13.5m]
エミュレーター実行なので非常に時間がかかりますが、10 iteration 目で Accuracy が 25% と、10% より有意に大きいので学習自体は行えていそうな事が読み取れます。
しばらく待つと
Epoch 1: ... 234/234 completed in 15.4mAccuracy: 86.43% Epoch 2: 234/234 completed in 15.5mAccuracy: 89.00%
のように 1, 2 epoch が終わり、C++ CPU 版の 1, 2 epoch 目と同様の Accuracy であることから、正しくトレーニングができている事が読み取れます。
もし、MN-Core2 の実機環境にアクセスできるのであれば、`./haribote_graph_compiler.py train_mncore2` で実機で動作させることができます。
今回の問題設定は、実装のしやすさ・人の手書き VSM でなんとか実装できる程度の小さい問題サイズなので、実機の速度をより体感するには問題サイズをもっと大きくする必要があります。
fx_export/train.py で定義した SimpleNN1024 モデルを差し替えるなどで、また別のモデルをコンパイルすることができます。例えば
class SimpleNN1024(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(1024, 16) # 32x32 = 1024
self.fc2 = nn.Linear(16, 16)
self.fc3 = nn.Linear(16, 16)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.fc3(x)
のように、16 次元の隠れ層を一つ増やしてみると、コンパイルが通り、学習が進みます。
また、16 次元ではなく、より高次元にすることでも、行列積などの生成コードをその次数に対応させることでコンパイル・実行ができるようになります。興味があればぜひ挑戦してみて下さい。
最後に
いかがだったでしょうか。今回は特定のモデルを動作させるための最低限の機能に絞り込んだグラフコンパイラを使って、MN-Core 用のアセンブリを生成し、エミュレータ上で実行してみる体験をしていただきました。
どのようにグラフコンパイラが動作するのか、その大まかな流れは体験できたのではないかと思います。
今回は最適化系の機能はスキップしました。次に行うとしたら、以下のような改善が考えられます。
- 同じ Shape の UL,DL が続いていたらスキップする
- Gemm と Add など特定の演算を Fuse した独自 Operator を追加する
- 今回は Operator が Shape に対して固定のメモリレイアウトを返したが、全体を見た最適なレイアウトプランナーを作る
- DL/UL と演算をそれぞれ単純連結するアセンブリを出力しているが、出力後に DL/UL と演算をオーバーラップできるように、出力アセンブリに対してパンチホール最適化を行う
- MN-Core 上の DRAM に配置するデータレイアウトについて、CPU のメモリで扱われるのと同じようにデータを順番に連続に配置したが、MN-Core の DL/UL 命令はハードウェアをシンプルするためにインターリーブ動作をするため、素直に PE に転送すると LM のレイアウトが直感と異なり複雑なことになったり、性能の出にくいレイアウトになってしまうことがある。そこで、DRAM へのレイアウトは DL/UL 命令の都合に合わせてシャッフルして配置することで性能向上を行う
…
など、MNIST を目的にした MLP に絞っても、多くの改善点があると思います。
更に実用的な大規模で複雑な問題を扱おうとしたら、別の最適化も考えられます。例えば出力が2回以上使われる計算について、一旦 DRAM に値を戻すよりも、軽い計算であれば再計算をしてしまった方が速いケース(Recomputation)もあります。
このように、グラフコンパイラには、いわゆる普通の CPU 向けコンパイラとはまた違った世界が広がっています。興味を持たれた方は、ぜひオリジナルのグラフコンパイラづくりに挑戦してみてください!
また、もし弊チームでの仕事に興味を持たれましたら、ぜひ、求人ページの「AI半導体 (MN-Core) コンパイラエンジニア」などからエントリーお願いします。




