粗大メモ置き場

個人用,たまーに来訪者を意識する雑記メモ

PytorchのPretrained Modelを使ってSegmentationを行う個人メモ

はじめに

本記事はあくまでML初心者の筆者の個人メモです。

pytorchの出来合いのモデルを使って画像認識タスクのうちSegmentationを行うことを目標にします。

実行環境

環境が汚れにくく、実行も高速なGoogleColabを使用します。

必要なデータはwgetなどでDLしてきても良いですし、下記のコマンドで簡単にGoogleDriveとも接続できるので簡単で便利です。

# mount drive
from google.colab import drive
drive.mount('/content/drive')

torchvisionのモデルを使ったsegmentation例

pytorchで使用できる既製のモデルはいくつかありますが、ひとまずtorchvisionで使えるモデルを使ってsegmentationを行っていきます。

pytorch.org

先に作例を示すと 某所から借りてきた星野源氏の下記写真から人物の部分のみを抜き出すことができたりします。

f:id:ossyaritoori:20211205164710p:plain
星野源 氏の写真

f:id:ossyaritoori:20211205165400p:plain
人物の領域の抜き出し例(マスク未Normalize)

1. モデルを選んでロード

はじめに、欲しい機能を実現するモデル(と学習済み重み)を選びます。
モデルはCNNのネットワーク構造、学習済みの重みはどのデータセットで学習したかを表します。

学習済み重みをどのデータセットで学習したかにはきちんと気を配る必要があり、例えばtorchvisionのモデルは人等を含む20クラス分類でしか学習していないのでそこにない物体を検知・抽出するには新たに転移学習をする必要があります。

下記では試しにFCNのresnet50を選んで試してみます。

import torchvision
# 試しにresnet50を用いる
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True) # pretrained = Trueとすることで学習済みのモデルがセットされた状態になる。
model.eval() # モデルを評価用に切り替える。逆に学習するときはmodel.train()とする。おまじないと思って良い。

2. モデルのパラメータを確認

モデルを選んだら次は下記のパラメータを事前に確認しておきます。これは後述の画像を入力するときに必要になります。

  • モデルの入力となる画像のサイズ
  • 学習時の正規化項(mean, std)

f:id:ossyaritoori:20211205173540p:plain
詳しくはこのあたりを参照にしてください

今回のFCNのresnet50の場合、

  • 入力の画像サイズ:224x224
  • 正規化項:mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] (範囲が0〜1であることに注意)

となります。

3. 画像の読み込みとモデルへの入力

以上の情報をもとに画像をモデルへとmodel(img)と入力すれば結果を得られるのですが、ここで前処理と型変換が問題になってきます。

よくあるNumpyの画像が[縦、横、RGBチャンネル数]のnp.arrayとなっているのに対し、今回使うsegmentationのモデルでは[バッチ数, カラーのチャンネル数, 横, 縦]という並びのTensorになっていなければいけません。

実際のコードでは同じ画像を下記の形式で行ったり来たり初学者にはとてもconfusingです

  • numpyのarray(OpenCVと連携する用)
  • PILのimage (pytorchのTensorとの相性よし)
  • pytorchのTensor(modelに入力する用)

PILとtorchvision.transformsを用いた前処理

一番簡単かつ便利な手法で、torchvisionのtransformsを用いることで簡単に前処理を実装することができます。

具体的には下記のようなコードで前処理を書くことができます。

from torchvision import transforms
from PIL import Image

# 前処理用
preprocess = torchvision.transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# load image
img = Image.open("img.png")

# Get Normalized image
img_tensor = preprocess(img) 

# バッチサイズにあたる次元を一つ追加
img_input = img_tensor.unsqueeze(0)

# 推論
output = model(img_input)

途中のpreprocessでは画像のリサイズやTensor形式の変換、画像の正規化を定義しています。 そしてこのオブジェクトに直接PIL形式の画像を与えることで任意の変換を行うことができます。

