[yt-svn] commit/yt: 5 new changesets

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Tue Mar 14 06:27:04 PDT 2017


5 new commits in yt:

https://bitbucket.org/yt_analysis/yt/commits/22f54591804e/
Changeset:   22f54591804e
Branch:      yt
User:        al007
Date:        2017-03-08 20:35:05+00:00
Summary:     Close netCDF4 datasets when not being used.
Affected #:  3 files

diff -r 1088ca4ac8d84017d1809a6f114c852297286350 -r 22f54591804e88d0cc3842ea8d1c19f1612aba1f yt/frontends/exodus_ii/data_structures.py
--- a/yt/frontends/exodus_ii/data_structures.py
+++ b/yt/frontends/exodus_ii/data_structures.py
@@ -180,18 +180,18 @@
 
     def _parse_parameter_file(self):
         self._handle = NetCDF4FileHandler(self.parameter_filename)
-        self._vars = self._handle.dataset.variables
-        self._read_glo_var()
-        self.dimensionality = self._vars['coor_names'].shape[0]
-        self.parameters['info_records'] = self._load_info_records()
-        self.unique_identifier = self._get_unique_identifier()
-        self.num_steps = len(self._vars['time_whole'])
-        self.current_time = self._get_current_time()
-        self.parameters['num_meshes'] = self._vars['eb_status'].shape[0]
-        self.parameters['elem_names'] = self._get_elem_names()
-        self.parameters['nod_names'] = self._get_nod_names()
-        self.domain_left_edge, self.domain_right_edge = self._load_domain_edge()
-        self.periodicity = (False, False, False)
+        with self._handle.open_ds() as ds:
+            self._read_glo_var()
+            self.dimensionality = ds.variables['coor_names'].shape[0]
+            self.parameters['info_records'] = self._load_info_records()
+            self.unique_identifier = self._get_unique_identifier()
+            self.num_steps = len(ds.variables['time_whole'])
+            self.current_time = self._get_current_time()
+            self.parameters['num_meshes'] = ds.variables['eb_status'].shape[0]
+            self.parameters['elem_names'] = self._get_elem_names()
+            self.parameters['nod_names'] = self._get_nod_names()
+            self.domain_left_edge, self.domain_right_edge = self._load_domain_edge()
+            self.periodicity = (False, False, False)
 
         # These attributes don't really make sense for unstructured
         # mesh data, but yt warns if they are not present, so we set
@@ -205,18 +205,18 @@
         self.refine_by = 0
 
     def _get_fluid_types(self):
-        handle = NetCDF4FileHandler(self.parameter_filename).dataset
-        fluid_types = ()
-        i = 1
-        while True:
-            ftype = 'connect%d' % i
-            if ftype in handle.variables:
-                fluid_types += (ftype,)
-                i += 1
-            else:
-                break
-        fluid_types += ('all',)
-        return fluid_types
+        with NetCDF4FileHandler(self.parameter_filename).open_ds() as ds:
+            fluid_types = ()
+            i = 1
+            while True:
+                ftype = 'connect%d' % i
+                if ftype in ds.variables:
+                    fluid_types += (ftype,)
+                    i += 1
+                else:
+                    break
+            fluid_types += ('all',)
+            return fluid_types
 
     def _read_glo_var(self):
         """
@@ -226,31 +226,34 @@
         names = self._get_glo_names()
         if not names:
             return
-        values = self._vars['vals_glo_var'][:].transpose()
-        for name, value in zip(names, values):
-            self.parameters[name] = value
+        with self._handle.open_ds() as ds:
+            values = ds.variables['vals_glo_var'][:].transpose()
+            for name, value in zip(names, values):
+                self.parameters[name] = value
 
     def _load_info_records(self):
         """
         Returns parsed version of the info_records.
         """
-        try:
-            return load_info_records(self._vars['info_records'])
-        except (KeyError, TypeError):
-            mylog.warning("No info_records found")
-            return []
+        with self._handle.open_ds() as ds:
+            try:
+                return load_info_records(ds.variables['info_records'])
+            except (KeyError, TypeError):
+                mylog.warning("No info_records found")
+                return []
 
     def _get_unique_identifier(self):
         return self.parameter_filename
 
     def _get_current_time(self):
-        try:
-            return self._vars['time_whole'][self.step]
-        except IndexError:
-            raise RuntimeError("Invalid step number, max is %d" \
-                               % (self.num_steps - 1))
-        except (KeyError, TypeError):
-            return 0.0
+        with self._handle.open_ds() as ds:
+            try:
+                return ds.variables['time_whole'][self.step]
+            except IndexError:
+                raise RuntimeError("Invalid step number, max is %d" \
+                                   % (self.num_steps - 1))
+            except (KeyError, TypeError):
+                return 0.0
 
     def _get_glo_names(self):
         """
