[yt-svn] commit/yt: ngoldbaum: Merged in brittonsmith/yt (pull request #2482)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Wed Jan 11 20:29:40 PST 2017


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/ec92132823b0/
Changeset:   ec92132823b0
Branch:      yt
User:        ngoldbaum
Date:        2017-01-12 04:29:14+00:00
Summary:     Merged in brittonsmith/yt (pull request #2482)

Use parent unit_registry for YTData and HaloCatalog datasets
Affected #:  8 files

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
--- a/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
+++ b/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
@@ -655,33 +655,32 @@
 
         Write light ray data to hdf5 file.
         """
+
+        extra_attrs = {"data_type": "yt_light_ray"}
         if self.simulation_type is None:
             ds = self.ds
         else:
             ds = {}
-            ds["dimensionality"] = self.simulation.dimensionality
-            ds["domain_left_edge"] = self.simulation.domain_left_edge
-            ds["domain_right_edge"] = self.simulation.domain_right_edge
-            ds["cosmological_simulation"] = self.simulation.cosmological_simulation
             ds["periodicity"] = (True, True, True)
             ds["current_redshift"] = self.near_redshift
-            for attr in ["omega_lambda", "omega_matter", "hubble_constant"]:
-                ds[attr] = getattr(self.cosmology, attr)
+            for attr in ["dimensionality", "cosmological_simulation",
+                         "domain_left_edge", "domain_right_edge",
+                         "length_unit", "time_unit"]:
+                ds[attr] = getattr(self.simulation, attr)
+            if self.simulation.cosmological_simulation:
+                for attr in ["omega_lambda", "omega_matter",
+                             "hubble_constant"]:
+                    ds[attr] = getattr(self.cosmology, attr)
             ds["current_time"] = \
               self.cosmology.t_from_z(ds["current_redshift"])
             if isinstance(ds["hubble_constant"], YTArray):
                 ds["hubble_constant"] = \
                   ds["hubble_constant"].to("100*km/(Mpc*s)").d
-        extra_attrs = {"data_type": "yt_light_ray"}
+            extra_attrs["unit_registry_json"] = \
+              self.simulation.unit_registry.to_json()
 
         # save the light ray solution
         if len(self.light_ray_solution) > 0:
-            # Convert everything to base unit system now to avoid
-            # problems with different units for each ds.
-            for s in self.light_ray_solution:
-                for f in s:
-                    if isinstance(s[f], YTArray):
-                        s[f].convert_to_base()
             for key in self.light_ray_solution[0]:
                 if key in ["next", "previous", "index"]:
                     continue

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/data_objects/static_output.py
--- a/yt/data_objects/static_output.py
+++ b/yt/data_objects/static_output.py
@@ -299,14 +299,7 @@
         self.no_cgs_equiv_length = False
 
         self._create_unit_registry()
-
-        create_code_unit_system(self)
-        if unit_system == "code":
-            unit_system = str(self)
-        else:
-            unit_system = str(unit_system).lower()
-
-        self.unit_system = unit_system_registry[unit_system]
+        self._assign_unit_system(unit_system)
 
         self._parse_parameter_file()
         self.set_units()
@@ -336,8 +329,9 @@
         for attr in ("center", "width", "left_edge", "right_edge"):
             n = "domain_%s" % attr
             v = getattr(self, n)
-            v = self.arr(v, "code_length")
-            setattr(self, n, v)
+            if not isinstance(v, YTArray):
+                v = self.arr(v, "code_length")
+                setattr(self, n, v)
 
     def __reduce__(self):
         args = (self._hash(),)
@@ -880,6 +874,14 @@
     def relative_refinement(self, l0, l1):
         return self.refine_by**(l1-l0)
 
+    def _assign_unit_system(self, unit_system):
+        create_code_unit_system(self)
+        if unit_system == "code":
+            unit_system = str(self)
+        else:
+            unit_system = str(unit_system).lower()
+        self.unit_system = unit_system_registry[unit_system]
+
     def _create_unit_registry(self):
         self.unit_registry = UnitRegistry()
         import yt.units.dimensions as dimensions

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/halo_catalog/data_structures.py
--- a/yt/frontends/halo_catalog/data_structures.py
+++ b/yt/frontends/halo_catalog/data_structures.py
@@ -16,22 +16,20 @@
 
 from yt.utilities.on_demand_imports import _h5py as h5py
 import numpy as np
-import stat
 import glob
-import os
 
 from .fields import \
     HaloCatalogFieldInfo
 
+from yt.frontends.ytdata.data_structures import \
+    SavedDataset
 from yt.funcs import \
-    parse_h5_attr, \
-    setdefaultattr
+    parse_h5_attr
 from yt.geometry.particle_geometry_handler import \
     ParticleIndex
 from yt.data_objects.static_output import \
-    Dataset, \
     ParticleFile
-    
+
 class HaloCatalogHDF5File(ParticleFile):
     def __init__(self, ds, io, filename, file_id):
         with h5py.File(filename, "r") as f:
@@ -40,11 +38,15 @@
 
         super(HaloCatalogHDF5File, self).__init__(ds, io, filename, file_id)
     
-class HaloCatalogDataset(Dataset):
+class HaloCatalogDataset(SavedDataset):
     _index_class = ParticleIndex
     _file_class = HaloCatalogHDF5File
     _field_info_class = HaloCatalogFieldInfo
     _suffix = ".h5"
+    _con_attrs = ("cosmological_simulation",
+                  "current_time", "current_redshift",
+                  "hubble_constant", "omega_matter", "omega_lambda",
+                  "domain_left_edge", "domain_right_edge")
 
     def __init__(self, filename, dataset_type="halocatalog_hdf5",
                  n_ref = 16, over_refine_factor = 1, units_override=None,
@@ -56,33 +58,17 @@
                                                  unit_system=unit_system)
 
     def _parse_parameter_file(self):
-        with h5py.File(self.parameter_filename, "r") as f:
-            hvals = dict((key, parse_h5_attr(f, key)) for key in f.attrs.keys())
+        self.refine_by = 2
         self.dimensionality = 3
-        self.refine_by = 2
-        self.unique_identifier = \
-            int(os.stat(self.parameter_filename)[stat.ST_CTIME])
+        nz = 1 << self.over_refine_factor
+        self.domain_dimensions = np.ones(self.dimensionality, "int32") * nz
+        self.periodicity = (True, True, True)
         prefix = ".".join(self.parameter_filename.rsplit(".", 2)[:-2])
         self.filename_template = "%s.%%(num)s%s" % (prefix, self._suffix)
         self.file_count = len(glob.glob(prefix + "*" + self._suffix))
-
-        for attr in ["cosmological_simulation", "current_time", "current_redshift",
-                     "hubble_constant", "omega_matter", "omega_lambda",
-                     "domain_left_edge", "domain_right_edge"]:
-            setattr(self, attr, hvals[attr])
-        self.periodicity = (True, True, True)
         self.particle_types = ("halos")
         self.particle_types_raw = ("halos")
-
-        nz = 1 << self.over_refine_factor
-        self.domain_dimensions = np.ones(3, "int32") * nz
-        self.parameters.update(hvals)
-
-    def _set_code_unit_attributes(self):
-        setdefaultattr(self, 'length_unit', self.quan(1.0, "cm"))
-        setdefaultattr(self, 'mass_unit', self.quan(1.0, "g"))
-        setdefaultattr(self, 'velocity_unit', self.quan(1.0, "cm / s"))
-        setdefaultattr(self, 'time_unit', self.quan(1.0, "s"))
+        super(HaloCatalogDataset, self)._parse_parameter_file()
 
     @classmethod
     def _is_valid(self, *args, **kwargs):

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/halo_catalog/io.py
--- a/yt/frontends/halo_catalog/io.py
+++ b/yt/frontends/halo_catalog/io.py
@@ -42,11 +42,13 @@
         for chunk in chunks:
             for obj in chunk.objs:
                 data_files.update(obj.data_files)
+        pn = "particle_position_%s"
         for data_file in sorted(data_files):
             with h5py.File(data_file.filename, "r") as f:
-                x = f['particle_position_x'].value.astype("float64")
-                y = f['particle_position_y'].value.astype("float64")
-                z = f['particle_position_z'].value.astype("float64")
+                units = f[pn % "x"].attrs["units"]
+                x, y, z = \
+                  (self.ds.arr(f[pn % ax].value.astype("float64"), units)
+                   for ax in "xyz")
                 yield "halos", (x, y, z)
 
     def _read_particle_fields(self, chunks, ptf, selector):
@@ -59,12 +61,14 @@
         for chunk in chunks:
             for obj in chunk.objs:
                 data_files.update(obj.data_files)
+        pn = "particle_position_%s"
         for data_file in sorted(data_files):
             with h5py.File(data_file.filename, "r") as f:
                 for ptype, field_list in sorted(ptf.items()):
-                    x = f['particle_position_x'].value.astype("float64")
-                    y = f['particle_position_y'].value.astype("float64")
-                    z = f['particle_position_z'].value.astype("float64")
+                    units = f[pn % "x"].attrs["units"]
+                    x, y, z = \
+                      (self.ds.arr(f[pn % ax].value.astype("float64"), units)
+                       for ax in "xyz")
                     mask = selector.select_points(x, y, z, 0.0)
                     del x, y, z
                     if mask is None: continue
@@ -82,28 +86,27 @@
         with h5py.File(data_file.filename, "r") as f:
             if not f.keys(): return None
             pos = np.empty((pcount, 3), dtype="float64")
-            pos = data_file.ds.arr(pos, "code_length")
+            units = f["particle_position_x"].attrs["units"]
             dx = np.finfo(f['particle_position_x'].dtype).eps
-            dx = 2.0*self.ds.quan(dx, "code_length")
+            dx = 2.0 * self.ds.quan(dx, units).to("code_length")
             pos[:,0] = f["particle_position_x"].value
             pos[:,1] = f["particle_position_y"].value
             pos[:,2] = f["particle_position_z"].value
+            pos = data_file.ds.arr(pos, units). to("code_length")
+            dle = self.ds.domain_left_edge.to("code_length")
+            dre = self.ds.domain_right_edge.to("code_length")
             # These are 32 bit numbers, so we give a little lee-way.
             # Otherwise, for big sets of particles, we often will bump into the
             # domain edges.  This helps alleviate that.
-            np.clip(pos, self.ds.domain_left_edge + dx,
-                         self.ds.domain_right_edge - dx, pos)
-            if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
-               np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+            np.clip(pos, dle + dx, dre - dx, pos)
+            if np.any(pos.min(axis=0) < dle) or \
+               np.any(pos.max(axis=0) > dre):
                 raise YTDomainOverflow(pos.min(axis=0),
                                        pos.max(axis=0),
-                                       self.ds.domain_left_edge,
-                                       self.ds.domain_right_edge)
+                                       dle, dre)
             regions.add_data_file(pos, data_file.file_id)
             morton[ind:ind+pos.shape[0]] = compute_morton(
-                pos[:,0], pos[:,1], pos[:,2],
-                data_file.ds.domain_left_edge,
-                data_file.ds.domain_right_edge)
+                pos[:,0], pos[:,1], pos[:,2], dle, dre)
         return morton
 
     def _count_particles(self, data_file):

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/data_structures.py
--- a/yt/frontends/ytdata/data_structures.py
+++ b/yt/frontends/ytdata/data_structures.py
@@ -47,6 +47,10 @@
     GridIndex
 from yt.geometry.particle_geometry_handler import \
     ParticleIndex
+from yt.units import \
+    dimensions
+from yt.units.unit_registry import \
+    UnitRegistry
 from yt.units.yt_array import \
     YTQuantity
 from yt.utilities.logger import \
@@ -68,8 +72,12 @@
                          "covering_grid",
                          "smoothed_covering_grid"]
 
-class YTDataset(Dataset):
-    """Base dataset class for all ytdata datasets."""
+class SavedDataset(Dataset):
+    """
+    Base dataset class for products of calling save_as_dataset.
+    """
+    _con_attrs = ()
+
     def _parse_parameter_file(self):
         self.refine_by = 2
         with h5py.File(self.parameter_filename, "r") as f:
@@ -78,18 +86,109 @@
                 if key == "con_args":
                     v = v.astype("str")
                 self.parameters[key] = v
-            self.num_particles = \
-              dict([(group, parse_h5_attr(f[group], "num_elements"))
-                    for group in f if group != self.default_fluid_type])
-        for attr in ["cosmological_simulation", "current_time", "current_redshift",
-                     "hubble_constant", "omega_matter", "omega_lambda",
-                     "dimensionality", "domain_dimensions", "periodicity",
-                     "domain_left_edge", "domain_right_edge",
-                     "container_type", "data_type"]:
+            self._with_parameter_file_open(f)
+
+        # if saved, restore unit registry from the json string
+        if "unit_registry_json" in self.parameters:
+            self.unit_registry = UnitRegistry.from_json(
+                self.parameters["unit_registry_json"])
+            # reset self.arr and self.quan to use new unit_registry
+            self._arr = None
+            self._quan = None
+            for dim in ["length", "mass", "pressure",
+                        "temperature", "time", "velocity"]:
+                cu = "code_" + dim
+                if cu not in self.unit_registry:
+                    self.unit_registry.add(
+                        cu, 1.0, getattr(dimensions, dim))
+            if "code_magnetic" not in self.unit_registry:
+                self.unit_registry.add("code_magnetic", 1.0,
+                                       dimensions.magnetic_field)
+
+        # if saved, set unit system
+        if "unit_system_name" in self.parameters:
+            unit_system = self.parameters["unit_system_name"]
+            del self.parameters["unit_system_name"]
+        else:
+            unit_system = "cgs"
+        # reset unit system since we may have a new unit registry
+        self._assign_unit_system(unit_system)
+
+        # assign units to parameters that have associated unit string
+        del_pars = []
+        for par in self.parameters:
+            ustr = "%s_units" % par
+            if ustr in self.parameters:
+                if isinstance(self.parameters[par], np.ndarray):
+                    to_u = self.arr
+                else:
+                    to_u = self.quan
+                self.parameters[par] = to_u(
+                    self.parameters[par], self.parameters[ustr])
+                del_pars.append(ustr)
+        for par in del_pars:
+            del self.parameters[par]
+
+        for attr in self._con_attrs:
             setattr(self, attr, self.parameters.get(attr))
         self.unique_identifier = \
           int(os.stat(self.parameter_filename)[stat.ST_CTIME])
 
+    def _with_parameter_file_open(self, f):
+        # This allows subclasses to access the parameter file
+        # while it's still open to get additional information.
+        pass
+
+    def set_units(self):
+        if "unit_registry_json" in self.parameters:
+            self._set_code_unit_attributes()
+            del self.parameters["unit_registry_json"]
+        else:
+            super(SavedDataset, self).set_units()
+
+    def _set_code_unit_attributes(self):
+        attrs = ('length_unit', 'mass_unit', 'time_unit',
+                 'velocity_unit', 'magnetic_unit')
+        cgs_units = ('cm', 'g', 's', 'cm/s', 'gauss')
+        base_units = np.ones(len(attrs))
+        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
+            if attr in self.parameters and \
+              isinstance(self.parameters[attr], YTQuantity):
+                uq = self.parameters[attr]
+            elif attr in self.parameters and \
+              "%s_units" % attr in self.parameters:
+                uq = self.quan(self.parameters[attr],
+                               self.parameters["%s_units" % attr])
+                del self.parameters[attr]
+                del self.parameters["%s_units" % attr]
+            elif isinstance(unit, string_types):
+                uq = self.quan(1.0, unit)
+            elif isinstance(unit, numeric_type):
+                uq = self.quan(unit, cgs_unit)
+            elif isinstance(unit, YTQuantity):
+                uq = unit
+            elif isinstance(unit, tuple):
+                uq = self.quan(unit[0], unit[1])
+            else:
+                raise RuntimeError("%s (%s) is invalid." % (attr, unit))
+            setattr(self, attr, uq)
+
+class YTDataset(SavedDataset):
+    """Base dataset class for all ytdata datasets."""
+
+    _con_attrs = ("cosmological_simulation", "current_time",
+                  "current_redshift", "hubble_constant",
+                  "omega_matter", "omega_lambda",
+                  "dimensionality", "domain_dimensions",
+                  "periodicity",
+                  "domain_left_edge", "domain_right_edge",
+                  "container_type", "data_type")
+
+    def _with_parameter_file_open(self, f):
+        self.num_particles = \
+          dict([(group, parse_h5_attr(f[group], "num_elements"))
+                for group in f if group != self.default_fluid_type])
+
     def create_field_info(self):
         self.field_dependencies = {}
         self.derived_field_list = []
@@ -119,24 +218,6 @@
     def _setup_override_fields(self):
         pass
 
-    def _set_code_unit_attributes(self):
-        attrs = ('length_unit', 'mass_unit', 'time_unit',
-                 'velocity_unit', 'magnetic_unit')
-        cgs_units = ('cm', 'g', 's', 'cm/s', 'gauss')
-        base_units = np.ones(len(attrs))
-        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
-            if isinstance(unit, string_types):
-                uq = self.quan(1.0, unit)
-            elif isinstance(unit, numeric_type):
-                uq = self.quan(unit, cgs_unit)
-            elif isinstance(unit, YTQuantity):
-                uq = unit
-            elif isinstance(unit, tuple):
-                uq = self.quan(unit[0], unit[1])
-            else:
-                raise RuntimeError("%s (%s) is invalid." % (attr, unit))
-            setattr(self, attr, uq)
-
 class YTDataHDF5File(ParticleFile):
     def __init__(self, ds, io, filename, file_id):
         with h5py.File(filename, "r") as f:
@@ -181,6 +262,7 @@
         # cover the field_list.
         self.field_info.alias(("gas", "cell_volume"), ("grid", "cell_volume"))
 
+    _data_obj = None
     @property
     def data(self):
         """
@@ -188,27 +270,21 @@
         create this dataset.
         """
 
-        # Some data containers can't be recontructed in the same way
-        # since this is now particle-like data.
-        data_type = self.parameters["data_type"]
-        container_type = self.parameters["container_type"]
-        ex_container_type = ["cutting", "proj", "ray", "slice"]
-        if data_type == "yt_light_ray" or container_type in ex_container_type:
-            mylog.info("Returning an all_data data container.")
-            return self.all_data()
+        if self._data_obj is None:
+            # Some data containers can't be recontructed in the same way
+            # since this is now particle-like data.
+            data_type = self.parameters.get("data_type")
+            container_type = self.parameters.get("container_type")
+            ex_container_type = ["cutting", "proj", "ray", "slice"]
+            if data_type == "yt_light_ray" or container_type in ex_container_type:
+                mylog.info("Returning an all_data data container.")
+                return self.all_data()
 
-        my_obj = getattr(self, self.parameters["container_type"])
-        my_args = []
-        for con_arg in self.parameters["con_args"]:
-            my_arg = self.parameters[con_arg]
-            my_units = self.parameters.get("%s_units" % con_arg)
-            if my_units is not None:
-                if isinstance(my_arg, np.ndarray):
-                    my_arg = self.arr(my_arg, my_units)
-                else:
-                    my_arg = self.quan(my_arg, my_units)
-            my_args.append(my_arg)
-        return my_obj(*my_args)
+            my_obj = getattr(self, self.parameters["container_type"])
+            my_args = [self.parameters[con_arg]
+                       for con_arg in self.parameters["con_args"]]
+            self._data_obj = my_obj(*my_args)
+        return self._data_obj
 
     @classmethod
     def _is_valid(self, *args, **kwargs):
@@ -251,15 +327,8 @@
         for field in lrs_fields:
             field_name = field[len(key)+1:]
             for i in range(self.parameters[field].shape[0]):
-                self.light_ray_solution[i][field_name] = self.parameters[field][i]
-                if "%s_units" % field in self.parameters:
-                    if len(self.parameters[field].shape) > 1:
-                        to_val = self.arr
-                    else:
-                        to_val = self.quan
-                    self.light_ray_solution[i][field_name] = \
-                      to_val(self.light_ray_solution[i][field_name],
-                             self.parameters["%s_units" % field])
+                self.light_ray_solution[i][field_name] = \
+                  self.parameters[field][i]
 
     @classmethod
     def _is_valid(self, *args, **kwargs):
@@ -693,9 +762,7 @@
                 my_range = np.log10(my_range)
             self.domain_left_edge[i] = my_range[0]
             self.domain_right_edge[i] = my_range[1]
-            setattr(self, range_name,
-                    self.arr(self.parameters[range_name],
-                             self.parameters[range_name+"_units"]))
+            setattr(self, range_name, self.parameters[range_name])
 
             bin_field = "%s_field" % ax
             if isinstance(self.parameters[bin_field], string_types) and \

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/io.py
--- a/yt/frontends/ytdata/io.py
+++ b/yt/frontends/ytdata/io.py
@@ -130,6 +130,7 @@
         return rv
 
     def _read_particle_coords(self, chunks, ptf):
+        pn = "particle_position_%s"
         chunks = list(chunks)
         for chunk in chunks:
             f = None
@@ -140,9 +141,10 @@
                 if g.NumberOfParticles == 0:
                     continue
                 for ptype, field_list in sorted(ptf.items()):
-                    pn = "particle_position_%s"
-                    x, y, z = (np.asarray(f[ptype][pn % ax].value, dtype="=f8")
-                               for ax in 'xyz')
+                    units = f[ptype][pn % "x"].attrs["units"]
+                    x, y, z = \
+                      (self.ds.arr(f[ptype][pn % ax].value.astype("float64"), units)
+                       for ax in "xyz")
                     for field in field_list:
                         if np.asarray(f[ptype][field]).ndim > 1:
                             self._array_fields[field] = f[ptype][field].shape
@@ -150,6 +152,7 @@
             if f: f.close()
 
     def _read_particle_fields(self, chunks, ptf, selector):
+        pn = "particle_position_%s"
         chunks = list(chunks)
         for chunk in chunks: # These should be organized by grid filename
             f = None
@@ -160,9 +163,10 @@
                 if g.NumberOfParticles == 0:
                     continue
                 for ptype, field_list in sorted(ptf.items()):
-                    pn = "particle_position_%s"
-                    x, y, z = (np.asarray(f[ptype][pn % ax].value, dtype="=f8")
-                               for ax in 'xyz')
+                    units = f[ptype][pn % "x"].attrs["units"]
+                    x, y, z = \
+                      (self.ds.arr(f[ptype][pn % ax].value.astype("float64"), units)
+                       for ax in "xyz")
                     mask = selector.select_points(x, y, z, 0.0)
                     if mask is None: continue
                     for field in field_list:
@@ -188,9 +192,10 @@
                 for ptype, field_list in sorted(ptf.items()):
                     pcount = data_file.total_particles[ptype]
                     if pcount == 0: continue
-                    x = _get_position_array(ptype, f, "x")
-                    y = _get_position_array(ptype, f, "y")
-                    z = _get_position_array(ptype, f, "z")
+                    units = _get_position_array_units(ptype, f, "x")
+                    x, y, z = \
+                      (self.ds.arr(_get_position_array(ptype, f, ax), units)
+                       for ax in "xyz")
                     yield ptype, (x, y, z)
 
     def _read_particle_fields(self, chunks, ptf, selector):
@@ -203,9 +208,10 @@
         for data_file in sorted(data_files):
             with h5py.File(data_file.filename, "r") as f:
                 for ptype, field_list in sorted(ptf.items()):
-                    x = _get_position_array(ptype, f, "x")
-                    y = _get_position_array(ptype, f, "y")
-                    z = _get_position_array(ptype, f, "z")
+                    units = _get_position_array_units(ptype, f, "x")
+                    x, y, z = \
+                      (self.ds.arr(_get_position_array(ptype, f, ax), units)
+                       for ax in "xyz")
                     mask = selector.select_points(x, y, z, 0.0)
                     del x, y, z
                     if mask is None: continue
@@ -224,31 +230,33 @@
             for ptype in all_count:
                 if ptype not in f or all_count[ptype] == 0: continue
                 pos = np.empty((all_count[ptype], 3), dtype="float64")
-                pos = data_file.ds.arr(pos, "code_length")
+                units = _get_position_array_units(ptype, f, "x")
                 if ptype == "grid":
                     dx = f["grid"]["dx"].value.min()
+                    dx = self.ds.quan(
+                        dx, f["grid"]["dx"].attrs["units"]).to("code_length")
                 else:
                     dx = 2. * np.finfo(f[ptype]["particle_position_x"].dtype).eps
-                dx = self.ds.quan(dx, "code_length")
+                    dx = self.ds.quan(dx, units).to("code_length")
                 pos[:,0] = _get_position_array(ptype, f, "x")
                 pos[:,1] = _get_position_array(ptype, f, "y")
                 pos[:,2] = _get_position_array(ptype, f, "z")
+                pos = self.ds.arr(pos, units).to("code_length")
+                dle = self.ds.domain_left_edge.to("code_length")
+                dre = self.ds.domain_right_edge.to("code_length")
+
                 # These are 32 bit numbers, so we give a little lee-way.
                 # Otherwise, for big sets of particles, we often will bump into the
                 # domain edges.  This helps alleviate that.
-                np.clip(pos, self.ds.domain_left_edge + dx,
-                             self.ds.domain_right_edge - dx, pos)
-                if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
-                   np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+                np.clip(pos, dle + dx, dre - dx, pos)
+                if np.any(pos.min(axis=0) < dle) or \
+                   np.any(pos.max(axis=0) > dre):
                     raise YTDomainOverflow(pos.min(axis=0),
                                            pos.max(axis=0),
-                                           self.ds.domain_left_edge,
-                                           self.ds.domain_right_edge)
+                                           dle, dre)
                 regions.add_data_file(pos, data_file.file_id)
                 morton[ind:ind+pos.shape[0]] = compute_morton(
-                    pos[:,0], pos[:,1], pos[:,2],
-                    data_file.ds.domain_left_edge,
-                    data_file.ds.domain_right_edge)
+                    pos[:,0], pos[:,1], pos[:,2], dle, dre)
                 ind += pos.shape[0]
         return morton
 
@@ -284,7 +292,7 @@
                     x = _get_position_array(ptype, f, "px")
                     y = _get_position_array(ptype, f, "py")
                     z = np.zeros(x.size, dtype="float64") + \
-                      self.ds.domain_left_edge[2].in_cgs().d
+                      self.ds.domain_left_edge[2].to("code_length").d
                     yield ptype, (x, y, z)
 
     def _read_particle_fields(self, chunks, ptf, selector):
@@ -301,7 +309,7 @@
                     x = _get_position_array(ptype, f, "px")
                     y = _get_position_array(ptype, f, "py")
                     z = np.zeros(all_count[ptype], dtype="float64") + \
-                      self.ds.domain_left_edge[2].in_cgs().d
+                      self.ds.domain_left_edge[2].to("code_length").d
                     mask = selector.select_points(x, y, z, 0.0)
                     del x, y, z
                     if mask is None: continue
@@ -320,32 +328,32 @@
             for ptype in all_count:
                 if ptype not in f or all_count[ptype] == 0: continue
                 pos = np.empty((all_count[ptype], 3), dtype="float64")
-                pos = data_file.ds.arr(pos, "code_length")
+                pos = self.ds.arr(pos, "code_length")
                 if ptype == "grid":
                     dx = f["grid"]["pdx"].value.min()
+                    dx = self.ds.quan(
+                        dx, f["grid"]["pdx"].attrs["units"]).to("code_length")
                 else:
                     raise NotImplementedError
-                dx = self.ds.quan(dx, "code_length")
                 pos[:,0] = _get_position_array(ptype, f, "px")
                 pos[:,1] = _get_position_array(ptype, f, "py")
                 pos[:,2] = np.zeros(all_count[ptype], dtype="float64") + \
-                  self.ds.domain_left_edge[2].in_cgs().d
+                  self.ds.domain_left_edge[2].to("code_length").d
+                dle = self.ds.domain_left_edge.to("code_length")
+                dre = self.ds.domain_right_edge.to("code_length")
+
                 # These are 32 bit numbers, so we give a little lee-way.
                 # Otherwise, for big sets of particles, we often will bump into the
                 # domain edges.  This helps alleviate that.
-                np.clip(pos, self.ds.domain_left_edge + dx,
-                             self.ds.domain_right_edge - dx, pos)
-                if np.any(pos.min(axis=0) < self.ds.domain_left_edge) or \
-                   np.any(pos.max(axis=0) > self.ds.domain_right_edge):
+                np.clip(pos, dle + dx, dre - dx, pos)
+                if np.any(pos.min(axis=0) < dle) or \
+                   np.any(pos.max(axis=0) > dre):
                     raise YTDomainOverflow(pos.min(axis=0),
                                            pos.max(axis=0),
-                                           self.ds.domain_left_edge,
-                                           self.ds.domain_right_edge)
+                                           dre, dle)
                 regions.add_data_file(pos, data_file.file_id)
                 morton[ind:ind+pos.shape[0]] = compute_morton(
-                    pos[:,0], pos[:,1], pos[:,2],
-                    data_file.ds.domain_left_edge,
-                    data_file.ds.domain_right_edge)
+                    pos[:,0], pos[:,1], pos[:,2], dle, dre)
                 ind += pos.shape[0]
         return morton
 
@@ -355,3 +363,10 @@
     else:
         pos_name = "particle_position_"
     return f[ptype][pos_name + ax].value.astype("float64")
+
+def _get_position_array_units(ptype, f, ax):
+    if ptype == "grid":
+        pos_name = ""
+    else:
+        pos_name = "particle_position_"
+    return f[ptype][pos_name + ax].attrs["units"]

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/tests/test_outputs.py
--- a/yt/frontends/ytdata/tests/test_outputs.py
+++ b/yt/frontends/ytdata/tests/test_outputs.py
@@ -42,6 +42,14 @@
 import os
 import shutil
 
+def compare_unit_attributes(ds1, ds2):
+    attrs = ('length_unit', 'mass_unit', 'time_unit',
+             'velocity_unit', 'magnetic_unit')
+    for attr in attrs:
+        u1 = getattr(ds1, attr, None)
+        u2 = getattr(ds2, attr, None)
+        assert u1 == u2
+
 class YTDataFieldTest(AnswerTestingTest):
     _type_name = "YTDataTest"
     _attrs = ("field_name", )
@@ -88,6 +96,7 @@
     fn = sphere.save_as_dataset(fields=["density", "particle_mass"])
     full_fn = os.path.join(tmpdir, fn)
     sphere_ds = load(full_fn)
+    compare_unit_attributes(ds, sphere_ds)
     assert isinstance(sphere_ds, YTDataContainerDataset)
     yield YTDataFieldTest(full_fn, ("grid", "density"))
     yield YTDataFieldTest(full_fn, ("all", "particle_mass"))
@@ -104,6 +113,7 @@
     fn = cg.save_as_dataset(fields=["density", "particle_mass"])
     full_fn = os.path.join(tmpdir, fn)
     cg_ds = load(full_fn)
+    compare_unit_attributes(ds, cg_ds)
     assert isinstance(cg_ds, YTGridDataset)
 
     yield YTDataFieldTest(full_fn, ("grid", "density"))
@@ -112,6 +122,7 @@
     frb = my_proj.to_frb(1.0, (800, 800))
     fn = frb.save_as_dataset(fields=["density"])
     frb_ds = load(fn)
+    compare_unit_attributes(ds, frb_ds)
     assert isinstance(frb_ds, YTGridDataset)
     yield YTDataFieldTest(full_fn, "density", geometric=False)
     os.chdir(curdir)
@@ -127,6 +138,7 @@
     fn = proj.save_as_dataset()
     full_fn = os.path.join(tmpdir, fn)
     proj_ds = load(full_fn)
+    compare_unit_attributes(ds, proj_ds)
     assert isinstance(proj_ds, YTSpatialPlotDataset)
     yield YTDataFieldTest(full_fn, ("grid", "density"), geometric=False)
     os.chdir(curdir)
@@ -144,6 +156,7 @@
     fn = profile_1d.save_as_dataset()
     full_fn = os.path.join(tmpdir, fn)
     prof_1d_ds = load(full_fn)
+    compare_unit_attributes(ds, prof_1d_ds)
     assert isinstance(prof_1d_ds, YTProfileDataset)
 
     p1 = ProfilePlot(prof_1d_ds.data, "density", "temperature",
@@ -159,6 +172,7 @@
     fn = profile_2d.save_as_dataset()
     full_fn = os.path.join(tmpdir, fn)
     prof_2d_ds = load(full_fn)
+    compare_unit_attributes(ds, prof_2d_ds)
     assert isinstance(prof_2d_ds, YTProfileDataset)
 
     p2 = PhasePlot(prof_2d_ds.data, "density", "temperature",
@@ -188,6 +202,7 @@
     save_as_dataset(ds, fn, my_data)
     full_fn = os.path.join(tmpdir, fn)
     array_ds = load(full_fn)
+    compare_unit_attributes(ds, array_ds)
     assert isinstance(array_ds, YTNonspatialDataset)
     yield YTDataFieldTest(full_fn, "region_density", geometric=False)
     yield YTDataFieldTest(full_fn, "sphere_density", geometric=False)

diff -r 003a2e024c30e293af9fdd58419bf0d777e41afc -r ec92132823b016f9056f53d166e7262bca2275a7 yt/frontends/ytdata/utilities.py
--- a/yt/frontends/ytdata/utilities.py
+++ b/yt/frontends/ytdata/utilities.py
@@ -95,7 +95,9 @@
                    "current_redshift", "current_time",
                    "domain_dimensions", "periodicity",
                    "cosmological_simulation", "omega_lambda",
-                   "omega_matter", "hubble_constant"]
+                   "omega_matter", "hubble_constant",
+                   "length_unit", "mass_unit", "time_unit",
+                   "velocity_unit", "magnetic_unit"]
 
     fh = h5py.File(filename, "w")
     if ds is None: ds = {}
@@ -104,6 +106,14 @@
         for attr, val in ds.parameters.items():
             _yt_array_hdf5_attr(fh, attr, val)
 
+    if hasattr(ds, "unit_registry"):
+        _yt_array_hdf5_attr(fh, "unit_registry_json",
+                            ds.unit_registry.to_json())
+
+    if hasattr(ds, "unit_system"):
+        _yt_array_hdf5_attr(fh, "unit_system_name",
+                            ds.unit_system.name)
+
     for attr in base_attrs:
         if isinstance(ds, dict):
             my_val = ds.get(attr, None)
@@ -126,19 +136,16 @@
             field_type = field_types[field]
         if field_type not in fh:
             fh.create_group(field_type)
-        # for now, let's avoid writing "code" units
-        if hasattr(data[field], "units"):
-            for atom in data[field].units.expr.atoms():
-                if str(atom).startswith("code"):
-                    data[field].convert_to_base()
-                    break
+
         if isinstance(field, tuple):
             field_name = field[1]
         else:
             field_name = field
-        # thanks, python3
+
+        # for python3
         if data[field].dtype.kind == 'U':
-            data[field] = data[field].astype('|S40')
+            data[field] = data[field].astype('|S')
+
         _yt_array_hdf5(fh[field_type], field_name, data[field])
         if "num_elements" not in fh[field_type].attrs:
             fh[field_type].attrs["num_elements"] = data[field].size
@@ -225,7 +232,6 @@
 
     if val is None: val = "None"
     if hasattr(val, "units"):
-        val = val.in_base()
         fh.attrs["%s_units" % attr] = str(val.units)
     # The following is a crappy workaround for getting
     # Unicode strings into HDF5 attributes in Python 3

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list