Skip to content

Commit 9bf4b56

Browse files
committed
Support for mrjob mrjobs. Kinda hacky but it works.
1 parent 2ce6fa4 commit 9bf4b56

2 files changed

Lines changed: 326 additions & 0 deletions

File tree

luigi/contrib/mrjob.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Some awesome license
2+
3+
# import mrjob
4+
from __future__ import absolute_import
5+
6+
from cStringIO import StringIO
7+
from luigi.task import flatten
8+
import logging
9+
import luigi
10+
import luigi.s3
11+
from datetime import datetime
12+
13+
14+
logger = logging.getLogger('luigi-interface')
15+
16+
17+
class MRJobTask(luigi.Task):
18+
job_class = None
19+
job_args = luigi.Parameter(default=None)
20+
job_options = {}
21+
job_results = None
22+
s3_output_path = None
23+
misc_options = {}
24+
25+
def run_inline(self, job):
26+
stdin = StringIO()
27+
inlines = self.input()
28+
inlines = flatten(inlines)
29+
inlines = ('\n'.join(inlines)).replace('\n\n', '\n')
30+
stdin.write(inlines)
31+
stdin.seek(0)
32+
job.sandbox(stdin=stdin)
33+
results = []
34+
with job.make_runner() as runner:
35+
runner.run()
36+
for line in runner.stream_output():
37+
if self.misc_options.get('parse_output', False):
38+
key, value = job.parse_output_line(line)
39+
results.append((key, value))
40+
else:
41+
results.append(line)
42+
43+
return results
44+
45+
def run_emr(self, job):
46+
results = []
47+
with job.make_runner() as runner:
48+
runner.run()
49+
for line in runner.stream_output():
50+
key, value = job.parse_output_line(line)
51+
results.append((key, value))
52+
53+
def run(self):
54+
self.job_args = list(self.job_args or []) or ['-r', 'inline', '--no-conf', '--strict-protocols', '-']
55+
56+
logger.info('Running mrjob task with arguments: %s' % ', '.join(self.job_args))
57+
58+
job = self.job_class(self.job_args)
59+
self.job_options = vars(vars(job)['options'])
60+
61+
runner = self.job_options['runner']
62+
63+
if runner == 'inline':
64+
self.job_results = self.run_inline(job)
65+
66+
elif runner == 'emr':
67+
68+
for arg in self.job_args:
69+
if '--output-dir' in arg:
70+
self.s3_output_path = arg.split('=', 1)[-1]
71+
72+
if not self.s3_output_path:
73+
self.s3_output_path = str(self.task_id) + '_' + datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
74+
logger.info('You did not specify an output s3 path. It will be automatically assigned as: %s' %
75+
self.s3_output_path)
76+
77+
self.run_emr(job)
78+
else:
79+
raise NotImplementedError
80+
logger.info('Finished running mrjob')
81+
82+
def complete(self):
83+
return (self.job_results or self.s3_output_path) is not None
84+
85+
def output(self):
86+
runner = self.job_options['runner']
87+
if runner == 'inline':
88+
return self.job_results
89+
elif runner == 'emr':
90+
return luigi.s3.S3Target(self.s3_output_path)
91+

