[yt-svn] commit/yt: 10 new changesets

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Mon Jul 31 08:37:22 PDT 2017


10 new commits in yt:

https://bitbucket.org/yt_analysis/yt/commits/f05eb312ccef/
Changeset:   f05eb312ccef
User:        Alex Lindsay
Date:        2017-07-25 15:59:03+00:00
Summary:     Atomize functions (#1505)
Affected #:  1 file

diff -r 85cb5b40cd70e9772c4f581e8133692d5f73ed9e -r f05eb312ccef2c51e578de80517dc056c82c0115 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -116,39 +116,169 @@
         """
         self.start_point = _validate_point(start_point, ds, start=True)
         self.end_point = _validate_point(end_point, ds)
-        self.npoints = npoints
-        self._x_unit = None
-        self._y_units = {}
-        self._titles = {}
 
-        data_source = ds.all_data()
+        self._initialize_instance(self, ds, fields, npoints, figure_size, fontsize)
 
-        self.fields = data_source._determine_fields(fields)
-        self.plots = LinePlotDictionary(data_source)
-        self.include_legend = defaultdict(bool)
         if labels is None:
             self.labels = {}
         else:
             self.labels = labels
-
-        super(LinePlot, self).__init__(data_source, figure_size, fontsize)
-
         for f in self.fields:
             if f not in self.labels:
                 self.labels[f] = f[1]
-            finfo = self.data_source.ds._get_field_info(*f)
-            if finfo.take_log:
-                self._field_transform[f] = log_transform
-            else:
-                self._field_transform[f] = linear_transform
 
         self._setup_plots()
 
+    @classmethod
+    def _initialize_instance(cls, obj, ds, fields, npoints, figure_size=5.,
+                             fontsize=14.):
+        obj.npoints = npoints
+        obj._x_unit = None
+        obj._y_units = {}
+        obj._titles = {}
+
+        data_source = ds.all_data()
+
+        obj.fields = data_source._determine_fields(fields)
+        obj.plots = LinePlotDictionary(data_source)
+        obj.include_legend = defaultdict(bool)
+        super(LinePlot, obj).__init__(data_source, figure_size, fontsize)
+        for f in obj.fields:
+            finfo = obj.data_source.ds._get_field_info(*f)
+            if finfo.take_log:
+                obj._field_transform[f] = log_transform
+            else:
+                obj._field_transform[f] = linear_transform
+
+
     @invalidate_plot
     def add_legend(self, field):
         """Adds a legend to the `LinePlot` instance"""
         self.include_legend[field] = True
 
+    @classmethod
+    def from_lines(cls, ds, fields, start_points, end_points, npoints,
+                   figure_size=5., font_size=14., labels=None):
+        """
+        A class method for constructing a line plot from multiple sampling lines
+
+        Parameters
+        ----------
+
+        ds : :class:`yt.data_objects.static_output.Dataset`
+            This is the dataset object corresponding to the
+            simulation output to be plotted.
+        fields : string / tuple, or list of strings / tuples
+            The name(s) of the field(s) to be plotted.
+        start_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
+            Each element of the outer iterable contains the coordinates of a starting
+            point for constructing a line.
+        end_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
+            Each element of the outer iterable contains the coordinates of an ending
+            point for constructing a line.
+        npoints : int
+            How many points to sample between start_point and end_point for
+            constructing the line plot
+        figure_size : int or two-element iterable of ints
+            Size in inches of the image.
+            Default: 5 (5x5)
+        fontsize : int
+            Font size for all text in the plot.
+            Default: 14
+        labels : dictionary
+            Keys should be the field names. Values should be latex-formattable
+            strings used in the LinePlot legend
+            Default: None
+        """
+        return 0
+
+    def _get_plot_instance(self, field):
+        fontscale = self._font_properties._size / 14.
+        top_buff_size = 0.35*fontscale
+
+        x_axis_size = 1.35*fontscale
+        y_axis_size = 0.7*fontscale
+        right_buff_size = 0.2*fontscale
+
+        if iterable(self.figure_size):
+            figure_size = self.figure_size
+        else:
+            figure_size = (self.figure_size, self.figure_size)
+
+        xbins = np.array([x_axis_size, figure_size[0],
+                          right_buff_size])
+        ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
+
+        size = [xbins.sum(), ybins.sum()]
+
+        x_frac_widths = xbins/size[0]
+        y_frac_widths = ybins/size[1]
+
+        axrect = (
+            x_frac_widths[0],
+            y_frac_widths[0],
+            x_frac_widths[1],
+            y_frac_widths[1],
+        )
+
+        try:
+            plot = self.plots[field]
+        except KeyError:
+            plot = PlotMPL(self.figure_size, axrect, None, None)
+            self.plots[field] = plot
+        return plot
+
+    def _plot_xy(self, field, plot, x, y, dimensions_counter):
+        if self._x_unit is None:
+            unit_x = x.units
+        else:
+            unit_x = self._x_unit
+
+        if field in self._y_units:
+            unit_y = self._y_units[field]
+        else:
+            unit_y = y.units
+
+        x = x.to(unit_x)
+        y = y.to(unit_y)
+
+        plot.axes.plot(x, y, label=self.labels[field])
+
+        if self._field_transform[field] != linear_transform:
+            if (y < 0).any():
+                plot.axes.set_yscale('symlog')
+            else:
+                plot.axes.set_yscale('log')
+
+        plot._set_font_properties(self._font_properties, None)
+
+        axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
+
+        finfo = self.ds.field_info[field]
+
+        x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+
+        finfo = self.ds.field_info[field]
+        dimensions = Unit(finfo.units,
+                          registry=self.ds.unit_registry).dimensions
+        dimensions_counter[dimensions] += 1
+        if dimensions_counter[dimensions] > 1:
+            y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+                       axes_unit_labels[1]+'}$')
+        else:
+            y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+                       axes_unit_labels[1]+'}$')
+
+        plot.axes.set_xlabel(x_label)
+        plot.axes.set_ylabel(y_label)
+
+        if field in self._titles:
+            plot.axes.set_title(self._titles[field])
+
+        if self.include_legend[field]:
+            plot.axes.legend()
+
+
     def _setup_plots(self):
         if self._plot_valid is True:
             return
@@ -156,91 +286,13 @@
             plot.axes.cla()
         dimensions_counter = defaultdict(int)
         for field in self.fields:
-            fontscale = self._font_properties._size / 14.
-            top_buff_size = 0.35*fontscale
-
-            x_axis_size = 1.35*fontscale
-            y_axis_size = 0.7*fontscale
-            right_buff_size = 0.2*fontscale
-
-            if iterable(self.figure_size):
-                figure_size = self.figure_size
-            else:
-                figure_size = (self.figure_size, self.figure_size)
-
-            xbins = np.array([x_axis_size, figure_size[0],
-                              right_buff_size])
-            ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
-
-            size = [xbins.sum(), ybins.sum()]
-
-            x_frac_widths = xbins/size[0]
-            y_frac_widths = ybins/size[1]
-
-            axrect = (
-                x_frac_widths[0],
-                y_frac_widths[0],
-                x_frac_widths[1],
-                y_frac_widths[1],
-            )
-
-            try:
-                plot = self.plots[field]
-            except KeyError:
-                plot = PlotMPL(self.figure_size, axrect, None, None)
-                self.plots[field] = plot
+            plot = self._get_plot_instance(field)
 
             x, y = self.ds.coordinates.pixelize_line(
                 field, self.start_point, self.end_point, self.npoints)
 
-            if self._x_unit is None:
-                unit_x = x.units
-            else:
-                unit_x = self._x_unit
-
-            if field in self._y_units:
-                unit_y = self._y_units[field]
-            else:
-                unit_y = y.units
-
-            x = x.to(unit_x)
-            y = y.to(unit_y)
-
-            plot.axes.plot(x, y, label=self.labels[field])
-
-            if self._field_transform[field] != linear_transform:
-                if (y < 0).any():
-                    plot.axes.set_yscale('symlog')
-                else:
-                    plot.axes.set_yscale('log')
-
-            plot._set_font_properties(self._font_properties, None)
+            self._plot_xy(field, plot, x, y, dimensions_counter)
 
-            axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
-            finfo = self.ds.field_info[field]
-
-            x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
-
-            finfo = self.ds.field_info[field]
-            dimensions = Unit(finfo.units,
-                              registry=self.ds.unit_registry).dimensions
-            dimensions_counter[dimensions] += 1
-            if dimensions_counter[dimensions] > 1:
-                y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
-                           axes_unit_labels[1]+'}$')
-            else:
-                y_label = (finfo.get_latex_display_name() + r'$\rm{' +
-                           axes_unit_labels[1]+'}$')
-
-            plot.axes.set_xlabel(x_label)
-            plot.axes.set_ylabel(y_label)
-
-            if field in self._titles:
-                plot.axes.set_title(self._titles[field])
-
-            if self.include_legend[field]:
-                plot.axes.legend()
 
     @invalidate_plot
     def set_x_unit(self, unit_name):


https://bitbucket.org/yt_analysis/yt/commits/f6e0afb30183/
Changeset:   f6e0afb30183
User:        Alex Lindsay
Date:        2017-07-25 22:00:34+00:00
Summary:     Finish proof of concept demo. Need to clean up labeling. (#1505)
Affected #:  3 files

diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
     write_bitmap, write_image, \
     apply_colormap, scale_image, write_projection, \
     SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
-    ProjectionPlot, OffAxisProjectionPlot, \
+    LineObject, ProjectionPlot, OffAxisProjectionPlot, \
     show_colormaps, add_cmap, make_colormap, \
     ProfilePlot, PhasePlot, ParticlePhasePlot, \
     ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \

diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -55,7 +55,8 @@
     plot_2d
 
 from .line_plot import \
-    LinePlot
+    LinePlot, \
+    LineObject
 
 from .profile_plotter import \
     ProfilePlot, \

diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -31,6 +31,12 @@
     linear_transform, \
     invalidate_plot
 
+class LineObject(object):
+    def __init__(self, start_point, end_point, ds, label=None):
+        self.start_point = _validate_point(start_point, ds, start=True)
+        self.end_point = _validate_point(end_point, ds)
+        self.label = label
+
 class LinePlotDictionary(PlotDictionary):
     def __init__(self, data_source):
         super(LinePlotDictionary, self).__init__(data_source)
@@ -150,15 +156,9 @@
             else:
                 obj._field_transform[f] = linear_transform
 
-
-    @invalidate_plot
-    def add_legend(self, field):
-        """Adds a legend to the `LinePlot` instance"""
-        self.include_legend[field] = True
-
     @classmethod
-    def from_lines(cls, ds, fields, start_points, end_points, npoints,
-                   figure_size=5., font_size=14., labels=None):
+    def from_lines(cls, ds, fields, lines, npoints,
+                   figure_size=5., font_size=14.):
         """
         A class method for constructing a line plot from multiple sampling lines
 
@@ -170,12 +170,6 @@
             simulation output to be plotted.
         fields : string / tuple, or list of strings / tuples
             The name(s) of the field(s) to be plotted.
-        start_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
-            Each element of the outer iterable contains the coordinates of a starting
-            point for constructing a line.
-        end_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
-            Each element of the outer iterable contains the coordinates of an ending
-            point for constructing a line.
         npoints : int
             How many points to sample between start_point and end_point for
             constructing the line plot
@@ -185,12 +179,20 @@
         fontsize : int
             Font size for all text in the plot.
             Default: 14
-        labels : dictionary
-            Keys should be the field names. Values should be latex-formattable
-            strings used in the LinePlot legend
-            Default: None
         """
-        return 0
+        obj = cls.__new__(cls)
+        cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
+
+        dimensions_counter = defaultdict(int)
+        for field in obj.fields:
+            plot = obj._get_plot_instance(field)
+            for line in lines:
+                x, y = obj.ds.coordinates.pixelize_line(
+                    field, line.start_point, line.end_point, npoints)
+                obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=line.label)
+            plot.axes.legend()
+        obj._plot_valid = True
+        return obj
 
     def _get_plot_instance(self, field):
         fontscale = self._font_properties._size / 14.
@@ -228,7 +230,7 @@
             self.plots[field] = plot
         return plot
 
-    def _plot_xy(self, field, plot, x, y, dimensions_counter):
+    def _plot_xy(self, field, plot, x, y, dimensions_counter, legend_label=None):
         if self._x_unit is None:
             unit_x = x.units
         else:
@@ -242,7 +244,7 @@
         x = x.to(unit_x)
         y = y.to(unit_y)
 
-        plot.axes.plot(x, y, label=self.labels[field])
+        plot.axes.plot(x, y, label=legend_label)
 
         if self._field_transform[field] != linear_transform:
             if (y < 0).any():
@@ -275,9 +277,6 @@
         if field in self._titles:
             plot.axes.set_title(self._titles[field])
 
-        if self.include_legend[field]:
-            plot.axes.legend()
-
 
     def _setup_plots(self):
         if self._plot_valid is True:
@@ -291,10 +290,19 @@
             x, y = self.ds.coordinates.pixelize_line(
                 field, self.start_point, self.end_point, self.npoints)
 
-            self._plot_xy(field, plot, x, y, dimensions_counter)
+            self._plot_xy(field, plot, x, y, dimensions_counter,
+                          legend_label=self.labels[field])
+
+            if self.include_legend[field]:
+                plot.axes.legend()
 
 
     @invalidate_plot
+    def add_field_legend(self, field):
+        """Adds a legend to the `LinePlot` instance"""
+        self.include_legend[field] = True
+
+    @invalidate_plot
     def set_x_unit(self, unit_name):
         """Set the unit to use along the x-axis
 


https://bitbucket.org/yt_analysis/yt/commits/94be1d31896c/
Changeset:   94be1d31896c
User:        Alex Lindsay
Date:        2017-07-25 22:49:46+00:00
Summary:     Make labels work as desired (#1505)
Affected #:  1 file

diff -r f6e0afb30183587e21b5953290e304ad64eb1719 -r 94be1d31896c83ad47bafee84815b9d925347b02 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -183,14 +183,27 @@
         obj = cls.__new__(cls)
         cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
 
-        dimensions_counter = defaultdict(int)
-        for field in obj.fields:
-            plot = obj._get_plot_instance(field)
-            for line in lines:
+        for line in lines:
+            dimensions_counter = defaultdict(int)
+            for field in obj.fields:
+                finfo = obj.ds.field_info[field]
+                dimensions = Unit(finfo.units,
+                                  registry=obj.ds.unit_registry).dimensions
+                dimensions_counter[dimensions] += 1
+
+            for field in obj.fields:
+                plot = obj._get_plot_instance(field)
                 x, y = obj.ds.coordinates.pixelize_line(
                     field, line.start_point, line.end_point, npoints)
-                obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=line.label)
-            plot.axes.legend()
+                finfo = obj.ds.field_info[field]
+                dimensions = Unit(finfo.units,
+                                  registry=obj.ds.unit_registry).dimensions
+                if dimensions_counter[dimensions] > 1:
+                    legend_label = r"$%s;$ %s" % (line.label, finfo.get_latex_display_name())
+                else:
+                    legend_label = r"$%s$" % line.label
+                obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=legend_label)
+                plot.axes.legend()
         obj._plot_valid = True
         return obj
 
