[yt-svn] commit/yt: ngoldbaum: Merged in brittonsmith/yt (pull request #2460)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Thu Dec 8 07:29:07 PST 2016
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/83d2f4c6b660/
Changeset: 83d2f4c6b660
Branch: yt
User: ngoldbaum
Date: 2016-12-08 15:28:40+00:00
Summary: Merged in brittonsmith/yt (pull request #2460)
Save light ray solution to dataset
Affected #: 5 files
diff -r 8b1d0f70261c0a562f0f1115ac9fc379b4ed22cf -r 83d2f4c6b6608ea701acf74d326c44125ae93b5e yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
--- a/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
+++ b/yt/analysis_modules/cosmological_observation/light_ray/light_ray.py
@@ -673,6 +673,25 @@
ds["hubble_constant"] = \
ds["hubble_constant"].to("100*km/(Mpc*s)").d
extra_attrs = {"data_type": "yt_light_ray"}
+
+ # save the light ray solution
+ if len(self.light_ray_solution) > 0:
+ # Convert everything to base unit system now to avoid
+ # problems with different units for each ds.
+ for s in self.light_ray_solution:
+ for f in s:
+ if isinstance(s[f], YTArray):
+ s[f].convert_to_base()
+ for key in self.light_ray_solution[0]:
+ if key in ["next", "previous"]:
+ continue
+ lrsa = [sol[key] for sol in self.light_ray_solution]
+ if isinstance(lrsa[-1], YTArray):
+ to_arr = YTArray
+ else:
+ to_arr = np.array
+ extra_attrs["light_ray_solution_%s" % key] = to_arr(lrsa)
+
field_types = dict([(field, "grid") for field in data.keys()])
# Only return LightRay elements with non-zero density
diff -r 8b1d0f70261c0a562f0f1115ac9fc379b4ed22cf -r 83d2f4c6b6608ea701acf74d326c44125ae93b5e yt/analysis_modules/cosmological_observation/light_ray/tests/test_light_ray.py
--- a/yt/analysis_modules/cosmological_observation/light_ray/tests/test_light_ray.py
+++ b/yt/analysis_modules/cosmological_observation/light_ray/tests/test_light_ray.py
@@ -12,7 +12,10 @@
import numpy as np
+from yt.convenience import \
+ load
from yt.testing import \
+ assert_array_equal, \
requires_file
from yt.analysis_modules.cosmological_observation.api import LightRay
import os
@@ -23,6 +26,19 @@
COSMO_PLUS = "enzo_cosmology_plus/AMRCosmology.enzo"
COSMO_PLUS_SINGLE = "enzo_cosmology_plus/RD0009/RD0009"
+def compare_light_ray_solutions(lr1, lr2):
+ assert len(lr1.light_ray_solution) == len(lr2.light_ray_solution)
+ if len(lr1.light_ray_solution) == 0:
+ return
+ for s1, s2 in zip(lr1.light_ray_solution, lr2.light_ray_solution):
+ for field in s1:
+ if field in ["next", "previous"]:
+ continue
+ if isinstance(s1[field], np.ndarray):
+ assert_array_equal(s1[field], s2[field])
+ else:
+ assert s1[field] == s2[field]
+
@requires_file(COSMO_PLUS)
def test_light_ray_cosmo():
"""
@@ -39,6 +55,9 @@
fields=['temperature', 'density', 'H_number_density'],
data_filename='lightray.h5')
+ ds = load('lightray.h5')
+ compare_light_ray_solutions(lr, ds)
+
# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)
@@ -62,6 +81,9 @@
fields=['temperature', 'density', 'H_number_density'],
data_filename='lightray.h5')
+ ds = load('lightray.h5')
+ compare_light_ray_solutions(lr, ds)
+
# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)
@@ -82,6 +104,9 @@
fields=['temperature', 'density', 'H_number_density'],
data_filename='lightray.h5')
+ ds = load('lightray.h5')
+ compare_light_ray_solutions(lr, ds)
+
# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)
@@ -105,6 +130,9 @@
fields=['temperature', 'density', 'H_number_density'],
data_filename='lightray.h5')
+ ds = load('lightray.h5')
+ compare_light_ray_solutions(lr, ds)
+
# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)
@@ -130,6 +158,9 @@
fields=['temperature', 'density', 'H_number_density'],
data_filename='lightray.h5')
+ ds = load('lightray.h5')
+ compare_light_ray_solutions(lr, ds)
+
# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)
diff -r 8b1d0f70261c0a562f0f1115ac9fc379b4ed22cf -r 83d2f4c6b6608ea701acf74d326c44125ae93b5e yt/frontends/ytdata/data_structures.py
--- a/yt/frontends/ytdata/data_structures.py
+++ b/yt/frontends/ytdata/data_structures.py
@@ -218,13 +218,58 @@
cont_type = parse_h5_attr(f, "container_type")
if data_type is None:
return False
- if data_type in ["yt_light_ray"]:
- return True
if data_type == "yt_data_container" and \
cont_type not in _grid_data_containers:
return True
return False
+class YTDataLightRayDataset(YTDataContainerDataset):
+ """Dataset for saved LightRay objects."""
+
+ def _parse_parameter_file(self):
+ super(YTDataLightRayDataset, self)._parse_parameter_file()
+ self._restore_light_ray_solution()
+
+ def _restore_light_ray_solution(self):
+ """
+ Restore all information asssociate with the light ray solution
+ to its original form.
+ """
+ key = "light_ray_solution"
+ self.light_ray_solution = []
+ lrs_fields = [par for par in self.parameters \
+ if key in par and not par.endswith("_units")]
+ if len(lrs_fields) == 0:
+ return
+ self.light_ray_solution = \
+ [{} for val in self.parameters[lrs_fields[0]]]
+ for sp3 in ["unique_identifier", "filename"]:
+ ksp3 = "%s_%s" % (key, sp3)
+ if ksp3 not in lrs_fields:
+ continue
+ self.parameters[ksp3] = self.parameters[ksp3].astype(str)
+ for field in lrs_fields:
+ field_name = field[len(key)+1:]
+ for i in range(self.parameters[field].shape[0]):
+ self.light_ray_solution[i][field_name] = self.parameters[field][i]
+ if "%s_units" % field in self.parameters:
+ if len(self.parameters[field].shape) > 1:
+ to_val = self.arr
+ else:
+ to_val = self.quan
+ self.light_ray_solution[i][field_name] = \
+ to_val(self.light_ray_solution[i][field_name],
+ self.parameters["%s_units" % field])
+
+ @classmethod
+ def _is_valid(self, *args, **kwargs):
+ if not args[0].endswith(".h5"): return False
+ with h5py.File(args[0], "r") as f:
+ data_type = parse_h5_attr(f, "data_type")
+ if data_type in ["yt_light_ray"]:
+ return True
+ return False
+
class YTSpatialPlotDataset(YTDataContainerDataset):
"""Dataset for saved slices and projections."""
_field_info_class = YTGridFieldInfo
diff -r 8b1d0f70261c0a562f0f1115ac9fc379b4ed22cf -r 83d2f4c6b6608ea701acf74d326c44125ae93b5e yt/frontends/ytdata/utilities.py
--- a/yt/frontends/ytdata/utilities.py
+++ b/yt/frontends/ytdata/utilities.py
@@ -232,5 +232,5 @@
if iterable(val):
val = np.array(val)
if val.dtype.kind == 'U':
- val = val.astype('|S40')
+ val = val.astype('|S')
fh.attrs[str(attr)] = val
diff -r 8b1d0f70261c0a562f0f1115ac9fc379b4ed22cf -r 83d2f4c6b6608ea701acf74d326c44125ae93b5e yt/visualization/plot_modifications.py
--- a/yt/visualization/plot_modifications.py
+++ b/yt/visualization/plot_modifications.py
@@ -2286,8 +2286,8 @@
for ray_ds in self.ray.light_ray_solution:
if ray_ds['unique_identifier'] == plot.ds.unique_identifier:
- start_coord = ray_ds['start']
- end_coord = ray_ds['end']
+ start_coord = plot.ds.arr(ray_ds['start'])
+ end_coord = plot.ds.arr(ray_ds['end'])
return (start_coord, end_coord)
# if no intersection between the plotted dataset and the LightRay
# return a false tuple to pass to start_coord
@@ -2317,9 +2317,11 @@
# if possible, break periodic ray into non-periodic
# segments and add each of them individually
if any(plot.ds.periodicity):
- segments = periodic_ray(start_coord, end_coord,
- left=plot.ds.domain_left_edge,
- right=plot.ds.domain_right_edge)
+ segments = periodic_ray(
+ start_coord.to("code_length"),
+ end_coord.to("code_length"),
+ left=plot.ds.domain_left_edge.to("code_length"),
+ right=plot.ds.domain_right_edge.to("code_length"))
else:
segments = [[start_coord, end_coord]]
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list