PytorchのPretrained Modelを使ってSegmentationを行う個人メモ
はじめに
本記事はあくまでML初心者の筆者の個人メモです。
pytorchの出来合いのモデルを使って画像認識タスクのうちSegmentationを行うことを目標にします。
実行環境
環境が汚れにくく、実行も高速なGoogleColabを使用します。
必要なデータはwgetなどでDLしてきても良いですし、下記のコマンドで簡単にGoogleDriveとも接続できるので簡単で便利です。
# mount drive from google.colab import drive drive.mount('/content/drive')
torchvisionのモデルを使ったsegmentation例
pytorchで使用できる既製のモデルはいくつかありますが、ひとまずtorchvisionで使えるモデルを使ってsegmentationを行っていきます。
先に作例を示すと 某所から借りてきた星野源氏の下記写真から人物の部分のみを抜き出すことができたりします。
1. モデルを選んでロード
はじめに、欲しい機能を実現するモデル(と学習済み重み)を選びます。
モデルはCNNのネットワーク構造、学習済みの重みはどのデータセットで学習したかを表します。
- torchvisionで使用できるモデル
- FCN ResNet50, ResNet101 FCNについて参考
- DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large DeepLabについて参考
- LR-ASPP MobileNetV3-Large
- モデルの学習済みパラメータ
- Pascal VOC on COCO
学習済み重みをどのデータセットで学習したかにはきちんと気を配る必要があり、例えばtorchvisionのモデルは人等を含む20クラス分類でしか学習していないのでそこにない物体を検知・抽出するには新たに転移学習をする必要があります。
下記では試しにFCNのresnet50を選んで試してみます。
import torchvision # 試しにresnet50を用いる model = torchvision.models.segmentation.fcn_resnet50(pretrained=True) # pretrained = Trueとすることで学習済みのモデルがセットされた状態になる。 model.eval() # モデルを評価用に切り替える。逆に学習するときはmodel.train()とする。おまじないと思って良い。
2. モデルのパラメータを確認
モデルを選んだら次は下記のパラメータを事前に確認しておきます。これは後述の画像を入力するときに必要になります。
- モデルの入力となる画像のサイズ
- 学習時の正規化項(mean, std)
今回のFCNのresnet50の場合、
- 入力の画像サイズ:224x224
- 正規化項:mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] (範囲が0〜1であることに注意)
となります。
3. 画像の読み込みとモデルへの入力
以上の情報をもとに画像をモデルへとmodel(img)
と入力すれば結果を得られるのですが、ここで前処理と型変換が問題になってきます。
よくあるNumpyの画像が[縦、横、RGBチャンネル数]のnp.arrayとなっているのに対し、今回使うsegmentationのモデルでは[バッチ数, カラーのチャンネル数, 横, 縦]という並びのTensorになっていなければいけません。
実際のコードでは同じ画像を下記の形式で行ったり来たり初学者にはとてもconfusingです
PILとtorchvision.transformsを用いた前処理
一番簡単かつ便利な手法で、torchvisionのtransformsを用いることで簡単に前処理を実装することができます。
具体的には下記のようなコードで前処理を書くことができます。
from torchvision import transforms from PIL import Image # 前処理用 preprocess = torchvision.transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # load image img = Image.open("img.png") # Get Normalized image img_tensor = preprocess(img) # バッチサイズにあたる次元を一つ追加 img_input = img_tensor.unsqueeze(0) # 推論 output = model(img_input)
途中のpreprocess
では画像のリサイズやTensor形式の変換、画像の正規化を定義しています。
そしてこのオブジェクトに直接PIL形式の画像を与えることで任意の変換を行うことができます。
先程述べた、下記のパラメータをきちんと反映させていることを確認してください。
- 入力の画像サイズ:224x224
- 正規化項:mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] (範囲が0〜1であることに注意)
他のtransformsに関しては公式か下記が参考になると思います。
numpyを用いた際の前処理
一応numpyを使っても前処理はできるのですがtorchvision.transformsが便利すぎて素直に変換したほうが良いです。
# PIL Image -> numpy array np_img = np.array(pil_img) # numpy array -> PIL Image pil_img = Image.fromarray(np_img))
なお、この場合でもuint8かfloat32かどうかや、RGBかBGRかは気を使う必要があります。
4. 結果の解釈
サクッと飛ばしましたがモデルへの入力はmodel(x)
のように計算できます。
model.forward()
やそのままmodel.predict()
でもできることがあるようですが違いは追々調べます…
出力結果がどの結果に属するかのマスクになるのですがこちらもTensor形式なのでnumpy arrayかPIL Image 形式にして図示する必要があります。
今回はなるべくPILへと変換します。何度も言いますがtransformsが楽なので。
可視化の際には公式のチュートリアルと同様にsoftmaxで正規化すると良いです。
- マスクの作成
from torch.nn.functional import softmax # torch.Size([1, 21, 224, 224]) -> torch.Size([21,224,224]) output_ = output['out'].squeeze() # normalize normalized_masks = softmax(output_, dim=0)
- 可視化
def visualize_tensor(tensors): n = len(tensors) plt.figure(figsize=(24, 5)) for i in range(n): img = transforms.ToPILImage()(tensors[i]) plt.subplot(1, n, i + 1) plt.xticks([]) plt.yticks([]) plt.imshow(img) plt.show() visualize_tensor(normalized_masks)
可視化した結果が下記のとおりです。VOCでは1番目が背景、16番目がPersonとなっていますがそれに該当する箇所がハイライトされていることがわかります。
draw_segmentation_masksを使った可視化
draw_segmentation_masksというTensorを引数にとる関数があるっぽいので試してみました。
一見便利そうですがTensorを引数にするのがちょっと癖があって難しいなと思いました。numpyならマスキングは非常に簡単だと思います。
import torch wid,hei = img.size reshape_tensor = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((hei,wid)), transforms.ToTensor(), ]) img_to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(torch.uint8) ]) person_mask = reshape_tensor(normalized_masks[15]) > 0.5 bg_mask = reshape_tensor(normalized_masks[0]) > 0.5 person_img = torchvision.utils.draw_segmentation_masks(img_to_tensor(img), person_mask) bg_img = torchvision.utils.draw_segmentation_masks(img_to_tensor(img), bg_mask) visualize_tensor([bg_img, person_img])
numpyを使うパターン
numpyの方は変換さえできればstraightforwardなのでさっくり書くにとどめます。
# tensor to numpy out_np = reshape_tensor(normalized_masks[15]).detach().numpy().copy() mask = (out_np > 0.5) mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB) masked = img_np * mask
参考・その他
とりあえず書き溜めておきます。汚ければ後で消すかもしれません。
参考になりそうな記事たち
大量の記事を斜め読みしたのでどれがどの参考になったかちょっと忘れてしまったのですがこれは確実に読んだというのを下記に記しておきます。
超初心者の抱えていた疑問と回答
とりあえず動かしていくにあたって感じたが疑問と現時点での自分の理解を書いておきます。