diff --git a/tests/rules/test_git_checkout.py b/tests/rules/test_git_checkout.py index e4381cd..13212d7 100644 --- a/tests/rules/test_git_checkout.py +++ b/tests/rules/test_git_checkout.py @@ -1,5 +1,6 @@ import pytest -from thefuck.rules.git_checkout import match, get_new_command +from io import BytesIO +from thefuck.rules.git_checkout import match, get_branches, get_new_command from tests.utils import Command @@ -13,8 +14,10 @@ def did_not_match(target, did_you_forget=False): @pytest.fixture -def get_branches(mocker): - return mocker.patch('thefuck.rules.git_checkout.get_branches') +def git_branch(mocker, branches): + mock = mocker.patch('subprocess.Popen') + mock.return_value.stdout = BytesIO(branches) + return mock @pytest.mark.parametrize('command', [ @@ -33,21 +36,34 @@ def test_not_match(command): assert not match(command) +@pytest.mark.parametrize('branches, branch_list', [ + (b'', []), + (b'* master', ['master']), + (b' remotes/origin/master', ['master']), + (b' just-another-branch', ['just-another-branch']), + (b'* master\n just-another-branch', ['master', 'just-another-branch']), + (b'* master\n remotes/origin/master\n just-another-branch', + ['master', 'master', 'just-another-branch'])]) +def test_get_branches(branches, branch_list, git_branch): + git_branch(branches) + assert list(get_branches()) == branch_list + + @pytest.mark.parametrize('branches, command, new_command', [ - ([], + (b'', Command(script='git checkout unknown', stderr=did_not_match('unknown')), 'git branch unknown && git checkout unknown'), - ([], + (b'', Command('git commit unknown', stderr=did_not_match('unknown')), 'git branch unknown && git commit unknown'), - (['test-random-branch-123'], + (b' test-random-branch-123', Command(script='git checkout tst-rdm-brnch-123', stderr=did_not_match('tst-rdm-brnch-123')), 'git checkout test-random-branch-123'), - (['test-random-branch-123'], + (b' test-random-branch-123', Command(script='git commit tst-rdm-brnch-123', stderr=did_not_match('tst-rdm-brnch-123')), 'git commit test-random-branch-123')]) -def test_get_new_command(branches, command, new_command, get_branches): - get_branches.return_value = branches +def test_get_new_command(branches, command, new_command, git_branch): + git_branch(branches) assert get_new_command(command) == new_command