Remrinのpython攻略日記

python3に入門しました。python3についてあれこれとサンプルコードとか。

NumPyの使い方(6) 条件制御

NumPyの条件制御について。
 
if~elseで条件分岐をするとき

a = 5
if a >= 0:
    b = 1
else:
    b = -1

と書けますが、これを3項演算子を使い以下のようにも書けます。

b = 1 if a >= 0 else -1

 
似たような動作を組み込みのリスト型で行うとき、
例えば、Trueならlist1から、そうでなければlist2から要素を選ぶとき、

list1 = [1, 2, 3, 4, 5]
list2 = [11, 12, 13, 14, 15]
condi = [True, False, True, False, True]

list3 = [x if c else y for (c, x, y) in zip(condi, list1, list2)]
print(list3)   # [1, 12, 3, 14, 5]

 
NumPyではnp.where()を使って
np.where(条件, x, y)のように書きます。
x,yは配列または数値。数値ならブロードキャストされる。

a3 = np.where(condi, list1, list2)
print(a3)     # [ 1 12  3 14  5]

 
他の例も

a1 = np.arange(9).reshape(3, 3)
print(a1)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

# 3未満の要素を0で置き換える(1)
print(np.where(a1 < 3, 0, a1))
# [[0 0 0]
#  [3 4 5]
#  [6 7 8]]

# 3未満の要素を0で置き換える(2)
print(np.where(a1 >= 3, a1, 0))
# [[0 0 0]
#  [3 4 5]
#  [6 7 8]]

# 正なら1、正でないなら-1にする
print(np.where(a1 > 0, 1, -1))
# [[-1  1  1]
#  [ 1  1  1]
#  [ 1  1  1]]

 
np.where()の引数を条件の配列だけにすると、
Trueとなる要素のインデックスを取得。

a1 = np.arange(9).reshape(3, 3)
print(a1)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

# Trueとなる要素のインデックスを取得
print(np.where(a1 > 5))
# (array([2, 2, 2], dtype=int32), array([0, 1, 2], dtype=int32))

# そのインデックスを利用して、要素を取り出せる。
print(a1[np.where(a1 > 5)])
# [6 7 8]

# Trueとなる要素のインデックスを取得
print(np.where(a1 > 2))
# (array([1, 1, 1, 2, 2, 2], dtype=int32), array([0, 1, 2, 0, 1, 2], dtype=int32))

# そのインデックスを利用して、要素を取り出せる。
print(a1[np.where(a1 > 2)])
# [3 4 5 6 7 8]

# 値の検索
target = [3, 4, 7]
ix = np.in1d(a1.ravel(), target).reshape(a1.shape)
print(ix)
# [[False False False]
#  [ True  True False]
#  [False  True False]]

# そのインデックスを表示
print(np.where(ix))
# (array([1, 1, 2], dtype=int32), array([0, 1, 1], dtype=int32))

 

2つの条件を同時に判断

cond1がTrueでcond2もTrue 3
cond1がTrueでcond2はFalse 2
cond1がFalseでcond2はTrue 1
cond1がFalseでcond2もFalse 0

としたい場合は、python標準だとif, elifなどを組み合わせて書けるが、
np.whereを3回ネストしても書ける

cond1 = np.array([True, True, False, False])
cond2 = np.array([True, False, True, False])
result = np.where(cond1 & cond2, 3,
                  np.where(cond1, 2,
                           np.where(cond2, 1, 0)))
print(result)   # [3 2 1 0]

True=1, False=0であることを使い次のようにも書ける。

result = cond1 * 2 + cond2
print(result)   # [3 2 1 0]