Skip to content

feat(api): New MaskString#1709

Merged
Fizzadar merged 5 commits into
pyinfra-dev:3.xfrom
gwelch-contegix:maskstring
May 28, 2026
Merged

feat(api): New MaskString#1709
Fizzadar merged 5 commits into
pyinfra-dev:3.xfrom
gwelch-contegix:maskstring

Conversation

@gwelch-contegix

Copy link
Copy Markdown
Contributor

This implements a new MaskString type which preserves the masking of the string in most contexts.
This does not attempt to fix an un-masked string showing up in 'diffed' output.

This makes it so that when MaskString is used in an operation argument it will show up masked instead of with the real value.
This also updates the masked value to be *MASKED* and updates all current usage to output valid commands.
Previously the unquoted *** value would show up in commands during verbose output with no context as to why it was there.

This MaskString implementation attemps to preserve normal string operations as much as possible and when it can't it only operates on the masked value (by default *MASKED*) see the tests for how it interacts with the str class on standard operations.

As this would be a breaking change for custom usage of MaskString it would may make sense to add this as another name and deprecate MaskString instead of replacing it directly.

  • Pull request is based on the default branch (3.x at this time)
  • Pull request includes tests for any new/updated operations/facts
  • Pull request includes documentation for any new/updated operations/facts
  • Tests pass (see scripts/dev-test.sh)
  • Type checking & code style passes (see scripts/dev-lint.sh)

Note that AI was used to generate part of the tests that's why some of the values chosen change between tests.

Fixes #1148

@wowi42 wowi42 added new feature API API mode specific issues. labels May 7, 2026

@Fizzadar Fizzadar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments - I like the direction this is going in, much more concrete/safe implementation than just raw subclass.

As this would be a breaking change for custom usage of MaskString it would may make sense to add this as another name and deprecate MaskString instead of replacing it directly.

Agreed!

Comment thread src/pyinfra/api/maskstring.py Outdated
Comment thread src/pyinfra/api/maskstring.py Outdated
@gwelch-contegix

gwelch-contegix commented May 13, 2026

Copy link
Copy Markdown
Contributor Author

I have no actual preference on what it gets called so let me know if you have a particular name in mind.
I can also squash this down so it's just a single commit if that is wanted.

This actually works out fine if someone wants a full string implementation, they can use multiple inheritance and make their own

str, HiddenValue subclass
import sys
from pyinfra.api import HiddenValue

from typing_extensions import override