@@ -259,12 +262,13 @@
 
         """
 
-        if "name_glo_var" not in self._vars:
-            mylog.warning("name_glo_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_glo_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_glo_var" not in ds.variables:
+                mylog.warning("name_glo_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_glo_var"]]
 
     def _get_elem_names(self):
         """
@@ -273,12 +277,13 @@
 
         """
 
-        if "name_elem_var" not in self._vars:
-            mylog.warning("name_elem_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_elem_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_elem_var" not in ds.variables:
+                mylog.warning("name_elem_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_elem_var"]]
 
     def _get_nod_names(self):
         """
@@ -287,12 +292,13 @@
 
         """
 
-        if "name_nod_var" not in self._vars:
-            mylog.warning("name_nod_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_nod_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_nod_var" not in ds.variables:
+                mylog.warning("name_nod_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_nod_var"]]
 
     def _read_coordinates(self):
         """
@@ -304,13 +310,14 @@
         coord_axes = 'xyz'[:self.dimensionality]
 
         mylog.info("Loading coordinates")
-        if "coord" not in self._vars:
-            coords = np.array([self._vars["coord%s" % ax][:]
-                               for ax in coord_axes]).transpose().copy()
-        else:
-            coords = np.array([coord for coord in
-                               self._vars["coord"][:]]).transpose().copy()
-        return coords
+        with self._handle.open_ds() as ds:
+            if "coord" not in ds.variables:
+                coords = np.array([ds.variables["coord%s" % ax][:]
+                                   for ax in coord_axes]).transpose().copy()
+            else:
+                coords = np.array([coord for coord in
+                                   ds.variables["coord"][:]]).transpose().copy()
+            return coords
 
     def _apply_displacement(self, coords, mesh_id):
 
@@ -324,13 +331,14 @@
         offset = self.displacements[mesh_name][1]
 
         coord_axes = 'xyz'[:self.dimensionality]
-        for i, ax in enumerate(coord_axes):
-            if "disp_%s" % ax in self.parameters['nod_names']:
-                ind = self.parameters['nod_names'].index("disp_%s" % ax)
-                disp = self._vars['vals_nod_var%d' % (ind + 1)][self.step]
-                new_coords[:, i] = coords[:, i] + fac*disp + offset[i]
+        with self._handle.open_ds() as ds:
+            for i, ax in enumerate(coord_axes):
+                if "disp_%s" % ax in self.parameters['nod_names']:
+                    ind = self.parameters['nod_names'].index("disp_%s" % ax)
+                    disp = ds.variables['vals_nod_var%d' % (ind + 1)][self.step]
+                    new_coords[:, i] = coords[:, i] + fac*disp + offset[i]
 
-        return new_coords
+            return new_coords
 
     def _read_connectivity(self):
         """
@@ -338,9 +346,10 @@
         """
         mylog.info("Loading connectivity")
         connectivity = []
