[Yt-svn] yt-commit r1062 - in branches/yt-object-serialization: . scripts yt yt/fido yt/lagos yt/lagos/hop yt/raven
mturk at wrangler.dreamhost.com
mturk at wrangler.dreamhost.com
Tue Dec 30 06:31:32 PST 2008
Author: mturk
Date: Tue Dec 30 06:31:31 2008
New Revision: 1062
URL: http://yt.spacepope.org/changeset/1062
Log:
Merge from trunk down to object-serialization
Added:
branches/yt-object-serialization/scripts/yt_lodgeit.py
branches/yt-object-serialization/yt/cmdln.py
branches/yt-object-serialization/yt/convenience.py
Removed:
branches/yt-object-serialization/scripts/yt
Modified:
branches/yt-object-serialization/ (props changed)
branches/yt-object-serialization/setup.py
branches/yt-object-serialization/yt/commands.py
branches/yt-object-serialization/yt/fido/ParameterFileStorage.py
branches/yt-object-serialization/yt/funcs.py
branches/yt-object-serialization/yt/lagos/BaseDataTypes.py
branches/yt-object-serialization/yt/lagos/DerivedQuantities.py
branches/yt-object-serialization/yt/lagos/HierarchyType.py
branches/yt-object-serialization/yt/lagos/OutputTypes.py
branches/yt-object-serialization/yt/lagos/ParallelTools.py
branches/yt-object-serialization/yt/lagos/Profiles.py
branches/yt-object-serialization/yt/lagos/hop/SS_HopOutput.py
branches/yt-object-serialization/yt/mods.py
branches/yt-object-serialization/yt/raven/Callbacks.py
Added: branches/yt-object-serialization/scripts/yt_lodgeit.py
==============================================================================
--- (empty file)
+++ branches/yt-object-serialization/scripts/yt_lodgeit.py Tue Dec 30 06:31:31 2008
@@ -0,0 +1,319 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+ LodgeIt!
+ ~~~~~~~~
+
+ A script that pastes stuff into the enzotools pastebin on
+ paste.enztools.org.
+
+ Modified (very, very slightly) from the original script by the authors
+ below.
+
+ .lodgeitrc / _lodgeitrc
+ -----------------------
+
+ Under UNIX create a file called ``~/.lodgeitrc``, under Windows
+ create a file ``%APPDATA%/_lodgeitrc`` to override defaults::
+
+ language=default_language
+ clipboard=true/false
+ open_browser=true/false
+ encoding=fallback_charset
+
+ :authors: 2007-2008 Georg Brandl <georg at python.org>,
+ 2006 Armin Ronacher <armin.ronacher at active-4.com>,
+ 2006 Matt Good <matt at matt-good.net>,
+ 2005 Raphael Slinckx <raphael at slinckx.net>
+"""
+import os
+import sys
+from optparse import OptionParser
+
+
+SCRIPT_NAME = os.path.basename(sys.argv[0])
+VERSION = '0.3'
+SERVICE_URL = 'http://paste.enzotools.org/'
+SETTING_KEYS = ['author', 'title', 'language', 'private', 'clipboard',
+ 'open_browser']
+
+# global server proxy
+_xmlrpc_service = None
+
+
+def fail(msg, code):
+ """Bail out with an error message."""
+ print >> sys.stderr, 'ERROR: %s' % msg
+ sys.exit(code)
+
+
+def load_default_settings():
+ """Load the defaults from the lodgeitrc file."""
+ settings = {
+ 'language': None,
+ 'clipboard': True,
+ 'open_browser': False,
+ 'encoding': 'iso-8859-15'
+ }
+ rcfile = None
+ if os.name == 'posix':
+ rcfile = os.path.expanduser('~/.lodgeitrc')
+ elif os.name == 'nt' and 'APPDATA' in os.environ:
+ rcfile = os.path.expandvars(r'$APPDATA\_lodgeitrc')
+ if rcfile:
+ try:
+ f = open(rcfile)
+ for line in f:
+ if line.strip()[:1] in '#;':
+ continue
+ p = line.split('=', 1)
+ if len(p) == 2:
+ key = p[0].strip().lower()
+ if key in settings:
+ if key in ('clipboard', 'open_browser'):
+ settings[key] = p[1].strip().lower() in \
+ ('true', '1', 'on', 'yes')
+ else:
+ settings[key] = p[1].strip()
+ f.close()
+ except IOError:
+ pass
+ settings['tags'] = []
+ settings['title'] = None
+ return settings
+
+
+def make_utf8(text, encoding):
+ """Convert a text to UTF-8, brute-force."""
+ try:
+ u = unicode(text, 'utf-8')
+ uenc = 'utf-8'
+ except UnicodeError:
+ try:
+ u = unicode(text, encoding)
+ uenc = 'utf-8'
+ except UnicodeError:
+ u = unicode(text, 'iso-8859-15', 'ignore')
+ uenc = 'iso-8859-15'
+ try:
+ import chardet
+ except ImportError:
+ return u.encode('utf-8')
+ d = chardet.detect(text)
+ if d['encoding'] == uenc:
+ return u.encode('utf-8')
+ return unicode(text, d['encoding'], 'ignore').encode('utf-8')
+
+
+def get_xmlrpc_service():
+ """Create the XMLRPC server proxy and cache it."""
+ global _xmlrpc_service
+ import xmlrpclib
+ if _xmlrpc_service is None:
+ try:
+ _xmlrpc_service = xmlrpclib.ServerProxy(SERVICE_URL + 'xmlrpc/',
+ allow_none=True)
+ except Exception, err:
+ fail('Could not connect to Pastebin: %s' % err, -1)
+ return _xmlrpc_service
+
+
+def copy_url(url):
+ """Copy the url into the clipboard."""
+ # try windows first
+ try:
+ import win32clipboard
+ except ImportError:
+ # then give pbcopy a try. do that before gtk because
+ # gtk might be installed on os x but nobody is interested
+ # in the X11 clipboard there.
+ from subprocess import Popen, PIPE
+ try:
+ client = Popen(['pbcopy'], stdin=PIPE)
+ except OSError:
+ try:
+ import pygtk
+ pygtk.require('2.0')
+ import gtk
+ import gobject
+ except ImportError:
+ return
+ gtk.clipboard_get(gtk.gdk.SELECTION_CLIPBOARD).set_text(url)
+ gobject.idle_add(gtk.main_quit)
+ gtk.main()
+ else:
+ client.stdin.write(url)
+ client.stdin.close()
+ client.wait()
+ else:
+ win32clipboard.OpenClipboard()
+ win32clipboard.EmptyClipboard()
+ win32clipboard.SetClipboardText(url)
+ win32clipboard.CloseClipboard()
+
+
+def open_webbrowser(url):
+ """Open a new browser window."""
+ import webbrowser
+ webbrowser.open(url)
+
+
+def language_exists(language):
+ """Check if a language alias exists."""
+ xmlrpc = get_xmlrpc_service()
+ langs = xmlrpc.pastes.getLanguages()
+ return language in langs
+
+
+def get_mimetype(data, filename):
+ """Try to get MIME type from data."""
+ try:
+ import gnomevfs
+ except ImportError:
+ from mimetypes import guess_type
+ if filename:
+ return guess_type(filename)[0]
+ else:
+ if filename:
+ return gnomevfs.get_mime_type(os.path.abspath(filename))
+ return gnomevfs.get_mime_type_for_data(data)
+
+
+def print_languages():
+ """Print a list of all supported languages, with description."""
+ xmlrpc = get_xmlrpc_service()
+ languages = xmlrpc.pastes.getLanguages().items()
+ languages.sort(lambda a, b: cmp(a[1].lower(), b[1].lower()))
+ print 'Supported Languages:'
+ for alias, name in languages:
+ print ' %-30s%s' % (alias, name)
+
+
+def download_paste(uid):
+ """Download a paste given by ID."""
+ xmlrpc = get_xmlrpc_service()
+ paste = xmlrpc.pastes.getPaste(uid)
+ if not paste:
+ fail('Paste "%s" does not exist.' % uid, 5)
+ print paste['code'].encode('utf-8')
+
+
+def create_paste(code, language, filename, mimetype, private):
+ """Create a new paste."""
+ xmlrpc = get_xmlrpc_service()
+ rv = xmlrpc.pastes.newPaste(language, code, None, filename, mimetype,
+ private)
+ if not rv:
+ fail('Could not create paste. Something went wrong '
+ 'on the server side.', 4)
+ return rv
+
+
+def compile_paste(filenames, langopt):
+ """Create a single paste out of zero, one or multiple files."""
+ def read_file(f):
+ try:
+ return f.read()
+ finally:
+ f.close()
+ mime = ''
+ lang = langopt or ''
+ if not filenames:
+ data = read_file(sys.stdin)
+ if not langopt:
+ mime = get_mimetype(data, '') or ''
+ elif len(filenames) == 1:
+ fname = filenames[0]
+ data = read_file(open(filenames[0], 'rb'))
+ if not langopt:
+ mime = get_mimetype(data, filenames[0]) or ''
+ else:
+ result = []
+ for fname in filenames:
+ data = read_file(open(fname, 'rb'))
+ if langopt:
+ result.append('### %s [%s]\n\n' % (fname, langopt))
+ else:
+ result.append('### %s\n\n' % fname)
+ result.append(data)
+ result.append('\n\n')
+ data = ''.join(result)
+ lang = 'multi'
+ return data, lang, fname, mime
+
+
+def main():
+ """Main script entry point."""
+
+ usage = ('Usage: %%prog [options] [FILE ...]\n\n'
+ 'Read the files and paste their contents to %s.\n'
+ 'If no file is given, read from standard input.\n'
+ 'If multiple files are given, they are put into a single paste.'
+ % SERVICE_URL)
+ parser = OptionParser(usage=usage)
+
+ settings = load_default_settings()
+
+ parser.add_option('-v', '--version', action='store_true',
+ help='Print script version')
+ parser.add_option('-L', '--languages', action='store_true', default=False,
+ help='Retrieve a list of supported languages')
+ parser.add_option('-l', '--language', default=settings['language'],
+ help='Used syntax highlighter for the file')
+ parser.add_option('-e', '--encoding', default=settings['encoding'],
+ help='Specify the encoding of a file (default is '
+ 'utf-8 or guessing if available)')
+ parser.add_option('-b', '--open-browser', dest='open_browser',
+ action='store_true',
+ default=settings['open_browser'],
+ help='Open the paste in a web browser')
+ parser.add_option('-p', '--private', action='store_true', default=False,
+ help='Paste as private')
+ parser.add_option('--no-clipboard', dest='clipboard',
+ action='store_false',
+ default=settings['clipboard'],
+ help="Don't copy the url into the clipboard")
+ parser.add_option('--download', metavar='UID',
+ help='Download a given paste')
+
+ opts, args = parser.parse_args()
+
+ # special modes of operation:
+ # - paste script version
+ if opts.version:
+ print '%s: version %s' % (SCRIPT_NAME, VERSION)
+ sys.exit()
+ # - print list of languages
+ elif opts.languages:
+ print_languages()
+ sys.exit()
+ # - download Paste
+ elif opts.download:
+ download_paste(opts.download)
+ sys.exit()
+
+ # check language if given
+ if opts.language and not language_exists(opts.language):
+ fail('Language %s is not supported.' % opts.language, 3)
+
+ # load file(s)
+ try:
+ data, language, filename, mimetype = compile_paste(args, opts.language)
+ except Exception, err:
+ fail('Error while reading the file(s): %s' % err, 2)
+ if not data:
+ fail('Aborted, no content to paste.', 4)
+
+ # create paste
+ code = make_utf8(data, opts.encoding)
+ pid = create_paste(code, language, filename, mimetype, opts.private)
+ url = '%sshow/%s/' % (SERVICE_URL, pid)
+ print url
+ if opts.open_browser:
+ open_webbrowser(url)
+ if opts.clipboard:
+ copy_url(url)
+
+
+if __name__ == '__main__':
+ sys.exit(main())
Modified: branches/yt-object-serialization/setup.py
==============================================================================
--- branches/yt-object-serialization/setup.py (original)
+++ branches/yt-object-serialization/setup.py Tue Dec 30 06:31:31 2008
@@ -56,10 +56,8 @@
'storage' : ['tables'],
'pdf' : ['pypdf']},
entry_points = { 'console_scripts' : [
- 'yt_timeseries = yt.commands:timeseries',
- 'yt_zoomin = yt.commands:zoomin',
- 'yt_hop = yt.commands:hop_single']
- },
+ 'yt = yt.commands:run_main',
+ ]},
author="Matthew J. Turk",
author_email="matthewturk at gmail.com",
url = "http://yt.enzotools.org/",
Added: branches/yt-object-serialization/yt/cmdln.py
==============================================================================
--- (empty file)
+++ branches/yt-object-serialization/yt/cmdln.py Tue Dec 30 06:31:31 2008
@@ -0,0 +1,1586 @@
+#!/usr/bin/env python
+# Copyright (c) 2002-2007 ActiveState Software Inc.
+# License: MIT (see LICENSE.txt for license details)
+# Author: Trent Mick
+# Home: http://trentm.com/projects/cmdln/
+
+"""An improvement on Python's standard cmd.py module.
+
+As with cmd.py, this module provides "a simple framework for writing
+line-oriented command intepreters." This module provides a 'RawCmdln'
+class that fixes some design flaws in cmd.Cmd, making it more scalable
+and nicer to use for good 'cvs'- or 'svn'-style command line interfaces
+or simple shells. And it provides a 'Cmdln' class that add
+optparse-based option processing. Basically you use it like this:
+
+ import cmdln
+
+ class MySVN(cmdln.Cmdln):
+ name = "svn"
+
+ @cmdln.alias('stat', 'st')
+ @cmdln.option('-v', '--verbose', action='store_true'
+ help='print verbose information')
+ def do_status(self, subcmd, opts, *paths):
+ print "handle 'svn status' command"
+
+ #...
+
+ if __name__ == "__main__":
+ shell = MySVN()
+ retval = shell.main()
+ sys.exit(retval)
+
+See the README.txt or <http://trentm.com/projects/cmdln/> for more
+details.
+"""
+
+__version_info__ = (1, 1, 1)
+__version__ = '.'.join(map(str, __version_info__))
+
+import os
+import sys
+import re
+import cmd
+import optparse
+from pprint import pprint
+import sys
+
+
+
+
+#---- globals
+
+LOOP_ALWAYS, LOOP_NEVER, LOOP_IF_EMPTY = range(3)
+
+# An unspecified optional argument when None is a meaningful value.
+_NOT_SPECIFIED = ("Not", "Specified")
+
+# Pattern to match a TypeError message from a call that
+# failed because of incorrect number of arguments (see
+# Python/getargs.c).
+_INCORRECT_NUM_ARGS_RE = re.compile(
+ r"(takes [\w ]+ )(\d+)( arguments? \()(\d+)( given\))")
+
+
+
+#---- exceptions
+
+class CmdlnError(Exception):
+ """A cmdln.py usage error."""
+ def __init__(self, msg):
+ self.msg = msg
+ def __str__(self):
+ return self.msg
+
+class CmdlnUserError(Exception):
+ """An error by a user of a cmdln-based tool/shell."""
+ pass
+
+
+
+#---- public methods and classes
+
+def alias(*aliases):
+ """Decorator to add aliases for Cmdln.do_* command handlers.
+
+ Example:
+ class MyShell(cmdln.Cmdln):
+ @cmdln.alias("!", "sh")
+ def do_shell(self, argv):
+ #...implement 'shell' command
+ """
+ def decorate(f):
+ if not hasattr(f, "aliases"):
+ f.aliases = []
+ f.aliases += aliases
+ return f
+ return decorate
+
+
+class RawCmdln(cmd.Cmd):
+ """An improved (on cmd.Cmd) framework for building multi-subcommand
+ scripts (think "svn" & "cvs") and simple shells (think "pdb" and
+ "gdb").
+
+ A simple example:
+
+ import cmdln
+
+ class MySVN(cmdln.RawCmdln):
+ name = "svn"
+
+ @cmdln.aliases('stat', 'st')
+ def do_status(self, argv):
+ print "handle 'svn status' command"
+
+ if __name__ == "__main__":
+ shell = MySVN()
+ retval = shell.main()
+ sys.exit(retval)
+
+ See <http://trentm.com/projects/cmdln> for more information.
+ """
+ name = None # if unset, defaults basename(sys.argv[0])
+ prompt = None # if unset, defaults to self.name+"> "
+ version = None # if set, default top-level options include --version
+
+ # Default messages for some 'help' command error cases.
+ # They are interpolated with one arg: the command.
+ nohelp = "no help on '%s'"
+ unknowncmd = "unknown command: '%s'"
+
+ helpindent = '' # string with which to indent help output
+
+ def __init__(self, completekey='tab',
+ stdin=None, stdout=None, stderr=None):
+ """Cmdln(completekey='tab', stdin=None, stdout=None, stderr=None)
+
+ The optional argument 'completekey' is the readline name of a
+ completion key; it defaults to the Tab key. If completekey is
+ not None and the readline module is available, command completion
+ is done automatically.
+
+ The optional arguments 'stdin', 'stdout' and 'stderr' specify
+ alternate input, output and error output file objects; if not
+ specified, sys.* are used.
+
+ If 'stdout' but not 'stderr' is specified, stdout is used for
+ error output. This is to provide least surprise for users used
+ to only the 'stdin' and 'stdout' options with cmd.Cmd.
+ """
+ import sys
+ if self.name is None:
+ self.name = os.path.basename(sys.argv[0])
+ if self.prompt is None:
+ self.prompt = self.name+"> "
+ self._name_str = self._str(self.name)
+ self._prompt_str = self._str(self.prompt)
+ if stdin is not None:
+ self.stdin = stdin
+ else:
+ self.stdin = sys.stdin
+ if stdout is not None:
+ self.stdout = stdout
+ else:
+ self.stdout = sys.stdout
+ if stderr is not None:
+ self.stderr = stderr
+ elif stdout is not None:
+ self.stderr = stdout
+ else:
+ self.stderr = sys.stderr
+ self.cmdqueue = []
+ self.completekey = completekey
+ self.cmdlooping = False
+
+ def get_optparser(self):
+ """Hook for subclasses to set the option parser for the
+ top-level command/shell.
+
+ This option parser is used retrieved and used by `.main()' to
+ handle top-level options.
+
+ The default implements a single '-h|--help' option. Sub-classes
+ can return None to have no options at the top-level. Typically
+ an instance of CmdlnOptionParser should be returned.
+ """
+ version = (self.version is not None
+ and "%s %s" % (self._name_str, self.version)
+ or None)
+ return CmdlnOptionParser(self, version=version)
+
+ def postoptparse(self):
+ """Hook method executed just after `.main()' parses top-level
+ options.
+
+ When called `self.options' holds the results of the option parse.
+ """
+ pass
+
+ def main(self, argv=None, loop=LOOP_NEVER):
+ """A possible mainline handler for a script, like so:
+
+ import cmdln
+ class MyCmd(cmdln.Cmdln):
+ name = "mycmd"
+ ...
+
+ if __name__ == "__main__":
+ MyCmd().main()
+
+ By default this will use sys.argv to issue a single command to
+ 'MyCmd', then exit. The 'loop' argument can be use to control
+ interactive shell behaviour.
+
+ Arguments:
+ "argv" (optional, default sys.argv) is the command to run.
+ It must be a sequence, where the first element is the
+ command name and subsequent elements the args for that
+ command.
+ "loop" (optional, default LOOP_NEVER) is a constant
+ indicating if a command loop should be started (i.e. an
+ interactive shell). Valid values (constants on this module):
+ LOOP_ALWAYS start loop and run "argv", if any
+ LOOP_NEVER run "argv" (or .emptyline()) and exit
+ LOOP_IF_EMPTY run "argv", if given, and exit;
+ otherwise, start loop
+ """
+ if argv is None:
+ import sys
+ argv = sys.argv
+ else:
+ argv = argv[:] # don't modify caller's list
+
+ self.optparser = self.get_optparser()
+ if self.optparser: # i.e. optparser=None means don't process for opts
+ try:
+ self.options, args = self.optparser.parse_args(argv[1:])
+ except CmdlnUserError, ex:
+ msg = "%s: %s\nTry '%s help' for info.\n"\
+ % (self.name, ex, self.name)
+ self.stderr.write(self._str(msg))
+ self.stderr.flush()
+ return 1
+ except StopOptionProcessing, ex:
+ return 0
+ else:
+ self.options, args = None, argv[1:]
+ self.postoptparse()
+
+ if loop == LOOP_ALWAYS:
+ if args:
+ self.cmdqueue.append(args)
+ return self.cmdloop()
+ elif loop == LOOP_NEVER:
+ if args:
+ return self.cmd(args)
+ else:
+ return self.emptyline()
+ elif loop == LOOP_IF_EMPTY:
+ if args:
+ return self.cmd(args)
+ else:
+ return self.cmdloop()
+
+ def cmd(self, argv):
+ """Run one command and exit.
+
+ "argv" is the arglist for the command to run. argv[0] is the
+ command to run. If argv is an empty list then the
+ 'emptyline' handler is run.
+
+ Returns the return value from the command handler.
+ """
+ assert (isinstance(argv, (list, tuple)),
+ "'argv' is not a sequence: %r" % argv)
+ retval = None
+ try:
+ argv = self.precmd(argv)
+ retval = self.onecmd(argv)
+ self.postcmd(argv)
+ except:
+ if not self.cmdexc(argv):
+ raise
+ retval = 1
+ return retval
+
+ def _str(self, s):
+ """Safely convert the given str/unicode to a string for printing."""
+ try:
+ return str(s)
+ except UnicodeError:
+ #XXX What is the proper encoding to use here? 'utf-8' seems
+ # to work better than "getdefaultencoding" (usually
+ # 'ascii'), on OS X at least.
+ #import sys
+ #return s.encode(sys.getdefaultencoding(), "replace")
+ return s.encode("utf-8", "replace")
+
+ def cmdloop(self, intro=None):
+ """Repeatedly issue a prompt, accept input, parse into an argv, and
+ dispatch (via .precmd(), .onecmd() and .postcmd()), passing them
+ the argv. In other words, start a shell.
+
+ "intro" (optional) is a introductory message to print when
+ starting the command loop. This overrides the class
+ "intro" attribute, if any.
+ """
+ self.cmdlooping = True
+ self.preloop()
+ if self.use_rawinput and self.completekey:
+ try:
+ import readline
+ self.old_completer = readline.get_completer()
+ readline.set_completer(self.complete)
+ readline.parse_and_bind(self.completekey+": complete")
+ except ImportError:
+ pass
+ try:
+ if intro is None:
+ intro = self.intro
+ if intro:
+ intro_str = self._str(intro)
+ self.stdout.write(intro_str+'\n')
+ self.stop = False
+ retval = None
+ while not self.stop:
+ if self.cmdqueue:
+ argv = self.cmdqueue.pop(0)
+ assert (isinstance(argv, (list, tuple)),
+ "item on 'cmdqueue' is not a sequence: %r" % argv)
+ else:
+ if self.use_rawinput:
+ try:
+ line = raw_input(self._prompt_str)
+ except EOFError:
+ line = 'EOF'
+ else:
+ self.stdout.write(self._prompt_str)
+ self.stdout.flush()
+ line = self.stdin.readline()
+ if not len(line):
+ line = 'EOF'
+ else:
+ line = line[:-1] # chop '\n'
+ argv = line2argv(line)
+ try:
+ argv = self.precmd(argv)
+ retval = self.onecmd(argv)
+ self.postcmd(argv)
+ except:
+ if not self.cmdexc(argv):
+ raise
+ retval = 1
+ self.lastretval = retval
+ self.postloop()
+ finally:
+ if self.use_rawinput and self.completekey:
+ try:
+ import readline
+ readline.set_completer(self.old_completer)
+ except ImportError:
+ pass
+ self.cmdlooping = False
+ return retval
+
+ def precmd(self, argv):
+ """Hook method executed just before the command argv is
+ interpreted, but after the input prompt is generated and issued.
+
+ "argv" is the cmd to run.
+
+ Returns an argv to run (i.e. this method can modify the command
+ to run).
+ """
+ return argv
+
+ def postcmd(self, argv):
+ """Hook method executed just after a command dispatch is finished.
+
+ "argv" is the command that was run.
+ """
+ pass
+
+ def cmdexc(self, argv):
+ """Called if an exception is raised in any of precmd(), onecmd(),
+ or postcmd(). If True is returned, the exception is deemed to have
+ been dealt with. Otherwise, the exception is re-raised.
+
+ The default implementation handles CmdlnUserError's, which
+ typically correspond to user error in calling commands (as
+ opposed to programmer error in the design of the script using
+ cmdln.py).
+ """
+ import sys
+ type, exc, traceback = sys.exc_info()
+ if isinstance(exc, CmdlnUserError):
+ msg = "%s %s: %s\nTry '%s help %s' for info.\n"\
+ % (self.name, argv[0], exc, self.name, argv[0])
+ self.stderr.write(self._str(msg))
+ self.stderr.flush()
+ return True
+
+ def onecmd(self, argv):
+ if not argv:
+ return self.emptyline()
+ self.lastcmd = argv
+ cmdname = self._get_canonical_cmd_name(argv[0])
+ if cmdname:
+ handler = self._get_cmd_handler(cmdname)
+ if handler:
+ return self._dispatch_cmd(handler, argv)
+ return self.default(argv)
+
+ def _dispatch_cmd(self, handler, argv):
+ return handler(argv)
+
+ def default(self, argv):
+ """Hook called to handle a command for which there is no handler.
+
+ "argv" is the command and arguments to run.
+
+ The default implementation writes and error message to stderr
+ and returns an error exit status.
+
+ Returns a numeric command exit status.
+ """
+ errmsg = self._str(self.unknowncmd % (argv[0],))
+ if self.cmdlooping:
+ self.stderr.write(errmsg+"\n")
+ else:
+ self.stderr.write("%s: %s\nTry '%s help' for info.\n"
+ % (self._name_str, errmsg, self._name_str))
+ self.stderr.flush()
+ return 1
+
+ def parseline(self, line):
+ # This is used by Cmd.complete (readline completer function) to
+ # massage the current line buffer before completion processing.
+ # We override to drop special '!' handling.
+ line = line.strip()
+ if not line:
+ return None, None, line
+ elif line[0] == '?':
+ line = 'help ' + line[1:]
+ i, n = 0, len(line)
+ while i < n and line[i] in self.identchars: i = i+1
+ cmd, arg = line[:i], line[i:].strip()
+ return cmd, arg, line
+
+ def helpdefault(self, cmd, known):
+ """Hook called to handle help on a command for which there is no
+ help handler.
+
+ "cmd" is the command name on which help was requested.
+ "known" is a boolean indicating if this command is known
+ (i.e. if there is a handler for it).
+
+ Returns a return code.
+ """
+ if known:
+ msg = self._str(self.nohelp % (cmd,))
+ if self.cmdlooping:
+ self.stderr.write(msg + '\n')
+ else:
+ self.stderr.write("%s: %s\n" % (self.name, msg))
+ else:
+ msg = self.unknowncmd % (cmd,)
+ if self.cmdlooping:
+ self.stderr.write(msg + '\n')
+ else:
+ self.stderr.write("%s: %s\n"
+ "Try '%s help' for info.\n"
+ % (self.name, msg, self.name))
+ self.stderr.flush()
+ return 1
+
+ def do_help(self, argv):
+ """${cmd_name}: give detailed help on a specific sub-command
+
+ Usage:
+ ${name} help [COMMAND]
+ """
+ if len(argv) > 1: # asking for help on a particular command
+ doc = None
+ cmdname = self._get_canonical_cmd_name(argv[1]) or argv[1]
+ if not cmdname:
+ return self.helpdefault(argv[1], False)
+ else:
+ helpfunc = getattr(self, "help_"+cmdname, None)
+ if helpfunc:
+ doc = helpfunc()
+ else:
+ handler = self._get_cmd_handler(cmdname)
+ if handler:
+ doc = handler.__doc__
+ if doc is None:
+ return self.helpdefault(argv[1], handler != None)
+ else: # bare "help" command
+ doc = self.__class__.__doc__ # try class docstring
+ if doc is None:
+ # Try to provide some reasonable useful default help.
+ if self.cmdlooping: prefix = ""
+ else: prefix = self.name+' '
+ doc = """Usage:
+ %sCOMMAND [ARGS...]
+ %shelp [COMMAND]
+
+ ${option_list}
+ ${command_list}
+ ${help_list}
+ """ % (prefix, prefix)
+ cmdname = None
+
+ if doc: # *do* have help content, massage and print that
+ doc = self._help_reindent(doc)
+ doc = self._help_preprocess(doc, cmdname)
+ doc = doc.rstrip() + '\n' # trim down trailing space
+ self.stdout.write(self._str(doc))
+ self.stdout.flush()
+ do_help.aliases = ["?"]
+
+ def _help_reindent(self, help, indent=None):
+ """Hook to re-indent help strings before writing to stdout.
+
+ "help" is the help content to re-indent
+ "indent" is a string with which to indent each line of the
+ help content after normalizing. If unspecified or None
+ then the default is use: the 'self.helpindent' class
+ attribute. By default this is the empty string, i.e.
+ no indentation.
+
+ By default, all common leading whitespace is removed and then
+ the lot is indented by 'self.helpindent'. When calculating the
+ common leading whitespace the first line is ignored -- hence
+ help content for Conan can be written as follows and have the
+ expected indentation:
+
+ def do_crush(self, ...):
+ '''${cmd_name}: crush your enemies, see them driven before you...
+
+ c.f. Conan the Barbarian'''
+ """
+ if indent is None:
+ indent = self.helpindent
+ lines = help.splitlines(0)
+ _dedentlines(lines, skip_first_line=True)
+ lines = [(indent+line).rstrip() for line in lines]
+ return '\n'.join(lines)
+
+ def _help_preprocess(self, help, cmdname):
+ """Hook to preprocess a help string before writing to stdout.
+
+ "help" is the help string to process.
+ "cmdname" is the canonical sub-command name for which help
+ is being given, or None if the help is not specific to a
+ command.
+
+ By default the following template variables are interpolated in
+ help content. (Note: these are similar to Python 2.4's
+ string.Template interpolation but not quite.)
+
+ ${name}
+ The tool's/shell's name, i.e. 'self.name'.
+ ${option_list}
+ A formatted table of options for this shell/tool.
+ ${command_list}
+ A formatted table of available sub-commands.
+ ${help_list}
+ A formatted table of additional help topics (i.e. 'help_*'
+ methods with no matching 'do_*' method).
+ ${cmd_name}
+ The name (and aliases) for this sub-command formatted as:
+ "NAME (ALIAS1, ALIAS2, ...)".
+ ${cmd_usage}
+ A formatted usage block inferred from the command function
+ signature.
+ ${cmd_option_list}
+ A formatted table of options for this sub-command. (This is
+ only available for commands using the optparse integration,
+ i.e. using @cmdln.option decorators or manually setting the
+ 'optparser' attribute on the 'do_*' method.)
+
+ Returns the processed help.
+ """
+ preprocessors = {
+ "${name}": self._help_preprocess_name,
+ "${option_list}": self._help_preprocess_option_list,
+ "${command_list}": self._help_preprocess_command_list,
+ "${help_list}": self._help_preprocess_help_list,
+ "${cmd_name}": self._help_preprocess_cmd_name,
+ "${cmd_usage}": self._help_preprocess_cmd_usage,
+ "${cmd_option_list}": self._help_preprocess_cmd_option_list,
+ }
+
+ for marker, preprocessor in preprocessors.items():
+ if marker in help:
+ help = preprocessor(help, cmdname)
+ return help
+
+ def _help_preprocess_name(self, help, cmdname=None):
+ return help.replace("${name}", self.name)
+
+ def _help_preprocess_option_list(self, help, cmdname=None):
+ marker = "${option_list}"
+ indent, indent_width = _get_indent(marker, help)
+ suffix = _get_trailing_whitespace(marker, help)
+
+ if self.optparser:
+ # Setup formatting options and format.
+ # - Indentation of 4 is better than optparse default of 2.
+ # C.f. Damian Conway's discussion of this in Perl Best
+ # Practices.
+ self.optparser.formatter.indent_increment = 4
+ self.optparser.formatter.current_indent = indent_width
+ block = self.optparser.format_option_help() + '\n'
+ else:
+ block = ""
+
+ help = help.replace(indent+marker+suffix, block, 1)
+ return help
+
+
+ def _help_preprocess_command_list(self, help, cmdname=None):
+ marker = "${command_list}"
+ indent, indent_width = _get_indent(marker, help)
+ suffix = _get_trailing_whitespace(marker, help)
+
+ # Find any aliases for commands.
+ token2canonical = self._get_canonical_map()
+ aliases = {}
+ for token, cmdname in token2canonical.items():
+ if token == cmdname: continue
+ aliases.setdefault(cmdname, []).append(token)
+
+ # Get the list of (non-hidden) commands and their
+ # documentation, if any.
+ cmdnames = {} # use a dict to strip duplicates
+ for attr in self.get_names():
+ if attr.startswith("do_"):
+ cmdnames[attr[3:]] = True
+ cmdnames = cmdnames.keys()
+ cmdnames.sort()
+ linedata = []
+ for cmdname in cmdnames:
+ if aliases.get(cmdname):
+ a = aliases[cmdname]
+ a.sort()
+ cmdstr = "%s (%s)" % (cmdname, ", ".join(a))
+ else:
+ cmdstr = cmdname
+ doc = None
+ try:
+ helpfunc = getattr(self, 'help_'+cmdname)
+ except AttributeError:
+ handler = self._get_cmd_handler(cmdname)
+ if handler:
+ doc = handler.__doc__
+ else:
+ doc = helpfunc()
+
+ # Strip "${cmd_name}: " from the start of a command's doc. Best
+ # practice dictates that command help strings begin with this, but
+ # it isn't at all wanted for the command list.
+ to_strip = "${cmd_name}:"
+ if doc and doc.startswith(to_strip):
+ #log.debug("stripping %r from start of %s's help string",
+ # to_strip, cmdname)
+ doc = doc[len(to_strip):].lstrip()
+ linedata.append( (cmdstr, doc) )
+
+ if linedata:
+ subindent = indent + ' '*4
+ lines = _format_linedata(linedata, subindent, indent_width+4)
+ block = indent + "Commands:\n" \
+ + '\n'.join(lines) + "\n\n"
+ help = help.replace(indent+marker+suffix, block, 1)
+ return help
+
+ def _gen_names_and_attrs(self):
+ # Inheritance says we have to look in class and
+ # base classes; order is not important.
+ names = []
+ classes = [self.__class__]
+ while classes:
+ aclass = classes.pop(0)
+ if aclass.__bases__:
+ classes = classes + list(aclass.__bases__)
+ for name in dir(aclass):
+ yield (name, getattr(aclass, name))
+
+ def _help_preprocess_help_list(self, help, cmdname=None):
+ marker = "${help_list}"
+ indent, indent_width = _get_indent(marker, help)
+ suffix = _get_trailing_whitespace(marker, help)
+
+ # Determine the additional help topics, if any.
+ helpnames = {}
+ token2cmdname = self._get_canonical_map()
+ for attrname, attr in self._gen_names_and_attrs():
+ if not attrname.startswith("help_"): continue
+ helpname = attrname[5:]
+ if helpname not in token2cmdname:
+ helpnames[helpname] = attr
+
+ if helpnames:
+ linedata = [(n, a.__doc__ or "") for n, a in helpnames.items()]
+ linedata.sort()
+
+ subindent = indent + ' '*4
+ lines = _format_linedata(linedata, subindent, indent_width+4)
+ block = (indent
+ + "Additional help topics (run `%s help TOPIC'):\n" % self.name
+ + '\n'.join(lines)
+ + "\n\n")
+ else:
+ block = ''
+ help = help.replace(indent+marker+suffix, block, 1)
+ return help
+
+ def _help_preprocess_cmd_name(self, help, cmdname=None):
+ marker = "${cmd_name}"
+ handler = self._get_cmd_handler(cmdname)
+ if not handler:
+ raise CmdlnError("cannot preprocess '%s' into help string: "
+ "could not find command handler for %r"
+ % (marker, cmdname))
+ s = cmdname
+ if hasattr(handler, "aliases"):
+ s += " (%s)" % (", ".join(handler.aliases))
+ help = help.replace(marker, s)
+ return help
+
+ #TODO: this only makes sense as part of the Cmdln class.
+ # Add hooks to add help preprocessing template vars and put
+ # this one on that class.
+ def _help_preprocess_cmd_usage(self, help, cmdname=None):
+ marker = "${cmd_usage}"
+ handler = self._get_cmd_handler(cmdname)
+ if not handler:
+ raise CmdlnError("cannot preprocess '%s' into help string: "
+ "could not find command handler for %r"
+ % (marker, cmdname))
+ indent, indent_width = _get_indent(marker, help)
+ suffix = _get_trailing_whitespace(marker, help)
+
+ # Extract the introspection bits we need.
+ func = handler.im_func
+ if func.func_defaults:
+ func_defaults = list(func.func_defaults)
+ else:
+ func_defaults = []
+ co_argcount = func.func_code.co_argcount
+ co_varnames = func.func_code.co_varnames
+ co_flags = func.func_code.co_flags
+ CO_FLAGS_ARGS = 4
+ CO_FLAGS_KWARGS = 8
+
+ # Adjust argcount for possible *args and **kwargs arguments.
+ argcount = co_argcount
+ if co_flags & CO_FLAGS_ARGS: argcount += 1
+ if co_flags & CO_FLAGS_KWARGS: argcount += 1
+
+ # Determine the usage string.
+ usage = "%s %s" % (self.name, cmdname)
+ if argcount <= 2: # handler ::= do_FOO(self, argv)
+ usage += " [ARGS...]"
+ elif argcount >= 3: # handler ::= do_FOO(self, subcmd, opts, ...)
+ argnames = list(co_varnames[3:argcount])
+ tail = ""
+ if co_flags & CO_FLAGS_KWARGS:
+ name = argnames.pop(-1)
+ import warnings
+ # There is no generally accepted mechanism for passing
+ # keyword arguments from the command line. Could
+ # *perhaps* consider: arg=value arg2=value2 ...
+ warnings.warn("argument '**%s' on '%s.%s' command "
+ "handler will never get values"
+ % (name, self.__class__.__name__,
+ func.func_name))
+ if co_flags & CO_FLAGS_ARGS:
+ name = argnames.pop(-1)
+ tail = "[%s...]" % name.upper()
+ while func_defaults:
+ func_defaults.pop(-1)
+ name = argnames.pop(-1)
+ tail = "[%s%s%s]" % (name.upper(), (tail and ' ' or ''), tail)
+ while argnames:
+ name = argnames.pop(-1)
+ tail = "%s %s" % (name.upper(), tail)
+ usage += ' ' + tail
+
+ block_lines = [
+ self.helpindent + "Usage:",
+ self.helpindent + ' '*4 + usage
+ ]
+ block = '\n'.join(block_lines) + '\n\n'
+
+ help = help.replace(indent+marker+suffix, block, 1)
+ return help
+
+ #TODO: this only makes sense as part of the Cmdln class.
+ # Add hooks to add help preprocessing template vars and put
+ # this one on that class.
+ def _help_preprocess_cmd_option_list(self, help, cmdname=None):
+ marker = "${cmd_option_list}"
+ handler = self._get_cmd_handler(cmdname)
+ if not handler:
+ raise CmdlnError("cannot preprocess '%s' into help string: "
+ "could not find command handler for %r"
+ % (marker, cmdname))
+ indent, indent_width = _get_indent(marker, help)
+ suffix = _get_trailing_whitespace(marker, help)
+ if hasattr(handler, "optparser"):
+ # Setup formatting options and format.
+ # - Indentation of 4 is better than optparse default of 2.
+ # C.f. Damian Conway's discussion of this in Perl Best
+ # Practices.
+ handler.optparser.formatter.indent_increment = 4
+ handler.optparser.formatter.current_indent = indent_width
+ block = handler.optparser.format_option_help() + '\n'
+ else:
+ block = ""
+
+ help = help.replace(indent+marker+suffix, block, 1)
+ return help
+
+ def _get_canonical_cmd_name(self, token):
+ map = self._get_canonical_map()
+ return map.get(token, None)
+
+ def _get_canonical_map(self):
+ """Return a mapping of available command names and aliases to
+ their canonical command name.
+ """
+ cacheattr = "_token2canonical"
+ if not hasattr(self, cacheattr):
+ # Get the list of commands and their aliases, if any.
+ token2canonical = {}
+ cmd2funcname = {} # use a dict to strip duplicates
+ for attr in self.get_names():
+ if attr.startswith("do_"): cmdname = attr[3:]
+ elif attr.startswith("_do_"): cmdname = attr[4:]
+ else:
+ continue
+ cmd2funcname[cmdname] = attr
+ token2canonical[cmdname] = cmdname
+ for cmdname, funcname in cmd2funcname.items(): # add aliases
+ func = getattr(self, funcname)
+ aliases = getattr(func, "aliases", [])
+ for alias in aliases:
+ if alias in cmd2funcname:
+ import warnings
+ warnings.warn("'%s' alias for '%s' command conflicts "
+ "with '%s' handler"
+ % (alias, cmdname, cmd2funcname[alias]))
+ continue
+ token2canonical[alias] = cmdname
+ setattr(self, cacheattr, token2canonical)
+ return getattr(self, cacheattr)
+
+ def _get_cmd_handler(self, cmdname):
+ handler = None
+ try:
+ handler = getattr(self, 'do_' + cmdname)
+ except AttributeError:
+ try:
+ # Private command handlers begin with "_do_".
+ handler = getattr(self, '_do_' + cmdname)
+ except AttributeError:
+ pass
+ return handler
+
+ def _do_EOF(self, argv):
+ # Default EOF handler
+ # Note: an actual EOF is redirected to this command.
+ #TODO: separate name for this. Currently it is available from
+ # command-line. Is that okay?
+ self.stdout.write('\n')
+ self.stdout.flush()
+ self.stop = True
+
+ def emptyline(self):
+ # Different from cmd.Cmd: don't repeat the last command for an
+ # emptyline.
+ if self.cmdlooping:
+ pass
+ else:
+ return self.do_help(["help"])
+
+
+#---- optparse.py extension to fix (IMO) some deficiencies
+#
+# See the class _OptionParserEx docstring for details.
+#
+
+class StopOptionProcessing(Exception):
+ """Indicate that option *and argument* processing should stop
+ cleanly. This is not an error condition. It is similar in spirit to
+ StopIteration. This is raised by _OptionParserEx's default "help"
+ and "version" option actions and can be raised by custom option
+ callbacks too.
+
+ Hence the typical CmdlnOptionParser (a subclass of _OptionParserEx)
+ usage is:
+
+ parser = CmdlnOptionParser(mycmd)
+ parser.add_option("-f", "--force", dest="force")
+ ...
+ try:
+ opts, args = parser.parse_args()
+ except StopOptionProcessing:
+ # normal termination, "--help" was probably given
+ sys.exit(0)
+ """
+
+class _OptionParserEx(optparse.OptionParser):
+ """An optparse.OptionParser that uses exceptions instead of sys.exit.
+
+ This class is an extension of optparse.OptionParser that differs
+ as follows:
+ - Correct (IMO) the default OptionParser error handling to never
+ sys.exit(). Instead OptParseError exceptions are passed through.
+ - Add the StopOptionProcessing exception (a la StopIteration) to
+ indicate normal termination of option processing.
+ See StopOptionProcessing's docstring for details.
+
+ I'd also like to see the following in the core optparse.py, perhaps
+ as a RawOptionParser which would serve as a base class for the more
+ generally used OptionParser (that works as current):
+ - Remove the implicit addition of the -h|--help and --version
+ options. They can get in the way (e.g. if want '-?' and '-V' for
+ these as well) and it is not hard to do:
+ optparser.add_option("-h", "--help", action="help")
+ optparser.add_option("--version", action="version")
+ These are good practices, just not valid defaults if they can
+ get in the way.
+ """
+ def error(self, msg):
+ raise optparse.OptParseError(msg)
+
+ def exit(self, status=0, msg=None):
+ if status == 0:
+ raise StopOptionProcessing(msg)
+ else:
+ #TODO: don't lose status info here
+ raise optparse.OptParseError(msg)
+
+
+
+#---- optparse.py-based option processing support
+
+class CmdlnOptionParser(_OptionParserEx):
+ """An optparse.OptionParser class more appropriate for top-level
+ Cmdln options. For parsing of sub-command options, see
+ SubCmdOptionParser.
+
+ Changes:
+ - disable_interspersed_args() by default, because a Cmdln instance
+ has sub-commands which may themselves have options.
+ - Redirect print_help() to the Cmdln.do_help() which is better
+ equiped to handle the "help" action.
+ - error() will raise a CmdlnUserError: OptionParse.error() is meant
+ to be called for user errors. Raising a well-known error here can
+ make error handling clearer.
+ - Also see the changes in _OptionParserEx.
+ """
+ def __init__(self, cmdln, **kwargs):
+ self.cmdln = cmdln
+ kwargs["prog"] = self.cmdln.name
+ _OptionParserEx.__init__(self, **kwargs)
+ self.disable_interspersed_args()
+
+ def print_help(self, file=None):
+ self.cmdln.onecmd(["help"])
+
+ def error(self, msg):
+ raise CmdlnUserError(msg)
+
+
+class SubCmdOptionParser(_OptionParserEx):
+ def set_cmdln_info(self, cmdln, subcmd):
+ """Called by Cmdln to pass relevant info about itself needed
+ for print_help().
+ """
+ self.cmdln = cmdln
+ self.subcmd = subcmd
+
+ def print_help(self, file=None):
+ self.cmdln.onecmd(["help", self.subcmd])
+
+ def error(self, msg):
+ raise CmdlnUserError(msg)
+
+
+def option(*args, **kwargs):
+ """Decorator to add an option to the optparser argument of a Cmdln
+ subcommand.
+
+ Example:
+ class MyShell(cmdln.Cmdln):
+ @cmdln.option("-f", "--force", help="force removal")
+ def do_remove(self, subcmd, opts, *args):
+ #...
+ """
+ #XXX Is there a possible optimization for many options to not have a
+ # large stack depth here?
+ def decorate(f):
+ if not hasattr(f, "optparser"):
+ f.optparser = SubCmdOptionParser()
+ f.optparser.add_option(*args, **kwargs)
+ return f
+ return decorate
+
+
+class Cmdln(RawCmdln):
+ """An improved (on cmd.Cmd) framework for building multi-subcommand
+ scripts (think "svn" & "cvs") and simple shells (think "pdb" and
+ "gdb").
+
+ A simple example:
+
+ import cmdln
+
+ class MySVN(cmdln.Cmdln):
+ name = "svn"
+
+ @cmdln.aliases('stat', 'st')
+ @cmdln.option('-v', '--verbose', action='store_true'
+ help='print verbose information')
+ def do_status(self, subcmd, opts, *paths):
+ print "handle 'svn status' command"
+
+ #...
+
+ if __name__ == "__main__":
+ shell = MySVN()
+ retval = shell.main()
+ sys.exit(retval)
+
+ 'Cmdln' extends 'RawCmdln' by providing optparse option processing
+ integration. See this class' _dispatch_cmd() docstring and
+ <http://trentm.com/projects/cmdln> for more information.
+ """
+ def _dispatch_cmd(self, handler, argv):
+ """Introspect sub-command handler signature to determine how to
+ dispatch the command. The raw handler provided by the base
+ 'RawCmdln' class is still supported:
+
+ def do_foo(self, argv):
+ # 'argv' is the vector of command line args, argv[0] is
+ # the command name itself (i.e. "foo" or an alias)
+ pass
+
+ In addition, if the handler has more than 2 arguments option
+ processing is automatically done (using optparse):
+
+ @cmdln.option('-v', '--verbose', action='store_true')
+ def do_bar(self, subcmd, opts, *args):
+ # subcmd = <"bar" or an alias>
+ # opts = <an optparse.Values instance>
+ if opts.verbose:
+ print "lots of debugging output..."
+ # args = <tuple of arguments>
+ for arg in args:
+ bar(arg)
+
+ TODO: explain that "*args" can be other signatures as well.
+
+ The `cmdln.option` decorator corresponds to an `add_option()`
+ method call on an `optparse.OptionParser` instance.
+
+ You can declare a specific number of arguments:
+
+ @cmdln.option('-v', '--verbose', action='store_true')
+ def do_bar2(self, subcmd, opts, bar_one, bar_two):
+ #...
+
+ and an appropriate error message will be raised/printed if the
+ command is called with a different number of args.
+ """
+ co_argcount = handler.im_func.func_code.co_argcount
+ if co_argcount == 2: # handler ::= do_foo(self, argv)
+ return handler(argv)
+ elif co_argcount >= 3: # handler ::= do_foo(self, subcmd, opts, ...)
+ try:
+ optparser = handler.optparser
+ except AttributeError:
+ optparser = handler.im_func.optparser = SubCmdOptionParser()
+ assert isinstance(optparser, SubCmdOptionParser)
+ optparser.set_cmdln_info(self, argv[0])
+ try:
+ opts, args = optparser.parse_args(argv[1:])
+ except StopOptionProcessing:
+ #TODO: this doesn't really fly for a replacement of
+ # optparse.py behaviour, does it?
+ return 0 # Normal command termination
+
+ try:
+ return handler(argv[0], opts, *args)
+ except TypeError, ex:
+ # Some TypeError's are user errors:
+ # do_foo() takes at least 4 arguments (3 given)
+ # do_foo() takes at most 5 arguments (6 given)
+ # do_foo() takes exactly 5 arguments (6 given)
+ # Raise CmdlnUserError for these with a suitably
+ # massaged error message.
+ import sys
+ tb = sys.exc_info()[2] # the traceback object
+ if tb.tb_next is not None:
+ # If the traceback is more than one level deep, then the
+ # TypeError do *not* happen on the "handler(...)" call
+ # above. In that we don't want to handle it specially
+ # here: it would falsely mask deeper code errors.
+ raise
+ msg = ex.args[0]
+ match = _INCORRECT_NUM_ARGS_RE.search(msg)
+ if match:
+ msg = list(match.groups())
+ msg[1] = int(msg[1]) - 3
+ if msg[1] == 1:
+ msg[2] = msg[2].replace("arguments", "argument")
+ msg[3] = int(msg[3]) - 3
+ msg = ''.join(map(str, msg))
+ raise CmdlnUserError(msg)
+ else:
+ raise
+ else:
+ raise CmdlnError("incorrect argcount for %s(): takes %d, must "
+ "take 2 for 'argv' signature or 3+ for 'opts' "
+ "signature" % (handler.__name__, co_argcount))
+
+
+
+#---- internal support functions
+
+def _format_linedata(linedata, indent, indent_width):
+ """Format specific linedata into a pleasant layout.
+
+ "linedata" is a list of 2-tuples of the form:
+ (<item-display-string>, <item-docstring>)
+ "indent" is a string to use for one level of indentation
+ "indent_width" is a number of columns by which the
+ formatted data will be indented when printed.
+
+ The <item-display-string> column is held to 15 columns.
+ """
+ lines = []
+ WIDTH = 78 - indent_width
+ SPACING = 2
+ NAME_WIDTH_LOWER_BOUND = 13
+ NAME_WIDTH_UPPER_BOUND = 16
+ NAME_WIDTH = max([len(s) for s,d in linedata])
+ if NAME_WIDTH < NAME_WIDTH_LOWER_BOUND:
+ NAME_WIDTH = NAME_WIDTH_LOWER_BOUND
+ else:
+ NAME_WIDTH = NAME_WIDTH_UPPER_BOUND
+
+ DOC_WIDTH = WIDTH - NAME_WIDTH - SPACING
+ for namestr, doc in linedata:
+ line = indent + namestr
+ if len(namestr) <= NAME_WIDTH:
+ line += ' ' * (NAME_WIDTH + SPACING - len(namestr))
+ else:
+ lines.append(line)
+ line = indent + ' ' * (NAME_WIDTH + SPACING)
+ line += _summarize_doc(doc, DOC_WIDTH)
+ lines.append(line.rstrip())
+ return lines
+
+def _summarize_doc(doc, length=60):
+ r"""Parse out a short one line summary from the given doclines.
+
+ "doc" is the doc string to summarize.
+ "length" is the max length for the summary
+
+ >>> _summarize_doc("this function does this")
+ 'this function does this'
+ >>> _summarize_doc("this function does this", 10)
+ 'this fu...'
+ >>> _summarize_doc("this function does this\nand that")
+ 'this function does this and that'
+ >>> _summarize_doc("this function does this\n\nand that")
+ 'this function does this'
+ """
+ import re
+ if doc is None:
+ return ""
+ assert length > 3, "length <= 3 is absurdly short for a doc summary"
+ doclines = doc.strip().splitlines(0)
+ if not doclines:
+ return ""
+
+ summlines = []
+ for i, line in enumerate(doclines):
+ stripped = line.strip()
+ if not stripped:
+ break
+ summlines.append(stripped)
+ if len(''.join(summlines)) >= length:
+ break
+
+ summary = ' '.join(summlines)
+ if len(summary) > length:
+ summary = summary[:length-3] + "..."
+ return summary
+
+
+def line2argv(line):
+ r"""Parse the given line into an argument vector.
+
+ "line" is the line of input to parse.
+
+ This may get niggly when dealing with quoting and escaping. The
+ current state of this parsing may not be completely thorough/correct
+ in this respect.
+
+ >>> from cmdln import line2argv
+ >>> line2argv("foo")
+ ['foo']
+ >>> line2argv("foo bar")
+ ['foo', 'bar']
+ >>> line2argv("foo bar ")
+ ['foo', 'bar']
+ >>> line2argv(" foo bar")
+ ['foo', 'bar']
+
+ Quote handling:
+
+ >>> line2argv("'foo bar'")
+ ['foo bar']
+ >>> line2argv('"foo bar"')
+ ['foo bar']
+ >>> line2argv(r'"foo\"bar"')
+ ['foo"bar']
+ >>> line2argv("'foo bar' spam")
+ ['foo bar', 'spam']
+ >>> line2argv("'foo 'bar spam")
+ ['foo bar', 'spam']
+
+ >>> line2argv('some\tsimple\ttests')
+ ['some', 'simple', 'tests']
+ >>> line2argv('a "more complex" test')
+ ['a', 'more complex', 'test']
+ >>> line2argv('a more="complex test of " quotes')
+ ['a', 'more=complex test of ', 'quotes']
+ >>> line2argv('a more" complex test of " quotes')
+ ['a', 'more complex test of ', 'quotes']
+ >>> line2argv('an "embedded \\"quote\\""')
+ ['an', 'embedded "quote"']
+
+ # Komodo bug 48027
+ >>> line2argv('foo bar C:\\')
+ ['foo', 'bar', 'C:\\']
+
+ # Komodo change 127581
+ >>> line2argv(r'"\test\slash" "foo bar" "foo\"bar"')
+ ['\\test\\slash', 'foo bar', 'foo"bar']
+
+ # Komodo change 127629
+ >>> if sys.platform == "win32":
+ ... line2argv(r'\foo\bar') == ['\\foo\\bar']
+ ... line2argv(r'\\foo\\bar') == ['\\\\foo\\\\bar']
+ ... line2argv('"foo') == ['foo']
+ ... else:
+ ... line2argv(r'\foo\bar') == ['foobar']
+ ... line2argv(r'\\foo\\bar') == ['\\foo\\bar']
+ ... try:
+ ... line2argv('"foo')
+ ... except ValueError, ex:
+ ... "not terminated" in str(ex)
+ True
+ True
+ True
+ """
+ import string
+ line = line.strip()
+ argv = []
+ state = "default"
+ arg = None # the current argument being parsed
+ i = -1
+ while 1:
+ i += 1
+ if i >= len(line): break
+ ch = line[i]
+
+ if ch == "\\" and i+1 < len(line):
+ # escaped char always added to arg, regardless of state
+ if arg is None: arg = ""
+ if (sys.platform == "win32"
+ or state in ("double-quoted", "single-quoted")
+ ) and line[i+1] not in tuple('"\''):
+ arg += ch
+ i += 1
+ arg += line[i]
+ continue
+
+ if state == "single-quoted":
+ if ch == "'":
+ state = "default"
+ else:
+ arg += ch
+ elif state == "double-quoted":
+ if ch == '"':
+ state = "default"
+ else:
+ arg += ch
+ elif state == "default":
+ if ch == '"':
+ if arg is None: arg = ""
+ state = "double-quoted"
+ elif ch == "'":
+ if arg is None: arg = ""
+ state = "single-quoted"
+ elif ch in string.whitespace:
+ if arg is not None:
+ argv.append(arg)
+ arg = None
+ else:
+ if arg is None: arg = ""
+ arg += ch
+ if arg is not None:
+ argv.append(arg)
+ if not sys.platform == "win32" and state != "default":
+ raise ValueError("command line is not terminated: unfinished %s "
+ "segment" % state)
+ return argv
+
+
+def argv2line(argv):
+ r"""Put together the given argument vector into a command line.
+
+ "argv" is the argument vector to process.
+
+ >>> from cmdln import argv2line
+ >>> argv2line(['foo'])
+ 'foo'
+ >>> argv2line(['foo', 'bar'])
+ 'foo bar'
+ >>> argv2line(['foo', 'bar baz'])
+ 'foo "bar baz"'
+ >>> argv2line(['foo"bar'])
+ 'foo"bar'
+ >>> print argv2line(['foo" bar'])
+ 'foo" bar'
+ >>> print argv2line(["foo' bar"])
+ "foo' bar"
+ >>> argv2line(["foo'bar"])
+ "foo'bar"
+ """
+ escapedArgs = []
+ for arg in argv:
+ if ' ' in arg and '"' not in arg:
+ arg = '"'+arg+'"'
+ elif ' ' in arg and "'" not in arg:
+ arg = "'"+arg+"'"
+ elif ' ' in arg:
+ arg = arg.replace('"', r'\"')
+ arg = '"'+arg+'"'
+ escapedArgs.append(arg)
+ return ' '.join(escapedArgs)
+
+
+# Recipe: dedent (0.1) in /Users/trentm/tm/recipes/cookbook
+def _dedentlines(lines, tabsize=8, skip_first_line=False):
+ """_dedentlines(lines, tabsize=8, skip_first_line=False) -> dedented lines
+
+ "lines" is a list of lines to dedent.
+ "tabsize" is the tab width to use for indent width calculations.
+ "skip_first_line" is a boolean indicating if the first line should
+ be skipped for calculating the indent width and for dedenting.
+ This is sometimes useful for docstrings and similar.
+
+ Same as dedent() except operates on a sequence of lines. Note: the
+ lines list is modified **in-place**.
+ """
+ DEBUG = False
+ if DEBUG:
+ print "dedent: dedent(..., tabsize=%d, skip_first_line=%r)"\
+ % (tabsize, skip_first_line)
+ indents = []
+ margin = None
+ for i, line in enumerate(lines):
+ if i == 0 and skip_first_line: continue
+ indent = 0
+ for ch in line:
+ if ch == ' ':
+ indent += 1
+ elif ch == '\t':
+ indent += tabsize - (indent % tabsize)
+ elif ch in '\r\n':
+ continue # skip all-whitespace lines
+ else:
+ break
+ else:
+ continue # skip all-whitespace lines
+ if DEBUG: print "dedent: indent=%d: %r" % (indent, line)
+ if margin is None:
+ margin = indent
+ else:
+ margin = min(margin, indent)
+ if DEBUG: print "dedent: margin=%r" % margin
+
+ if margin is not None and margin > 0:
+ for i, line in enumerate(lines):
+ if i == 0 and skip_first_line: continue
+ removed = 0
+ for j, ch in enumerate(line):
+ if ch == ' ':
+ removed += 1
+ elif ch == '\t':
+ removed += tabsize - (removed % tabsize)
+ elif ch in '\r\n':
+ if DEBUG: print "dedent: %r: EOL -> strip up to EOL" % line
+ lines[i] = lines[i][j:]
+ break
+ else:
+ raise ValueError("unexpected non-whitespace char %r in "
+ "line %r while removing %d-space margin"
+ % (ch, line, margin))
+ if DEBUG:
+ print "dedent: %r: %r -> removed %d/%d"\
+ % (line, ch, removed, margin)
+ if removed == margin:
+ lines[i] = lines[i][j+1:]
+ break
+ elif removed > margin:
+ lines[i] = ' '*(removed-margin) + lines[i][j+1:]
+ break
+ return lines
+
+def _dedent(text, tabsize=8, skip_first_line=False):
+ """_dedent(text, tabsize=8, skip_first_line=False) -> dedented text
+
+ "text" is the text to dedent.
+ "tabsize" is the tab width to use for indent width calculations.
+ "skip_first_line" is a boolean indicating if the first line should
+ be skipped for calculating the indent width and for dedenting.
+ This is sometimes useful for docstrings and similar.
+
+ textwrap.dedent(s), but don't expand tabs to spaces
+ """
+ lines = text.splitlines(1)
+ _dedentlines(lines, tabsize=tabsize, skip_first_line=skip_first_line)
+ return ''.join(lines)
+
+
+def _get_indent(marker, s, tab_width=8):
+ """_get_indent(marker, s, tab_width=8) ->
+ (<indentation-of-'marker'>, <indentation-width>)"""
+ # Figure out how much the marker is indented.
+ INDENT_CHARS = tuple(' \t')
+ start = s.index(marker)
+ i = start
+ while i > 0:
+ if s[i-1] not in INDENT_CHARS:
+ break
+ i -= 1
+ indent = s[i:start]
+ indent_width = 0
+ for ch in indent:
+ if ch == ' ':
+ indent_width += 1
+ elif ch == '\t':
+ indent_width += tab_width - (indent_width % tab_width)
+ return indent, indent_width
+
+def _get_trailing_whitespace(marker, s):
+ """Return the whitespace content trailing the given 'marker' in string 's',
+ up to and including a newline.
+ """
+ suffix = ''
+ start = s.index(marker) + len(marker)
+ i = start
+ while i < len(s):
+ if s[i] in ' \t':
+ suffix += s[i]
+ elif s[i] in '\r\n':
+ suffix += s[i]
+ if s[i] == '\r' and i+1 < len(s) and s[i+1] == '\n':
+ suffix += s[i+1]
+ break
+ else:
+ break
+ i += 1
+ return suffix
+
+
+
+#---- bash completion support
+# Note: This is still experimental. I expect to change this
+# significantly.
+#
+# To get Bash completion for a cmdln.Cmdln class, run the following
+# bash command:
+# $ complete -C 'python -m cmdln /path/to/script.py CmdlnClass' cmdname
+# For example:
+# $ complete -C 'python -m cmdln ~/bin/svn.py SVN' svn
+#
+#TODO: Simplify the above so don't have to given path to script (try to
+# find it on PATH, if possible). Could also make class name
+# optional if there is only one in the module (common case).
+
+if __name__ == "__main__" and len(sys.argv) == 6:
+ def _log(s):
+ return # no-op, comment out for debugging
+ from os.path import expanduser
+ fout = open(expanduser("~/tmp/bashcpln.log"), 'a')
+ fout.write(str(s) + '\n')
+ fout.close()
+
+ # Recipe: module_from_path (1.0.1+)
+ def _module_from_path(path):
+ import imp, os, sys
+ path = os.path.expanduser(path)
+ dir = os.path.dirname(path) or os.curdir
+ name = os.path.splitext(os.path.basename(path))[0]
+ sys.path.insert(0, dir)
+ try:
+ iinfo = imp.find_module(name, [dir])
+ return imp.load_module(name, *iinfo)
+ finally:
+ sys.path.remove(dir)
+
+ def _get_bash_cplns(script_path, class_name, cmd_name,
+ token, preceding_token):
+ _log('--')
+ _log('get_cplns(%r, %r, %r, %r, %r)'
+ % (script_path, class_name, cmd_name, token, preceding_token))
+ comp_line = os.environ["COMP_LINE"]
+ comp_point = int(os.environ["COMP_POINT"])
+ _log("COMP_LINE: %r" % comp_line)
+ _log("COMP_POINT: %r" % comp_point)
+
+ try:
+ script = _module_from_path(script_path)
+ except ImportError, ex:
+ _log("error importing `%s': %s" % (script_path, ex))
+ return []
+ shell = getattr(script, class_name)()
+ cmd_map = shell._get_canonical_map()
+ del cmd_map["EOF"]
+
+ # Determine if completing the sub-command name.
+ parts = comp_line[:comp_point].split(None, 1)
+ _log(parts)
+ if len(parts) == 1 or not (' ' in parts[1] or '\t' in parts[1]):
+ #TODO: if parts[1].startswith('-'): handle top-level opts
+ _log("complete sub-command names")
+ matches = {}
+ for name, canon_name in cmd_map.items():
+ if name.startswith(token):
+ matches[name] = canon_name
+ if not matches:
+ return []
+ elif len(matches) == 1:
+ return matches.keys()
+ elif len(set(matches.values())) == 1:
+ return [matches.values()[0]]
+ else:
+ return matches.keys()
+
+ # Otherwise, complete options for the given sub-command.
+ #TODO: refine this so it does the right thing with option args
+ if token.startswith('-'):
+ cmd_name = comp_line.split(None, 2)[1]
+ try:
+ cmd_canon_name = cmd_map[cmd_name]
+ except KeyError:
+ return []
+ handler = shell._get_cmd_handler(cmd_canon_name)
+ optparser = getattr(handler, "optparser", None)
+ if optparser is None:
+ optparser = SubCmdOptionParser()
+ opt_strs = []
+ for option in optparser.option_list:
+ for opt_str in option._short_opts + option._long_opts:
+ if opt_str.startswith(token):
+ opt_strs.append(opt_str)
+ return opt_strs
+
+ return []
+
+ for cpln in _get_bash_cplns(*sys.argv[1:]):
+ print cpln
+
Modified: branches/yt-object-serialization/yt/commands.py
==============================================================================
--- branches/yt-object-serialization/yt/commands.py (original)
+++ branches/yt-object-serialization/yt/commands.py Tue Dec 30 06:31:31 2008
@@ -24,7 +24,9 @@
"""
from yt.mods import *
+from yt.funcs import *
from yt.recipes import _fix_pf
+import yt.cmdln as cmdln
import optparse, os, os.path, math
_common_options = dict(
@@ -68,7 +70,7 @@
help="Center (-1,-1,-1 for max)"),
bn = dict(short="-b", long="--basename",
action="store", type="string",
- dest="basename", default="galaxy",
+ dest="basename", default=None,
help="Basename of parameter files"),
output = dict(short="-o", long="--output",
action="store", type="string",
@@ -139,16 +141,73 @@
_add_options(parser, *options)
return parser
-# Now we define our functions, each of which will be an 'entry_point' when
-# passed to setuptools.
+def add_cmd_options(options):
+ opts = []
+ for option in options:
+ vals = _common_options[option].copy()
+ opts.append(([vals.pop("short"), vals.pop("long")],
+ vals))
+ def apply_options(func):
+ for args, kwargs in opts:
+ func = cmdln.option(*args, **kwargs)(func)
+ return func
+ return apply_options
+
+def check_args(func):
+ @wraps(func)
+ def arg_iterate(self, subcmd, opts, *args):
+ if len(args) == 1:
+ pfs = args
+ elif len(args) == 2 and opts.basename is not None:
+ pfs = ["%s%04i" % (opts.basename, r)
+ for r in range(int(args[0]), int(args[1]), opts.skip) ]
+ else: pfs = args
+ for arg in args:
+ func(self, subcmd, opts, arg)
+ return arg_iterate
+
+class YTCommands(cmdln.Cmdln):
+ name="yt"
+
+ def __init__(self, *args, **kwargs):
+ cmdln.Cmdln.__init__(self, *args, **kwargs)
+ cmdln.Cmdln.do_help.aliases.append("h")
+
+ def do_loop(self, subcmd, opts, *args):
+ """
+ Interactive loop
+
+ ${cmd_option_list}
+ """
+ self.cmdloop()
+
+ @add_cmd_options(['outputfn','bn','thresh','dm_only'])
+ @check_args
+ def do_hop(self, subcmd, opts, arg):
+ """
+ Run HOP on one or more datasets
-def zoomin():
- parser = _get_parser("maxw", "minw", "proj", "axis", "field", "weight",
- "zlim", "nframes", "output", "cmap", "uboxes")
- opts, args = parser.parse_args()
-
- for arg in args:
+ ${cmd_option_list}
+ """
pf = _fix_pf(arg)
+ sp = pf.h.sphere((pf["DomainLeftEdge"] + pf["DomainRightEdge"])/2.0,
+ pf['unitary'])
+ kwargs = {'dm_only' : opts.dm_only}
+ if opts.threshold is not None: kwargs['threshold'] = opts.threshold
+ hop_list = hop.HopList(sp, **kwargs)
+ if opts.output is None: fn = "%s.hop" % pf
+ else: fn = opts.output
+ hop_list.write_out(fn)
+
+ @add_cmd_options(["maxw", "minw", "proj", "axis", "field", "weight",
+ "zlim", "nframes", "output", "cmap", "uboxes"])
+ def do_zoomin(self, subcmd, opts, args):
+ """
+ Create a set of zoomin frames
+
+ ${cmd_option_list}
+ """
+ pf = _fix_pf(args[-1])
min_width = opts.min_width * pf.h.get_smallest_dx()
if opts.axis == 4:
axes = range(3)
@@ -182,29 +241,21 @@
pc.save(os.path.join(opts.output,"%s_frame%06i" % (pf,i)))
w *= factor
-def timeseries():
- parser = _get_parser("width", "unit", "bn", "proj", "center",
- "zlim", "axis", "field", "weight", "skip",
- "cmap", "output")
- opts, args = parser.parse_args()
-
- try:
- first = int(args[0])
- last = int(args[1])
- except:
- mylog.error("Hey, sorry, but you need to specify the indices of the first and last outputs you want to look at.")
- sys.exit()
-
- for n in range(first,last+1,opts.skip):
- # Now we figure out where this file is
- bn_try = "%s%04i" % (opts.basename, n)
- try:
- pf = _fix_pf(bn_try)
- except IOError:
- pf = _fix_pf("%s.dir/%s" % (bn_try, bn_try))
+ @add_cmd_options(["width", "unit", "bn", "proj", "center",
+ "zlim", "axis", "field", "weight", "skip",
+ "cmap", "output"])
+ @check_args
+ def do_plot(self, subcmd, opts, arg):
+ """
+ Create a set of images
+
+ ${cmd_usage}
+ ${cmd_option_list}
+ """
+ pf = _fix_pf(arg)
pc=raven.PlotCollection(pf)
center = opts.center
- if center is None or opts.center == (-1,-1,-1):
+ if opts.center == (-1,-1,-1):
mylog.info("No center fed in; seeking.")
v, center = pf.h.find_max("Density")
center = na.array(center)
@@ -222,16 +273,10 @@
if opts.zlim: pc.set_zlim(*opts.zlim)
pc.save(os.path.join(opts.output,"%s" % (pf)))
-def hop_single():
- parser = _get_parser("outputfn", "thresh", 'dm_only')
- opts, args = parser.parse_args()
-
- pf = _fix_pf(args[-1])
- sp = pf.h.sphere((pf["DomainLeftEdge"] + pf["DomainRightEdge"])/2.0,
- pf['unitary'])
- kwargs = {'dm_only' : opts.dm_only}
- if opts.threshold is not None: kwargs['threshold'] = opts.threshold
- hop_list = hop.HopList(sp, **kwargs)
- if opts.output is None: fn = "%s.hop" % pf
- else: fn = opts.output
- hop_list.write_out(fn)
+def run_main():
+ for co in ["--parallel", "--paste"]:
+ if co in sys.argv: del sys.argv[sys.argv.index(co)]
+ YT = YTCommands()
+ sys.exit(YT.main())
+
+if __name__ == "__main__": run_main()
Added: branches/yt-object-serialization/yt/convenience.py
==============================================================================
--- (empty file)
+++ branches/yt-object-serialization/yt/convenience.py Tue Dec 30 06:31:31 2008
@@ -0,0 +1,49 @@
+"""
+Some convenience functions, objects, and iterators
+
+Author: Matthew Turk <matthewturk at gmail.com>
+Affiliation: KIPAC/SLAC/Stanford
+Homepage: http://yt.enzotools.org/
+License:
+ Copyright (C) 2007-2008 Matthew Turk. All Rights Reserved.
+
+ This file is part of yt.
+
+ yt is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import glob
+
+# Named imports
+import yt.lagos as lagos
+import yt.raven as raven
+from yt.funcs import *
+import numpy as na
+import os.path, inspect, types
+from functools import wraps
+from yt.logger import ytLogger as mylog
+
+def all_pfs(max_depth=1, name_spec="*.hierarchy"):
+ list_of_names = []
+ for i in range(max_depth):
+ bb = list('*' * i) + [name_spec]
+ list_of_names += glob.glob(os.path.join(*bb))
+ list_of_names.sort(key=lambda b: os.path.basename(b))
+ for fn in list_of_names:
+ yield lagos.EnzoStaticOutput(fn[:-10])
+
+def max_spheres(width, unit, **kwargs):
+ for pf in all_pfs(**kwargs):
+ v, c = pf.h.find_max("Density")
+ yield pf.h.sphere(c, width/pf[unit])
Modified: branches/yt-object-serialization/yt/fido/ParameterFileStorage.py
==============================================================================
--- branches/yt-object-serialization/yt/fido/ParameterFileStorage.py (original)
+++ branches/yt-object-serialization/yt/fido/ParameterFileStorage.py Tue Dec 30 06:31:31 2008
@@ -23,6 +23,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
+from yt.config import ytcfg
from yt.fido import *
from yt.funcs import *
import shelve
Modified: branches/yt-object-serialization/yt/funcs.py
==============================================================================
--- branches/yt-object-serialization/yt/funcs.py (original)
+++ branches/yt-object-serialization/yt/funcs.py Tue Dec 30 06:31:31 2008
@@ -23,7 +23,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
-import time, types, signal, traceback
+import time, types, signal, traceback, sys
import progressbar as pb
from math import floor, ceil
@@ -35,6 +35,23 @@
except ValueError: # Not in main thread
pass
+def paste_traceback(exc_type, exc, tb):
+ sys.__excepthook__(exc_type, exc, tb)
+ import xmlrpclib, cStringIO
+ p = xmlrpclib.ServerProxy(
+ "http://paste.enzotools.org/xmlrpc/",
+ allow_none=True)
+ s = cStringIO.StringIO()
+ traceback.print_exception(exc_type, exc, tb, file=s)
+ s = s.getvalue()
+ ret = p.pastes.newPaste('pytb', s, None, '', '', True)
+ print
+ print "Traceback pasted to http://paste.enzotools.org/show/%s" % (ret)
+ print
+
+if "--paste" in sys.argv:
+ sys.excepthook = paste_traceback
+
def blank_wrapper(f):
return lambda a: a
Modified: branches/yt-object-serialization/yt/lagos/BaseDataTypes.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/BaseDataTypes.py (original)
+++ branches/yt-object-serialization/yt/lagos/BaseDataTypes.py Tue Dec 30 06:31:31 2008
@@ -620,7 +620,7 @@
_con_args = ['axis', 'coord']
#@time_execution
def __init__(self, axis, coord, fields = None, center=None, pf=None,
- node_name = False, **kwargs):
+ node_name = False, source = None, **kwargs):
"""
Slice along *axis*:ref:`axis-specification`, at the coordinate *coord*.
Optionally supply fields.
@@ -628,12 +628,21 @@
AMR2DData.__init__(self, axis, fields, pf, **kwargs)
self.center = center
self.coord = coord
+ self._initialize_source(source)
if node_name is False:
self._refresh_data()
else:
if node_name is True: self._deserialize()
else: self._deserialize(node_name)
+ def _initialize_source(self, source = None):
+ if source is None:
+ check, source = self._partition_hierarchy_2d(self.axis)
+ self._check_region = check
+ else:
+ self._check_region = True
+ self.source = source
+
def reslice(self, coord):
"""
Change the entire dataset, clearing out the current data and slicing at
@@ -681,7 +690,18 @@
self.ActiveDimensions = (t.shape[0], 1, 1)
def _get_list_of_grids(self):
- self._grids, ind = self.hierarchy.find_slice_grids(self.coord, self.axis)
+ goodI = ((self.source.gridRightEdge[:,self.axis] > self.coord)
+ & (self.source.gridLeftEdge[:,self.axis] < self.coord ))
+ self._grids = self.source._grids[goodI] # Using sources not hierarchy
+
+ def __cut_mask_child_mask(self, grid):
+ mask = grid.child_mask.copy()
+ if self._check_region:
+ cut_mask = self.source._get_cut_mask(grid)
+ if mask is False: mask *= False
+ elif mask is True: pass
+ else: mask &= cut_mask
+ return mask
def _generate_grid_coords(self, grid):
xaxis = x_dict[self.axis]
@@ -694,7 +714,8 @@
sl = tuple(sl)
nx = grid.child_mask.shape[xaxis]
ny = grid.child_mask.shape[yaxis]
- cm = na.where(grid.child_mask[sl].ravel() == 1)
+ mask = self.__cut_mask_child_mask(grid)[sl]
+ cm = na.where(mask.ravel()== 1)
cmI = na.indices((nx,ny))
xind = cmI[0,:].ravel()
xpoints = na.ones(cm[0].shape, 'float64')
@@ -730,7 +751,8 @@
dv = grid[field]
if dv.size == 1: dv = na.ones(grid.ActiveDimensions)*dv
dv = dv[sl]
- dataVals = dv.ravel()[grid.child_mask[sl].ravel() == 1]
+ mask = self.__cut_mask_child_mask(grid)[sl]
+ dataVals = dv.ravel()[mask.ravel() == 1]
return dataVals
def _gen_node_name(self):
@@ -1324,6 +1346,13 @@
k = (k | self._get_cut_particle_mask(grid))
return na.where(k)
+ def cut_region(self, field_cuts):
+ """
+ Return an InLineExtractedRegion, where the grid cells are cut on the
+ fly with a set of field_cuts.
+ """
+ return InLineExtractedRegionBase(self, field_cuts)
+
def extract_region(self, indices):
"""
Return an ExtractedRegion where the points contained in it are defined
@@ -1441,6 +1470,33 @@
# Yeah, if it's not true, we don't care.
return self._indices.get(grid.id-grid._id_offset, ())
+class InLineExtractedRegionBase(AMR3DData):
+ """
+ In-line extracted regions accept a base region and a set of field_cuts to
+ determine which points in a grid should be included.
+ """
+ def __init__(self, base_region, field_cuts, **kwargs):
+ cen = base_region.get_field_parameter("center")
+ AMR3DData.__init__(self, center=cen,
+ fields=None, pf=base_region.pf, **kwargs)
+ self._base_region = base_region # We don't weakly reference because
+ # It is not cyclic
+ self._field_cuts = ensure_list(field_cuts)[:]
+ self._refresh_data()
+
+ def _get_list_of_grids(self):
+ self._grids = self._base_region._grids
+
+ def _is_fully_enclosed(self, grid):
+ return False
+
+ @cache_mask
+ def _get_cut_mask(self, grid):
+ point_mask = self._base_region._get_cut_mask(grid)
+ for cut in self._field_cuts:
+ point_mask *= eval(cut)
+ return point_mask
+
class AMRCylinderBase(AMR3DData):
"""
We can define a cylinder (or disk) to act as a data object.
@@ -1535,12 +1591,12 @@
if self._is_fully_enclosed(grid):
return True
else:
- cm = ( (grid['x'] - 0.5*grid['dx'] < self.right_edge[0])
- & (grid['x'] + 0.5*grid['dx'] >= self.left_edge[0])
- & (grid['y'] - 0.5*grid['dy'] < self.right_edge[1])
- & (grid['y'] + 0.5*grid['dy'] >= self.left_edge[1])
- & (grid['z'] - 0.5*grid['dz'] < self.right_edge[2])
- & (grid['z'] + 0.5*grid['dz'] >= self.left_edge[2]) )
+ cm = ( (grid['x'] - grid['dx'] < self.right_edge[0])
+ & (grid['x'] + grid['dx'] > self.left_edge[0])
+ & (grid['y'] - grid['dy'] < self.right_edge[1])
+ & (grid['y'] + grid['dy'] > self.left_edge[1])
+ & (grid['z'] - grid['dz'] < self.right_edge[2])
+ & (grid['z'] + grid['dz'] > self.left_edge[2]) )
return cm
class AMRPeriodicRegionBase(AMR3DData):
@@ -1590,11 +1646,11 @@
cm = na.zeros(grid.ActiveDimensions,dtype='bool')
for off_x, off_y, off_z in self.offsets:
cm = cm | ( (grid['x'] - grid['dx'] + off_x < self.right_edge[0])
- & (grid['x'] + grid['dx'] + off_x >= self.left_edge[0])
+ & (grid['x'] + grid['dx'] + off_x > self.left_edge[0])
& (grid['y'] - grid['dy'] + off_y < self.right_edge[1])
- & (grid['y'] + grid['dy'] + off_y >= self.left_edge[1])
+ & (grid['y'] + grid['dy'] + off_y > self.left_edge[1])
& (grid['z'] - grid['dz'] + off_z < self.right_edge[2])
- & (grid['z'] + grid['dz'] + off_z >= self.left_edge[2]) )
+ & (grid['z'] + grid['dz'] + off_z > self.left_edge[2]) )
return cm
class AMRGridCollection(AMR3DData):
@@ -1671,7 +1727,7 @@
self._cut_masks[grid.id] = cm
return cm
-class AMRCoveringGrid(AMR3DData):
+class AMRCoveringGridBase(AMR3DData):
"""
Covering grids represent fixed-resolution data over a given region.
In order to achieve this goal -- for instance in order to obtain ghost
@@ -1798,7 +1854,7 @@
self.left_edge, self.right_edge, c_dx, c_fields,
ll, self.pf["DomainLeftEdge"], self.pf["DomainRightEdge"])
-class AMRSmoothedCoveringGrid(AMRCoveringGrid):
+class AMRSmoothedCoveringGridBase(AMRCoveringGridBase):
_type_name = "smoothed_covering_grid"
def __init__(self, *args, **kwargs):
dlog2 = na.log10(kwargs['dims'])/na.log10(2)
@@ -1806,7 +1862,7 @@
mylog.warning("Must be power of two dimensions")
#raise ValueError
kwargs['num_ghost_zones'] = 0
- AMRCoveringGrid.__init__(self, *args, **kwargs)
+ AMRCoveringGridBase.__init__(self, *args, **kwargs)
if na.any(self.left_edge == self.pf["DomainLeftEdge"]):
self.left_edge += self.dx
self.ActiveDimensions -= 1
@@ -1917,8 +1973,8 @@
class EnzoPeriodicRegionBase(AMRPeriodicRegionBase): pass
class EnzoGridCollection(AMRGridCollection): pass
class EnzoSphereBase(AMRSphereBase): pass
-class EnzoCoveringGrid(AMRCoveringGrid): pass
-class EnzoSmoothedCoveringGrid(AMRSmoothedCoveringGrid): pass
+class EnzoCoveringGrid(AMRCoveringGridBase): pass
+class EnzoSmoothedCoveringGrid(AMRSmoothedCoveringGridBase): pass
def _reconstruct_object(*args, **kwargs):
pfid = args[0]
Modified: branches/yt-object-serialization/yt/lagos/DerivedQuantities.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/DerivedQuantities.py (original)
+++ branches/yt-object-serialization/yt/lagos/DerivedQuantities.py Tue Dec 30 06:31:31 2008
@@ -102,7 +102,7 @@
def keys(self):
return self.functions.keys()
-def _TotalMass(self, data):
+def _TotalMass(data):
"""
This function takes no arguments and returns the sum of cell masses and
particle masses in the object.
@@ -227,15 +227,111 @@
# We only divide once here because we have velocity in cgs, but radius is
# in code.
G = 6.67e-8 / data.convert("cm") # cm^3 g^-1 s^-2
- pot = 2*G*PointCombine.FindBindingEnergy(data["CellMass"],
- data['x'],data['y'],data['z'],
- truncate, kinetic/(2*G))
+ import time
+ t1 = time.time()
+ try:
+ pot = 2*G*_cudaIsBound(data, truncate, kinetic/(2*G))
+ except (ImportError, AssertionError):
+ pot = 2*G*PointCombine.FindBindingEnergy(data["CellMass"],
+ data['x'],data['y'],data['z'],
+ False, kinetic/(2*G))
+ mylog.info("Boundedness check took %0.3e seconds", time.time()-t1)
return [(pot / kinetic)]
def _combIsBound(data, bound):
return bound
add_quantity("IsBound",function=_IsBound,combine_function=_combIsBound,n_ret=1,
force_unlazy=True)
+def _cudaIsBound(data, truncate, ratio):
+ import pycuda.driver as cuda
+ import pycuda.autoinit
+ import pycuda.gpuarray as gpuarray
+ my_stream = cuda.Stream()
+ cuda.init()
+ assert cuda.Device.count() >= 1
+
+ # Now the tedious process of rescaling our values...
+ length_scale_factor = data['dx'].max()/data['dx'].min()
+ mass_scale_factor = 1.0/(data['CellMass'].max())
+ x = ((data['x'] - data['x'].min()) * length_scale_factor).astype('float32')
+ y = ((data['y'] - data['y'].min()) * length_scale_factor).astype('float32')
+ z = ((data['z'] - data['z'].min()) * length_scale_factor).astype('float32')
+ m = (data['CellMass'] * mass_scale_factor).astype('float32')
+ p = na.zeros(z.shape, dtype='float32')
+ x_gpu = cuda.mem_alloc(x.size * x.dtype.itemsize)
+ y_gpu = cuda.mem_alloc(y.size * y.dtype.itemsize)
+ z_gpu = cuda.mem_alloc(z.size * z.dtype.itemsize)
+ m_gpu = cuda.mem_alloc(m.size * m.dtype.itemsize)
+ p_gpu = cuda.mem_alloc(p.size * p.dtype.itemsize)
+ for ag, a in [(x_gpu, x), (y_gpu, y), (z_gpu, z), (m_gpu, m), (p_gpu, p)]:
+ cuda.memcpy_htod(ag, a)
+ source = """
+
+ extern __shared__ float array[];
+
+ __global__ void isbound(float *x, float *y, float *z, float *m,
+ float *p)
+ {
+
+ /* My index in the array */
+ int idx1 = blockIdx.x * blockDim.x + threadIdx.x;
+ /* Note we are setting a start index */
+ int idx2 = blockIdx.y * blockDim.x;
+ int offset = threadIdx.x;
+
+ float* x_data1 = (float*) array;
+ float* y_data1 = (float*) &x_data1[blockDim.x];
+ float* z_data1 = (float*) &y_data1[blockDim.x];
+ float* m_data1 = (float*) &z_data1[blockDim.x];
+
+ float* x_data2 = (float*) &m_data1[blockDim.x];
+ float* y_data2 = (float*) &x_data2[blockDim.x];
+ float* z_data2 = (float*) &y_data2[blockDim.x];
+ float* m_data2 = (float*) &z_data2[blockDim.x];
+
+ x_data1[offset] = x[idx1];
+ y_data1[offset] = y[idx1];
+ z_data1[offset] = z[idx1];
+ m_data1[offset] = m[idx1];
+
+ x_data2[offset] = x[idx2 + offset];
+ y_data2[offset] = y[idx2 + offset];
+ z_data2[offset] = z[idx2 + offset];
+ m_data2[offset] = m[idx2 + offset];
+
+ __syncthreads();
+
+ float tx, ty, tz;
+
+ float my_p = 0.0;
+
+ for (int i = 0; i < blockDim.x; i++){
+ if(i + idx2 < idx1 + 1) continue;
+ tx = (x_data1[offset]-x_data2[i]);
+ ty = (y_data1[offset]-y_data2[i]);
+ tz = (z_data1[offset]-z_data2[i]);
+ my_p += m_data1[offset]*m_data2[i] /
+ sqrt(tx*tx+ty*ty+tz*tz);
+ }
+ p[idx1] += my_p;
+ __syncthreads();
+ }
+ """
+ bsize = 256
+ mod = cuda.SourceModule(source % dict(p=m.size, b=bsize), keep=True)
+ func = mod.get_function('isbound')
+ import math
+ gsize=int(math.ceil(float(m.size)/bsize))
+ mylog.info("Running CUDA functions. May take a while. (%0.5e, %s)",
+ x.size, gsize)
+ import pycuda.tools as ct
+ t1 = time.time()
+ ret = func(x_gpu, y_gpu, z_gpu, m_gpu, p_gpu,
+ shared=8*bsize*m.dtype.itemsize,
+ block=(bsize,1,1), grid=(gsize, gsize), time_kernel=True)
+ cuda.memcpy_dtoh(p, p_gpu)
+ p1 = p.sum()
+ return p1 * (length_scale_factor / (mass_scale_factor**2.0))
def _Extrema(data, fields):
"""
Modified: branches/yt-object-serialization/yt/lagos/HierarchyType.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/HierarchyType.py (original)
+++ branches/yt-object-serialization/yt/lagos/HierarchyType.py Tue Dec 30 06:31:31 2008
@@ -108,7 +108,12 @@
def _initialize_data_file(self):
if not ytcfg.getboolean('lagos','serialize'): return
- fn = os.path.join(self.directory,"%s.yt" % self["CurrentTimeIdentifier"])
+ if os.path.isfile(os.path.join(self.directory,
+ "%s.yt" % self["CurrentTimeIdentifier"])):
+ fn = os.path.join(self.directory,"%s.yt" % self["CurrentTimeIdentifier"])
+ else:
+ fn = os.path.join(self.directory,
+ "%s.yt" % self.parameter_file.basename)
if ytcfg.getboolean('lagos','onlydeserialize'):
self._data_mode = mode = 'r'
else:
@@ -150,7 +155,7 @@
if name in node_loc and force:
mylog.info("Overwriting node %s/%s", node, name)
self._data_file.removeNode(node, name, recursive=True)
- if name in node_loc and passthrough:
+ elif name in node_loc and passthrough:
return
except tables.exceptions.NoSuchNodeError:
pass
@@ -187,20 +192,39 @@
del self._data_file
self._data_file = None
+ def _add_object_class(self, name, obj):
+ self.object_types.append(name)
+ setattr(self, name, obj)
+
def _setup_classes(self, dd):
- self.proj = classobj("AMRProj",(AMRProjBase,), dd)
- self.slice = classobj("AMRSlice",(AMRSliceBase,), dd)
- self.region = classobj("AMRRegion",(AMRRegionBase,), dd)
- self.periodic_region = classobj("AMRPeriodicRegion",(AMRPeriodicRegionBase,), dd)
- self.covering_grid = classobj("AMRCoveringGrid",(AMRCoveringGrid,), dd)
- self.smoothed_covering_grid = classobj("AMRSmoothedCoveringGrid",(AMRSmoothedCoveringGrid,), dd)
- self.sphere = classobj("AMRSphere",(AMRSphereBase,), dd)
- self.cutting = classobj("AMRCuttingPlane",(AMRCuttingPlaneBase,), dd)
- self.ray = classobj("AMRRay",(AMRRayBase,), dd)
- self.ortho_ray = classobj("AMROrthoRay",(AMROrthoRayBase,), dd)
- self.disk = classobj("AMRCylinder",(AMRCylinderBase,), dd)
- self.grid_collection = classobj("AMRGridCollection",(AMRGridCollection,), dd)
- self.extracted_region = classobj("ExtractedRegion",(ExtractedRegionBase,), dd)
+ self.object_types = []
+ self._add_object_class('proj',
+ classobj("AMRProj",(AMRProjBase,), dd))
+ self._add_object_class('slice',
+ classobj("AMRSlice",(AMRSliceBase,), dd))
+ self._add_object_class('region',
+ classobj("AMRRegion",(AMRRegionBase,), dd))
+ self._add_object_class('periodic_region',
+ classobj("AMRPeriodicRegion",(AMRPeriodicRegionBase,), dd))
+ self._add_object_class('covering_grid',
+ classobj("AMRCoveringGrid",(AMRCoveringGridBase,), dd))
+ self._add_object_class('smoothed_covering_grid',
+ classobj("AMRSmoothedCoveringGrid",(AMRSmoothedCoveringGridBase,), dd))
+ self._add_object_class('sphere',
+ classobj("AMRSphere",(AMRSphereBase,), dd))
+ self._add_object_class('cutting',
+ classobj("AMRCuttingPlane",(AMRCuttingPlaneBase,), dd))
+ self._add_object_class('ray',
+ classobj("AMRRay",(AMRRayBase,), dd))
+ self._add_object_class('ortho_ray',
+ classobj("AMROrthoRay",(AMROrthoRayBase,), dd))
+ self._add_object_class('disk',
+ classobj("AMRCylinder",(AMRCylinderBase,), dd))
+ self._add_object_class('grid_collection',
+ classobj("AMRGridCollection",(AMRGridCollection,), dd))
+ self._add_object_class('extracted_region',
+ classobj("ExtractedRegion",(ExtractedRegionBase,), dd))
+ self.object_types.sort()
def _deserialize_hierarchy(self, harray):
mylog.debug("Cached entry found.")
@@ -617,8 +641,10 @@
def _setup_classes(self):
dd = self._get_data_reader_dict()
- self.grid = classobj("EnzoGrid",(EnzoGridBase,), dd)
AMRHierarchy._setup_classes(self, dd)
+ self._add_object_class('grid',
+ classobj("EnzoGrid",(EnzoGridBase,), dd))
+ self.object_types.sort()
def __guess_data_style(self, rank, testGrid, testGridID):
if self.data_style: return
@@ -1317,8 +1343,10 @@
def _setup_classes(self):
dd = self._get_data_reader_dict()
dd["field_indexes"] = self.field_indexes
- self.grid = classobj("OrionGrid",(OrionGridBase,), dd)
AMRHierarchy._setup_classes(self, dd)
+ self._add_object_class('grid',
+ classobj("OrionGrid",(OrionGridBase,), dd))
+ self.object_types.sort()
def _get_grid_children(self, grid):
mask = na.zeros(self.num_grids, dtype='bool')
Modified: branches/yt-object-serialization/yt/lagos/OutputTypes.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/OutputTypes.py (original)
+++ branches/yt-object-serialization/yt/lagos/OutputTypes.py Tue Dec 30 06:31:31 2008
@@ -48,7 +48,8 @@
obj = object.__new__(cls)
obj.__init__(filename, *args, **kwargs)
_cached_pfs[apath] = obj
- _pf_store.check_pf(obj)
+ if ytcfg.getboolean('lagos','serialize'):
+ _pf_store.check_pf(obj)
return _cached_pfs[apath]
def __init__(self, filename, data_style=None):
Modified: branches/yt-object-serialization/yt/lagos/ParallelTools.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/ParallelTools.py (original)
+++ branches/yt-object-serialization/yt/lagos/ParallelTools.py Tue Dec 30 06:31:31 2008
@@ -26,7 +26,7 @@
from yt.lagos import *
from yt.funcs import *
import yt.logger
-import itertools, sys
+import itertools, sys, cStringIO
if os.path.basename(sys.executable) in ["mpi4py", "embed_enzo"] \
or "--parallel" in sys.argv or '_parallel' in dir(sys):
@@ -114,11 +114,14 @@
@wraps(func)
def single_proc_results(self, *args, **kwargs):
retval = None
- if not self._distributed:
+ if self._processing or not self._distributed:
return func(self, *args, **kwargs)
- if self._owned:
+ if self._owner == MPI.COMM_WORLD.rank:
+ self._processing = True
retval = func(self, *args, **kwargs)
- retval = MPI.COMM_WORLD.Bcast(retval, root=MPI.COMM_WORLD.rank)
+ self._processing = False
+ retval = MPI.COMM_WORLD.Bcast(retval, root=self._owner)
+ MPI.COMM_WORLD.Barrier()
return retval
return single_proc_results
@@ -313,3 +316,15 @@
for field in fields:
deps += ensure_list(fi[field].get_dependencies().requested)
return list(set(deps))
+
+ def _claim_object(self, obj):
+ if not parallel_capable: return
+ obj._owner = MPI.COMM_WORLD.rank
+ obj._distributed = True
+
+ def _write_on_root(self, fn):
+ if not parallel_capable: return open(fn, "w")
+ if MPI.COMM_WORLD.rank == 0:
+ return open(fn, "w")
+ else:
+ return cStringIO.StringIO()
Modified: branches/yt-object-serialization/yt/lagos/Profiles.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/Profiles.py (original)
+++ branches/yt-object-serialization/yt/lagos/Profiles.py Tue Dec 30 06:31:31 2008
@@ -108,10 +108,10 @@
for field in fields:
f, w, u = self._bin_field(self._data_source, field, weight,
accumulation, self._args, check_cut = False)
- ub = na.where(u)
if weight:
- f[ub] /= w[ub]
+ f[u] /= w[u]
self[field] = f
+ self["myweight"] = w
self["UsedBins"] = u
def add_fields(self, fields, weight = "CellMassMsun", accumulation = False):
@@ -138,20 +138,19 @@
def __setitem__(self, key, value):
self._data[key] = value
- def _get_field(self, source, field, check_cut):
+ def _get_field(self, source, this_field, check_cut):
# This is where we will iterate to get all contributions to a field
# which is how we will implement hybrid particle/cell fields
# but... we default to just the field.
data = []
- for field in _field_mapping.get(field, (field,)):
+ for field in _field_mapping.get(this_field, (this_field,)):
+ pointI = None
if check_cut:
if field in self.pf.field_info \
and self.pf.field_info[field].particle_type:
pointI = self._data_source._get_particle_indices(source)
else:
pointI = self._data_source._get_point_indices(source)
- else:
- pointI = slice(None)
data.append(source[field][pointI].ravel().astype('float64'))
return na.concatenate(data, axis=0)
@@ -159,7 +158,8 @@
class BinnedProfile1D(BinnedProfile):
def __init__(self, data_source, n_bins, bin_field,
lower_bound, upper_bound,
- log_space = True, lazy_reader=False):
+ log_space = True, lazy_reader=False,
+ left_collect = False):
"""
A 'Profile' produces either a weighted (or unweighted) average or a
straight sum of a field in a bin defined by another field. In the case
@@ -174,6 +174,7 @@
BinnedProfile.__init__(self, data_source, lazy_reader)
self.bin_field = bin_field
self._x_log = log_space
+ self.left_collect = left_collect
# Get our bins
if log_space:
func = na.logspace
@@ -224,8 +225,11 @@
if source_data.size == 0: # Nothing for us here.
return
# Truncate at boundaries.
- mi = na.where( (source_data > self[self.bin_field].min())
- & (source_data < self[self.bin_field].max()))
+ if self.left_collect:
+ mi = na.where(source_data < self[self.bin_field].max())
+ else:
+ mi = na.where( (source_data > self[self.bin_field].min())
+ & (source_data < self[self.bin_field].max()))
sd = source_data[mi]
if sd.size == 0:
return
@@ -255,7 +259,7 @@
def __init__(self, data_source,
x_n_bins, x_bin_field, x_lower_bound, x_upper_bound, x_log,
y_n_bins, y_bin_field, y_lower_bound, y_upper_bound, y_log,
- lazy_reader=False):
+ lazy_reader=False, left_collect=False):
"""
A 'Profile' produces either a weighted (or unweighted) average or a
straight sum of a field in a bin defined by two other fields. In the case
@@ -274,6 +278,7 @@
self.y_bin_field = y_bin_field
self._x_log = x_log
self._y_log = y_log
+ self.left_collect = left_collect
if x_log: self[x_bin_field] = na.logspace(na.log10(x_lower_bound*0.99),
na.log10(x_upper_bound*1.01),
x_n_bins)
@@ -330,10 +335,14 @@
source_data_y = self._get_field(source, self.y_bin_field, check_cut)
if source_data_x.size == 0:
return
- mi = na.where( (source_data_x > self[self.x_bin_field].min())
- & (source_data_x < self[self.x_bin_field].max())
- & (source_data_y > self[self.y_bin_field].min())
- & (source_data_y < self[self.y_bin_field].max()))
+ if self.left_collect:
+ mi = na.where( (source_data_x < self[self.x_bin_field].max())
+ & (source_data_y < self[self.y_bin_field].max()))
+ else:
+ mi = na.where( (source_data_x > self[self.x_bin_field].min())
+ & (source_data_x < self[self.x_bin_field].max())
+ & (source_data_y > self[self.y_bin_field].min())
+ & (source_data_y < self[self.y_bin_field].max()))
sd_x = source_data_x[mi]
sd_y = source_data_y[mi]
if sd_x.size == 0 or sd_y.size == 0:
Modified: branches/yt-object-serialization/yt/lagos/hop/SS_HopOutput.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/hop/SS_HopOutput.py (original)
+++ branches/yt-object-serialization/yt/lagos/hop/SS_HopOutput.py Tue Dec 30 06:31:31 2008
@@ -112,7 +112,10 @@
"""
Write out standard HOP information to *filename*.
"""
- f = open(filename,"w")
+ if hasattr(filename, 'write'):
+ f = filename
+ else:
+ f = open(filename,"w")
f.write("\t".join(["# Group","Mass","# part","max dens"
"x","y","z", "center-of-mass",
"x","y","z",
@@ -120,7 +123,7 @@
for group in self:
f.write("%10i\t" % group.id)
f.write("%0.9e\t" % group.total_mass())
- f.write("%10i\t" % group.indices.size)
+ f.write("%10i\t" % group.get_size())
f.write("%0.9e\t" % group.maximum_density())
f.write("\t".join(["%0.9e" % v for v in group.maximum_density_location()]))
f.write("\t")
@@ -130,6 +133,7 @@
f.write("\t")
f.write("%0.9e\t" % group.maximum_radius())
f.write("\n")
+ f.flush()
f.close()
class HopIterator(object):
@@ -149,7 +153,9 @@
"""
__metaclass__ = ParallelDummy # This will proxy up our methods
_distributed = False
- _owned = True
+ _processing = False
+ _owner = 0
+ indices = None
dont_wrap = ["get_sphere"]
def __init__(self, hop_output, id, indices = None):
@@ -159,6 +165,7 @@
if indices is not None: self.indices = hop_output._base_indices[indices]
# We assume that if indices = None, the instantiator has OTHER plans
# for us -- i.e., setting it somehow else
+
def center_of_mass(self):
"""
Calculate and return the center of mass.
@@ -234,6 +241,9 @@
center, radius=radius)
return sphere
+ def get_size(self):
+ return self.indices.size
+
class HaloFinder(HopList, ParallelAnalysisInterface):
def __init__(self, pf, threshold=160.0, dm_only=True):
self.pf = pf
@@ -253,6 +263,9 @@
self.bounds = (LE, RE)
# reflect particles around the periodic boundary
self._reposition_particles((LE, RE))
+ self.data_source.get_data(["ParticleMassMsun"] +
+ ["particle_velocity_%s" % ax for ax in 'xyz'] +
+ ["particle_position_%s" % ax for ax in 'xyz'])
# MJT: This is the point where HOP is run, and we have halos for every
# single sub-region
super(HaloFinder, self).__init__(self.data_source, threshold, dm_only)
@@ -262,7 +275,6 @@
def _parse_hoplist(self):
groups, max_dens, hi = [], {}, 0
LE, RE = self.bounds
- print LE, RE
for halo in self._groups:
this_max_dens = halo.maximum_density_location()
# if the most dense particle is in the box, keep it
@@ -271,11 +283,10 @@
# self.hop_list
# We need to mock up the HopList thingie, so we need to set:
# self._max_dens
- #
max_dens[hi] = self._max_dens[halo.id]
groups.append(HopGroup(self, hi))
groups[-1].indices = halo.indices
- groups[-1]._owned = True
+ self._claim_object(groups[-1])
hi += 1
del self._groups, self._max_dens # explicit >> implicit
self._groups = groups
@@ -299,19 +310,28 @@
# sort the list by the size of the groups
# Now we add ghost halos and reassign all the IDs
# Note: we already know which halos we own!
- after = nhalos - (my_first_id + len(self._groups))
+ after = my_first_id + len(self._groups)
# One single fake halo, not owned, does the trick
- fake_halo = HopGroup(self, 0)
- fake_halo._owned = False
- self._groups = [fake_halo] * my_first_id + \
+ self._groups = [HopGroup(self, i) for i in range(my_first_id)] + \
self._groups + \
- [fake_halo] * after
+ [HopGroup(self, i) for i in range(after, nhalos)]
# MJT: Sorting doesn't work yet. They need to be sorted.
#haloes.sort(lambda x, y: cmp(len(x.indices),len(y.indices)))
# Unfortunately, we can't sort *just yet*.
+ id = 0
+ for proc in sorted(halo_info.keys()):
+ for halo in self._groups[id:id+halo_info[proc]]:
+ halo.id = id
+ halo._distributed = True
+ halo._owner = proc
+ id += 1
+ self._groups.sort(key = lambda h: -1 * h.get_size())
+ sorted_max_dens = {}
for i, halo in enumerate(self._groups):
- self._distributed = True
+ if halo.id in self._max_dens:
+ sorted_max_dens[i] = self._max_dens[halo.id]
halo.id = i
+ self._max_dens = sorted_max_dens
def _reposition_particles(self, bounds):
# This only does periodicity. We do NOT want to deal with anything
@@ -322,3 +342,7 @@
arr = self.data_source["particle_position_%s" % ax]
arr[arr < LE[i]-self.padding] += dw[i]
arr[arr > RE[i]+self.padding] -= dw[i]
+
+ def write_out(self, filename):
+ f = self._write_on_root(filename)
+ HopList.write_out(self, f)
Modified: branches/yt-object-serialization/yt/mods.py
==============================================================================
--- branches/yt-object-serialization/yt/mods.py (original)
+++ branches/yt-object-serialization/yt/mods.py Tue Dec 30 06:31:31 2008
@@ -50,11 +50,10 @@
fieldInfo = EnzoFieldInfo
# Now individual component imports from raven
-from yt.raven import PlotCollection, PlotCollectionInteractive, \
- QuiverCallback, ParticleCallback, ContourCallback, \
- GridBoundaryCallback, UnitBoundaryCallback, \
- LinePlotCallback, CuttingQuiverCallback, ClumpContourCallback, \
- HopCircleCallback
+from yt.raven import PlotCollection, PlotCollectionInteractive
+from yt.raven.Callbacks import callback_registry
+for name, cls in callback_registry:
+ exec("from yt.raven import %s" % name)
# Optional component imports from raven
try:
Modified: branches/yt-object-serialization/yt/raven/Callbacks.py
==============================================================================
--- branches/yt-object-serialization/yt/raven/Callbacks.py (original)
+++ branches/yt-object-serialization/yt/raven/Callbacks.py Tue Dec 30 06:31:31 2008
@@ -32,7 +32,14 @@
import _MPL
+callback_registry = []
+
class PlotCallback(object):
+ class __metaclass__(type):
+ def __init__(cls, name, b, d):
+ type.__init__(name, b, d)
+ callback_registry.append((name, cls))
+
def __init__(self, *args, **kwargs):
pass
@@ -226,30 +233,30 @@
yy0, yy1 = plot._axes.get_ylim()
dx = (xx1-xx0)/(x1-x0)
dy = (yy1-yy0)/(y1-y0)
- GLE = plot.data.gridLeftEdge
- GRE = plot.data.gridRightEdge
px_index = lagos.x_dict[plot.data.axis]
py_index = lagos.y_dict[plot.data.axis]
- left_edge_px = na.maximum((GLE[:,px_index]-x0)*dx, xx0)
- left_edge_py = na.maximum((GLE[:,py_index]-y0)*dy, yy0)
- right_edge_px = na.minimum((GRE[:,px_index]-x0)*dx, xx1)
- right_edge_py = na.minimum((GRE[:,py_index]-y0)*dy, yy1)
- print left_edge_px.min(), left_edge_px.max(), \
- right_edge_px.min(), right_edge_px.max(), \
- x0, x1, y0, y1
- verts = na.array(
- [(left_edge_px, left_edge_px, right_edge_px, right_edge_px),
- (left_edge_py, right_edge_py, right_edge_py, left_edge_py)])
- visible = ( right_edge_px - left_edge_px > self.min_pix ) & \
- ( right_edge_px - left_edge_px > self.min_pix )
- verts=verts.transpose()[visible,:,:]
- edgecolors = (0.0,0.0,0.0,self.alpha)
- grid_collection = matplotlib.collections.PolyCollection(
- verts, facecolors=(0.0,0.0,0.0,0.0),
- edgecolors=edgecolors)
- plot._axes.hold(True)
- plot._axes.add_collection(grid_collection)
- plot._axes.hold(False)
+ dom = plot.data.pf["DomainRightEdge"] - plot.data.pf["DomainLeftEdge"]
+ for px_off, py_off in na.mgrid[-1:1:3j,-1:1:3j]:
+ GLE = plot.data.gridLeftEdge + px_off * dom[px_index]
+ GRE = plot.data.gridRightEdge + py_off * dom[py_index]
+ left_edge_px = na.maximum((GLE[:,px_index]-x0)*dx, xx0)
+ left_edge_py = na.maximum((GLE[:,py_index]-y0)*dy, yy0)
+ right_edge_px = na.minimum((GRE[:,px_index]-x0)*dx, xx1)
+ right_edge_py = na.minimum((GRE[:,py_index]-y0)*dy, yy1)
+ verts = na.array(
+ [(left_edge_px, left_edge_px, right_edge_px, right_edge_px),
+ (left_edge_py, right_edge_py, right_edge_py, left_edge_py)])
+ visible = ( right_edge_px - left_edge_px > self.min_pix ) & \
+ ( right_edge_px - left_edge_px > self.min_pix )
+ verts=verts.transpose()[visible,:,:]
+ if verts.size == 0: continue
+ edgecolors = (0.0,0.0,0.0,self.alpha)
+ grid_collection = matplotlib.collections.PolyCollection(
+ verts, facecolors=(0.0,0.0,0.0,0.0),
+ edgecolors=edgecolors)
+ plot._axes.hold(True)
+ plot._axes.add_collection(grid_collection)
+ plot._axes.hold(False)
class LabelCallback(PlotCallback):
def __init__(self, label):
@@ -507,7 +514,7 @@
plot._axes.add_patch(cir)
if self.annotate:
if self.print_halo_size:
- plot._axes.text(center_x, center_y, "%s" % len(halo.indices),
+ plot._axes.text(center_x, center_y, "%s" % halo.get_size(),
fontsize=self.font_size)
else:
plot._axes.text(center_x, center_y, "%s" % halo.id,
More information about the yt-svn
mailing list