@@ -263,7 +276,6 @@
         finfo = self.ds.field_info[field]
         dimensions = Unit(finfo.units,
                           registry=self.ds.unit_registry).dimensions
-        dimensions_counter[dimensions] += 1
         if dimensions_counter[dimensions] > 1:
             y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
                        axes_unit_labels[1]+'}$')
@@ -285,6 +297,11 @@
             plot.axes.cla()
         dimensions_counter = defaultdict(int)
         for field in self.fields:
+            finfo = self.ds.field_info[field]
+            dimensions = Unit(finfo.units,
+                              registry=self.ds.unit_registry).dimensions
+            dimensions_counter[dimensions] += 1
+        for field in self.fields:
             plot = self._get_plot_instance(field)
 
             x, y = self.ds.coordinates.pixelize_line(


https://bitbucket.org/yt_analysis/yt/commits/6b2f2727cfe4/
Changeset:   6b2f2727cfe4
User:        Alex Lindsay
Date:        2017-07-26 18:49:50+00:00
Summary:     remove LineObject and create LineBuffer (#1505)
Affected #:  3 files

diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
     write_bitmap, write_image, \
     apply_colormap, scale_image, write_projection, \
     SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
-    LineObject, ProjectionPlot, OffAxisProjectionPlot, \
+    LineBuffer, ProjectionPlot, OffAxisProjectionPlot, \
     show_colormaps, add_cmap, make_colormap, \
     ProfilePlot, PhasePlot, ParticlePhasePlot, \
     ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \

diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -56,7 +56,7 @@
 
 from .line_plot import \
     LinePlot, \
-    LineObject
+    LineBuffer
 
 from .profile_plotter import \
     ProfilePlot, \

diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -17,7 +17,8 @@
 
 from collections import defaultdict
 from yt.funcs import \
-    iterable
+    iterable, \
+    mylog
 from yt.units.unit_object import \
     Unit
 from yt.units.yt_array import \
@@ -31,11 +32,35 @@
     linear_transform, \
     invalidate_plot
 
-class LineObject(object):
-    def __init__(self, start_point, end_point, ds, label=None):
+class LineBuffer(object):
+    def __init__(self, ds, start_point, end_point, npoints, label=None):
+        self.ds = ds
         self.start_point = _validate_point(start_point, ds, start=True)
         self.end_point = _validate_point(end_point, ds)
+        self.npoints = npoints
         self.label = label
+        self.data = {}
+
+    def keys(self):
+        return self.data.keys()
+
+    def __setitem__(self, item, val):
+        self.data[item] = val
+
+    def __getitem__(self, item):
+        if item in self.data: return self.data[item]
+        mylog.info("Making a line buffer with %d points of %s" % \
+            (self.npoints, item))
+        self.points, self.data[item] = self.ds.coordinates.pixelize_line(item,
+                                                               self.start_point,
+                                                               self.end_point,
+                                                               self.npoints)
+
+        return self.data[item]
+
+    def __delitem__(self, item):
+        del self.data[item]
+
 
 class LinePlotDictionary(PlotDictionary):
     def __init__(self, data_source):


https://bitbucket.org/yt_analysis/yt/commits/1ad615c21613/
Changeset:   1ad615c21613
User:        Alex Lindsay
Date:        2017-07-26 19:52:48+00:00
Summary:     Change to annotate_legend and add tests
Affected #:  4 files

diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 doc/source/cookbook/simple_1d_line_plot.py
--- a/doc/source/cookbook/simple_1d_line_plot.py
+++ b/doc/source/cookbook/simple_1d_line_plot.py
@@ -8,7 +8,7 @@
 plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], (0, 0, 0), (0, 1, 0), 1000)
 
 # Add a legend
