[yt-svn] commit/yt: 10 new changesets
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Mon Jul 31 08:37:22 PDT 2017
10 new commits in yt:
https://bitbucket.org/yt_analysis/yt/commits/f05eb312ccef/
Changeset: f05eb312ccef
User: Alex Lindsay
Date: 2017-07-25 15:59:03+00:00
Summary: Atomize functions (#1505)
Affected #: 1 file
diff -r 85cb5b40cd70e9772c4f581e8133692d5f73ed9e -r f05eb312ccef2c51e578de80517dc056c82c0115 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -116,39 +116,169 @@
"""
self.start_point = _validate_point(start_point, ds, start=True)
self.end_point = _validate_point(end_point, ds)
- self.npoints = npoints
- self._x_unit = None
- self._y_units = {}
- self._titles = {}
- data_source = ds.all_data()
+ self._initialize_instance(self, ds, fields, npoints, figure_size, fontsize)
- self.fields = data_source._determine_fields(fields)
- self.plots = LinePlotDictionary(data_source)
- self.include_legend = defaultdict(bool)
if labels is None:
self.labels = {}
else:
self.labels = labels
-
- super(LinePlot, self).__init__(data_source, figure_size, fontsize)
-
for f in self.fields:
if f not in self.labels:
self.labels[f] = f[1]
- finfo = self.data_source.ds._get_field_info(*f)
- if finfo.take_log:
- self._field_transform[f] = log_transform
- else:
- self._field_transform[f] = linear_transform
self._setup_plots()
+ @classmethod
+ def _initialize_instance(cls, obj, ds, fields, npoints, figure_size=5.,
+ fontsize=14.):
+ obj.npoints = npoints
+ obj._x_unit = None
+ obj._y_units = {}
+ obj._titles = {}
+
+ data_source = ds.all_data()
+
+ obj.fields = data_source._determine_fields(fields)
+ obj.plots = LinePlotDictionary(data_source)
+ obj.include_legend = defaultdict(bool)
+ super(LinePlot, obj).__init__(data_source, figure_size, fontsize)
+ for f in obj.fields:
+ finfo = obj.data_source.ds._get_field_info(*f)
+ if finfo.take_log:
+ obj._field_transform[f] = log_transform
+ else:
+ obj._field_transform[f] = linear_transform
+
+
@invalidate_plot
def add_legend(self, field):
"""Adds a legend to the `LinePlot` instance"""
self.include_legend[field] = True
+ @classmethod
+ def from_lines(cls, ds, fields, start_points, end_points, npoints,
+ figure_size=5., font_size=14., labels=None):
+ """
+ A class method for constructing a line plot from multiple sampling lines
+
+ Parameters
+ ----------
+
+ ds : :class:`yt.data_objects.static_output.Dataset`
+ This is the dataset object corresponding to the
+ simulation output to be plotted.
+ fields : string / tuple, or list of strings / tuples
+ The name(s) of the field(s) to be plotted.
+ start_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
+ Each element of the outer iterable contains the coordinates of a starting
+ point for constructing a line.
+ end_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
+ Each element of the outer iterable contains the coordinates of an ending
+ point for constructing a line.
+ npoints : int
+ How many points to sample between start_point and end_point for
+ constructing the line plot
+ figure_size : int or two-element iterable of ints
+ Size in inches of the image.
+ Default: 5 (5x5)
+ fontsize : int
+ Font size for all text in the plot.
+ Default: 14
+ labels : dictionary
+ Keys should be the field names. Values should be latex-formattable
+ strings used in the LinePlot legend
+ Default: None
+ """
+ return 0
+
+ def _get_plot_instance(self, field):
+ fontscale = self._font_properties._size / 14.
+ top_buff_size = 0.35*fontscale
+
+ x_axis_size = 1.35*fontscale
+ y_axis_size = 0.7*fontscale
+ right_buff_size = 0.2*fontscale
+
+ if iterable(self.figure_size):
+ figure_size = self.figure_size
+ else:
+ figure_size = (self.figure_size, self.figure_size)
+
+ xbins = np.array([x_axis_size, figure_size[0],
+ right_buff_size])
+ ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
+
+ size = [xbins.sum(), ybins.sum()]
+
+ x_frac_widths = xbins/size[0]
+ y_frac_widths = ybins/size[1]
+
+ axrect = (
+ x_frac_widths[0],
+ y_frac_widths[0],
+ x_frac_widths[1],
+ y_frac_widths[1],
+ )
+
+ try:
+ plot = self.plots[field]
+ except KeyError:
+ plot = PlotMPL(self.figure_size, axrect, None, None)
+ self.plots[field] = plot
+ return plot
+
+ def _plot_xy(self, field, plot, x, y, dimensions_counter):
+ if self._x_unit is None:
+ unit_x = x.units
+ else:
+ unit_x = self._x_unit
+
+ if field in self._y_units:
+ unit_y = self._y_units[field]
+ else:
+ unit_y = y.units
+
+ x = x.to(unit_x)
+ y = y.to(unit_y)
+
+ plot.axes.plot(x, y, label=self.labels[field])
+
+ if self._field_transform[field] != linear_transform:
+ if (y < 0).any():
+ plot.axes.set_yscale('symlog')
+ else:
+ plot.axes.set_yscale('log')
+
+ plot._set_font_properties(self._font_properties, None)
+
+ axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
+
+ finfo = self.ds.field_info[field]
+
+ x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ dimensions_counter[dimensions] += 1
+ if dimensions_counter[dimensions] > 1:
+ y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
+ else:
+ y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
+
+ plot.axes.set_xlabel(x_label)
+ plot.axes.set_ylabel(y_label)
+
+ if field in self._titles:
+ plot.axes.set_title(self._titles[field])
+
+ if self.include_legend[field]:
+ plot.axes.legend()
+
+
def _setup_plots(self):
if self._plot_valid is True:
return
@@ -156,91 +286,13 @@
plot.axes.cla()
dimensions_counter = defaultdict(int)
for field in self.fields:
- fontscale = self._font_properties._size / 14.
- top_buff_size = 0.35*fontscale
-
- x_axis_size = 1.35*fontscale
- y_axis_size = 0.7*fontscale
- right_buff_size = 0.2*fontscale
-
- if iterable(self.figure_size):
- figure_size = self.figure_size
- else:
- figure_size = (self.figure_size, self.figure_size)
-
- xbins = np.array([x_axis_size, figure_size[0],
- right_buff_size])
- ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
-
- size = [xbins.sum(), ybins.sum()]
-
- x_frac_widths = xbins/size[0]
- y_frac_widths = ybins/size[1]
-
- axrect = (
- x_frac_widths[0],
- y_frac_widths[0],
- x_frac_widths[1],
- y_frac_widths[1],
- )
-
- try:
- plot = self.plots[field]
- except KeyError:
- plot = PlotMPL(self.figure_size, axrect, None, None)
- self.plots[field] = plot
+ plot = self._get_plot_instance(field)
x, y = self.ds.coordinates.pixelize_line(
field, self.start_point, self.end_point, self.npoints)
- if self._x_unit is None:
- unit_x = x.units
- else:
- unit_x = self._x_unit
-
- if field in self._y_units:
- unit_y = self._y_units[field]
- else:
- unit_y = y.units
-
- x = x.to(unit_x)
- y = y.to(unit_y)
-
- plot.axes.plot(x, y, label=self.labels[field])
-
- if self._field_transform[field] != linear_transform:
- if (y < 0).any():
- plot.axes.set_yscale('symlog')
- else:
- plot.axes.set_yscale('log')
-
- plot._set_font_properties(self._font_properties, None)
+ self._plot_xy(field, plot, x, y, dimensions_counter)
- axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
- finfo = self.ds.field_info[field]
-
- x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
-
- finfo = self.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=self.ds.unit_registry).dimensions
- dimensions_counter[dimensions] += 1
- if dimensions_counter[dimensions] > 1:
- y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
- axes_unit_labels[1]+'}$')
- else:
- y_label = (finfo.get_latex_display_name() + r'$\rm{' +
- axes_unit_labels[1]+'}$')
-
- plot.axes.set_xlabel(x_label)
- plot.axes.set_ylabel(y_label)
-
- if field in self._titles:
- plot.axes.set_title(self._titles[field])
-
- if self.include_legend[field]:
- plot.axes.legend()
@invalidate_plot
def set_x_unit(self, unit_name):
https://bitbucket.org/yt_analysis/yt/commits/f6e0afb30183/
Changeset: f6e0afb30183
User: Alex Lindsay
Date: 2017-07-25 22:00:34+00:00
Summary: Finish proof of concept demo. Need to clean up labeling. (#1505)
Affected #: 3 files
diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
write_bitmap, write_image, \
apply_colormap, scale_image, write_projection, \
SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
- ProjectionPlot, OffAxisProjectionPlot, \
+ LineObject, ProjectionPlot, OffAxisProjectionPlot, \
show_colormaps, add_cmap, make_colormap, \
ProfilePlot, PhasePlot, ParticlePhasePlot, \
ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \
diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -55,7 +55,8 @@
plot_2d
from .line_plot import \
- LinePlot
+ LinePlot, \
+ LineObject
from .profile_plotter import \
ProfilePlot, \
diff -r f05eb312ccef2c51e578de80517dc056c82c0115 -r f6e0afb30183587e21b5953290e304ad64eb1719 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -31,6 +31,12 @@
linear_transform, \
invalidate_plot
+class LineObject(object):
+ def __init__(self, start_point, end_point, ds, label=None):
+ self.start_point = _validate_point(start_point, ds, start=True)
+ self.end_point = _validate_point(end_point, ds)
+ self.label = label
+
class LinePlotDictionary(PlotDictionary):
def __init__(self, data_source):
super(LinePlotDictionary, self).__init__(data_source)
@@ -150,15 +156,9 @@
else:
obj._field_transform[f] = linear_transform
-
- @invalidate_plot
- def add_legend(self, field):
- """Adds a legend to the `LinePlot` instance"""
- self.include_legend[field] = True
-
@classmethod
- def from_lines(cls, ds, fields, start_points, end_points, npoints,
- figure_size=5., font_size=14., labels=None):
+ def from_lines(cls, ds, fields, lines, npoints,
+ figure_size=5., font_size=14.):
"""
A class method for constructing a line plot from multiple sampling lines
@@ -170,12 +170,6 @@
simulation output to be plotted.
fields : string / tuple, or list of strings / tuples
The name(s) of the field(s) to be plotted.
- start_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
- Each element of the outer iterable contains the coordinates of a starting
- point for constructing a line.
- end_points : iterable of n-element lists, tuples, ndarrays, or YTArrays
- Each element of the outer iterable contains the coordinates of an ending
- point for constructing a line.
npoints : int
How many points to sample between start_point and end_point for
constructing the line plot
@@ -185,12 +179,20 @@
fontsize : int
Font size for all text in the plot.
Default: 14
- labels : dictionary
- Keys should be the field names. Values should be latex-formattable
- strings used in the LinePlot legend
- Default: None
"""
- return 0
+ obj = cls.__new__(cls)
+ cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
+
+ dimensions_counter = defaultdict(int)
+ for field in obj.fields:
+ plot = obj._get_plot_instance(field)
+ for line in lines:
+ x, y = obj.ds.coordinates.pixelize_line(
+ field, line.start_point, line.end_point, npoints)
+ obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=line.label)
+ plot.axes.legend()
+ obj._plot_valid = True
+ return obj
def _get_plot_instance(self, field):
fontscale = self._font_properties._size / 14.
@@ -228,7 +230,7 @@
self.plots[field] = plot
return plot
- def _plot_xy(self, field, plot, x, y, dimensions_counter):
+ def _plot_xy(self, field, plot, x, y, dimensions_counter, legend_label=None):
if self._x_unit is None:
unit_x = x.units
else:
@@ -242,7 +244,7 @@
x = x.to(unit_x)
y = y.to(unit_y)
- plot.axes.plot(x, y, label=self.labels[field])
+ plot.axes.plot(x, y, label=legend_label)
if self._field_transform[field] != linear_transform:
if (y < 0).any():
@@ -275,9 +277,6 @@
if field in self._titles:
plot.axes.set_title(self._titles[field])
- if self.include_legend[field]:
- plot.axes.legend()
-
def _setup_plots(self):
if self._plot_valid is True:
@@ -291,10 +290,19 @@
x, y = self.ds.coordinates.pixelize_line(
field, self.start_point, self.end_point, self.npoints)
- self._plot_xy(field, plot, x, y, dimensions_counter)
+ self._plot_xy(field, plot, x, y, dimensions_counter,
+ legend_label=self.labels[field])
+
+ if self.include_legend[field]:
+ plot.axes.legend()
@invalidate_plot
+ def add_field_legend(self, field):
+ """Adds a legend to the `LinePlot` instance"""
+ self.include_legend[field] = True
+
+ @invalidate_plot
def set_x_unit(self, unit_name):
"""Set the unit to use along the x-axis
https://bitbucket.org/yt_analysis/yt/commits/94be1d31896c/
Changeset: 94be1d31896c
User: Alex Lindsay
Date: 2017-07-25 22:49:46+00:00
Summary: Make labels work as desired (#1505)
Affected #: 1 file
diff -r f6e0afb30183587e21b5953290e304ad64eb1719 -r 94be1d31896c83ad47bafee84815b9d925347b02 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -183,14 +183,27 @@
obj = cls.__new__(cls)
cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
- dimensions_counter = defaultdict(int)
- for field in obj.fields:
- plot = obj._get_plot_instance(field)
- for line in lines:
+ for line in lines:
+ dimensions_counter = defaultdict(int)
+ for field in obj.fields:
+ finfo = obj.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=obj.ds.unit_registry).dimensions
+ dimensions_counter[dimensions] += 1
+
+ for field in obj.fields:
+ plot = obj._get_plot_instance(field)
x, y = obj.ds.coordinates.pixelize_line(
field, line.start_point, line.end_point, npoints)
- obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=line.label)
- plot.axes.legend()
+ finfo = obj.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=obj.ds.unit_registry).dimensions
+ if dimensions_counter[dimensions] > 1:
+ legend_label = r"$%s;$ %s" % (line.label, finfo.get_latex_display_name())
+ else:
+ legend_label = r"$%s$" % line.label
+ obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=legend_label)
+ plot.axes.legend()
obj._plot_valid = True
return obj
@@ -263,7 +276,6 @@
finfo = self.ds.field_info[field]
dimensions = Unit(finfo.units,
registry=self.ds.unit_registry).dimensions
- dimensions_counter[dimensions] += 1
if dimensions_counter[dimensions] > 1:
y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
axes_unit_labels[1]+'}$')
@@ -285,6 +297,11 @@
plot.axes.cla()
dimensions_counter = defaultdict(int)
for field in self.fields:
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ dimensions_counter[dimensions] += 1
+ for field in self.fields:
plot = self._get_plot_instance(field)
x, y = self.ds.coordinates.pixelize_line(
https://bitbucket.org/yt_analysis/yt/commits/6b2f2727cfe4/
Changeset: 6b2f2727cfe4
User: Alex Lindsay
Date: 2017-07-26 18:49:50+00:00
Summary: remove LineObject and create LineBuffer (#1505)
Affected #: 3 files
diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
write_bitmap, write_image, \
apply_colormap, scale_image, write_projection, \
SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
- LineObject, ProjectionPlot, OffAxisProjectionPlot, \
+ LineBuffer, ProjectionPlot, OffAxisProjectionPlot, \
show_colormaps, add_cmap, make_colormap, \
ProfilePlot, PhasePlot, ParticlePhasePlot, \
ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \
diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -56,7 +56,7 @@
from .line_plot import \
LinePlot, \
- LineObject
+ LineBuffer
from .profile_plotter import \
ProfilePlot, \
diff -r 94be1d31896c83ad47bafee84815b9d925347b02 -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -17,7 +17,8 @@
from collections import defaultdict
from yt.funcs import \
- iterable
+ iterable, \
+ mylog
from yt.units.unit_object import \
Unit
from yt.units.yt_array import \
@@ -31,11 +32,35 @@
linear_transform, \
invalidate_plot
-class LineObject(object):
- def __init__(self, start_point, end_point, ds, label=None):
+class LineBuffer(object):
+ def __init__(self, ds, start_point, end_point, npoints, label=None):
+ self.ds = ds
self.start_point = _validate_point(start_point, ds, start=True)
self.end_point = _validate_point(end_point, ds)
+ self.npoints = npoints
self.label = label
+ self.data = {}
+
+ def keys(self):
+ return self.data.keys()
+
+ def __setitem__(self, item, val):
+ self.data[item] = val
+
+ def __getitem__(self, item):
+ if item in self.data: return self.data[item]
+ mylog.info("Making a line buffer with %d points of %s" % \
+ (self.npoints, item))
+ self.points, self.data[item] = self.ds.coordinates.pixelize_line(item,
+ self.start_point,
+ self.end_point,
+ self.npoints)
+
+ return self.data[item]
+
+ def __delitem__(self, item):
+ del self.data[item]
+
class LinePlotDictionary(PlotDictionary):
def __init__(self, data_source):
https://bitbucket.org/yt_analysis/yt/commits/1ad615c21613/
Changeset: 1ad615c21613
User: Alex Lindsay
Date: 2017-07-26 19:52:48+00:00
Summary: Change to annotate_legend and add tests
Affected #: 4 files
diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 doc/source/cookbook/simple_1d_line_plot.py
--- a/doc/source/cookbook/simple_1d_line_plot.py
+++ b/doc/source/cookbook/simple_1d_line_plot.py
@@ -8,7 +8,7 @@
plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], (0, 0, 0), (0, 1, 0), 1000)
# Add a legend
-plot.add_legend(('all', 'v'))
+plot.annotate_legend(('all', 'v'))
# Save the line plot
plot.save()
diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 doc/source/visualizing/plots.rst
--- a/doc/source/visualizing/plots.rst
+++ b/doc/source/visualizing/plots.rst
@@ -208,7 +208,7 @@
Plots of 2D Datasets
~~~~~~~~~~~~~~~~~~~~
-If you have a two-dimensional cartesian, cylindrical, or polar dataset,
+If you have a two-dimensional cartesian, cylindrical, or polar dataset,
:func:`~yt.visualization.plot_window.plot_2d` is a way to make a plot
within the dataset's plane without having to specify the axis, which
in this case is redundant. Otherwise, ``plot_2d`` accepts the same
@@ -1114,7 +1114,7 @@
1. When instantiating the ``LinePlot``, pass a dictionary of
labels with keys corresponding to the field names
-2. Call the ``LinePlot`` ``add_legend`` method
+2. Call the ``LinePlot`` ``annotate_legend`` method
X- and Y- axis units can be set with ``set_x_unit`` and ``set_unit`` methods
respectively. The below code snippet combines all the features we've discussed:
@@ -1126,7 +1126,7 @@
ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
- plot.add_legend('density')
+ plot.annotate_legend('density')
plot.set_x_unit('cm')
plot.set_unit('density', 'kg/cm**3')
plot.save()
@@ -1136,7 +1136,7 @@
quantities. E.g. if ``LinePlot`` receives two fields with units of "length/time"
and a field with units of "temperature", two different figures will be created,
one with plots of the "length/time" fields and another with the plot of the
-"temperature" field. It is only necessary to call ``add_legend``
+"temperature" field. It is only necessary to call ``annotate_legend``
for one field of a multi-field plot to produce a legend containing all the
labels passed in the initial construction of the ``LinePlot`` instance. Example:
@@ -1147,7 +1147,7 @@
ds = yt.load("SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1)
plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], [0, 0, 0], [0, 1, 0],
100, labels={('all', 'u') : r"v$_x$", ('all', 'v') : r"v$_y$"})
- plot.add_legend(('all', 'u'))
+ plot.annotate_legend(('all', 'u'))
plot.save()
``LinePlot`` is a bit different from yt ray objects which are data
diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -147,8 +147,9 @@
"""
self.start_point = _validate_point(start_point, ds, start=True)
self.end_point = _validate_point(end_point, ds)
+ self.npoints = npoints
- self._initialize_instance(self, ds, fields, npoints, figure_size, fontsize)
+ self._initialize_instance(self, ds, fields, figure_size, fontsize)
if labels is None:
self.labels = {}
@@ -161,9 +162,7 @@
self._setup_plots()
@classmethod
- def _initialize_instance(cls, obj, ds, fields, npoints, figure_size=5.,
- fontsize=14.):
- obj.npoints = npoints
+ def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.):
obj._x_unit = None
obj._y_units = {}
obj._titles = {}
@@ -182,8 +181,7 @@
obj._field_transform[f] = linear_transform
@classmethod
- def from_lines(cls, ds, fields, lines, npoints,
- figure_size=5., font_size=14.):
+ def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14.):
"""
A class method for constructing a line plot from multiple sampling lines
@@ -195,9 +193,6 @@
simulation output to be plotted.
fields : string / tuple, or list of strings / tuples
The name(s) of the field(s) to be plotted.
- npoints : int
- How many points to sample between start_point and end_point for
- constructing the line plot
figure_size : int or two-element iterable of ints
Size in inches of the image.
Default: 5 (5x5)
@@ -206,7 +201,7 @@
Default: 14
"""
obj = cls.__new__(cls)
- cls._initialize_instance(obj, ds, fields, npoints, figure_size, font_size)
+ cls._initialize_instance(obj, ds, fields, figure_size, font_size)
for line in lines:
dimensions_counter = defaultdict(int)
@@ -219,7 +214,7 @@
for field in obj.fields:
plot = obj._get_plot_instance(field)
x, y = obj.ds.coordinates.pixelize_line(
- field, line.start_point, line.end_point, npoints)
+ field, line.start_point, line.end_point, line.npoints)
finfo = obj.ds.field_info[field]
dimensions = Unit(finfo.units,
registry=obj.ds.unit_registry).dimensions
@@ -340,7 +335,7 @@
@invalidate_plot
- def add_field_legend(self, field):
+ def annotate_legend(self, field):
"""Adds a legend to the `LinePlot` instance"""
self.include_legend[field] = True
diff -r 6b2f2727cfe44ff6928f404e4e66f30afd23c6f7 -r 1ad615c216135c41cea988e211097e34e3c56940 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -25,14 +25,12 @@
from yt.config import ytcfg
ytcfg["yt", "__withintesting"] = "True"
-def compare(ds, fields, point1, point2, resolution, test_prefix, decimals=12):
- def line_plot(filename_prefix):
- ln = yt.LinePlot(ds, fields, point1, point2, resolution)
- image_file = ln.save(filename_prefix)
- return image_file
+def compare(ds, plot, test_prefix, decimals=12):
+ def image_from_plot(filename_prefix):
+ return plot.save(filename_prefix)
- line_plot.__name__ = "line_{}".format(test_prefix)
- test = GenericImageTest(ds, line_plot, decimals)
+ image_from_plot.__name__ = "line_{}".format(test_prefix)
+ test = GenericImageTest(ds, image_from_plot, decimals)
test.prefix = test_prefix
return test
@@ -42,7 +40,18 @@
def test_line_plot():
ds = data_dir_load(tri2, kwargs={'step':-1})
fields = [field for field in ds.field_list if field[0] == 'all']
- yield compare(ds, fields, (0, 0, 0), (1, 1, 0), 1000, "answers_line_plot")
+ plot = yt.LinePlot(ds, fields, (0, 0, 0), (1, 1, 0), 1000)
+ yield compare(ds, plot, "answers_line_plot")
+
+ at requires_ds(tri2)
+def test_multi_line_plot():
+ ds = data_dir_load(tri2, kwargs={'step':-1})
+ fields = [field for field in ds.field_list if field[0] == 'all']
+ lines = []
+ lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+ lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+ plot = yt.LinePlot.from_lines(ds, fields, lines)
+ yield compare(ds, plot, "answers_multi_line_plot")
def test_line_plot_methods():
# Perform I/O in safe place instead of yt main dir
@@ -53,7 +62,7 @@
ds = fake_random_ds(32)
plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
- plot.add_legend('density')
+ plot.annotate_legend('density')
plot.set_x_unit('cm')
plot.set_unit('density', 'kg/cm**3')
plot.save()
@@ -61,3 +70,12 @@
os.chdir(curdir)
# clean up
shutil.rmtree(tmpdir)
+
+def test_line_buffer():
+ ds = fake_random_ds(32)
+ lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
+ density = lb['density']
+ lb['density'] = 0
+ vx = lb['velocity_x']
+ keys = lb.keys()
+ del lb['velocity_x']
https://bitbucket.org/yt_analysis/yt/commits/f752f79ec42d/
Changeset: f752f79ec42d
User: Alex Lindsay
Date: 2017-07-26 20:42:22+00:00
Summary: Add documentation (#1505)
Affected #: 1 file
diff -r 1ad615c216135c41cea988e211097e34e3c56940 -r f752f79ec42d2b8b200b140b7463330c36ed1c14 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -33,6 +33,35 @@
invalidate_plot
class LineBuffer(object):
+ r"""
+ LineBuffer(ds, start_point, end_point, npoints, label = None)
+
+ This takes a data source and implements a protocol for generating a
+ 'pixelized', fixed-resolution line buffer. In other words, LineBuffer
+ takes a starting point, ending point, and number of sampling points and
+ can subsequently generate YTArrays of field values along the sample points.
+
+ Parameters
+ ----------
+ ds : :class:`yt.data_objects.static_output.Dataset`
+ This is the dataset object holding the data that can be sampled by the
+ LineBuffer
+ start_point : n-element list, tuple, ndarray, or YTArray
+ Contains the coordinates of the first point for constructing the LineBuffer.
+ Must contain n elements where n is the dimensionality of the dataset.
+ end_point : n-element list, tuple, ndarray, or YTArray
+ Contains the coordinates of the first point for constructing the LineBuffer.
+ Must contain n elements where n is the dimensionality of the dataset.
+ npoints : int
+ How many points to sample between start_point and end_point
+
+ Examples
+ --------
+ >>> lb = yt.LineBuffer(ds, (.25, 0, 0), (.25, 1, 0), 100)
+ >>> lb[('all', 'u')].max()
+ 0.11562424257143075 dimensionless
+
+ """
def __init__(self, ds, start_point, end_point, npoints, label=None):
self.ds = ds
self.start_point = _validate_point(start_point, ds, start=True)
@@ -193,12 +222,25 @@
simulation output to be plotted.
fields : string / tuple, or list of strings / tuples
The name(s) of the field(s) to be plotted.
+ lines : a list of :class:`yt.visualization.line_plot.LineBuffer`s
+ The lines from which to sample data
figure_size : int or two-element iterable of ints
Size in inches of the image.
Default: 5 (5x5)
fontsize : int
Font size for all text in the plot.
Default: 14
+
+ Example
+ --------
+ >>> ds = yt.load('SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e', step=-1)
+ >>> fields = [field for field in ds.field_list if field[0] == 'all']
+ >>> lines = []
+ >>> lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+ >>> lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+ >>> plot = yt.LinePlot.from_lines(ds, fields, lines)
+ >>> plot.save()
+
"""
obj = cls.__new__(cls)
cls._initialize_instance(obj, ds, fields, figure_size, font_size)
https://bitbucket.org/yt_analysis/yt/commits/ad9d85b2438f/
Changeset: ad9d85b2438f
User: Alex Lindsay
Date: 2017-07-26 22:27:17+00:00
Summary: Formulate so that LinePlots created from from_lines can also use annotate_legend, set_x_unit, etc. Also enables some code condensing (#1505)
Affected #: 1 file
diff -r f752f79ec42d2b8b200b140b7463330c36ed1c14 -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -147,7 +147,7 @@
fontsize : int
Font size for all text in the plot.
Default: 14
- labels : dictionary
+ field_labels : dictionary
Keys should be the field names. Values should be latex-formattable
strings used in the LinePlot legend
Default: None
@@ -170,28 +170,19 @@
_plot_type = 'line_plot'
def __init__(self, ds, fields, start_point, end_point, npoints,
- figure_size=5., fontsize=14., labels=None):
+ figure_size=5., fontsize=14., field_labels=None):
"""
Sets up figure and axes
"""
- self.start_point = _validate_point(start_point, ds, start=True)
- self.end_point = _validate_point(end_point, ds)
- self.npoints = npoints
-
- self._initialize_instance(self, ds, fields, figure_size, fontsize)
-
- if labels is None:
- self.labels = {}
- else:
- self.labels = labels
- for f in self.fields:
- if f not in self.labels:
- self.labels[f] = f[1]
-
+ line = LineBuffer(ds, start_point, end_point, npoints, label=None)
+ self.lines = [line]
+ self._initialize_instance(self, ds, fields, figure_size,
+ fontsize, field_labels)
self._setup_plots()
@classmethod
- def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.):
+ def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.,
+ field_labels=None):
obj._x_unit = None
obj._y_units = {}
obj._titles = {}
@@ -209,8 +200,16 @@
else:
obj._field_transform[f] = linear_transform
+ if field_labels is None:
+ obj.field_labels = {}
+ else:
+ obj.field_labels = field_labels
+ for f in obj.fields:
+ if f not in obj.field_labels:
+ obj.field_labels[f] = f[1]
+
@classmethod
- def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14.):
+ def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14., field_labels=None):
"""
A class method for constructing a line plot from multiple sampling lines
@@ -230,6 +229,10 @@
fontsize : int
Font size for all text in the plot.
Default: 14
+ field_labels : dictionary
+ Keys should be the field names. Values should be latex-formattable
+ strings used in the LinePlot legend
+ Default: None
Example
--------
@@ -243,30 +246,9 @@
"""
obj = cls.__new__(cls)
- cls._initialize_instance(obj, ds, fields, figure_size, font_size)
-
- for line in lines:
- dimensions_counter = defaultdict(int)
- for field in obj.fields:
- finfo = obj.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=obj.ds.unit_registry).dimensions
- dimensions_counter[dimensions] += 1
-
- for field in obj.fields:
- plot = obj._get_plot_instance(field)
- x, y = obj.ds.coordinates.pixelize_line(
- field, line.start_point, line.end_point, line.npoints)
- finfo = obj.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=obj.ds.unit_registry).dimensions
- if dimensions_counter[dimensions] > 1:
- legend_label = r"$%s;$ %s" % (line.label, finfo.get_latex_display_name())
- else:
- legend_label = r"$%s$" % line.label
- obj._plot_xy(field, plot, x, y, dimensions_counter, legend_label=legend_label)
- plot.axes.legend()
- obj._plot_valid = True
+ obj.lines = lines
+ cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
+ obj._setup_plots()
return obj
def _get_plot_instance(self, field):
@@ -305,75 +287,89 @@
self.plots[field] = plot
return plot
- def _plot_xy(self, field, plot, x, y, dimensions_counter, legend_label=None):
- if self._x_unit is None:
- unit_x = x.units
- else:
- unit_x = self._x_unit
-
- if field in self._y_units:
- unit_y = self._y_units[field]
- else:
- unit_y = y.units
-
- x = x.to(unit_x)
- y = y.to(unit_y)
-
- plot.axes.plot(x, y, label=legend_label)
-
- if self._field_transform[field] != linear_transform:
- if (y < 0).any():
- plot.axes.set_yscale('symlog')
- else:
- plot.axes.set_yscale('log')
-
- plot._set_font_properties(self._font_properties, None)
-
- axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
- finfo = self.ds.field_info[field]
-
- x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
-
- finfo = self.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=self.ds.unit_registry).dimensions
- if dimensions_counter[dimensions] > 1:
- y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
- axes_unit_labels[1]+'}$')
- else:
- y_label = (finfo.get_latex_display_name() + r'$\rm{' +
- axes_unit_labels[1]+'}$')
-
- plot.axes.set_xlabel(x_label)
- plot.axes.set_ylabel(y_label)
-
- if field in self._titles:
- plot.axes.set_title(self._titles[field])
-
-
def _setup_plots(self):
if self._plot_valid is True:
return
for plot in self.plots.values():
plot.axes.cla()
- dimensions_counter = defaultdict(int)
- for field in self.fields:
- finfo = self.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=self.ds.unit_registry).dimensions
- dimensions_counter[dimensions] += 1
- for field in self.fields:
- plot = self._get_plot_instance(field)
+ for line in self.lines:
+ dimensions_counter = defaultdict(int)
+ for field in self.fields:
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ dimensions_counter[dimensions] += 1
+ for field in self.fields:
+ # get plot instance
+ plot = self._get_plot_instance(field)
+
+ # calculate x and y
+ x, y = self.ds.coordinates.pixelize_line(
+ field, line.start_point, line.end_point, line.npoints)
+
+ # scale x and y to proper units
+ if self._x_unit is None:
+ unit_x = x.units
+ else:
+ unit_x = self._x_unit
+
+ if field in self._y_units:
+ unit_y = self._y_units[field]
+ else:
+ unit_y = y.units
+
+ x = x.to(unit_x)
+ y = y.to(unit_y)
+
+ # determine legend label
+ str_seq = []
+ str_seq.append(line.label)
+ str_seq.append(self.field_labels[field])
+ delim = "; "
+ legend_label = delim.join(filter(None, str_seq))
+
+ # apply plot to matplotlib axes
+ plot.axes.plot(x, y, label=legend_label)
- x, y = self.ds.coordinates.pixelize_line(
- field, self.start_point, self.end_point, self.npoints)
+ # apply log transforms if requested
+ if self._field_transform[field] != linear_transform:
+ if (y < 0).any():
+ plot.axes.set_yscale('symlog')
+ else:
+ plot.axes.set_yscale('log')
+
+ # set font properties
+ plot._set_font_properties(self._font_properties, None)
+
+ # set x and y axis labels
+ axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
+
+ finfo = self.ds.field_info[field]
+
+ x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
- self._plot_xy(field, plot, x, y, dimensions_counter,
- legend_label=self.labels[field])
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ if dimensions_counter[dimensions] > 1:
+ y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
+ else:
+ y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
- if self.include_legend[field]:
- plot.axes.legend()
+ plot.axes.set_xlabel(x_label)
+ plot.axes.set_ylabel(y_label)
+
+ # apply title
+ if field in self._titles:
+ plot.axes.set_title(self._titles[field])
+
+ # apply legend
+ if self.include_legend[field]:
+ plot.axes.legend()
+
+ self._plot_valid = True
@invalidate_plot
https://bitbucket.org/yt_analysis/yt/commits/eac50268ed01/
Changeset: eac50268ed01
User: Alex Lindsay
Date: 2017-07-27 13:58:16+00:00
Summary: Increment answers. Use x and y labels. Fix pep8
Affected #: 3 files
diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -67,12 +67,13 @@
- yt/analysis_modules/photon_simulator/tests/test_spectra.py
- yt/analysis_modules/photon_simulator/tests/test_sloshing.py
- local_unstructured_008:
+ local_unstructured_009:
- yt/visualization/volume_rendering/tests/test_mesh_render.py
- yt/visualization/tests/test_mesh_slices.py:test_tri2
- yt/visualization/tests/test_mesh_slices.py:test_quad2
- yt/visualization/tests/test_mesh_slices.py:test_multi_region
- yt/visualization/tests/test_line_plots.py:test_line_plot
+ - yt/visualization/tests/test_line_plots.py:test_multi_line_plot
local_boxlib_004:
- yt/frontends/boxlib/tests/test_outputs.py:test_radadvect
diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -344,19 +344,23 @@
# set x and y axis labels
axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
- finfo = self.ds.field_info[field]
-
- x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+ if self._xlabel is not None:
+ x_label = self._xlabel
+ else:
+ x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
- finfo = self.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=self.ds.unit_registry).dimensions
- if dimensions_counter[dimensions] > 1:
- y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
- axes_unit_labels[1]+'}$')
+ if self._ylabel is not None:
+ y_label = self._ylabel
else:
- y_label = (finfo.get_latex_display_name() + r'$\rm{' +
- axes_unit_labels[1]+'}$')
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ if dimensions_counter[dimensions] > 1:
+ y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
+ else:
+ y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
plot.axes.set_xlabel(x_label)
plot.axes.set_ylabel(y_label)
diff -r ad9d85b2438fc013815f2658ac54dcbcc1a4ebab -r eac50268ed013076d933352ec7918c9a9962e032 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -74,8 +74,8 @@
def test_line_buffer():
ds = fake_random_ds(32)
lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
- density = lb['density']
+ lb['density']
lb['density'] = 0
- vx = lb['velocity_x']
- keys = lb.keys()
+ lb['velocity_x']
+ lb.keys()
del lb['velocity_x']
https://bitbucket.org/yt_analysis/yt/commits/50bddc8e9027/
Changeset: 50bddc8e9027
User: Alex Lindsay
Date: 2017-07-28 16:37:52+00:00
Summary: Make sure legend label is added for every field of multi-field plot (#1505)
Affected #: 1 file
diff -r eac50268ed013076d933352ec7918c9a9962e032 -r 50bddc8e90276a78ed2db4b48118b46a2105cae1 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -370,16 +370,22 @@
plot.axes.set_title(self._titles[field])
# apply legend
- if self.include_legend[field]:
+ dim_field = self.plots._sanitize_dimensions(field)
+ if self.include_legend[dim_field]:
plot.axes.legend()
- self._plot_valid = True
+ self._plot_valid = True
@invalidate_plot
def annotate_legend(self, field):
- """Adds a legend to the `LinePlot` instance"""
- self.include_legend[field] = True
+ """
+ Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
+ call ensures that a legend label will be added for every field of
+ a multi-field plot
+ """
+ dim_field = self.plots._sanitize_dimensions(field)
+ self.include_legend[dim_field] = True
@invalidate_plot
def set_x_unit(self, unit_name):
https://bitbucket.org/yt_analysis/yt/commits/eb258efcbe83/
Changeset: eb258efcbe83
User: ngoldbaum
Date: 2017-07-31 15:37:07+00:00
Summary: Merge pull request #1509 from lindsayad/multiple_lines
Allow multiple sampling lines for LinePlot
Affected #: 7 files
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 doc/source/cookbook/simple_1d_line_plot.py
--- a/doc/source/cookbook/simple_1d_line_plot.py
+++ b/doc/source/cookbook/simple_1d_line_plot.py
@@ -8,7 +8,7 @@
plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], (0, 0, 0), (0, 1, 0), 1000)
# Add a legend
-plot.add_legend(('all', 'v'))
+plot.annotate_legend(('all', 'v'))
# Save the line plot
plot.save()
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 doc/source/visualizing/plots.rst
--- a/doc/source/visualizing/plots.rst
+++ b/doc/source/visualizing/plots.rst
@@ -208,7 +208,7 @@
Plots of 2D Datasets
~~~~~~~~~~~~~~~~~~~~
-If you have a two-dimensional cartesian, cylindrical, or polar dataset,
+If you have a two-dimensional cartesian, cylindrical, or polar dataset,
:func:`~yt.visualization.plot_window.plot_2d` is a way to make a plot
within the dataset's plane without having to specify the axis, which
in this case is redundant. Otherwise, ``plot_2d`` accepts the same
@@ -1114,7 +1114,7 @@
1. When instantiating the ``LinePlot``, pass a dictionary of
labels with keys corresponding to the field names
-2. Call the ``LinePlot`` ``add_legend`` method
+2. Call the ``LinePlot`` ``annotate_legend`` method
X- and Y- axis units can be set with ``set_x_unit`` and ``set_unit`` methods
respectively. The below code snippet combines all the features we've discussed:
@@ -1126,7 +1126,7 @@
ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
- plot.add_legend('density')
+ plot.annotate_legend('density')
plot.set_x_unit('cm')
plot.set_unit('density', 'kg/cm**3')
plot.save()
@@ -1136,7 +1136,7 @@
quantities. E.g. if ``LinePlot`` receives two fields with units of "length/time"
and a field with units of "temperature", two different figures will be created,
one with plots of the "length/time" fields and another with the plot of the
-"temperature" field. It is only necessary to call ``add_legend``
+"temperature" field. It is only necessary to call ``annotate_legend``
for one field of a multi-field plot to produce a legend containing all the
labels passed in the initial construction of the ``LinePlot`` instance. Example:
@@ -1147,7 +1147,7 @@
ds = yt.load("SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1)
plot = yt.LinePlot(ds, [('all', 'v'), ('all', 'u')], [0, 0, 0], [0, 1, 0],
100, labels={('all', 'u') : r"v$_x$", ('all', 'v') : r"v$_y$"})
- plot.add_legend(('all', 'u'))
+ plot.annotate_legend(('all', 'u'))
plot.save()
``LinePlot`` is a bit different from yt ray objects which are data
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -67,12 +67,13 @@
- yt/analysis_modules/photon_simulator/tests/test_spectra.py
- yt/analysis_modules/photon_simulator/tests/test_sloshing.py
- local_unstructured_008:
+ local_unstructured_009:
- yt/visualization/volume_rendering/tests/test_mesh_render.py
- yt/visualization/tests/test_mesh_slices.py:test_tri2
- yt/visualization/tests/test_mesh_slices.py:test_quad2
- yt/visualization/tests/test_mesh_slices.py:test_multi_region
- yt/visualization/tests/test_line_plots.py:test_line_plot
+ - yt/visualization/tests/test_line_plots.py:test_multi_line_plot
local_boxlib_004:
- yt/frontends/boxlib/tests/test_outputs.py:test_radadvect
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/__init__.py
--- a/yt/__init__.py
+++ b/yt/__init__.py
@@ -104,7 +104,7 @@
write_bitmap, write_image, \
apply_colormap, scale_image, write_projection, \
SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, LinePlot, \
- ProjectionPlot, OffAxisProjectionPlot, \
+ LineBuffer, ProjectionPlot, OffAxisProjectionPlot, \
show_colormaps, add_cmap, make_colormap, \
ProfilePlot, PhasePlot, ParticlePhasePlot, \
ParticleProjectionPlot, ParticleImageBuffer, ParticlePlot, \
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -55,7 +55,8 @@
plot_2d
from .line_plot import \
- LinePlot
+ LinePlot, \
+ LineBuffer
from .profile_plotter import \
ProfilePlot, \
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/line_plot.py
--- a/yt/visualization/line_plot.py
+++ b/yt/visualization/line_plot.py
@@ -17,7 +17,8 @@
from collections import defaultdict
from yt.funcs import \
- iterable
+ iterable, \
+ mylog
from yt.units.unit_object import \
Unit
from yt.units.yt_array import \
@@ -31,6 +32,65 @@
linear_transform, \
invalidate_plot
+class LineBuffer(object):
+ r"""
+ LineBuffer(ds, start_point, end_point, npoints, label = None)
+
+ This takes a data source and implements a protocol for generating a
+ 'pixelized', fixed-resolution line buffer. In other words, LineBuffer
+ takes a starting point, ending point, and number of sampling points and
+ can subsequently generate YTArrays of field values along the sample points.
+
+ Parameters
+ ----------
+ ds : :class:`yt.data_objects.static_output.Dataset`
+ This is the dataset object holding the data that can be sampled by the
+ LineBuffer
+ start_point : n-element list, tuple, ndarray, or YTArray
+ Contains the coordinates of the first point for constructing the LineBuffer.
+ Must contain n elements where n is the dimensionality of the dataset.
+ end_point : n-element list, tuple, ndarray, or YTArray
+ Contains the coordinates of the first point for constructing the LineBuffer.
+ Must contain n elements where n is the dimensionality of the dataset.
+ npoints : int
+ How many points to sample between start_point and end_point
+
+ Examples
+ --------
+ >>> lb = yt.LineBuffer(ds, (.25, 0, 0), (.25, 1, 0), 100)
+ >>> lb[('all', 'u')].max()
+ 0.11562424257143075 dimensionless
+
+ """
+ def __init__(self, ds, start_point, end_point, npoints, label=None):
+ self.ds = ds
+ self.start_point = _validate_point(start_point, ds, start=True)
+ self.end_point = _validate_point(end_point, ds)
+ self.npoints = npoints
+ self.label = label
+ self.data = {}
+
+ def keys(self):
+ return self.data.keys()
+
+ def __setitem__(self, item, val):
+ self.data[item] = val
+
+ def __getitem__(self, item):
+ if item in self.data: return self.data[item]
+ mylog.info("Making a line buffer with %d points of %s" % \
+ (self.npoints, item))
+ self.points, self.data[item] = self.ds.coordinates.pixelize_line(item,
+ self.start_point,
+ self.end_point,
+ self.npoints)
+
+ return self.data[item]
+
+ def __delitem__(self, item):
+ del self.data[item]
+
+
class LinePlotDictionary(PlotDictionary):
def __init__(self, data_source):
super(LinePlotDictionary, self).__init__(data_source)
@@ -87,7 +147,7 @@
fontsize : int
Font size for all text in the plot.
Default: 14
- labels : dictionary
+ field_labels : dictionary
Keys should be the field names. Values should be latex-formattable
strings used in the LinePlot legend
Default: None
@@ -110,137 +170,222 @@
_plot_type = 'line_plot'
def __init__(self, ds, fields, start_point, end_point, npoints,
- figure_size=5., fontsize=14., labels=None):
+ figure_size=5., fontsize=14., field_labels=None):
"""
Sets up figure and axes
"""
- self.start_point = _validate_point(start_point, ds, start=True)
- self.end_point = _validate_point(end_point, ds)
- self.npoints = npoints
- self._x_unit = None
- self._y_units = {}
- self._titles = {}
+ line = LineBuffer(ds, start_point, end_point, npoints, label=None)
+ self.lines = [line]
+ self._initialize_instance(self, ds, fields, figure_size,
+ fontsize, field_labels)
+ self._setup_plots()
+
+ @classmethod
+ def _initialize_instance(cls, obj, ds, fields, figure_size=5., fontsize=14.,
+ field_labels=None):
+ obj._x_unit = None
+ obj._y_units = {}
+ obj._titles = {}
data_source = ds.all_data()
- self.fields = data_source._determine_fields(fields)
- self.plots = LinePlotDictionary(data_source)
- self.include_legend = defaultdict(bool)
- if labels is None:
- self.labels = {}
+ obj.fields = data_source._determine_fields(fields)
+ obj.plots = LinePlotDictionary(data_source)
+ obj.include_legend = defaultdict(bool)
+ super(LinePlot, obj).__init__(data_source, figure_size, fontsize)
+ for f in obj.fields:
+ finfo = obj.data_source.ds._get_field_info(*f)
+ if finfo.take_log:
+ obj._field_transform[f] = log_transform
+ else:
+ obj._field_transform[f] = linear_transform
+
+ if field_labels is None:
+ obj.field_labels = {}
else:
- self.labels = labels
+ obj.field_labels = field_labels
+ for f in obj.fields:
+ if f not in obj.field_labels:
+ obj.field_labels[f] = f[1]
+
+ @classmethod
+ def from_lines(cls, ds, fields, lines, figure_size=5., font_size=14., field_labels=None):
+ """
+ A class method for constructing a line plot from multiple sampling lines
+
+ Parameters
+ ----------
- super(LinePlot, self).__init__(data_source, figure_size, fontsize)
+ ds : :class:`yt.data_objects.static_output.Dataset`
+ This is the dataset object corresponding to the
+ simulation output to be plotted.
+ fields : string / tuple, or list of strings / tuples
+ The name(s) of the field(s) to be plotted.
+ lines : a list of :class:`yt.visualization.line_plot.LineBuffer`s
+ The lines from which to sample data
+ figure_size : int or two-element iterable of ints
+ Size in inches of the image.
+ Default: 5 (5x5)
+ fontsize : int
+ Font size for all text in the plot.
+ Default: 14
+ field_labels : dictionary
+ Keys should be the field names. Values should be latex-formattable
+ strings used in the LinePlot legend
+ Default: None
- for f in self.fields:
- if f not in self.labels:
- self.labels[f] = f[1]
- finfo = self.data_source.ds._get_field_info(*f)
- if finfo.take_log:
- self._field_transform[f] = log_transform
- else:
- self._field_transform[f] = linear_transform
+ Example
+ --------
+ >>> ds = yt.load('SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e', step=-1)
+ >>> fields = [field for field in ds.field_list if field[0] == 'all']
+ >>> lines = []
+ >>> lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+ >>> lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+ >>> plot = yt.LinePlot.from_lines(ds, fields, lines)
+ >>> plot.save()
+
+ """
+ obj = cls.__new__(cls)
+ obj.lines = lines
+ cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
+ obj._setup_plots()
+ return obj
+
+ def _get_plot_instance(self, field):
+ fontscale = self._font_properties._size / 14.
+ top_buff_size = 0.35*fontscale
+
+ x_axis_size = 1.35*fontscale
+ y_axis_size = 0.7*fontscale
+ right_buff_size = 0.2*fontscale
- self._setup_plots()
+ if iterable(self.figure_size):
+ figure_size = self.figure_size
+ else:
+ figure_size = (self.figure_size, self.figure_size)
+
+ xbins = np.array([x_axis_size, figure_size[0],
+ right_buff_size])
+ ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
+
+ size = [xbins.sum(), ybins.sum()]
+
+ x_frac_widths = xbins/size[0]
+ y_frac_widths = ybins/size[1]
- @invalidate_plot
- def add_legend(self, field):
- """Adds a legend to the `LinePlot` instance"""
- self.include_legend[field] = True
+ axrect = (
+ x_frac_widths[0],
+ y_frac_widths[0],
+ x_frac_widths[1],
+ y_frac_widths[1],
+ )
+
+ try:
+ plot = self.plots[field]
+ except KeyError:
+ plot = PlotMPL(self.figure_size, axrect, None, None)
+ self.plots[field] = plot
+ return plot
def _setup_plots(self):
if self._plot_valid is True:
return
for plot in self.plots.values():
plot.axes.cla()
- dimensions_counter = defaultdict(int)
- for field in self.fields:
- fontscale = self._font_properties._size / 14.
- top_buff_size = 0.35*fontscale
-
- x_axis_size = 1.35*fontscale
- y_axis_size = 0.7*fontscale
- right_buff_size = 0.2*fontscale
+ for line in self.lines:
+ dimensions_counter = defaultdict(int)
+ for field in self.fields:
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ dimensions_counter[dimensions] += 1
+ for field in self.fields:
+ # get plot instance
+ plot = self._get_plot_instance(field)
- if iterable(self.figure_size):
- figure_size = self.figure_size
- else:
- figure_size = (self.figure_size, self.figure_size)
+ # calculate x and y
+ x, y = self.ds.coordinates.pixelize_line(
+ field, line.start_point, line.end_point, line.npoints)
- xbins = np.array([x_axis_size, figure_size[0],
- right_buff_size])
- ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
-
- size = [xbins.sum(), ybins.sum()]
+ # scale x and y to proper units
+ if self._x_unit is None:
+ unit_x = x.units
+ else:
+ unit_x = self._x_unit
- x_frac_widths = xbins/size[0]
- y_frac_widths = ybins/size[1]
+ if field in self._y_units:
+ unit_y = self._y_units[field]
+ else:
+ unit_y = y.units
- axrect = (
- x_frac_widths[0],
- y_frac_widths[0],
- x_frac_widths[1],
- y_frac_widths[1],
- )
+ x = x.to(unit_x)
+ y = y.to(unit_y)
- try:
- plot = self.plots[field]
- except KeyError:
- plot = PlotMPL(self.figure_size, axrect, None, None)
- self.plots[field] = plot
+ # determine legend label
+ str_seq = []
+ str_seq.append(line.label)
+ str_seq.append(self.field_labels[field])
+ delim = "; "
+ legend_label = delim.join(filter(None, str_seq))
- x, y = self.ds.coordinates.pixelize_line(
- field, self.start_point, self.end_point, self.npoints)
+ # apply plot to matplotlib axes
+ plot.axes.plot(x, y, label=legend_label)
- if self._x_unit is None:
- unit_x = x.units
- else:
- unit_x = self._x_unit
+ # apply log transforms if requested
+ if self._field_transform[field] != linear_transform:
+ if (y < 0).any():
+ plot.axes.set_yscale('symlog')
+ else:
+ plot.axes.set_yscale('log')
- if field in self._y_units:
- unit_y = self._y_units[field]
- else:
- unit_y = y.units
+ # set font properties
+ plot._set_font_properties(self._font_properties, None)
- x = x.to(unit_x)
- y = y.to(unit_y)
+ # set x and y axis labels
+ axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
- plot.axes.plot(x, y, label=self.labels[field])
+ if self._xlabel is not None:
+ x_label = self._xlabel
+ else:
+ x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
- if self._field_transform[field] != linear_transform:
- if (y < 0).any():
- plot.axes.set_yscale('symlog')
+ if self._ylabel is not None:
+ y_label = self._ylabel
else:
- plot.axes.set_yscale('log')
-
- plot._set_font_properties(self._font_properties, None)
-
- axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
-
- finfo = self.ds.field_info[field]
+ finfo = self.ds.field_info[field]
+ dimensions = Unit(finfo.units,
+ registry=self.ds.unit_registry).dimensions
+ if dimensions_counter[dimensions] > 1:
+ y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
+ else:
+ y_label = (finfo.get_latex_display_name() + r'$\rm{' +
+ axes_unit_labels[1]+'}$')
- x_label = r'$\rm{Path\ Length' + axes_unit_labels[0]+'}$'
+ plot.axes.set_xlabel(x_label)
+ plot.axes.set_ylabel(y_label)
+
+ # apply title
+ if field in self._titles:
+ plot.axes.set_title(self._titles[field])
+
+ # apply legend
+ dim_field = self.plots._sanitize_dimensions(field)
+ if self.include_legend[dim_field]:
+ plot.axes.legend()
- finfo = self.ds.field_info[field]
- dimensions = Unit(finfo.units,
- registry=self.ds.unit_registry).dimensions
- dimensions_counter[dimensions] += 1
- if dimensions_counter[dimensions] > 1:
- y_label = (r'$\rm{Multiple\ Fields}$' + r'$\rm{' +
- axes_unit_labels[1]+'}$')
- else:
- y_label = (finfo.get_latex_display_name() + r'$\rm{' +
- axes_unit_labels[1]+'}$')
+ self._plot_valid = True
+
- plot.axes.set_xlabel(x_label)
- plot.axes.set_ylabel(y_label)
-
- if field in self._titles:
- plot.axes.set_title(self._titles[field])
-
- if self.include_legend[field]:
- plot.axes.legend()
+ @invalidate_plot
+ def annotate_legend(self, field):
+ """
+ Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
+ call ensures that a legend label will be added for every field of
+ a multi-field plot
+ """
+ dim_field = self.plots._sanitize_dimensions(field)
+ self.include_legend[dim_field] = True
@invalidate_plot
def set_x_unit(self, unit_name):
diff -r c4e4d396976743a3a1f9736e148c296f0d572ed7 -r eb258efcbe831799f623d9cc9737763aab28a712 yt/visualization/tests/test_line_plots.py
--- a/yt/visualization/tests/test_line_plots.py
+++ b/yt/visualization/tests/test_line_plots.py
@@ -25,14 +25,12 @@
from yt.config import ytcfg
ytcfg["yt", "__withintesting"] = "True"
-def compare(ds, fields, point1, point2, resolution, test_prefix, decimals=12):
- def line_plot(filename_prefix):
- ln = yt.LinePlot(ds, fields, point1, point2, resolution)
- image_file = ln.save(filename_prefix)
- return image_file
+def compare(ds, plot, test_prefix, decimals=12):
+ def image_from_plot(filename_prefix):
+ return plot.save(filename_prefix)
- line_plot.__name__ = "line_{}".format(test_prefix)
- test = GenericImageTest(ds, line_plot, decimals)
+ image_from_plot.__name__ = "line_{}".format(test_prefix)
+ test = GenericImageTest(ds, image_from_plot, decimals)
test.prefix = test_prefix
return test
@@ -42,7 +40,18 @@
def test_line_plot():
ds = data_dir_load(tri2, kwargs={'step':-1})
fields = [field for field in ds.field_list if field[0] == 'all']
- yield compare(ds, fields, (0, 0, 0), (1, 1, 0), 1000, "answers_line_plot")
+ plot = yt.LinePlot(ds, fields, (0, 0, 0), (1, 1, 0), 1000)
+ yield compare(ds, plot, "answers_line_plot")
+
+ at requires_ds(tri2)
+def test_multi_line_plot():
+ ds = data_dir_load(tri2, kwargs={'step':-1})
+ fields = [field for field in ds.field_list if field[0] == 'all']
+ lines = []
+ lines.append(yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label='x = 0.25'))
+ lines.append(yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label='x = 0.5'))
+ plot = yt.LinePlot.from_lines(ds, fields, lines)
+ yield compare(ds, plot, "answers_multi_line_plot")
def test_line_plot_methods():
# Perform I/O in safe place instead of yt main dir
@@ -53,7 +62,7 @@
ds = fake_random_ds(32)
plot = yt.LinePlot(ds, 'density', [0, 0, 0], [1, 1, 1], 512)
- plot.add_legend('density')
+ plot.annotate_legend('density')
plot.set_x_unit('cm')
plot.set_unit('density', 'kg/cm**3')
plot.save()
@@ -61,3 +70,12 @@
os.chdir(curdir)
# clean up
shutil.rmtree(tmpdir)
+
+def test_line_buffer():
+ ds = fake_random_ds(32)
+ lb = yt.LineBuffer(ds, (0, 0, 0), (1, 1, 1), 512, label='diag')
+ lb['density']
+ lb['density'] = 0
+ lb['velocity_x']
+ lb.keys()
+ del lb['velocity_x']
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list