[yt-svn] commit/yt: 7 new changesets

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Wed Dec 4 12:56:45 PST 2013


7 new commits in yt:

https://bitbucket.org/yt_analysis/yt/commits/96f4c8f352a8/
Changeset:   96f4c8f352a8
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 17:10:22
Summary:     Adding particle IO to FLASH.
Affected #:  1 file

diff -r 9537359f38e2c9f5245e7342bf0de2185e522f37 -r 96f4c8f352a8dfeb237af48da24b3bde89b2097a yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -42,6 +42,52 @@
             count_list, conv_factors):
         pass
 
+    def _read_particle_coords(self, chunks, ptf):
+        chunks = list(chunks)
+        f_part = self._particle_handle
+        p_ind = self.pf.h._particle_indices
+        px, py, pz = (self._particle_fields["particle_pos%s" % ax]
+                      for ax in 'xyz')
+        p_fields = f_part["/tracer particles"]
+        assert(len(ptf) == 1)
+        ptype = ptf.keys()[0]
+        for chunk in chunks:
+            for g in chunk.objs:
+                if g.NumberOfParticles == 0:
+                    continue
+                start = p_ind[g.id - g._id_offset]
+                end = p_ind[g.id - g._id_offset + 1]
+                x = p_fields[start:end, px]
+                y = p_fields[start:end, py]
+                z = p_fields[start:end, pz]
+                yield ptype, (x, y, z)
+
+    def _read_particle_fields(self, chunks, ptf, selector):
+        chunks = list(chunks)
+        f_part = self._particle_handle
+        p_ind = self.pf.h._particle_indices
+        px, py, pz = (self._particle_fields["particle_pos%s" % ax]
+                      for ax in 'xyz')
+        p_fields = f_part["/tracer particles"]
+        assert(len(ptf) == 1)
+        ptype = ptf.keys()[0]
+        field_list = ptf[ptype]
+        for chunk in chunks:
+            for g in chunk.objs:
+                if g.NumberOfParticles == 0:
+                    continue
+                start = p_ind[g.id - g._id_offset]
+                end = p_ind[g.id - g._id_offset + 1]
+                x = p_fields[start:end, px]
+                y = p_fields[start:end, py]
+                z = p_fields[start:end, pz]
+                mask = selector.select_points(x, y, z)
+                if mask is None: continue
+                for field in field_list:
+                    fi = self._particle_fields[field]
+                    data = p_fields[start:end, fi]
+                    yield (ptype, field), data[mask]
+
     def _read_data_set(self, grid, field):
         f = self._handle
         f_part = self._particle_handle


https://bitbucket.org/yt_analysis/yt/commits/d9ce2a081f9a/
Changeset:   d9ce2a081f9a
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 19:40:13
Summary:     Set up fields explicitly by type in FLASH frontend.
Affected #:  1 file

diff -r 96f4c8f352a8dfeb237af48da24b3bde89b2097a -r d9ce2a081f9acc450250e6648713e271893130f1 yt/frontends/flash/data_structures.py
--- a/yt/frontends/flash/data_structures.py
+++ b/yt/frontends/flash/data_structures.py
@@ -72,9 +72,9 @@
 
     def _detect_fields(self):
         ncomp = self._handle["/unknown names"].shape[0]
-        self.field_list = [s for s in self._handle["/unknown names"][:].flat]
+        self.field_list = [("gas", s) for s in self._handle["/unknown names"][:].flat]
         if ("/particle names" in self._particle_handle) :