-        for i in range(self.parameters['num_meshes']):
-            connectivity.append(self._vars["connect%d" % (i+1)][:].astype("i8"))
-        return connectivity
+        with self._handle.open_ds() as ds:
+            for i in range(self.parameters['num_meshes']):
+                connectivity.append(ds.variables["connect%d" % (i+1)][:].astype("i8"))
+            return connectivity
 
     def _load_domain_edge(self):
         """
@@ -373,7 +382,7 @@
         for i in range(self.dimensionality, 3):
             mi[i] = 0.0
             ma[i] = 1.0
-        
+
         return mi, ma
 
     @classmethod

diff -r 1088ca4ac8d84017d1809a6f114c852297286350 -r 22f54591804e88d0cc3842ea8d1c19f1612aba1f yt/frontends/exodus_ii/io.py
--- a/yt/frontends/exodus_ii/io.py
+++ b/yt/frontends/exodus_ii/io.py
@@ -28,7 +28,7 @@
     def __init__(self, ds):
         self.filename = ds.index_filename
         exodus_ii_handler = NetCDF4FileHandler(self.filename)
-        self.handler = exodus_ii_handler.dataset
+        self.handler = exodus_ii_handler
         super(IOHandlerExodusII, self).__init__(ds)
         self.node_fields = ds._get_nod_names()
         self.elem_fields = ds._get_elem_names()
@@ -46,46 +46,47 @@
         # dict gets returned at the end and it should be flat, with selected
         # data.  Note that if you're reading grid data, you might need to
         # special-case a grid selector object.
-        chunks = list(chunks)
-        rv = {}
-        for field in fields:
-            ftype, fname = field
-            if ftype == "all":
-                ci = np.concatenate([mesh.connectivity_indices - self._INDEX_OFFSET \
-                                     for mesh in self.ds.index.mesh_union])
-            else:
-                ci = self.handler.variables[ftype][:] - self._INDEX_OFFSET
-            num_elem = ci.shape[0]
-            if fname in self.node_fields:
-                nodes_per_element = ci.shape[1]
-                rv[field] = np.zeros((num_elem, nodes_per_element), dtype="float64")
-            elif fname in self.elem_fields:
-                rv[field] = np.zeros(num_elem, dtype="float64")
-        for field in fields:
-            ind = 0
-            ftype, fname = field
-            if ftype == "all":
-                mesh_ids = [mesh.mesh_id + 1 for mesh in self.ds.index.mesh_union]
-                objs = [mesh for mesh in self.ds.index.mesh_union]
-            else:
-                mesh_ids = [int(ftype[-1])]
-                chunk = chunks[mesh_ids[0] - 1]
-                objs = chunk.objs
-            if fname in self.node_fields:
-                field_ind = self.node_fields.index(fname)
-                fdata = self.handler.variables['vals_nod_var%d' % (field_ind + 1)]
-                for g in objs:
-                    ci = g.connectivity_indices - self._INDEX_OFFSET
-                    data = fdata[self.ds.step][ci]
-                    ind += g.select(selector, data, rv[field], ind)  # caches
-            if fname in self.elem_fields:
-                field_ind = self.elem_fields.index(fname)
-                for g, mesh_id in zip(objs, mesh_ids):
-                    fdata = self.handler.variables['vals_elem_var%deb%s' %
-                                                   (field_ind + 1, mesh_id)][:]
-                    data = fdata[self.ds.step, :]
-                    ind += g.select(selector, data, rv[field], ind)  # caches
-        return rv
+        with self.handler.open_ds() as ds:
+            chunks = list(chunks)
+            rv = {}
+            for field in fields:
+                ftype, fname = field
+                if ftype == "all":
+                    ci = np.concatenate([mesh.connectivity_indices - self._INDEX_OFFSET \
+                                         for mesh in self.ds.index.mesh_union])
+                else:
+                    ci = ds.variables[ftype][:] - self._INDEX_OFFSET
+                num_elem = ci.shape[0]
+                if fname in self.node_fields:
+                    nodes_per_element = ci.shape[1]
+                    rv[field] = np.zeros((num_elem, nodes_per_element), dtype="float64")
+                elif fname in self.elem_fields:
+                    rv[field] = np.zeros(num_elem, dtype="float64")
+            for field in fields:
+                ind = 0
+                ftype, fname = field
+                if ftype == "all":
+                    mesh_ids = [mesh.mesh_id + 1 for mesh in self.ds.index.mesh_union]
+                    objs = [mesh for mesh in self.ds.index.mesh_union]
+                else:
+                    mesh_ids = [int(ftype[-1])]
+                    chunk = chunks[mesh_ids[0] - 1]
+                    objs = chunk.objs
+                if fname in self.node_fields:
+                    field_ind = self.node_fields.index(fname)
+                    fdata = ds.variables['vals_nod_var%d' % (field_ind + 1)]
+                    for g in objs:
+                        ci = g.connectivity_indices - self._INDEX_OFFSET
+                        data = fdata[self.ds.step][ci]
+                        ind += g.select(selector, data, rv[field], ind)  # caches
+                if fname in self.elem_fields:
+                    field_ind = self.elem_fields.index(fname)
+                    for g, mesh_id in zip(objs, mesh_ids):
+                        fdata = ds.variables['vals_elem_var%deb%s' %
+                                                       (field_ind + 1, mesh_id)][:]
+                        data = fdata[self.ds.step, :]
+                        ind += g.select(selector, data, rv[field], ind)  # caches
+            return rv
 
     def _read_chunk_data(self, chunk, fields):
         # This reads the data from a single chunk, and is only used for

diff -r 1088ca4ac8d84017d1809a6f114c852297286350 -r 22f54591804e88d0cc3842ea8d1c19f1612aba1f yt/utilities/file_handler.py
--- a/yt/utilities/file_handler.py
+++ b/yt/utilities/file_handler.py
@@ -14,6 +14,7 @@
 #-----------------------------------------------------------------------------
 
 from yt.utilities.on_demand_imports import _h5py as h5py
+from contextlib import contextmanager
 
 class HDF5FileHandler(object):
     handle = None
@@ -67,8 +68,13 @@
     def close(self):
         self.handle.close()
 
-class NetCDF4FileHandler(object):
+class NetCDF4FileHandler():
     def __init__(self, filename):
+        self.filename = filename
+
+    @contextmanager
+    def open_ds(self):
         from netCDF4 import Dataset
-        ds = Dataset(filename)
-        self.dataset = ds
+        ds = Dataset(self.filename)
+        yield ds
+        ds.close()


https://bitbucket.org/yt_analysis/yt/commits/6c8928456ca3/
Changeset:   6c8928456ca3
Branch:      yt
User:        al007
Date:        2017-03-13 20:42:19+00:00
Summary:     Override __iter__ and __getitem__
Affected #:  1 file

diff -r 22f54591804e88d0cc3842ea8d1c19f1612aba1f -r 6c8928456ca3d426191748abd8e91166c5325353 yt/frontends/exodus_ii/simulation_handling.py
--- a/yt/frontends/exodus_ii/simulation_handling.py
+++ b/yt/frontends/exodus_ii/simulation_handling.py
@@ -45,6 +45,25 @@
         self.all_outputs = self._check_for_outputs(potential_outputs)
         self.all_outputs.sort(key=lambda obj: obj["filename"])
 
+    def __iter__(self):
+        for o in self._pre_outputs:
+            fn, step = o
+            ds = load(fn, step=step)
+            self._setup_function(ds)
+            yield ds
+
+    def __getitem__(self, key):
+        if isinstance(key, slice):
+            if isinstance(key.start, float):
+                return self.get_range(key.start, key.stop)
+            # This will return a sliced up object!
+            return DatasetSeries(self._pre_outputs[key], self.parallel)
+        o = self._pre_outputs[key]
+        fn, step = o
+        o = load(fn, step=step)
+        self._setup_function(o)
+        return o
+
     def get_time_series(self, parallel=False, setup_function=None):
         r"""
         Instantiate a DatasetSeries object for a set of outputs.