先程述べた、下記のパラメータをきちんと反映させていることを確認してください。

  • 入力の画像サイズ:224x224
  • 正規化項:mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] (範囲が0〜1であることに注意)

他のtransformsに関しては公式か下記が参考になると思います。

qiita.com

numpyを用いた際の前処理

一応numpyを使っても前処理はできるのですがtorchvision.transformsが便利すぎて素直に変換したほうが良いです。

# PIL Image -> numpy array
np_img = np.array(pil_img)

# numpy array -> PIL Image
pil_img = Image.fromarray(np_img))

なお、この場合でもuint8かfloat32かどうかや、RGBかBGRかは気を使う必要があります。

nixeneko.hatenablog.com

4. 結果の解釈

サクッと飛ばしましたがモデルへの入力はmodel(x)のように計算できます。 model.forward()やそのままmodel.predict() でもできることがあるようですが違いは追々調べます…

出力結果がどの結果に属するかのマスクになるのですがこちらもTensor形式なのでnumpy arrayかPIL Image 形式にして図示する必要があります。

tzmi.hatenablog.com

今回はなるべくPILへと変換します。何度も言いますがtransformsが楽なので。

可視化の際には公式のチュートリアルと同様にsoftmaxで正規化すると良いです。

pytorch.org

  • マスクの作成
from torch.nn.functional import softmax

# torch.Size([1, 21, 224, 224]) -> torch.Size([21,224,224])
output_ = output['out'].squeeze()
# normalize
normalized_masks = softmax(output_, dim=0)
  • 可視化
def visualize_tensor(tensors):
  n = len(tensors)
  plt.figure(figsize=(24, 5))
  for i in range(n):
    img = transforms.ToPILImage()(tensors[i])
    plt.subplot(1, n, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img)
  plt.show()

visualize_tensor(normalized_masks)

可視化した結果が下記のとおりです。VOCでは1番目が背景、16番目がPersonとなっていますがそれに該当する箇所がハイライトされていることがわかります。

f:id:ossyaritoori:20211207233331p:plain
21クラスの分類結果

draw_segmentation_masksを使った可視化

draw_segmentation_masksというTensorを引数にとる関数があるっぽいので試してみました。

pytorch.org

一見便利そうですがTensorを引数にするのがちょっと癖があって難しいなと思いました。numpyならマスキングは非常に簡単だと思います。

import torch

wid,hei = img.size
reshape_tensor = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((hei,wid)),
    transforms.ToTensor(),
    ])  

img_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.uint8)
])

person_mask = reshape_tensor(normalized_masks[15]) > 0.5
bg_mask = reshape_tensor(normalized_masks[0]) > 0.5

person_img = torchvision.utils.draw_segmentation_masks(img_to_tensor(img), person_mask)
bg_img = torchvision.utils.draw_segmentation_masks(img_to_tensor(img), bg_mask)

visualize_tensor([bg_img, person_img])

f:id:ossyaritoori:20211208003947p:plain
閾値0.5での切り抜き結果

numpyを使うパターン

numpyの方は変換さえできればstraightforwardなのでさっくり書くにとどめます。

# tensor to numpy
out_np = reshape_tensor(normalized_masks[15]).detach().numpy().copy()

mask = (out_np > 0.5)
mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB)

masked = img_np * mask

参考・その他

とりあえず書き溜めておきます。汚ければ後で消すかもしれません。

参考になりそうな記事たち

大量の記事を斜め読みしたのでどれがどの参考になったかちょっと忘れてしまったのですがこれは確実に読んだというのを下記に記しておきます。

超初心者の抱えていた疑問と回答

とりあえず動かしていくにあたって感じたが疑問と現時点での自分の理解を書いておきます。

  • modelの入力に入れるTensorのサイズがよくわからない。なぜ四次元?
    • segmentationに関して言えば[バッチ数, カラーのチャンネル数, 横, 縦]という次元になっている 参考