# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for Python Fire's tests.""" import contextlib import io import os import re import sys import unittest from unittest import mock from fire import core from fire import trace class BaseTestCase(unittest.TestCase): """Shared test case for Python Fire tests.""" @contextlib.contextmanager def assertOutputMatches(self, stdout='.*', stderr='.*', capture=True): """Asserts that the context generates stdout and stderr matching regexps. Note: If wrapped code raises an exception, stdout and stderr will not be checked. Args: stdout: (str) regexp to match against stdout (None will check no stdout) stderr: (str) regexp to match against stderr (None will check no stderr) capture: (bool, default True) do not bubble up stdout or stderr Yields: Yields to the wrapped context. """ stdout_fp = io.StringIO() stderr_fp = io.StringIO() try: with mock.patch.object(sys, 'stdout', stdout_fp): with mock.patch.object(sys, 'stderr', stderr_fp): yield finally: if not capture: sys.stdout.write(stdout_fp.getvalue()) sys.stderr.write(stderr_fp.getvalue()) for name, regexp, fp in [('stdout', stdout, stdout_fp), ('stderr', stderr, stderr_fp)]: value = fp.getvalue() if regexp is None: if value: raise AssertionError('%s: Expected no output. Got: %r' % (name, value)) else: if not re.search(regexp, value, re.DOTALL | re.MULTILINE): raise AssertionError('%s: Expected %r to match %r' % (name, value, regexp)) @contextlib.contextmanager def assertRaisesFireExit(self, code, regexp='.*'): """Asserts that a FireExit error is raised in the context. Allows tests to check that Fire's wrapper around SystemExit is raised and that a regexp is matched in the output. Args: code: The status code that the FireExit should contain. regexp: stdout must match this regex. Yields: Yields to the wrapped context. """ with self.assertOutputMatches(stderr=regexp): with self.assertRaises(core.FireExit): try: yield except core.FireExit as exc: if exc.code != code: raise AssertionError('Incorrect exit code: %r != %r' % (exc.code, code)) self.assertIsInstance(exc.trace, trace.FireTrace) raise @contextlib.contextmanager def ChangeDirectory(directory): """Context manager to mock a directory change and revert on exit.""" cwdir = os.getcwd() os.chdir(directory) try: yield directory finally: os.chdir(cwdir) # pylint: disable=invalid-name main = unittest.main skip = unittest.skip skipIf = unittest.skipIf # pylint: enable=invalid-name