@@ -55,15 +74,15 @@
         Fine-level filtering is currently not implemented.
         
         """
-        
+
         all_outputs = self.all_outputs
         ds_list = []
         for output in all_outputs:
             num_steps = output['num_steps']
             fn = output['filename']
             for step in range(num_steps):
-                ds = ExodusIIDataset(fn, step=step)
-                ds_list.append(ds)
+                # ds = ExodusIIDataset(fn, step=step)
+                ds_list.append((fn, step))
         super(ExodusIISimulation, self).__init__(ds_list, 
                                                  parallel=parallel, 
                                                  setup_function=setup_function)


https://bitbucket.org/yt_analysis/yt/commits/9ac06a9f2eb4/
Changeset:   9ac06a9f2eb4
Branch:      yt
User:        al007
Date:        2017-03-14 00:55:31+00:00
Summary:     Fix flake8 error.
Affected #:  1 file

diff -r 6c8928456ca3d426191748abd8e91166c5325353 -r 9ac06a9f2eb4d2493ab4a5361bc24ec59ea88540 yt/frontends/exodus_ii/simulation_handling.py
--- a/yt/frontends/exodus_ii/simulation_handling.py
+++ b/yt/frontends/exodus_ii/simulation_handling.py
@@ -14,7 +14,6 @@
 from yt.data_objects.time_series import \
     DatasetSeries, \
     RegisteredSimulationTimeSeries
-from yt.frontends.exodus_ii.api import ExodusIIDataset
 
 
 @add_metaclass(RegisteredSimulationTimeSeries)
@@ -27,7 +26,7 @@
 
     simulation_directory : str
         The directory that contain the simulation data.
-    
+
     Examples
     --------
     >>> import yt
@@ -37,7 +36,7 @@
     ...     print ds.current_time
 
     """
-    
+
     def __init__(self, simulation_directory, find_outputs=False):
         self.simulation_directory = simulation_directory
         fn_pattern = "%s/*" % self.simulation_directory
@@ -72,7 +71,7 @@
         created with all potential datasets created by the simulation.
 
         Fine-level filtering is currently not implemented.
-        
+
         """
 
         all_outputs = self.all_outputs
@@ -81,12 +80,11 @@
             num_steps = output['num_steps']
             fn = output['filename']
             for step in range(num_steps):
-                # ds = ExodusIIDataset(fn, step=step)
                 ds_list.append((fn, step))
-        super(ExodusIISimulation, self).__init__(ds_list, 
-                                                 parallel=parallel, 
+        super(ExodusIISimulation, self).__init__(ds_list,
+                                                 parallel=parallel,
                                                  setup_function=setup_function)
-        
+
     def _check_for_outputs(self, potential_outputs):
         r"""
         Check a list of files to see if they are valid datasets.