test/contrib/mrjob_test.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Super exciting license
2+
from cStringIO import StringIO
3+
4+
from mrjob.job import MRJob
5+
from mrjob.protocol import RawValueProtocol, RawProtocol
6+
import mock
7+
from mock import patch
8+
9+
from luigi.s3 import S3Target, S3Client
10+
# from luigi.contrib import mrjob
11+
import luigi.contrib
12+
from luigi.contrib.mrjob import MRJobTask
13+
from luigi.task import flatten
14+
import luigi
15+
import luigi.interface
16+
import luigi.worker
17+
import luigi.scheduler
18+
import unittest
19+
import tempfile
20+
import os
21+
22+
# moto does not yet work with
23+
# python 2.6. Until it does,
24+
# disable these tests in python2.6
25+
try:
26+
from moto import mock_s3
27+
except ImportError:
28+
# https://github.com/spulec/moto/issues/29
29+
print('Skipping %s because moto does not install properly before '
30+
'python2.7' % __file__)
31+
from luigi.mock import skip
32+
mock_s3 = skip
33+
34+
AWS_ACCESS_KEY = "XXXXXXXXXXXXXXXXXXXX"
35+
AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
36+
37+
38+
class MRWordFrequencyCount(MRJob):
39+
40+
def mapper(self, _, line):
41+
yield "chars", len(line)
42+
yield "words", len(line.split())
43+
yield "lines", 1
44+
45+
def reducer(self, key, values):
46+
yield key, sum(values)
47+
48+
49+
class MRAddIncrementNumbers(MRJob):
50+
51+
# INPUT_PROTOCOL = RawValueProtocol
52+
OUTPUT_PROTOCOL = RawValueProtocol
53+
54+
def mapper(self, _, number):
55+
yield None, int(number) + 5
56+
57+
def reducer(self, _, values):
58+
59+
yield None, sum(map(int, values))
60+
61+
62+
class DebugTarget(luigi.Target):
63+
64+
def __init__(self, data):
65+
self.data = data
66+
67+
def __call__(self, *args, **kwargs):
68+
return self.data
69+
70+
71+
class DataDump(luigi.ExternalTask):
72+
73+
def run(self):
74+
pass
75+
76+
def complete(self):
77+
return True
78+
79+
def output(self):
80+
return None
81+
82+
83+
class WordCountTask(MRJobTask):
84+
job_class = MRWordFrequencyCount
85+
86+
def requires(self):
87+
return DataDump()
88+
89+
90+
class AddIncrementTask1a(MRJobTask):
91+
job_class = MRAddIncrementNumbers
92+
# job_args = luigi.Parameter()
93+
94+
def requires(self):
95+
return DataDump()
96+
97+
98+
class AddIncrementTask1b(MRJobTask):
99+
job_class = MRAddIncrementNumbers
100+
# job_args = luigi.Parameter()
101+
102+
def requires(self):
103+
return DataDump()
104+
105+
106+
class AddIncrementTask2(MRJobTask):
107+
job_class = MRAddIncrementNumbers
108+
misc_options = {'parse_output': True}
109+
110+
def requires(self):
111+
return AddIncrementTask1a(), AddIncrementTask1b()
112+
113+
114+
class AddIncrementTask3(MRJobTask):
115+
job_class = MRAddIncrementNumbers
116+
misc_options = {'parse_output': True}
117+
s3_bucket = ''
118+
119+
def requires(self):
120+
return AddIncrementTask1a(job_args=['-r', 'emr', '--no-conf', '--strict-protocols',
121+
'--output-dir=s3://mybucket/wc_out/']),\
122+
AddIncrementTask1b(job_args=['-r', 'emr', '--no-conf', '--strict-protocols',
123+
'--output-dir=s3://mybucket/wc_out/'])
124+
125+
126+
127+
128+
class MRJobTaskTest(unittest.TestCase):
129+
130+
def setUp(self):
131+
self.default_args = ['-r', 'inline', '--no-conf', '--strict-protocols', '-']
132+
self.m_data_output = patch.object(DataDump, 'output', autospec=True).start()
133+
self.m_run_emr = patch.object(MRJobTask, 'run_emr', autospec=True).start()
134+
self.m_run_emr.side_effect = lambda self_, _: self.fake_emr(self_)
135+
136+
self.client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY)
137+
138+
f = tempfile.NamedTemporaryFile(mode='wb', delete=False)
139+
self.tempFileContents = (
140+
'\n'.join(["hello", "hello", "hello"])
141+
)
142+
self.tempFilePath = f.name
143+
f.write(self.tempFileContents)
144+
f.close()
145+
146+
def tearDown(self):
147+
os.remove(self.tempFilePath)
148+
149+
def run_task_with_input(self, task, input_data):
150+
self.m_data_output.return_value = input_data
151+
task.run()
152+
153+
def test_run_task(self):
154+
task = WordCountTask()
155+
task.misc_options = {'parse_output': True}
156+
self.run_task_with_input(task, ["hello", "hello", "hello"])
157+
158+
result = task.output()
159+
comparison = [('chars', 15), ('lines', 3), ('words', 3)]
160+
self.assertEqual(result, comparison)
161+
162+
def fake_emr(self, task):
163+
args = task.job_args
164+
args[args.index('emr')] = 'inline'
165+
job = task.job_class(args)
166+
stdin = StringIO()
167+
if isinstance(task.input(), list):
168+
inlines = []
169+
for t in task.input():
170+
inlines.append(t.open().read().split())
171+
inlines = flatten(inlines)
172+
else:
173+
inlines = task.input().open().read().split()
174+
175+
inlines = '\n'.join(inlines)
176+
stdin.write(inlines)
177+
stdin.seek(0)
178+
job.sandbox(stdin=stdin)
179+
results = ''
180+
task.job_options['runner'] = 'inline'
181+
with job.make_runner() as runner:
182+
runner.run()
183+
for line in runner.stream_output():
184+
results += line
185+
self.client.put_string(results, task.s3_output_path)
186+
task.job_options['runner'] = 'emr'
187+
return results
188+
189+
@mock_s3
190+
def test_fake_emr(self):
191+
bucket = 'mybucket'
192+
output_path = 's3://mybucket/wc_out/'
193+
input_path = 's3://mybucket/wc_in/'
194+
self.client.s3.create_bucket(bucket)
195+
self.client.put(self.tempFilePath, input_path)
196+
197+
args = ['-r', 'emr', '--no-conf', '--strict-protocols', '--output-dir=' + output_path]
198+
task = WordCountTask()
199+
task.job_args = args
200+
self.run_task_with_input(task, S3Target(input_path, client=self.client))
201+
202+
read_file = task.output().open()
203+
file_str = read_file.read()
204+
self.assertEqual(file_str, '"chars"\t15\n"lines"\t3\n"words"\t3\n')
205+
206+
def test_multiple_tasks(self):
207+
self.m_data_output.return_value = ['1', '2', '3', '4']
208+
task = AddIncrementTask2()
209+
luigi.interface.setup_interface_logging()
210+
sch = luigi.scheduler.CentralPlannerScheduler()
211+
w = luigi.worker.Worker(scheduler=sch)
212+
w.add(task)
213+
w.run()
214+
self.assertEqual(task.output(), [(None, '70\n')])
215+
216+
@mock_s3
217+
def test_multiple_fake_emr(self):
218+
bucket = 'mybucket'
219+
output_path = 's3://mybucket/wc_out/'
220+
input_path = 's3://mybucket/wc_in/'
221+
self.client.s3.create_bucket(bucket)
222+
self.client.put_string('\n'.join(['1', '2', '3', '4']), input_path)
223+
224+
self.m_data_output.return_value = S3Target(input_path, client=self.client)
225+
task = AddIncrementTask3(job_args=['-r', 'emr', '--no-conf',
226+
'--strict-protocols', '--output-dir=' + output_path])
227+
228+
luigi.interface.setup_interface_logging()
229+
sch = luigi.scheduler.CentralPlannerScheduler()
230+
w = luigi.worker.Worker(scheduler=sch)
231+
w.add(task)
232+
w.run()
233+
read_file = task.output().open()
234+
file_str = read_file.read()
235+
self.assertEqual(file_str, '70\n')

0 commit comments

Comments
 (0)