[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt/yt-3.0 (pull request #1060)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Wed Jul 23 15:04:40 PDT 2014
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/e97899b53081/
Changeset: e97899b53081
Branch: yt-3.0
User: MatthewTurk
Date: 2014-07-24 00:04:30
Summary: Merged in ngoldbaum/yt/yt-3.0 (pull request #1060)
Merging from the yt branch.
Affected #: 74 files
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/analysis_modules/halo_finding/halo_objects.py
--- a/yt/analysis_modules/halo_finding/halo_objects.py
+++ b/yt/analysis_modules/halo_finding/halo_objects.py
@@ -549,22 +549,23 @@
temp_e2[:,dim] = e2_vector[dim]
length = np.abs(np.sum(rr * temp_e2, axis = 1) * (1 - \
np.sum(rr * temp_e0, axis = 1)**2. * mag_A**-2. - \
- np.sum(rr * temp_e1, axis = 1)**2. * mag_B**-2)**(-0.5))
+ np.sum(rr * temp_e1, axis = 1)**2. * mag_B**-2.)**(-0.5))
length[length == np.inf] = 0.
tC_index = np.nanargmax(length)
mag_C = length[tC_index]
# tilt is calculated from the rotation about x axis
# needed to align e1 vector with the y axis
# after e0 is aligned with x axis
- # find the t1 angle needed to rotate about z axis to align e0 to x
- t1 = np.arctan(e0_vector[1] / e0_vector[0])
- RZ = get_rotation_matrix(-t1, (0, 0, 1)).transpose()
- r1 = (e0_vector * RZ).sum(axis = 1)
+ # find the t1 angle needed to rotate about z axis to align e0 onto x-z plane
+ t1 = np.arctan(-e0_vector[1] / e0_vector[0])
+ RZ = get_rotation_matrix(t1, (0, 0, 1))
+ r1 = np.dot(RZ, e0_vector)
# find the t2 angle needed to rotate about y axis to align e0 to x
- t2 = np.arctan(-r1[2] / r1[0])
- RY = get_rotation_matrix(-t2, (0, 1, 0)).transpose()
+ t2 = np.arctan(r1[2] / r1[0])
+ RY = get_rotation_matrix(t2, (0, 1, 0))
r2 = np.dot(RY, np.dot(RZ, e1_vector))
- tilt = np.arctan(r2[2]/r2[1])
+ # find the tilt angle needed to rotate about x axis to align e1 to y and e2 to z
+ tilt = np.arctan(-r2[2] / r2[1])
return (mag_A, mag_B, mag_C, e0_vector[0], e0_vector[1],
e0_vector[2], tilt)
@@ -782,13 +783,13 @@
Returns
-------
- tuple : (cm, mag_A, mag_B, mag_C, e1_vector, tilt)
+ tuple : (cm, mag_A, mag_B, mag_C, e0_vector, tilt)
The 6-tuple has in order:
#. The center of mass as an array.
#. mag_A as a float.
#. mag_B as a float.
#. mag_C as a float.
- #. e1_vector as an array.
+ #. e0_vector as an array.
#. tilt as a float.
Examples
@@ -819,7 +820,7 @@
def __init__(self, ds, id, size=None, CoM=None,
max_dens_point=None, group_total_mass=None, max_radius=None, bulk_vel=None,
rms_vel=None, fnames=None, mag_A=None, mag_B=None, mag_C=None,
- e1_vec=None, tilt=None, supp=None):
+ e0_vec=None, tilt=None, supp=None):
self.ds = ds
self.gridsize = (self.ds.domain_right_edge - \
@@ -835,7 +836,7 @@
self.mag_A = mag_A
self.mag_B = mag_B
self.mag_C = mag_C
- self.e1_vec = e1_vec
+ self.e0_vec = e0_vec
self.tilt = tilt
# locs=the names of the h5 files that have particle data for this halo
self.fnames = fnames
@@ -928,8 +929,8 @@
def _get_ellipsoid_parameters_basic_loadedhalo(self):
if self.mag_A is not None:
- return (self.mag_A, self.mag_B, self.mag_C, self.e1_vec[0],
- self.e1_vec[1], self.e1_vec[2], self.tilt)
+ return (self.mag_A, self.mag_B, self.mag_C, self.e0_vec[0],
+ self.e0_vec[1], self.e0_vec[2], self.tilt)
else:
return self._get_ellipsoid_parameters_basic()
@@ -943,13 +944,13 @@
Returns
-------
- tuple : (cm, mag_A, mag_B, mag_C, e1_vector, tilt)
+ tuple : (cm, mag_A, mag_B, mag_C, e0_vector, tilt)
The 6-tuple has in order:
#. The center of mass as an array.
#. mag_A as a float.
#. mag_B as a float.
#. mag_C as a float.
- #. e1_vector as an array.
+ #. e0_vector as an array.
#. tilt as a float.
Examples
@@ -1021,7 +1022,7 @@
max_dens_point=None, group_total_mass=None, max_radius=None, bulk_vel=None,
rms_vel=None, fnames=None, mag_A=None, mag_B=None, mag_C=None,
- e1_vec=None, tilt=None, supp=None):
+ e0_vec=None, tilt=None, supp=None):
self.ds = ds
self.gridsize = (self.ds.domain_right_edge - \
@@ -1037,7 +1038,7 @@
self.mag_A = mag_A
self.mag_B = mag_B
self.mag_C = mag_C
- self.e1_vec = e1_vec
+ self.e0_vec = e0_vec
self.tilt = tilt
self.bin_count = None
self.overdensity = None
@@ -1181,8 +1182,8 @@
"x","y","z", "center-of-mass",
"x","y","z",
"vx","vy","vz","max_r","rms_v",
- "mag_A", "mag_B", "mag_C", "e1_vec0",
- "e1_vec1", "e1_vec2", "tilt", "\n"]))
+ "mag_A", "mag_B", "mag_C", "e0_vec0",
+ "e0_vec1", "e0_vec2", "tilt", "\n"]))
for group in self:
f.write("%10i\t" % group.id)
@@ -1494,17 +1495,17 @@
mag_A = float(line[15])
mag_B = float(line[16])
mag_C = float(line[17])
- e1_vec0 = float(line[18])
- e1_vec1 = float(line[19])
- e1_vec2 = float(line[20])
- e1_vec = np.array([e1_vec0, e1_vec1, e1_vec2])
+ e0_vec0 = float(line[18])
+ e0_vec1 = float(line[19])
+ e0_vec2 = float(line[20])
+ e0_vec = np.array([e0_vec0, e0_vec1, e0_vec2])
tilt = float(line[21])
self._groups.append(LoadedHalo(self.ds, halo, size = size,
CoM = CoM,
max_dens_point = max_dens_point,
group_total_mass = group_total_mass, max_radius = max_radius,
bulk_vel = bulk_vel, rms_vel = rms_vel, fnames = fnames,
- mag_A = mag_A, mag_B = mag_B, mag_C = mag_C, e1_vec = e1_vec,
+ mag_A = mag_A, mag_B = mag_B, mag_C = mag_C, e0_vec = e0_vec,
tilt = tilt))
else:
mylog.error("I am unable to parse this line. Too many or too few items. %s" % orig)
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -225,6 +225,9 @@
self.weight_field = weight_field
self._set_center(center)
if data_source is None: data_source = self.ds.all_data()
+ for k, v in data_source.field_parameters.items():
+ if k not in self.field_parameters or self._is_default_field_parameter(k):
+ self.set_field_parameter(k, v)
self.data_source = data_source
self.weight_field = weight_field
self.get_data(field)
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/data_containers.py
--- a/yt/data_objects/data_containers.py
+++ b/yt/data_objects/data_containers.py
@@ -107,10 +107,19 @@
self.ds.objects.append(weakref.proxy(self))
mylog.debug("Appending object to %s (type: %s)", self.ds, type(self))
self.field_data = YTFieldData()
- if field_parameters is None: field_parameters = {}
+ self._default_field_parameters = {
+ 'center': np.zeros(3, dtype='float64'),
+ 'bulk_velocity': np.zeros(3, dtype='float64'),
+ 'normal': np.zeros(3, dtype='float64'),
+ }
+ if field_parameters is None:
+ self.field_parameters = {}
+ else:
+ self.field_parameters = field_parameters
self._set_default_field_parameters()
- for key, val in field_parameters.items():
- mylog.debug("Setting %s to %s", key, val)
+ for key, val in self.field_parameters.items():
+ if not self._is_default_field_parameter(key):
+ mylog.debug("Setting %s to %s", key, val)
self.set_field_parameter(key, val)
@property
@@ -125,13 +134,13 @@
return self._index
def _set_default_field_parameters(self):
- self.field_parameters = {}
- self.set_field_parameter(
- "center",self.ds.arr(np.zeros(3,dtype='float64'),'cm'))
- self.set_field_parameter(
- "bulk_velocity",self.ds.arr(np.zeros(3,dtype='float64'),'cm/s'))
- self.set_field_parameter(
- "normal",np.array([0,0,1],dtype='float64'))
+ for k,v in self._default_field_parameters.items():
+ self.set_field_parameter(k,v)
+
+ def _is_default_field_parameter(self, parameter):
+ if parameter not in self._default_field_parameters:
+ return False
+ return self._default_field_parameters[parameter] is self.field_parameters[parameter]
def apply_units(self, arr, units):
return self.ds.arr(arr, input_units = units)
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/tests/test_boolean_regions.py
--- a/yt/data_objects/tests/test_boolean_regions.py
+++ b/yt/data_objects/tests/test_boolean_regions.py
@@ -256,10 +256,8 @@
for n in [1, 2, 4, 8]:
ds = fake_random_ds(64, nprocs=n)
ds.index
- ell1 = ds.ellipsoid([0.25]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
- np.array([0.1]*3))
- ell2 = ds.ellipsoid([0.75]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
- np.array([0.1]*3))
+ ell1 = ds.ellipsoid([0.25]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
+ ell2 = ds.ellipsoid([0.75]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
# Store the original indices
i1 = ell1['ID']
i1.sort()
@@ -298,10 +296,8 @@
for n in [1, 2, 4, 8]:
ds = fake_random_ds(64, nprocs=n)
ds.index
- ell1 = ds.ellipsoid([0.45]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
- np.array([0.1]*3))
- ell2 = ds.ellipsoid([0.55]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
- np.array([0.1]*3))
+ ell1 = ds.ellipsoid([0.45]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
+ ell2 = ds.ellipsoid([0.55]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
# Get indices of both.
i1 = ell1['ID']
i2 = ell2['ID']
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/tests/test_projection.py
--- a/yt/data_objects/tests/test_projection.py
+++ b/yt/data_objects/tests/test_projection.py
@@ -35,6 +35,12 @@
rho_tot = dd.quantities["TotalQuantity"]("density")
coords = np.mgrid[xi:xf:xn*1j, yi:yf:yn*1j, zi:zf:zn*1j]
uc = [np.unique(c) for c in coords]
+ # test if projections inherit the field parameters of their data sources
+ dd.set_field_parameter("bulk_velocity", np.array([0,1,2]))
+ proj = ds.proj(0, "density", data_source=dd)
+ yield assert_equal, dd.field_parameters["bulk_velocity"], \
+ proj.field_parameters["bulk_velocity"]
+
# Some simple projection tests with single grids
for ax, an in enumerate("xyz"):
xax = ds.coordinates.x_axis[ax]
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_container.py
--- a/yt/visualization/plot_container.py
+++ b/yt/visualization/plot_container.py
@@ -102,18 +102,22 @@
log_transform = FieldTransform('log10', np.log10, LogLocator())
linear_transform = FieldTransform('linear', lambda x: x, LinearLocator())
-class PlotDictionary(dict):
+class PlotDictionary(defaultdict):
def __getitem__(self, item):
- item = self.data_source._determine_fields(item)[0]
- return dict.__getitem__(self, item)
+ return defaultdict.__getitem__(
+ self, self.data_source._determine_fields(item)[0])
+
+ def __setitem__(self, item, value):
+ return defaultdict.__setitem__(
+ self, self.data_source._determine_fields(item)[0], value)
def __contains__(self, item):
- item = self.data_source._determine_fields(item)[0]
- return dict.__contains__(self, item)
+ return defaultdict.__contains__(
+ self, self.data_source._determine_fields(item)[0])
- def __init__(self, data_source, *args):
+ def __init__(self, data_source, default_factory=None):
self.data_source = data_source
- return dict.__init__(self, args)
+ return defaultdict.__init__(self, default_factory)
class ImagePlotContainer(object):
"""A countainer for plots with colorbars.
@@ -136,6 +140,10 @@
font_path = matplotlib.get_data_path() + '/fonts/ttf/STIXGeneral.ttf'
self._font_properties = FontProperties(size=fontsize, fname=font_path)
self._font_color = None
+ self._xlabel = None
+ self._ylabel = None
+ self._colorbar_label = PlotDictionary(
+ self.data_source, lambda: None)
@invalidate_plot
def set_log(self, field, log):
@@ -184,7 +192,7 @@
@invalidate_plot
def set_transform(self, field, name):
field = self.data_source._determine_fields(field)[0]
- if name not in field_transforms:
+ if name not in field_transforms:
raise KeyError(name)
self._field_transform[field] = field_transforms[name]
return self
@@ -529,3 +537,59 @@
img = base64.b64encode(self.plots[field]._repr_png_())
ret += '<img src="data:image/png;base64,%s"><br>' % img
return ret
+
+ @invalidate_plot
+ def set_xlabel(self, label):
+ r"""
+ Allow the user to modify the X-axis title
+ Defaults to the global value. Fontsize defaults
+ to 18.
+
+ Parameters
+ ----------
+ x_title: str
+ The new string for the x-axis.
+
+ >>> plot.set_xtitle("H2I Number Density (cm$^{-3}$)")
+
+ """
+ self._xlabel = label
+ return self
+
+ @invalidate_plot
+ def set_ylabel(self, label):
+ r"""
+ Allow the user to modify the Y-axis title
+ Defaults to the global value.
+
+ Parameters
+ ----------
+ label: str
+ The new string for the y-axis.
+
+ >>> plot.set_ytitle("Temperature (K)")
+
+ """
+ self._ylabel = label
+ return self
+
+ @invalidate_plot
+ def set_colorbar_label(self, field, label):
+ r"""
+ Sets the colorbar label.
+
+ Parameters
+ ----------
+ field: str or tuple
+ The name of the field to modify the label for.
+ label: str
+ The new label
+
+ >>> plot.set_colorbar_label("Enclosed Gas Mass ($M_{\odot}$)")
+
+ """
+ self._colorbar_label[field] = label
+ return self
+
+ def _get_axes_labels(self, field):
+ return(self._xlabel, self._ylabel, self._colorbar_label[field])
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -20,6 +20,7 @@
from distutils.version import LooseVersion
from matplotlib.patches import Circle
+from matplotlib.colors import colorConverter
from yt.funcs import *
from yt.extern.six import add_metaclass
@@ -369,20 +370,26 @@
class GridBoundaryCallback(PlotCallback):
"""
annotate_grids(alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True,
- min_level=None, max_level=None, cmap='B-W LINEAR_r'):
+ min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None,
+ linewidth=1.0):
Draws grids on an existing PlotWindow object.
Adds grid boundaries to a plot, optionally with alpha-blending. By default,
colors different levels of grids with different colors going from white to
- black, but you can change to any arbitrary colormap with cmap keyword
- (or all black cells for all levels with cmap=None). Cuttoff for display is at
- min_pix wide. draw_ids puts the grid id in the corner of the grid.
+ black, but you can change to any arbitrary colormap with cmap keyword, to all black
+ grid edges for all levels with cmap=None and edgecolors=None, or to an arbitrary single
+ color for grid edges with edgecolors='YourChosenColor' defined in any of the standard ways
+ (e.g., edgecolors='white', edgecolors='r', edgecolors='#00FFFF', or edgecolor='0.3', where
+ the last is a float in 0-1 scale indicating gray).
+ Note that setting edgecolors overrides cmap if you have both set to non-None values.
+ Cutoff for display is at min_pix wide. draw_ids puts the grid id in the corner of the grid.
(Not so great in projections...). One can set min and maximum level of
- grids to display.
+ grids to display, and can change the linewidth of the displayed grids.
"""
_type_name = "grids"
def __init__(self, alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True,
- min_level=None, max_level=None, cmap='B-W LINEAR_r'):
+ min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None,
+ linewidth=1.0):
PlotCallback.__init__(self)
self.alpha = alpha
self.min_pix = min_pix
@@ -391,7 +398,9 @@
self.periodic = periodic
self.min_level = min_level
self.max_level = max_level
+ self.linewidth = linewidth
self.cmap = cmap
+ self.edgecolors = edgecolors
def __call__(self, plot):
x0, x1 = plot.xlim
@@ -433,13 +442,18 @@
( levels >= min_level) & \
( levels <= max_level)
- if self.cmap is not None:
- edgecolors = apply_colormap(levels[(levels <= max_level) & (levels >= min_level)]*1.0,
- color_bounds=[0,plot.data.ds.index.max_level],
- cmap_name=self.cmap)[0,:,:]*1.0/255.
- edgecolors[:,3] = self.alpha
- else:
- edgecolors = (0.0,0.0,0.0,self.alpha)
+ # Grids can either be set by edgecolors OR a colormap.
+ if self.edgecolors is not None:
+ edgecolors = colorConverter.to_rgba(self.edgecolors, alpha=self.alpha)
+ else: # use colormap if not explicity overridden by edgecolors
+ if self.cmap is not None:
+ sample_levels = levels[(levels <= max_level) & (levels >= min_level)]
+ color_bounds = [0,plot.data.pf.h.max_level]
+ edgecolors = apply_colormap(sample_levels*1.0, color_bounds=color_bounds,
+ cmap_name=self.cmap)[0,:,:]*1.0/255.
+ edgecolors[:,3] = self.alpha
+ else:
+ edgecolors = (0.0,0.0,0.0,self.alpha)
if visible.nonzero()[0].size == 0: continue
verts = np.array(
@@ -447,8 +461,7 @@
(left_edge_y, right_edge_y, right_edge_y, left_edge_y)])
verts=verts.transpose()[visible,:,:]
grid_collection = matplotlib.collections.PolyCollection(
- verts, facecolors="none",
- edgecolors=edgecolors)
+ verts, facecolors="none", edgecolors=edgecolors, linewidth=self.linewidth)
plot._axes.hold(True)
plot._axes.add_collection(grid_collection)
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_window.py
--- a/yt/visualization/plot_window.py
+++ b/yt/visualization/plot_window.py
@@ -885,25 +885,33 @@
yax = self.ds.coordinates.y_axis[axis_index]
if hasattr(self.ds.coordinates, "axis_default_unit_label"):
- axes_unit_labels = [self.ds.coordinates.axis_default_unit_name[xax],
- self.ds.coordinates.axis_default_unit_name[yax]]
+ axes_unit_labels = \
+ [self.ds.coordinates.axis_default_unit_name[xax],
+ self.ds.coordinates.axis_default_unit_name[yax]]
labels = [r'$\rm{'+axis_names[xax]+axes_unit_labels[0] + r'}$',
r'$\rm{'+axis_names[yax]+axes_unit_labels[1] + r'}$']
if hasattr(self.ds.coordinates, "axis_field"):
if xax in self.ds.coordinates.axis_field:
- xmin, xmax = self.ds.coordinates.axis_field[xax](0,
- self.xlim, self.ylim)
+ xmin, xmax = self.ds.coordinates.axis_field[xax](
+ 0, self.xlim, self.ylim)
else:
xmin, xmax = [float(x) for x in extentx]
if yax in self.ds.coordinates.axis_field:
- ymin, ymax = self.ds.coordinates.axis_field[yax](1,
- self.xlim, self.ylim)
+ ymin, ymax = self.ds.coordinates.axis_field[yax](
+ 1, self.xlim, self.ylim)
else:
ymin, ymax = [float(y) for y in extenty]
self.plots[f].image.set_extent((xmin,xmax,ymin,ymax))
self.plots[f].axes.set_aspect("auto")
+ x_label, y_label, colorbar_label = self._get_axes_labels(f)
+
+ if x_label is not None:
+ labels[0] = x_label
+ if y_label is not None:
+ labels[1] = y_label
+
self.plots[f].axes.set_xlabel(labels[0],fontproperties=fp)
self.plots[f].axes.set_ylabel(labels[1],fontproperties=fp)
@@ -913,21 +921,18 @@
self.plots[f].axes.yaxis.get_offset_text()]):
label.set_fontproperties(fp)
- colorbar_label = image.info['label']
-
- # If we're creating a plot of a projection, change the displayed
- # field name accordingly.
- if hasattr(self, 'projected'):
- colorbar_label = "$\\rm{Projected }$ %s" % colorbar_label
-
# Determine the units of the data
units = Unit(self.frb[f].units, registry=self.ds.unit_registry)
units = units.latex_representation()
- if units is None or units == '':
- pass
- else:
- colorbar_label += r'$\/\/('+units+r')$'
+ if colorbar_label is None:
+ colorbar_label = image.info['label']
+ if hasattr(self, 'projected'):
+ colorbar_label = "$\\rm{Projected }$ %s" % colorbar_label
+ if units is None or units == '':
+ pass
+ else:
+ colorbar_label += r'$\/\/('+units+r')$'
parser = MathTextParser('Agg')
try:
diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/profile_plotter.py
--- a/yt/visualization/profile_plotter.py
+++ b/yt/visualization/profile_plotter.py
@@ -679,13 +679,11 @@
"""
x_log = None
y_log = None
- x_title = None
- y_title = None
- z_title = None
plot_title = None
_plot_valid = False
_plot_type = 'Phase'
+
def __init__(self, data_source, x_field, y_field, z_fields,
weight_field="cell_mass", x_bins=128, y_bins=128,
accumulation=False, fractional=False,
@@ -700,16 +698,22 @@
accumulation=accumulation,
fractional=fractional)
- type(self)._initialize_instance(self, data_source, profile, fontsize, figure_size)
+ type(self)._initialize_instance(self, data_source, profile, fontsize,
+ figure_size)
@classmethod
- def _initialize_instance(cls, obj, data_source, profile, fontsize, figure_size):
+ def _initialize_instance(cls, obj, data_source, profile, fontsize,
+ figure_size):
obj.plot_title = {}
obj.z_log = {}
obj.z_title = {}
obj._initfinished = False
obj.x_log = None
obj.y_log = None
+ obj._plot_text = {}
+ obj._text_xpos = {}
+ obj._text_ypos = {}
+ obj._text_kwargs = {}
obj.profile = profile
super(PhasePlot, obj).__init__(data_source, figure_size, fontsize)
obj._setup_plots()
@@ -729,10 +733,11 @@
y_unit = profile.y.units
z_unit = profile.field_units[field_z]
fractional = profile.fractional
- x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
- y_title = self.y_title or self._get_field_label(field_y, yfi, y_unit)
- z_title = self.z_title.get(field_z, None) or \
- self._get_field_label(field_z, zfi, z_unit, fractional)
+ x_label, y_label, z_label = self._get_axes_labels(field_z)
+ x_title = x_label or self._get_field_label(field_x, xfi, x_unit)
+ y_title = y_label or self._get_field_label(field_y, yfi, y_unit)
+ z_title = z_label or self._get_field_label(field_z, zfi, z_unit,
+ fractional)
return (x_title, y_title, z_title)
def _get_field_label(self, field, field_info, field_unit, fractional=False):
@@ -827,6 +832,12 @@
self.plots[f].axes.yaxis.set_label_text(y_title, fontproperties=fp)
self.plots[f].cax.yaxis.set_label_text(z_title, fontproperties=fp)
+ if f in self._plot_text:
+ self.plots[f].axes.text(self._text_xpos[f], self._text_ypos[f],
+ self._plot_text[f],
+ fontproperties=self._font_properties,
+ **self._text_kwargs[f])
+
if f in self.plot_title:
self.plots[f].axes.set_title(self.plot_title[f])
@@ -877,6 +888,41 @@
return cls._initialize_instance(obj, data_source, profile, fontsize,
figure_size)
+
+ def annotate_text(self, xpos=0.0, ypos=0.0, text=None, **text_kwargs):
+ r"""
+ Allow the user to insert text onto the plot
+ The x-position and y-position must be given as well as the text string.
+ Add *text* tp plot at location *xpos*, *ypos* in plot coordinates
+ (see example below).
+
+ Parameters
+ ----------
+ field: str or tuple
+ The name of the field to add text to.
+ xpos: float
+ Position on plot in x-coordinates.
+ ypos: float
+ Position on plot in y-coordinates.
+ text: str
+ The text to insert onto the plot.
+ text_kwargs: dict
+ Dictionary of text keyword arguments to be passed to matplotlib
+
+ >>> plot.annotate_text(1e-15, 5e4, "Hello YT")
+
+ """
+ for f in self.data_source._determine_fields(self.plots.keys()):
+ if self.plots[f].figure is not None and text is not None:
+ self.plots[f].axes.text(xpos, ypos, text,
+ fontproperties=self._font_properties,
+ **text_kwargs)
+ self._plot_text[f] = text
+ self._text_xpos[f] = xpos
+ self._text_ypos[f] = ypos
+ self._text_kwargs[f] = text_kwargs
+ return self
+
def save(self, name=None, mpl_kwargs=None):
r"""
Saves a 2d profile plot.
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list