diff --git a/tests/fixtures/self_assert/assertEqual_in.py b/tests/fixtures/self_assert/assertEqual_in.py index eaace0e..bfe2bfc 100644 --- a/tests/fixtures/self_assert/assertEqual_in.py +++ b/tests/fixtures/self_assert/assertEqual_in.py @@ -32,6 +32,11 @@ def test_line_wrapping(self): msg='Wrap %s' % 'everything') + def test_implicit_string_with_comment(self): + self.assertEqual('a' + # keep this comment + 'b', actual) + def test_expression_as_argument(self): self.assertEqual(abc not in self.data, True) self.assertEqual(abc in self.data, not contains) diff --git a/tests/fixtures/self_assert/assertEqual_out.py b/tests/fixtures/self_assert/assertEqual_out.py index c197544..bfa18ca 100644 --- a/tests/fixtures/self_assert/assertEqual_out.py +++ b/tests/fixtures/self_assert/assertEqual_out.py @@ -30,6 +30,11 @@ def test_line_wrapping(self): 'Wrap %s' % \ 'everything' + def test_implicit_string_with_comment(self): + assert ('a' + # keep this comment + 'b' == actual) + def test_expression_as_argument(self): assert (abc not in self.data) == True assert (abc in self.data) == (not contains) diff --git a/unittest2pytest/fixes/fix_self_assert.py b/unittest2pytest/fixes/fix_self_assert.py index e4c9992..d6f4b75 100644 --- a/unittest2pytest/fixes/fix_self_assert.py +++ b/unittest2pytest/fixes/fix_self_assert.py @@ -100,6 +100,21 @@ def parenthesize_expression(value): return value +def has_comment_prefix(value): + return any("#" in leaf.prefix for leaf in value.leaves()) + + +def parenthesize_assert_expression_with_comments(assert_stmt): + if assert_stmt.type != syms.assert_stmt: + return + expression = assert_stmt.children[1] + if has_comment_prefix(expression): + parenthesized = parenthesize(expression.clone()) + parenthesized.prefix = expression.prefix + parenthesized.children[1].prefix = "" + assert_stmt.set_child(1, parenthesized) + + def fill_template(template, *args): parts = TEMPLATE_PATTERN.findall(template) kids = [] @@ -476,6 +491,8 @@ def process_arg(arg): if argsdict.get("msg", None) is not None and method != "fail": n_stmt.children.extend((Name(","), argsdict["msg"])) + parenthesize_assert_expression_with_comments(n_stmt) + def fix_line_wrapping(x): for c in x.children: # no need to worry about wrapping of "[", "{" and "("