こんにちは,米国データサイエンティストのかめ(@usdatascientist)です
データサイエンスのためのPython入門第9回目です(講座の目次はこちら).今回は,NumPyでよく使う便利な関数を紹介したいと思います.(「データサイエンスのためのPython講座」動画版がでました!詳細はこちら)
行列を作る関数はまとめて前回の記事にあるのでこちらを読んでください.今回の記事では前回習った関数をいくつか使います.
ここまできたらだいぶNumPyに慣れてきたと思います!今回の記事で必要な知識はすべて入ります!もうひといき!
目次
要素の統計量を求める
統計量というのは,最大値,最小値,平均,中央値,標準偏差などデータの特徴を要約した値です.
なにも難しいことはありません.それぞれみていきましょう.
まず,なんでもいいので行列を作ってみます.ここでは標準正規分布の5×5の行列を作ります.
1 2 |
normal_dist_mat = np.random.randn(5, 5) print(normal_dist_mat) |
1 2 3 4 5 |
[[-1.45685591 0.60322563 0.31419512 -1.11173979 0.17662918] [-0.65074201 -0.48102014 0.71527314 0.2808805 2.18018355] [-0.66471332 0.03258803 0.04345204 1.11498607 0.69945222] [ 0.99647058 1.01231696 0.31846513 -0.46606175 1.68640274] [-0.32164133 -1.32313429 0.6725172 0.34840777 -0.48340848]] |
このnormal_dist_matの統計量をそれぞれ出してみましょう.
- .max()と.argmax()
最大値を取得します.()で関数をコールすることに注意しましょう.
1 2 3 |
#max print('max is') print(normal_dist_mat.max()) |
1 2 |
max is 2.1801835547459243 |
また, .argmax() で最大値のindexを求めることができます.こちらもよく使います.
1 2 3 |
#argmax print('argmax is') print(normal_dist_mat.argmax()) |
1 2 |
argmax is 9 |
5 x5 の行列を一列にしたときのindex=9に最大値の値が格納されています.ちなみに,flatten()で一列にできます.
1 |
normal_dist_mat.flatten()[9] |
1 |
2.1801835547459243 |
- .min()と.argmin()
最小値を取得します.
1 2 3 |
#min print('min is') print(normal_dist_mat.min()) |
1 2 |
min is -1.4568559125323208 |
.argmax() 同様に, .argmin() で最小値のindexを取得できます.
1 2 3 |
#argmin print('argmin is') print(normal_dist_mat.argmin()) |
1 2 |
argmin is 0 |
- .mean()
平均値を取得します.全ての要素を足して要素数で割った値です.
1 2 3 |
#mean print('mean is') print(normal_dist_mat.mean()) |
1 2 |
mean is 0.16944515500456525 |
- np.median(ndarray)
中央値を取得します.真ん中の値です.平均値とは意味が違うことに注意してください.また,ndarrayではなく,npから呼び出すことに注意しましょう.
平均値ではなくて中央値を見ることで外れ値(分布から突出した値)の影響を抑えることができます.外れ値があると平均値はそれに引っ張られますからね.
1 2 3 |
#median print('median is') print(np.median(normal_dist_mat)) |
1 2 |
median is 0.2808805000799674 |
- .std()
「各データの値と平均の差を2乗した合計を,データ数で割った正の平方根」で求められます.この公式も非常に重要ですが,詳しくはPythonで学ぶデータサイエンス統計編第5回の記事を参照ください!統計学を本で勉強したい方はこちらの記事でお勧めの本を紹介しておりますので,ご一読ください.
68-95-99.7ルールというのを感覚的に覚えてください.(正規分布では)平均から±1標準偏差には約68%のデータが,平均から±2標準偏差には約95%のデータが,平均から±3標準偏差には約99%のデータが,含まれるという経験則です.数字自体は覚えなくとも,なんとなく感覚値を覚えておきましょう.Pythonで学ぶデータサイエンス統計編第8回でも扱っています.
1 2 3 |
#std print('standard deviation is') print(normal_dist_mat.std()) |
1 2 |
standard deviation is 0.875259797640201 |
今回私の結果ではmean=0.169, std=0.875でした.25個程度のランダムサンプリングでは,元の分布の特徴を表すには少ないことがわかります.
では,どれだけの数をサンプリングすればいいのか?そういったことを考えるのもデータサイエンスの学問です.詳しくは今後の統計学講座で扱いますね!
今回は「標準正規分布から値を取ってきた」ことを知っていますが,実業務ではサンプリングしたデータの元の分布(母集団)がどういう分布なのか全くわかりません.このように統計量を見ることによって分布の特徴をつかむことができるのです.データサイエンス必須の関数と言えます.
なお,上記の関数は全てnp.関数名(ndarray)で呼ぶことも可能です.試してみてください.(↓こんな感じ)
1 2 |
print(np.max(normal_dist_mat)) print(normal_dist_mat.max()) |
また,特定の行,列での統計量を求めたい場合は引数axisを指定します.axis=0を指定すると各列の統計量,axis=1を指定すると各行の統計量を返します.このaxisの使い方は重要です.覚えておきましょう.
1 2 3 |
print(normal_dist_mat) print('axis=0> {}'.format(normal_dist_mat.max(axis=0))) print('axis=1> {}'.format(normal_dist_mat.max(axis=1))) |
1 2 3 4 5 6 7 |
[[-1.45685591 0.60322563 0.31419512 -1.11173979 0.17662918] [-0.65074201 -0.48102014 0.71527314 0.2808805 2.18018355] [-0.66471332 0.03258803 0.04345204 1.11498607 0.69945222] [ 0.99647058 1.01231696 0.31846513 -0.46606175 1.68640274] [-0.32164133 -1.32313429 0.6725172 0.34840777 -0.48340848]] axis=0> [0.99647058 1.01231696 0.71527314 1.11498607 2.18018355] axis=1> [0.60322563 2.18018355 1.11498607 1.68640274 0.6725172 ] |
数学でよく使う関数
NumPyに組み込まれている以下の関数をよく使います.「そういえばこれNumPyにあったな」程度に覚えておきましょう.
- np.exp(ndarray)
指数関数(exponential function)です.NumPyではネイピア数e(約2.718)の指数関数になります.統計学や機械学習でよく出てくるので知らない人は一度復習しておきましょう. 細かい話はしません.
指数関数というのは一般に
$$y = a^x$$
の形をした関数です.
この導関数,すなわち微分は
$$y’ = a^x \log_e a$$
であることは高校数学で習いました.(通常自然対数\(log_e a\)の底eは省略され,\(log a\) と表記します.)
\(a=e\)のとき,\(log_e e = 1\)なので\(y=e^x\)の導関数は\(y’=e^x\)になります.つまり,微分しても変わらないってこと.
データサイエンスでは時々でてくるので覚えておきましょう.
NumPyでは np.exp() 関数で底数eの指数関数計算が可能です.
1 2 3 4 |
ndarray = np.linspace(-3, 3, 10) expndarray = np.exp(ndarray) print(ndarray) print(expndarray) |
1 2 3 4 |
[-3. -2.33333333 -1.66666667 -1. -0.33333333 0.33333333 1. 1.66666667 2.33333333 3. ] [ 0.04978707 0.09697197 0.1888756 0.36787944 0.71653131 1.39561243 2.71828183 5.29449005 10.3122585 20.08553692] |
ちなみに,横軸にndarray,縦軸にexpndarrayをとると以下のようなグラフになります.
これがいわゆる「指数関数的に◯◯」とみんなが言っているやつですね・・・
*このグラフの表示の仕方は第20回で紹介しています.
- np.log(ndarray)
1 2 3 4 |
ndarray = np.linspace(-3, 3, 10) logndarray = np.log(ndarray) print(ndarray) print(logndarray) |
1 2 3 4 |
[-3. -2.33333333 -1.66666667 -1. -0.33333333 0.33333333 1. 1.66666667 2.33333333 3. ] [ nan nan nan nan nan -1.09861229 0. 0.51082562 0.84729786 1.09861229] |
- np.nan と np.isnan()
先ほど nan がでてきましたが,これは第4回の記事で出てきたNoneとは少し違います.
1 2 |
print(logndarray[0]) print('nan == None?:{}'.format(logndarray[0] is None)) |
1 2 |
nan nan == None?:False |
この「nan」というのはNot-A-Numberです.NumPy演算で例えば0で割ったりすると出てきます.PythonのNoneとは別物なので区別しましょう.数学以外でもデータサイエンスの操作全般で頻出です.
また,Noneと同じようにnanも「nanかどうかのチェック」をすることが非常に多いです.チェックには np.isnan() を使います.(a == np.nan や a is np.nanではチェックしきれないので注意です.)
np.nanとNoneの混合は初心者がよくやるバグの原因です.
1 |
np.isnan(logndarray[0]) |
1 |
True |
- np.e
ネイピア数(2.7182….)です.以上.
1 2 |
print(np.e) print(np.log(np.e)) |
1 2 |
2.718281828459045 1.0 |
ndarrayのshape操作
私は専門がコンピュータ・ビジョンでよく深層学習のアルゴリズムを開発しますが,以下の関数をよく使うので一応紹介します.画像系じゃなければそんなに使わないかも?
なお, .shape でndarrayのサイズを確認できることを第6回で紹介しました.
1 2 3 |
ndarray = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) print('array is:\n{}'.format(ndarray)) print("ndarray's shape is:\n{}".format(ndarray.shape)) |
1 2 3 4 5 6 |
array is: [[1 2 3] [4 5 6] [7 8 9]] ndarray's shape is: (3, 3) |
自分はいつも.NumPy Arrayを作ったらすぐ .shape してます.それくらい使ってます.例えば.shapeが(3, 3)となっている場合,3x3の行列であることはわかりますが,これって(3, 3, 1)または(1, 3, 3)と同じものを意味しますよね?(もちろんndarrayでは一つ入れ子が増えるので別物ですが)
この「次元の数」のことをrankと呼びます.(3, 3)であればrankは2, (3, 3, 1)はrankは3.
どちらも同じ行列を意味しますが,プログラミングで処理する際にどちらにも対応するコードを書くことは難しいですよね?なので,あらかじめ(3, 3, 1)のndarrayがきたら(3, 3)にrankを削減する処理(または逆)を書くことが多いです.画像処理のコードを書く時は必須です.それではrank削減/追加の関数をそれぞれ紹介します.
- np.expand_dims(ndarray, axis)
rankを一つ追加します.axis=0なら一つ目の次元を,axis=-1なら最後の次元を追加します.使うのは大抵axis=0かaxis=-1です.
1 2 |
expanded_ndarray = np.expand_dims(ndarray, axis=0) expanded_ndarray.shape |
1 |
(1, 3, 3) |
- np.squeeze(ndarray)
shapeで1の次元を無くします.例えばshapeが(3, 3, 1)であればshapeを(3, 3)にします.
squeeze(絞る)するイメージです.
1 2 |
squeezed_expanded_ndarray = np.squeeze(expanded_ndarray) squeezed_expanded_ndarray.shape |
1 |
(3, 3) |
最初は次元の操作に慣れないと思います.自分でndarray作って,expand_dimsしたりsqueezeして実際の行列がどう変わっているかみてみてください.
- .flatten()
ndarrayを一列にします.行列構造を持つ必要がなくなったりしたら使います.
1 2 3 |
flatten_array = ndarray.flatten() print('flatten_array:\n{}'.format(flatten_array)) print('ndarray:\n{}'.format(ndarray)) |
1 2 3 4 5 6 |
flatten_array: [1 2 3 4 5 6 7 8 9] ndarray: [[1 2 3] [4 5 6] [7 8 9]] |
もとのndarrayは変更されないことに注意しましょう.
NumPy Arrayを保存する&読み込む
作ったArrayを.npyファイルに保存したり,それを読み込んだりすることがよくあります.
本当にいろんなところで必要になりますが,特に画像データを加工して保存することが多いです.深層学習ではNumPy形式のデータを学習に使うことがほとんどなので,NumPy形式のままで保存します.
保存も読み込みも簡単で,以下のようにしてできます.保存はnp.save(), 読み込みはnp.load()です.
- np.save(‘ファイルパス’, ndarray)
1 2 3 4 5 6 |
ndarray = np.array([ [1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400], ]) np.save('saved_numpy', ndarray) |
np.save()を実行すると,.npy形式で保存されているのがわかると思います.(Jupyter左側を確認ください.)
第一引数に’saved_numpy.npy’と入れても同じように保存されます.拡張子はつけてもつけなくてもOK!ってことです.
それでは,保存した.npyファイルを読み込みます.読み込むにはnp.load()を使います.
- np.load(‘ファイルパス’)
1 2 |
loaded_numpy = np.load('saved_numpy.npy') loaded_numpy |
こんな感じです.簡単ですね!np.load()にはファイル名(パス)を指定します..npyはつけて指定してくださいね!
まとめ
なんだか詰め込みすぎた感がありますが,,,一旦これでNumPy編は終わりです!!
今回の記事のまとめ↓
- .max()と.argmax() : ndarrayの最大値とそのindexを返す
- .min()と.argmin() : ndarrayの最小値とそのindexを返す
- .mean():ndarrayの平均値を返す
- np.median(ndarray):ndarrayの中央値を返す
- .std():ndarrayの標準偏差を返す
- np.exp(ndarray):ネイピア数の指数関数の結果を返す
- np.log(ndarray):自然対数の結果を返す
- np.nanとnp.isnan():0で割った時などにnanが格納され,nanチェックにはnp.isnan()を用いる
- np.e:ネイピア数(2.7182…)
- np.expand_dims(ndarray, axis):rankを特定のaxisに追加
- np.squeeze(ndarray):shapeで1になっている次元を削減
- .flatten():一列にする
- np.save(‘パス’, ndarray): ndarrayを指定したパスに保存
- np.load(‘パス’):指定したパスの.npyファイルを読み込む
..多い.
一気に全て覚えようとせず,Pythonでコーディングをするたびに戻ってきつつ覚えましょう.
自分も最初,全然わけわからなくて,特にndarrayの次元が意味不明でした.
しかもこれらの関数が何に使われるのかわからず,本などではただただ関数と使い方だけ説明されていて,楽しくなかった.
なので私のブログでは少しでも「どういうケースで使うのか」を例にしながら記事にしました.
もし,参考にしていただいてるようでしたらTwitterフォローお願いします!また,友人や同僚に是非オススメしてください!
それでは!次回はPandasというライブラリを扱います.一気にデータサイエンスっぽくなりますよ!
追記)次回書きました.こちらです↓