-            self.field_list += ["particle_" + s[0].strip() for s
+            self.field_list += [("io", "particle_" + s[0].strip()) for s
                                 in self._particle_handle["/particle names"][:]]
     
     def _setup_classes(self):
@@ -176,26 +176,6 @@
                 g.dds[1] = DD
         self.max_level = self.grid_levels.max()
 
-    def _setup_derived_fields(self):
-        super(FLASHHierarchy, self)._setup_derived_fields()
-        [self.parameter_file.conversion_factors[field] 
-         for field in self.field_list]
-        for field in self.field_list:
-            if field not in self.derived_field_list:
-                self.derived_field_list.append(field)
-            if (field not in KnownFLASHFields and
-                field.startswith("particle")) :
-                self.parameter_file.field_info.add_field(
-                        field, function=NullFunc, take_log=False,
-                        validators = [ValidateDataField(field)],
-                        particle_type=True)
-
-        for field in self.derived_field_list:
-            f = self.parameter_file.field_info[field]
-            if f._function.func_name == "_TranslationFunc":
-                # Translating an already-converted field
-                self.parameter_file.conversion_factors[field] = 1.0 
-                
 class FLASHStaticOutput(StaticOutput):
     _hierarchy_class = FLASHHierarchy
     _fieldinfo_fallback = FLASHFieldInfo


https://bitbucket.org/yt_analysis/yt/commits/d6f9187706e0/
Changeset:   d6f9187706e0
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 19:40:54
Summary:     Find longest consecutive grid sequences for FLASH particle IO.
Affected #:  1 file

diff -r d9ce2a081f9acc450250e6648713e271893130f1 -r d6f9187706e0ce1133539e4b41f2c8da55479ae0 yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -15,11 +15,19 @@
 
 import numpy as np
 import h5py
+from itertools import groupby
 
 from yt.utilities.io_handler import \
     BaseIOHandler
 from yt.utilities.logger import ytLogger as mylog
 
+# http://stackoverflow.com/questions/2361945/detecting-consecutive-integers-in-a-list
+def particle_sequences(grids):
+    g_iter = sorted(grids, key = lambda g: g.id)
+    for k, g in groupby(enumerate(g_iter), lambda (i,x):i-x.id):
+        seq = list(v[1] for v in g)
+        yield seq[0], seq[-1]
+
 class IOHandlerFLASH(BaseIOHandler):
     _particle_reader = False
     _data_style = "flash_hdf5"
@@ -52,11 +60,10 @@
         assert(len(ptf) == 1)
         ptype = ptf.keys()[0]
         for chunk in chunks:
-            for g in chunk.objs:
-                if g.NumberOfParticles == 0:
-                    continue
-                start = p_ind[g.id - g._id_offset]
-                end = p_ind[g.id - g._id_offset + 1]
+            start = end = None
+            for g1, g2 in particle_sequences(chunk.objs):
+                start = p_ind[g1.id - g1._id_offset]
+                end = p_ind[g2.id - g2._id_offset + 1]
                 x = p_fields[start:end, px]
                 y = p_fields[start:end, py]
                 z = p_fields[start:end, pz]
@@ -73,11 +80,9 @@
         ptype = ptf.keys()[0]
         field_list = ptf[ptype]
         for chunk in chunks:
-            for g in chunk.objs:
-                if g.NumberOfParticles == 0:
-                    continue
-                start = p_ind[g.id - g._id_offset]
-                end = p_ind[g.id - g._id_offset + 1]
+            for g1, g2 in particle_sequences(chunk.objs):
+                start = p_ind[g1.id - g1._id_offset]
+                end = p_ind[g2.id - g2._id_offset + 1]
                 x = p_fields[start:end, px]
                 y = p_fields[start:end, py]
                 z = p_fields[start:end, pz]


https://bitbucket.org/yt_analysis/yt/commits/4520ed80046a/
Changeset:   4520ed80046a
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 19:49:38
Summary:     Use the same sequential grid trick for fluids.
Affected #:  1 file

diff -r d6f9187706e0ce1133539e4b41f2c8da55479ae0 -r 4520ed80046ae82360e3cd91fab9dc17fdcd554a yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -28,6 +28,12 @@
         seq = list(v[1] for v in g)
         yield seq[0], seq[-1]
 
+def grid_sequences(grids):
+    g_iter = sorted(grids, key = lambda g: g.id)
+    for k, g in groupby(enumerate(g_iter), lambda (i,x):i-x.id):
+        seq = list(v[1] for v in g)
+        yield seq
+
 class IOHandlerFLASH(BaseIOHandler):
     _particle_reader = False
     _data_style = "flash_hdf5"
@@ -132,9 +138,12 @@
             ds = f["/%s" % fname]
             ind = 0
             for chunk in chunks:
-                for g in chunk.objs:
-                    data = ds[g.id - g._id_offset,:,:,:].transpose()
-                    ind += g.select(selector, data, rv[field], ind) # caches
+                for gs in grid_sequences(chunk.objs):
+                    start = gs[0].id - gs[0]._id_offset
+                    end = gs[-1].id - gs[-1]._id_offset + 1
+                    data = ds[start:end,:,:,:].transpose()
+                    for i, g in enumerate(gs):
+                        ind += g.select(selector, data[...,i], rv[field], ind)
         return rv
 
     def _read_chunk_data(self, chunk, fields):
@@ -146,8 +155,11 @@
             ftype, fname = field
             ds = f["/%s" % fname]
             ind = 0
-            for g in chunk.objs:
-                data = ds[g.id - g._id_offset,:,:,:].transpose()
-                rv[g.id][field] = data
+            for gs in grid_sequences(chunk.objs):
+                start = gs[0].id - gs[0]._id_offset
+                end = gs[-1].id - gs[-1]._id_offset + 1
+                data = ds[start:end,:,:,:].transpose()
+                for i, g in enumerate(gs):
+                    rv[g.id][field] = data[...,i]
         return rv
 


https://bitbucket.org/yt_analysis/yt/commits/a8187e5a33c0/
Changeset:   a8187e5a33c0
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 19:53:01
Summary:     Remove ValidateDataField from FLASH fields.
Affected #:  1 file

diff -r 4520ed80046ae82360e3cd91fab9dc17fdcd554a -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad yt/frontends/flash/fields.py
--- a/yt/frontends/flash/fields.py
+++ b/yt/frontends/flash/fields.py
@@ -20,11 +20,7 @@
     NullFunc, \
     TranslationFunc, \
     FieldInfo, \
-    ValidateParameter, \
-    ValidateDataField, \
-    ValidateProperty, \
-    ValidateSpatial, \
-    ValidateGridType
+    ValidateSpatial
 import yt.fields.universal_fields
 from yt.utilities.physical_constants import \
     kboltz, mh, Na
@@ -241,7 +237,6 @@
     if v not in KnownFLASHFields:
         pfield = v.startswith("particle")
         add_flash_field(v, function=NullFunc, take_log=False,
-                  validators = [ValidateDataField(v)],
                   particle_type = pfield)
     if f.endswith("_Fraction") :
         dname = "%s\/Fraction" % f.split("_")[0]


https://bitbucket.org/yt_analysis/yt/commits/b064eeba51bd/
Changeset:   b064eeba51bd
Branch:      yt-3.0
User:        MatthewTurk
Date:        2013-12-04 21:55:56
Summary:     Merging from mainline yt-3.0
Affected #:  21 files

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/analysis_modules/photon_simulator/photon_simulator.py
--- a/yt/analysis_modules/photon_simulator/photon_simulator.py
+++ b/yt/analysis_modules/photon_simulator/photon_simulator.py
@@ -250,6 +250,7 @@
             hubble = getattr(pf, "hubble_constant", None)
             omega_m = getattr(pf, "omega_matter", None)
             omega_l = getattr(pf, "omega_lambda", None)
+            if hubble == 0: hubble = None
             if hubble is not None and \
                omega_m is not None and \
                omega_l is not None:
@@ -948,9 +949,9 @@
         col1 = pyfits.Column(name='ENERGY', format='E',
                              array=self["eobs"])
         col2 = pyfits.Column(name='DEC', format='D',
+                             array=self["ysky"])
+        col3 = pyfits.Column(name='RA', format='D',
                              array=self["xsky"])
-        col3 = pyfits.Column(name='RA', format='D',
-                             array=self["ysky"])
 
         coldefs = pyfits.ColDefs([col1, col2, col3])
 

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/analysis_modules/sunyaev_zeldovich/projection.py
--- a/yt/analysis_modules/sunyaev_zeldovich/projection.py
+++ b/yt/analysis_modules/sunyaev_zeldovich/projection.py
@@ -19,11 +19,11 @@
 #-----------------------------------------------------------------------------
 
 from yt.utilities.physical_constants import sigma_thompson, clight, hcgs, kboltz, mh, Tcmb
+from yt.utilities.fits_image import FITSImageBuffer
 from yt.data_objects.image_array import ImageArray
 from yt.data_objects.field_info_container import add_field
 from yt.funcs import fix_axis, mylog, iterable, get_pbar
 from yt.utilities.definitions import inv_axis_names
-from yt.visualization.image_writer import write_fits, write_projection
 from yt.visualization.volume_rendering.camera import off_axis_projection
 from yt.utilities.parallel_tools.parallel_analysis_interface import \
      communication_system, parallel_root_only
@@ -272,32 +272,52 @@
         self.data["TeSZ"] = ImageArray(Te)
 
     @parallel_root_only
-    def write_fits(self, filename, clobber=True):
+    def write_fits(self, filename, sky_center=None, sky_scale=None, clobber=True):
         r""" Export images to a FITS file. Writes the SZ distortion in all
         specified frequencies as well as the mass-weighted temperature and the
-        optical depth. Distance units are in kpc.
+        optical depth. Distance units are in kpc, unless *sky_center*
+        and *scale* are specified. 
 
         Parameters
         ----------
         filename : string
             The name of the FITS file to be written. 
+        sky_center : tuple of floats, optional
+            The center of the observation in (RA, Dec) in degrees. Only used if
+            converting to sky coordinates.          
+        sky_scale : float, optional
+            Scale between degrees and kpc. Only used if
+            converting to sky coordinates.
         clobber : boolean, optional
             If the file already exists, do we overwrite?
 
         Examples
         --------
+        >>> # This example just writes out a FITS file with kpc coords
         >>> szprj.write_fits("SZbullet.fits", clobber=False)
+        >>> # This example uses sky coords
+        >>> sky_scale = 1./3600. # One arcsec per kpc
+        >>> sky_center = (30., 45.) # In degrees
+        >>> szprj.write_fits("SZbullet.fits", sky_center=sky_center, sky_scale=sky_scale)
         """
-        coords = {}
-        coords["dx"] = self.dx*self.pf.units["kpc"]
-        coords["dy"] = self.dy*self.pf.units["kpc"]
-        coords["xctr"] = 0.0
-        coords["yctr"] = 0.0
-        coords["units"] = "kpc"
-        other_keys = {"Time" : self.pf.current_time}
-        write_fits(self.data, filename, clobber=clobber, coords=coords,
-                   other_keys=other_keys)
 
+        deltas = np.array([self.dx*self.pf.units["kpc"],
+                           self.dy*self.pf.units["kpc"]])
+
+        if sky_center is None:
+            center = [0.0]*2
+            units = "kpc"
+        else:
+            center = sky_center
+            units = "deg"
+            deltas *= sky_scale
+            
+        fib = FITSImageBuffer(self.data, fields=self.data.keys(),
+                              center=center, units=units,
+                              scale=deltas)
+        fib.update_all_headers("Time", self.pf.current_time)
+        fib.writeto(filename, clobber=clobber)
+        
     @parallel_root_only
     def write_png(self, filename_prefix, cmap_name="algae",
                   log_fields=None):

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/data_objects/data_containers.py
--- a/yt/data_objects/data_containers.py
+++ b/yt/data_objects/data_containers.py
@@ -353,6 +353,24 @@
         else:
             self.hierarchy.save_object(self, name)
 
+    def to_glue(self, fields, label="yt"):
+        """
+        Takes specific *fields* in the container and exports them to
+        Glue (http://www.glueviz.org) for interactive
+        analysis. Optionally add a *label*.  
+        """
+        from glue.core import DataCollection, Data
+        from glue.core.coordinates import coordinates_from_header
+        from glue.qt.glue_application import GlueApplication
+        
+        gdata = Data(label=label)
+        for component_name in fields:
+            gdata.add_component(self[component_name], component_name)
+        dc = DataCollection([gdata])
+
+        app = GlueApplication(dc)
+        app.start()
+
     def __reduce__(self):
         args = tuple([self.pf._hash(), self._type_name] +
                      [getattr(self, n) for n in self._con_args] +

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/fits/api.py
--- /dev/null
+++ b/yt/frontends/fits/api.py
@@ -0,0 +1,23 @@
+"""
+API for yt.frontends.fits
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+
+from .data_structures import \
+      FITSGrid, \
+      FITSHierarchy, \
+      FITSStaticOutput
+
+from .fields import \
+      FITSFieldInfo, \
+      add_fits_field
+
+from .io import \
+      IOHandlerFITS

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/fits/data_structures.py
--- /dev/null
+++ b/yt/frontends/fits/data_structures.py
@@ -0,0 +1,298 @@
+"""
+FITS-specific data structures
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+
+try:
+    import astropy.io.fits as pyfits
+    import astropy.wcs as pywcs
+except ImportError:
+    pass
+
+import stat
+import numpy as np
+import weakref
+
+from yt.config import ytcfg
+from yt.funcs import *
+from yt.data_objects.grid_patch import \
+    AMRGridPatch
+from yt.geometry.grid_geometry_handler import \
+    GridGeometryHandler
+from yt.geometry.geometry_handler import \
+    YTDataChunk
+from yt.data_objects.static_output import \
+    StaticOutput
+from yt.utilities.definitions import \
+    mpc_conversion, sec_conversion
+from yt.utilities.io_handler import \
+    io_registry
+from yt.utilities.physical_constants import cm_per_mpc
+from .fields import FITSFieldInfo, add_fits_field, KnownFITSFields
+from yt.data_objects.field_info_container import FieldInfoContainer, NullFunc, \
+     ValidateDataField, TranslationFunc
+
+class FITSGrid(AMRGridPatch):
+    _id_offset = 0
+    def __init__(self, id, hierarchy, level):
+        AMRGridPatch.__init__(self, id, filename = hierarchy.hierarchy_filename,
+                              hierarchy = hierarchy)
+        self.Parent = None
+        self.Children = []
+        self.Level = 0
+
+    def __repr__(self):
+        return "FITSGrid_%04i (%s)" % (self.id, self.ActiveDimensions)
+    
+class FITSHierarchy(GridGeometryHandler):
+
+    grid = FITSGrid
+    
+    def __init__(self,pf,data_style='fits'):
+        self.data_style = data_style
+        self.field_indexes = {}
+        self.parameter_file = weakref.proxy(pf)
+        # for now, the hierarchy file is the parameter file!
+        self.hierarchy_filename = self.parameter_file.parameter_filename
+        self.directory = os.path.dirname(self.hierarchy_filename)
+        self._handle = pf._handle
+        self.float_type = np.float64
+        GridGeometryHandler.__init__(self,pf,data_style)
+
+    def _initialize_data_storage(self):
+        pass
+
+    def _detect_fields(self):
+        self.field_list = []
+        for h in self._handle[self.parameter_file.first_image:]:
+            if h.is_image:
+                self.field_list.append(h.name.lower())
+                        
+    def _setup_classes(self):
+        dd = self._get_data_reader_dict()
+        GridGeometryHandler._setup_classes(self, dd)
+        self.object_types.sort()
+
+    def _count_grids(self):
+        self.num_grids = 1
+                
+    def _parse_hierarchy(self):
+        f = self._handle # shortcut
+        pf = self.parameter_file # shortcut
+        
+        # Initialize to the domain left / domain right
+        self.grid_left_edge[0,:] = pf.domain_left_edge
+        self.grid_right_edge[0,:] = pf.domain_right_edge
+        self.grid_dimensions[0] = pf.domain_dimensions
+        
+        # This will become redundant, as _prepare_grid will reset it to its
+        # current value.  Note that FLASH uses 1-based indexing for refinement
+        # levels, but we do not, so we reduce the level by 1.
+        self.grid_levels.flat[:] = 0
+        self.grids = np.empty(self.num_grids, dtype='object')
+        for i in xrange(self.num_grids):
+            self.grids[i] = self.grid(i, self, self.grid_levels[i,0])
+        
+    def _populate_grid_objects(self):
+        self.grids[0]._prepare_grid()
+        self.grids[0]._setup_dx()
+        self.max_level = 0 
+
+    def _setup_derived_fields(self):
+        super(FITSHierarchy, self)._setup_derived_fields()
+        [self.parameter_file.conversion_factors[field] 
+         for field in self.field_list]
+        for field in self.field_list:
+            if field not in self.derived_field_list:
+                self.derived_field_list.append(field)
+
+        for field in self.derived_field_list:
+            f = self.parameter_file.field_info[field]
+            if f._function.func_name == "_TranslationFunc":
+                # Translating an already-converted field
+                self.parameter_file.conversion_factors[field] = 1.0 
+                
+    def _setup_data_io(self):
+        self.io = io_registry[self.data_style](self.parameter_file)
+
+class FITSStaticOutput(StaticOutput):
+    _hierarchy_class = FITSHierarchy
+    _fieldinfo_fallback = FITSFieldInfo
+    _fieldinfo_known = KnownFITSFields
+    _handle = None
+    
+    def __init__(self, filename, data_style='fits',
+                 primary_header = None,
+                 sky_conversion = None,
+                 storage_filename = None,
+                 conversion_override = None):
+
+        if isinstance(filename, pyfits.HDUList):
+            self._handle = filename
+            fname = filename.filename()
+        else:
+            self._handle = pyfits.open(filename)
+            fname = filename
+        for i, h in enumerate(self._handle):
+            if h.is_image and h.data is not None:
+                self.first_image = i
+                break
+            
+        if primary_header is None:
+            self.primary_header = self._handle[self.first_image].header
+        else:
+            self.primary_header = primary_header
+        self.shape = self._handle[self.first_image].shape
+        if conversion_override is None: conversion_override = {}
+        self._conversion_override = conversion_override
+
+        self.wcs = pywcs.WCS(self.primary_header)
+
+        if self.wcs.wcs.cunit[0].name in ["deg","arcsec","arcmin","mas"]:
+            self.sky_wcs = self.wcs.deepcopy()
+            if sky_conversion is None:
+                self._set_minimalist_wcs()
+            else:
+                dims = np.array(self.shape)
+                ndims = len(self.shape)
+                new_unit = sky_conversion[1]
+                new_deltx = np.abs(self.wcs.wcs.cdelt[0])*sky_conversion[0]
+                new_delty = np.abs(self.wcs.wcs.cdelt[1])*sky_conversion[0]
+                self.wcs.wcs.cdelt = [new_deltx, new_delty]
+                self.wcs.wcs.crpix = 0.5*(dims+1)
+                self.wcs.wcs.crval = [0.0]*2
+                self.wcs.wcs.cunit = [new_unit]*2
+                self.wcs.wcs.ctype = ["LINEAR"]*2
+
+        if not all(key in self.primary_header for key in
+                   ["CRPIX1","CRVAL1","CDELT1","CUNIT1"]):
+            self._set_minimalist_wcs()
+
+        StaticOutput.__init__(self, fname, data_style)
+        self.storage_filename = storage_filename
+            
+        self.refine_by = 2
+        self._set_units()
+
+    def _set_minimalist_wcs(self):
+        mylog.warning("Could not determine WCS information. Using pixel units.")
+        dims = np.array(self.shape)
+        ndims = len(dims)
+        self.wcs.wcs.crpix = 0.5*(dims+1)
+        self.wcs.wcs.cdelt = [1.,1.]
+        self.wcs.wcs.crval = 0.5*(dims+1)
+        self.wcs.wcs.cunit = ["pixel"]*ndims
+        self.wcs.wcs.ctype = ["LINEAR"]*ndims
+
+    def _set_units(self):
+        """
+        Generates the conversion to various physical _units based on the parameter file
+        """
+        self.units = {}
+        self.time_units = {}
+        if len(self.parameters) == 0:
+            self._parse_parameter_file()
+        self.conversion_factors = defaultdict(lambda: 1.0)
+        file_unit = self.wcs.wcs.cunit[0].name.lower()
+        if file_unit in mpc_conversion:
+            self._setup_getunits_units()
+        else:
+            self._setup_nounits_units()
+        self.parameters["Time"] = self.conversion_factors["Time"]
+        self.time_units['1'] = 1
+        self.units['1'] = 1.0
+        self.units['unitary'] = 1.0 / (self.domain_right_edge - self.domain_left_edge).max()
+        for unit in sec_conversion.keys():
+            self.time_units[unit] = self.conversion_factors["Time"] / sec_conversion[unit]
+        for p, v in self._conversion_override.items():
+            self.conversion_factors[p] = v
+
+    def _setup_comoving_units(self):
+        pass
+
+    def _setup_getunits_units(self):
+        file_unit = self.wcs.wcs.cunit[0].name.lower()
+        for unit in mpc_conversion.keys():
+            self.units[unit] = mpc_conversion[unit]/mpc_conversion[file_unit]
+        self.conversion_factors["Time"] = 1.0
+                                            
+    def _setup_nounits_units(self):
+        for unit in mpc_conversion.keys():
+            self.units[unit] = mpc_conversion[unit] / mpc_conversion["cm"]
+        self.conversion_factors["Time"] = 1.0
+
+    def _parse_parameter_file(self):
+        self.unique_identifier = \
+            int(os.stat(self.parameter_filename)[stat.ST_CTIME])
+        for k, v in self.primary_header.items():
+            self.parameters[k] = v
+
+        # Determine dimensionality
+
+        self.dimensionality = self.primary_header["naxis"]
+        self.geometry = "cartesian"
+
+        self.domain_dimensions = np.array(self._handle[self.first_image].shape)
+        if self.dimensionality == 2:
+            self.domain_dimensions = np.append(self.domain_dimensions,
+                                               [int(1)])
+        ND = self.dimensionality
+        
+        le = [0.5]*ND
+        re = [float(dim)+0.5 for dim in self.domain_dimensions]
+        if ND == 2:
+            xe, ye = self.wcs.wcs_pix2world([le[0],re[0]],
+                                            [le[1],re[1]], 1)
+            self.domain_left_edge = np.array([xe[0], ye[0], 0.0])
+            self.domain_right_edge = np.array([xe[1], ye[1], 1.0]) 
+        elif ND == 3:
+            xe, ye, ze = world_edges = self.wcs.wcs_pix2world([le[0],re[0]],
+                                                              [le[1],re[1]],
+                                                              [le[2],re[2]], 1)
+            self.domain_left_edge = np.array([xe[0], ye[0], ze[0]])
+            self.domain_right_edge = np.array([xe[1], ye[1], ze[1]])
+
+        # Get the simulation time
+        try:
+            self.current_time = self.parameters["time"]
+        except:
+            mylog.warning("Cannot find time")
+            self.current_time = 0.0
+            pass
+        
+        # For now we'll ignore these
+        self.periodicity = (False,)*3
+        self.current_redshift = self.omega_lambda = self.omega_matter = \
+            self.hubble_constant = self.cosmological_simulation = 0.0
+
+    def __del__(self):
+        self._handle.close()
+
+    @classmethod
+    def _is_valid(self, *args, **kwargs):
+        try:
+            if isinstance(args[0], pyfits.HDUList):
+                for h in args[0]:
+                    if h.is_image and h.data is not None:
+                        return True
+        except:
+            pass
+        try:
+            fileh = pyfits.open(args[0])
+            for h in fileh:
+                if h.is_image and h.data is not None:
+                    fileh.close()
+                    return True
+            fileh.close()
+        except:
+            pass
+        return False
+
+

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/fits/fields.py
--- /dev/null
+++ b/yt/frontends/fits/fields.py
@@ -0,0 +1,31 @@
+"""
+FITS-specific fields
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+
+import numpy as np
+from yt.utilities.exceptions import *
+from yt.data_objects.field_info_container import \
+    FieldInfoContainer, \
+    NullFunc, \
+    TranslationFunc, \
+    FieldInfo, \
+    ValidateParameter, \
+    ValidateDataField, \
+    ValidateProperty, \
+    ValidateSpatial, \
+    ValidateGridType
+import yt.fields.universal_fields
+KnownFITSFields = FieldInfoContainer()
+add_fits_field = KnownFITSFields.add_field
+
+FITSFieldInfo = FieldInfoContainer.create_with_fallback(FieldInfo)
+add_field = FITSFieldInfo.add_field
+

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/fits/io.py
--- /dev/null
+++ b/yt/frontends/fits/io.py
@@ -0,0 +1,82 @@
+"""
+FITS-specific IO functions
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+
+import numpy as np
+try:
+    import astropy.io.fits as pyfits
+except ImportError:
+    pass
+
+from yt.utilities.math_utils import prec_accum
+
+from yt.utilities.io_handler import \
+    BaseIOHandler
+from yt.utilities.logger import ytLogger as mylog
+
+class IOHandlerFITS(BaseIOHandler):
+    _particle_reader = False
+    _data_style = "fits"
+
+    def __init__(self, pf):
+        super(IOHandlerFITS, self).__init__(pf)
+        self.pf = pf
+        self._handle = pf._handle
+        
+    def _read_particles(self, fields_to_read, type, args, grid_list,
+            count_list, conv_factors):
+        pass
+
+    def _read_data_set(self, grid, field):
+        f = self._handle
+        if self.pf.dimensionality == 2:
+            nx,ny = f[field].data.tranpose().shape
+            tr = f[field].data.transpose().reshape(nx,ny,1)
+        elif self.pf.dimensionality == 3:
+            tr = f[field].data.transpose()
+        return tr.astype("float64")
+
+    def _read_data_slice(self, grid, field, axis, coord):
+        sl = [slice(None), slice(None), slice(None)]
+        sl[axis] = slice(coord, coord + 1)
+        f = self._handle
+        if self.pf.dimensionality == 2:
+            nx,ny = f[field].data.transpose().shape
+            tr = f[field].data.transpose().reshape(nx,ny,1)[sl]
+        elif self.pf.dimensionality == 3:
+            tr = f[field].data.transpose()[sl]
+        return tr.astype("float64")
+
+    def _read_fluid_selection(self, chunks, selector, fields, size):
+        chunks = list(chunks)
+        if any((ftype != "gas" for ftype, fname in fields)):
+            raise NotImplementedError
+        f = self._handle
+        rv = {}
+        dt = "float64"
+        for field in fields:
+            rv[field] = np.empty(size, dtype=dt)
+        ng = sum(len(c.objs) for c in chunks)
+        mylog.debug("Reading %s cells of %s fields in %s blocks",
+                    size, [f2 for f1, f2 in fields], ng)
+        for field in fields:
+            ftype, fname = field
+            ds = f[fname].data.astype("float64")
+            ind = 0
+            for chunk in chunks:
+                for g in chunk.objs:
+                    if self.pf.dimensionality == 2:
+                        nx,ny = ds.transpose().shape
+                        data = ds.transpose().reshape(nx,ny,1)
+                    elif self.pf.dimensionality == 3:
+                        data = ds.transpose()
+                    ind += g.select(selector, data, rv[field], ind) # caches
+        return rv

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 import h5py
+from yt.utilities.math_utils import prec_accum
 from itertools import groupby
 
 from yt.utilities.io_handler import \

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/frontends/setup.py
--- a/yt/frontends/setup.py
+++ b/yt/frontends/setup.py
@@ -12,6 +12,7 @@
     config.add_subpackage("boxlib")
     config.add_subpackage("chombo")
     config.add_subpackage("enzo")
+    config.add_subpackage("fits")
     config.add_subpackage("flash")
     config.add_subpackage("gdf")
     config.add_subpackage("moab")

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/mods.py
--- a/yt/mods.py
+++ b/yt/mods.py
@@ -116,6 +116,9 @@
     GadgetFieldInfo, add_gadget_field, \
     TipsyStaticOutput, TipsyFieldInfo, add_tipsy_field
 
+#from yt.frontends.fits.api import \
+#    FITSStaticOutput, FITSFieldInfo, add_fits_field
+
 from yt.analysis_modules.list_modules import \
     get_available_modules, amods
 available_analysis_modules = get_available_modules()
@@ -132,7 +135,7 @@
     PlotCollection, PlotCollectionInteractive, \
     get_multi_plot, FixedResolutionBuffer, ObliqueFixedResolutionBuffer, \
     callback_registry, write_bitmap, write_image, \
-    apply_colormap, scale_image, write_projection, write_fits, \
+    apply_colormap, scale_image, write_projection, \
     SlicePlot, AxisAlignedSlicePlot, OffAxisSlicePlot, \
     ProjectionPlot, OffAxisProjectionPlot, \
     show_colormaps

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/utilities/fits_image.py
--- /dev/null
+++ b/yt/utilities/fits_image.py
@@ -0,0 +1,246 @@
+"""
+FITSImageBuffer Class
+"""
+
+#-----------------------------------------------------------------------------
+# Copyright (c) 2013, yt Development Team.
+#
+# Distributed under the terms of the Modified BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+
+import numpy as np
+from yt.funcs import mylog, iterable
+from yt.visualization.fixed_resolution import FixedResolutionBuffer
+from yt.data_objects.construction_data_containers import YTCoveringGridBase
+
+try:
+    from astropy.io.fits import HDUList, ImageHDU
+    from astropy import wcs as pywcs
+except ImportError:
+    pass
+
+class FITSImageBuffer(HDUList):
+
+    def __init__(self, data, fields=None, units="cm",
+                 center=None, scale=None):
+        r""" Initialize a FITSImageBuffer object.
+
+        FITSImageBuffer contains a list of FITS ImageHDU instances, and optionally includes
+        WCS information. It inherits from HDUList, so operations such as `writeto` are
+        enabled. Images can be constructed from ImageArrays, NumPy arrays, dicts of such
+        arrays, FixedResolutionBuffers, and YTCoveringGrids. The latter
+        two are the most powerful because WCS information can be constructed from their coordinates.
+
+        Parameters
+        ----------
+        data : FixedResolutionBuffer or a YTCoveringGrid. Or, an
+            ImageArray, an numpy.ndarray, or dict of such arrays
+            The data to be made into a FITS image or images.
+        fields : single string or list of strings, optional
+            The field names for the data. If *fields* is none and *data* has keys,
+            it will use these for the fields. If *data* is just a single array one field name
+            must be specified.
+        units : string
+            The units of the WCS coordinates, default "cm". 
+        center : array_like, optional
+            The coordinates [xctr,yctr] of the images in units
+            *units*. If *units* is not specified, defaults to the origin. 
+        scale : tuple of floats, optional
+            Pixel scale in unit *units*. Will be ignored if *data* is
+            a FixedResolutionBuffer or a YTCoveringGrid. Must be
+            specified otherwise, or if *units* is "deg".
+
+        Examples
+        --------
+
+        >>> ds = load("sloshing_nomag2_hdf5_plt_cnt_0150")
+        >>> prj = ds.h.proj(2, "TempkeV", weight_field="Density")
+        >>> frb = prj.to_frb((0.5, "mpc"), 800)
+        >>> # This example just uses the FRB and puts the coords in kpc.
+        >>> f_kpc = FITSImageBuffer(frb, fields="TempkeV", units="kpc")
+        >>> # This example specifies sky coordinates.
+        >>> scale = [1./3600.]*2 # One arcsec per pixel
+        >>> f_deg = FITSImageBuffer(frb, fields="TempkeV", units="deg",
+                                    scale=scale, center=(30., 45.))
+        >>> f_deg.writeto("temp.fits")
+        """
+        
+        super(HDUList, self).__init__()
+
+        if isinstance(fields, basestring): fields = [fields]
+            
+        exclude_fields = ['x','y','z','px','py','pz',
+                          'pdx','pdy','pdz','weight_field']
+        
+        if hasattr(data, 'keys'):
+            img_data = data
+        else:
+            img_data = {}
+            if fields is None:
+                mylog.error("Please specify a field name for this array.")
+                raise KeyError
+            img_data[fields[0]] = data
+
+        if fields is None: fields = img_data.keys()
+        if len(fields) == 0:
+            mylog.error("Please specify one or more fields to write.")
+            raise KeyError
+
+        first = False
+    
+        for key in fields:
+            if key not in exclude_fields:
+                mylog.info("Making a FITS image of field %s" % (key))
+                if first:
+                    hdu = PrimaryHDU(np.array(img_data[key]))
+                    hdu.name = key
+                else:
+                    hdu = ImageHDU(np.array(img_data[key]), name=key)
+                self.append(hdu)
+
+        self.dimensionality = len(self[0].data.shape)
+        
+        if self.dimensionality == 2:
+            self.nx, self.ny = self[0].data.shape
+        elif self.dimensionality == 3:
+            self.nx, self.ny, self.nz = self[0].data.shape
+
+        has_coords = (isinstance(img_data, FixedResolutionBuffer) or
+                      isinstance(img_data, YTCoveringGridBase))
+        
+        if center is None:
+            if units == "deg":
+                mylog.error("Please specify center=(RA, Dec) in degrees.")
+                raise ValueError
+            elif not has_coords:
+                mylog.warning("Setting center to the origin.")
+                center = [0.0]*self.dimensionality
+
+        if scale is None:
+            if units == "deg" or not has_coords:
+                mylog.error("Please specify scale=(dx,dy[,dz]) in %s." % (units))
+                raise ValueError
+
+        w = pywcs.WCS(header=self[0].header, naxis=self.dimensionality)
+        w.wcs.crpix = 0.5*(np.array(self.shape)+1)
+
+        proj_type = ["linear"]*self.dimensionality
+
+        if isinstance(img_data, FixedResolutionBuffer) and units != "deg":
+            # FRBs are a special case where we have coordinate
+            # information, so we take advantage of this and
+            # construct the WCS object
+            dx = (img_data.bounds[1]-img_data.bounds[0])/self.nx
+            dy = (img_data.bounds[3]-img_data.bounds[2])/self.ny
+            dx *= img_data.pf.units[units]
+            dy *= img_data.pf.units[units]
+            xctr = 0.5*(img_data.bounds[1]+img_data.bounds[0])
+            yctr = 0.5*(img_data.bounds[3]+img_data.bounds[2])
+            xctr *= img_data.pf.units[units]
+            yctr *= img_data.pf.units[units]
+            center = [xctr, yctr]
+        elif isinstance(img_data, YTCoveringGridBase):
+            dx, dy, dz = img_data.dds
+            dx *= img_data.pf.units[units]
+            dy *= img_data.pf.units[units]
+            dz *= img_data.pf.units[units]
+            center = 0.5*(img_data.left_edge+img_data.right_edge)
+            center *= img_data.pf.units[units]
+        elif units == "deg" and self.dimensionality == 2:
+            dx = -scale[0]
+            dy = scale[1]
+            proj_type = ["RA---TAN","DEC--TAN"]
+        else:
+            dx = scale[0]
+            dy = scale[1]
+            if self.dimensionality == 3: dz = scale[2]
+            
+        w.wcs.crval = center
+        w.wcs.cunit = [units]*self.dimensionality
+        w.wcs.ctype = proj_type
+        
+        if self.dimensionality == 2:
+            w.wcs.cdelt = [dx,dy]
+        elif self.dimensionality == 3:
+            w.wcs.cdelt = [dx,dy,dz]
+
+        self._set_wcs(w)
+            
+    def _set_wcs(self, wcs):
+        """
+        Set the WCS coordinate information for all images
+        with a WCS object *wcs*.
+        """
+        self.wcs = wcs
+        h = self.wcs.to_header()
+        for img in self:
+            for k, v in h.items():
+                img.header.update(k,v)
+
+    def update_all_headers(self, key, value):
+        """
+        Update the FITS headers for all images with the
+        same *key*, *value* pair.
+        """
+        for img in self: img.header.update(key,value)
+            
+    def keys(self):
+        return [f.name for f in self]
+
+    def has_key(self, key):
+        return key in self.keys()
+
+    def values(self):
+        return [self[k] for k in self.keys()]
+
+    def items(self):
+        return [(k, self[k]) for k in self.keys()]
+
+    def __add__(self, other):
+        if len(set(self.keys()).intersection(set(other.keys()))) > 0:
+            mylog.error("There are duplicate extension names! Don't know which ones you want to keep!")
+            raise KeyError
+        new_buffer = {}
+        for im1 in self:
+            new_buffer[im1.name] = im1.data
+        for im2 in other:
+            new_buffer[im2.name] = im2.data
+        new_wcs = self.wcs
+        return FITSImageBuffer(new_buffer, wcs=new_wcs)
+
+    def writeto(self, fileobj, **kwargs):
+        HDUList(self).writeto(fileobj, **kwargs)
+        
+    @property
+    def shape(self):
+        if self.dimensionality == 2:
+            return self.nx, self.ny
+        elif self.dimensionality == 3:
+            return self.nx, self.ny, self.nz
+
+    def to_glue(self, label="yt"):
+        """
+        Takes the data in the FITSImageBuffer and exports it to
+        Glue (http://www.glueviz.org) for interactive
+        analysis. Optionally add a *label*. 
+        """
+        from glue.core import DataCollection, Data
+        from glue.core.coordinates import coordinates_from_header
+        from glue.qt.glue_application import GlueApplication
+
+        field_dict = dict((key,self[key].data) for key in self.keys())
+        
+        image = Data(label=label)
+        image.coords = coordinates_from_header(self.wcs.to_header())
+        for k,v in field_dict.items():
+            image.add_component(v, k)
+        dc = DataCollection([image])
+
+        app = GlueApplication(dc)
+        app.start()
+
+        
+
+    

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/visualization/api.py
--- a/yt/visualization/api.py
+++ b/yt/visualization/api.py
@@ -34,8 +34,7 @@
     splat_points, \
     apply_colormap, \
     scale_image, \
-    write_projection, \
-    write_fits
+    write_projection
 
 from plot_modifications import \
     PlotCallback, \

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/visualization/color_maps.py
--- a/yt/visualization/color_maps.py
+++ b/yt/visualization/color_maps.py
@@ -120,7 +120,7 @@
 # Add colormaps in _colormap_data.py that weren't defined here
 _vs = np.linspace(0,1,255)
 for k,v in _cm.color_map_luts.iteritems():
-    if k not in yt_colormaps:
+    if k not in yt_colormaps and k not in mcm.cmap_d:
         cdict = { 'red': zip(_vs,v[0],v[0]),
                   'green': zip(_vs,v[1],v[1]),
                   'blue': zip(_vs,v[2],v[2]) }
@@ -140,7 +140,7 @@
     Displays the colormaps available to yt.  Note, most functions can use
     both the matplotlib and the native yt colormaps; however, there are 
     some special functions existing within image_writer.py (e.g. write_image()
-    write_fits(), write_bitmap(), etc.), which cannot access the matplotlib
+    write_bitmap(), etc.), which cannot access the matplotlib
     colormaps.
 
     In addition to the colormaps listed, one can access the reverse of each 

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/visualization/fixed_resolution.py
--- a/yt/visualization/fixed_resolution.py
+++ b/yt/visualization/fixed_resolution.py
@@ -19,7 +19,6 @@
     y_dict, \
     axis_names
 from .volume_rendering.api import off_axis_projection
-from image_writer import write_fits
 from yt.data_objects.image_array import ImageArray
 from yt.utilities.lib.misc_utilities import \
     pixelize_cylinder
@@ -271,7 +270,7 @@
         output.close()
 
     def export_fits(self, filename, fields=None, clobber=False,
-                    other_keys=None, units="cm", sky_center=(0.0,0.0), D_A=None):
+                    other_keys=None, units="cm"):
         r"""Export a set of pixelized fields to a FITS file.
 
         This will export a set of FITS images of either the fields specified
@@ -291,13 +290,6 @@
             the length units that the coordinates are written in, default 'cm'
             If units are set to "deg" then assume that sky coordinates are
             requested.
-        sky_center : array_like, optional
-            Center of the image in (ra,dec) in degrees if sky coordinates
-            (units="deg") are requested.
-        D_A : float or tuple, optional
-            Angular diameter distance, given in code units as a float or
-            a tuple containing the value and the length unit. Required if
-            using sky coordinates.
         """
 
         try:
