Skip to content

Commit 3108c2a

Browse files
committed
Implemented bulk replacing in files.
1 parent d643a66 commit 3108c2a

File tree

10 files changed

+502
-11
lines changed

10 files changed

+502
-11
lines changed

https_everywhere/__main__.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import asyncio
2+
import sys
3+
import typing
4+
from concurrent.futures import ThreadPoolExecutor
5+
from functools import partial
6+
from os import cpu_count
7+
from pathlib import Path
8+
9+
from binaryornot.check import is_binary
10+
from plumbum import cli
11+
12+
from .core import CombinedReplacerFactory, ReplaceContext
13+
from .core.InBufferReplacer import InBufferReplacer
14+
from .core.InFileReplacer import InFileReplacer
15+
from .replacers.HEReplacer import HEReplacer
16+
from .replacers.HSTSPreloadReplacer import HSTSPreloadReplacer
17+
18+
19+
class OurInBufferReplacer(InBufferReplacer):
20+
__slots__ = ()
21+
FACS = CombinedReplacerFactory(
22+
{
23+
"preloads": HSTSPreloadReplacer,
24+
"heRulesets": HEReplacer,
25+
}
26+
)
27+
28+
def __init__(self, preloads=None, heRulesets=None):
29+
super().__init__(preloads=preloads, heRulesets=heRulesets)
30+
31+
32+
class OurInFileReplacer(InFileReplacer):
33+
def __init__(self, preloads=None, heRulesets=None):
34+
super().__init__(OurInBufferReplacer(preloads=preloads, heRulesets=heRulesets))
35+
36+
37+
class CLI(cli.Application):
38+
"""HTTPSEverywhere-like URI rewriter"""
39+
40+
41+
class FileClassifier:
42+
__slots__ = ("noSkipDot", "noSkipBinary")
43+
44+
def __init__(self, noSkipDot: bool, noSkipBinary: bool):
45+
self.noSkipDot = noSkipDot
46+
self.noSkipBinary = noSkipBinary
47+
48+
def __call__(self, p: Path) -> str:
49+
for pa in p.parts:
50+
if not self.noSkipDot and pa[0] == ".":
51+
return "dotfile"
52+
53+
if not p.is_dir():
54+
if p.is_file():
55+
if self.noSkipBinary or not is_binary(p):
56+
return ""
57+
else:
58+
return "binary"
59+
else:
60+
return "not regular file"
61+
62+
63+
class FilesEnumerator:
64+
__slots__ = ("classifier", "disallowedReportingCallback")
65+
66+
def __init__(self, classifier, disallowedReportingCallback):
67+
self.classifier = classifier
68+
self.disallowedReportingCallback = disallowedReportingCallback
69+
70+
def __call__(self, fileOrDir: Path):
71+
reasonOfDisallowal = self.classifier(fileOrDir)
72+
if not reasonOfDisallowal:
73+
if fileOrDir.is_dir():
74+
for f in fileOrDir.iterdir():
75+
yield from self(f)
76+
else:
77+
yield fileOrDir
78+
else:
79+
self.disallowedReportingCallback(fileOrDir, reasonOfDisallowal)
80+
81+
82+
@CLI.subcommand("bulk")
83+
class FileRewriteCLI(cli.Application):
84+
"""Rewrites URIs in files. Use - to consume list of files from stdin. Don't use `find`, it is a piece of shit which is impossible to configure to skip .git dirs."""
85+
86+
__slots__ = ("_repl",)
87+
88+
@property
89+
def repl(self) -> InFileReplacer:
90+
if self._repl is None:
91+
self._repl = OurInFileReplacer()
92+
print(
93+
len(self._repl.inBufferReplacer.singleURIReplacer.children[0].preloads),
94+
"HSTS preloads",
95+
)
96+
print(len(self._repl.inBufferReplacer.singleURIReplacer.children[1].rulesets), "HE rules")
97+
return self._repl
98+
99+
def processEachFileName(self, ctx: ReplaceContext, l: str) -> Path:
100+
l = l.strip()
101+
if l:
102+
l = l.decode("utf-8")
103+
p = Path(l).resolve().absolute()
104+
self.processEachFilePath(ctx, p)
105+
106+
def processEachFilePath(self, ctx: ReplaceContext, p: Path) -> None:
107+
for pp in self.fe(p):
108+
if self.trace:
109+
print("Processing", pp, file=sys.stderr)
110+
self.repl(ctx, pp)
111+
if self.trace:
112+
print("Processed", pp, file=sys.stderr)
113+
114+
@asyncio.coroutine
115+
def asyncMainPathsFromStdIn(self):
116+
conc = []
117+
asyncStdin = asyncio.StreamReader(loop=self.loop)
118+
yield from self.loop.connect_read_pipe(
119+
lambda: asyncio.StreamReaderProtocol(asyncStdin, loop=self.loop), sys.stdin
120+
)
121+
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
122+
while not asyncStdin.at_eof():
123+
l = yield from asyncStdin.readline()
124+
yield from self.loop.run_in_executor(pool, partial(self.processEachFileName, l))
125+
126+
@asyncio.coroutine
127+
def asyncMainPathsFromCLI(self, filesOrDirs: typing.Iterable[typing.Union[Path, str]]):
128+
try:
129+
from tqdm import tqdm
130+
except ImportError:
131+
132+
def tqdm(x):
133+
return x
134+
135+
ctx = ReplaceContext(None)
136+
replaceInEachFileWithContext = partial(self.repl, ctx)
137+
138+
with tqdm(filesOrDirs) as pb:
139+
for fileOrDir in pb:
140+
fileOrDir = Path(fileOrDir).resolve().absolute()
141+
142+
files = tuple(self.fe(fileOrDir))
143+
144+
if files:
145+
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
146+
for f in files:
147+
if self.trace:
148+
print("Processing", f, file=pb)
149+
yield from self.loop.run_in_executor(pool, partial(replaceInEachFileWithContext, f))
150+
if self.trace:
151+
print("Processed", f, file=pb)
152+
153+
noSkipBinary = cli.Flag(
154+
["--no-skip-binary", "-n"],
155+
help="Don't skip binary files. Allows usage without `binaryornot`",
156+
default=False,
157+
)
158+
noSkipDot = cli.Flag(
159+
["--no-skip-dotfiles", "-d"],
160+
help="Don't skip files and dirs which name stem begins from dot.",
161+
default=False,
162+
)
163+
trace = cli.Flag(
164+
["--trace", "-t"],
165+
help="Print info about processing of regular files",
166+
default=False,
167+
)
168+
noReportSkipped = cli.Flag(
169+
["--no-report-skipped", "-s"],
170+
help="Don't report about skipped files",
171+
default=False,
172+
)
173+
174+
def disallowedReportingCallback(self, fileOrDir: Path, reasonOfDisallowal: str) -> None:
175+
if not self.noReportSkipped:
176+
print("Skipping ", fileOrDir, ":", reasonOfDisallowal)
177+
178+
def main(self, *filesOrDirs):
179+
self._repl = None # type: OurInFileReplacer
180+
self.loop = asyncio.get_event_loop()
181+
182+
self.fc = FileClassifier(self.noSkipDot, self.noSkipBinary)
183+
self.fe = FilesEnumerator(self.fc, self.disallowedReportingCallback)
184+
185+
if len(filesOrDirs) == 1 and filesOrDirs[0] == "0":
186+
t = self.loop.create_task(self.asyncMainPathsFromStdIn())
187+
else:
188+
t = self.loop.create_task(self.asyncMainPathsFromCLI(filesOrDirs))
189+
self.loop.run_until_complete(t)
190+
191+
192+
if __name__ == "__main__":
193+
CLI.run()

https_everywhere/_rules.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -821,39 +821,52 @@ def _get_ruleset(hostname, rulesets=None):
821821

822822
logger.debug("no ruleset matches {}".format(hostname))
823823

824+
from icecream import ic
825+
826+
def _remove_trailing_slash(url):
827+
if url[-1] == "/":
828+
url = url[:-1]
829+
return url
824830

825831
def https_url_rewrite(url, rulesets=None):
832+
orig_url = url
826833
if isinstance(url, str):
827834
# In HTTPSEverywhere, URLs must contain a '/'.
828835
if url.replace("http://", "").find("/") == -1:
829836
url += "/"
837+
remove_trailing_slash_if_needed = _remove_trailing_slash
830838
parsed_url = urlparse(url)
831839
else:
840+
remove_trailing_slash_if_needed = lambda x: x
841+
832842
parsed_url = url
833843
if hasattr(parsed_url, "geturl"):
834844
url = parsed_url.geturl()
835845
else:
836846
url = str(parsed_url)
837847

848+
if parsed_url.scheme is None or parsed_url.host is None:
849+
return orig_url
850+
838851
try:
839852
ruleset = _get_ruleset(parsed_url.host, rulesets)
840853
except AttributeError:
841854
ruleset = _get_ruleset(parsed_url.netloc, rulesets)
842855