https://bitbucket.org/yt_analysis/yt/commits/24a55c267d85/
Changeset:   24a55c267d85
Branch:      yt
User:        al007
Date:        2017-03-14 01:00:50+00:00
Summary:     Add back trailing whitespace for ease of review.
Affected #:  1 file

diff -r 9ac06a9f2eb4d2493ab4a5361bc24ec59ea88540 -r 24a55c267d853c5665e6cb65fc1c0bdb62288abf yt/frontends/exodus_ii/simulation_handling.py
--- a/yt/frontends/exodus_ii/simulation_handling.py
+++ b/yt/frontends/exodus_ii/simulation_handling.py
@@ -26,7 +26,7 @@
 
     simulation_directory : str
         The directory that contain the simulation data.
-
+    
     Examples
     --------
     >>> import yt
@@ -36,7 +36,7 @@
     ...     print ds.current_time
 
     """
-
+    
     def __init__(self, simulation_directory, find_outputs=False):
         self.simulation_directory = simulation_directory
         fn_pattern = "%s/*" % self.simulation_directory
@@ -71,7 +71,7 @@
         created with all potential datasets created by the simulation.
 
         Fine-level filtering is currently not implemented.
-
+        
         """
 
         all_outputs = self.all_outputs
@@ -81,10 +81,10 @@
             fn = output['filename']
             for step in range(num_steps):
                 ds_list.append((fn, step))
