[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt/yt-3.0 (pull request #1060)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Wed Jul 23 15:04:40 PDT 2014


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/e97899b53081/
Changeset:   e97899b53081
Branch:      yt-3.0
User:        MatthewTurk
Date:        2014-07-24 00:04:30
Summary:     Merged in ngoldbaum/yt/yt-3.0 (pull request #1060)

Merging from the yt branch.
Affected #:  74 files

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/analysis_modules/halo_finding/halo_objects.py
--- a/yt/analysis_modules/halo_finding/halo_objects.py
+++ b/yt/analysis_modules/halo_finding/halo_objects.py
@@ -549,22 +549,23 @@
             temp_e2[:,dim] = e2_vector[dim]
         length = np.abs(np.sum(rr * temp_e2, axis = 1) * (1 - \
             np.sum(rr * temp_e0, axis = 1)**2. * mag_A**-2. - \
-            np.sum(rr * temp_e1, axis = 1)**2. * mag_B**-2)**(-0.5))
+            np.sum(rr * temp_e1, axis = 1)**2. * mag_B**-2.)**(-0.5))
         length[length == np.inf] = 0.
         tC_index = np.nanargmax(length)
         mag_C = length[tC_index]
         # tilt is calculated from the rotation about x axis
         # needed to align e1 vector with the y axis
         # after e0 is aligned with x axis
-        # find the t1 angle needed to rotate about z axis to align e0 to x
-        t1 = np.arctan(e0_vector[1] / e0_vector[0])
-        RZ = get_rotation_matrix(-t1, (0, 0, 1)).transpose()
-        r1 = (e0_vector * RZ).sum(axis = 1)
+        # find the t1 angle needed to rotate about z axis to align e0 onto x-z plane
+        t1 = np.arctan(-e0_vector[1] / e0_vector[0])
+        RZ = get_rotation_matrix(t1, (0, 0, 1))
+        r1 = np.dot(RZ, e0_vector)
         # find the t2 angle needed to rotate about y axis to align e0 to x
-        t2 = np.arctan(-r1[2] / r1[0])
-        RY = get_rotation_matrix(-t2, (0, 1, 0)).transpose()
+        t2 = np.arctan(r1[2] / r1[0])
+        RY = get_rotation_matrix(t2, (0, 1, 0))
         r2 = np.dot(RY, np.dot(RZ, e1_vector))
-        tilt = np.arctan(r2[2]/r2[1])
+        # find the tilt angle needed to rotate about x axis to align e1 to y and e2 to z
+        tilt = np.arctan(-r2[2] / r2[1])
         return (mag_A, mag_B, mag_C, e0_vector[0], e0_vector[1],
             e0_vector[2], tilt)
 
@@ -782,13 +783,13 @@
         
         Returns
         -------
-        tuple : (cm, mag_A, mag_B, mag_C, e1_vector, tilt)
+        tuple : (cm, mag_A, mag_B, mag_C, e0_vector, tilt)
             The 6-tuple has in order:
               #. The center of mass as an array.
               #. mag_A as a float.
               #. mag_B as a float.
               #. mag_C as a float.
-              #. e1_vector as an array.
+              #. e0_vector as an array.
               #. tilt as a float.
         
         Examples
@@ -819,7 +820,7 @@
     def __init__(self, ds, id, size=None, CoM=None,
         max_dens_point=None, group_total_mass=None, max_radius=None, bulk_vel=None,
         rms_vel=None, fnames=None, mag_A=None, mag_B=None, mag_C=None,
-        e1_vec=None, tilt=None, supp=None):
+        e0_vec=None, tilt=None, supp=None):
 
         self.ds = ds
         self.gridsize = (self.ds.domain_right_edge - \
@@ -835,7 +836,7 @@
         self.mag_A = mag_A
         self.mag_B = mag_B
         self.mag_C = mag_C
-        self.e1_vec = e1_vec
+        self.e0_vec = e0_vec
         self.tilt = tilt
         # locs=the names of the h5 files that have particle data for this halo
         self.fnames = fnames
@@ -928,8 +929,8 @@
 
     def _get_ellipsoid_parameters_basic_loadedhalo(self):
         if self.mag_A is not None:
-            return (self.mag_A, self.mag_B, self.mag_C, self.e1_vec[0],
-                self.e1_vec[1], self.e1_vec[2], self.tilt)
+            return (self.mag_A, self.mag_B, self.mag_C, self.e0_vec[0],
+                self.e0_vec[1], self.e0_vec[2], self.tilt)
         else:
             return self._get_ellipsoid_parameters_basic()
 
@@ -943,13 +944,13 @@
 
         Returns
         -------
-        tuple : (cm, mag_A, mag_B, mag_C, e1_vector, tilt)
+        tuple : (cm, mag_A, mag_B, mag_C, e0_vector, tilt)
             The 6-tuple has in order:
               #. The center of mass as an array.
               #. mag_A as a float.
               #. mag_B as a float.
               #. mag_C as a float.
-              #. e1_vector as an array.
+              #. e0_vector as an array.
               #. tilt as a float.
 
         Examples
@@ -1021,7 +1022,7 @@
 
         max_dens_point=None, group_total_mass=None, max_radius=None, bulk_vel=None,
         rms_vel=None, fnames=None, mag_A=None, mag_B=None, mag_C=None,
-        e1_vec=None, tilt=None, supp=None):
+        e0_vec=None, tilt=None, supp=None):
 
         self.ds = ds
         self.gridsize = (self.ds.domain_right_edge - \
@@ -1037,7 +1038,7 @@
         self.mag_A = mag_A
         self.mag_B = mag_B
         self.mag_C = mag_C
-        self.e1_vec = e1_vec
+        self.e0_vec = e0_vec
         self.tilt = tilt
         self.bin_count = None
         self.overdensity = None
@@ -1181,8 +1182,8 @@
                                "x","y","z", "center-of-mass",
                                "x","y","z",
                                "vx","vy","vz","max_r","rms_v",
-                               "mag_A", "mag_B", "mag_C", "e1_vec0",
-                               "e1_vec1", "e1_vec2", "tilt", "\n"]))
+                               "mag_A", "mag_B", "mag_C", "e0_vec0",
+                               "e0_vec1", "e0_vec2", "tilt", "\n"]))
 
         for group in self:
             f.write("%10i\t" % group.id)