@@ -305,77 +297,19 @@
         except:
             mylog.error("You don't have AstroPy installed!")
             raise ImportError
-        
-        if units == "deg" and D_A is None:
-            mylog.error("Sky coordinates require an angular diameter distance. Please specify D_A.")    
-            raise ValueError
-    
-        if iterable(D_A):
-            dist = D_A[0]/self.pf.units[D_A[1]]
-        else:
-            dist = D_A
-
-        if other_keys is None:
-            hdu_keys = {}
-        else:
-            hdu_keys = other_keys
+        from yt.utilities.fits_image import FITSImageBuffer
 
         extra_fields = ['x','y','z','px','py','pz','pdx','pdy','pdz','weight_field']
         if fields is None: 
             fields = [field for field in self.data_source.fields 
                       if field not in extra_fields]
 
-        coords = {}
-        nx, ny = self.buff_size
-        dx = (self.bounds[1]-self.bounds[0])/nx
-        dy = (self.bounds[3]-self.bounds[2])/ny
-        if units == "deg":  
-            coords["dx"] = -np.rad2deg(dx/dist)
-            coords["dy"] = np.rad2deg(dy/dist)
-            coords["xctr"] = sky_center[0]
-            coords["yctr"] = sky_center[1]
-            hdu_keys["MTYPE1"] = "EQPOS"
-            hdu_keys["MFORM1"] = "RA,DEC"
-            hdu_keys["CTYPE1"] = "RA---TAN"
-            hdu_keys["CTYPE2"] = "DEC--TAN"
-        else:
-            coords["dx"] = dx*self.pf.units[units]
-            coords["dy"] = dy*self.pf.units[units]
-            coords["xctr"] = 0.5*(self.bounds[0]+self.bounds[1])*self.pf.units[units]
-            coords["yctr"] = 0.5*(self.bounds[2]+self.bounds[3])*self.pf.units[units]
-        coords["units"] = units
+        fib = FITSImageBuffer(self, fields=fields, units=units)
+        if other_keys is not None:
+            for k,v in other_keys.items():
+                fib.update_all_headers(k,v)
+        fib.writeto(filename, clobber=clobber)
         