class MaskString(str, HiddenValue):
    """
    A string subclass that contains the equivalent of "*MASKED*"
    when used as a normal string.
    To retrieve the real value use .unmask()

    Most methods are copied from UserString
    """

    # Note that adding or otherwise modifying a MaskString (or subclass)
    # will create this base class as a base class could take for example a secret ID instead of a plain value to use

    def unmask(self) -> str:
        return self.raw_value  # type: ignore[attr-defined]

    @override
    def __new__(cls, content="", masked_value="*MASKED*"):
        # Create a new string object with the value "*MASKED*"
        s = super().__new__(cls, masked_value)
        # Real value is stored here so that only those aware of the type can get the real value
        s.raw_value = content  # type: ignore[attr-defined]
        return s

    def __init__(self, *args, **kwargs) -> None:
        ...

    @override
    def __hash__(self):
        # We want hashing to work correctly so we hash the unmasked string
        return hash(self.unmask())

    # Transparently allow operations with other MaskString
    # Use the masked value if it's not a MaskString
    @override
    def __eq__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() == other.unmask()
        return super().__eq__(other)

    @override
    def __ne__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() != other.unmask()
        return super().__ne__(other)

    @override
    def __lt__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() < other.unmask()
        return super().__lt__(other)

    @override
    def __le__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() <= other.unmask()
        return super().__le__(other)

    @override
    def __gt__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() > other.unmask()
        return super().__gt__(other)

    @override
    def __ge__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            return self.unmask() >= other.unmask()
        return super().__ge__(other)

    @override
    def __contains__(self, other) -> bool:
        if isinstance(other, HiddenValue):
            other = other.unmask()
            return other in self.unmask()
        return super().__contains__(other)

    # Explicitly don't support len and index access to the real string
    # @override
    # def __len__(self) -> int:
    #     return len(self.unmask())
    # @override
    # def __getitem__(self, index) -> str:
    #     return MaskString(self.unmask()[index])

    # MaskString is viral if you do an operation with it it becomes a mask string.
    # This will probably break things like pathlib.Path
    @override
    def __add__(self, other) -> "MaskString":
        other_s = other
        if isinstance(other, HiddenValue):
            other_s = other.unmask()
            other = ""
        return MaskString(self.unmask() + other_s, masked_value=super().__add__(other))

    def __radd__(self, other) -> "MaskString":
        other_s = other
        if isinstance(other, HiddenValue):
            other_s = other.unmask()
            other = ""
        return MaskString(other_s + self.unmask(), masked_value=other + str(self))

    # strings don't multiply together. Let default exception propagate
    @override
    def __mul__(self, n) -> "MaskString":
        return MaskString(self.unmask() * n, masked_value=super().__mul__(n))

    __rmul__ = __mul__

    # % formatting will fail if there are unused format directives so don't allow formatting a MaskString
    # @override
    # def __mod__(self, args):
    #     return MaskString(self.unmask() % args)

    # % formatting will fail if there are unused format directives so don't allow formatting a MaskString
    # @override
    # def __rmod__(self, template):
    #     return MaskString(str(template) % self)

    @override
    # the following methods are defined in alphabetical order:
    def capitalize(self) -> "MaskString":
        return MaskString(self.unmask().capitalize(), masked_value=super().capitalize())

    @override
    def casefold(self) -> "MaskString":
        return MaskString(self.unmask().casefold(), masked_value=super().casefold())

    @override
    def center(self, width, *args) -> "MaskString":
        return MaskString(
            self.unmask().center(width, *args), masked_value=super().center(width, *args)
        )

    @override
    def count(self, sub, start=0, end=sys.maxsize) -> int:
        if isinstance(sub, HiddenValue):
            sub = sub.unmask()
            return self.unmask().count(sub, start, end)
        return super().count(sub, start, end)

    @override
    def removeprefix(self, prefix, /) -> "MaskString":
        prefix_s = prefix
        if isinstance(prefix, HiddenValue):
            prefix_s = prefix.unmask()
        return MaskString(
            self.unmask().removeprefix(prefix_s), masked_value=super().removeprefix(prefix)
        )

    @override
    def removesuffix(self, suffix, /) -> "MaskString":
        suffix_s = suffix
        if isinstance(suffix, HiddenValue):
            suffix_s = suffix.unmask()
        return MaskString(
            self.unmask().removesuffix(suffix_s), masked_value=super().removesuffix(suffix)
        )

    # @override
    # def encode(self, encoding="utf-8", errors="strict") -> bytes:
    #     encoding = "utf-8" if encoding is None else encoding
    #     errors = "strict" if errors is None else errors
    #     return super().encode(encoding, errors)

    @override
    def endswith(self, suffix, start=0, end=sys.maxsize) -> bool:
        if isinstance(suffix, HiddenValue):
            return self.unmask().endswith(suffix.unmask(), start, end)
        return super().endswith(suffix, start, end)

    @override
    def expandtabs(self, tabsize=8):
        return MaskString(
            self.unmask().expandtabs(tabsize), masked_value=super().expandtabs(tabsize)
        )

    @override
    def find(self, sub, start=0, end=sys.maxsize) -> int:
        if isinstance(sub, HiddenValue):
            return self.unmask().find(sub.unmask(), start, end)
        return super().find(sub, start, end)

    # format will fail if there are unused format directives so don't allow formatting a MaskString
    # @override
    # def format(self, /, *args, **kwds) -> str:
    #     unmask = kwds.pop("unmask", False)
    #     return str(self)

    # format will fail if there are unused format directives so don't allow formatting a MaskString
    # @override
    # def format_map(self, mapping) -> str:
    #     return str(self)

    @override
    def index(self, sub, start=0, end=sys.maxsize) -> int:
        if isinstance(sub, HiddenValue):
            return self.unmask().index(sub.unmask(), start, end)
        return super().index(sub, start, end)

    # Not sure if this should return the unmasked result or the masked result
    # @override
    # def isalpha(self) -> bool:
    #     return self.unmask().isalpha()

    # @override
    # def isalnum(self) -> bool:
    #     return self.unmask().isalnum()

    # @override
    # def isascii(self) -> bool:
    #     return self.unmask().isascii()

    # @override
    # def isdecimal(self) -> bool:
    #     return self.unmask().isdecimal()

    # @override
    # def isdigit(self) -> bool:
    #     return self.unmask().isdigit()

    # @override
    # def isidentifier(self) -> bool:
    #     return self.unmask().isidentifier()

    # @override
    # def islower(self) -> bool:
    #     return self.unmask().islower()

    # @override
    # def isnumeric(self) -> bool:
    #     return self.unmask().isnumeric()

    # @override
    # def isprintable(self) -> bool:
    #     return self.unmask().isprintable()

    # @override
    # def isspace(self) -> bool:
    #     return self.unmask().isspace()

    # @override
    # def istitle(self) -> bool:
    #     return self.unmask().istitle()

    # @override
    # def isupper(self) -> bool:
    #     return self.unmask().isupper()

    @override
    def join(self, seq) -> "MaskString":
        seq = list(seq)
        seq_s = [x.unmask() if isinstance(x, HiddenValue) else x for x in seq]
        # We just make the masked_value the default. Otherwise we end up with '*MASKED*' repeated a bunch which doesn't really help anything
        return MaskString(self.unmask().join(seq_s))

    @override
    def ljust(self, width, *args) -> "MaskString":
        return MaskString(
            self.unmask().ljust(width, *args), masked_value=super().ljust(width, *args)
        )

    @override
    def lower(self) -> "MaskString":
        return MaskString(self.unmask().lower(), masked_value=super().lower())

    @override
    def lstrip(self, chars=None) -> "MaskString":
        chars_s = chars
        if isinstance(chars, HiddenValue):
            # we want the resulting masked string to stay the same as otherwise it would most likely just strip all the characters
            chars_s = chars.unmask()
            return MaskString(self.unmask().lstrip(chars_s), masked_value=str(self))
        return MaskString(self.unmask().lstrip(chars), masked_value=super().lstrip(chars))

    # I don't know how maketrans works... default to using the masked value
    # maketrans = str.maketrans

    @override
    def partition(self, sep) -> tuple["MaskString", "MaskString", "MaskString"]:
        sep_s = sep
        if isinstance(sep, HiddenValue):
            sep_s = sep.unmask()

        s = self.unmask().partition(sep_s)
        u = super().partition(sep)
        return (
            MaskString(s[0], masked_value=u[0]),
            MaskString(s[1], masked_value=u[1]),
            MaskString(s[2], masked_value=u[2]),
        )

    @override
    def replace(self, old, new, maxsplit=-1) -> "MaskString":
        old_s = old
        new_s = new
        if isinstance(old, HiddenValue):
            old_s = old.unmask()
        if isinstance(new, HiddenValue):
            new_s = new.unmask()
        return MaskString(
            self.unmask().replace(old_s, new_s, maxsplit), masked_value=super().replace(old, new)
        )

    @override
    def rfind(self, sub, start=0, end=sys.maxsize) -> int:
        if isinstance(sub, HiddenValue):
            return self.unmask().rfind(sub.unmask(), start, end)
        return super().rfind(sub, start, end)

    @override
    def rindex(self, sub, start=0, end=sys.maxsize) -> int:
        if isinstance(sub, HiddenValue):
            return self.unmask().rindex(sub.unmask(), start, end)
        return super().rindex(sub, start, end)

    @override
    def rjust(self, width, *args) -> "MaskString":
        return MaskString(
            self.unmask().rjust(width, *args), masked_value=super().rjust(width, *args)
        )

    @override
    def rpartition(self, sep) -> tuple["MaskString", "MaskString", "MaskString"]:
        sep_s = sep
        if isinstance(sep, HiddenValue):
            sep_s = sep.unmask()

        s = self.unmask().rpartition(sep_s)
        u = super().rpartition(sep)
        return (
            MaskString(s[0], masked_value=u[0]),
            MaskString(s[1], masked_value=u[1]),
            MaskString(s[2], masked_value=u[2]),
        )

    @override
    def rstrip(self, chars=None) -> "MaskString":
        chars_s = chars
        if isinstance(chars, HiddenValue):
            # we want the resulting masked string to stay the same as otherwise it would most likely just strip all the characters
            chars_s = chars.unmask()
            return MaskString(self.unmask().rstrip(chars_s), masked_value=str(self))
        return MaskString(self.unmask().rstrip(chars_s), masked_value=super().rstrip(chars))

    # There is no way to keep the split functions "in-sync" with the masked string and the raw_value
    # @override
    # def split(self, sep=None, maxsplit=-1) -> list[MaskString]:
    #     if isinstance(sep, HiddenValue):
    #         sep = sep.unmask()
    #         return [MaskString(x) for x in self.unmask().split(sep, maxsplit)]
    #     return super().split(sep, maxsplit)

    # @override
    # def rsplit(self, sep=None, maxsplit=-1) -> list[MaskString]:
    #     if isinstance(sep, HiddenValue):
    #         sep = sep.unmask()
    #         return [MaskString(x) for x in self.unmask().rsplit(sep, maxsplit)]
    #     return super().rsplit(sep, maxsplit)

    # @override
    # def splitlines(self, keepends=False) -> list[MaskString]:
    #     return [MaskString(x) for x in self.unmask().splitlines(keepends)]

    @override
    def startswith(self, prefix, start=0, end=sys.maxsize) -> bool:
        if isinstance(prefix, HiddenValue):
            prefix = prefix.unmask()
            return self.unmask().startswith(prefix, start, end)
        return super().startswith(prefix, start, end)

    @override
    def strip(self, chars=None) -> "MaskString":
        chars_s = chars
        if isinstance(chars, HiddenValue):
            # we want the resulting masked string to stay the same as otherwise it would most likely just strip all the characters
            chars_s = chars.unmask()
            return MaskString(self.unmask().strip(chars_s), masked_value=str(self))
        return MaskString(self.unmask().strip(chars_s), masked_value=super().strip(chars))

    @override
    def swapcase(self) -> "MaskString":
        return MaskString(self.unmask().swapcase(), masked_value=super().swapcase())

    @override
    def title(self) -> "MaskString":
        return MaskString(self.unmask().title(), masked_value=super().title())

    # I don't know how translate works... default to using the masked value
    # @override
    # def translate(self, *args) -> 'MaskString':
    #     return MaskString(self.unmask().translate(*args))

    @override
    def upper(self) -> "MaskString":
        return MaskString(self.unmask().upper(), masked_value=super().upper())

    @override
    def zfill(self, width) -> "MaskString":
        return MaskString(self.unmask().zfill(width), masked_value=super().zfill(width))
str, HiddenValue tests
from unittest import TestCase
from .maskstring import MaskString
import copy


class OtherMaskString(MaskString):
    """Subclass of MaskString to validate basic usage"""

    def __new__(cls, service: str = "", username: str = "") -> "OtherMaskString":
        mask_string = super().__new__(cls, "fake_value", masked_value="FAKE NEWS")
        mask_string.__service = service
        mask_string.__username = username
        return mask_string

    def unmask(self) -> str:
        return self.__service + self.__username


class TestMaskString(TestCase):
    def test_capitalize(self):
        s = MaskString("secret value")
        new_s = s.capitalize()
        assert new_s.unmask() == "Secret value"
        assert new_s == "*masked*"

    def test_casefold(self):
        s = MaskString("SECRET")
        new_s = s.casefold()
        assert new_s.unmask() == "secret"
        assert new_s == "*masked*"

    def test_center(self):
        s = MaskString("hi")
        new_s = s.center(10)
        assert new_s.unmask() == "    hi    "

        new_s = s.center(10, "-")
        assert new_s.unmask() == "----hi----"

    def test_contains(self):
        s = MaskString("Secret Value")
        inner = MaskString("Secret")
        assert inner in s
        assert "Secret" not in s
        assert "MASKED" in s
        assert "*" in s

    def test_count(self):
        s = MaskString("hello world")
        sub = "l"
        assert s.count(sub) == 0  # counts in masked value '*MASKED*'
        sub = MaskString("l")
        assert s.count(sub) == 3  # counts in unmasked value
        assert s.count(sub, 0, 3) == 1
        assert s.count(sub, 3) == 2

    def test_endswith(self):
        s = MaskString("secret value")
        end = MaskString("value")
        assert s.endswith("*") is True
        assert s.endswith(end) is True
        assert s.endswith("SKED*")
        assert not s.endswith("value")

    def test_expandtabs(self):
        s = MaskString("a\tb")
        new_s = s.expandtabs(4)
        assert new_s.unmask() == "a   b"

        new_s = s.expandtabs()
        assert new_s.unmask() == "a       b"

    def test_find(self):
        s = MaskString("hello world")
        assert s.find("hello") == -1  # searches masked value

        sub = MaskString("missing")  # searches unmasked value
        assert s.find(sub) == -1

        sub = MaskString("l")
        assert s.find(sub) == 2

        assert s.find(sub, 4) == 9
        assert s.find(sub, 10) == -1

    def test_eq(self):
        s1 = MaskString("Secret Value")
        s2 = MaskString("Secret Value")
        s3 = MaskString("Other Value")

        assert s1 == s2
        assert s1 != s3

        assert s1 == "*MASKED*"
        assert s2 == "*MASKED*"
        assert s3 == "*MASKED*"

    def test_ne(self):
        s1 = MaskString("secret")
        s2 = MaskString("other")
        s3 = MaskString("secret")
        assert s1 != s2
        assert not (s1 != s3)
        # plain str compares against masked value
        assert s1 != "secret"
        assert s1 == "*MASKED*"

    def test_add_str(self):
        s = MaskString("Secret Value")
        new_s = s + " test"
        assert new_s == "*MASKED* test"
        assert new_s.unmask() == "Secret Value test"

        new_s = s + MaskString(" test")
        assert new_s == "*MASKED*"
        assert new_s.unmask() == "Secret Value test"

    def test_radd(self):
        s = MaskString("Secret Value")
        new_s = "test " + s
        assert new_s == "test *MASKED*"
        assert new_s.unmask() == "test Secret Value"

        s2 = MaskString("test ")
        new_s = s.__radd__(s2)
        assert new_s == "*MASKED*"
        assert new_s.unmask() == "test Secret Value"

    def test_ge(self):
        s1 = MaskString("b")
        s2 = MaskString("a")
        s3 = MaskString("b")
        assert s1 >= s2
        assert s1 >= s3
        assert s1 >= "!"  # '*' > '!'
        assert s1 >= "*"  # len(s1) > len('*')
        assert s1 == "*MASKED*"
        assert s2 == "*MASKED*"
        assert s3 == "*MASKED*"

    def test_gt(self):
        s1 = MaskString("b")
        s2 = MaskString("a")
        assert s1 > s2
        assert s1 > "!"  # '*' > '!'
        assert s1 > "*"  # len(s1) > len('*')
        assert s1 == "*MASKED*"
        assert s2 == "*MASKED*"

    def test_le(self):
        s1 = MaskString("a")
        s2 = MaskString("b")
        s3 = MaskString("a")
        assert s1 <= s2
        assert s1 <= s3
        assert s1 <= "0"  # '*' <= '0'
        assert not s1 <= "*"  # len(s1) > len('*')
        assert s1 == "*MASKED*"
        assert s2 == "*MASKED*"
        assert s3 == "*MASKED*"

    def test_lt(self):
        s = MaskString("a")
        s1 = MaskString("b")
        assert s < s1
        assert s1 < "0"  # '*' <= '0'
        assert not s1 < "*"  # len(s1) > len('*')
        assert s == "*MASKED*"
        assert s1 == "*MASKED*"

    def test_mul(self):
        s = MaskString("ab")
        new_s = s * 3
        assert new_s == "*MASKED**MASKED**MASKED*"
        assert new_s.unmask() == "ababab"

    def test_rmul(self):
        s = MaskString("ab")
        new_s = 3 * s
        assert new_s == "*MASKED**MASKED**MASKED*"
        assert new_s.unmask() == "ababab"

    def test_hash(self):
        s1 = MaskString("secret")
        s2 = MaskString("secret")
        assert hash(s1) == hash(s2)
        assert hash(s1) == hash("secret")

    def test_hash_dict_key(self):
        s = MaskString("key")
        d = {s: "value"}
        assert d[MaskString("key")] == "value"
        assert (
            d.get("key", None) != "value"
        )  # plain str won't match since dict takes the type into account

    def test_index_found(self):
        s = MaskString("secret value")
        sub = MaskString("value")
        assert s.index(sub) == 7

        sub = MaskString("missing")
        with self.assertRaises(ValueError):
            s.index(sub)

        # plain str searches the masked value '*MASKED*'
        assert s.index("*") == 0

    def test_join(self):
        s = MaskString("-")
        parts = [MaskString("a"), MaskString("b"), MaskString("c")]
        new_s = s.join(parts)
        assert new_s.unmask() == "a-b-c"
        assert (
            new_s == "*MASKED*"
        )  # The other option is to make it return '*MASKED**MASKED**MASKED**MASKED**MASKED*' which seems unnecessary

        parts = ["a", MaskString("b"), "c"]
        new_s = s.join(parts)
        assert new_s.unmask() == "a-b-c"
        assert new_s == "*MASKED*"

        parts = ["a", "b", "c"]
        new_s = s.join(parts)
        assert new_s.unmask() == "a-b-c"
        assert new_s == "*MASKED*"

    def test_ljust(self):
        s = MaskString("hi")
        new_s = s.ljust(10)
        assert new_s.unmask() == "hi        "
        assert new_s == "*MASKED*  "

        new_s = s.ljust(10, "-")
        assert new_s.unmask() == "hi--------"
        assert new_s == "*MASKED*--"

    def test_lower(self):
        s = MaskString("SECRET")
        new_s = s.lower()
        assert new_s.unmask() == "secret"
        assert new_s == "*masked*"

    def test_lstrip(self):
        s = MaskString("  secret  ")
        new_s = s.lstrip()
        assert new_s.unmask() == "secret  "

        s = MaskString("**secret**")
        new_s = s.lstrip("*")
        assert new_s.unmask() == "secret**"
        assert new_s == "MASKED*"

        s = MaskString("**secret**")
        new_s = s.lstrip(MaskString("*"))
        assert new_s.unmask() == "secret**"
        assert new_s == "*MASKED*"

        s = MaskString("  secret  ")
        new_s = s.lstrip(MaskString(" "))
        assert new_s.unmask() == "secret  "
        assert new_s == "*MASKED*"

    def test_partition(self):
        s = MaskString("secret:value")
        before, sep, after = s.partition(":")
        assert before.unmask() == "secret"
        assert sep.unmask() == ":"
        assert after.unmask() == "value"

        assert before == "*MASKED*"
        assert sep == ""  # maybe set these to '*MASKED*' since this output isn't helpful
        assert after == ""

        before, sep, after = s.partition(MaskString(":"))
        assert before.unmask() == "secret"
        assert sep.unmask() == ":"
        assert after.unmask() == "value"

        assert before == ""
        assert sep == "*MASKED*"
        assert after == ""

        before, sep, after = s.partition("|")
        assert before.unmask() == "secret:value"
        assert sep.unmask() == ""
        assert after.unmask() == ""

        assert before == "*MASKED*"
        assert sep == ""  # maybe set these to '*MASKED*' since this output isn't helpful
        assert after == ""

        before, sep, after = s.partition(MaskString("|"))
        assert before.unmask() == "secret:value"
        assert sep.unmask() == ""
        assert after.unmask() == ""

        assert before == ""
        assert sep == "*MASKED*"
        assert after == ""

    def test_pickle(self):
        import pickle

        s = MaskString("secret")
        pickled = pickle.dumps(s)
        restored = pickle.loads(pickled)
        assert isinstance(restored, MaskString)
        assert restored.unmask() == "secret"
        assert restored == "*MASKED*"

    def test_removeprefix(self):
        s = MaskString("secret value")
        new_s = s.removeprefix("secret ")
        assert new_s.unmask() == "value"

        prefix = MaskString("secret ")
        new_s = s.removeprefix(prefix)
        assert new_s.unmask() == "value"

        new_s = s.removeprefix("other")
        assert new_s.unmask() == "secret value"
        assert new_s == "*MASKED*"

    def test_removesuffix(self):
        s = MaskString("secret value")
        new_s = s.removesuffix(" value")
        assert new_s.unmask() == "secret"

        suffix = MaskString(" value")
        new_s = s.removesuffix(suffix)
        assert new_s.unmask() == "secret"

        new_s = s.removesuffix("other")
        assert new_s.unmask() == "secret value"
        assert new_s == "*MASKED*"

    def test_replace(self):
        s = MaskString("secret value")
        new_s = s.replace("value", "data")
        assert new_s.unmask() == "secret data"
        assert new_s == "*MASKED*"

        old = MaskString("value")
        new = MaskString("data")
        new_s = s.replace(old, new)
        assert new_s.unmask() == "secret data"

    def test_replace_maxsplit(self):
        s = MaskString("aaa")
        new_s = s.replace("a", "b", 2)
        assert new_s.unmask() == "bba"

    def test_result_is_maskstring(self):
        # Verify all transform methods return MaskString instances
        s = MaskString("secret")
        assert isinstance(s.upper(), MaskString)
        assert isinstance(s.lower(), MaskString)
        assert isinstance(s.capitalize(), MaskString)
        assert isinstance(s.swapcase(), MaskString)
        assert isinstance(s.title(), MaskString)
        assert isinstance(s.strip(), MaskString)
        assert isinstance(s.lstrip(), MaskString)
        assert isinstance(s.rstrip(), MaskString)
        assert isinstance(s + "", MaskString)
        assert isinstance("" + s, MaskString)
        assert isinstance(s * 2, MaskString)
        assert isinstance(s.center(10), MaskString)
        assert isinstance(s.ljust(10), MaskString)
        assert isinstance(s.rjust(10), MaskString)
        assert isinstance(s.zfill(10), MaskString)
        assert isinstance(s.replace("x", "y"), MaskString)
        assert isinstance(s.removeprefix("x"), MaskString)
        assert isinstance(s.removesuffix("x"), MaskString)
        assert isinstance(s.expandtabs(), MaskString)
        assert isinstance(s.casefold(), MaskString)

    def test_rfind(self):
        s = MaskString("abcabc")
        sub = MaskString("a")
        assert s.rfind(sub) == 3
        assert s.rfind("*") == 7  # last '*' in '*MASKED*'

    def test_rindex_found(self):
        s = MaskString("abcabc")
        sub = MaskString("a")
        assert s.rindex(sub) == 3

        sub = MaskString("missing")
        with self.assertRaises(ValueError):
            s.rindex(sub)
        # plain str searches the masked value '*MASKED*'
        assert s.rindex("*") == 7

    def test_rjust(self):
        s = MaskString("hi")
        new_s = s.rjust(10)
        assert new_s.unmask() == "        hi"

        new_s = s.rjust(10, "-")
        assert new_s.unmask() == "--------hi"

    def test_rpartition(self):
        s = MaskString("secret:value")
        before, sep, after = s.rpartition(":")
        assert before.unmask() == "secret"
        assert sep.unmask() == ":"
        assert after.unmask() == "value"

        assert before == ""  # maybe set these to '*MASKED*' since this output isn't helpful
        assert sep == ""
        assert after == "*MASKED*"

        before, sep, after = s.rpartition(MaskString(":"))
        assert before.unmask() == "secret"
        assert sep.unmask() == ":"
        assert after.unmask() == "value"

        assert before == ""
        assert sep == "*MASKED*"
        assert after == ""

        before, sep, after = s.rpartition("|")
        assert before.unmask() == ""
        assert sep.unmask() == ""
        assert after.unmask() == "secret:value"

        assert before == ""
        assert sep == ""
        assert after == "*MASKED*"

        before, sep, after = s.rpartition(MaskString("|"))
        assert before.unmask() == ""
        assert sep.unmask() == ""
        assert after.unmask() == "secret:value"

        assert before == ""
        assert sep == "*MASKED*"
        assert after == ""

    def test_rstrip(self):
        s = MaskString("  secret  ")
        new_s = s.rstrip()
        assert new_s.unmask() == "  secret"

        s = MaskString("**secret**")
        new_s = s.rstrip("*")
        assert new_s.unmask() == "**secret"
        assert new_s == "*MASKED"

        s = MaskString("**secret**")
        new_s = s.rstrip(MaskString("*"))
        assert new_s.unmask() == "**secret"
        assert new_s == "*MASKED*"

        s = MaskString("  secret  ")
        new_s = s.rstrip(MaskString(" "))
        assert new_s.unmask() == "  secret"
        assert new_s == "*MASKED*"

    def test_startswith(self):
        s = MaskString("secret value")
        assert s.startswith("*")
        assert not s.startswith("secret")
        prefix = MaskString("secret")
        assert s.startswith(prefix)

        prefix = MaskString("value")
        assert s.startswith(prefix, 7)

    def test_str_repr_masked(self):
        s = MaskString("top secret")
        assert str(s) == "*MASKED*"
        assert repr(s) == "'*MASKED*'"

    def test_strip(self):
        s = MaskString("  secret  ")
        new_s = s.strip()
        assert new_s.unmask() == "secret"

        s = MaskString("**secret**")
        new_s = s.strip("*")
        assert new_s.unmask() == "secret"
        assert new_s == "MASKED"

        s = MaskString("**secret**")
        new_s = s.strip(MaskString("*"))
        assert new_s.unmask() == "secret"
        assert new_s == "*MASKED*"

        s = MaskString("  secret  ")
        new_s = s.strip(MaskString(" "))
        assert new_s.unmask() == "secret"
        assert new_s == "*MASKED*"

    def test_swapcase(self):
        s = MaskString("Secret")
        new_s = s.swapcase()
        assert new_s.unmask() == "sECRET"
        assert new_s == "*masked*"

    def test_title(self):
        s = MaskString("secret value")
        new_s = s.title()
        assert new_s.unmask() == "Secret Value"
        assert new_s == "*Masked*"

    def test_upper(self):
        s = MaskString("secret")
        new_s = s.upper()
        assert new_s.unmask() == "SECRET"
        assert new_s == "*MASKED*"

    def test_viral_add_chain(self):
        s = MaskString("secret")
        result = s + " part1" + " part2"
        assert result.unmask() == "secret part1 part2"
        assert result == "*MASKED* part1 part2"

        result = "part1 " + s + " part2"
        assert result.unmask() == "part1 secret part2"
        assert result == "part1 *MASKED* part2"

    def test_zfill(self):
        s = MaskString("12345")
        new_s = s.zfill(6)
        assert new_s.unmask() == "012345"

        new_s = s.zfill(3)
        assert new_s.unmask() == "12345"
        assert new_s == "*MASKED*"

    def test_deepcopy(self):
        s = MaskString("12345")
        new_s = copy.deepcopy(s)
        assert new_s.unmask() == "12345"
        assert new_s == "*MASKED*"

    def test_deepcopy(self):
        s = OtherMaskString(service="ssh", username="pyinfra")
        new_s = copy.deepcopy(s)
        assert isinstance(new_s, OtherMaskString)
        assert new_s.unmask() == "sshpyinfra"
        assert new_s == "FAKE NEWS"

@Fizzadar Fizzadar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor fixes, this is coming together nicely.

Comment thread src/pyinfra/connectors/util.py Outdated
Comment thread src/pyinfra/connectors/util.py Outdated
Comment thread src/pyinfra/api/command.py Outdated
Comment thread src/pyinfra/api/command.py Outdated

@Fizzadar Fizzadar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @gwelch-contegix, this is really nice

@Fizzadar Fizzadar merged commit 043aca5 into pyinfra-dev:3.x May 28, 2026
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

API API mode specific issues. new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mysql_password (and probably lots of others) is displayed all over the logs when using a single -v

3 participants