@@ -1494,17 +1495,17 @@
                 mag_A = float(line[15])
                 mag_B = float(line[16])
                 mag_C = float(line[17])
-                e1_vec0 = float(line[18])
-                e1_vec1 = float(line[19])
-                e1_vec2 = float(line[20])
-                e1_vec = np.array([e1_vec0, e1_vec1, e1_vec2])
+                e0_vec0 = float(line[18])
+                e0_vec1 = float(line[19])
+                e0_vec2 = float(line[20])
+                e0_vec = np.array([e0_vec0, e0_vec1, e0_vec2])
                 tilt = float(line[21])
                 self._groups.append(LoadedHalo(self.ds, halo, size = size,
                     CoM = CoM,
                     max_dens_point = max_dens_point,
                     group_total_mass = group_total_mass, max_radius = max_radius,
                     bulk_vel = bulk_vel, rms_vel = rms_vel, fnames = fnames,
-                    mag_A = mag_A, mag_B = mag_B, mag_C = mag_C, e1_vec = e1_vec,
+                    mag_A = mag_A, mag_B = mag_B, mag_C = mag_C, e0_vec = e0_vec,
                     tilt = tilt))
             else:
                 mylog.error("I am unable to parse this line. Too many or too few items. %s" % orig)

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -225,6 +225,9 @@
         self.weight_field = weight_field
         self._set_center(center)
         if data_source is None: data_source = self.ds.all_data()
