機械学習/Pythonで決定木を使う のバックアップ(No.3)


はじめに

このページはまだ書きかけです。

データ

ここでは、irisデータをサンプルとして用います。

このデータセットは,アヤメの種類(class)を花びらの長さ(sepal length),幅(sepal width),がくの長さ(petal length),幅(petal width)によって分類する問題です. 長さと幅は連続値,種類はIris-setosa, Iris-versicolor, Iris-virginicaのいずれかをとる離散値です.

DeepAnalyticsのフォーマットに倣って、訓練データを train.tsv、テストデータを test_X.tsvとして、タブ区切りのCSVファイルで保存されているものとします。

今回のサンプルファイルはこれです。

train.tsv

train.tsvはこんな感じです。

idsepal lentghsepal widthpetal lengthpetal widthclass
24.931.40.2Iris-setosa
526.43.24.51.5Iris-versicolor
1016.33.362.5Iris-virginica
..................

test_X.tsv

test_X.tsvはこんな感じです。

idsepal lengthsepal widthpetal lengthpetal width
15.13.51.40.2
5173.24.71.4
1037.135.92.1
...............

データの読み込み

pandasのread_csvを使って、タブ区切りのCSVファイルを読み込みます。 タブ区切りなのでdelimiterオプションを、先頭の列がインデックスなのでindex_colオプションを指定します。

import pandas as pd
df_iris_train = pd.read_csv('train.tsv',  delimiter='\t', index_col=0)
df_iris_test  = pd.read_csv('test_X.tsv', delimiter='\t', index_col=0)

読み込んだデータは、pandasのDataFrameとなります。 次のようにすると、Jupyter Notebook上でDataFrameを確認できます。

df_iris_train
df_iris_test

決定木の学習

まず、訓練データから、入力Xと出力yをNumPy.Arrayで取り出します。

X = df_iris_train.drop('class', axis=1).values
y = df_iris_train['class'].values

dropはDataFrameから行または列を取り除きます(axisオプションで行か列かを指定します)。 valuesはDataFrameをNumPy.Arrayに変換します。

Scikit-learnで決定木学習を使うには、sklearn.tree.DecisionTreeClassifierクラスを使います。 fitでモデル(ここでは決定木)を学習します。

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
clf.fit(X, y)

学習したモデルのスコアを確認します。 sklearn.tree.DecisionTreeClassifierのスコアは、正解率 (mean accuracy) です。

clf.score(X, y)

学習した決定木による予測

まず、テストデータの入力XをNumPy.Array形式に変換します。

$ conda install pydotplus

ここでは、訓練データの入力Xを上書きしていますので、注意しましょう。

predictメソッドで学習したモデルに基づいてラベルを予測します。

import pydotplus
from IPython.display import Image
from sklearn.externals.six import StringIO
from sklearn.tree import export_graphviz
dot = StringIO()
export_graphviz(clf, out_file=dot,
                feature_names=df_iris_test.columns,
                class_names=df_iris_train['class'].unique(),
                filled=True, rounded=True,)
graph = pydotplus.graph_from_dot_data(dot.getvalue())
Image(graph.create_png())

予測したラベルの出力

予測したラベルを、タブ区切りのCSV形式で出力します。 このとき、テストデータを参照してインデックスとします。

X = df_iris_test.values
トップ   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS