[yt-svn] commit/yt: 8 new changesets
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Thu May 1 10:02:41 PDT 2014
8 new commits in yt:
https://bitbucket.org/yt_analysis/yt/commits/2e1e4b6d7641/
Changeset: 2e1e4b6d7641
Branch: yt-3.0
User: hegan
Date: 2014-04-23 20:55:42
Summary: annotating halos
Affected #: 1 file
diff -r e194bdae23c0d3ba246f4ea81d915cf8383e7fe3 -r 2e1e4b6d764131385c89d4d760915e2f017aeaaa yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -1,4 +1,5 @@
"""
+
Callbacks to add additional functionality on to plots.
@@ -950,6 +951,69 @@
kwargs["transform"] = plot._axes.transAxes
plot._axes.text(x, y, self.text, **kwargs)
+class HaloCatalogCallback(PlotCallback):
+
+ _type_name = 'halos'
+ region = None
+ _descriptor = None
+
+ def __init__(self, halo_catalog, col='white', alpha =1, width = None):
+ PlotCallback.__init__(self)
+ self.halo_catalog = halo_catalog
+ self.color = col
+ self.alpha = alpha
+ self.width = width
+
+ def __call__(self, plot):
+ data = plot.data
+ x0, x1 = plot.xlim
+ y0, y1 = plot.ylim
+ xx0, xx1 = plot._axes.get_xlim()
+ yy0, yy1 = plot._axes.get_ylim()
+
+ halo_data= self.halo_catalog.halos_pf.all_data()
+ field_x = "particle_position_%s" % axis_names[x_dict[data.axis]]
+ field_y = "particle_position_%s" % axis_names[y_dict[data.axis]]
+ field_z = "particle_position_%s" % axis_names[data.axis]
+ plot._axes.hold(True)
+
+ # Set up scales for pixel size and original data
+ units = 'Mpccm'
+ pixel_scale = self.pixel_scale(plot)[0]
+ data_scale = data.pf.length_unit
+
+ # Convert halo positions to code units of the plotted data
+ # and then to units of the plotted window
+ px = halo_data[field_x][:].in_units(units) / data_scale
+ py = halo_data[field_y][:].in_units(units) / data_scale
+ px, py = self.convert_to_plot(plot,[px,py])
+
+ # Convert halo radii to a radius in pixels
+ radius = halo_data['radius'][:].in_units(units)
+ radius = radius*pixel_scale/data_scale
+
+ if self.width:
+ pz = halo_data[field_z][:].in_units(units)/data_scale
+ pz = data.pf.arr(pz, 'code_length')
+ c = data.center[data.axis]
+
+ # I should catch an error here if width isn't in this form
+ # but I dont really want to reimplement get_sanitized_width...
+ width = data.pf.arr(self.width[0], self.width[1]).in_units('code_length')
+
+ indices = np.where((pz > c-width) & (pz < c+width))
+
+ px = px[indices]
+ py = py[indices]
+ radius = radius[indices]
+
+ plot._axes.scatter(px, py, edgecolors='None', marker='o',
+ s=radius, c=self.color,alpha=self.alpha)
+ plot._axes.set_xlim(xx0,xx1)
+ plot._axes.set_ylim(yy0,yy1)
+ plot._axes.hold(False)
+
+
class ParticleCallback(PlotCallback):
"""
annotate_particles(width, p_size=1.0, col='k', marker='o', stride=1.0,
https://bitbucket.org/yt_analysis/yt/commits/efb77fe1b2fd/
Changeset: efb77fe1b2fd
Branch: yt-3.0
User: hegan
Date: 2014-04-23 22:11:49
Summary: callback docs for annotate_halos
Affected #: 1 file
diff -r 2e1e4b6d764131385c89d4d760915e2f017aeaaa -r efb77fe1b2fd6ab0afd7e3af29240157f681feac doc/source/visualizing/_cb_docstrings.inc
--- a/doc/source/visualizing/_cb_docstrings.inc
+++ b/doc/source/visualizing/_cb_docstrings.inc
@@ -104,6 +104,34 @@
-------------
+.. function:: annotate_halos(self, halo_catalog, col='white', alpha =1, width= None):
+
+ (This is a proxy for :class:`~yt.visualization.plot_modifications.HaloCatalogCallback`.)
+
+ Accepts a :class:`yt.HaloCatalog` *HaloCatalog* and plots
+ a circle at the location of each halo with the radius of
+ the circle corresponding to the virial radius of the halo.
+ If *width* is set to None (default) all halos are plotted.
+ Otherwise, only halos that fall within a slab with width
+ *width* centered on the center of the plot data. The
+ color and transparency of the circles can be controlled with
+ *col* and *alpha* respectively.
+
+.. python-script::
+
+ from yt.mods import *
+ data_pf = load('Enzo_64/RD0006/RD0006')
+ halos_pf = load('rockstar_halos/halos_0.0.bin')
+
+ hc = HaloCatalog(halos_pf=halos_pf)
+ hc.create()
+
+ prj = ProjectionPlot(data_pf, 'z', 'density')
+ prj.annotate_halos(hc)
+ prj.save()
+
+-------------
+
.. function:: annotate_hop_circles(self, hop_output, max_number=None, annotate=False, min_size=20, max_size=10000000, font_size=8, print_halo_size=False, print_halo_mass=False, width=None):
(This is a proxy for :class:`~yt.visualization.plot_modifications.HopCircleCallback`.)
https://bitbucket.org/yt_analysis/yt/commits/cdc3b525086b/
Changeset: cdc3b525086b
Branch: yt-3.0
User: hegan
Date: 2014-04-28 20:33:25
Summary: Removed outdated annotate hop halos and annotate hop circles
Affected #: 1 file
diff -r efb77fe1b2fd6ab0afd7e3af29240157f681feac -r cdc3b525086b5623c8821b3a548f496ef69348ff yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -813,113 +813,6 @@
plot._axes.text(center_x, center_y, self.text,
**self.text_args)
-class HopCircleCallback(PlotCallback):
- """
- annotate_hop_circles(hop_output, max_number=None,
- annotate=False, min_size=20, max_size=10000000,
- font_size=8, print_halo_size=False,
- print_halo_mass=False, width=None)
-
- Accepts a :class:`yt.HopList` *hop_output* and plots up to
- *max_number* (None for unlimited) halos as circles.
- """
- _type_name = "hop_circles"
- def __init__(self, hop_output, max_number=None,
- annotate=False, min_size=20, max_size=10000000,
- font_size=8, print_halo_size=False,
- print_halo_mass=False, width=None):
- self.hop_output = hop_output
- self.max_number = max_number
- self.annotate = annotate
- self.min_size = min_size
- self.max_size = max_size
- self.font_size = font_size
- self.print_halo_size = print_halo_size
- self.print_halo_mass = print_halo_mass
- self.width = width
-
- def __call__(self, plot):
- from matplotlib.patches import Circle
- num = len(self.hop_output[:self.max_number])
- for halo in self.hop_output[:self.max_number]:
- size = halo.get_size()
- if size < self.min_size or size > self.max_size: continue
- # This could use halo.maximum_radius() instead of width
- if self.width is not None and \
- np.abs(halo.center_of_mass() -
- plot.data.center)[plot.data.axis] > \
- self.width:
- continue
-
- radius = halo.maximum_radius() * self.pixel_scale(plot)[0]
- center = halo.center_of_mass()
-
- (xi, yi) = (x_dict[plot.data.axis], y_dict[plot.data.axis])
-
- (center_x,center_y) = self.convert_to_plot(plot,(center[xi], center[yi]))
- color = np.ones(3) * (0.4 * (num - halo.id)/ num) + 0.6
- cir = Circle((center_x, center_y), radius, fill=False, color=color)
- plot._axes.add_patch(cir)
- if self.annotate:
- if self.print_halo_size:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % size,
- fontsize=self.font_size, color=color)
- elif self.print_halo_mass:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % halo.total_mass(),
- fontsize=self.font_size, color=color)
- else:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % halo.id,
- fontsize=self.font_size, color=color)
-
-class HopParticleCallback(PlotCallback):
- """
- annotate_hop_particles(hop_output, max_number, p_size=1.0,
- min_size=20, alpha=0.2):
-
- Adds particle positions for the members of each halo as identified
- by HOP. Along *axis* up to *max_number* groups in *hop_output* that are
- larger than *min_size* are plotted with *p_size* pixels per particle;
- *alpha* determines the opacity of each particle.
- """
- _type_name = "hop_particles"
- def __init__(self, hop_output, max_number=None, p_size=1.0,
- min_size=20, alpha=0.2):
- self.hop_output = hop_output
- self.p_size = p_size
- self.max_number = max_number
- self.min_size = min_size
- self.alpha = alpha
-
- def __call__(self,plot):
- (dx,dy) = self.pixel_scale(plot)
-
- (xi, yi) = (x_names[plot.data.axis], y_names[plot.data.axis])
-
- # now we loop over the haloes
- for halo in self.hop_output[:self.max_number]:
- size = halo.get_size()
-
- if size < self.min_size: continue
-
- (px,py) = self.convert_to_plot(plot,(halo["particle_position_%s" % xi],
- halo["particle_position_%s" % yi]))
-
- # Need to get the plot limits and set the hold state before scatter
- # and then restore the limits and turn off the hold state afterwards
- # because scatter will automatically adjust the plot window which we
- # do not want
-
- xlim = plot._axes.get_xlim()
- ylim = plot._axes.get_ylim()
- plot._axes.hold(True)
-
- plot._axes.scatter(px, py, edgecolors="None",
- s=self.p_size, c='black', alpha=self.alpha)
-
- plot._axes.set_xlim(xlim)
- plot._axes.set_ylim(ylim)
- plot._axes.hold(False)
-
class TextLabelCallback(PlotCallback):
"""
https://bitbucket.org/yt_analysis/yt/commits/08c3dec25599/
Changeset: 08c3dec25599
Branch: yt-3.0
User: hegan
Date: 2014-04-28 22:05:32
Summary: merged with main yt-3.0
Affected #: 5 files
diff -r cdc3b525086b5623c8821b3a548f496ef69348ff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 yt/config.py
--- a/yt/config.py
+++ b/yt/config.py
@@ -53,6 +53,7 @@
answer_testing_bitwise = 'False',
gold_standard_filename = 'gold311',
local_standard_filename = 'local001',
+ answer_tests_url = 'http://answers.yt-project.org/%s_%s',
sketchfab_api_key = 'None',
thread_field_detection = 'False',
ignore_invalid_unit_operation_errors = 'False'
diff -r cdc3b525086b5623c8821b3a548f496ef69348ff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -592,3 +592,38 @@
yield assert_array_equal, arr.value, np.array(arr)
yield assert_array_equal, arr.v, np.array(arr)
+
+
+def test_registry_association():
+ ds = fake_random_pf(64, nprocs=1, length_unit=10)
+ a = ds.quan(3, 'cm')
+ b = YTQuantity(4, 'm')
+ c = ds.quan(6, '')
+ d = 5
+
+ yield assert_equal, id(a.units.registry), id(ds.unit_registry)
+
+ def binary_op_registry_comparison(op):
+ e = op(a, b)
+ f = op(b, a)
+ g = op(c, d)
+ h = op(d, c)
+
+ assert_equal(id(e.units.registry), id(ds.unit_registry))
+ assert_equal(id(f.units.registry), id(b.units.registry))
+ assert_equal(id(g.units.registry), id(h.units.registry))
+ assert_equal(id(g.units.registry), id(ds.unit_registry))
+
+ def unary_op_registry_comparison(op):
+ c = op(a)
+ d = op(b)
+
+ assert_equal(id(c.units.registry), id(ds.unit_registry))
+ assert_equal(id(d.units.registry), id(b.units.registry))
+
+ for op in [operator.add, operator.sub, operator.mul, operator.div,
+ operator.truediv]:
+ yield binary_op_registry_comparison, op
+
+ for op in [operator.abs, operator.neg, operator.pos]:
+ yield unary_op_registry_comparison, op
diff -r cdc3b525086b5623c8821b3a548f496ef69348ff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -497,12 +497,16 @@
def __isub__(self, other):
""" See __sub__. """
oth = sanitize_units_add(self, other, "subtraction")
- return np.subtract(self, other, out=self)
+ return np.subtract(self, oth, out=self)
def __neg__(self):
""" Negate the data. """
return YTArray(super(YTArray, self).__neg__())
+ def __pos__(self):
+ """ Posify the data. """
+ return YTArray(super(YTArray, self).__pos__(), self.units)
+
def __mul__(self, right_object):
"""
Multiply this YTArray by the object on the right of the `*` operator.
@@ -665,7 +669,7 @@
def __eq__(self, other):
""" Test if this is equal to the object on the right. """
# Check that other is a YTArray.
- if other == None:
+ if other is None:
# self is a YTArray, so it can't be None.
return False
if isinstance(other, YTArray):
@@ -679,7 +683,7 @@
def __ne__(self, other):
""" Test if this is not equal to the object on the right. """
# Check that the other is a YTArray.
- if other == None:
+ if other is None:
return True
if isinstance(other, YTArray):
if not self.units.same_dimensions_as(other.units):
@@ -763,7 +767,7 @@
return ret
elif context[0] in unary_operators:
u = getattr(context[1][0], 'units', None)
- if u == None:
+ if u is None:
u = Unit()
try:
unit = self._ufunc_registry[context[0]](u)
@@ -774,10 +778,10 @@
elif context[0] in binary_operators:
unit1 = getattr(context[1][0], 'units', None)
unit2 = getattr(context[1][1], 'units', None)
- if unit1 == None:
- unit1 = Unit()
- if unit2 == None and context[0] is not power:
- unit2 = Unit()
+ if unit1 is None:
+ unit1 = Unit(registry=getattr(unit2, 'registry', None))
+ if unit2 is None and context[0] is not power:
+ unit2 = Unit(registry=getattr(unit1, 'registry', None))
elif context[0] is power:
unit2 = context[1][1]
if isinstance(unit2, np.ndarray):
@@ -817,8 +821,8 @@
See the documentation for the standard library pickle module:
http://docs.python.org/2/library/pickle.html
- Unit metadata is encoded in the zeroth element of third element of the
- returned tuple, itself a tuple used to restore the state of the ndarray.
+ Unit metadata is encoded in the zeroth element of third element of the
+ returned tuple, itself a tuple used to restore the state of the ndarray.
This is always defined for numpy arrays.
"""
np_ret = super(YTArray, self).__reduce__()
diff -r cdc3b525086b5623c8821b3a548f496ef69348ff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 yt/utilities/answer_testing/framework.py
--- a/yt/utilities/answer_testing/framework.py
+++ b/yt/utilities/answer_testing/framework.py
@@ -45,7 +45,7 @@
# Set the latest gold and local standard filenames
_latest = ytcfg.get("yt", "gold_standard_filename")
_latest_local = ytcfg.get("yt", "local_standard_filename")
-_url_path = "http://yt-answer-tests.s3-website-us-east-1.amazonaws.com/%s_%s"
+_url_path = ytcfg.get("yt", "answer_tests_url")
class AnswerTesting(Plugin):
name = "answer-testing"
@@ -197,30 +197,20 @@
if self.answer_name is None: return
# This is where we dump our result storage up to Amazon, if we are able
# to.
- import boto
- from boto.s3.key import Key
- c = boto.connect_s3()
- bucket = c.get_bucket("yt-answer-tests")
- for pf_name in result_storage:
+ import pyrax
+ pyrax.set_credential_file(os.path.expanduser("~/.yt/rackspace"))
+ cf = pyrax.cloudfiles
+ c = cf.get_container("yt-answer-tests")
+ pb = get_pbar("Storing results ", len(result_storage))
+ for i, pf_name in enumerate(result_storage):
+ pb.update(i)
rs = cPickle.dumps(result_storage[pf_name])
- tk = bucket.get_key("%s_%s" % (self.answer_name, pf_name))
- if tk is not None: tk.delete()
- k = Key(bucket)
- k.key = "%s_%s" % (self.answer_name, pf_name)
-
- pb_widgets = [
- unicode(k.key, errors='ignore').encode('utf-8'), ' ',
- progressbar.FileTransferSpeed(),' <<<', progressbar.Bar(),
- '>>> ', progressbar.Percentage(), ' ', progressbar.ETA()
- ]
- self.pbar = progressbar.ProgressBar(widgets=pb_widgets,
- maxval=sys.getsizeof(rs))
-
- self.pbar.start()
- k.set_contents_from_string(rs, cb=self.progress_callback,
- num_cb=100000)
- k.set_acl("public-read")
- self.pbar.finish()
+ object_name = "%s_%s" % (self.answer_name, pf_name)
+ if object_name in c.get_object_names():
+ obj = c.get_object(object_name)
+ c.delete_object(obj)
+ c.store_object(object_name, rs)
+ pb.finish()
class AnswerTestLocalStorage(AnswerTestStorage):
def dump(self, result_storage):
diff -r cdc3b525086b5623c8821b3a548f496ef69348ff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 yt/utilities/parallel_tools/parallel_analysis_interface.py
--- a/yt/utilities/parallel_tools/parallel_analysis_interface.py
+++ b/yt/utilities/parallel_tools/parallel_analysis_interface.py
@@ -636,6 +636,9 @@
def __init__(self, comm=None):
self.comm = comm
self._distributed = comm is not None and self.comm.size > 1
+
+ def __del__(self):
+ self.comm.Free()
"""
This is an interface specification providing several useful utility
functions for analyzing something in parallel.
https://bitbucket.org/yt_analysis/yt/commits/0691ab75861a/
Changeset: 0691ab75861a
Branch: yt-3.0
User: hegan
Date: 2014-04-28 22:11:31
Summary: actually merged in yt-3.0
Affected #: 50 files
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 doc/source/reference/api/api.rst
--- a/doc/source/reference/api/api.rst
+++ b/doc/source/reference/api/api.rst
@@ -660,7 +660,6 @@
~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_passthrough
~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_root_only
~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_simple_proxy
- ~yt.utilities.parallel_tools.parallel_analysis_interface.parallel_splitter
Math Utilities
--------------
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/analysis_modules/halo_analysis/halo_catalog.py
--- a/yt/analysis_modules/halo_analysis/halo_catalog.py
+++ b/yt/analysis_modules/halo_analysis/halo_catalog.py
@@ -351,6 +351,14 @@
if self.halos_pf is None:
# Find the halos and make a dataset of them
self.halos_pf = self.finder_method(self.data_pf)
+ if self.halos_pf is None:
+ mylog.warning('No halos were found for {0}'.format(\
+ self.data_pf.basename))
+ if save_catalog:
+ self.halos_pf = self.data_pf
+ self.save_catalog()
+ self.halos_pf = None
+ return
# Assign pf and data sources appropriately
self.data_source = self.halos_pf.all_data()
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/analysis_modules/halo_analysis/halo_finding_methods.py
--- a/yt/analysis_modules/halo_analysis/halo_finding_methods.py
+++ b/yt/analysis_modules/halo_analysis/halo_finding_methods.py
@@ -25,6 +25,7 @@
from .operator_registry import \
finding_method_registry
+
def add_finding_method(name, function):
finding_method_registry[name] = HaloFindingMethod(function)
@@ -75,8 +76,14 @@
rh = RockstarHaloFinder(pf)
rh.run()
+
+
halos_pf = RockstarDataset("rockstar_halos/halos_0.0.bin")
- halos_pf.create_field_info()
+ try:
+ halos_pf.create_field_info()
+ except ValueError:
+ return None
+
return halos_pf
add_finding_method("rockstar", _rockstar_method)
@@ -87,6 +94,8 @@
num_halos = len(halo_list)
+ if num_halos == 0: return None
+
# Set up fields that we want to pull from identified halos and their units
new_fields = ['particle_identifier', 'particle_mass', 'particle_position_x',
'particle_position_y','particle_position_z',
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/analysis_modules/halo_finding/rockstar/rockstar_groupies.pyx
--- a/yt/analysis_modules/halo_finding/rockstar/rockstar_groupies.pyx
+++ b/yt/analysis_modules/halo_finding/rockstar/rockstar_groupies.pyx
@@ -220,12 +220,11 @@
cdef np.int64_t last_fof_tag = 1
cdef np.int64_t k = 0
for i in range(num_particles):
- if fof_tags[i] == 0:
+ if fof_tags[i] < 0:
continue
if fof_tags[i] != last_fof_tag:
last_fof_tag = fof_tags[i]
if k > 16:
- print "Finding subs", k, i
fof_obj.num_p = k
find_subs(&fof_obj)
k = 0
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/analysis_modules/sunyaev_zeldovich/projection.py
--- a/yt/analysis_modules/sunyaev_zeldovich/projection.py
+++ b/yt/analysis_modules/sunyaev_zeldovich/projection.py
@@ -22,7 +22,6 @@
from yt.fields.local_fields import add_field, derived_field
from yt.data_objects.image_array import ImageArray
from yt.funcs import fix_axis, mylog, iterable, get_pbar
-from yt.utilities.definitions import inv_axis_names
from yt.visualization.volume_rendering.camera import off_axis_projection
from yt.utilities.parallel_tools.parallel_analysis_interface import \
communication_system, parallel_root_only
@@ -134,7 +133,7 @@
--------
>>> szprj.on_axis("y", center="max", width=(1.0, "Mpc"), source=my_sphere)
"""
- axis = fix_axis(axis)
+ axis = fix_axis(axis, self.pf)
if center == "c":
ctr = self.pf.domain_center
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -41,7 +41,6 @@
march_cubes_grid, march_cubes_grid_flux
from yt.utilities.data_point_utilities import CombineGrids,\
DataCubeRefine, DataCubeReplace, FillRegion, FillBuffer
-from yt.utilities.definitions import axis_names, x_dict, y_dict
from yt.utilities.minimal_representation import \
MinimalProjectionData
from yt.utilities.parallel_tools.parallel_analysis_interface import \
@@ -252,8 +251,8 @@
self._mrep.upload()
def _get_tree(self, nvals):
- xax = x_dict[self.axis]
- yax = y_dict[self.axis]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
xd = self.pf.domain_dimensions[xax]
yd = self.pf.domain_dimensions[yax]
bounds = (self.pf.domain_left_edge[xax],
@@ -292,18 +291,20 @@
else:
raise NotImplementedError
# TODO: Add the combine operation
- ox = self.pf.domain_left_edge[x_dict[self.axis]]
- oy = self.pf.domain_left_edge[y_dict[self.axis]]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
+ ox = self.pf.domain_left_edge[xax]
+ oy = self.pf.domain_left_edge[yax]
px, py, pdx, pdy, nvals, nwvals = tree.get_all(False, merge_style)
nvals = self.comm.mpi_allreduce(nvals, op=op)
nwvals = self.comm.mpi_allreduce(nwvals, op=op)
- np.multiply(px, self.pf.domain_width[x_dict[self.axis]], px)
+ np.multiply(px, self.pf.domain_width[xax], px)
np.add(px, ox, px)
- np.multiply(pdx, self.pf.domain_width[x_dict[self.axis]], pdx)
+ np.multiply(pdx, self.pf.domain_width[xax], pdx)
- np.multiply(py, self.pf.domain_width[y_dict[self.axis]], py)
+ np.multiply(py, self.pf.domain_width[yax], py)
np.add(py, oy, py)
- np.multiply(pdy, self.pf.domain_width[y_dict[self.axis]], pdy)
+ np.multiply(pdy, self.pf.domain_width[yax], pdy)
if self.weight_field is not None:
np.divide(nvals, nwvals[:,None], nvals)
# We now convert to half-widths and center-points
@@ -348,8 +349,10 @@
def _initialize_chunk(self, chunk, tree):
icoords = chunk.icoords
- i1 = icoords[:,x_dict[self.axis]]
- i2 = icoords[:,y_dict[self.axis]]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
+ i1 = icoords[:,xax]
+ i2 = icoords[:,yax]
ilevel = chunk.ires * self.pf.ires_factor
tree.initialize_chunk(i1, i2, ilevel)
@@ -370,8 +373,10 @@
else:
w = np.ones(chunk.ires.size, dtype="float64")
icoords = chunk.icoords
- i1 = icoords[:,x_dict[self.axis]]
- i2 = icoords[:,y_dict[self.axis]]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
+ i1 = icoords[:,xax]
+ i2 = icoords[:,yax]
ilevel = chunk.ires * self.pf.ires_factor
tree.add_chunk_to_tree(i1, i2, ilevel, v, w)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/data_containers.py
--- a/yt/data_objects/data_containers.py
+++ b/yt/data_objects/data_containers.py
@@ -28,7 +28,6 @@
from yt.data_objects.particle_io import particle_handler_registry
from yt.utilities.lib.marching_cubes import \
march_cubes_grid, march_cubes_grid_flux
-from yt.utilities.definitions import x_dict, y_dict
from yt.utilities.parallel_tools.parallel_analysis_interface import \
ParallelAnalysisInterface
from yt.utilities.parameter_file_storage import \
@@ -726,9 +725,10 @@
_spatial = False
def __init__(self, axis, pf, field_parameters):
ParallelAnalysisInterface.__init__(self)
- self.axis = fix_axis(axis)
super(YTSelectionContainer2D, self).__init__(
pf, field_parameters)
+ # We need the pf, which will exist by now, for fix_axis.
+ self.axis = fix_axis(axis, self.pf)
self.set_field_parameter("axis", axis)
def _convert_field_name(self, field):
@@ -821,8 +821,8 @@
if not iterable(resolution):
resolution = (resolution, resolution)
from yt.visualization.fixed_resolution import FixedResolutionBuffer
- xax = x_dict[self.axis]
- yax = y_dict[self.axis]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
bounds = (center[xax] - width*0.5, center[xax] + width*0.5,
center[yax] - height*0.5, center[yax] + height*0.5)
frb = FixedResolutionBuffer(self, bounds, resolution,
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/grid_patch.py
--- a/yt/data_objects/grid_patch.py
+++ b/yt/data_objects/grid_patch.py
@@ -20,13 +20,11 @@
import numpy as np
from yt.funcs import *
-from yt.utilities.definitions import x_dict, y_dict
from yt.data_objects.data_containers import \
YTFieldData, \
YTDataContainer, \
YTSelectionContainer
-from yt.utilities.definitions import x_dict, y_dict
from yt.fields.field_exceptions import \
NeedsGridType, \
NeedsOriginalGrid, \
@@ -379,9 +377,9 @@
def count_particles(self, selector, x, y, z):
# We don't cache the selector results
- count = selector.count_points(x,y,z)
+ count = selector.count_points(x,y,z, 0.0)
return count
def select_particles(self, selector, x, y, z):
- mask = selector.select_points(x,y,z)
+ mask = selector.select_points(x,y,z, 0.0)
return mask
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/octree_subset.py
--- a/yt/data_objects/octree_subset.py
+++ b/yt/data_objects/octree_subset.py
@@ -249,11 +249,11 @@
def count_particles(self, selector, x, y, z):
# We don't cache the selector results
- count = selector.count_points(x,y,z)
+ count = selector.count_points(x,y,z, 0.0)
return count
def select_particles(self, selector, x, y, z):
- mask = selector.select_points(x,y,z)
+ mask = selector.select_points(x,y,z, 0.0)
return mask
class ParticleOctreeSubset(OctreeSubset):
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/selection_data_containers.py
--- a/yt/data_objects/selection_data_containers.py
+++ b/yt/data_objects/selection_data_containers.py
@@ -25,8 +25,6 @@
YTSelectionContainer1D, YTSelectionContainer2D, YTSelectionContainer3D
from yt.data_objects.derived_quantities import \
DerivedQuantityCollection
-from yt.utilities.definitions import \
- x_dict, y_dict, axis_names
from yt.utilities.exceptions import YTSphereTooSmall
from yt.utilities.linear_interpolators import TrilinearFieldInterpolator
from yt.utilities.minimal_representation import \
@@ -73,12 +71,15 @@
def __init__(self, axis, coords, pf=None, field_parameters=None):
super(YTOrthoRayBase, self).__init__(pf, field_parameters)
self.axis = axis
- self.px_ax = x_dict[self.axis]
- self.py_ax = y_dict[self.axis]
- self.px_dx = 'd%s'%(axis_names[self.px_ax])
- self.py_dx = 'd%s'%(axis_names[self.py_ax])
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
+ self.px_ax = xax
+ self.py_ax = yax
+ # Even though we may not be using x,y,z we use them here.
+ self.px_dx = 'd%s'%('xyz'[self.px_ax])
+ self.py_dx = 'd%s'%('xyz'[self.py_ax])
self.px, self.py = coords
- self.sort_by = axis_names[self.axis]
+ self.sort_by = 'xyz'[self.axis]
@property
def coords(self):
@@ -191,16 +192,18 @@
self.coord = coord
def _generate_container_field(self, field):
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
if self._current_chunk is None:
self.index._identify_base_chunk(self)
if field == "px":
- return self._current_chunk.fcoords[:,x_dict[self.axis]]
+ return self._current_chunk.fcoords[:,xax]
elif field == "py":
- return self._current_chunk.fcoords[:,y_dict[self.axis]]
+ return self._current_chunk.fcoords[:,yax]
elif field == "pdx":
- return self._current_chunk.fwidth[:,x_dict[self.axis]] * 0.5
+ return self._current_chunk.fwidth[:,xax] * 0.5
elif field == "pdy":
- return self._current_chunk.fwidth[:,y_dict[self.axis]] * 0.5
+ return self._current_chunk.fwidth[:,yax] * 0.5
else:
raise KeyError(field)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/static_output.py
--- a/yt/data_objects/static_output.py
+++ b/yt/data_objects/static_output.py
@@ -574,6 +574,8 @@
self.unit_registry.add("code_magnetic", 1.0, dimensions.magnetic_field)
self.unit_registry.add("code_temperature", 1.0, dimensions.temperature)
self.unit_registry.add("code_velocity", 1.0, dimensions.velocity)
+ self.unit_registry.add("code_metallicity", 1.0,
+ dimensions.dimensionless)
def set_units(self):
"""
@@ -628,12 +630,10 @@
self.length_unit / self.time_unit)
self.unit_registry.modify("code_velocity", vel_unit)
# domain_width does not yet exist
- if self.domain_left_edge is None or self.domain_right_edge is None:
- DW = np.zeros(3)
- else:
+ if None not in (self.domain_left_edge, self.domain_right_edge):
DW = self.arr(self.domain_right_edge - self.domain_left_edge, "code_length")
- self.unit_registry.add("unitary", float(DW.max() * DW.units.cgs_value),
- DW.units.dimensions)
+ self.unit_registry.add("unitary", float(DW.max() * DW.units.cgs_value),
+ DW.units.dimensions)
_arr = None
@property
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/tests/test_chunking.py
--- a/yt/data_objects/tests/test_chunking.py
+++ b/yt/data_objects/tests/test_chunking.py
@@ -3,7 +3,7 @@
def _get_dobjs(c):
dobjs = [("sphere", ("center", (1.0, "unitary"))),
("sphere", ("center", (0.1, "unitary"))),
- ("ortho_ray", (0, (c[x_dict[0]], c[y_dict[0]]))),
+ ("ortho_ray", (0, (c[1], c[2]))),
("slice", (0, c[0])),
#("disk", ("center", [0.1, 0.3, 0.6],
# (0.2, 'unitary'), (0.1, 'unitary'))),
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/tests/test_projection.py
--- a/yt/data_objects/tests/test_projection.py
+++ b/yt/data_objects/tests/test_projection.py
@@ -29,8 +29,8 @@
uc = [np.unique(c) for c in coords]
# Some simple projection tests with single grids
for ax, an in enumerate("xyz"):
- xax = x_dict[ax]
- yax = y_dict[ax]
+ xax = pf.coordinates.x_axis[ax]
+ yax = pf.coordinates.y_axis[ax]
for wf in ["density", None]:
fns = []
proj = pf.proj(["ones", "density"], ax, weight_field = wf)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/tests/test_slice.py
--- a/yt/data_objects/tests/test_slice.py
+++ b/yt/data_objects/tests/test_slice.py
@@ -17,8 +17,6 @@
from nose.tools import raises
from yt.testing import \
fake_random_pf, assert_equal, assert_array_equal, YTArray
-from yt.utilities.definitions import \
- x_dict, y_dict
from yt.utilities.exceptions import \
YTNoDataInObjectError
from yt.units.unit_object import Unit
@@ -50,8 +48,8 @@
slc_pos = 0.5
# Some simple slice tests with single grids
for ax, an in enumerate("xyz"):
- xax = x_dict[ax]
- yax = y_dict[ax]
+ xax = pf.coordinates.x_axis[ax]
+ yax = pf.coordinates.y_axis[ax]
for wf in ["density", None]:
fns = []
slc = pf.slice(ax, slc_pos)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/data_objects/unstructured_mesh.py
--- a/yt/data_objects/unstructured_mesh.py
+++ b/yt/data_objects/unstructured_mesh.py
@@ -171,9 +171,9 @@
def count_particles(self, selector, x, y, z):
# We don't cache the selector results
- count = selector.count_points(x,y,z)
+ count = selector.count_points(x,y,z, 0.0)
return count
def select_particles(self, selector, x, y, z):
- mask = selector.select_points(x,y,z)
+ mask = selector.select_points(x,y,z, 0.0)
return mask
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/art/io.py
--- a/yt/frontends/art/io.py
+++ b/yt/frontends/art/io.py
@@ -74,7 +74,7 @@
pbool, idxa, idxb = _determine_field_size(pf, ftype, self.ls, ptmax)
pstr = 'particle_position_%s'
x,y,z = [self._get_field((ftype, pstr % ax)) for ax in 'xyz']
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if self.caching:
self.masks[key] = mask
return self.masks[key]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/artio/io.py
--- a/yt/frontends/artio/io.py
+++ b/yt/frontends/artio/io.py
@@ -64,7 +64,7 @@
for ptype, field_list in sorted(ptf.items()):
x, y, z = (np.asarray(rv[ptype][pn % ax], dtype="=f8")
for ax in 'XYZ')
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
data = np.asarray(rv[ptype][field], "=f8")
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/enzo/io.py
--- a/yt/frontends/enzo/io.py
+++ b/yt/frontends/enzo/io.py
@@ -104,7 +104,7 @@
r"particle_position_%s")
x, y, z = (np.asarray(pds.get(pn % ax).value, dtype="=f8")
for ax in 'xyz')
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
data = np.asarray(pds.get(field).value, "=f8")
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -93,7 +93,7 @@
x = p_fields[start:end, px]
y = p_fields[start:end, py]
z = p_fields[start:end, pz]
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
fi = self._particle_fields[field]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/halo_catalogs/halo_catalog/io.py
--- a/yt/frontends/halo_catalogs/halo_catalog/io.py
+++ b/yt/frontends/halo_catalogs/halo_catalog/io.py
@@ -68,7 +68,7 @@
x = f['particle_position_x'].value.astype("float64")
y = f['particle_position_y'].value.astype("float64")
z = f['particle_position_z'].value.astype("float64")
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None: continue
for field in field_list:
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/halo_catalogs/rockstar/io.py
--- a/yt/frontends/halo_catalogs/rockstar/io.py
+++ b/yt/frontends/halo_catalogs/rockstar/io.py
@@ -74,7 +74,7 @@
x = halos['particle_position_x'].astype("float64")
y = halos['particle_position_y'].astype("float64")
z = halos['particle_position_z'].astype("float64")
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None: continue
for field in field_list:
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/ramses/io.py
--- a/yt/frontends/ramses/io.py
+++ b/yt/frontends/ramses/io.py
@@ -77,7 +77,7 @@
for ptype, field_list in sorted(ptf.items()):
x, y, z = (np.asarray(rv[ptype, pn % ax], "=f8")
for ax in 'xyz')
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
for field in field_list:
data = np.asarray(rv.pop((ptype, field))[mask], "=f8")
yield (ptype, field), data
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/sph/data_structures.py
--- a/yt/frontends/sph/data_structures.py
+++ b/yt/frontends/sph/data_structures.py
@@ -471,16 +471,20 @@
self.current_time = hvals["time"]
nz = 1 << self.over_refine_factor
self.domain_dimensions = np.ones(3, "int32") * nz
- if self.parameters.get('bPeriodic', True):
- self.periodicity = (True, True, True)
+ periodic = self.parameters.get('bPeriodic', True)
+ period = self.parameters.get('dPeriod', None)
+ comoving = self.parameters.get('bComove', False)
+ self.periodicity = (periodic, periodic, periodic)
+ if comoving and period is None:
+ period = 1.0
+ if periodic and period is not None:
# If we are periodic, that sets our domain width to either 1 or dPeriod.
- self.domain_left_edge = np.zeros(3, "float64") - 0.5*self.parameters.get('dPeriod', 1)
- self.domain_right_edge = np.zeros(3, "float64") + 0.5*self.parameters.get('dPeriod', 1)
+ self.domain_left_edge = np.zeros(3, "float64") - 0.5*period
+ self.domain_right_edge = np.zeros(3, "float64") + 0.5*period
else:
- self.periodicity = (False, False, False)
self.domain_left_edge = None
self.domain_right_edge = None
- if self.parameters.get('bComove', False):
+ if comoving:
cosm = self._cosmology_parameters or {}
self.scale_factor = hvals["time"]#In comoving simulations, time stores the scale factor a
self.cosmological_simulation = 1
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/sph/io.py
--- a/yt/frontends/sph/io.py
+++ b/yt/frontends/sph/io.py
@@ -99,7 +99,7 @@
g = f["/%s" % ptype]
coords = g["Coordinates"][:].astype("float64")
mask = selector.select_points(
- coords[:,0], coords[:,1], coords[:,2])
+ coords[:,0], coords[:,1], coords[:,2], 0.0)
del coords
if mask is None: continue
for field in field_list:
@@ -281,7 +281,7 @@
pos = self._read_field_from_file(f,
tp[ptype], "Coordinates")
mask = selector.select_points(
- pos[:,0], pos[:,1], pos[:,2])
+ pos[:,0], pos[:,1], pos[:,2], 0.0)
del pos
if mask is None: continue
for field in field_list:
@@ -534,7 +534,7 @@
mask = selector.select_points(
p["Coordinates"]['x'].astype("float64"),
p["Coordinates"]['y'].astype("float64"),
- p["Coordinates"]['z'].astype("float64"))
+ p["Coordinates"]['z'].astype("float64"), 0.0)
if mask is None: continue
tf = self._fill_fields(field_list, p, mask, data_file)
for field in field_list:
@@ -557,6 +557,8 @@
pf.domain_left_edge = 0
pf.domain_right_edge = 0
f.seek(pf._header_offset)
+ mi = np.array([1e30, 1e30, 1e30], dtype="float64")
+ ma = -np.array([1e30, 1e30, 1e30], dtype="float64")
for iptype, ptype in enumerate(self._ptypes):
# We'll just add the individual types separately
count = data_file.total_particles[ptype]
@@ -566,19 +568,23 @@
c = min(CHUNKSIZE, stop - ind)
pp = np.fromfile(f, dtype = self._pdtypes[ptype],
count = c)
- for ax in 'xyz':
- mi = pp["Coordinates"][ax].min()
- ma = pp["Coordinates"][ax].max()
- outlier = self.arr(np.max(np.abs((mi,ma))), 'code_length')
- if outlier > pf.domain_right_edge or -outlier < pf.domain_left_edge:
- # scale these up so the domain is slightly
- # larger than the most distant particle position
- pf.domain_left_edge = -1.01*outlier
- pf.domain_right_edge = 1.01*outlier
+ eps = np.finfo(pp["Coordinates"]["x"].dtype).eps
+ np.minimum(mi, [pp["Coordinates"]["x"].min(),
+ pp["Coordinates"]["y"].min(),
+ pp["Coordinates"]["z"].min()], mi)
+ np.maximum(ma, [pp["Coordinates"]["x"].max(),
+ pp["Coordinates"]["y"].max(),
+ pp["Coordinates"]["z"].max()], ma)
ind += c
- pf.domain_left_edge = np.ones(3)*pf.domain_left_edge
- pf.domain_right_edge = np.ones(3)*pf.domain_right_edge
- pf.domain_width = np.ones(3)*2*pf.domain_right_edge
+ # We extend by 1%.
+ DW = ma - mi
+ mi -= 0.01 * DW
+ ma += 0.01 * DW
+ pf.domain_left_edge = pf.arr(mi, 'code_length')
+ pf.domain_right_edge = pf.arr(ma, 'code_length')
+ pf.domain_width = DW = pf.domain_right_edge - pf.domain_left_edge
+ pf.unit_registry.add("unitary", float(DW.max() * DW.units.cgs_value),
+ DW.units.dimensions)
def _initialize_index(self, data_file, regions):
pf = data_file.pf
@@ -745,7 +751,7 @@
c = np.frombuffer(s, dtype="float64")
c.shape = (c.shape[0]/3.0, 3)
mask = selector.select_points(
- c[:,0], c[:,1], c[:,2])
+ c[:,0], c[:,1], c[:,2], 0.0)
del c
if mask is None: continue
for field in field_list:
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/frontends/stream/io.py
--- a/yt/frontends/stream/io.py
+++ b/yt/frontends/stream/io.py
@@ -85,7 +85,7 @@
for ptype, field_list in sorted(ptf.items()):
x, y, z = (gf[ptype, "particle_position_%s" % ax]
for ax in 'xyz')
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
data = np.asarray(gf[ptype, field])
@@ -127,7 +127,7 @@
for ptype, field_list in sorted(ptf.items()):
x, y, z = (f[ptype, "particle_position_%s" % ax]
for ax in 'xyz')
- mask = selector.select_points(x, y, z)
+ mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
data = f[ptype, field][mask]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/funcs.py
--- a/yt/funcs.py
+++ b/yt/funcs.py
@@ -24,7 +24,6 @@
from yt.utilities.exceptions import *
from yt.utilities.logger import ytLogger as mylog
-from yt.utilities.definitions import inv_axis_names, axis_names, x_dict, y_dict
import yt.extern.progressbar as pb
import yt.utilities.rpdb as rpdb
from yt.units.yt_array import YTArray, YTQuantity
@@ -637,8 +636,8 @@
return os.environ.get("OMP_NUM_THREADS", 0)
return nt
-def fix_axis(axis):
- return inv_axis_names.get(axis, axis)
+def fix_axis(axis, pf):
+ return pf.coordinates.axis_id.get(axis, axis)
def get_image_suffix(name):
suffix = os.path.splitext(name)[1]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/coordinate_handler.py
--- a/yt/geometry/coordinate_handler.py
+++ b/yt/geometry/coordinate_handler.py
@@ -24,7 +24,7 @@
from yt.utilities.io_handler import io_registry
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.parallel_tools.parallel_analysis_interface import \
- ParallelAnalysisInterface, parallel_splitter
+ ParallelAnalysisInterface
from yt.utilities.lib.misc_utilities import \
pixelize_cylinder
import yt.visualization._MPL as _MPL
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/geometry_handler.py
--- a/yt/geometry/geometry_handler.py
+++ b/yt/geometry/geometry_handler.py
@@ -36,7 +36,7 @@
from yt.utilities.io_handler import io_registry
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.parallel_tools.parallel_analysis_interface import \
- ParallelAnalysisInterface, parallel_splitter
+ ParallelAnalysisInterface, parallel_root_only
from yt.utilities.exceptions import YTFieldNotFound
class Index(ParallelAnalysisInterface):
@@ -126,7 +126,8 @@
if getattr(self, "io", None) is not None: return
self.io = io_registry[self.dataset_type](self.parameter_file)
- def _save_data(self, array, node, name, set_attr=None, force=False, passthrough = False):
+ @parallel_root_only
+ def save_data(self, array, node, name, set_attr=None, force=False, passthrough = False):
"""
Arbitrary numpy data will be saved to the region in the datafile
described by *node* and *name*. If data file does not exist, it throws
@@ -157,14 +158,6 @@
del self._data_file
self._data_file = h5py.File(self.__data_filename, self._data_mode)
- save_data = parallel_splitter(_save_data, _reload_data_file)
-
- def _reset_save_data(self,round_robin=False):
- if round_robin:
- self.save_data = self._save_data
- else:
- self.save_data = parallel_splitter(self._save_data, self._reload_data_file)
-
def save_object(self, obj, name):
"""
Save an object (*obj*) to the data_file using the Pickle protocol,
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/grid_geometry_handler.py
--- a/yt/geometry/grid_geometry_handler.py
+++ b/yt/geometry/grid_geometry_handler.py
@@ -31,7 +31,7 @@
from yt.utilities.physical_constants import sec_per_year
from yt.utilities.io_handler import io_registry
from yt.utilities.parallel_tools.parallel_analysis_interface import \
- ParallelAnalysisInterface, parallel_splitter
+ ParallelAnalysisInterface
from yt.utilities.lib.GridTree import GridTree, MatchPointsToGrids
from yt.data_objects.data_containers import data_object_registry
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/object_finding_mixin.py
--- a/yt/geometry/object_finding_mixin.py
+++ b/yt/geometry/object_finding_mixin.py
@@ -38,10 +38,12 @@
# So if gRE > coord, we get a mask, if not, we get a zero
# if gLE > coord, we get a zero, if not, mask
# Thus, if the coordinate is between the two edges, we win!
- np.choose(np.greater(self.grid_right_edge[:,x_dict[axis]],coord[0]),(0,mask),mask)
- np.choose(np.greater(self.grid_left_edge[:,x_dict[axis]],coord[0]),(mask,0),mask)
- np.choose(np.greater(self.grid_right_edge[:,y_dict[axis]],coord[1]),(0,mask),mask)
- np.choose(np.greater(self.grid_left_edge[:,y_dict[axis]],coord[1]),(mask,0),mask)
+ xax = self.pf.coordinates.x_axis[axis]
+ yax = self.pf.coordinates.y_axis[axis]
+ np.choose(np.greater(self.grid_right_edge[:,xax],coord[0]),(0,mask),mask)
+ np.choose(np.greater(self.grid_left_edge[:,xax],coord[0]),(mask,0),mask)
+ np.choose(np.greater(self.grid_right_edge[:,yax],coord[1]),(0,mask),mask)
+ np.choose(np.greater(self.grid_left_edge[:,yax],coord[1]),(mask,0),mask)
ind = np.where(mask == 1)
return self.grids[ind], ind
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/oct_geometry_handler.py
--- a/yt/geometry/oct_geometry_handler.py
+++ b/yt/geometry/oct_geometry_handler.py
@@ -30,7 +30,7 @@
from yt.utilities.definitions import MAXLEVEL
from yt.utilities.io_handler import io_registry
from yt.utilities.parallel_tools.parallel_analysis_interface import \
- ParallelAnalysisInterface, parallel_splitter
+ ParallelAnalysisInterface
from yt.data_objects.data_containers import data_object_registry
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/particle_geometry_handler.py
--- a/yt/geometry/particle_geometry_handler.py
+++ b/yt/geometry/particle_geometry_handler.py
@@ -32,7 +32,7 @@
from yt.utilities.definitions import MAXLEVEL
from yt.utilities.io_handler import io_registry
from yt.utilities.parallel_tools.parallel_analysis_interface import \
- ParallelAnalysisInterface, parallel_splitter
+ ParallelAnalysisInterface
from yt.data_objects.data_containers import data_object_registry
from yt.data_objects.octree_subset import ParticleOctreeSubset
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/geometry/selection_routines.pyx
--- a/yt/geometry/selection_routines.pyx
+++ b/yt/geometry/selection_routines.pyx
@@ -455,10 +455,10 @@
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- def count_points(self, np.ndarray[np.float64_t, ndim=1] x,
- np.ndarray[np.float64_t, ndim=1] y,
- np.ndarray[np.float64_t, ndim=1] z,
- np.float64_t radius = 0.0):
+ def count_points(self, np.ndarray[anyfloat, ndim=1] x,
+ np.ndarray[anyfloat, ndim=1] y,
+ np.ndarray[anyfloat, ndim=1] z,
+ np.float64_t radius):
cdef int count = 0
cdef int i
cdef np.float64_t pos[3]
@@ -483,10 +483,10 @@
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- def select_points(self, np.ndarray[np.float64_t, ndim=1] x,
- np.ndarray[np.float64_t, ndim=1] y,
- np.ndarray[np.float64_t, ndim=1] z,
- np.float64_t radius = 0.0):
+ def select_points(self, np.ndarray[anyfloat, ndim=1] x,
+ np.ndarray[anyfloat, ndim=1] y,
+ np.ndarray[anyfloat, ndim=1] z,
+ np.float64_t radius):
cdef int count = 0
cdef int i
cdef np.float64_t pos[3]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/gui/reason/extdirect_repl.py
--- a/yt/gui/reason/extdirect_repl.py
+++ b/yt/gui/reason/extdirect_repl.py
@@ -39,7 +39,6 @@
from yt.funcs import *
from yt.utilities.logger import ytLogger, ufstring
-from yt.utilities.definitions import inv_axis_names
from yt.visualization.image_writer import apply_colormap
from yt.visualization.api import Streamlines
from .widget_store import WidgetStore
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/gui/reason/widget_store.py
--- a/yt/gui/reason/widget_store.py
+++ b/yt/gui/reason/widget_store.py
@@ -61,10 +61,11 @@
center = pf.h.find_max('Density')[1]
else:
center = np.array(center)
- axis = inv_axis_names[axis.lower()]
+ axis = pf.coordinates.axis_id[axis.lower()]
coord = center[axis]
sl = pf.slice(axis, coord, center = center)
- xax, yax = x_dict[axis], y_dict[axis]
+ xax = pf.coordinates.x_axis[axis]
+ yax = pf.coordinates.y_axis[axis]
DLE, DRE = pf.domain_left_edge, pf.domain_right_edge
pw = PWViewerExtJS(sl, (DLE[xax], DRE[xax], DLE[yax], DRE[yax]),
setup = False, plot_type='SlicePlot')
@@ -82,9 +83,10 @@
def create_proj(self, pf, axis, field, weight):
if weight == "None": weight = None
- axis = inv_axis_names[axis.lower()]
+ axis = pf.coordinates.axis_id[axis.lower()]
proj = pf.proj(field, axis, weight_field=weight)
- xax, yax = x_dict[axis], y_dict[axis]
+ xax = pf.coordinates.x_axis[axis]
+ yax = pf.coordinates.y_axis[axis]
DLE, DRE = pf.domain_left_edge, pf.domain_right_edge
pw = PWViewerExtJS(proj, (DLE[xax], DRE[xax], DLE[yax], DRE[yax]),
setup = False, plot_type='ProjectionPlot')
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -237,7 +237,9 @@
__array_priority__ = 2.0
- def __new__(cls, input_array, input_units=None, registry=None, dtype=np.float64):
+ def __new__(cls, input_array, input_units=None, registry=None, dtype=None):
+ if dtype is None:
+ dtype = getattr(input_array, 'dtype', np.float64)
if input_array is NotImplemented:
return input_array
if registry is None and isinstance(input_units, basestring):
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/answer_testing/output_tests.py
--- a/yt/utilities/answer_testing/output_tests.py
+++ b/yt/utilities/answer_testing/output_tests.py
@@ -177,8 +177,8 @@
This is a helper function that returns a 2D array of the specified
source, in the specified field, at the specified spatial extent.
"""
- xax = x_dict[self.axis]
- yax = y_dict[self.axis]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
if edges is None:
edges = (self.pf.domain_left_edge[xax],
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/definitions.py
--- a/yt/utilities/definitions.py
+++ b/yt/utilities/definitions.py
@@ -21,20 +21,6 @@
# The number of levels we expect to have at most
MAXLEVEL=48
-axis_labels = [('y','z'),('x','z'),('x','y')]
-axis_names = {0: 'x', 1: 'y', 2: 'z', 4:''}
-inv_axis_names = {'x':0,'y':1,'z':2,
- 'X':0,'Y':1,'Z':2}
-
-vm_axis_names = {0:'x', 1:'y', 2:'z', 3:'dx', 4:'dy'}
-
-# The appropriate axes for which way we are slicing
-x_dict = [1,0,0]
-y_dict = [2,2,1]
-
-x_names = ['y','x','x']
-y_names = ['z','z','y']
-
# How many of each thing are in an Mpc
mpc_conversion = {'Mpc' : mpc_per_mpc,
'mpc' : mpc_per_mpc,
@@ -56,5 +42,3 @@
'Myr' : sec_per_Myr,
'years' : sec_per_year,
'days' : sec_per_day}
-
-axis_labels = [('y','z'),('x','z'),('x','y')]
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/io_handler.py
--- a/yt/utilities/io_handler.py
+++ b/yt/utilities/io_handler.py
@@ -146,7 +146,7 @@
# Here, ptype_map means which particles contribute to a given type.
# And ptf is the actual fields from disk to read.
for ptype, (x, y, z) in self._read_particle_coords(chunks, ptf):
- psize[ptype] += selector.count_points(x, y, z)
+ psize[ptype] += selector.count_points(x, y, z, 0.0)
self._last_selector_counts = dict(**psize)
self._last_selector_id = hash(selector)
# Now we allocate
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/lib/ContourFinding.pxd
--- a/yt/utilities/lib/ContourFinding.pxd
+++ b/yt/utilities/lib/ContourFinding.pxd
@@ -39,6 +39,7 @@
ContourID *parent
ContourID *next
ContourID *prev
+ np.int64_t count
cdef struct CandidateContour
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/lib/ContourFinding.pyx
--- a/yt/utilities/lib/ContourFinding.pyx
+++ b/yt/utilities/lib/ContourFinding.pyx
@@ -36,6 +36,7 @@
node.contour_id = contour_id
node.next = node.parent = NULL
node.prev = prev
+ node.count = 0
if prev != NULL: prev.next = node
return node
@@ -631,7 +632,8 @@
np.ndarray[np.float64_t, ndim=2] positions,
np.ndarray[np.int64_t, ndim=1] particle_ids,
int domain_id = -1, int domain_offset = 0,
- periodicity = (True, True, True)):
+ periodicity = (True, True, True),
+ minimum_count = 8):
cdef np.ndarray[np.int64_t, ndim=1] pdoms, pcount, pind, doff
cdef np.float64_t pos[3]
cdef Oct *oct = NULL, **neighbors = NULL
@@ -728,6 +730,16 @@
c1 = container[poffset]
c0 = contour_find(c1)
contour_ids[pind[poffset]] = c0.contour_id
+ c0.count += 1
+ for i in range(doff.shape[0]):
+ if doff[i] < 0: continue
+ for j in range(pcount[i]):
+ poffset = doff[i] + j
+ c1 = container[poffset]
+ if c1 == NULL: continue
+ c0 = contour_find(c1)
+ if c0.count < minimum_count:
+ contour_ids[pind[poffset]] = -1
free(container)
return contour_ids
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/utilities/parallel_tools/parallel_analysis_interface.py
--- a/yt/utilities/parallel_tools/parallel_analysis_interface.py
+++ b/yt/utilities/parallel_tools/parallel_analysis_interface.py
@@ -27,8 +27,6 @@
ensure_list, iterable, traceback_writer_hook
from yt.config import ytcfg
-from yt.utilities.definitions import \
- x_dict, y_dict
import yt.utilities.logger
from yt.utilities.lib.QuadTree import \
QuadTree, merge_quadtrees
@@ -179,14 +177,13 @@
used on objects that subclass
:class:`~yt.utilities.parallel_tools.parallel_analysis_interface.ParallelAnalysisInterface`.
"""
- if not parallel_capable: return func
@wraps(func)
def single_proc_results(self, *args, **kwargs):
retval = None
if hasattr(self, "dont_wrap"):
if func.func_name in self.dont_wrap:
return func(self, *args, **kwargs)
- if self._processing or not self._distributed:
+ if not parallel_capable or self._processing or not self._distributed:
return func(self, *args, **kwargs)
comm = _get_comm((self,))
if self._owner == comm.rank:
@@ -243,6 +240,8 @@
"""
@wraps(func)
def barrierize(*args, **kwargs):
+ if not parallel_capable:
+ return func(*args, **kwargs)
mylog.debug("Entering barrier before %s", func.func_name)
comm = _get_comm(args)
comm.barrier()
@@ -250,26 +249,7 @@
mylog.debug("Entering barrier after %s", func.func_name)
comm.barrier()
return retval
- if parallel_capable:
- return barrierize
- else:
- return func
-
-def parallel_splitter(f1, f2):
- """
- This function returns either the function *f1* or *f2* depending on whether
- or not we're the root processor. Mainly used in class definitions.
- """
- @wraps(f1)
- def in_order(*args, **kwargs):
- comm = _get_comm(args)
- if comm.rank == 0:
- f1(*args, **kwargs)
- comm.barrier()
- if comm.rank != 0:
- f2(*args, **kwargs)
- if not parallel_capable: return f1
- return in_order
+ return barrierize
def parallel_root_only(func):
"""
@@ -278,6 +258,8 @@
"""
@wraps(func)
def root_only(*args, **kwargs):
+ if not parallel_capable:
+ return func(*args, **kwargs)
comm = _get_comm(args)
rv = None
if comm.rank == 0:
@@ -292,8 +274,7 @@
all_clear = comm.mpi_bcast(all_clear)
if not all_clear: raise RuntimeError
return rv
- if parallel_capable: return root_only
- return func
+ return root_only
class Workgroup(object):
def __init__(self, size, ranks, comm, name):
@@ -1083,7 +1064,8 @@
return False, self.index.grid_collection(self.center,
self.index.grids)
- xax, yax = x_dict[axis], y_dict[axis]
+ xax = self.pf.coordinates.x_axis[axis]
+ yax = self.pf.coordinates.y_axis[axis]
cc = MPI.Compute_dims(self.comm.size, 2)
mi = self.comm.rank
cx, cy = np.unravel_index(mi, cc)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/base_plot_types.py
--- a/yt/visualization/base_plot_types.py
+++ b/yt/visualization/base_plot_types.py
@@ -17,7 +17,7 @@
from ._mpl_imports import \
FigureCanvasAgg, FigureCanvasPdf, FigureCanvasPS
from yt.funcs import \
- get_image_suffix, mylog, x_dict, y_dict
+ get_image_suffix, mylog
import numpy as np
class CallbackWrapper(object):
@@ -30,8 +30,8 @@
self.image = self._axes.images[0]
if frb.axis < 3:
DD = frb.pf.domain_width
- xax = x_dict[frb.axis]
- yax = y_dict[frb.axis]
+ xax = frb.pf.coordinates.x_axis[frb.axis]
+ yax = frb.pf.coordinates.y_axis[frb.axis]
self._period = (DD[xax], DD[yax])
self.pf = frb.pf
self.xlim = viewer.xlim
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/eps_writer.py
--- a/yt/visualization/eps_writer.py
+++ b/yt/visualization/eps_writer.py
@@ -19,11 +19,6 @@
from _mpl_imports import FigureCanvasAgg
from yt.utilities.logger import ytLogger as mylog
-from yt.utilities.definitions import \
- x_dict, x_names, \
- y_dict, y_names, \
- axis_names, \
- axis_labels
from .plot_window import PlotWindow
from .profile_plotter import PhasePlot
from .plot_modifications import get_smallest_appropriate_unit
@@ -296,6 +291,7 @@
_yrange = (0, width * plot.pf[units])
_xlog = False
_ylog = False
+ axis_names = plot.pf.coordinates.axis_name
if bare_axes:
_xlabel = ""
_ylabel = ""
@@ -305,14 +301,16 @@
_xlabel = xlabel
else:
if data.axis != 4:
- _xlabel = '%s (%s)' % (x_names[data.axis], units)
+ xax = plot.pf.coordinates.x_axis[data.axis]
+ _xlabel = '%s (%s)' % (axis_names[xax], units)
else:
_xlabel = 'Image x (%s)' % (units)
if ylabel != None:
_ylabel = ylabel
else:
if data.axis != 4:
- _ylabel = '%s (%s)' % (y_names[data.axis], units)
+ yax = plot.pf.coordinatesyx_axis[data.axis]
+ _ylabel = '%s (%s)' % (axis_names[yax], units)
else:
_ylabel = 'Image y (%s)' % (units)
if tickcolor == None:
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/fixed_resolution.py
--- a/yt/visualization/fixed_resolution.py
+++ b/yt/visualization/fixed_resolution.py
@@ -14,10 +14,6 @@
#-----------------------------------------------------------------------------
from yt.funcs import *
-from yt.utilities.definitions import \
- x_dict, \
- y_dict, \
- axis_names
from .volume_rendering.api import off_axis_projection
from yt.data_objects.image_array import ImageArray
from yt.utilities.lib.misc_utilities import \
@@ -104,8 +100,8 @@
DRE = self.pf.domain_right_edge
DD = float(self.periodic)*(DRE - DLE)
axis = self.data_source.axis
- xax = x_dict[axis]
- yax = y_dict[axis]
+ xax = self.pf.coordinates.x_axis[axis]
+ yax = self.pf.coordinates.y_axis[axis]
self._period = (DD[xax], DD[yax])
self._edges = ( (DLE[xax], DRE[xax]), (DLE[yax], DRE[yax]) )
@@ -334,10 +330,10 @@
@property
def limits(self):
rv = dict(x = None, y = None, z = None)
- xax = x_dict[self.axis]
- yax = y_dict[self.axis]
- xn = axis_names[xax]
- yn = axis_names[yax]
+ xax = self.pf.coordinates.x_axis[self.axis]
+ yax = self.pf.coordinates.y_axis[self.axis]
+ xn = self.pf.coordinates.axis_name[xax]
+ yn = self.pf.coordinates.axis_name[yax]
rv[xn] = (self.bounds[0], self.bounds[1])
rv[yn] = (self.bounds[2], self.bounds[3])
return rv
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/image_panner/vm_panner.py
--- a/yt/visualization/image_panner/vm_panner.py
+++ b/yt/visualization/image_panner/vm_panner.py
@@ -19,8 +19,6 @@
FixedResolutionBuffer, ObliqueFixedResolutionBuffer
from yt.data_objects.data_containers import \
data_object_registry
-from yt.utilities.definitions import \
- x_dict, y_dict
from yt.funcs import *
class VariableMeshPanner(object):
@@ -62,7 +60,8 @@
if not hasattr(self, 'pf'): self.pf = self.source.pf
DLE, DRE = self.pf.domain_left_edge, self.pf.domain_right_edge
ax = self.source.axis
- xax, yax = x_dict[ax], y_dict[ax]
+ xax = self.pf.coordinates.x_axis[ax]
+ yax = self.pf.coordinates.y_axis[ax]
xbounds = DLE[xax], DRE[xax]
ybounds = DLE[yax], DRE[yax]
return (xbounds, ybounds)
@@ -183,8 +182,10 @@
if len(center) == 2:
centerx, centery = center
elif len(center) == 3:
- centerx = center[x_dict[self.source.axis]]
- centery = center[y_dict[self.source.axis]]
+ xax = self.pf.coordinates.x_axis[self.source.axis]
+ yax = self.pf.coordinates.y_axis[self.source.axis]
+ centerx = center[xax]
+ centery = center[yax]
else:
raise RuntimeError
Wx, Wy = self.width
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/plot_container.py
--- a/yt/visualization/plot_container.py
+++ b/yt/visualization/plot_container.py
@@ -17,7 +17,6 @@
from yt.funcs import \
defaultdict, get_image_suffix, \
get_ipython_api_version
-from yt.utilities.definitions import axis_names
from yt.utilities.exceptions import \
YTNotInsideNotebook
from ._mpl_imports import FigureCanvasAgg
@@ -424,7 +423,8 @@
for k, v in self.plots.iteritems():
names.append(v.save(name, mpl_kwargs))
return names
- axis = axis_names[self.data_source.axis]
+ axis = self.pf.coordinates.axis_name.get(
+ self.data_source.axis, '')
weight = None
type = self._plot_type
if type in ['Projection', 'OffAxisProjection']:
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -1,4 +1,4 @@
-"""
+ """
Callbacks to add additional functionality on to plots.
@@ -19,11 +19,6 @@
from yt.funcs import *
from _mpl_imports import *
-from yt.utilities.definitions import \
- x_dict, x_names, \
- y_dict, y_names, \
- axis_names, \
- axis_labels
from yt.utilities.physical_constants import \
sec_per_Gyr, sec_per_Myr, \
sec_per_kyr, sec_per_year, \
@@ -116,13 +111,17 @@
"cutting_plane_velocity_y",
self.factor)
else:
- xv = "velocity_%s" % (x_names[plot.data.axis])
- yv = "velocity_%s" % (y_names[plot.data.axis])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
+ axis_names = plot.data.pf.coordinates.axis_name
+ xv = "velocity_%s" % (axis_names[xi])
+ yv = "velocity_%s" % (axis_names[yi])
bv = plot.data.get_field_parameter("bulk_velocity")
if bv is not None:
- bv_x = bv[x_dict[plot.data.axis]]
- bv_y = bv[y_dict[plot.data.axis]]
+ bv_x = bv[xi]
+ bv_y = bv[yi]
else: bv_x = bv_y = YTQuantity(0, 'cm/s')
qcb = QuiverCallback(xv, yv, self.factor, scale=self.scale,
@@ -157,8 +156,11 @@
"cutting_plane_by",
self.factor)
else:
- xv = "magnetic_field_%s" % (x_names[plot.data.axis])
- yv = "magnetic_field_%s" % (y_names[plot.data.axis])
+ xax = plot.data.pf.coordinates.x_axis[plot.data.axis]
+ yax = plot.data.pf.coordinates.y_axis[plot.data.axis]
+ axis_names = plot.data.pf.coordinates.axis_name
+ xv = "magnetic_field_%s" % (axis_names[xax])
+ yv = "magnetic_field_%s" % (axis_names[yax])
qcb = QuiverCallback(xv, yv, self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize)
return qcb(plot)
@@ -195,8 +197,10 @@
# periodicity
ax = plot.data.axis
pf = plot.data.pf
- period_x = pf.domain_width[x_dict[ax]]
- period_y = pf.domain_width[y_dict[ax]]
+ (xi, yi) = (pf.coordinates.x_axis[ax],
+ pf.coordinates.y_axis[ax])
+ period_x = pf.domain_width[xi]
+ period_y = pf.domain_width[yi]
periodic = int(any(pf.periodicity))
fv_x = plot.data[self.field_x]
if self.bv_x != 0.0:
@@ -385,8 +389,9 @@
yy0, yy1 = plot._axes.get_ylim()
(dx, dy) = self.pixel_scale(plot)
(xpix, ypix) = plot.image._A.shape
- px_index = x_dict[plot.data.axis]
- py_index = y_dict[plot.data.axis]
+ ax = plot.data.axis
+ px_index = plot.data.pf.coordinates.x_axis[ax]
+ py_index = plot.data.pf.coordinates.y_axis[ax]
DW = plot.data.pf.domain_width
if self.periodic:
pxs, pys = np.mgrid[-1:1:3j,-1:1:3j]
@@ -651,11 +656,12 @@
plot._axes.hold(True)
- px_index = x_dict[plot.data.axis]
- py_index = y_dict[plot.data.axis]
+ ax = plot.data.axis
+ px_index = plot.data.pf.coordinates.x_axis[ax]
+ py_index = plot.data.pf.coordinates.y_axis[ax]
- xf = axis_names[px_index]
- yf = axis_names[py_index]
+ xf = plot.data.pf.coordinates.axis_name[px_index]
+ yf = plot.data.pf.coordinates.axis_name[py_index]
dxf = "d%s" % xf
dyf = "d%s" % yf
@@ -702,8 +708,10 @@
def __call__(self, plot):
if len(self.pos) == 3:
- pos = (self.pos[x_dict[plot.data.axis]],
- self.pos[y_dict[plot.data.axis]])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
+ pos = self.pos[xi], self.pos[yi]
else: pos = self.pos
if isinstance(self.code_size[1], basestring):
code_size = plot.data.pf.quan(*self.code_size).value
@@ -733,8 +741,10 @@
def __call__(self, plot):
if len(self.pos) == 3:
- pos = (self.pos[x_dict[plot.data.axis]],
- self.pos[y_dict[plot.data.axis]])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
+ pos = self.pos[xi], self.pos[yi]
else: pos = self.pos
width,height = plot.image._A.shape
x,y = self.convert_to_plot(plot, pos)
@@ -759,8 +769,10 @@
xx0, xx1 = plot._axes.get_xlim()
yy0, yy1 = plot._axes.get_ylim()
if len(self.pos) == 3:
- pos = (self.pos[x_dict[plot.data.axis]],
- self.pos[y_dict[plot.data.axis]])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
+ pos = self.pos[xi], self.pos[yi]
elif len(self.pos) == 2:
pos = self.pos
x,y = self.convert_to_plot(plot, pos)
@@ -803,7 +815,9 @@
if plot.data.axis == 4:
(xi, yi) = (0, 1)
else:
- (xi, yi) = (x_dict[plot.data.axis], y_dict[plot.data.axis])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
(center_x,center_y) = self.convert_to_plot(plot,(self.center[xi], self.center[yi]))
@@ -834,8 +848,10 @@
kwargs = self.text_args.copy()
if self.data_coords and len(plot.image._A.shape) == 2:
if len(self.pos) == 3:
- pos = (self.pos[x_dict[plot.data.axis]],
- self.pos[y_dict[plot.data.axis]])
+ ax = plot.data.axis
+ (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
+ plot.data.pf.coordinates.y_axis[ax])
+ pos = self.pos[xi], self.pos[yi]
else: pos = self.pos
x,y = self.convert_to_plot(plot, pos)
else:
@@ -949,8 +965,12 @@
xx0, xx1 = plot._axes.get_xlim()
yy0, yy1 = plot._axes.get_ylim()
reg = self._get_region((x0,x1), (y0,y1), plot.data.axis, data)
- field_x = "particle_position_%s" % axis_names[x_dict[data.axis]]
- field_y = "particle_position_%s" % axis_names[y_dict[data.axis]]
+ ax = data.axis
+ xax = plot.data.pf.coordinates.x_axis[ax]
+ yax = plot.data.pf.coordinates.y_axis[ax]
+ axis_names = plot.data.pf.coordinates.axis_name
+ field_x = "particle_position_%s" % axis_names[xax]
+ field_y = "particle_position_%s" % axis_names[yax]
gg = ( ( reg[field_x] >= x0 ) & ( reg[field_x] <= x1 )
& ( reg[field_y] >= y0 ) & ( reg[field_y] <= y1 ) )
if self.ptype is not None:
@@ -978,8 +998,9 @@
def _get_region(self, xlim, ylim, axis, data):
LE, RE = [None]*3, [None]*3
- xax = x_dict[axis]
- yax = y_dict[axis]
+ pf = data.pf
+ xax = pf.coordinates.x_axis[axis]
+ yax = pf.coordinates.y_axis[axis]
zax = axis
LE[xax], RE[xax] = xlim
LE[yax], RE[yax] = ylim
@@ -1197,7 +1218,9 @@
def __call__(self, plot):
plot._axes.hold(True)
- xax, yax = x_dict[plot.data.axis], y_dict[plot.data.axis]
+ ax = data.axis
+ xax = plot.data.pf.coordinates.x_axis[ax]
+ yax = plot.data.pf.coordinates.y_axis[ax]
l_cy = triangle_plane_intersect(plot.data.axis, plot.data.coord, self.vertices)[:,:,(xax, yax)]
lc = matplotlib.collections.LineCollection(l_cy, **self.plot_args)
plot._axes.add_collection(lc)
diff -r 08c3dec255997a13ec09f7079706d8d7bcf6e3c1 -r 0691ab75861a266288df82c72d226da52cd098a2 yt/visualization/plot_window.py
--- a/yt/visualization/plot_window.py
+++ b/yt/visualization/plot_window.py
@@ -43,8 +43,6 @@
from yt.utilities.png_writer import \
write_png_to_string
from yt.utilities.definitions import \
- x_dict, y_dict, \
- axis_names, axis_labels, \
formatted_length_unit_names
from yt.utilities.math_utils import \
ortho_find
@@ -68,8 +66,8 @@
from pyparsing import ParseFatalException
def fix_unitary(u):
- if u is '1':
- return 'code_length'
+ if u == '1':
+ return 'unitary'
else:
return u
@@ -108,7 +106,9 @@
if width is None:
# Default to code units
if not iterable(axis):
- w = pf.domain_width[[x_dict[axis], y_dict[axis]]]
+ xax = pf.coordinates.x_axis[axis]
+ yax = pf.coordinates.y_axis[axis]
+ w = pf.domain_width[[xax, yax]]
else:
# axis is actually the normal vector
# for an off-axis data object.
@@ -189,10 +189,12 @@
center[2] = 0.0
else:
raise NotImplementedError
- bounds = (center[x_dict[axis]]-width[0] / 2,
- center[x_dict[axis]]+width[0] / 2,
- center[y_dict[axis]]-width[1] / 2,
- center[y_dict[axis]]+width[1] / 2)
+ xax = pf.coordinates.x_axis[axis]
+ yax = pf.coordinates.y_axis[axis]
+ bounds = (center[xax]-width[0] / 2,
+ center[xax]+width[0] / 2,
+ center[yax]-width[1] / 2,
+ center[yax]+width[1] / 2)
return (bounds, center)
def get_oblique_window_parameters(normal, center, width, pf, depth=None):
@@ -377,10 +379,30 @@
Parameters
----------
- deltas : sequence of floats
- (delta_x, delta_y) in *absolute* code unit coordinates
+ deltas : Two-element sequence of floats, quantities, or (float, unit)
+ tuples.
+ (delta_x, delta_y). If a unit is not supplied the unit is assumed
+ to be code_length.
"""
+ if len(deltas) != 2:
+ raise RuntimeError(
+ "The pan function accepts a two-element sequence.\n"
+ "Received %s." % (deltas, )
+ )
+ if isinstance(deltas[0], Number) and isinstance(deltas[1], Number):
+ deltas = (self.pf.quan(deltas[0], 'code_length'),
+ self.pf.quan(deltas[1], 'code_length'))
+ elif isinstance(deltas[0], tuple) and isinstance(deltas[1], tuple):
+ deltas = (self.pf.quan(deltas[0][0], deltas[0][1]),
+ self.pf.quan(deltas[1][0], deltas[1][1]))
+ elif isinstance(deltas[0], YTQuantity) and isinstance(deltas[1], YTQuantity):
+ pass
+ else:
+ raise RuntimeError(
+ "The arguments of the pan function must be a sequence of floats,\n"
+ "quantities, or (float, unit) tuples. Received %s." % (deltas, )
+ )
self.xlim = (self.xlim[0] + deltas[0], self.xlim[1] + deltas[0])
self.ylim = (self.ylim[0] + deltas[1], self.ylim[1] + deltas[1])
return self
@@ -652,10 +674,12 @@
xllim, xrlim = self.xlim
yllim, yrlim = self.ylim
elif origin[2] == 'domain':
- xllim = self.pf.domain_left_edge[x_dict[axis_index]]
- xrlim = self.pf.domain_right_edge[x_dict[axis_index]]
- yllim = self.pf.domain_left_edge[y_dict[axis_index]]
- yrlim = self.pf.domain_right_edge[y_dict[axis_index]]
+ xax = pf.coordinates.x_axis[axis_index]
+ yax = pf.coordinates.y_axis[axis_index]
+ xllim = self.pf.domain_left_edge[xax]
+ xrlim = self.pf.domain_right_edge[xax]
+ yllim = self.pf.domain_left_edge[yax]
+ yrlim = self.pf.domain_right_edge[yax]
elif origin[2] == 'native':
return (self.pf.quan(0.0, 'code_length'),
self.pf.quan(0.0, 'code_length'))
@@ -786,8 +810,11 @@
labels = [r'$\rm{Image\/x'+axes_unit_labels[0]+'}$',
r'$\rm{Image\/y'+axes_unit_labels[1]+'}$']
else:
- labels = [r'$\rm{'+axis_labels[axis_index][i]+
- axes_unit_labels[i] + r'}$' for i in (0,1)]
+ axis_names = self.pf.coordinates.axis_name
+ xax = self.pf.coordinates.x_axis[axis_index]
+ yax = self.pf.coordinates.y_axis[axis_index]
+ labels = [r'$\rm{'+axis_names[xax]+axes_unit_labels[0] + r'}$',
+ r'$\rm{'+axis_names[yax]+axes_unit_labels[1] + r'}$']
self.plots[f].axes.set_xlabel(labels[0],fontproperties=fp)
self.plots[f].axes.set_ylabel(labels[1],fontproperties=fp)
@@ -968,7 +995,7 @@
ts = self._initialize_dataset(pf)
self.ts = ts
pf = self.pf = ts[0]
- axis = fix_axis(axis)
+ axis = fix_axis(axis, pf)
(bounds, center) = get_window_parameters(axis, center, width, pf)
if field_parameters is None: field_parameters = {}
slc = pf.slice(axis, center[axis],
@@ -1094,7 +1121,7 @@
ts = self._initialize_dataset(pf)
self.ts = ts
pf = self.pf = ts[0]
- axis = fix_axis(axis)
+ axis = fix_axis(axis, pf)
(bounds, center) = get_window_parameters(axis, center, width, pf)
if field_parameters is None: field_parameters = {}
proj = pf.proj(fields, axis, weight_field=weight_field,
@@ -1424,8 +1451,11 @@
self._frb.bounds, (nx,ny))
axis = self._frb.data_source.axis
- fx = "%s-velocity" % (axis_names[x_dict[axis]])
- fy = "%s-velocity" % (axis_names[y_dict[axis]])
+ xax = self._frb.data_source.pf.coordinates.x_axis[axis]
+ yax = self._frb.data_source.pf.coordinates.y_axis[axis]
+ axis_names = self._frb.data_source.pf.coordinates.axis_name
+ fx = "velocity_%s" % (axis_names[xax])
+ fy = "velocity_%x" % (axis_names[yax])
px = new_frb[fx][::-1,:]
py = new_frb[fy][::-1,:]
x = np.mgrid[0:vi-1:ny*1j]
@@ -1509,12 +1539,14 @@
unit = self._axes_unit_names
units = self.get_field_units(field, strip_mathml)
center = getattr(self._frb.data_source, "center", None)
+ xax = self.pf.coordinates.x_axis[self._frb.axis]
+ yax = self.pf.coordinates.y_axis[self._frb.axis]
if center is None or self._frb.axis == 4:
xc, yc, zc = -999, -999, -999
else:
- center[x_dict[self._frb.axis]] = 0.5 * (
+ center[xax] = 0.5 * (
self.xlim[0] + self.xlim[1])
- center[y_dict[self._frb.axis]] = 0.5 * (
+ center[yax] = 0.5 * (
self.ylim[0] + self.ylim[1])
xc, yc, zc = center
if return_string:
https://bitbucket.org/yt_analysis/yt/commits/091cd103855c/
Changeset: 091cd103855c
Branch: yt-3.0
User: hegan
Date: 2014-04-28 22:15:22
Summary: removed documentation for annotate hop halos
Affected #: 1 file
diff -r 0691ab75861a266288df82c72d226da52cd098a2 -r 091cd103855c6786c33f4fcbd01a2b0c16422991 doc/source/visualizing/_cb_docstrings.inc
--- a/doc/source/visualizing/_cb_docstrings.inc
+++ b/doc/source/visualizing/_cb_docstrings.inc
@@ -132,45 +132,6 @@
-------------
-.. function:: annotate_hop_circles(self, hop_output, max_number=None, annotate=False, min_size=20, max_size=10000000, font_size=8, print_halo_size=False, print_halo_mass=False, width=None):
-
- (This is a proxy for :class:`~yt.visualization.plot_modifications.HopCircleCallback`.)
-
- Accepts a :class:`yt.HopList` *hop_output* and plots up
- to *max_number* (None for unlimited) halos as circles.
-
-.. python-script::
-
- from yt.mods import *
- pf = load("Enzo_64/DD0043/data0043")
- halos = HaloFinder(pf)
- p = ProjectionPlot(pf, "z", "density")
- p.annotate_hop_circles(halos)
- p.save()
-
--------------
-
-.. function:: annotate_hop_particles(self, hop_output, max_number, p_size=1.0, min_size=20, alpha=0.2):
-
- (This is a proxy for :class:`~yt.visualization.plot_modifications.HopParticleCallback`.)
-
- Adds particle positions for the members of each halo as
- identified by HOP. Along *axis* up to *max_number* groups
- in *hop_output* that are larger than *min_size* are
- plotted with *p_size* pixels per particle; *alpha*
- determines the opacity of each particle.
-
-.. python-script::
-
- from yt.mods import *
- pf = load("Enzo_64/DD0043/data0043")
- halos = HaloFinder(pf)
- p = ProjectionPlot(pf, "x", "density", center='m', width=(10, 'Mpc'))
- p.annotate_hop_particles(halos, max_number=100, p_size=5.0)
- p.save()
-
--------------
-
.. function:: annotate_image_line(self, p1, p2, data_coords=False, plot_args=None):
(This is a proxy for :class:`~yt.visualization.plot_modifications.ImageLineCallback`.)
https://bitbucket.org/yt_analysis/yt/commits/3680bc9b2a61/
Changeset: 3680bc9b2a61
Branch: yt-3.0
User: hegan
Date: 2014-05-01 18:16:16
Summary: removed indent
Affected #: 1 file
diff -r 091cd103855c6786c33f4fcbd01a2b0c16422991 -r 3680bc9b2a6130215ea441e312bcb79a2e3b5e76 yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -1,4 +1,4 @@
- """
+"""
Callbacks to add additional functionality on to plots.
https://bitbucket.org/yt_analysis/yt/commits/893a689ff45f/
Changeset: 893a689ff45f
Branch: yt-3.0
User: MatthewTurk
Date: 2014-05-01 19:02:34
Summary: Merged in hegan/yt/yt-3.0 (pull request #850)
Annotate_halos
Affected #: 2 files
diff -r f228e5bea67d968ec1ca2d2cc755883481b43dce -r 893a689ff45f9d951b367be0f3b6a8ebb9a74b18 doc/source/visualizing/_cb_docstrings.inc
--- a/doc/source/visualizing/_cb_docstrings.inc
+++ b/doc/source/visualizing/_cb_docstrings.inc
@@ -104,42 +104,31 @@
-------------
-.. function:: annotate_hop_circles(self, hop_output, max_number=None, annotate=False, min_size=20, max_size=10000000, font_size=8, print_halo_size=False, print_halo_mass=False, width=None):
+.. function:: annotate_halos(self, halo_catalog, col='white', alpha =1, width= None):
- (This is a proxy for :class:`~yt.visualization.plot_modifications.HopCircleCallback`.)
+ (This is a proxy for :class:`~yt.visualization.plot_modifications.HaloCatalogCallback`.)
- Accepts a :class:`yt.HopList` *hop_output* and plots up
- to *max_number* (None for unlimited) halos as circles.
+ Accepts a :class:`yt.HaloCatalog` *HaloCatalog* and plots
+ a circle at the location of each halo with the radius of
+ the circle corresponding to the virial radius of the halo.
+ If *width* is set to None (default) all halos are plotted.
+ Otherwise, only halos that fall within a slab with width
+ *width* centered on the center of the plot data. The
+ color and transparency of the circles can be controlled with
+ *col* and *alpha* respectively.
.. python-script::
+
+ from yt.mods import *
+ data_pf = load('Enzo_64/RD0006/RD0006')
+ halos_pf = load('rockstar_halos/halos_0.0.bin')
- from yt.mods import *
- pf = load("Enzo_64/DD0043/data0043")
- halos = HaloFinder(pf)
- p = ProjectionPlot(pf, "z", "density")
- p.annotate_hop_circles(halos)
- p.save()
+ hc = HaloCatalog(halos_pf=halos_pf)
+ hc.create()
--------------
-
-.. function:: annotate_hop_particles(self, hop_output, max_number, p_size=1.0, min_size=20, alpha=0.2):
-
- (This is a proxy for :class:`~yt.visualization.plot_modifications.HopParticleCallback`.)
-
- Adds particle positions for the members of each halo as
- identified by HOP. Along *axis* up to *max_number* groups
- in *hop_output* that are larger than *min_size* are
- plotted with *p_size* pixels per particle; *alpha*
- determines the opacity of each particle.
-
-.. python-script::
-
- from yt.mods import *
- pf = load("Enzo_64/DD0043/data0043")
- halos = HaloFinder(pf)
- p = ProjectionPlot(pf, "x", "density", center='m', width=(10, 'Mpc'))
- p.annotate_hop_particles(halos, max_number=100, p_size=5.0)
- p.save()
+ prj = ProjectionPlot(data_pf, 'z', 'density')
+ prj.annotate_halos(hc)
+ prj.save()
-------------
diff -r f228e5bea67d968ec1ca2d2cc755883481b43dce -r 893a689ff45f9d951b367be0f3b6a8ebb9a74b18 yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -1,4 +1,5 @@
"""
+
Callbacks to add additional functionality on to plots.
@@ -826,118 +827,6 @@
plot._axes.text(center_x, center_y, self.text,
**self.text_args)
-class HopCircleCallback(PlotCallback):
- """
- annotate_hop_circles(hop_output, max_number=None,
- annotate=False, min_size=20, max_size=10000000,
- font_size=8, print_halo_size=False,
- print_halo_mass=False, width=None)
-
- Accepts a :class:`yt.HopList` *hop_output* and plots up to
- *max_number* (None for unlimited) halos as circles.
- """
- _type_name = "hop_circles"
- def __init__(self, hop_output, max_number=None,
- annotate=False, min_size=20, max_size=10000000,
- font_size=8, print_halo_size=False,
- print_halo_mass=False, width=None):
- self.hop_output = hop_output
- self.max_number = max_number
- self.annotate = annotate
- self.min_size = min_size
- self.max_size = max_size
- self.font_size = font_size
- self.print_halo_size = print_halo_size
- self.print_halo_mass = print_halo_mass
- self.width = width
-
- def __call__(self, plot):
- from matplotlib.patches import Circle
- num = len(self.hop_output[:self.max_number])
- for halo in self.hop_output[:self.max_number]:
- size = halo.get_size()
- if size < self.min_size or size > self.max_size: continue
- # This could use halo.maximum_radius() instead of width
- if self.width is not None and \
- np.abs(halo.center_of_mass() -
- plot.data.center)[plot.data.axis] > \
- self.width:
- continue
-
- radius = halo.maximum_radius() * self.pixel_scale(plot)[0]
- center = halo.center_of_mass()
-
- ax = plot.data.axis
- (xi, yi) = (plot.data.pf.coordinates.x_axis[ax],
- plot.data.pf.coordinates.y_axis[ax])
-
- (center_x,center_y) = self.convert_to_plot(plot,(center[xi], center[yi]))
- color = np.ones(3) * (0.4 * (num - halo.id)/ num) + 0.6
- cir = Circle((center_x, center_y), radius, fill=False, color=color)
- plot._axes.add_patch(cir)
- if self.annotate:
- if self.print_halo_size:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % size,
- fontsize=self.font_size, color=color)
- elif self.print_halo_mass:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % halo.total_mass(),
- fontsize=self.font_size, color=color)
- else:
- plot._axes.text(center_x+radius, center_y+radius, "%s" % halo.id,
- fontsize=self.font_size, color=color)
-
-class HopParticleCallback(PlotCallback):
- """
- annotate_hop_particles(hop_output, max_number, p_size=1.0,
- min_size=20, alpha=0.2):
-
- Adds particle positions for the members of each halo as identified
- by HOP. Along *axis* up to *max_number* groups in *hop_output* that are
- larger than *min_size* are plotted with *p_size* pixels per particle;
- *alpha* determines the opacity of each particle.
- """
- _type_name = "hop_particles"
- def __init__(self, hop_output, max_number=None, p_size=1.0,
- min_size=20, alpha=0.2):
- self.hop_output = hop_output
- self.p_size = p_size
- self.max_number = max_number
- self.min_size = min_size
- self.alpha = alpha
-
- def __call__(self,plot):
- (dx,dy) = self.pixel_scale(plot)
-
- xax = plot.data.pf.coordinates.x_axis[plot.data.axis]
- yax = plot.data.pf.coordinates.y_axis[plot.data.axis]
- axis_names = plot.data.pf.coordinates.axis_name
- (xi, yi) = (axis_names[xax], axis_names[yax])
-
- # now we loop over the haloes
- for halo in self.hop_output[:self.max_number]:
- size = halo.get_size()
-
- if size < self.min_size: continue
-
- (px,py) = self.convert_to_plot(plot,(halo["particle_position_%s" % xi],
- halo["particle_position_%s" % yi]))
-
- # Need to get the plot limits and set the hold state before scatter
- # and then restore the limits and turn off the hold state afterwards
- # because scatter will automatically adjust the plot window which we
- # do not want
-
- xlim = plot._axes.get_xlim()
- ylim = plot._axes.get_ylim()
- plot._axes.hold(True)
-
- plot._axes.scatter(px, py, edgecolors="None",
- s=self.p_size, c='black', alpha=self.alpha)
-
- plot._axes.set_xlim(xlim)
- plot._axes.set_ylim(ylim)
- plot._axes.hold(False)
-
class TextLabelCallback(PlotCallback):
"""
@@ -971,6 +860,69 @@
kwargs["transform"] = plot._axes.transAxes
plot._axes.text(x, y, self.text, **kwargs)
+class HaloCatalogCallback(PlotCallback):
+
+ _type_name = 'halos'
+ region = None
+ _descriptor = None
+
+ def __init__(self, halo_catalog, col='white', alpha =1, width = None):
+ PlotCallback.__init__(self)
+ self.halo_catalog = halo_catalog
+ self.color = col
+ self.alpha = alpha
+ self.width = width
+
+ def __call__(self, plot):
+ data = plot.data
+ x0, x1 = plot.xlim
+ y0, y1 = plot.ylim
+ xx0, xx1 = plot._axes.get_xlim()
+ yy0, yy1 = plot._axes.get_ylim()
+
+ halo_data= self.halo_catalog.halos_pf.all_data()
+ field_x = "particle_position_%s" % axis_names[x_dict[data.axis]]
+ field_y = "particle_position_%s" % axis_names[y_dict[data.axis]]
+ field_z = "particle_position_%s" % axis_names[data.axis]
+ plot._axes.hold(True)
+
+ # Set up scales for pixel size and original data
+ units = 'Mpccm'
+ pixel_scale = self.pixel_scale(plot)[0]
+ data_scale = data.pf.length_unit
+
+ # Convert halo positions to code units of the plotted data
+ # and then to units of the plotted window
+ px = halo_data[field_x][:].in_units(units) / data_scale
+ py = halo_data[field_y][:].in_units(units) / data_scale
+ px, py = self.convert_to_plot(plot,[px,py])
+
+ # Convert halo radii to a radius in pixels
+ radius = halo_data['radius'][:].in_units(units)
+ radius = radius*pixel_scale/data_scale
+
+ if self.width:
+ pz = halo_data[field_z][:].in_units(units)/data_scale
+ pz = data.pf.arr(pz, 'code_length')
+ c = data.center[data.axis]
+
+ # I should catch an error here if width isn't in this form
+ # but I dont really want to reimplement get_sanitized_width...
+ width = data.pf.arr(self.width[0], self.width[1]).in_units('code_length')
+
+ indices = np.where((pz > c-width) & (pz < c+width))
+
+ px = px[indices]
+ py = py[indices]
+ radius = radius[indices]
+
+ plot._axes.scatter(px, py, edgecolors='None', marker='o',
+ s=radius, c=self.color,alpha=self.alpha)
+ plot._axes.set_xlim(xx0,xx1)
+ plot._axes.set_ylim(yy0,yy1)
+ plot._axes.hold(False)
+
+
class ParticleCallback(PlotCallback):
"""
annotate_particles(width, p_size=1.0, col='k', marker='o', stride=1.0,
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list