-plot.add_legend(('all', 'v'))
+plot.annotate_legend(('all', 'v'))
 
 # Save the line plot
 plot.save()

diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 doc/source/visualizing/plots.rst
--- a/doc/source/visualizing/plots.rst
+++ b/doc/source/visualizing/plots.rst
@@ -208,7 +208,7 @@
 Plots of 2D Datasets
 ~~~~~~~~~~~~~~~~~~~~
 
-If you have a two-dimensional cartesian, cylindrical, or polar dataset, 
+If you have a two-dimensional cartesian, cylindrical, or polar dataset,
 :func:`~yt.visualization.plot_window.plot_2d` is a way to make a plot
 within the dataset's plane without having to specify the axis, which
 in this case is redundant. Otherwise, ``plot_2d`` accepts the same
@@ -1114,7 +1114,7 @@
 
 1. When instantiating the ``LinePlot``, pass a dictionary of
    labels with keys corresponding to the field names
-2. Call the ``LinePlot`` ``add_legend`` method
+2. Call the ``LinePlot`` ``annotate_legend`` method
 
 X- and Y- axis units can be set with ``set_x_unit`` and ``set_unit`` methods
 respectively. The below code snippet combines all the features we've discussed:
@@ -1126,7 +1126,7 @@
    ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
 
    plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
-   plot.add_legend('density')
+   plot.annotate_legend('density')
    plot.set_x_unit('cm')
    plot.set_unit('density', 'kg/cm**3')
    plot.save()
@@ -1136,7 +1136,7 @@
 quantities. E.g. if ``LinePlot`` receives two fields with units of "length/time"
 and a field with units of "temperature", two different figures will be created,
 one with plots of the "length/time" fields and another with the plot of the
-"temperature" field. It is only necessary to call ``add_legend``
+"temperature" field. It is only necessary to call ``annotate_legend``
 for one field of a multi-field plot to produce a legend containing all the
 labels passed in the initial construction of the ``LinePlot`` instance. Example:
 
@@ -1147,7 +1147,7 @@
    ds = yt.load("SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1)
    plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], [0, 0, 0], [0, 1, 0],
                       100, labels={('all', 'u') : r"v$_x$", ('all', 'v') : r"v$_y$"})
-   plot.add_legend(('all', 'u'))
+   plot.annotate_legend(('all', 'u'))
    plot.save()
 
 ``LinePlot`` is a bit different from yt ray objects which are data

diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -147,8 +147,9 @@
         """
         self.start_point = _validate_point(start_point, ds, start=True)
         self.end_point = _validate_point(end_point, ds)
+        self.npoints = npoints
 
-        self._initialize_instance(self, ds, fields, npoints, figure_size, fontsize)
+        self._initialize_instance(self, ds, fields, figure_size, fontsize)
 
         if labels is None:
             self.labels = {}
@@ -161,9 +162,7 @@
         self._setup_plots()
 
     @classmethod
-    def _initialize_instance(cls, obj, ds, fields, npoints, figure_size=5.,
-                             fontsize=14.):
-        obj.npoints = npoints
+    def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.):
         obj._x_unit = None
         obj._y_units = {}
         obj._titles = {}
@@ -182,8 +181,7 @@
                 obj._field_transform[f] = linear_transform
 
     @classmethod
-    def from_lines(cls, ds, fields, lines, npoints,
-                   figure_size=5., font_size=14.):
+    def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14.):
         """
         A class method for constructing a line plot from multiple sampling lines
 
@@ -195,9 +193,6 @@
             simulation output to be plotted.
         fields : string / tuple, or list of strings / tuples
             The name(s) of the field(s) to be plotted.
-        npoints : int
-            How many points to sample between start_point and end_point for
-            constructing the line plot
         figure_size : int or two-element iterable of ints
             Size in inches of the image.
             Default: 5 (5x5)
@@ -206,7 +201,7 @@
             Default: 14
         """
         obj = cls.__new__(cls)
-        cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
+        cls._initialize_instance(obj, ds, fields, figure_size, font_size)
 
         for line in lines:
             dimensions_counter = defaultdict(int)
@@ -219,7 +214,7 @@
             for field in obj.fields:
                 plot = obj._get_plot_instance(field)
                 x, y = obj.ds.coordinates.pixelize_line(
-                    field, line.start_point, line.end_point, npoints)
+                    field, line.start_point, line.end_point, line.npoints)
                 finfo = obj.ds.field_info[field]
                 dimensions = Unit(finfo.units,
                                   registry=obj.ds.unit_registry).dimensions
@@ -340,7 +335,7 @@
 
 
     @invalidate_plot
-    def add_field_legend(self, field):
+    def annotate_legend(self, field):
         """Adds a legend to the `LinePlot` instance"""
         self.include_legend[field] = True
 

diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -25,14 +25,12 @@
     from yt.config import ytcfg
     ytcfg["yt", "__withintesting"] = "True"
 
-def compare(ds, fields, point1, point2, resolution, test_prefix, decimals=12):
-    def line_plot(filename_prefix):
-        ln = yt.LinePlot(ds, fields, point1, point2, resolution)
-        image_file = ln.save(filename_prefix)
-        return image_file
+def compare(ds, plot, test_prefix, decimals=12):
+    def image_from_plot(filename_prefix):
+        return plot.save(filename_prefix)
 
-    line_plot.__name__ = "line_{}".format(test_prefix)
-    test = GenericImageTest(ds, line_plot, decimals)
+    image_from_plot.__name__ = "line_{}".format(test_prefix)
+    test = GenericImageTest(ds, image_from_plot, decimals)
     test.prefix = test_prefix
     return test
 
@@ -42,7 +40,18 @@
 def test_line_plot():
     ds = data_dir_load(tri2, kwargs={'step':-1})
     fields = [field for field in ds.field_list if field[0] == 'all']
-    yield compare(ds, fields, (0, 0, 0), (1, 1, 0), 1000, "answers_line_plot")
+    plot = yt.LinePlot(ds, fields, (0, 0, 0), (1, 1, 0), 1000)
+    yield compare(ds, plot, "answers_line_plot")
+
+ at requires_ds(tri2)
+def test_multi_line_plot():
+    ds = data_dir_load(tri2, kwargs={'step':-1})
+    fields = [field for field in ds.field_list if field[0] == 'all']
+    lines = []
+    lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+    lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+    plot = yt.LinePlot.from_lines(ds, fields, lines)
+    yield compare(ds, plot, "answers_multi_line_plot")
 
 def test_line_plot_methods():
     # Perform I/O in safe place instead of yt main dir
@@ -53,7 +62,7 @@
     ds = fake_random_ds(32)
 
     plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
-    plot.add_legend('density')
+    plot.annotate_legend('density')
     plot.set_x_unit('cm')
     plot.set_unit('density', 'kg/cm**3')
     plot.save()
@@ -61,3 +70,12 @@
     os.chdir(curdir)
     # clean up
     shutil.rmtree(tmpdir)
+
+def test_line_buffer():
+    ds = fake_random_ds(32)
+    lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
+    density = lb['density']
+    lb['density'] = 0
+    vx = lb['velocity_x']
+    keys = lb.keys()
+    del lb['velocity_x']


https://bitbucket.org/yt_analysis/yt/commits/f752f79ec42d/
Changeset:   f752f79ec42d
User:        Alex Lindsay
Date:        2017-07-26 20:42:22+00:00
Summary:     Add documentation (#1505)
Affected #:  1 file

diff -r 1ad615c216135c41cea988e211097e34e3c56940 -r f752f79ec42d2b8b200b140b7463330c36ed1c14 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -33,6 +33,35 @@
     invalidate_plot
 
 class LineBuffer(object):
+    r"""
+    LineBuffer(ds, start_point, end_point, npoints, label = None)
+
+    This takes a data source and implements a protocol for generating a
+    'pixelized', fixed-resolution line buffer. In other words, LineBuffer
+    takes a starting point, ending point, and number of sampling points and
+    can subsequently generate YTArrays of field values along the sample points.
+
+    Parameters
+    ----------
+    ds : :class:`yt.data_objects.static_output.Dataset`
+        This is the dataset object holding the data that can be sampled by the
+        LineBuffer
+    start_point : n-element list, tuple, ndarray, or YTArray
+        Contains the coordinates of the first point for constructing the LineBuffer.
+        Must contain n elements where n is the dimensionality of the dataset.
+    end_point : n-element list, tuple, ndarray, or YTArray
+        Contains the coordinates of the first point for constructing the LineBuffer.
+        Must contain n elements where n is the dimensionality of the dataset.
+    npoints : int
+        How many points to sample between start_point and end_point
+
+    Examples
+    --------
+    >>> lb = yt.LineBuffer(ds, (.25, 0, 0), (.25, 1, 0), 100)
+    >>> lb[('all', 'u')].max()
+    0.11562424257143075 dimensionless
+
+    """
     def __init__(self, ds, start_point, end_point, npoints, label=None):
         self.ds = ds
         self.start_point = _validate_point(start_point, ds, start=True)
@@ -193,12 +222,25 @@
             simulation output to be plotted.
         fields : string / tuple, or list of strings / tuples
             The name(s) of the field(s) to be plotted.
+        lines : a list of :class:`yt.visualization.line_plot.LineBuffer`s
+            The lines from which to sample data
         figure_size : int or two-element iterable of ints
             Size in inches of the image.
             Default: 5 (5x5)
         fontsize : int
             Font size for all text in the plot.
             Default: 14
+
+        Example
+        --------
+        >>> ds = yt.load('SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e', step=-1)
+        >>> fields = [field for field in ds.field_list if field[0] == 'all']
+        >>> lines = []
+        >>> lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+        >>> lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+        >>> plot = yt.LinePlot.from_lines(ds, fields, lines)
+        >>> plot.save()
+
         """
         obj = cls.__new__(cls)
         cls._initialize_instance(obj, ds, fields, figure_size, font_size)


https://bitbucket.org/yt_analysis/yt/commits/ad9d85b2438f/
Changeset:   ad9d85b2438f
User:        Alex Lindsay
Date:        2017-07-26 22:27:17+00:00
Summary:     Formulate so that LinePlots created from from_lines can also use annotate_legend, set_x_unit, etc. Also enables some code condensing (#1505)
Affected #:  1 file

diff -r f752f79ec42d2b8b200b140b7463330c36ed1c14 -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -147,7 +147,7 @@
     fontsize : int
         Font size for all text in the plot.
         Default: 14
-    labels : dictionary
+    field_labels : dictionary
         Keys should be the field names. Values should be latex-formattable
         strings used in the LinePlot legend
         Default: None
@@ -170,28 +170,19 @@
     _plot_type = 'line_plot'
 
     def __init__(self, ds, fields, start_point, end_point, npoints,
-                 figure_size=5., fontsize=14., labels=None):
+                 figure_size=5., fontsize=14., field_labels=None):
         """
         Sets up figure and axes
         """
-        self.start_point = _validate_point(start_point, ds, start=True)
-        self.end_point = _validate_point(end_point, ds)
-        self.npoints = npoints
-
-        self._initialize_instance(self, ds, fields, figure_size, fontsize)
-
-        if labels is None:
-            self.labels = {}
-        else:
-            self.labels = labels
-        for f in self.fields:
-            if f not in self.labels:
-                self.labels[f] = f[1]
-
+        line = LineBuffer(ds, start_point, end_point, npoints, label=None)
+        self.lines = [line]
+        self._initialize_instance(self, ds, fields, figure_size,
+                                  fontsize, field_labels)
         self._setup_plots()
 
     @classmethod
-    def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.):
+    def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.,
+                             field_labels=None):
         obj._x_unit = None
         obj._y_units = {}
         obj._titles = {}
@@ -209,8 +200,16 @@
             else:
                 obj._field_transform[f] = linear_transform
 
+        if field_labels is None:
+            obj.field_labels = {}
+        else:
+            obj.field_labels = field_labels
+        for f in obj.fields:
+            if f not in obj.field_labels:
+                obj.field_labels[f] = f[1]
+
     @classmethod
-    def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14.):
+    def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14., field_labels=None):
         """
         A class method for constructing a line plot from multiple sampling lines
 
@@ -230,6 +229,10 @@
         fontsize : int
             Font size for all text in the plot.
             Default: 14
+        field_labels : dictionary
+            Keys should be the field names. Values should be latex-formattable
+            strings used in the LinePlot legend
+            Default: None
 
         Example
         --------
@@ -243,30 +246,9 @@
 
         """
         obj = cls.__new__(cls)
-        cls._initialize_instance(obj, ds, fields, figure_size, font_size)
-
-        for line in lines:
-            dimensions_counter = defaultdict(int)
-            for field in obj.fields:
-                finfo = obj.ds.field_info[field]
-                dimensions = Unit(finfo.units,
-                                  registry=obj.ds.unit_registry).dimensions
-                dimensions_counter[dimensions] += 1
-
-            for field in obj.fields:
-                plot = obj._get_plot_instance(field)
-                x, y = obj.ds.coordinates.pixelize_line(
-                    field, line.start_point, line.end_point, line.npoints)
-                finfo = obj.ds.field_info[field]
-                dimensions = Unit(finfo.units,
-                                  registry=obj.ds.unit_registry).dimensions
-                if dimensions_counter[dimensions] > 1:
-                    legend_label = r"$%s;$ %s" % (line.label, finfo.get_latex_display_name())
-                else:
-                    legend_label = r"$%s$" % line.label
-                obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=legend_label)
-                plot.axes.legend()
-        obj._plot_valid = True
+        obj.lines = lines
+        cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
+        obj._setup_plots()
         return obj
 
     def _get_plot_instance(self, field):
@@ -305,75 +287,89 @@
             self.plots[field] = plot
         return plot
 
-    def _plot_xy(self, field, plot, x, y, dimensions_counter, legend_label=None):
-        if self._x_unit is None:
-            unit_x = x.units
-        else:
-            unit_x = self._x_unit
-
-        if field in self._y_units:
-            unit_y = self._y_units[field]
-        else:
-            unit_y = y.units
-
-        x = x.to(unit_x)
-        y = y.to(unit_y)
-
-        plot.axes.plot(x, y, label=legend_label)
-
-        if self._field_transform[field] != linear_transform:
-            if (y < 0).any():
-                plot.axes.set_yscale('symlog')
-            else:
-                plot.axes.set_yscale('log')
-
-        plot._set_font_properties(self._font_properties, None)
-
-        axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
-        finfo = self.ds.field_info[field]
-
-        x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
-
-        finfo = self.ds.field_info[field]
-        dimensions = Unit(finfo.units,
-                          registry=self.ds.unit_registry).dimensions
-        if dimensions_counter[dimensions] > 1:
-            y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
-                       axes_unit_labels[1]+'}$')
-        else:
-            y_label = (finfo.get_latex_display_name() + r'$\rm{' +
-                       axes_unit_labels[1]+'}$')
-
-        plot.axes.set_xlabel(x_label)
-        plot.axes.set_ylabel(y_label)
-
-        if field in self._titles:
-            plot.axes.set_title(self._titles[field])
-
-
     def _setup_plots(self):
         if self._plot_valid is True:
             return
         for plot in self.plots.values():
             plot.axes.cla()
-        dimensions_counter = defaultdict(int)
-        for field in self.fields:
-            finfo = self.ds.field_info[field]
-            dimensions = Unit(finfo.units,
-                              registry=self.ds.unit_registry).dimensions
-            dimensions_counter[dimensions] += 1
-        for field in self.fields:
-            plot = self._get_plot_instance(field)
+        for line in self.lines:
+            dimensions_counter = defaultdict(int)
+            for field in self.fields:
+                finfo = self.ds.field_info[field]
+                dimensions = Unit(finfo.units,
+                                  registry=self.ds.unit_registry).dimensions
+                dimensions_counter[dimensions] += 1
+            for field in self.fields:
+                # get plot instance
+                plot = self._get_plot_instance(field)
+
+                # calculate x and y
+                x, y = self.ds.coordinates.pixelize_line(
+                    field, line.start_point, line.end_point, line.npoints)
+
+                # scale x and y to proper units
+                if self._x_unit is None:
+                    unit_x = x.units
+                else:
+                    unit_x = self._x_unit
+
+                if field in self._y_units:
+                    unit_y = self._y_units[field]
+                else:
+                    unit_y = y.units
+
+                x = x.to(unit_x)
+                y = y.to(unit_y)
+
+                # determine legend label
+                str_seq = []
+                str_seq.append(line.label)
+                str_seq.append(self.field_labels[field])
+                delim = "; "
+                legend_label = delim.join(filter(None, str_seq))
+
+                # apply plot to matplotlib axes
+                plot.axes.plot(x, y, label=legend_label)
 
-            x, y = self.ds.coordinates.pixelize_line(
-                field, self.start_point, self.end_point, self.npoints)
+                # apply log transforms if requested
+                if self._field_transform[field] != linear_transform:
+                    if (y < 0).any():
+                        plot.axes.set_yscale('symlog')
+                    else:
+                        plot.axes.set_yscale('log')
+
+                # set font properties
+                plot._set_font_properties(self._font_properties, None)
+
+                # set x and y axis labels
+                axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
+
+                finfo = self.ds.field_info[field]
+
+                x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
 
-            self._plot_xy(field, plot, x, y, dimensions_counter,
-                          legend_label=self.labels[field])
+                finfo = self.ds.field_info[field]
+                dimensions = Unit(finfo.units,
+                                  registry=self.ds.unit_registry).dimensions
+                if dimensions_counter[dimensions] > 1:
+                    y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+                               axes_unit_labels[1]+'}$')
+                else:
+                    y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+                               axes_unit_labels[1]+'}$')
 
-            if self.include_legend[field]:
-                plot.axes.legend()
+                plot.axes.set_xlabel(x_label)
+                plot.axes.set_ylabel(y_label)
+
+                # apply title
+                if field in self._titles:
+                    plot.axes.set_title(self._titles[field])
+
+                # apply legend
+                if self.include_legend[field]:
+                    plot.axes.legend()
+
+                self._plot_valid = True
 
 
     @invalidate_plot


https://bitbucket.org/yt_analysis/yt/commits/eac50268ed01/
Changeset:   eac50268ed01
User:        Alex Lindsay
Date:        2017-07-27 13:58:16+00:00
Summary:     Increment answers. Use x and y labels. Fix pep8
Affected #:  3 files

diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -67,12 +67,13 @@
     - yt/analysis_modules/photon_simulator/tests/test_spectra.py
     - yt/analysis_modules/photon_simulator/tests/test_sloshing.py
 
-  local_unstructured_008:
+  local_unstructured_009:
     - yt/visualization/volume_rendering/tests/test_mesh_render.py
     - yt/visualization/tests/test_mesh_slices.py:test_tri2
     - yt/visualization/tests/test_mesh_slices.py:test_quad2
     - yt/visualization/tests/test_mesh_slices.py:test_multi_region
     - yt/visualization/tests/test_line_plots.py:test_line_plot
+    - yt/visualization/tests/test_line_plots.py:test_multi_line_plot
 
   local_boxlib_004:
     - yt/frontends/boxlib/tests/test_outputs.py:test_radadvect

diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -344,19 +344,23 @@
                 # set x and y axis labels
                 axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
 
-                finfo = self.ds.field_info[field]
-
-                x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+                if self._xlabel is not None:
+                    x_label = self._xlabel
+                else:
+                    x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
 
-                finfo = self.ds.field_info[field]
-                dimensions = Unit(finfo.units,
-                                  registry=self.ds.unit_registry).dimensions
-                if dimensions_counter[dimensions] > 1:
-                    y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
-                               axes_unit_labels[1]+'}$')
+                if self._ylabel is not None:
+                    y_label = self._ylabel
                 else:
-                    y_label = (finfo.get_latex_display_name() + r'$\rm{' +
-                               axes_unit_labels[1]+'}$')
+                    finfo = self.ds.field_info[field]
+                    dimensions = Unit(finfo.units,
+                                      registry=self.ds.unit_registry).dimensions
+                    if dimensions_counter[dimensions] > 1:
+                        y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+                                   axes_unit_labels[1]+'}$')
+                    else:
+                        y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+                                   axes_unit_labels[1]+'}$')
 
                 plot.axes.set_xlabel(x_label)
                 plot.axes.set_ylabel(y_label)

diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -74,8 +74,8 @@
 def test_line_buffer():
     ds = fake_random_ds(32)
     lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
-    density = lb['density']
+    lb['density']
     lb['density'] = 0
-    vx = lb['velocity_x']
-    keys = lb.keys()
+    lb['velocity_x']
+    lb.keys()
     del lb['velocity_x']


https://bitbucket.org/yt_analysis/yt/commits/50bddc8e9027/
Changeset:   50bddc8e9027
User:        Alex Lindsay
Date:        2017-07-28 16:37:52+00:00
Summary:     Make sure legend label is added for every field of multi-field plot (#1505)
Affected #:  1 file

diff -r eac50268ed013076d933352ec7918c9a9962e032 -r 50bddc8e90276a78ed2db4b48118b46a2105cae1 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -370,16 +370,22 @@
                     plot.axes.set_title(self._titles[field])
 
                 # apply legend
-                if self.include_legend[field]:
+                dim_field = self.plots._sanitize_dimensions(field)
+                if self.include_legend[dim_field]:
                     plot.axes.legend()
 
-                self._plot_valid = True
+        self._plot_valid = True
 
 
     @invalidate_plot
     def annotate_legend(self, field):
-        """Adds a legend to the `LinePlot` instance"""
-        self.include_legend[field] = True
+        """
+        Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
+        call ensures that a legend label will be added for every field of
+        a multi-field plot
+        """
+        dim_field = self.plots._sanitize_dimensions(field)
+        self.include_legend[dim_field] = True
 
     @invalidate_plot
     def set_x_unit(self, unit_name):


https://bitbucket.org/yt_analysis/yt/commits/eb258efcbe83/
Changeset:   eb258efcbe83
User:        ngoldbaum
Date:        2017-07-31 15:37:07+00:00
Summary:     Merge pull request #1509 from lindsayad/multiple_lines

Allow multiple sampling lines for LinePlot
Affected #:  7 files

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 doc/source/cookbook/simple_1d_line_plot.py
--- a/doc/source/cookbook/simple_1d_line_plot.py
+++ b/doc/source/cookbook/simple_1d_line_plot.py
@@ -8,7 +8,7 @@
 plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], (0, 0, 0), (0, 1, 0), 1000)
 
 # Add a legend
-plot.add_legend(('all', 'v'))
+plot.annotate_legend(('all', 'v'))
 
 # Save the line plot
 plot.save()

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 doc/source/visualizing/plots.rst
--- a/doc/source/visualizing/plots.rst
+++ b/doc/source/visualizing/plots.rst
@@ -208,7 +208,7 @@
 Plots of 2D Datasets
 ~~~~~~~~~~~~~~~~~~~~
 
-If you have a two-dimensional cartesian, cylindrical, or polar dataset, 
+If you have a two-dimensional cartesian, cylindrical, or polar dataset,
 :func:`~yt.visualization.plot_window.plot_2d` is a way to make a plot
 within the dataset's plane without having to specify the axis, which
 in this case is redundant. Otherwise, ``plot_2d`` accepts the same
@@ -1114,7 +1114,7 @@
 
 1. When instantiating the ``LinePlot``, pass a dictionary of
    labels with keys corresponding to the field names
-2. Call the ``LinePlot`` ``add_legend`` method
+2. Call the ``LinePlot`` ``annotate_legend`` method
 
 X- and Y- axis units can be set with ``set_x_unit`` and ``set_unit`` methods
 respectively. The below code snippet combines all the features we've discussed:
@@ -1126,7 +1126,7 @@
    ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
 
    plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
-   plot.add_legend('density')
+   plot.annotate_legend('density')
    plot.set_x_unit('cm')
    plot.set_unit('density', 'kg/cm**3')
    plot.save()
@@ -1136,7 +1136,7 @@
 quantities. E.g. if ``LinePlot`` receives two fields with units of "length/time"
 and a field with units of "temperature", two different figures will be created,
 one with plots of the "length/time" fields and another with the plot of the
-"temperature" field. It is only necessary to call ``add_legend``
+"temperature" field. It is only necessary to call ``annotate_legend``
 for one field of a multi-field plot to produce a legend containing all the
 labels passed in the initial construction of the ``LinePlot`` instance. Example:
 
@@ -1147,7 +1147,7 @@
    ds = yt.load("SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1)
    plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], [0, 0, 0], [0, 1, 0],
                       100, labels={('all', 'u') : r"v$_x$", ('all', 'v') : r"v$_y$"})
-   plot.add_legend(('all', 'u'))
+   plot.annotate_legend(('all', 'u'))
    plot.save()
 
 ``LinePlot`` is a bit different from yt ray objects which are data

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -67,12 +67,13 @@
     - yt/analysis_modules/photon_simulator/tests/test_spectra.py
     - yt/analysis_modules/photon_simulator/tests/test_sloshing.py
 
-  local_unstructured_008:
+  local_unstructured_009:
     - yt/visualization/volume_rendering/tests/test_mesh_render.py
     - yt/visualization/tests/test_mesh_slices.py:test_tri2
     - yt/visualization/tests/test_mesh_slices.py:test_quad2
     - yt/visualization/tests/test_mesh_slices.py:test_multi_region
     - yt/visualization/tests/test_line_plots.py:test_line_plot
+    - yt/visualization/tests/test_line_plots.py:test_multi_line_plot
 
   local_boxlib_004:
     - yt/frontends/boxlib/tests/test_outputs.py:test_radadvect

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
     write_bitmap, write_image, \
     apply_colormap, scale_image, write_projection, \
     SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
-    ProjectionPlot, OffAxisProjectionPlot, \
+    LineBuffer, ProjectionPlot, OffAxisProjectionPlot, \
     show_colormaps, add_cmap, make_colormap, \
     ProfilePlot, PhasePlot, ParticlePhasePlot, \
     ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -55,7 +55,8 @@
     plot_2d
 
 from .line_plot import \
-    LinePlot
+    LinePlot, \
+    LineBuffer
 
 from .profile_plotter import \
     ProfilePlot, \

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -17,7 +17,8 @@
 
 from collections import defaultdict
 from yt.funcs import \
-    iterable
+    iterable, \
+    mylog
 from yt.units.unit_object import \
     Unit
 from yt.units.yt_array import \
