NumPyの使い方(6) 条件制御
NumPyの条件制御について。
if~elseで条件分岐をするとき
a = 5 if a >= 0: b = 1 else: b = -1
と書けますが、これを3項演算子を使い以下のようにも書けます。
b = 1 if a >= 0 else -1
似たような動作を組み込みのリスト型で行うとき、
例えば、Trueならlist1から、そうでなければlist2から要素を選ぶとき、
list1 = [1, 2, 3, 4, 5] list2 = [11, 12, 13, 14, 15] condi = [True, False, True, False, True] list3 = [x if c else y for (c, x, y) in zip(condi, list1, list2)] print(list3) # [1, 12, 3, 14, 5]
NumPyではnp.where()を使って
np.where(条件, x, y)のように書きます。
x,yは配列または数値。数値ならブロードキャストされる。
a3 = np.where(condi, list1, list2) print(a3) # [ 1 12 3 14 5]
他の例も
a1 = np.arange(9).reshape(3, 3) print(a1) # [[0 1 2] # [3 4 5] # [6 7 8]] # 3未満の要素を0で置き換える(1) print(np.where(a1 < 3, 0, a1)) # [[0 0 0] # [3 4 5] # [6 7 8]] # 3未満の要素を0で置き換える(2) print(np.where(a1 >= 3, a1, 0)) # [[0 0 0] # [3 4 5] # [6 7 8]] # 正なら1、正でないなら-1にする print(np.where(a1 > 0, 1, -1)) # [[-1 1 1] # [ 1 1 1] # [ 1 1 1]]
np.where()の引数を条件の配列だけにすると、
Trueとなる要素のインデックスを取得。
a1 = np.arange(9).reshape(3, 3) print(a1) # [[0 1 2] # [3 4 5] # [6 7 8]] # Trueとなる要素のインデックスを取得 print(np.where(a1 > 5)) # (array([2, 2, 2], dtype=int32), array([0, 1, 2], dtype=int32)) # そのインデックスを利用して、要素を取り出せる。 print(a1[np.where(a1 > 5)]) # [6 7 8] # Trueとなる要素のインデックスを取得 print(np.where(a1 > 2)) # (array([1, 1, 1, 2, 2, 2], dtype=int32), array([0, 1, 2, 0, 1, 2], dtype=int32)) # そのインデックスを利用して、要素を取り出せる。 print(a1[np.where(a1 > 2)]) # [3 4 5 6 7 8] # 値の検索 target = [3, 4, 7] ix = np.in1d(a1.ravel(), target).reshape(a1.shape) print(ix) # [[False False False] # [ True True False] # [False True False]] # そのインデックスを表示 print(np.where(ix)) # (array([1, 1, 2], dtype=int32), array([0, 1, 1], dtype=int32))
2つの条件を同時に判断
cond1がTrueでcond2もTrue | 3 |
cond1がTrueでcond2はFalse | 2 |
cond1がFalseでcond2はTrue | 1 |
cond1がFalseでcond2もFalse | 0 |
としたい場合は、python標準だとif, elifなどを組み合わせて書けるが、
np.whereを3回ネストしても書ける
cond1 = np.array([True, True, False, False]) cond2 = np.array([True, False, True, False]) result = np.where(cond1 & cond2, 3, np.where(cond1, 2, np.where(cond2, 1, 0))) print(result) # [3 2 1 0]
True=1, False=0であることを使い次のようにも書ける。
result = cond1 * 2 + cond2 print(result) # [3 2 1 0]