+        for k, v in data_source.field_parameters.items():
+            if k not in self.field_parameters or self._is_default_field_parameter(k):
+                self.set_field_parameter(k, v)
         self.data_source = data_source
         self.weight_field = weight_field
         self.get_data(field)

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/data_containers.py
--- a/yt/data_objects/data_containers.py
+++ b/yt/data_objects/data_containers.py
@@ -107,10 +107,19 @@
         self.ds.objects.append(weakref.proxy(self))
         mylog.debug("Appending object to %s (type: %s)", self.ds, type(self))
         self.field_data = YTFieldData()
-        if field_parameters is None: field_parameters = {}
+        self._default_field_parameters = {
+            'center': np.zeros(3, dtype='float64'),
+            'bulk_velocity': np.zeros(3, dtype='float64'),
+            'normal': np.zeros(3, dtype='float64'),
+        }
+        if field_parameters is None:
+            self.field_parameters = {}
+        else:
+            self.field_parameters = field_parameters
         self._set_default_field_parameters()
-        for key, val in field_parameters.items():
-            mylog.debug("Setting %s to %s", key, val)
+        for key, val in self.field_parameters.items():
+            if not self._is_default_field_parameter(key):
+                mylog.debug("Setting %s to %s", key, val)
             self.set_field_parameter(key, val)
 
     @property
@@ -125,13 +134,13 @@
         return self._index
 
     def _set_default_field_parameters(self):
-        self.field_parameters = {}
-        self.set_field_parameter(
-            "center",self.ds.arr(np.zeros(3,dtype='float64'),'cm'))
-        self.set_field_parameter(
-            "bulk_velocity",self.ds.arr(np.zeros(3,dtype='float64'),'cm/s'))
-        self.set_field_parameter(
-            "normal",np.array([0,0,1],dtype='float64'))
+        for k,v in self._default_field_parameters.items():
+            self.set_field_parameter(k,v)
+
+    def _is_default_field_parameter(self, parameter):
+        if parameter not in self._default_field_parameters:
+            return False
+        return self._default_field_parameters[parameter] is self.field_parameters[parameter]
 
     def apply_units(self, arr, units):
         return self.ds.arr(arr, input_units = units)

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/tests/test_boolean_regions.py
--- a/yt/data_objects/tests/test_boolean_regions.py
+++ b/yt/data_objects/tests/test_boolean_regions.py
@@ -256,10 +256,8 @@
     for n in [1, 2, 4, 8]:
         ds = fake_random_ds(64, nprocs=n)
         ds.index
