2017-02-25 35 views
0

如何将numpy数组二进制化为阈值以上的行中对应的最大值。如果数组行的最大值小于阈值,那么列1应该等于1。如何用阈值对数组进行二进制化?

a=np.array([[ 0.01, 0.3 , 0.6 ],                              
      [ 0.2 , 0.1 , 0.4 ],                              
      [ 0.7 , 0.1 , 0.3 ],                              
      [ 0.2 , 0.3 , 0.5 ],                              
      [ 0.1 , 0.7 , 0.3 ],                              
      [ 0.1 , 0.5 , 0.8 ]])  


#required output with thresold row.max>=0.6 else column 1 should be 1 
    np.array([[ 0.0 , 0.0 , 1.0 ],                              
      [ 0.0 , 1.0 , 0.0 ],                              
      [ 1.0 , 0.0 , 0.0 ],                              
      [ 0.0 , 1.0 , 0.0 ],                              
      [ 0.0 , 1.0 , 0.0 ],                              
      [ 0.0 , 0.0 , 1.0 ]])  

回答

1

一个解决方案:使用argmax和高级索引

am = a.argmax(axis=-1) 
am[a[np.arange(len(a)), am] < 0.6] = 1 
out = np.zeros_like(a) 
out[np.arange(len(a)), am] = 1 
out 
array([[ 0., 0., 1.], 
     [ 0., 1., 0.], 
     [ 1., 0., 0.], 
     [ 0., 1., 0.], 
     [ 0., 1., 0.], 
     [ 0., 0., 1.]]) 
相关问题