-        super(ExodusIISimulation, self).__init__(ds_list,
-                                                 parallel=parallel,
+        super(ExodusIISimulation, self).__init__(ds_list, 
+                                                 parallel=parallel, 
                                                  setup_function=setup_function)
-
+        
     def _check_for_outputs(self, potential_outputs):
         r"""
         Check a list of files to see if they are valid datasets.


https://bitbucket.org/yt_analysis/yt/commits/95401107ce07/
Changeset:   95401107ce07
Branch:      yt
User:        xarthisius
Date:        2017-03-14 13:26:55+00:00
Summary:     Merged in al007/yt (pull request #2542)

Close netCDF4 datasets when not being used.

Approved-by: yt-fido
Approved-by: Nathan Goldbaum
Approved-by: Andrew Myers
Approved-by: Kacper Kowalik
Affected #:  4 files

diff -r f3015d6048eb7cd636f07151284643db38cec6fc -r 95401107ce0793c6b74d10c336f0633aaee990d5 yt/frontends/exodus_ii/data_structures.py
--- a/yt/frontends/exodus_ii/data_structures.py
+++ b/yt/frontends/exodus_ii/data_structures.py
@@ -180,18 +180,18 @@
 
     def _parse_parameter_file(self):
         self._handle = NetCDF4FileHandler(self.parameter_filename)
-        self._vars = self._handle.dataset.variables
-        self._read_glo_var()
-        self.dimensionality = self._vars['coor_names'].shape[0]
-        self.parameters['info_records'] = self._load_info_records()
-        self.unique_identifier = self._get_unique_identifier()
-        self.num_steps = len(self._vars['time_whole'])
-        self.current_time = self._get_current_time()
-        self.parameters['num_meshes'] = self._vars['eb_status'].shape[0]
-        self.parameters['elem_names'] = self._get_elem_names()
-        self.parameters['nod_names'] = self._get_nod_names()
-        self.domain_left_edge, self.domain_right_edge = self._load_domain_edge()
-        self.periodicity = (False, False, False)
+        with self._handle.open_ds() as ds:
+            self._read_glo_var()
+            self.dimensionality = ds.variables['coor_names'].shape[0]
+            self.parameters['info_records'] = self._load_info_records()
+            self.unique_identifier = self._get_unique_identifier()
+            self.num_steps = len(ds.variables['time_whole'])
+            self.current_time = self._get_current_time()
+            self.parameters['num_meshes'] = ds.variables['eb_status'].shape[0]
+            self.parameters['elem_names'] = self._get_elem_names()
+            self.parameters['nod_names'] = self._get_nod_names()
+            self.domain_left_edge, self.domain_right_edge = self._load_domain_edge()
+            self.periodicity = (False, False, False)
 
         # These attributes don't really make sense for unstructured
         # mesh data, but yt warns if they are not present, so we set
@@ -205,18 +205,18 @@
         self.refine_by = 0
 
     def _get_fluid_types(self):
-        handle = NetCDF4FileHandler(self.parameter_filename).dataset
-        fluid_types = ()
-        i = 1
-        while True:
-            ftype = 'connect%d' % i
-            if ftype in handle.variables:
-                fluid_types += (ftype,)
-                i += 1
-            else:
-                break
-        fluid_types += ('all',)
-        return fluid_types
+        with NetCDF4FileHandler(self.parameter_filename).open_ds() as ds:
+            fluid_types = ()
+            i = 1
+            while True:
+                ftype = 'connect%d' % i
+                if ftype in ds.variables:
+                    fluid_types += (ftype,)
+                    i += 1
+                else:
+                    break
+            fluid_types += ('all',)
+            return fluid_types
 
     def _read_glo_var(self):
         """
@@ -226,31 +226,34 @@
         names = self._get_glo_names()
         if not names:
             return
-        values = self._vars['vals_glo_var'][:].transpose()
-        for name, value in zip(names, values):
-            self.parameters[name] = value
+        with self._handle.open_ds() as ds:
+            values = ds.variables['vals_glo_var'][:].transpose()
+            for name, value in zip(names, values):
+                self.parameters[name] = value
 
     def _load_info_records(self):
         """
         Returns parsed version of the info_records.
         """
-        try:
-            return load_info_records(self._vars['info_records'])
-        except (KeyError, TypeError):
-            mylog.warning("No info_records found")
-            return []
+        with self._handle.open_ds() as ds:
+            try:
+                return load_info_records(ds.variables['info_records'])
+            except (KeyError, TypeError):
+                mylog.warning("No info_records found")
+                return []
 
     def _get_unique_identifier(self):
         return self.parameter_filename
 
     def _get_current_time(self):
-        try:
-            return self._vars['time_whole'][self.step]
-        except IndexError:
-            raise RuntimeError("Invalid step number, max is %d" \
-                               % (self.num_steps - 1))
-        except (KeyError, TypeError):
-            return 0.0
+        with self._handle.open_ds() as ds:
+            try:
+                return ds.variables['time_whole'][self.step]
+            except IndexError:
+                raise RuntimeError("Invalid step number, max is %d" \
+                                   % (self.num_steps - 1))
+            except (KeyError, TypeError):
+                return 0.0
 
     def _get_glo_names(self):
         """
@@ -259,12 +262,13 @@
 
         """
 
-        if "name_glo_var" not in self._vars:
-            mylog.warning("name_glo_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_glo_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_glo_var" not in ds.variables:
+                mylog.warning("name_glo_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_glo_var"]]
 
     def _get_elem_names(self):
         """
@@ -273,12 +277,13 @@
 
         """
 
-        if "name_elem_var" not in self._vars:
-            mylog.warning("name_elem_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_elem_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_elem_var" not in ds.variables:
+                mylog.warning("name_elem_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_elem_var"]]
 
     def _get_nod_names(self):
         """
@@ -287,12 +292,13 @@
 
         """
 
-        if "name_nod_var" not in self._vars:
-            mylog.warning("name_nod_var not found")
-            return []
-        else:
-            return [sanitize_string(v.tostring()) for v in
-                    self._vars["name_nod_var"]]
+        with self._handle.open_ds() as ds:
+            if "name_nod_var" not in ds.variables:
+                mylog.warning("name_nod_var not found")
+                return []
+            else:
+                return [sanitize_string(v.tostring()) for v in
+                        ds.variables["name_nod_var"]]
 
     def _read_coordinates(self):
         """
@@ -304,13 +310,14 @@
         coord_axes = 'xyz'[:self.dimensionality]
 
         mylog.info("Loading coordinates")
-        if "coord" not in self._vars:
-            coords = np.array([self._vars["coord%s" % ax][:]
-                               for ax in coord_axes]).transpose().copy()
-        else:
-            coords = np.array([coord for coord in
-                               self._vars["coord"][:]]).transpose().copy()
-        return coords
+        with self._handle.open_ds() as ds:
+            if "coord" not in ds.variables:
+                coords = np.array([ds.variables["coord%s" % ax][:]
+                                   for ax in coord_axes]).transpose().copy()
+            else:
+                coords = np.array([coord for coord in
+                                   ds.variables["coord"][:]]).transpose().copy()
+            return coords
 
     def _apply_displacement(self, coords, mesh_id):
 
@@ -324,13 +331,14 @@
         offset = self.displacements[mesh_name][1]
 
         coord_axes = 'xyz'[:self.dimensionality]
-        for i, ax in enumerate(coord_axes):
-            if "disp_%s" % ax in self.parameters['nod_names']:
-                ind = self.parameters['nod_names'].index("disp_%s" % ax)
-                disp = self._vars['vals_nod_var%d' % (ind + 1)][self.step]
-                new_coords[:, i] = coords[:, i] + fac*disp + offset[i]
+        with self._handle.open_ds() as ds:
+            for i, ax in enumerate(coord_axes):
+                if "disp_%s" % ax in self.parameters['nod_names']:
+                    ind = self.parameters['nod_names'].index("disp_%s" % ax)
+                    disp = ds.variables['vals_nod_var%d' % (ind + 1)][self.step]
+                    new_coords[:, i] = coords[:, i] + fac*disp + offset[i]
 
-        return new_coords
+            return new_coords
 
     def _read_connectivity(self):
         """
@@ -338,9 +346,10 @@
         """
         mylog.info("Loading connectivity")
         connectivity = []
-        for i in range(self.parameters['num_meshes']):
-            connectivity.append(self._vars["connect%d" % (i+1)][:].astype("i8"))
-        return connectivity
+        with self._handle.open_ds() as ds:
+            for i in range(self.parameters['num_meshes']):
+                connectivity.append(ds.variables["connect%d" % (i+1)][:].astype("i8"))
+            return connectivity
 
     def _load_domain_edge(self):
         """
@@ -373,7 +382,7 @@
         for i in range(self.dimensionality, 3):
             mi[i] = 0.0
             ma[i] = 1.0
-        
+
         return mi, ma
 
     @classmethod

diff -r f3015d6048eb7cd636f07151284643db38cec6fc -r 95401107ce0793c6b74d10c336f0633aaee990d5 yt/frontends/exodus_ii/io.py
--- a/yt/frontends/exodus_ii/io.py
+++ b/yt/frontends/exodus_ii/io.py
@@ -28,7 +28,7 @@
     def __init__(self, ds):
         self.filename = ds.index_filename
         exodus_ii_handler = NetCDF4FileHandler(self.filename)
-        self.handler = exodus_ii_handler.dataset
+        self.handler = exodus_ii_handler
         super(IOHandlerExodusII, self).__init__(ds)
         self.node_fields = ds._get_nod_names()
         self.elem_fields = ds._get_elem_names()
@@ -46,46 +46,47 @@
         # dict gets returned at the end and it should be flat, with selected
         # data.  Note that if you're reading grid data, you might need to
         # special-case a grid selector object.
-        chunks = list(chunks)
-        rv = {}
-        for field in fields:
-            ftype, fname = field
-            if ftype == "all":
-                ci = np.concatenate([mesh.connectivity_indices - self._INDEX_OFFSET \
-                                     for mesh in self.ds.index.mesh_union])
-            else:
-                ci = self.handler.variables[ftype][:] - self._INDEX_OFFSET
-            num_elem = ci.shape[0]
-            if fname in self.node_fields:
-                nodes_per_element = ci.shape[1]
-                rv[field] = np.zeros((num_elem, nodes_per_element), dtype="float64")
-            elif fname in self.elem_fields:
-                rv[field] = np.zeros(num_elem, dtype="float64")
-        for field in fields:
-            ind = 0
-            ftype, fname = field
-            if ftype == "all":
-                mesh_ids = [mesh.mesh_id + 1 for mesh in self.ds.index.mesh_union]
-                objs = [mesh for mesh in self.ds.index.mesh_union]
-            else:
-                mesh_ids = [int(ftype[-1])]
-                chunk = chunks[mesh_ids[0] - 1]
-                objs = chunk.objs
-            if fname in self.node_fields:
-                field_ind = self.node_fields.index(fname)
-                fdata = self.handler.variables['vals_nod_var%d' % (field_ind + 1)]
-                for g in objs:
-                    ci = g.connectivity_indices - self._INDEX_OFFSET
-                    data = fdata[self.ds.step][ci]
-                    ind += g.select(selector, data, rv[field], ind)  # caches
-            if fname in self.elem_fields:
-                field_ind = self.elem_fields.index(fname)
-                for g, mesh_id in zip(objs, mesh_ids):
-                    fdata = self.handler.variables['vals_elem_var%deb%s' %
-                                                   (field_ind + 1, mesh_id)][:]
-                    data = fdata[self.ds.step, :]
-                    ind += g.select(selector, data, rv[field], ind)  # caches
-        return rv
+        with self.handler.open_ds() as ds:
+            chunks = list(chunks)
+            rv = {}
+            for field in fields:
+                ftype, fname = field
+                if ftype == "all":
+                    ci = np.concatenate([mesh.connectivity_indices - self._INDEX_OFFSET \
+                                         for mesh in self.ds.index.mesh_union])
+                else:
+                    ci = ds.variables[ftype][:] - self._INDEX_OFFSET
+                num_elem = ci.shape[0]
+                if fname in self.node_fields:
+                    nodes_per_element = ci.shape[1]
+                    rv[field] = np.zeros((num_elem, nodes_per_element), dtype="float64")
+                elif fname in self.elem_fields:
+                    rv[field] = np.zeros(num_elem, dtype="float64")
+            for field in fields:
+                ind = 0
+                ftype, fname = field
+                if ftype == "all":
+                    mesh_ids = [mesh.mesh_id + 1 for mesh in self.ds.index.mesh_union]
+                    objs = [mesh for mesh in self.ds.index.mesh_union]
+                else:
+                    mesh_ids = [int(ftype[-1])]
+                    chunk = chunks[mesh_ids[0] - 1]
+                    objs = chunk.objs
+                if fname in self.node_fields:
+                    field_ind = self.node_fields.index(fname)
+                    fdata = ds.variables['vals_nod_var%d' % (field_ind + 1)]
+                    for g in objs:
+                        ci = g.connectivity_indices - self._INDEX_OFFSET
+                        data = fdata[self.ds.step][ci]
+                        ind += g.select(selector, data, rv[field], ind)  # caches
+                if fname in self.elem_fields:
+                    field_ind = self.elem_fields.index(fname)
+                    for g, mesh_id in zip(objs, mesh_ids):
+                        fdata = ds.variables['vals_elem_var%deb%s' %
+                                                       (field_ind + 1, mesh_id)][:]
+                        data = fdata[self.ds.step, :]
+                        ind += g.select(selector, data, rv[field], ind)  # caches
+            return rv
 
     def _read_chunk_data(self, chunk, fields):
         # This reads the data from a single chunk, and is only used for

diff -r f3015d6048eb7cd636f07151284643db38cec6fc -r 95401107ce0793c6b74d10c336f0633aaee990d5 yt/frontends/exodus_ii/simulation_handling.py
--- a/yt/frontends/exodus_ii/simulation_handling.py
+++ b/yt/frontends/exodus_ii/simulation_handling.py
@@ -14,7 +14,6 @@
 from yt.data_objects.time_series import \
     DatasetSeries, \
     RegisteredSimulationTimeSeries
-from yt.frontends.exodus_ii.api import ExodusIIDataset
 
 
 @add_metaclass(RegisteredSimulationTimeSeries)
@@ -45,6 +44,25 @@
         self.all_outputs = self._check_for_outputs(potential_outputs)
         self.all_outputs.sort(key=lambda obj: obj["filename"])
 
+    def __iter__(self):
+        for o in self._pre_outputs:
+            fn, step = o
+            ds = load(fn, step=step)
+            self._setup_function(ds)
+            yield ds
+
+    def __getitem__(self, key):
+        if isinstance(key, slice):
+            if isinstance(key.start, float):
+                return self.get_range(key.start, key.stop)
+            # This will return a sliced up object!
+            return DatasetSeries(self._pre_outputs[key], self.parallel)
+        o = self._pre_outputs[key]
+        fn, step = o
+        o = load(fn, step=step)
+        self._setup_function(o)
+        return o
+
     def get_time_series(self, parallel=False, setup_function=None):
         r"""
         Instantiate a DatasetSeries object for a set of outputs.
@@ -55,15 +73,14 @@
         Fine-level filtering is currently not implemented.
         
         """
-        
+
         all_outputs = self.all_outputs
         ds_list = []
         for output in all_outputs:
             num_steps = output['num_steps']
             fn = output['filename']
             for step in range(num_steps):
-                ds = ExodusIIDataset(fn, step=step)
-                ds_list.append(ds)
+                ds_list.append((fn, step))
         super(ExodusIISimulation, self).__init__(ds_list, 
                                                  parallel=parallel, 
                                                  setup_function=setup_function)

diff -r f3015d6048eb7cd636f07151284643db38cec6fc -r 95401107ce0793c6b74d10c336f0633aaee990d5 yt/utilities/file_handler.py
--- a/yt/utilities/file_handler.py
+++ b/yt/utilities/file_handler.py
@@ -14,6 +14,7 @@
 #-----------------------------------------------------------------------------
 
 from yt.utilities.on_demand_imports import _h5py as h5py
+from contextlib import contextmanager
 
 class HDF5FileHandler(object):
     handle = None
@@ -67,8 +68,13 @@
     def close(self):
         self.handle.close()
 
-class NetCDF4FileHandler(object):
+class NetCDF4FileHandler():
     def __init__(self, filename):
+        self.filename = filename
+
+    @contextmanager
+    def open_ds(self):
         from netCDF4 import Dataset
-        ds = Dataset(filename)
-        self.dataset = ds
+        ds = Dataset(self.filename)
+        yield ds
+        ds.close()

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list