[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt/yt-3.0 (pull request #996)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Mon Jul 14 10:57:42 PDT 2014


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/8f2015a0da27/
Changeset:   8f2015a0da27
Branch:      yt-3.0
User:        MatthewTurk
Date:        2014-07-14 19:57:34
Summary:     Merged in ngoldbaum/yt/yt-3.0 (pull request #996)

Streamlining ProfilePlot and PhasePlot creation, adding tests
Affected #:  6 files

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c doc/source/visualizing/plots.rst
--- a/doc/source/visualizing/plots.rst
+++ b/doc/source/visualizing/plots.rst
@@ -772,8 +772,8 @@
    ds = yt.load("sizmbhloz-clref04SNth-rs9_a0.9011/sizmbhloz-clref04SNth-rs9_a0.9011.art")
    center = ds.arr([64.0, 64.0, 64.0], 'code_length')
    rvir = ds.quan(1e-1, "Mpccm/h")
+   sph = ds.sphere(center, rvir)
 
-   sph = ds.sphere(center, rvir)
    plot = yt.PhasePlot(sph, "density", "temperature", "cell_mass",
                        weight_field=None)
    plot.set_unit('density', 'Msun/pc**3')
@@ -782,6 +782,29 @@
    plot.set_ylim(1,1e7)
    plot.save()
 
+It is also possible to construct a custom 2D profile object and then use the
+``from_profile`` method to create a ``PhasePlot`` using the profile object.
+This will sometimes be faster, especially if you need custom x and y axes
+limits.  The following example illustrates this workflow:
+
+.. python-script::
+
+   import yt
+   ds = yt.load("sizmbhloz-clref04SNth-rs9_a0.9011/sizmbhloz-clref04SNth-rs9_a0.9011.art")
+   center = ds.arr([64.0, 64.0, 64.0], 'code_length')
+   rvir = ds.quan(1e-1, "Mpccm/h")
+   sph = ds.sphere(center, rvir)
+   units = dict(density='Msun/pc**3', cell_mass='Msun')
+   extrema = dict(density=(1e-5, 1e1), temperature=(1, 1e7))
+
+   profile = yt.create_profile(sph, ['density', 'temperature'],
+                               n_bins=[128, 128], fields=['cell_mass'],
+                               weight_field=None, units=units, extrema=extrema)
+
+   plot = yt.PhasePlot.from_profile(profile)
+
+   plot.save()
+
 Probability Distribution Functions and Accumulation
 ---------------------------------------------------
 

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c yt/data_objects/profiles.py
--- a/yt/data_objects/profiles.py
+++ b/yt/data_objects/profiles.py
@@ -1196,7 +1196,8 @@
     extrema : dict of min, max tuples
         Minimum and maximum values of the bin_fields for the profiles.
         The keys correspond to the field names. Defaults to the extrema
-        of the bin_fields of the dataset.
+        of the bin_fields of the dataset. If a units dict is provided, extrema
+        are understood to be in the units specified in the dictionary.
     logs : dict of boolean values
         Whether or not to log the bin_fields for the profiles.
         The keys correspond to the field names. Defaults to the take_log
@@ -1245,6 +1246,16 @@
         raise NotImplementedError
     bin_fields = data_source._determine_fields(bin_fields)
     fields = data_source._determine_fields(fields)
+    if units is not None:
+        dummy = {}
+        for item in units:
+            dummy[data_source._determine_fields(item)[0]] = units[item]
+        units.update(dummy)
+    if extrema is not None:
+        dummy = {}
+        for item in extrema:
+            dummy[data_source._determine_fields(item)[0]] = extrema[item]
+        extrema.update(dummy)
     if weight_field is not None:
         weight_field, = data_source._determine_fields([weight_field])
     if not iterable(n_bins):
@@ -1262,12 +1273,16 @@
     else:
         ex = []
         for bin_field in bin_fields:
-            bf_units = data_source.pf._get_field_info(bin_field[0],
-                                                      bin_field[1]).units
+            bf_units = data_source.pf._get_field_info(
+                bin_field[0], bin_field[1]).units
             try:
                 field_ex = list(extrema[bin_field[-1]])
             except KeyError:
                 field_ex = list(extrema[bin_field])
+            if bin_field in units:
+                fe = data_source.pf.arr(field_ex, units[bin_field])
+                fe.convert_to_units(bf_units)
+                field_ex = [fe[0].v, fe[1].v]
             if iterable(field_ex[0]):
                 field_ex[0] = data_source.pf.quan(field_ex[0][0], field_ex[0][1])
                 field_ex[0] = field_ex[0].in_units(bf_units)

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c yt/testing.py
--- a/yt/testing.py
+++ b/yt/testing.py
@@ -1,5 +1,5 @@
 """
-
+Utilities to aid testing.
 
 
 """

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c yt/visualization/profile_plotter.py
--- a/yt/visualization/profile_plotter.py
+++ b/yt/visualization/profile_plotter.py
@@ -31,15 +31,15 @@
     ImagePlotContainer, \
     log_transform, linear_transform
 from yt.data_objects.profiles import \
-     create_profile
+    create_profile
 from yt.utilities.exceptions import \
-     YTNotInsideNotebook
+    YTNotInsideNotebook
 from yt.utilities.logger import ytLogger as mylog
 import _mpl_imports as mpl
 from yt.funcs import \
-     ensure_list, \
-     get_image_suffix, \
-     get_ipython_api_version
+    ensure_list, \
+    get_image_suffix, \
+    get_ipython_api_version
 
 def get_canvas(name):
     suffix = get_image_suffix(name)
@@ -150,10 +150,6 @@
         A dictionary or list of dictionaries containing plot keyword 
         arguments.  For example, dict(color="red", linestyle=":").
         Default: None.
-    profiles : list of profiles
-        If not None, a list of profile objects created with 
-        `yt.data_objects.profiles.create_profile`.
-        Default: None.
 
     Examples
     --------
@@ -196,48 +192,37 @@
     z_log = None
     x_title = None
     y_title = None
-
     _plot_valid = False
 
-    def __init__(self, data_source, x_field, y_fields, 
+    def __init__(self, data_source, x_field, y_fields,
                  weight_field="cell_mass", n_bins=64,
                  accumulation=False, fractional=False,
-                 label=None, plot_spec=None, profiles=None):
-        self.y_log = {}
-        self.y_title = {}
-        self.x_log = None
-        if profiles is None:
-            self.profiles = [create_profile(data_source, [x_field],
-                                            n_bins=[n_bins],
-                                            fields=ensure_list(y_fields),
-                                            weight_field=weight_field,
-                                            accumulation=accumulation,
-                                            fractional=fractional)]
-        else:
-            self.profiles = ensure_list(profiles)
-        
-        self.label = sanitize_label(label, len(self.profiles))
+                 label=None, plot_spec=None):
 
-        self.plot_spec = plot_spec
-        if self.plot_spec is None:
-            self.plot_spec = [dict() for p in self.profiles]
-        if not isinstance(self.plot_spec, list):
-            self.plot_spec = [self.plot_spec.copy() for p in self.profiles]
+        profiles = [create_profile(data_source, [x_field],
+                                   n_bins=[n_bins],
+                                   fields=ensure_list(y_fields),
+                                   weight_field=weight_field,
+                                   accumulation=accumulation,
+                                   fractional=fractional)]
 
-        self.figures = FigureContainer()
-        self.axes = AxesContainer(self.figures)
-        self._setup_plots()
-        
+        if plot_spec is None:
+            plot_spec = [dict() for p in profiles]
+        if not isinstance(plot_spec, list):
+            plot_spec = [plot_spec.copy() for p in profiles]
+
+        ProfilePlot._initialize_instance(self, profiles, label, plot_spec)
+
     def save(self, name=None):
         r"""
-        Saves a 1d profile plot.
+         Saves a 1d profile plot.
 
-        Parameters
-        ----------
-        name : str
-            The output file keyword.
-        
-        """
+         Parameters
+         ----------
+         name : str
+             The output file keyword.
+
+         """
         if not self._plot_valid:
             self._setup_plots()
         unique = set(self.figures.values())
@@ -259,14 +244,15 @@
         if not suffix:
             suffix = ".png"
         canvas_cls = get_canvas(name)
+        fns = []
         for uid, fig in iters:
             if isinstance(uid, types.TupleType):
                 uid = uid[1]
             canvas = canvas_cls(fig)
-            fn = "%s_1d-Profile_%s_%s%s" % (prefix, xfn, uid, suffix)
-            mylog.info("Saving %s", fn)
-            canvas.print_figure(fn)
-        return self
+            fns.append("%s_1d-Profile_%s_%s%s" % (prefix, xfn, uid, suffix))
+            mylog.info("Saving %s", fns[-1])
+            canvas.print_figure(fns[-1])
+        return fns
 
     def show(self):
         r"""This will send any existing plots to the IPython notebook.
@@ -323,7 +309,7 @@
             for field, field_data in profile.items():
                 self.axes[field].plot(np.array(profile.x), np.array(field_data),
                                       label=self.label[i], **self.plot_spec[i])
-        
+
         # This relies on 'profile' leaking
         for fname, axes in self.axes.items():
             xscale, yscale = self._get_field_log(fname, profile)
@@ -338,27 +324,43 @@
         self._plot_valid = True
 
     @classmethod
+    def _initialize_instance(cls, obj, profiles, labels, plot_specs):
+        obj.y_log = {}
+        obj.y_title = {}
+        obj.x_log = None
+        obj.profiles = ensure_list(profiles)
+        obj.label = sanitize_label(labels, len(obj.profiles))
+        if plot_specs is None:
+            plot_specs = [dict() for p in obj.profiles]
+        obj.plot_spec = plot_specs
+        obj.figures = FigureContainer()
+        obj.axes = AxesContainer(obj.figures)
+        obj._setup_plots()
+        return obj
+
+    @classmethod
     def from_profiles(cls, profiles, labels=None, plot_specs=None):
         r"""
-        Instantiate a ProfilePlot object from a list of profiles 
-        created with `yt.data_objects.profiles.create_profile`.
+        Instantiate a ProfilePlot object from a list of profiles
+        created with :func:`~yt.data_objects.profiles.create_profile`.
 
         Parameters
         ----------
-        profiles : list of profiles
-            If not None, a list of profile objects created with 
-            `yt.data_objects.profiles.create_profile`.
+        profiles : a profile or list of profiles
+            A single profile or list of profile objects created with
+            :func:`~yt.data_objects.profiles.create_profile`.
         labels : list of strings
             A list of labels for each profile to be overplotted.
             Default: None.
         plot_specs : list of dicts
-            A list of dictionaries containing plot keyword 
+            A list of dictionaries containing plot keyword
             arguments.  For example, [dict(color="red", linestyle=":")].
             Default: None.
 
         Examples
         --------
 
+        >>> from yt import simulation
         >>> es = simulation("AMRCosmology.enzo", "Enzo")
         >>> es.get_time_series()
 
@@ -382,9 +384,8 @@
             raise RuntimeError("Profiles list and labels list must be the same size.")
         if plot_specs is not None and len(plot_specs) != len(profiles):
             raise RuntimeError("Profiles list and plot_specs list must be the same size.")
-        obj = cls(None, None, None, profiles=profiles, label=labels,
-                  plot_spec=plot_specs)
-        return obj
+        obj = cls.__new__(cls)
+        return cls._initialize_instance(obj, profiles, labels, plot_specs)
 
     @invalidate_plot
     def set_line_property(self, property, value, index=None):
@@ -602,7 +603,7 @@
         fractional = profile.fractional
         x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
         y_title = self.y_title.get(field_y, None) or \
-                    self._get_field_label(field_y, yfi, y_unit, fractional)
+            self._get_field_label(field_y, yfi, y_unit, fractional)
 
         return (x_title, y_title)
 
@@ -656,9 +657,6 @@
     fontsize: int
         Font size for all text in the plot.
         Default: 18.
-    font_color : str
-        Color for all text in the plot.
-        Default: "black"
     figure_size : int
         Size in inches of the image.
         Default: 8 (8x8)
@@ -691,27 +689,32 @@
     def __init__(self, data_source, x_field, y_field, z_fields,
                  weight_field="cell_mass", x_bins=128, y_bins=128,
                  accumulation=False, fractional=False,
-                 profile=None, fontsize=18, font_color="black", figure_size=8.0):
-        self.plot_title = {}
-        self.z_log = {}
-        self.z_title = {}
-        self._initfinished = False
-        self.x_log = None
-        self.y_log = None
+                 fontsize=18, figure_size=8.0):
 
-        if profile is None:
-            profile = create_profile(data_source,
-               [x_field, y_field],
-               ensure_list(z_fields),
-               n_bins = [x_bins, y_bins],
-               weight_field = weight_field,
-               accumulation=accumulation,
-               fractional=fractional)
-        self.profile = profile
-        super(PhasePlot, self).__init__(data_source, figure_size, fontsize)
-        # This is a fallback, in case we forget.
-        self._setup_plots()
-        self._initfinished = True
+        profile = create_profile(
+            data_source,
+            [x_field, y_field],
+            ensure_list(z_fields),
+            n_bins=[x_bins, y_bins],
+            weight_field=weight_field,
+            accumulation=accumulation,
+            fractional=fractional)
+
+        type(self)._initialize_instance(self, data_source, profile, fontsize, figure_size)
+
+    @classmethod
+    def _initialize_instance(cls, obj, data_source, profile, fontsize, figure_size):
+        obj.plot_title = {}
+        obj.z_log = {}
+        obj.z_title = {}
+        obj._initfinished = False
+        obj.x_log = None
+        obj.y_log = None
+        obj.profile = profile
+        super(PhasePlot, obj).__init__(data_source, figure_size, fontsize)
+        obj._setup_plots()
+        obj._initfinished = True
+        return obj
 
     def _get_field_title(self, field_z, profile):
         pf = profile.data_source.pf
@@ -729,7 +732,7 @@
         x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
         y_title = self.y_title or self._get_field_label(field_y, yfi, y_unit)
         z_title = self.z_title.get(field_z, None) or \
-                    self._get_field_label(field_z, zfi, z_unit, fractional)
+            self._get_field_label(field_z, zfi, z_unit, fractional)
         return (x_title, y_title, z_title)
 
     def _get_field_label(self, field, field_info, field_unit, fractional=False):
@@ -837,6 +840,42 @@
                     label.set_color(self._font_color)
         self._plot_valid = True
 
+    @classmethod
+    def from_profile(cls, profile, fontsize=18, figure_size=8.0):
+        r"""
+        Instantiate a PhasePlot object from a profile object created
+        with :func:`~yt.data_objects.profiles.create_profile`.
+
+        Parameters
+        ----------
+        profile : An instance of :class:`~yt.data_objects.profiles.ProfileND`
+             A single profile object.
+        fontsize : float
+             The fontsize to use, in points.
+        figure_size : float
+             The figure size to use, in inches.
+
+        Examples
+        --------
+
+        >>> import yt
+        >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
+        >>> extrema = {
+        ... 'density': (1e-31, 1e-24),
+        ... 'temperature': (1e1, 1e8),
+        ... 'cell_mass': (1e-6, 1e-1),
+        ... }
+        >>> profile = yt.create_profile(ds.all_data(), ['density', 'temperature'],
+        ...                             fields=['cell_mass'],extrema=extrema,
+        ...                             fractional=True)
+        >>> ph = yt.PhasePlot.from_profile(profile)
+        >>> ph.save()
+        """
+        obj = cls.__new__(cls)
+        data_source = profile.data_source
+        return cls._initialize_instance(obj, data_source, profile, fontsize,
+                                        figure_size)
+
     def save(self, name=None, mpl_kwargs=None):
         r"""
         Saves a 2d profile plot.

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c yt/visualization/tests/test_plotwindow.py
--- a/yt/visualization/tests/test_plotwindow.py
+++ b/yt/visualization/tests/test_plotwindow.py
@@ -61,7 +61,7 @@
     return image_type == os.path.splitext(fname)[1]
 
 
-TEST_FLNMS = [None, 'test.png', 'test.eps',
+TEST_FLNMS = [None, 'test', 'test.png', 'test.eps',
               'test.ps', 'test.pdf']
 M7 = "DD0010/moving7_0010"
 WT = "WindTunnel/windtunnel_4lev_hdf5_plt_cnt_0030"

diff -r 8676f068d48fae5fcc84d69f000686e27c74167c -r 8f2015a0da2717edbb82f2dbd4f2073b0d95b63c yt/visualization/tests/test_profile_plots.py
--- /dev/null
+++ b/yt/visualization/tests/test_profile_plots.py
@@ -0,0 +1,85 @@
+"""
+Testsuite for ProfilePlot and PhasePlot
+
+
+
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+import itertools
+import os
+import tempfile
+import shutil
+import unittest
+from yt.data_objects.profiles import create_profile
+from yt.extern.parameterized import\
+    parameterized, param
+from yt.testing import fake_random_pf
+from yt.visualization.profile_plotter import \
+    ProfilePlot, PhasePlot
+from yt.visualization.tests.test_plotwindow import \
+    assert_fname, TEST_FLNMS
+
+class TestProfilePlotSave(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        fields = ('density', 'temperature', 'velocity_x', 'velocity_y',
+                  'velocity_z')
+        units = ('g/cm**3', 'K', 'cm/s', 'cm/s', 'cm/s')
+        test_ds = fake_random_pf(64, fields=fields, units=units)
+        regions = [test_ds.region([0.5]*3, [0.4]*3, [0.6]*3), test_ds.all_data()]
+        profiles = []
+        phases = []
+        pr_fields = [('density', 'temperature'), ('density', 'velocity_x'),
+                     ('temperature', 'cell_mass'), ('density', 'radius'),
+                     ('velocity_magnitude', 'cell_mass')]
+        ph_fields = [('density', 'temperature', 'cell_mass'),
+                     ('density', 'velocity_x', 'cell_mass'),
+                     ('radius', 'temperature', 'velocity_magnitude')]
+        for reg in regions:
+            for x_field, y_field in pr_fields:
+                profiles.append(ProfilePlot(reg, x_field, y_field))
+                profiles.append(ProfilePlot(reg, x_field, y_field,
+                                            fractional=True, accumulation=True))
+                p1d = create_profile(reg, x_field, y_field)
+                profiles.append(ProfilePlot.from_profiles(p1d))
+            for x_field, y_field, z_field in ph_fields:
+                # set n_bins to [16, 16] since matplotlib's postscript
+                # renderer is slow when it has to write a lot of polygons
+                phases.append(PhasePlot(reg, x_field, y_field, z_field,
+                                        x_bins=16, y_bins=16))
+                phases.append(PhasePlot(reg, x_field, y_field, z_field,
+                                        fractional=True, accumulation=True,
+                                        x_bins=16, y_bins=16))
+                p2d = create_profile(reg, [x_field, y_field], z_field,
+                                     n_bins=[16, 16])
+                phases.append(PhasePlot.from_profile(p2d))
+        cls.profiles = profiles
+        cls.phases = phases
+        cls.ds = test_ds
+
+    def setUp(self):
+        self.tmpdir = tempfile.mkdtemp()
+        self.curdir = os.getcwd()
+        os.chdir(self.tmpdir)
+
+    def tearDown(self):
+        os.chdir(self.curdir)
+        shutil.rmtree(self.tmpdir)
+
+    @parameterized.expand(param.explicit((fname, )) for fname in TEST_FLNMS)
+    def test_profile_plot(self, fname):
+        for p in self.profiles:
+            yield assert_fname(p.save(fname)[0])
+
+    @parameterized.expand(param.explicit((fname, )) for fname in TEST_FLNMS)
+    def test_phase_plot(self, fname):
+        for p in self.phases:
+            assert assert_fname(p.save(fname)[0])

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list