-        ell1 = ds.ellipsoid([0.25]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
-            np.array([0.1]*3))
-        ell2 = ds.ellipsoid([0.75]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
-            np.array([0.1]*3))
+        ell1 = ds.ellipsoid([0.25]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
+        ell2 = ds.ellipsoid([0.75]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
         # Store the original indices
         i1 = ell1['ID']
         i1.sort()
@@ -298,10 +296,8 @@
     for n in [1, 2, 4, 8]:
         ds = fake_random_ds(64, nprocs=n)
         ds.index
-        ell1 = ds.ellipsoid([0.45]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
-            np.array([0.1]*3))
-        ell2 = ds.ellipsoid([0.55]*3, 0.05, 0.05, 0.05, np.array([0.1]*3),
-            np.array([0.1]*3))
+        ell1 = ds.ellipsoid([0.45]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
+        ell2 = ds.ellipsoid([0.55]*3, 0.05, 0.05, 0.05, np.array([0.1]*3), 0.1)
         # Get indices of both.
         i1 = ell1['ID']
         i2 = ell2['ID']

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/data_objects/tests/test_projection.py
--- a/yt/data_objects/tests/test_projection.py
+++ b/yt/data_objects/tests/test_projection.py
@@ -35,6 +35,12 @@
         rho_tot = dd.quantities["TotalQuantity"]("density")
         coords = np.mgrid[xi:xf:xn*1j, yi:yf:yn*1j, zi:zf:zn*1j]
         uc = [np.unique(c) for c in coords]
+        # test if projections inherit the field parameters of their data sources
+        dd.set_field_parameter("bulk_velocity", np.array([0,1,2]))
+        proj = ds.proj(0, "density", data_source=dd)
+        yield assert_equal, dd.field_parameters["bulk_velocity"], \
+          proj.field_parameters["bulk_velocity"]
+
         # Some simple projection tests with single grids
         for ax, an in enumerate("xyz"):
             xax = ds.coordinates.x_axis[ax]

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_container.py
--- a/yt/visualization/plot_container.py
+++ b/yt/visualization/plot_container.py
@@ -102,18 +102,22 @@
 log_transform = FieldTransform('log10', np.log10, LogLocator())
 linear_transform = FieldTransform('linear', lambda x: x, LinearLocator())
 
-class PlotDictionary(dict):
+class PlotDictionary(defaultdict):
     def __getitem__(self, item):
-        item = self.data_source._determine_fields(item)[0]
-        return dict.__getitem__(self, item)
+        return defaultdict.__getitem__(
+            self, self.data_source._determine_fields(item)[0])
+
+    def __setitem__(self, item, value):
+        return defaultdict.__setitem__(
+            self, self.data_source._determine_fields(item)[0], value)
 
     def __contains__(self, item):
-        item = self.data_source._determine_fields(item)[0]
-        return dict.__contains__(self, item)
+        return defaultdict.__contains__(
+            self, self.data_source._determine_fields(item)[0])
 
-    def __init__(self, data_source, *args):
+    def __init__(self, data_source, default_factory=None):
         self.data_source = data_source
-        return dict.__init__(self, args)
+        return defaultdict.__init__(self, default_factory)
 
 class ImagePlotContainer(object):
     """A countainer for plots with colorbars.
@@ -136,6 +140,10 @@
         font_path = matplotlib.get_data_path() + '/fonts/ttf/STIXGeneral.ttf'
         self._font_properties = FontProperties(size=fontsize, fname=font_path)
         self._font_color = None
+        self._xlabel = None
+        self._ylabel = None
+        self._colorbar_label = PlotDictionary(
+            self.data_source, lambda: None)
 
     @invalidate_plot
     def set_log(self, field, log):
@@ -184,7 +192,7 @@
     @invalidate_plot
     def set_transform(self, field, name):
         field = self.data_source._determine_fields(field)[0]
-        if name not in field_transforms: 
+        if name not in field_transforms:
             raise KeyError(name)
         self._field_transform[field] = field_transforms[name]
         return self
@@ -529,3 +537,59 @@
             img = base64.b64encode(self.plots[field]._repr_png_())
             ret += '<img src="data:image/png;base64,%s"><br>' % img
         return ret
+
+    @invalidate_plot
+    def set_xlabel(self, label):
+        r"""
+        Allow the user to modify the X-axis title
+        Defaults to the global value. Fontsize defaults
+        to 18.
+
+        Parameters
+        ----------
+        x_title: str
+              The new string for the x-axis.
+
+        >>>  plot.set_xtitle("H2I Number Density (cm$^{-3}$)")
+
+        """
+        self._xlabel = label
+        return self
+
+    @invalidate_plot
+    def set_ylabel(self, label):
+        r"""
+        Allow the user to modify the Y-axis title
+        Defaults to the global value.
+
+        Parameters
+        ----------
+        label: str
+          The new string for the y-axis.
+
+        >>>  plot.set_ytitle("Temperature (K)")
+
+        """
+        self._ylabel = label
+        return self
+
+    @invalidate_plot
+    def set_colorbar_label(self, field, label):
+        r"""
+        Sets the colorbar label.
+
+        Parameters
+        ----------
+        field: str or tuple
+          The name of the field to modify the label for.
+        label: str
+          The new label
+
+        >>>  plot.set_colorbar_label("Enclosed Gas Mass ($M_{\odot}$)")
+
+        """
+        self._colorbar_label[field] = label
+        return self
+
+    def _get_axes_labels(self, field):
+        return(self._xlabel, self._ylabel, self._colorbar_label[field])

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -20,6 +20,7 @@
 from distutils.version import LooseVersion
 
 from matplotlib.patches import Circle
+from matplotlib.colors import colorConverter
 
 from yt.funcs import *
 from yt.extern.six import add_metaclass
@@ -369,20 +370,26 @@
 class GridBoundaryCallback(PlotCallback):
     """
     annotate_grids(alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True, 
-                 min_level=None, max_level=None, cmap='B-W LINEAR_r'):
+                 min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None,
+                 linewidth=1.0):
 
     Draws grids on an existing PlotWindow object.
     Adds grid boundaries to a plot, optionally with alpha-blending. By default, 
     colors different levels of grids with different colors going from white to
-    black, but you can change to any arbitrary colormap with cmap keyword 
-    (or all black cells for all levels with cmap=None).  Cuttoff for display is at 
-    min_pix wide. draw_ids puts the grid id in the corner of the grid. 
+    black, but you can change to any arbitrary colormap with cmap keyword, to all black
+    grid edges for all levels with cmap=None and edgecolors=None, or to an arbitrary single
+    color for grid edges with edgecolors='YourChosenColor' defined in any of the standard ways
+    (e.g., edgecolors='white', edgecolors='r', edgecolors='#00FFFF', or edgecolor='0.3', where
+    the last is a float in 0-1 scale indicating gray).
+    Note that setting edgecolors overrides cmap if you have both set to non-None values.
+    Cutoff for display is at min_pix wide. draw_ids puts the grid id in the corner of the grid.
     (Not so great in projections...).  One can set min and maximum level of
-    grids to display.
+    grids to display, and can change the linewidth of the displayed grids.
     """
     _type_name = "grids"
     def __init__(self, alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True, 
-                 min_level=None, max_level=None, cmap='B-W LINEAR_r'):
+                 min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None,
+                 linewidth=1.0):
         PlotCallback.__init__(self)
         self.alpha = alpha
         self.min_pix = min_pix
@@ -391,7 +398,9 @@
         self.periodic = periodic
         self.min_level = min_level
         self.max_level = max_level
+        self.linewidth = linewidth
         self.cmap = cmap
+        self.edgecolors = edgecolors
 
     def __call__(self, plot):
         x0, x1 = plot.xlim
@@ -433,13 +442,18 @@
                        ( levels >= min_level) & \
                        ( levels <= max_level)
 
-            if self.cmap is not None: 
-                edgecolors = apply_colormap(levels[(levels <= max_level) & (levels >= min_level)]*1.0,
-                                  color_bounds=[0,plot.data.ds.index.max_level],
-                                  cmap_name=self.cmap)[0,:,:]*1.0/255.
-                edgecolors[:,3] = self.alpha
-            else:
-                edgecolors = (0.0,0.0,0.0,self.alpha)
+            # Grids can either be set by edgecolors OR a colormap.
+            if self.edgecolors is not None:
+                edgecolors = colorConverter.to_rgba(self.edgecolors, alpha=self.alpha)
+            else:  # use colormap if not explicity overridden by edgecolors
+                if self.cmap is not None:
+                    sample_levels = levels[(levels <= max_level) & (levels >= min_level)]
+                    color_bounds = [0,plot.data.pf.h.max_level]
+                    edgecolors = apply_colormap(sample_levels*1.0, color_bounds=color_bounds,
+                                                cmap_name=self.cmap)[0,:,:]*1.0/255.
+                    edgecolors[:,3] = self.alpha
+                else:
+                    edgecolors = (0.0,0.0,0.0,self.alpha)
 
             if visible.nonzero()[0].size == 0: continue
             verts = np.array(
@@ -447,8 +461,7 @@
                  (left_edge_y, right_edge_y, right_edge_y, left_edge_y)])
             verts=verts.transpose()[visible,:,:]
             grid_collection = matplotlib.collections.PolyCollection(
-                verts, facecolors="none",
-                edgecolors=edgecolors)
+                verts, facecolors="none", edgecolors=edgecolors, linewidth=self.linewidth)
             plot._axes.hold(True)
             plot._axes.add_collection(grid_collection)
 

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/plot_window.py
--- a/yt/visualization/plot_window.py
+++ b/yt/visualization/plot_window.py
@@ -885,25 +885,33 @@
                 yax = self.ds.coordinates.y_axis[axis_index]
 
                 if hasattr(self.ds.coordinates, "axis_default_unit_label"):
-                    axes_unit_labels = [self.ds.coordinates.axis_default_unit_name[xax],
-                                        self.ds.coordinates.axis_default_unit_name[yax]]
+                    axes_unit_labels = \
+                    [self.ds.coordinates.axis_default_unit_name[xax],
+                     self.ds.coordinates.axis_default_unit_name[yax]]
                 labels = [r'$\rm{'+axis_names[xax]+axes_unit_labels[0] + r'}$',
                           r'$\rm{'+axis_names[yax]+axes_unit_labels[1] + r'}$']
 
                 if hasattr(self.ds.coordinates, "axis_field"):
                     if xax in self.ds.coordinates.axis_field:
-                        xmin, xmax = self.ds.coordinates.axis_field[xax](0,
-                                                                         self.xlim, self.ylim)
+                        xmin, xmax = self.ds.coordinates.axis_field[xax](
+                            0, self.xlim, self.ylim)
                     else:
                         xmin, xmax = [float(x) for x in extentx]
                     if yax in self.ds.coordinates.axis_field:
-                        ymin, ymax = self.ds.coordinates.axis_field[yax](1,
-                                                                         self.xlim, self.ylim)
+                        ymin, ymax = self.ds.coordinates.axis_field[yax](
+                            1, self.xlim, self.ylim)
                     else:
                         ymin, ymax = [float(y) for y in extenty]
                     self.plots[f].image.set_extent((xmin,xmax,ymin,ymax))
                     self.plots[f].axes.set_aspect("auto")
 
+            x_label, y_label, colorbar_label = self._get_axes_labels(f)
+
+            if x_label is not None:
+                labels[0] = x_label
+            if y_label is not None:
+                labels[1] = y_label
+
             self.plots[f].axes.set_xlabel(labels[0],fontproperties=fp)
             self.plots[f].axes.set_ylabel(labels[1],fontproperties=fp)
 
@@ -913,21 +921,18 @@
                            self.plots[f].axes.yaxis.get_offset_text()]):
                 label.set_fontproperties(fp)
 
-            colorbar_label = image.info['label']
-
-            # If we're creating a plot of a projection, change the displayed
-            # field name accordingly.
-            if hasattr(self, 'projected'):
-                colorbar_label = "$\\rm{Projected }$ %s" % colorbar_label
-
             # Determine the units of the data
             units = Unit(self.frb[f].units, registry=self.ds.unit_registry)
             units = units.latex_representation()
 
-            if units is None or units == '':
-                pass
-            else:
-                colorbar_label += r'$\/\/('+units+r')$'
+            if colorbar_label is None:
+                colorbar_label = image.info['label']
+                if hasattr(self, 'projected'):
+                    colorbar_label = "$\\rm{Projected }$ %s" % colorbar_label
+                if units is None or units == '':
+                    pass
+                else:
+                    colorbar_label += r'$\/\/('+units+r')$'
 
             parser = MathTextParser('Agg')
             try:

diff -r 931e363adca844c1046fe6027481b538f1ba2576 -r e97899b530819bec400dbe2127aba867d227dc0b yt/visualization/profile_plotter.py
--- a/yt/visualization/profile_plotter.py
+++ b/yt/visualization/profile_plotter.py
@@ -679,13 +679,11 @@
     """
     x_log = None
     y_log = None
-    x_title = None
-    y_title = None
-    z_title = None
     plot_title = None
     _plot_valid = False
     _plot_type = 'Phase'
 
+
     def __init__(self, data_source, x_field, y_field, z_fields,
                  weight_field="cell_mass", x_bins=128, y_bins=128,
                  accumulation=False, fractional=False,
@@ -700,16 +698,22 @@
             accumulation=accumulation,
             fractional=fractional)
 
-        type(self)._initialize_instance(self, data_source, profile, fontsize, figure_size)
+        type(self)._initialize_instance(self, data_source, profile, fontsize,
+                                        figure_size)
 
     @classmethod
-    def _initialize_instance(cls, obj, data_source, profile, fontsize, figure_size):
+    def _initialize_instance(cls, obj, data_source, profile, fontsize,
+                             figure_size):
         obj.plot_title = {}
         obj.z_log = {}
         obj.z_title = {}
         obj._initfinished = False
         obj.x_log = None
         obj.y_log = None
+        obj._plot_text = {}
+        obj._text_xpos = {}
+        obj._text_ypos = {}
+        obj._text_kwargs = {}
         obj.profile = profile
         super(PhasePlot, obj).__init__(data_source, figure_size, fontsize)
         obj._setup_plots()
@@ -729,10 +733,11 @@
         y_unit = profile.y.units
         z_unit = profile.field_units[field_z]
         fractional = profile.fractional
-        x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
-        y_title = self.y_title or self._get_field_label(field_y, yfi, y_unit)
-        z_title = self.z_title.get(field_z, None) or \
-            self._get_field_label(field_z, zfi, z_unit, fractional)
+        x_label, y_label, z_label = self._get_axes_labels(field_z)
+        x_title = x_label or self._get_field_label(field_x, xfi, x_unit)
+        y_title = y_label or self._get_field_label(field_y, yfi, y_unit)
+        z_title = z_label or self._get_field_label(field_z, zfi, z_unit,
+                                                   fractional)
         return (x_title, y_title, z_title)
 
     def _get_field_label(self, field, field_info, field_unit, fractional=False):
@@ -827,6 +832,12 @@
             self.plots[f].axes.yaxis.set_label_text(y_title, fontproperties=fp)
             self.plots[f].cax.yaxis.set_label_text(z_title, fontproperties=fp)
 
+            if f in self._plot_text:
+                self.plots[f].axes.text(self._text_xpos[f], self._text_ypos[f],
+                                        self._plot_text[f],
+                                        fontproperties=self._font_properties,
+                                        **self._text_kwargs[f])
+
             if f in self.plot_title:
                 self.plots[f].axes.set_title(self.plot_title[f])
 
@@ -877,6 +888,41 @@
         return cls._initialize_instance(obj, data_source, profile, fontsize,
                                         figure_size)
 
+
+    def annotate_text(self, xpos=0.0, ypos=0.0, text=None, **text_kwargs):
+        r"""
+        Allow the user to insert text onto the plot
+        The x-position and y-position must be given as well as the text string. 
+        Add *text* tp plot at location *xpos*, *ypos* in plot coordinates
+        (see example below).
+                
+        Parameters
+        ----------
+        field: str or tuple
+          The name of the field to add text to. 
+        xpos: float
+          Position on plot in x-coordinates.
+        ypos: float
+          Position on plot in y-coordinates.
+        text: str
+          The text to insert onto the plot.
+        text_kwargs: dict
+          Dictionary of text keyword arguments to be passed to matplotlib
+
+        >>>  plot.annotate_text(1e-15, 5e4, "Hello YT")
+
+        """
+        for f in self.data_source._determine_fields(self.plots.keys()):
+            if self.plots[f].figure is not None and text is not None:
+                self.plots[f].axes.text(xpos, ypos, text,
+                                        fontproperties=self._font_properties,
+                                        **text_kwargs)
+            self._plot_text[f] = text
+            self._text_xpos[f] = xpos
+            self._text_ypos[f] = ypos
+            self._text_kwargs[f] = text_kwargs
+        return self
+
     def save(self, name=None, mpl_kwargs=None):
         r"""
         Saves a 2d profile plot.

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list