@@ -31,6 +32,65 @@
     linear_transform, \
     invalidate_plot
 
+class LineBuffer(object):
+    r"""
+    LineBuffer(ds, start_point, end_point, npoints, label = None)
+
+    This takes a data source and implements a protocol for generating a
+    'pixelized', fixed-resolution line buffer. In other words, LineBuffer
+    takes a starting point, ending point, and number of sampling points and
+    can subsequently generate YTArrays of field values along the sample points.
+
+    Parameters
+    ----------
+    ds : :class:`yt.data_objects.static_output.Dataset`
+        This is the dataset object holding the data that can be sampled by the
+        LineBuffer
+    start_point : n-element list, tuple, ndarray, or YTArray
+        Contains the coordinates of the first point for constructing the LineBuffer.
+        Must contain n elements where n is the dimensionality of the dataset.
+    end_point : n-element list, tuple, ndarray, or YTArray
+        Contains the coordinates of the first point for constructing the LineBuffer.
+        Must contain n elements where n is the dimensionality of the dataset.
+    npoints : int
+        How many points to sample between start_point and end_point
+
+    Examples
+    --------
+    >>> lb = yt.LineBuffer(ds, (.25, 0, 0), (.25, 1, 0), 100)
+    >>> lb[('all', 'u')].max()
+    0.11562424257143075 dimensionless
+
+    """
+    def __init__(self, ds, start_point, end_point, npoints, label=None):
+        self.ds = ds
+        self.start_point = _validate_point(start_point, ds, start=True)
+        self.end_point = _validate_point(end_point, ds)
+        self.npoints = npoints
+        self.label = label
+        self.data = {}
+
+    def keys(self):
+        return self.data.keys()
+
+    def __setitem__(self, item, val):
+        self.data[item] = val
+
+    def __getitem__(self, item):
+        if item in self.data: return self.data[item]
+        mylog.info("Making a line buffer with %d points of %s" % \
+            (self.npoints, item))
+        self.points, self.data[item] = self.ds.coordinates.pixelize_line(item,
+                                                               self.start_point,
+                                                               self.end_point,
+                                                               self.npoints)
+
+        return self.data[item]
+
+    def __delitem__(self, item):
+        del self.data[item]
+
+
 class LinePlotDictionary(PlotDictionary):
     def __init__(self, data_source):
         super(LinePlotDictionary, self).__init__(data_source)
@@ -87,7 +147,7 @@
     fontsize : int
         Font size for all text in the plot.
         Default: 14
-    labels : dictionary
+    field_labels : dictionary
         Keys should be the field names. Values should be latex-formattable
         strings used in the LinePlot legend
         Default: None
@@ -110,137 +170,222 @@
     _plot_type = 'line_plot'
 
     def __init__(self, ds, fields, start_point, end_point, npoints,
-                 figure_size=5., fontsize=14., labels=None):
+                 figure_size=5., fontsize=14., field_labels=None):
         """
         Sets up figure and axes
         """
-        self.start_point = _validate_point(start_point, ds, start=True)
-        self.end_point = _validate_point(end_point, ds)
-        self.npoints = npoints
-        self._x_unit = None
-        self._y_units = {}
-        self._titles = {}
+        line = LineBuffer(ds, start_point, end_point, npoints, label=None)
+        self.lines = [line]
+        self._initialize_instance(self, ds, fields, figure_size,
+                                  fontsize, field_labels)
+        self._setup_plots()
+
+    @classmethod
+    def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.,
+                             field_labels=None):
+        obj._x_unit = None
+        obj._y_units = {}
+        obj._titles = {}
 
         data_source = ds.all_data()
 
-        self.fields = data_source._determine_fields(fields)
-        self.plots = LinePlotDictionary(data_source)
-        self.include_legend = defaultdict(bool)
-        if labels is None:
-            self.labels = {}
+        obj.fields = data_source._determine_fields(fields)
+        obj.plots = LinePlotDictionary(data_source)
+        obj.include_legend = defaultdict(bool)
+        super(LinePlot, obj).__init__(data_source, figure_size, fontsize)
+        for f in obj.fields:
+            finfo = obj.data_source.ds._get_field_info(*f)
+            if finfo.take_log:
+                obj._field_transform[f] = log_transform
+            else:
+                obj._field_transform[f] = linear_transform
+
+        if field_labels is None:
+            obj.field_labels = {}
         else:
-            self.labels = labels
+            obj.field_labels = field_labels
+        for f in obj.fields:
+            if f not in obj.field_labels:
+                obj.field_labels[f] = f[1]
+
+    @classmethod
+    def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14., field_labels=None):
+        """
+        A class method for constructing a line plot from multiple sampling lines
+
+        Parameters
+        ----------
 
-        super(LinePlot, self).__init__(data_source, figure_size, fontsize)
+        ds : :class:`yt.data_objects.static_output.Dataset`
+            This is the dataset object corresponding to the
+            simulation output to be plotted.
+        fields : string / tuple, or list of strings / tuples
+            The name(s) of the field(s) to be plotted.
+        lines : a list of :class:`yt.visualization.line_plot.LineBuffer`s
+            The lines from which to sample data
+        figure_size : int or two-element iterable of ints
+            Size in inches of the image.
+            Default: 5 (5x5)
+        fontsize : int
+            Font size for all text in the plot.
+            Default: 14
+        field_labels : dictionary
+            Keys should be the field names. Values should be latex-formattable
+            strings used in the LinePlot legend
+            Default: None
 
-        for f in self.fields:
-            if f not in self.labels:
-                self.labels[f] = f[1]
-            finfo = self.data_source.ds._get_field_info(*f)
-            if finfo.take_log:
-                self._field_transform[f] = log_transform
-            else:
-                self._field_transform[f] = linear_transform
+        Example
+        --------
+        >>> ds = yt.load('SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e', step=-1)
+        >>> fields = [field for field in ds.field_list if field[0] == 'all']
+        >>> lines = []
+        >>> lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+        >>> lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+        >>> plot = yt.LinePlot.from_lines(ds, fields, lines)
+        >>> plot.save()
+
+        """
+        obj = cls.__new__(cls)
+        obj.lines = lines
+        cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
+        obj._setup_plots()
+        return obj
+
+    def _get_plot_instance(self, field):
+        fontscale = self._font_properties._size / 14.
+        top_buff_size = 0.35*fontscale
+
+        x_axis_size = 1.35*fontscale
+        y_axis_size = 0.7*fontscale
+        right_buff_size = 0.2*fontscale
 
-        self._setup_plots()
+        if iterable(self.figure_size):
+            figure_size = self.figure_size
+        else:
+            figure_size = (self.figure_size, self.figure_size)
+
+        xbins = np.array([x_axis_size, figure_size[0],
+                          right_buff_size])
+        ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
+
+        size = [xbins.sum(), ybins.sum()]
+
+        x_frac_widths = xbins/size[0]
+        y_frac_widths = ybins/size[1]
 
-    @invalidate_plot
-    def add_legend(self, field):
-        """Adds a legend to the `LinePlot` instance"""
-        self.include_legend[field] = True
+        axrect = (
+            x_frac_widths[0],
+            y_frac_widths[0],
+            x_frac_widths[1],
+            y_frac_widths[1],
+        )
+
+        try:
+            plot = self.plots[field]
+        except KeyError:
+            plot = PlotMPL(self.figure_size, axrect, None, None)
+            self.plots[field] = plot
+        return plot
 
     def _setup_plots(self):
         if self._plot_valid is True:
             return
         for plot in self.plots.values():
             plot.axes.cla()
