2015-01-06 101 views
0

我有两个变量re这两个变量都是字典,字符串作为键和csr_matrices作为值。现在我想断言他们是平等的。我该怎么做呢?Python比较字典与csr_matrices作为值

尝试1:

from scipy.sparse.csr import csr_matrix 
import numpy as np 

def test_dict_equals(self): 
    r = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])} 
    e = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])} 
    self.assertDictEqual(r, e) 

这不起作用:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(). 

尝试2:

def test_dict_equals(self): 
    r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])} 
    e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])} 
    self.assertListEqual(r.keys(), e.keys()) 
    for k in r.keys(): 
     np.testing.assert_allclose(r[k], e[k]) 

但这也不起作用:

AssertionError: First sequence is not a list: dict_keys(['a']) 

尝试3:

def test_dict_equals(self): 
    r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])} 
    e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])} 
    self.assertListEqual(list(r.keys()), list(e.keys())) 
    for k in r.keys(): 
     np.testing.assert_allclose(r[k], e[k]) 

但这也不起作用:

TypeError: ufunc 'isinf' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe'' 

回答

1

assertDictEqual功能将调用对象的方法__eq__。在源代码csr_matrix中,可以看到没有__eq__方法。

你必须写的csr_matrix一个子类,然后做断言。以下是numpy.ndarray的一个例子。代码必须相似。

import copy 
import numpy 
import unittest 

class SaneEqualityArray(numpy.ndarray): 
    def __eq__(self, other): 
     return (isinstance(other, SaneEqualityArray) and 
       self.shape == other.shape and 
       numpy.ndarray.__eq__(self, other).all()) 

class TestAsserts(unittest.TestCase): 

    def testAssert(self): 
     tests = [ 
      [1, 2], 
      {'foo': 2}, 
      [2, 'foo', {'d': 4}], 
      SaneEqualityArray([1, 2]), 
      {'foo': {'hey': SaneEqualityArray([2, 3])}}, 
      [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}}, 
      SaneEqualityArray([5, 6]), 34] 
     ] 
     for t in tests: 
      self.assertEqual(t, copy.deepcopy(t)) 

if __name__ == '__main__': 
    unittest.main() 

希望它能帮助。:)

1

忘掉字典的瞬间,并专注于比较2点sparse矩阵。它们不是numpy数组,所以不能直接使用np方法。这就是为什么你的第三次尝试不起作用。

有一个scipy.sparse单元测试目录。我没有检查它,但它可能会给你超出我在下面建议的想法。

https://github.com/scipy/scipy/tree/master/scipy/sparse/tests

A=sparse.csr_matrix(np.arange(9).reshape(3,3)) 
B=sparse.csr_matrix(np.arange(9).reshape(3,3)) 

它们是不同的对象

id(A)==id(B) # False 

它们具有

A.nnz == B.nnz # True - just a comparison of 2 numbers 

此稀疏格式的数据被包含在3个阵列,A.data相同数量的非零元素,A.indicesA.indptr。所以,你可以使用np方法来测试一个或一个以上的那些

np.allclose(A.data, B.data) # this would also compare dtype 

您也可以比较形状等。

较新版本的scipy实现了元素比较器的稀疏矩阵元素。 ==实现,但可能给你一个警告:

SparseEfficiencyWarning:比较使用==是低效的稀疏矩阵,请尝试使用替代=。

如果形状匹配,这可能是比较稀疏矩阵的一种有效的方法:

(A!=B).nnz==0 

如果形状不匹配,A!=C回报True

如果他们是小,你可以比较他们的密集等价物:

np.allclose(A.A, B.A) 
+0

谢谢。我确实现在正在测试A.data,A.indices和A.indptr。您引用的测试目录是用于执行csr_matrix的测试代码,而不是用于测试的帮助函数。 – physicalattraction

+0

注意:有可能有两个相同的矩阵,但具有不同的indptr/indices/data数组。这两个非零元素的顺序不同。这导致错误检测到不平等。 – physicalattraction