843856
if not ruleset:
844-
return url
857+
return orig_url
845858

846859
if not isinstance(ruleset, _Ruleset):
847860
ruleset = _Ruleset(ruleset[0], ruleset[1])
848861

849862
if ruleset.exclude_url(url):
850-
return url
863+
return orig_url
851864

852865
# process rules
853866
for rule in ruleset.rules:
854867
logger.debug("checking rule {} -> {}".format(rule[0], rule[1]))
855868
try:
856-
new_url = rule[0].sub(rule[1], url)
869+
count, new_url = rule[0].subn(rule[1], url)
857870
except Exception as e: # pragma: no cover
858871
logger.warning(
859872
"failed during rule {} -> {} , input {}: {}".format(
@@ -863,7 +876,7 @@ def https_url_rewrite(url, rulesets=None):
863876
raise
864877

865878
# stop if this rule was a hit
866-
if new_url != url:
867-
return new_url
879+
if count:
880+
return remove_trailing_slash_if_needed(new_url)
868881

869-
return url
882+
return orig_url

https_everywhere/adapter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from logging_helper import setup_logging
44

55
import urllib3
6-
from urllib3.util.url import parse_url
76

87
import requests
98
from requests.adapters import HTTPAdapter
@@ -13,6 +12,7 @@
1312
from ._chrome_preload_hsts import _preload_including_subdomains
1413
from ._mozilla_preload_hsts import _preload_remove_negative
1514
from ._util import _check_in
15+
from .replacers.HSTSPreloadReplacer import apply_HSTS_preload
1616

1717
PY2 = str != "".__class__
1818
if PY2:
@@ -155,10 +155,7 @@ def __init__(self, *args, **kwargs):
155155

156156
def get_redirect(self, url):
157157
if url.startswith("http://"):
158-
p = parse_url(url)
159-
if _check_in(self._domains, p.host):
160-
new_url = "https:" + url[5:]
161-
return new_url
158+
return apply_HSTS_preload(url, self._domains)
162159

163160
return super(PreloadHSTSAdapter, self).get_redirect(url)
164161

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import re
2+
import typing
3+
4+
from urllib3.util.url import parse_url
5+
6+
from . import ReplaceContext, SingleURIReplacer
7+
8+
uri_re_source = "(?:http|ftp):\\/\\/?((?:[\\w-]+)(?::[\\w-]+)?@)?[\\w\\.:(-]+(?:\\/[\\w\\.:(/-]*)?"
9+
uri_re_text = re.compile(uri_re_source)
10+
uri_re_binary = re.compile(uri_re_source.encode("ascii"))
11+
12+
13+
class InBufferReplacer(SingleURIReplacer):
14+
__slots__ = ("singleURIReplacer",)
15+
FACS = None
16+
17+
def __init__(self, **kwargs):
18+
self.singleURIReplacer = self.__class__.FACS(**kwargs)
19+
20+
def _rePlaceFuncCore(self, uri):
21+
ctx = ReplaceContext(uri)
22+
self.singleURIReplacer(ctx)
23+
return ctx
24+
25+
def _rePlaceFuncText(self, m):
26+
uri = m.group(0)
27+
ctx = self._rePlaceFuncCore(uri)
28+
if ctx.count > 0:
29+
return ctx.res
30+
return uri
31+
32+
def _rePlaceFuncBinary(self, m):
33+
uri = m.group(0)
34+
ctx = self._rePlaceFuncCore(uri.decode("utf-8"))
35+
if ctx.count > 0:
36+
return ctx.res.encode("utf-8")
37+
return uri
38+
39+
def __call__(self, inputStr: typing.Union[str, bytes]) -> ReplaceContext:
40+
if isinstance(inputStr, str):
41+
return ReplaceContext(*uri_re_text.subn(self._rePlaceFuncText, inputStr))
42+
else:
43+
return ReplaceContext(*uri_re_binary.subn(self._rePlaceFuncBinary, inputStr))

0 commit comments

Comments
 (0)