Source code for opennode.oms.endpoint.ssh.protocol

import fnmatch
import itertools
import os
import re
import sys
import traceback

from grokcore.component import Subscription, implements, context
from twisted.conch.insults.insults import ServerProtocol
from twisted.internet import defer
from twisted.python import log
from zope.security.interfaces import ForbiddenAttribute, Unauthorized

from opennode.oms.config import get_config
from opennode.oms.endpoint.ssh import cmdline
from opennode.oms.endpoint.ssh.cmd import registry, completion, commands
from opennode.oms.endpoint.ssh.colored_columnize import columnize
from opennode.oms.endpoint.ssh.terminal import InteractiveTerminal, BLUE, CYAN, GREEN, CTRL_C
from opennode.oms.endpoint.ssh.tokenizer import CommandLineTokenizer, CommandLineSyntaxError
from opennode.oms.model.model.base import IContainer
from opennode.oms.model.model.bin import ICommand
from opennode.oms.model.model.proc import Proc
from opennode.oms.security.interaction import new_interaction
from opennode.oms.zodb import db
from opennode.oms.zodb.extractors import IContextExtractor


[docs]def protocolInlineCallbacks(fun): """Executes protocol async callbacks while buffering next keystrokes during the execution of the callback. After the callback finishes the buffered keystrokes will be replayed back, but currently only non-special characters are replayed, since special characters could trigger a reentrant here which would inject keystrokes out of order. """ @defer.inlineCallbacks def wrapper(self, *args, **kwargs): try: old_sub_protocol = self.sub_protocol if not old_sub_protocol: self.sub_protocol = CallbackExecutionSubProtocol() yield defer.inlineCallbacks(fun)(self, *args, **kwargs) except Exception as e: log.msg("got exception while %s: %s" % (fun, e), system='protocol') if get_config().getboolean('debug', 'print_exceptions'): log.err(system='protocol') execution_sub_protocol = self.sub_protocol self.sub_protocol = old_sub_protocol for (key, mod) in execution_sub_protocol.buffer: if key not in self.keyHandlers.keys(): self.keystrokeReceived(key, mod) return wrapper
[docs]class CallbackExecutionSubProtocol(object): def __init__(self): self.buffer = []
[docs] def keystrokeReceived(self, keyID, mod): self.buffer.append((keyID, mod))
[docs]class OmsShellProtocol(InteractiveTerminal): """The OMS virtual console over SSH. Accepts lines of input and writes them back to its connection. If a line consisting solely of "quit" is received, the connection is dropped. """ def __init__(self): super(OmsShellProtocol, self).__init__() self.path = [''] self.last_error = None self.environment = {'PATH': '.:./actions:/bin'} self.path_stack = [] self.sub_protocol = None self.principal = None self.use_security_proxy = get_config().getboolean('auth', 'security_proxy_omsh') @defer.inlineCallbacks def _get_obj_path(): self.obj_path = yield db.ro_transact(lambda: [db.ref(db.get_root()['oms_root'])])() self.get_object_path_deferred = _get_obj_path() self.tokenizer = CommandLineTokenizer() @defer.inlineCallbacks
[docs] def ensure_initialized(self): yield self.get_object_path_deferred
[docs] def logged_in(self, principal): """Invoked when the principal which opened this session is known""" self.principal = principal self.interaction = new_interaction(principal.id) self.tid = Proc.register(None, self, '/bin/omsh', principal=principal) self.terminalSizeAfterLogin()
[docs] def connectionMade(self): super(OmsShellProtocol, self).connectionMade()
[docs] def close_connection(self): Proc.unregister(self.tid) super(OmsShellProtocol, self).close_connection()
[docs] def dataReceived(self, data): # some sub protocols need raw data, because `keystrokeReceived` # reinterprets all special chars (like arrows etc) and there is no way # to get back to the original escape sequences. if self.sub_protocol and hasattr(self.sub_protocol, 'dataReceived'): return self.sub_protocol.dataReceived(data) self.terminal._orig_dataReceived(data)
[docs] def keystrokeReceived(self, keyID, modifier): (self.sub_protocol or super(OmsShellProtocol, self)).keystrokeReceived(keyID, modifier)
[docs] def exit_sub_protocol(self): self.sub_protocol = None self.print_prompt()
@defer.inlineCallbacks
[docs] def lineReceived(self, line): try: yield self.spawn_commands(line) finally: self._command_completed()
@defer.inlineCallbacks
[docs] def spawn_commands(self, line): yield self.ensure_initialized() # XXX: handle ; chars in quotes and comments for command in line.split(';'): yield self.spawn_command(command)
@defer.inlineCallbacks
[docs] def spawn_command(self, line): line = line.strip() try: command, cmd_args = yield self.parse_line(line) except CommandLineSyntaxError as e: self.terminal.write("Syntax error: %s\n" % (e.message)) self.print_prompt() return except Exception as e: log.msg("Got exception parsing '%s'" % (line), system='protocol') self.terminal.write(''.join(traceback.format_exception(*sys.exc_info()))) return try: self.sub_protocol = CommandExecutionSubProtocol(self) deferred = defer.Deferred() yield command.register(deferred, cmd_args, line, self.tid) cmdd = defer.maybeDeferred(command, *cmd_args) cmdd.chainDeferred(deferred) yield deferred except cmdline.ArgumentParsingError: return except Unauthorized as e: msg = e log.err(system='ssh') if isinstance(e.message, tuple) and len(e.message) == 3: msg = "accessing %s's attribute '%s' requires @%s right" % e.message self.terminal.write("Permission denied: %s\n" % msg) except Exception as e: self.last_error = (line, sys.exc_info()) log.msg("Got exception executing '%s': %s" % self.last_error, system='protocol') if get_config().getboolean('debug', 'print_exceptions'): traceback.print_tb(self.last_error[1][2]) self.terminal.write("Command returned an unhandled error: %s\n" % e) self.terminal.write("type last_error for more details\n")
def _command_completed(self, *args): self.print_prompt() if self.sub_protocol: buffer = self.sub_protocol.buffer self.sub_protocol = None for (key, mod) in buffer or (): if key not in self.keyHandlers.keys(): self.keystrokeReceived(key, mod) @db.ro_transact
[docs] def parse_line(self, line): """Returns a command instance and parsed cmdline argument list. TODO: Shell expansion should be handled here. """ cmd_name, cmd_args = line.partition(' ')[::2] command_cls = self.get_command_class(cmd_name) command = command_cls(self) tokenized_cmd_args = self.expand(command, self.tokenizer.tokenize(cmd_args.strip())) return command, tokenized_cmd_args
[docs] def get_command_class(self, name): # NOTE: used to leverage the 'traverse()' method which takes into consideration # path handling quirks for relative paths dummy = commands.NoCommand(self) for d in self.environment['PATH'].split(':'): effective_dir = name if os.path.isabs(name) else os.path.join(d, name) try: command = dummy.traverse(effective_dir) if ICommand.providedBy(command): return command.cmd except ForbiddenAttribute: # skip command paths where we don't have access pass # NOTE: retained temporarily because it contains inner class return registry.get_command(name)
[docs] def expand(self, command, tokens): return list(itertools.chain.from_iterable([self.expand_token(command, i) for i in tokens]))
[docs] def expand_token(self, command, token): if re.match('.*[*[\]].*', os.path.basename(token)): base = os.path.dirname(token) current_obj = command.traverse(base) # Only if intermediate path resolves. if current_obj and IContainer.providedBy(current_obj): filtered = [os.path.join(base, i) for i in fnmatch.filter(current_obj.listnames(), os.path.basename(token))] # Mimic Bash behavior: if expansion doesn't provide results then pass the glob pattern to # the command. if filtered: return filtered return [token]
@protocolInlineCallbacks
[docs] def handle_TAB(self): """Handles tab completion.""" partial, rest, completions = yield completion.complete(self, self.lineBuffer, self.lineBufferIndex) if len(completions) == 1: space = '' if rest else ' ' # handle quote closing if self.lineBuffer[self.lineBufferIndex - len(partial) - 1] == '"': space = '" ' # Avoid space after '=' just for aestetics. # Avoid space after '/' for functionality. for i in ('=', '/'): if completions[0].endswith(i): space = '' patch = completions[0][len(partial):] + space # Drop @, *, half hack for i in ('@', '*'): if patch.endswith(i + ' '): patch = patch.rstrip(i + ' ') + ' ' self.insert_text(patch) elif len(completions) > 1: common_prefix = os.path.commonprefix(completions) patch = common_prefix[len(partial):] self.insert_text(patch) # postpone showing list of possible completions until next tab if not patch: self.terminal.nextLine() _, _, completions = yield completion.complete(self, self.lineBuffer, self.lineBufferIndex, display=True) # reorder optional values at end for readability required = [] optional = [] for comp in completions: (optional if comp.startswith('[') else required).append(comp) completions = required + optional completions = [self.colorize(self._completion_color(item), item) for item in completions] self.terminal.write(columnize(completions, self.width)) self.drawInputLine() if len(rest): self.terminal.cursorBackward(len(rest))
def _completion_color(self, completion): if completion.endswith('/'): return BLUE elif completion.endswith('@'): return CYAN elif completion.endswith('*'): return GREEN else: return None @property
[docs] def hist_file_name(self): return os.path.expanduser('~/.oms_history')
@property
[docs] def ps(self): user = self.principal.id if self.principal else 'user' ps1 = '%s@%s:%s%s ' % (user, 'oms', self._cwd(), '#') return [ps1, '... ']
def _cwd(self): return self.make_path(self.path) @staticmethod
[docs] def make_path(path): return '/'.join(path) or '/'
[docs] def handle_EOF(self): (self.sub_protocol or super(OmsShellProtocol, self)).handle_EOF()
[docs]class CommandExecutionSubProtocol(object): def __init__(self, parent): self.parent = parent self.buffer = []
[docs] def handle_EOF(self): pass
def _echo(self, keyID, mod): """Echoes characters on terminal like on unix (special chars etc)""" ch = keyID if isinstance(keyID, str): if ord(keyID) == 127: ch = '^H' if ord(keyID) < 32 and keyID != '\r': ch = '^' + chr(ord('A') + ord(keyID) - 1) self.parent.terminal.write(ch) if keyID in ('\r', CTRL_C): self.parent.terminal.write('\n')
[docs] def keystrokeReceived(self, keyID, mod): self._echo(keyID, mod) # HACK: poor man's interrupt if keyID == CTRL_C: return self.parent.exit_sub_protocol() self.buffer.append((keyID, mod))
[docs]class ProtocolContextExtractor(Subscription): implements(IContextExtractor) context(OmsShellProtocol)
[docs] def get_context(self): return {'interaction': self.context.interaction} # HACK: Monkey patch # TODO: handle this with custom ServerProtocol
[docs]def dataReceived(self, data): return self.terminalProtocol.dataReceived(data)
ServerProtocol._orig_dataReceived = ServerProtocol.dataReceived ServerProtocol.dataReceived = dataReceived

This Page