2014年6月17日
PythonでPostgresデータから決定木を構築する
本記事は、原著者の許諾のもとに翻訳・掲載しております。
今回は、任意の人物の所得を人口統計データを使って予測する手法をご紹介します。使用するのは 20年前の人口統計データ です。
この例を用いて、関係データベースの情報から予測モデルを導き出す方法と、その途中で起こり得るトラブルについて触れたいと思います。
このデータの優れた点は、データの作成者が下記のようなアルゴリズムの精度をデータに添付している点です。こうした数値はスモークテストの結果評価に役立ちます。
Algorithm Error
-- ---------------- -----
1 C4.5 15.54
2 C4.5-auto 14.46
3 C4.5 rules 14.94
4 Voted ID3 (0.6) 15.64
5 Voted ID3 (0.8) 16.47
6 T2 16.84
7 1R 19.54
8 NBTree 14.10
9 CN2 16.00
10 HOODG 14.82
11 FSS Naive Bayes 14.05
12 IDTM (Decision table) 14.46
13 Naive-Bayes 16.12
14 Nearest-neighbor (1) 21.42
15 Nearest-neighbor (3) 20.35
16 OC1 15.04
このデータセットをPostgresのデータベースに読み込むには、まずデータ作成者のファイルの一番下にある空白行と「1×0 Cross Validator」と書かれた行を削除する必要があります。
次に、下記の手法でデータを読み込みます。テストデータも同じPostgresデータベースにロードしていますが、そこは気にしないでください。
ご覧の通り、PostgresではUNCパスを指定することができます。特筆するほどのことではないと思われるかもしれませんが、VisualStudioではUNCパスを読み込むことができないことを考えると、うれしい機能ではないでしょうか。
DROP TABLE income_trn;
CREATE TABLE income_trn
(age INTEGER,
workclass text,
fnlwgt INTEGER,
education text,
education_num INTEGER,
marital_status text,
occupation text,
relationship text,
race text,
sex text,
capital_gain INTEGER,
capital_loss INTEGER,
hours_per_week INTEGER,
native_country text,
category text);
COPY income_trn
FROM '\\\\nas\\Files\\Data\\income\\adult.data' DELIMITER ',' CSV;
DROP TABLE income_test;
CREATE TABLE income_test
(age INTEGER,
workclass text,
fnlwgt INTEGER,
education text,
education_num INTEGER,
marital_status text,
occupation text,
relationship text,
race text,
sex text,
capital_gain INTEGER,
capital_loss INTEGER,
hours_per_week INTEGER,
native_country text,
category text);
COPY income_test
FROM '\\\\nas\\Files\\Data\\income\\adult.test' DELIMITER ',' CSV;
Pythonの場合、こうしたデータならSQLAlchemyを使っても読み込むことができます。ただしPostgresドライバ(”pg8000″)は不安定なため、たまに次のようなエラーが起こる場合があります。
ProgrammingError: (ProgrammingError)
('ERROR', '34000',
'portal "pg8000_portal_12" does not exist')
None None
エラーの原因はさまざまですが、そのひとつとしてPostgresの旧バージョンを使用しているケースが考えられます (著者は9.3を使用しました)。旧バージョンには閉じているカーソルのデータが読み込まれてしまうという問題もあるようです。
from sqlalchemy import *
engine = create_engine(
"postgresql+pg8000://postgres:postgres@localhost/pacer",
isolation_level="READ UNCOMMITTED"
)
c = engine.connect()
meta = MetaData()
income_trn = Table('income_trn', meta, autoload=True, autoload_with=engine)
income_test = Table('income_test', meta, autoload=True, autoload_with=engine)
大量のデータを処理する場合、クエリの結果をモデルの中にストリーム処理する手法はとても有効です。ただし今回はデータサイズが小さいので、ストリーム処理は行いませんでした。また、ひとつの表に全データが入っている場合には、効率よくデータ処理ができるよう、任意にデータを半分に分割する方法を考えなければならないでしょう。
from sqlalchemy.sql import select
def get_data(table):
s = select([table])
result = c.execute(s)
return [row for row in result]
test_data = get_data(income_trn)
trn_data = get_data(income_test)
このデータには、もともとテキスト型の列と整数型の列が混ざって入っていました(職業と年齢など)。意外にもPythonの機械学習ライブラリは、このような混合データの認識が苦手なようです(少なくともディシジョンツリーは苦手です)。こうしたデータは一連のvalue値のみで構成されたデータとはまったく別物なので、特別な配慮が必要になります。
問題は、ライブラリがvalue値のリストを期待しているにも関わらず元データが数値型である、というケースです。この問題を解決するには、次のようなマッピングを行うグローバルな辞書の構築が必要になるでしょう。
maxVal = 0
vals = dict()
rev_vals = dict()
def f(x):
global maxVal
global vals
if (not x in vals):
maxVal = maxVal + 1
vals[x] = maxVal
rev_vals[maxVal] = x
return vals[x]
ここで、属性を2つに分割しなければなりません。ひとつは出力に、もうひとつは出力を予測するための属性に分割します。
def get_selectors(data):
return [ [f(x) for x in t[0:-1]] for t in data]
def get_predictors(data):
return [0 if "<" in t[14] else 1 for t in data]
trn = get_selectors(trn_data)
trn_v = get_predictors(trn_data)
この事例で最も注目すべきは、なんといってもモデルの作成が驚くほど簡単なことです。例を見てみましょう。
from sklearn import tree
clf = tree.DecisionTreeRegressor()
clf = clf.fit(trn, trn_v)
結局、テストメソッドは自前で実装することになりました。混同行列は、クラスに定義されたデータの計算が得意ではないようですね。
test = get_selectors(test_data)
test_v = get_predictors(test_data)
testsRun = 0
testsPassed = 0
for t in test:
if clf.predict(t) == test_v[testsRun]:
testsPassed = testsPassed + 1
testsRun = testsRun + 1
100 * testsPassed / testsRun
DecisionTreeClassifier: 78%
DecisionTreeRegressor: 79%
最後に、scikit-learnのドキュメントをチェックしてみてください。すべての事例に ステキな図表 がついていますね。ただ、ドキュメントを読めば分かりますが、ディシジョンツリーはかなり長くなる可能性があります。何千というルールが適用されることもあるので、よほどシンプルなケースでない限り図表化には向かないでしょう。
株式会社リクルート プロダクト統括本部 プロダクト開発統括室 グループマネジャー 株式会社ニジボックス デベロップメント室 室長 Node.js 日本ユーザーグループ代表
- Twitter: @yosuke_furukawa
- Github: yosuke-furukawa