-        dimensions_counter = defaultdict(int)
-        for field in self.fields:
-            fontscale = self._font_properties._size / 14.
-            top_buff_size = 0.35*fontscale
-
-            x_axis_size = 1.35*fontscale
-            y_axis_size = 0.7*fontscale
-            right_buff_size = 0.2*fontscale
+        for line in self.lines:
+            dimensions_counter = defaultdict(int)
+            for field in self.fields:
+                finfo = self.ds.field_info[field]
+                dimensions = Unit(finfo.units,
+                                  registry=self.ds.unit_registry).dimensions
+                dimensions_counter[dimensions] += 1
+            for field in self.fields:
+                # get plot instance
+                plot = self._get_plot_instance(field)
 
-            if iterable(self.figure_size):
-                figure_size = self.figure_size
-            else:
-                figure_size = (self.figure_size, self.figure_size)
+                # calculate x and y
+                x, y = self.ds.coordinates.pixelize_line(
+                    field, line.start_point, line.end_point, line.npoints)
 
-            xbins = np.array([x_axis_size, figure_size[0],
-                              right_buff_size])
-            ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
-
-            size = [xbins.sum(), ybins.sum()]
+                # scale x and y to proper units
+                if self._x_unit is None:
+                    unit_x = x.units
+                else:
+                    unit_x = self._x_unit
 
-            x_frac_widths = xbins/size[0]
-            y_frac_widths = ybins/size[1]
+                if field in self._y_units:
+                    unit_y = self._y_units[field]
+                else:
+                    unit_y = y.units
 
-            axrect = (
-                x_frac_widths[0],
-                y_frac_widths[0],
-                x_frac_widths[1],
-                y_frac_widths[1],
-            )
+                x = x.to(unit_x)
+                y = y.to(unit_y)
 
-            try:
-                plot = self.plots[field]
-            except KeyError:
-                plot = PlotMPL(self.figure_size, axrect, None, None)
-                self.plots[field] = plot
+                # determine legend label
+                str_seq = []
+                str_seq.append(line.label)
+                str_seq.append(self.field_labels[field])
+                delim = "; "
+                legend_label = delim.join(filter(None, str_seq))
 
-            x, y = self.ds.coordinates.pixelize_line(
-                field, self.start_point, self.end_point, self.npoints)
+                # apply plot to matplotlib axes
+                plot.axes.plot(x, y, label=legend_label)
 
-            if self._x_unit is None:
-                unit_x = x.units
-            else:
-                unit_x = self._x_unit
+                # apply log transforms if requested
+                if self._field_transform[field] != linear_transform:
+                    if (y < 0).any():
+                        plot.axes.set_yscale('symlog')
+                    else:
+                        plot.axes.set_yscale('log')
 
-            if field in self._y_units:
-                unit_y = self._y_units[field]
-            else:
-                unit_y = y.units
+                # set font properties
+                plot._set_font_properties(self._font_properties, None)
 
-            x = x.to(unit_x)
-            y = y.to(unit_y)
+                # set x and y axis labels
+                axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
 
-            plot.axes.plot(x, y, label=self.labels[field])
+                if self._xlabel is not None:
+                    x_label = self._xlabel
+                else:
+                    x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
 
-            if self._field_transform[field] != linear_transform:
-                if (y < 0).any():
-                    plot.axes.set_yscale('symlog')
+                if self._ylabel is not None:
+                    y_label = self._ylabel
                 else:
-                    plot.axes.set_yscale('log')
-
-            plot._set_font_properties(self._font_properties, None)
-
-            axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
-            finfo = self.ds.field_info[field]
+                    finfo = self.ds.field_info[field]
+                    dimensions = Unit(finfo.units,
+                                      registry=self.ds.unit_registry).dimensions
+                    if dimensions_counter[dimensions] > 1:
+                        y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+                                   axes_unit_labels[1]+'}$')
+                    else:
+                        y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+                                   axes_unit_labels[1]+'}$')
 
-            x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+                plot.axes.set_xlabel(x_label)
+                plot.axes.set_ylabel(y_label)
+
+                # apply title
+                if field in self._titles:
+                    plot.axes.set_title(self._titles[field])
+
+                # apply legend
+                dim_field = self.plots._sanitize_dimensions(field)
+                if self.include_legend[dim_field]:
+                    plot.axes.legend()
 
-            finfo = self.ds.field_info[field]
-            dimensions = Unit(finfo.units,
-                              registry=self.ds.unit_registry).dimensions
-            dimensions_counter[dimensions] += 1
-            if dimensions_counter[dimensions] > 1:
-                y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
-                           axes_unit_labels[1]+'}$')
-            else:
-                y_label = (finfo.get_latex_display_name() + r'$\rm{' +
-                           axes_unit_labels[1]+'}$')
+        self._plot_valid = True
+
 
-            plot.axes.set_xlabel(x_label)
-            plot.axes.set_ylabel(y_label)
-
-            if field in self._titles:
-                plot.axes.set_title(self._titles[field])
-
-            if self.include_legend[field]:
-                plot.axes.legend()
+    @invalidate_plot
+    def annotate_legend(self, field):
+        """
+        Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
+        call ensures that a legend label will be added for every field of
+        a multi-field plot
+        """
+        dim_field = self.plots._sanitize_dimensions(field)
+        self.include_legend[dim_field] = True
 
     @invalidate_plot
     def set_x_unit(self, unit_name):

diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -25,14 +25,12 @@
     from yt.config import ytcfg
     ytcfg["yt", "__withintesting"] = "True"
 
-def compare(ds, fields, point1, point2, resolution, test_prefix, decimals=12):
-    def line_plot(filename_prefix):
-        ln = yt.LinePlot(ds, fields, point1, point2, resolution)
-        image_file = ln.save(filename_prefix)
-        return image_file
+def compare(ds, plot, test_prefix, decimals=12):
+    def image_from_plot(filename_prefix):
+        return plot.save(filename_prefix)
 
-    line_plot.__name__ = "line_{}".format(test_prefix)
-    test = GenericImageTest(ds, line_plot, decimals)
+    image_from_plot.__name__ = "line_{}".format(test_prefix)
+    test = GenericImageTest(ds, image_from_plot, decimals)
     test.prefix = test_prefix
     return test
 
@@ -42,7 +40,18 @@
 def test_line_plot():
     ds = data_dir_load(tri2, kwargs={'step':-1})
     fields = [field for field in ds.field_list if field[0] == 'all']
-    yield compare(ds, fields, (0, 0, 0), (1, 1, 0), 1000, "answers_line_plot")
+    plot = yt.LinePlot(ds, fields, (0, 0, 0), (1, 1, 0), 1000)
+    yield compare(ds, plot, "answers_line_plot")
+
+ at requires_ds(tri2)
+def test_multi_line_plot():
+    ds = data_dir_load(tri2, kwargs={'step':-1})
+    fields = [field for field in ds.field_list if field[0] == 'all']
+    lines = []
+    lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+    lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+    plot = yt.LinePlot.from_lines(ds, fields, lines)
+    yield compare(ds, plot, "answers_multi_line_plot")
 
 def test_line_plot_methods():
     # Perform I/O in safe place instead of yt main dir
@@ -53,7 +62,7 @@
     ds = fake_random_ds(32)
 
     plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
-    plot.add_legend('density')
+    plot.annotate_legend('density')
     plot.set_x_unit('cm')
     plot.set_unit('density', 'kg/cm**3')
     plot.save()
@@ -61,3 +70,12 @@
     os.chdir(curdir)
     # clean up
     shutil.rmtree(tmpdir)
+
+def test_line_buffer():
+    ds = fake_random_ds(32)
+    lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
+    lb['density']
+    lb['density'] = 0
+    lb['velocity_x']
+    lb.keys()
+    del lb['velocity_x']

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list