-        hdu_keys["Time"] = self.pf.current_time
-
-        data = dict([(field,self[field]) for field in fields])
-        write_fits(data, filename, clobber=clobber, coords=coords,
-                   other_keys=hdu_keys)
-
-    def open_in_ds9(self, field, take_log=True):
-        """
-        This will open a given field in the DS9 viewer.
-
-        Displaying fields can often be much easier in an interactive viewer,
-        particularly one as versatile as DS9.  This function will pixelize a
-        field and export it to an interactive DS9 package.  This requires the
-        *numdisplay* package, which is a simple download from STSci.
-        Furthermore, it presupposed that it can connect to DS9 -- that is, that
-        DS9 is already open.
-
-        Parameters
-        ----------
-        field : strings
-            This field will be pixelized and displayed.
-        take_log : boolean
-            DS9 seems to have issues with logging fields in-memory.  This will
-            pre-log the field before sending it to DS9.
-        """
-        import numdisplay
-        numdisplay.open()
-        if take_log: data=np.log10(self[field])
-        else: data=self[field]
-        numdisplay.display(data)    
-
     @property
     def limits(self):
         rv = dict(x = None, y = None, z = None)

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/visualization/image_writer.py
--- a/yt/visualization/image_writer.py
+++ b/yt/visualization/image_writer.py
@@ -386,76 +386,6 @@
     canvas.print_figure(filename, dpi=dpi)
     return filename
 
