[yt-svn] commit/yt: ngoldbaum: Merged in brittonsmith/yt (pull request #2482)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Wed Jan 11 20:29:40 PST 2017
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/ec92132823b0/
Changeset: ec92132823b0
Branch: yt
User: ngoldbaum
Date: 2017-01-12 04:29:14+00:00
Summary: Merged in brittonsmith/yt (pull request #2482)
Use parent unit_registry for YTData and HaloCatalog datasets
Affected #: 8 files
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
--- a/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
+++ b/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
@@ -655,33 +655,32 @@
Write light ray data to hdf5 file.
"""
+
+ extra_attrs = {"data_type": "yt_light_ray"}
if self.simulation_type is None:
ds = self.ds
else:
ds = {}
- ds["dimensionality"] = self.simulation.dimensionality
- ds["domain_left_edge"] = self.simulation.domain_left_edge
- ds["domain_right_edge"] = self.simulation.domain_right_edge
- ds["cosmological_simulation"] = self.simulation.cosmological_simulation
ds["periodicity"] = (True, True, True)
ds["current_redshift"] = self.near_redshift
- for attr in ["omega_lambda", "omega_matter", "hubble_constant"]:
- ds[attr] = getattr(self.cosmology, attr)
+ for attr in ["dimensionality", "cosmological_simulation",
+ "domain_left_edge", "domain_right_edge",
+ "length_unit", "time_unit"]:
+ ds[attr] = getattr(self.simulation, attr)
+ if self.simulation.cosmological_simulation:
+ for attr in ["omega_lambda", "omega_matter",
+ "hubble_constant"]:
+ ds[attr] = getattr(self.cosmology, attr)
ds["current_time"] = \
self.cosmology.t_from_z(ds["current_redshift"])
if isinstance(ds["hubble_constant"], YTArray):
ds["hubble_constant"] = \
ds["hubble_constant"].to("100*km/(Mpc*s)").d
- extra_attrs = {"data_type": "yt_light_ray"}
+ extra_attrs["unit_registry_json"] = \
+ self.simulation.unit_registry.to_json()
# save the light ray solution
if len(self.light_ray_solution) > 0:
- # Convert everything to base unit system now to avoid
- # problems with different units for each ds.
- for s in self.light_ray_solution:
- for f in s:
- if isinstance(s[f], YTArray):
- s[f].convert_to_base()
for key in self.light_ray_solution[0]:
if key in ["next", "previous", "index"]:
continue
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/data_objects/static_output.py
--- a/yt/data_objects/static_output.py
+++ b/yt/data_objects/static_output.py
@@ -299,14 +299,7 @@
self.no_cgs_equiv_length = False
self._create_unit_registry()
-
- create_code_unit_system(self)
- if unit_system == "code":
- unit_system = str(self)
- else:
- unit_system = str(unit_system).lower()
-
- self.unit_system = unit_system_registry[unit_system]
+ self._assign_unit_system(unit_system)
self._parse_parameter_file()
self.set_units()
@@ -336,8 +329,9 @@
for attr in ("center", "width", "left_edge", "right_edge"):
n = "domain_%s" % attr
v = getattr(self, n)
- v = self.arr(v, "code_length")
- setattr(self, n, v)
+ if not isinstance(v, YTArray):
+ v = self.arr(v, "code_length")
+ setattr(self, n, v)
def __reduce__(self):
args = (self._hash(),)
@@ -880,6 +874,14 @@
def relative_refinement(self, l0, l1):
return self.refine_by**(l1-l0)
+ def _assign_unit_system(self, unit_system):
+ create_code_unit_system(self)
+ if unit_system == "code":
+ unit_system = str(self)
+ else:
+ unit_system = str(unit_system).lower()
+ self.unit_system = unit_system_registry[unit_system]
+
def _create_unit_registry(self):
self.unit_registry = UnitRegistry()
import yt.units.dimensions as dimensions
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/halo_catalog/data_structures.py
--- a/yt/frontends/halo_catalog/data_structures.py
+++ b/yt/frontends/halo_catalog/data_structures.py
@@ -16,22 +16,20 @@
from yt.utilities.on_demand_imports import _h5py as h5py
import numpy as np
-import stat
import glob
-import os
from .fields import \
HaloCatalogFieldInfo
+from yt.frontends.ytdata.data_structures import \
+ SavedDataset
from yt.funcs import \
- parse_h5_attr, \
- setdefaultattr
+ parse_h5_attr
from yt.geometry.particle_geometry_handler import \
ParticleIndex
from yt.data_objects.static_output import \
- Dataset, \
ParticleFile
-
+
class HaloCatalogHDF5File(ParticleFile):
def __init__(self, ds, io, filename, file_id):
with h5py.File(filename, "r") as f:
@@ -40,11 +38,15 @@
super(HaloCatalogHDF5File, self).__init__(ds, io, filename, file_id)
-class HaloCatalogDataset(Dataset):
+class HaloCatalogDataset(SavedDataset):
_index_class = ParticleIndex
_file_class = HaloCatalogHDF5File
_field_info_class = HaloCatalogFieldInfo
_suffix = ".h5"
+ _con_attrs = ("cosmological_simulation",
+ "current_time", "current_redshift",
+ "hubble_constant", "omega_matter", "omega_lambda",
+ "domain_left_edge", "domain_right_edge")
def __init__(self, filename, dataset_type="halocatalog_hdf5",
n_ref = 16, over_refine_factor = 1, units_override=None,
@@ -56,33 +58,17 @@
unit_system=unit_system)
def _parse_parameter_file(self):
- with h5py.File(self.parameter_filename, "r") as f:
- hvals = dict((key, parse_h5_attr(f, key)) for key in f.attrs.keys())
+ self.refine_by = 2
self.dimensionality = 3
- self.refine_by = 2
- self.unique_identifier = \
- int(os.stat(self.parameter_filename)[stat.ST_CTIME])
+ nz = 1 << self.over_refine_factor
+ self.domain_dimensions = np.ones(self.dimensionality, "int32") * nz
+ self.periodicity = (True, True, True)
prefix = ".".join(self.parameter_filename.rsplit(".", 2)[:-2])
self.filename_template = "%s.%%(num)s%s" % (prefix, self._suffix)
self.file_count = len(glob.glob(prefix + "*" + self._suffix))
-
- for attr in ["cosmological_simulation", "current_time", "current_redshift",
- "hubble_constant", "omega_matter", "omega_lambda",
- "domain_left_edge", "domain_right_edge"]:
- setattr(self, attr, hvals[attr])
- self.periodicity = (True, True, True)
self.particle_types = ("halos")
self.particle_types_raw = ("halos")
-
- nz = 1 << self.over_refine_factor
- self.domain_dimensions = np.ones(3, "int32") * nz
- self.parameters.update(hvals)
-
- def _set_code_unit_attributes(self):
- setdefaultattr(self, 'length_unit', self.quan(1.0, "cm"))
- setdefaultattr(self, 'mass_unit', self.quan(1.0, "g"))
- setdefaultattr(self, 'velocity_unit', self.quan(1.0, "cm / s"))
- setdefaultattr(self, 'time_unit', self.quan(1.0, "s"))
+ super(HaloCatalogDataset, self)._parse_parameter_file()
@classmethod
def _is_valid(self, *args, **kwargs):
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/halo_catalog/io.py
--- a/yt/frontends/halo_catalog/io.py
+++ b/yt/frontends/halo_catalog/io.py
@@ -42,11 +42,13 @@
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
+ pn = "particle_position_%s"
for data_file in sorted(data_files):
with h5py.File(data_file.filename, "r") as f:
- x = f['particle_position_x'].value.astype("float64")
- y = f['particle_position_y'].value.astype("float64")
- z = f['particle_position_z'].value.astype("float64")
+ units = f[pn % "x"].attrs["units"]
+ x, y, z = \
+ (self.ds.arr(f[pn % ax].value.astype("float64"), units)
+ for ax in "xyz")
yield "halos", (x, y, z)
def _read_particle_fields(self, chunks, ptf, selector):
@@ -59,12 +61,14 @@
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
+ pn = "particle_position_%s"
for data_file in sorted(data_files):
with h5py.File(data_file.filename, "r") as f:
for ptype, field_list in sorted(ptf.items()):
- x = f['particle_position_x'].value.astype("float64")
- y = f['particle_position_y'].value.astype("float64")
- z = f['particle_position_z'].value.astype("float64")
+ units = f[pn % "x"].attrs["units"]
+ x, y, z = \
+ (self.ds.arr(f[pn % ax].value.astype("float64"), units)
+ for ax in "xyz")
mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None: continue
@@ -82,28 +86,27 @@
with h5py.File(data_file.filename, "r") as f:
if not f.keys(): return None
pos = np.empty((pcount, 3), dtype="float64")
- pos = data_file.ds.arr(pos, "code_length")
+ units = f["particle_position_x"].attrs["units"]
dx = np.finfo(f['particle_position_x'].dtype).eps
- dx = 2.0*self.ds.quan(dx, "code_length")
+ dx = 2.0 * self.ds.quan(dx, units).to("code_length")
pos[:,0] = f["particle_position_x"].value
pos[:,1] = f["particle_position_y"].value
pos[:,2] = f["particle_position_z"].value
+ pos = data_file.ds.arr(pos, units). to("code_length")
+ dle = self.ds.domain_left_edge.to("code_length")
+ dre = self.ds.domain_right_edge.to("code_length")
# These are 32 bit numbers, so we give a little lee-way.
# Otherwise, for big sets of particles, we often will bump into the
# domain edges. This helps alleviate that.
- np.clip(pos, self.ds.domain_left_edge + dx,
- self.ds.domain_right_edge - dx, pos)
- if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
- np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+ np.clip(pos, dle + dx, dre - dx, pos)
+ if np.any(pos.min(axis=0) < dle) or \
+ np.any(pos.max(axis=0) > dre):
raise YTDomainOverflow(pos.min(axis=0),
pos.max(axis=0),
- self.ds.domain_left_edge,
- self.ds.domain_right_edge)
+ dle, dre)
regions.add_data_file(pos, data_file.file_id)
morton[ind:ind+pos.shape[0]] = compute_morton(
- pos[:,0], pos[:,1], pos[:,2],
- data_file.ds.domain_left_edge,
- data_file.ds.domain_right_edge)
+ pos[:,0], pos[:,1], pos[:,2], dle, dre)
return morton
def _count_particles(self, data_file):
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/data_structures.py
--- a/yt/frontends/ytdata/data_structures.py
+++ b/yt/frontends/ytdata/data_structures.py
@@ -47,6 +47,10 @@
GridIndex
from yt.geometry.particle_geometry_handler import \
ParticleIndex
+from yt.units import \
+ dimensions
+from yt.units.unit_registry import \
+ UnitRegistry
from yt.units.yt_array import \
YTQuantity
from yt.utilities.logger import \
@@ -68,8 +72,12 @@
"covering_grid",
"smoothed_covering_grid"]
-class YTDataset(Dataset):
- """Base dataset class for all ytdata datasets."""
+class SavedDataset(Dataset):
+ """
+ Base dataset class for products of calling save_as_dataset.
+ """
+ _con_attrs = ()
+
def _parse_parameter_file(self):
self.refine_by = 2
with h5py.File(self.parameter_filename, "r") as f:
@@ -78,18 +86,109 @@
if key == "con_args":
v = v.astype("str")
self.parameters[key] = v
- self.num_particles = \
- dict([(group, parse_h5_attr(f[group], "num_elements"))
- for group in f if group != self.default_fluid_type])
- for attr in ["cosmological_simulation", "current_time", "current_redshift",
- "hubble_constant", "omega_matter", "omega_lambda",
- "dimensionality", "domain_dimensions", "periodicity",
- "domain_left_edge", "domain_right_edge",
- "container_type", "data_type"]:
+ self._with_parameter_file_open(f)
+
+ # if saved, restore unit registry from the json string
+ if "unit_registry_json" in self.parameters:
+ self.unit_registry = UnitRegistry.from_json(
+ self.parameters["unit_registry_json"])
+ # reset self.arr and self.quan to use new unit_registry
+ self._arr = None
+ self._quan = None
+ for dim in ["length", "mass", "pressure",
+ "temperature", "time", "velocity"]:
+ cu = "code_" + dim
+ if cu not in self.unit_registry:
+ self.unit_registry.add(
+ cu, 1.0, getattr(dimensions, dim))
+ if "code_magnetic" not in self.unit_registry:
+ self.unit_registry.add("code_magnetic", 1.0,
+ dimensions.magnetic_field)
+
+ # if saved, set unit system
+ if "unit_system_name" in self.parameters:
+ unit_system = self.parameters["unit_system_name"]
+ del self.parameters["unit_system_name"]
+ else:
+ unit_system = "cgs"
+ # reset unit system since we may have a new unit registry
+ self._assign_unit_system(unit_system)
+
+ # assign units to parameters that have associated unit string
+ del_pars = []
+ for par in self.parameters:
+ ustr = "%s_units" % par
+ if ustr in self.parameters:
+ if isinstance(self.parameters[par], np.ndarray):
+ to_u = self.arr
+ else:
+ to_u = self.quan
+ self.parameters[par] = to_u(
+ self.parameters[par], self.parameters[ustr])
+ del_pars.append(ustr)
+ for par in del_pars:
+ del self.parameters[par]
+
+ for attr in self._con_attrs:
setattr(self, attr, self.parameters.get(attr))
self.unique_identifier = \
int(os.stat(self.parameter_filename)[stat.ST_CTIME])
+ def _with_parameter_file_open(self, f):
+ # This allows subclasses to access the parameter file
+ # while it's still open to get additional information.
+ pass
+
+ def set_units(self):
+ if "unit_registry_json" in self.parameters:
+ self._set_code_unit_attributes()
+ del self.parameters["unit_registry_json"]
+ else:
+ super(SavedDataset, self).set_units()
+
+ def _set_code_unit_attributes(self):
+ attrs = ('length_unit', 'mass_unit', 'time_unit',
+ 'velocity_unit', 'magnetic_unit')
+ cgs_units = ('cm', 'g', 's', 'cm/s', 'gauss')
+ base_units = np.ones(len(attrs))
+ for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
+ if attr in self.parameters and \
+ isinstance(self.parameters[attr], YTQuantity):
+ uq = self.parameters[attr]
+ elif attr in self.parameters and \
+ "%s_units" % attr in self.parameters:
+ uq = self.quan(self.parameters[attr],
+ self.parameters["%s_units" % attr])
+ del self.parameters[attr]
+ del self.parameters["%s_units" % attr]
+ elif isinstance(unit, string_types):
+ uq = self.quan(1.0, unit)
+ elif isinstance(unit, numeric_type):
+ uq = self.quan(unit, cgs_unit)
+ elif isinstance(unit, YTQuantity):
+ uq = unit
+ elif isinstance(unit, tuple):
+ uq = self.quan(unit[0], unit[1])
+ else:
+ raise RuntimeError("%s (%s) is invalid." % (attr, unit))
+ setattr(self, attr, uq)
+
+class YTDataset(SavedDataset):
+ """Base dataset class for all ytdata datasets."""
+
+ _con_attrs = ("cosmological_simulation", "current_time",
+ "current_redshift", "hubble_constant",
+ "omega_matter", "omega_lambda",
+ "dimensionality", "domain_dimensions",
+ "periodicity",
+ "domain_left_edge", "domain_right_edge",
+ "container_type", "data_type")
+
+ def _with_parameter_file_open(self, f):
+ self.num_particles = \
+ dict([(group, parse_h5_attr(f[group], "num_elements"))
+ for group in f if group != self.default_fluid_type])
+
def create_field_info(self):
self.field_dependencies = {}
self.derived_field_list = []
@@ -119,24 +218,6 @@
def _setup_override_fields(self):
pass
- def _set_code_unit_attributes(self):
- attrs = ('length_unit', 'mass_unit', 'time_unit',
- 'velocity_unit', 'magnetic_unit')
- cgs_units = ('cm', 'g', 's', 'cm/s', 'gauss')
- base_units = np.ones(len(attrs))
- for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
- if isinstance(unit, string_types):
- uq = self.quan(1.0, unit)
- elif isinstance(unit, numeric_type):
- uq = self.quan(unit, cgs_unit)
- elif isinstance(unit, YTQuantity):
- uq = unit
- elif isinstance(unit, tuple):
- uq = self.quan(unit[0], unit[1])
- else:
- raise RuntimeError("%s (%s) is invalid." % (attr, unit))
- setattr(self, attr, uq)
-
class YTDataHDF5File(ParticleFile):
def __init__(self, ds, io, filename, file_id):
with h5py.File(filename, "r") as f:
@@ -181,6 +262,7 @@
# cover the field_list.
self.field_info.alias(("gas", "cell_volume"), ("grid", "cell_volume"))
+ _data_obj = None
@property
def data(self):
"""
@@ -188,27 +270,21 @@
create this dataset.
"""
- # Some data containers can't be recontructed in the same way
- # since this is now particle-like data.
- data_type = self.parameters["data_type"]
- container_type = self.parameters["container_type"]
- ex_container_type = ["cutting", "proj", "ray", "slice"]
- if data_type == "yt_light_ray" or container_type in ex_container_type:
- mylog.info("Returning an all_data data container.")
- return self.all_data()
+ if self._data_obj is None:
+ # Some data containers can't be recontructed in the same way
+ # since this is now particle-like data.
+ data_type = self.parameters.get("data_type")
+ container_type = self.parameters.get("container_type")
+ ex_container_type = ["cutting", "proj", "ray", "slice"]
+ if data_type == "yt_light_ray" or container_type in ex_container_type:
+ mylog.info("Returning an all_data data container.")
+ return self.all_data()
- my_obj = getattr(self, self.parameters["container_type"])
- my_args = []
- for con_arg in self.parameters["con_args"]:
- my_arg = self.parameters[con_arg]
- my_units = self.parameters.get("%s_units" % con_arg)
- if my_units is not None:
- if isinstance(my_arg, np.ndarray):
- my_arg = self.arr(my_arg, my_units)
- else:
- my_arg = self.quan(my_arg, my_units)
- my_args.append(my_arg)
- return my_obj(*my_args)
+ my_obj = getattr(self, self.parameters["container_type"])
+ my_args = [self.parameters[con_arg]
+ for con_arg in self.parameters["con_args"]]
+ self._data_obj = my_obj(*my_args)
+ return self._data_obj
@classmethod
def _is_valid(self, *args, **kwargs):
@@ -251,15 +327,8 @@
for field in lrs_fields:
field_name = field[len(key)+1:]
for i in range(self.parameters[field].shape[0]):
- self.light_ray_solution[i][field_name] = self.parameters[field][i]
- if "%s_units" % field in self.parameters:
- if len(self.parameters[field].shape) > 1:
- to_val = self.arr
- else:
- to_val = self.quan
- self.light_ray_solution[i][field_name] = \
- to_val(self.light_ray_solution[i][field_name],
- self.parameters["%s_units" % field])
+ self.light_ray_solution[i][field_name] = \
+ self.parameters[field][i]
@classmethod
def _is_valid(self, *args, **kwargs):
@@ -693,9 +762,7 @@
my_range = np.log10(my_range)
self.domain_left_edge[i] = my_range[0]
self.domain_right_edge[i] = my_range[1]
- setattr(self, range_name,
- self.arr(self.parameters[range_name],
- self.parameters[range_name+"_units"]))
+ setattr(self, range_name, self.parameters[range_name])
bin_field = "%s_field" % ax
if isinstance(self.parameters[bin_field], string_types) and \
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/io.py
--- a/yt/frontends/ytdata/io.py
+++ b/yt/frontends/ytdata/io.py
@@ -130,6 +130,7 @@
return rv
def _read_particle_coords(self, chunks, ptf):
+ pn = "particle_position_%s"
chunks = list(chunks)
for chunk in chunks:
f = None
@@ -140,9 +141,10 @@
if g.NumberOfParticles == 0:
continue
for ptype, field_list in sorted(ptf.items()):
- pn = "particle_position_%s"
- x, y, z = (np.asarray(f[ptype][pn % ax].value, dtype="=f8")
- for ax in 'xyz')
+ units = f[ptype][pn % "x"].attrs["units"]
+ x, y, z = \
+ (self.ds.arr(f[ptype][pn % ax].value.astype("float64"), units)
+ for ax in "xyz")
for field in field_list:
if np.asarray(f[ptype][field]).ndim > 1:
self._array_fields[field] = f[ptype][field].shape
@@ -150,6 +152,7 @@
if f: f.close()
def _read_particle_fields(self, chunks, ptf, selector):
+ pn = "particle_position_%s"
chunks = list(chunks)
for chunk in chunks: # These should be organized by grid filename
f = None
@@ -160,9 +163,10 @@
if g.NumberOfParticles == 0:
continue
for ptype, field_list in sorted(ptf.items()):
- pn = "particle_position_%s"
- x, y, z = (np.asarray(f[ptype][pn % ax].value, dtype="=f8")
- for ax in 'xyz')
+ units = f[ptype][pn % "x"].attrs["units"]
+ x, y, z = \
+ (self.ds.arr(f[ptype][pn % ax].value.astype("float64"), units)
+ for ax in "xyz")
mask = selector.select_points(x, y, z, 0.0)
if mask is None: continue
for field in field_list:
@@ -188,9 +192,10 @@
for ptype, field_list in sorted(ptf.items()):
pcount = data_file.total_particles[ptype]
if pcount == 0: continue
- x = _get_position_array(ptype, f, "x")
- y = _get_position_array(ptype, f, "y")
- z = _get_position_array(ptype, f, "z")
+ units = _get_position_array_units(ptype, f, "x")
+ x, y, z = \
+ (self.ds.arr(_get_position_array(ptype, f, ax), units)
+ for ax in "xyz")
yield ptype, (x, y, z)
def _read_particle_fields(self, chunks, ptf, selector):
@@ -203,9 +208,10 @@
for data_file in sorted(data_files):
with h5py.File(data_file.filename, "r") as f:
for ptype, field_list in sorted(ptf.items()):
- x = _get_position_array(ptype, f, "x")
- y = _get_position_array(ptype, f, "y")
- z = _get_position_array(ptype, f, "z")
+ units = _get_position_array_units(ptype, f, "x")
+ x, y, z = \
+ (self.ds.arr(_get_position_array(ptype, f, ax), units)
+ for ax in "xyz")
mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None: continue
@@ -224,31 +230,33 @@
for ptype in all_count:
if ptype not in f or all_count[ptype] == 0: continue
pos = np.empty((all_count[ptype], 3), dtype="float64")
- pos = data_file.ds.arr(pos, "code_length")
+ units = _get_position_array_units(ptype, f, "x")
if ptype == "grid":
dx = f["grid"]["dx"].value.min()
+ dx = self.ds.quan(
+ dx, f["grid"]["dx"].attrs["units"]).to("code_length")
else:
dx = 2. * np.finfo(f[ptype]["particle_position_x"].dtype).eps
- dx = self.ds.quan(dx, "code_length")
+ dx = self.ds.quan(dx, units).to("code_length")
pos[:,0] = _get_position_array(ptype, f, "x")
pos[:,1] = _get_position_array(ptype, f, "y")
pos[:,2] = _get_position_array(ptype, f, "z")
+ pos = self.ds.arr(pos, units).to("code_length")
+ dle = self.ds.domain_left_edge.to("code_length")
+ dre = self.ds.domain_right_edge.to("code_length")
+
# These are 32 bit numbers, so we give a little lee-way.
# Otherwise, for big sets of particles, we often will bump into the
# domain edges. This helps alleviate that.
- np.clip(pos, self.ds.domain_left_edge + dx,
- self.ds.domain_right_edge - dx, pos)
- if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
- np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+ np.clip(pos, dle + dx, dre - dx, pos)
+ if np.any(pos.min(axis=0) < dle) or \
+ np.any(pos.max(axis=0) > dre):
raise YTDomainOverflow(pos.min(axis=0),
pos.max(axis=0),
- self.ds.domain_left_edge,
- self.ds.domain_right_edge)
+ dle, dre)
regions.add_data_file(pos, data_file.file_id)
morton[ind:ind+pos.shape[0]] = compute_morton(
- pos[:,0], pos[:,1], pos[:,2],
- data_file.ds.domain_left_edge,
- data_file.ds.domain_right_edge)
+ pos[:,0], pos[:,1], pos[:,2], dle, dre)
ind += pos.shape[0]
return morton
@@ -284,7 +292,7 @@
x = _get_position_array(ptype, f, "px")
y = _get_position_array(ptype, f, "py")
z = np.zeros(x.size, dtype="float64") + \
- self.ds.domain_left_edge[2].in_cgs().d
+ self.ds.domain_left_edge[2].to("code_length").d
yield ptype, (x, y, z)
def _read_particle_fields(self, chunks, ptf, selector):
@@ -301,7 +309,7 @@
x = _get_position_array(ptype, f, "px")
y = _get_position_array(ptype, f, "py")
z = np.zeros(all_count[ptype], dtype="float64") + \
- self.ds.domain_left_edge[2].in_cgs().d
+ self.ds.domain_left_edge[2].to("code_length").d
mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None: continue
@@ -320,32 +328,32 @@
for ptype in all_count:
if ptype not in f or all_count[ptype] == 0: continue
pos = np.empty((all_count[ptype], 3), dtype="float64")
- pos = data_file.ds.arr(pos, "code_length")
+ pos = self.ds.arr(pos, "code_length")
if ptype == "grid":
dx = f["grid"]["pdx"].value.min()
+ dx = self.ds.quan(
+ dx, f["grid"]["pdx"].attrs["units"]).to("code_length")
else:
raise NotImplementedError
- dx = self.ds.quan(dx, "code_length")
pos[:,0] = _get_position_array(ptype, f, "px")
pos[:,1] = _get_position_array(ptype, f, "py")
pos[:,2] = np.zeros(all_count[ptype], dtype="float64") + \
- self.ds.domain_left_edge[2].in_cgs().d
+ self.ds.domain_left_edge[2].to("code_length").d
+ dle = self.ds.domain_left_edge.to("code_length")
+ dre = self.ds.domain_right_edge.to("code_length")
+
# These are 32 bit numbers, so we give a little lee-way.
# Otherwise, for big sets of particles, we often will bump into the
# domain edges. This helps alleviate that.
- np.clip(pos, self.ds.domain_left_edge + dx,
- self.ds.domain_right_edge - dx, pos)
- if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
- np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+ np.clip(pos, dle + dx, dre - dx, pos)
+ if np.any(pos.min(axis=0) < dle) or \
+ np.any(pos.max(axis=0) > dre):
raise YTDomainOverflow(pos.min(axis=0),
pos.max(axis=0),
- self.ds.domain_left_edge,
- self.ds.domain_right_edge)
+ dre, dle)
regions.add_data_file(pos, data_file.file_id)
morton[ind:ind+pos.shape[0]] = compute_morton(
- pos[:,0], pos[:,1], pos[:,2],
- data_file.ds.domain_left_edge,
- data_file.ds.domain_right_edge)
+ pos[:,0], pos[:,1], pos[:,2], dle, dre)
ind += pos.shape[0]
return morton
@@ -355,3 +363,10 @@
else:
pos_name = "particle_position_"
return f[ptype][pos_name + ax].value.astype("float64")
+
+def _get_position_array_units(ptype, f, ax):
+ if ptype == "grid":
+ pos_name = ""
+ else:
+ pos_name = "particle_position_"
+ return f[ptype][pos_name + ax].attrs["units"]
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/tests/test_outputs.py
--- a/yt/frontends/ytdata/tests/test_outputs.py
+++ b/yt/frontends/ytdata/tests/test_outputs.py
@@ -42,6 +42,14 @@
import os
import shutil
+def compare_unit_attributes(ds1, ds2):
+ attrs = ('length_unit', 'mass_unit', 'time_unit',
+ 'velocity_unit', 'magnetic_unit')
+ for attr in attrs:
+ u1 = getattr(ds1, attr, None)
+ u2 = getattr(ds2, attr, None)
+ assert u1 == u2
+
class YTDataFieldTest(AnswerTestingTest):
_type_name = "YTDataTest"
_attrs = ("field_name", )
@@ -88,6 +96,7 @@
fn = sphere.save_as_dataset(fields=["density", "particle_mass"])
full_fn = os.path.join(tmpdir, fn)
sphere_ds = load(full_fn)
+ compare_unit_attributes(ds, sphere_ds)
assert isinstance(sphere_ds, YTDataContainerDataset)
yield YTDataFieldTest(full_fn, ("grid", "density"))
yield YTDataFieldTest(full_fn, ("all", "particle_mass"))
@@ -104,6 +113,7 @@
fn = cg.save_as_dataset(fields=["density", "particle_mass"])
full_fn = os.path.join(tmpdir, fn)
cg_ds = load(full_fn)
+ compare_unit_attributes(ds, cg_ds)
assert isinstance(cg_ds, YTGridDataset)
yield YTDataFieldTest(full_fn, ("grid", "density"))
@@ -112,6 +122,7 @@
frb = my_proj.to_frb(1.0, (800, 800))
fn = frb.save_as_dataset(fields=["density"])
frb_ds = load(fn)
+ compare_unit_attributes(ds, frb_ds)
assert isinstance(frb_ds, YTGridDataset)
yield YTDataFieldTest(full_fn, "density", geometric=False)
os.chdir(curdir)
@@ -127,6 +138,7 @@
fn = proj.save_as_dataset()
full_fn = os.path.join(tmpdir, fn)
proj_ds = load(full_fn)
+ compare_unit_attributes(ds, proj_ds)
assert isinstance(proj_ds, YTSpatialPlotDataset)
yield YTDataFieldTest(full_fn, ("grid", "density"), geometric=False)
os.chdir(curdir)
@@ -144,6 +156,7 @@
fn = profile_1d.save_as_dataset()
full_fn = os.path.join(tmpdir, fn)
prof_1d_ds = load(full_fn)
+ compare_unit_attributes(ds, prof_1d_ds)
assert isinstance(prof_1d_ds, YTProfileDataset)
p1 = ProfilePlot(prof_1d_ds.data, "density", "temperature",
@@ -159,6 +172,7 @@
fn = profile_2d.save_as_dataset()
full_fn = os.path.join(tmpdir, fn)
prof_2d_ds = load(full_fn)
+ compare_unit_attributes(ds, prof_2d_ds)
assert isinstance(prof_2d_ds, YTProfileDataset)
p2 = PhasePlot(prof_2d_ds.data, "density", "temperature",
@@ -188,6 +202,7 @@
save_as_dataset(ds, fn, my_data)
full_fn = os.path.join(tmpdir, fn)
array_ds = load(full_fn)
+ compare_unit_attributes(ds, array_ds)
assert isinstance(array_ds, YTNonspatialDataset)
yield YTDataFieldTest(full_fn, "region_density", geometric=False)
yield YTDataFieldTest(full_fn, "sphere_density", geometric=False)
diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/utilities.py
--- a/yt/frontends/ytdata/utilities.py
+++ b/yt/frontends/ytdata/utilities.py
@@ -95,7 +95,9 @@
"current_redshift", "current_time",
"domain_dimensions", "periodicity",
"cosmological_simulation", "omega_lambda",
- "omega_matter", "hubble_constant"]
+ "omega_matter", "hubble_constant",
+ "length_unit", "mass_unit", "time_unit",
+ "velocity_unit", "magnetic_unit"]
fh = h5py.File(filename, "w")
if ds is None: ds = {}
@@ -104,6 +106,14 @@
for attr, val in ds.parameters.items():
_yt_array_hdf5_attr(fh, attr, val)
+ if hasattr(ds, "unit_registry"):
+ _yt_array_hdf5_attr(fh, "unit_registry_json",
+ ds.unit_registry.to_json())
+
+ if hasattr(ds, "unit_system"):
+ _yt_array_hdf5_attr(fh, "unit_system_name",
+ ds.unit_system.name)
+
for attr in base_attrs:
if isinstance(ds, dict):
my_val = ds.get(attr, None)
@@ -126,19 +136,16 @@
field_type = field_types[field]
if field_type not in fh:
fh.create_group(field_type)
- # for now, let's avoid writing "code" units
- if hasattr(data[field], "units"):
- for atom in data[field].units.expr.atoms():
- if str(atom).startswith("code"):
- data[field].convert_to_base()
- break
+
if isinstance(field, tuple):
field_name = field[1]
else:
field_name = field
- # thanks, python3
+
+ # for python3
if data[field].dtype.kind == 'U':
- data[field] = data[field].astype('|S40')
+ data[field] = data[field].astype('|S')
+
_yt_array_hdf5(fh[field_type], field_name, data[field])
if "num_elements" not in fh[field_type].attrs:
fh[field_type].attrs["num_elements"] = data[field].size
@@ -225,7 +232,6 @@
if val is None: val = "None"
if hasattr(val, "units"):
- val = val.in_base()
fh.attrs["%s_units" % attr] = str(val.units)
# The following is a crappy workaround for getting
# Unicode strings into HDF5 attributes in Python 3
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list