MNISTをMLPで推論(Julia/Flux実装)
Juliaで機械学習をするための有名なライブラリにFluxがあります。Fluxを使ってMNISTの手書き数字の推論を行ったのでその方法をまとめておきます。 コードは次のようになります。これを参考に書きました。
パッケージ¶
基本的にFlux
さえあれば良いです。今回はMNISTデータを用いるのでMLDatasets
というパッケージを用いてデータを読み込みます。これらのパッケージは事前にインストールしておく必要があります。JuliaのREPLやnotebook上で次を入力してください。
using Flux
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
using MLDatasets
データの読み込み¶
MNISTデータを読み込みます。
x_train, y_train = MLDatasets.MNIST.traindata(Float32)
x_test, y_test = MLDatasets.MNIST.testdata(Float32)
flatten
で各データを1次元に落とします。
また、各画像の数字(0~9)はone-hotにしておきたいのでそちらはonehotbatch
という関数で変換しておきます。
モデルの定義¶
いよいよモデルの定義です。今回は一番簡単なMLPで実装していきます。
img_size = (28,28,1)
input_size = prod(img_size) # 784
nclasses = 10 # 0~9
# Define model
model = Chain(
Dense(input_size, 32, relu),
Dense(32, nclasses)
)
Dense(input_size, output_size, f)
という関数\(F\colon\mathbb{R}^{\mathrm{inputsize}}\to\mathbb{R}^{\mathrm{outputsize}}\)は
$$
F(x) = f(Wx+b)
$$
になります。\(f\)は活性化関数です。\(W,b\)は内部で勝手に定義されます。デフォルトでは初期値\(W,b\)はGlorotの一様分布に従ってランダムに選ばれます。
また、\(f\)を指定しなければ活性化関数は恒等関数になります。すなわち非線形変換は行われません。
今回は活性化関数にReLU関数を用いました。
Chain
は合成関数を作ります。すなわち、Chain(F,G)
は\(G\circ F\)という関数に対応します。
今回定義したモデルは784次元の入力から10次元の出力を返します。出力の10次元の中で一番大きい要素のindexが推定される数字とします。
定義したモデルから学習すべきパラメータを取り出しておきます。
学習の準備¶
データが多いのでミニバッチ学習を行いましょう。バッチサイズとエポック数を定義します。
これをもとにtrainデータとtestデータをバッチに分けていきます。train_loader = DataLoader((x_train, y_train), batchsize=batch_size, shuffle=true)
test_loader = DataLoader((x_test, y_test), batchsize=batch_size, shuffle=true)
損失関数¶
損失関数を定義します。入力x
に対して出力ŷ=model(x)
は10次元のベクトルになりますが、これは特に正規化されていません。本来は出力の段階でsoftmax関数で正規化すべきかもしれませんが、推定の意味においては最大値を取りさえすれば良いので特に問題はありません。
また、softmaxを通した10次元の離散分布softmax(ŷ)
とone-hotの分布y
の間の交差エントロピーcrossentropy(softmax(ŷ), y)
を計算すると数値的な誤差が生まれやすいことが知られています。
数学的にこれと等価なlogitcrossentropy(ŷ, y)=crossentropy(softmax(ŷ), y)
を用いたほうが数値的にも安定します。
よって損失関数は次のように定義します。
function loss_accuracy(loader)
acc = 0.0
ls = 0.0
num = 0
for (x, y) in loader
ŷ = model(x)
ls += logitcrossentropy(ŷ, y, agg=sum)
acc += sum(onecold(ŷ) .== onecold(y))
num += size(x, 2)
end
return ls/num, acc/num
end
学習¶
いよいよ学習させます。各エポックごとにtrainデータとtestデータのlossと精度を出力する関数を定義しておきます。
function callback(epoch)
println("Epoch=$epoch")
train_loss, train_accuracy = loss_accuracy(train_loader)
test_loss, test_accuracy = loss_accuracy(test_loader)
println(" train_loss = $train_loss, train_accuracy = $train_accuracy")
println(" test_loss = $test_loss, test_accuracy = $test_accuracy")
end
Flux.train!
関数でミニバッチ学習をしてもらいます。
10エポックの学習で96%近くの学習精度を達成できます。
また、手元の環境(M1 mac mini)で10エポック回すのに6.03秒かかりました。GPUとかは使わなくてもここまでの速度と精度が出るのは素晴らしいですね。