-
-def write_fits(image, filename, clobber=True, coords=None,
-               other_keys=None):
-    r"""Write out floating point arrays directly to a FITS file, optionally
-    adding coordinates and header keywords.
-        
-    Parameters
-    ----------
-    image : array_like, or dict of array_like objects
-        This is either an (unscaled) array of floating point values, or a dict of
-        such arrays, shape (N,N,) to save in a FITS file. 
-    filename : string
-        This name of the FITS file to be written.
-    clobber : boolean
-        If the file exists, this governs whether we will overwrite.
-    coords : dictionary, optional
-        A set of header keys and values to write to the FITS header to set up
-        a coordinate system, which is assumed to be linear unless specified otherwise
-        in *other_keys*
-        "units": the length units
-        "xctr","yctr": the center of the image
-        "dx","dy": the pixel width in each direction                                                
-    other_keys : dictionary, optional
-        A set of header keys and values to write into the FITS header.    
-    """
-
-    try:
-        import astropy.io.fits as pyfits
-    except:
-        mylog.error("You don't have AstroPy installed!")
-        raise ImportError
-    
-    try:
-        image.keys()
-        image_dict = image
-    except:
-        image_dict = dict(yt_data=image)
-
-    hdulist = [pyfits.PrimaryHDU()]
-
-    for key in image_dict.keys():
-
-        mylog.info("Writing image block \"%s\"" % (key))
-        hdu = pyfits.ImageHDU(image_dict[key])
-        hdu.update_ext_name(key)
-        
-        if coords is not None:
-            nx, ny = image_dict[key].shape
-            hdu.header.update('CUNIT1', coords["units"])
-            hdu.header.update('CUNIT2', coords["units"])
-            hdu.header.update('CRPIX1', 0.5*(nx+1))
-            hdu.header.update('CRPIX2', 0.5*(ny+1))
-            hdu.header.update('CRVAL1', coords["xctr"])
-            hdu.header.update('CRVAL2', coords["yctr"])
-            hdu.header.update('CDELT1', coords["dx"])
-            hdu.header.update('CDELT2', coords["dy"])
-            # These are the defaults, but will get overwritten if
-            # the caller has specified them
-            hdu.header.update('CTYPE1', "LINEAR")
-            hdu.header.update('CTYPE2', "LINEAR")
-                                    
-        if other_keys is not None:
-            for k,v in other_keys.items():
-                hdu.header.update(k,v)
-
-        hdulist.append(hdu)
-
-    hdulist = pyfits.HDUList(hdulist)
-    hdulist.writeto(filename, clobber=clobber)                    
-
 def display_in_notebook(image, max_val=None):
     """
     A helper function to display images in an IPython notebook

diff -r a8187e5a33c033d3d85c6d24a5e4ba1433119fad -r b064eeba51bd197eff59ba0939e9210cf182940d yt/visualization/volume_rendering/image_handling.py
--- a/yt/visualization/volume_rendering/image_handling.py
+++ b/yt/visualization/volume_rendering/image_handling.py
@@ -23,6 +23,8 @@
     and saves to *fn*.  If *h5* is True, then it will save in hdf5 format.  If
     *fits* is True, it will save in fits format.
     """
