techno_memo

個人用の技術メモ。python・ROS・AI系のソフトウェア・ツールなどの情報を記載

matplotlibのアニメーション描画

この記事の目的

 pythonのmatplotlibを使ってアニメーションデータを表示/gif形式で保存する機能についてまとめる。

1. 時系列データの可視化について

 組込み機器(特に自動車や移動ロボット)の開発で、時系列データを解析する際に位置座標と時系列情報を一緒に確認したいことがある。 時系列毎にX,Y座標とセンサーデータが取得できているという前提でシークバーを用いて時間軸をずらしながら位置座標を確認したり、位置座標が時間ごとに どのように変化したかをgifファイルに保存するpythonスクリプトを実装する。

本記事では下記のようなデータを想定してアニメーションによる可視化を検討する。

f:id:sd08419ttic:20200301222637p:plain

2. matplotlib可視化ライブラリ

 Matplotlibにはanimation用のAPIが用意されており、一定時間ごとに異なるプロット内容を表示することでアニメーションを描画できる。 この機能については下記サイトなどで解説されている。

qiita.com

 単純にアニメーション描画をしたい場合には上記機能でも十分だが組込み機器の時系列データ確認では時系列を変化させながら位置座標やセンサー値の変化を確認したいことが多い。 そのために、描画したグラフの時間軸を変化させて波形を確認できる animatplotというライブラリを導入する。animatplotは下記のpipコマンドでインストールできる。

pip install animatplot

3. animatplotを使ったデータの可視化

animatplotでは、描画するデータの2次元座標X,Yデータを時間軸毎の配列として付与することでアニメーションデータを描画する。 例えば、0.1sec間隔で5秒間取得したデータを1秒づつ描画したい場合は、X軸のデータを10点ずつまとめたデータを開始時点をずらして0~4.0秒目分用意する。 pythonコードでは下記のように実装できる。

    ref_df = pd.read_csv(csv_file_path, encoding="utf-8-sig")    #日本語データ(Shift-Jis)を含む場合を想定
    x_np = np.array(ref_df["x"])
    y_np = np.array(ref_df["y"])
    Xs_log =np.asarray([x_np[t:t+10] for t in range(len(time_data_np)-10)]) #X軸データ × 時間軸 分の配列
    Ys_log =[y_np[t:t+10] for t in range(len(time_data_np)-10)]             #Y軸データ × 時間軸 分の配列

上記のデータを用意したあとで、animatplotのblocksというクラスの描画関数Scatter(散布図)もしくはLines(線グラフ)を使ってデータを描画する。 描画したデータを関数Animationでアニメーションの対象として設定して、plot.showするとアニメーション表示される。 また、anim.controls()で時間軸の制御コントロールバーを利用することができ、save_gifで結果をgifファイルに保存する。

    anim = amp.Animation([block,block2,block3,block4])
    anim.controls()
    anim.save_gif("result")
    plt.show()

上記で基本的なグラフの描画はできるが、animatplotはmatplotのsubplot機能が利用できるので複数のグラフを1つの画面で表示できる。 (ただし、すべてのデータの時間軸を合わせておく必要があるので注意する。) subplotを使ってXYのデータとセンサーの時間軸データを1画面に表示する例を示す。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import animatplot as amp

#グラフの描画
def plot_animation(ref_df):
    #X軸・Y軸のデータ取得
    X_data = 0
    Y_data = 0
    #refの経路描画
    time_data_np = np.array(ref_df["time"])
    x_np = np.array(ref_df["x"])
    y_np = np.array(ref_df["y"])
    sensor_1_np = np.array(ref_df["sensor1"])
    sensor_2_np = np.array(ref_df["sensor2"])
    sensor_3_np = np.array(ref_df["sensor3"])

    Xs_log =np.asarray([x_np[t:t+10] for t in range(len(time_data_np)-10)]) #X軸データ × 時間軸 分の配列
    Ys_log =[y_np[t:t+10] for t in range(len(time_data_np)-10)]             #Y軸データ × 時間軸 分の配列
    sensor_1_log =[sensor_1_np[t:t+10] for t in range(len(time_data_np)-10)]
    sensor_2_log =[sensor_2_np[t:t+10] for t in range(len(time_data_np)-10)]
    sensor_3_log =[sensor_3_np[t:t+10] for t in range(len(time_data_np)-10)]
    Time_log =np.asarray([time_data_np[t:t+10] for t in range(len(time_data_np)-10)])

    #subplotの描画 (X-Yの情報を3行分の画面で表示)
    ax1 = plt.subplot2grid((3,2), (0,0), rowspan=3)
    ax2 = plt.subplot2grid((3,2), (0,1))
    ax3 = plt.subplot2grid((3,2), (1,1))
    ax4 = plt.subplot2grid((3,2), (2,1))

    ax1.set_xlim([x_np.min(), x_np.max()])      #描画範囲の設定
    ax1.set_ylim([y_np.min(),y_np.max()])       #描画範囲の設定
    block = amp.blocks.Scatter(Xs_log, Ys_log,label="X_Y",ax=ax1)

    block2 = amp.blocks.Line(Time_log, sensor_1_log, label="sensor1",ax=ax2)
    block3 = amp.blocks.Line(Time_log, sensor_2_log, label="sensor2",ax=ax3)
    block4 = amp.blocks.Line(Time_log, sensor_3_log, label="sensor3",ax=ax4)

    ax2.set_xlim([time_data_np.min(), time_data_np.max()])    #描画範囲の設定
    ax2.set_ylim([sensor_1_np.min(),sensor_1_np.max()])       #描画範囲の設定
    ax3.set_xlim([time_data_np.min(), time_data_np.max()])    #描画範囲の設定
    ax3.set_ylim([sensor_1_np.min(),sensor_1_np.max()])       #描画範囲の設定
    ax4.set_xlim([time_data_np.min(), time_data_np.max()])    #描画範囲の設定
    ax4.set_ylim([sensor_1_np.min(),sensor_1_np.max()])       #描画範囲の設定
    ax1.legend()
    ax2.legend()
    ax3.legend()
    ax4.legend()
    plt.subplots_adjust(wspace=0.4, hspace=0.6)

    anim = amp.Animation([block,block2,block3,block4])
    anim.controls()
    anim.save_gif("result")
    plt.show()


if __name__ == '__main__':
    
    csv_file_path = "data\\plotdata.csv"

    #CSVの読み込み
    ref_df = pd.read_csv(csv_file_path, encoding="utf-8-sig")    #日本語データ(Shift-Jis)を含む場合を想定
    plot_animation(ref_df)

    print("finished!")

f:id:sd08419ttic:20200302235752g:plain