[yt-svn] commit/yt: xarthisius: Mask out zeros while comparing arrays in assert_rel_equal, also make sure that masks are identical

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Sun Jan 4 16:24:06 PST 2015


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/b90813248d01/
Changeset:   b90813248d01
Branch:      yt
User:        xarthisius
Date:        2015-01-04 22:56:12+00:00
Summary:     Mask out zeros while comparing arrays in assert_rel_equal, also make sure that masks are identical
Affected #:  1 file

diff -r d0d4c7a4e3c79a2c7471e98abf2bf164fb7cc024 -r b90813248d01135ba964847ae49317f2ba272c27 yt/testing.py
--- a/yt/testing.py
+++ b/yt/testing.py
@@ -28,14 +28,22 @@
 import yt.fields.api as field_api
 from yt.convenience import load
 
+
 def assert_rel_equal(a1, a2, decimals, err_msg='', verbose=True):
     # We have nan checks in here because occasionally we have fields that get
     # weighted without non-zero weights.  I'm looking at you, particle fields!
     if isinstance(a1, np.ndarray):
         assert(a1.size == a2.size)
         # Mask out NaNs
+        assert((np.isnan(a1) == np.isnan(a2)).all())
         a1[np.isnan(a1)] = 1.0
         a2[np.isnan(a2)] = 1.0
+        # Mask out 0
+        ind1 = np.array(np.abs(a1) < np.finfo(a1.dtype).eps)
+        ind2 = np.array(np.abs(a2) < np.finfo(a2.dtype).eps)
+        assert((ind1 == ind2).all())
+        a1[ind1] = 1.0
+        a2[ind2] = 1.0
     elif np.any(np.isnan(a1)) and np.any(np.isnan(a2)):
         return True
     return assert_almost_equal(np.array(a1)/np.array(a2), 1.0, decimals, err_msg=err_msg,
@@ -212,7 +220,7 @@
         fields = ("particle_position_x",
                   "particle_position_y",
                   "particle_position_z",
-                  "particle_mass", 
+                  "particle_mass",
                   "particle_velocity_x",
                   "particle_velocity_y",
                   "particle_velocity_z"),
@@ -249,12 +257,12 @@
 
     It will return a list of kwargs dicts containing combinations of
     the various kwarg values you passed it.  These can then be passed
-    to the appropriate function in nosetests. 
+    to the appropriate function in nosetests.
 
     If full=True, then every possible combination of keywords is produced,
     otherwise, every keyword option is included at least once in the output
     list.  Be careful, by using full=True, you may be in for an exponentially
-    larger number of tests! 
+    larger number of tests!
 
     keywords : dict
         a dictionary where the keys are the keywords for the function,
@@ -262,7 +270,7 @@
         can take in the function
 
    full : bool
-        if set to True, every possible combination of given keywords is 
+        if set to True, every possible combination of given keywords is
         returned
 
     Returns
@@ -279,18 +287,18 @@
     >>> list_of_kwargs = expand_keywords(keywords)
     >>> print list_of_kwargs
 
-    array([{'cmap': 'algae', 'dpi': 50}, 
+    array([{'cmap': 'algae', 'dpi': 50},
            {'cmap': 'jet', 'dpi': 100},
            {'cmap': 'algae', 'dpi': 200}], dtype=object)
 
     >>> list_of_kwargs = expand_keywords(keywords, full=True)
     >>> print list_of_kwargs
 
-    array([{'cmap': 'algae', 'dpi': 50}, 
+    array([{'cmap': 'algae', 'dpi': 50},
            {'cmap': 'algae', 'dpi': 100},
-           {'cmap': 'algae', 'dpi': 200}, 
+           {'cmap': 'algae', 'dpi': 200},
            {'cmap': 'jet', 'dpi': 50},
-           {'cmap': 'jet', 'dpi': 100}, 
+           {'cmap': 'jet', 'dpi': 100},
            {'cmap': 'jet', 'dpi': 200}], dtype=object)
 
     >>> for kwargs in list_of_kwargs:
@@ -302,8 +310,8 @@
         keys = sorted(keywords)
         list_of_kwarg_dicts = np.array([dict(zip(keys, prod)) for prod in \
                               it.product(*(keywords[key] for key in keys))])
-            
-    # if we just want to probe each keyword, but not necessarily every 
+
+    # if we just want to probe each keyword, but not necessarily every
     # combination
     else:
         # Determine the maximum number of values any of the keywords has
@@ -313,14 +321,14 @@
                 num_lists = max(1.0, num_lists)
             else:
                 num_lists = max(len(val), num_lists)
-    
+
         # Construct array of kwargs dicts, each element of the list is a different
         # **kwargs dict.  each kwargs dict gives a different combination of
         # the possible values of the kwargs
-    
+
         # initialize array
         list_of_kwarg_dicts = np.array([dict() for x in range(num_lists)])
-    
+
         # fill in array
         for i in np.arange(num_lists):
             list_of_kwarg_dicts[i] = {}
@@ -340,7 +348,7 @@
 def requires_module(module):
     """
     Decorator that takes a module name as an argument and tries to import it.
-    If the module imports without issue, the function is returned, but if not, 
+    If the module imports without issue, the function is returned, but if not,
     a null function is returned. This is so tests that depend on certain modules
     being imported will not fail if the module is not installed on the testing
     platform.
@@ -355,7 +363,7 @@
         return ffalse
     else:
         return ftrue
-    
+
 def requires_file(req_file):
     path = ytcfg.get("yt", "test_data_dir")
     def ffalse(func):
@@ -662,7 +670,7 @@
     from yt.mods import unparsed_args
     if "--answer-reference" in unparsed_args:
         return compute_results(func)
-    
+
     def compare_results(func):
         def _func(*args, **kwargs):
             name = kwargs.pop("result_basename", func.func_name)

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list