+    if (not h5 and not fits) or (h5 and fits):
+        raise ValueError("Choose either HDF5 or FITS format!")
     if h5:
         f = h5py.File('%s.h5'%fn, "w")
         f.create_dataset("R", data=image[:,:,0])
@@ -31,17 +33,17 @@
         f.create_dataset("A", data=image[:,:,3])
         f.close()
     if fits:
-        try:
-            import pyfits
-        except ImportError:
-            mylog.error('You do not have pyfits, install before attempting to use fits exporter')
-            raise
-        hdur = pyfits.PrimaryHDU(image[:,:,0])
-        hdug = pyfits.ImageHDU(image[:,:,1])
-        hdub = pyfits.ImageHDU(image[:,:,2])
-        hdua = pyfits.ImageHDU(image[:,:,3])
-        hdulist = pyfits.HDUList([hdur,hdug,hdub,hdua])
-        hdulist.writeto('%s.fits'%fn,clobber=True)
+        from yt.utilities.fits_image import FITSImageBuffer
+        data = {}
+        data["r"] = image[:,:,0]
+        data["g"] = image[:,:,1]
+        data["b"] = image[:,:,2]
+        data["a"] = image[:,:,3]
+        nx, ny = data["r"].shape
+        fib = FITSImageBuffer(data, units="pixel",
+                              center=[0.5*(nx+1), 0.5*(ny+1)],
+                              scale=[1.]*2)
+        fib.writeto('%s.fits'%fn,clobber=True)
 
 def import_rgba(name, h5=True):
     """


https://bitbucket.org/yt_analysis/yt/commits/424f1dcfd229/
Changeset:   424f1dcfd229
Branch:      yt-3.0
User:        jzuhone
Date:        2013-12-04 21:56:41
Summary:     Merged in MatthewTurk/yt/yt-3.0 (pull request #667)

Adding particle IO to FLASH.
Affected #:  3 files

diff -r 184f8241a86a70fede034a9484e83eca05d0e9e3 -r 424f1dcfd229055851cb113b23742362f98f2885 yt/frontends/flash/data_structures.py
--- a/yt/frontends/flash/data_structures.py
+++ b/yt/frontends/flash/data_structures.py
@@ -72,9 +72,9 @@
 
     def _detect_fields(self):
         ncomp = self._handle["/unknown names"].shape[0]
-        self.field_list = [s for s in self._handle["/unknown names"][:].flat]
+        self.field_list = [("gas", s) for s in self._handle["/unknown names"][:].flat]
         if ("/particle names" in self._particle_handle) :
-            self.field_list += ["particle_" + s[0].strip() for s
+            self.field_list += [("io", "particle_" + s[0].strip()) for s
                                 in self._particle_handle["/particle names"][:]]
     
     def _setup_classes(self):
@@ -176,26 +176,6 @@
                 g.dds[1] = DD
         self.max_level = self.grid_levels.max()
 
-    def _setup_derived_fields(self):
-        super(FLASHHierarchy, self)._setup_derived_fields()
-        [self.parameter_file.conversion_factors[field] 
-         for field in self.field_list]
-        for field in self.field_list:
-            if field not in self.derived_field_list:
-                self.derived_field_list.append(field)
-            if (field not in KnownFLASHFields and
-                field.startswith("particle")) :
-                self.parameter_file.field_info.add_field(
-                        field, function=NullFunc, take_log=False,
-                        validators = [ValidateDataField(field)],
-                        particle_type=True)
-
-        for field in self.derived_field_list:
-            f = self.parameter_file.field_info[field]
-            if f._function.func_name == "_TranslationFunc":
-                # Translating an already-converted field
-                self.parameter_file.conversion_factors[field] = 1.0 
-                
 class FLASHStaticOutput(StaticOutput):
     _hierarchy_class = FLASHHierarchy
     _fieldinfo_fallback = FLASHFieldInfo

diff -r 184f8241a86a70fede034a9484e83eca05d0e9e3 -r 424f1dcfd229055851cb113b23742362f98f2885 yt/frontends/flash/fields.py
--- a/yt/frontends/flash/fields.py
+++ b/yt/frontends/flash/fields.py
@@ -20,11 +20,7 @@
     NullFunc, \
     TranslationFunc, \
     FieldInfo, \
-    ValidateParameter, \
-    ValidateDataField, \
-    ValidateProperty, \
-    ValidateSpatial, \
-    ValidateGridType
+    ValidateSpatial
 import yt.fields.universal_fields
 from yt.utilities.physical_constants import \
     kboltz, mh, Na
@@ -241,7 +237,6 @@
     if v not in KnownFLASHFields:
         pfield = v.startswith("particle")
         add_flash_field(v, function=NullFunc, take_log=False,
-                  validators = [ValidateDataField(v)],
                   particle_type = pfield)
     if f.endswith("_Fraction") :
         dname = "%s\/Fraction" % f.split("_")[0]

diff -r 184f8241a86a70fede034a9484e83eca05d0e9e3 -r 424f1dcfd229055851cb113b23742362f98f2885 yt/frontends/flash/io.py
--- a/yt/frontends/flash/io.py
+++ b/yt/frontends/flash/io.py
@@ -16,11 +16,25 @@
 import numpy as np
 import h5py
 from yt.utilities.math_utils import prec_accum
+from itertools import groupby
 
 from yt.utilities.io_handler import \
     BaseIOHandler
 from yt.utilities.logger import ytLogger as mylog
 
+# http://stackoverflow.com/questions/2361945/detecting-consecutive-integers-in-a-list
+def particle_sequences(grids):
+    g_iter = sorted(grids, key = lambda g: g.id)
+    for k, g in groupby(enumerate(g_iter), lambda (i,x):i-x.id):
+        seq = list(v[1] for v in g)
+        yield seq[0], seq[-1]
+
+def grid_sequences(grids):
+    g_iter = sorted(grids, key = lambda g: g.id)
+    for k, g in groupby(enumerate(g_iter), lambda (i,x):i-x.id):
+        seq = list(v[1] for v in g)
+        yield seq
+
 class IOHandlerFLASH(BaseIOHandler):
     _particle_reader = False
     _data_style = "flash_hdf5"
@@ -43,6 +57,49 @@
             count_list, conv_factors):
         pass
 
+    def _read_particle_coords(self, chunks, ptf):
+        chunks = list(chunks)
+        f_part = self._particle_handle
+        p_ind = self.pf.h._particle_indices
+        px, py, pz = (self._particle_fields["particle_pos%s" % ax]
+                      for ax in 'xyz')
+        p_fields = f_part["/tracer particles"]
+        assert(len(ptf) == 1)
+        ptype = ptf.keys()[0]
+        for chunk in chunks:
+            start = end = None
+            for g1, g2 in particle_sequences(chunk.objs):
+                start = p_ind[g1.id - g1._id_offset]
+                end = p_ind[g2.id - g2._id_offset + 1]
+                x = p_fields[start:end, px]
+                y = p_fields[start:end, py]
+                z = p_fields[start:end, pz]
+                yield ptype, (x, y, z)
+
+    def _read_particle_fields(self, chunks, ptf, selector):
+        chunks = list(chunks)
+        f_part = self._particle_handle
+        p_ind = self.pf.h._particle_indices
+        px, py, pz = (self._particle_fields["particle_pos%s" % ax]
+                      for ax in 'xyz')
+        p_fields = f_part["/tracer particles"]
+        assert(len(ptf) == 1)
+        ptype = ptf.keys()[0]
+        field_list = ptf[ptype]
+        for chunk in chunks:
+            for g1, g2 in particle_sequences(chunk.objs):
+                start = p_ind[g1.id - g1._id_offset]
+                end = p_ind[g2.id - g2._id_offset + 1]
+                x = p_fields[start:end, px]
+                y = p_fields[start:end, py]
+                z = p_fields[start:end, pz]
+                mask = selector.select_points(x, y, z)
+                if mask is None: continue
+                for field in field_list:
+                    fi = self._particle_fields[field]
+                    data = p_fields[start:end, fi]
+                    yield (ptype, field), data[mask]
+
     def _read_data_set(self, grid, field):
         f = self._handle
         f_part = self._particle_handle
@@ -72,7 +129,6 @@
         for field in fields:
             ftype, fname = field
             dt = f["/%s" % fname].dtype
-            dt = prec_accum[dt]
             if dt == "float32": dt = "float64"
             rv[field] = np.empty(size, dtype=dt)
         ng = sum(len(c.objs) for c in chunks)
@@ -83,9 +139,12 @@
             ds = f["/%s" % fname]
             ind = 0
             for chunk in chunks:
-                for g in chunk.objs:
-                    data = ds[g.id - g._id_offset,:,:,:].transpose()
-                    ind += g.select(selector, data, rv[field], ind) # caches
+                for gs in grid_sequences(chunk.objs):
+                    start = gs[0].id - gs[0]._id_offset
+                    end = gs[-1].id - gs[-1]._id_offset + 1
+                    data = ds[start:end,:,:,:].transpose()
+                    for i, g in enumerate(gs):
+                        ind += g.select(selector, data[...,i], rv[field], ind)
         return rv
 
     def _read_chunk_data(self, chunk, fields):
@@ -97,8 +156,11 @@
             ftype, fname = field
             ds = f["/%s" % fname]
             ind = 0
-            for g in chunk.objs:
-                data = ds[g.id - g._id_offset,:,:,:].transpose()
-                rv[g.id][field] = data
+            for gs in grid_sequences(chunk.objs):
+                start = gs[0].id - gs[0]._id_offset
+                end = gs[-1].id - gs[-1]._id_offset + 1
+                data = ds[start:end,:,:,:].transpose()
+                for i, g in enumerate(gs):
+                    rv[g.id][field] = data[...,i]
         return rv

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list