Skip to content

Commit 47d1a63

Browse files
committed
100% coverage
1 parent c68cb69 commit 47d1a63

2 files changed

Lines changed: 91 additions & 30 deletions

File tree

iter.lua

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ local random = math.random
1515
local tabcat = table.concat
1616
local sub = string.sub
1717
local format = string.format
18-
local pack = table.pack or
19-
function(...) return { n = select('#', ...), ... } end
18+
local pack = table.pack or function(...) return { n = select('#', ...), ... } end
2019

2120

2221
-- iterator creation
@@ -63,7 +62,6 @@ Iter.__name = "iterator"
6362
Iter.__index = Iter
6463

6564
-- internal callback defaults
66-
function Iter.iter() end
6765
function Iter:next(state, key) return self.iter(state, key) end
6866
function Iter:reset(other)
6967
if other then
@@ -76,19 +74,18 @@ end
7674

7775
local function collect(self, key, ...)
7876
self.current = key
79-
if self.current == nil then self.stopped = true end
77+
if key == nil then self.stopped = true end
8078
return key, ...
8179
end
8280

83-
function Iter:__call(state, key)
81+
function Iter:__call()
8482
if self.stopped then return end
85-
state = state or self.state
86-
return collect(self, self:next(state, key or self.current or self.init))
83+
return collect(self, self:next(self.state, self.current))
8784
end
8885

8986
function Iter:rewind()
9087
if self.stopped then
91-
self.stopped, self.current = nil, nil
88+
self.stopped, self.current = nil, self.init
9289
if self.next ~= Iter.next then
9390
return self:reset() or self
9491
end
@@ -113,17 +110,18 @@ function Iter:clone()
113110
end
114111

115112
local function new_stateless(func, state, init)
116-
local self = { iter = func, state = state, init = init }
113+
local self = { iter = func, state = state, init = init, current = init }
117114
if not state then self.state = self end
118115
return setmetatable(self, Iter)
119116
end
120117

121118
local function new_stateful(reset, func, state, init)
122-
local self = { reset = reset, next = func, state = state, init = init }
119+
local self = { reset = reset, next = func, state = state, init = init, current = init }
123120
if not state then self.state = self end
124121
return reset(setmetatable(self, Iter)) or self
125122
end
126123

124+
local function nil_iter() end
127125
local function string_iter(state, key)
128126
key = (key or 0) + 1
129127
local ch = sub(state, key, key)
@@ -134,19 +132,14 @@ end
134132
local function newiter(v, state, init)
135133
local t = type(v)
136134
if t == "table" then
137-
if getmetatable(v) ~= Iter then
138-
return new_stateless(pairs(v))
139-
elseif v.next == Iter.next and state then
140-
return new_stateless(v, state, init)
141-
else
142-
return v:clone()
143-
end
135+
if getmetatable(v) == Iter then return v:clone() end
136+
return new_stateless(pairs(v))
144137
elseif t == "function" then
145138
return new_stateless(v, state, init)
146139
elseif t == "string" then
147140
return new_stateless(string_iter, v, 0)
148141
elseif t == "nil" then
149-
return new_stateless(Iter.iter)
142+
return new_stateless(nil_iter)
150143
end
151144
error(format('attempt to iterate a %s value', t))
152145
end
@@ -351,10 +344,23 @@ local function takewhile(func, base)
351344
return self
352345
end
353346

347+
local function dropwhile_collect(self, state, key, ...)
348+
if key == nil then return end
349+
if self.remain ~= true then
350+
if state(key, ...) then return dropwhile_collect(self, state, self[1]()) end
351+
self.remain = true
352+
end
353+
return key, ...
354+
end
355+
356+
local function dropwhile_next(self, state)
357+
return dropwhile_collect(self, state, self[1]())
358+
end
359+
354360
local function dropwhile(func, base)
355-
assert(func, "function expected")
356-
func = function(...) return not func(...) end
357-
return takewhile(func, base)
361+
local self = new_stateful(takedrop_reset, dropwhile_next, func or id)
362+
self[1] = base
363+
return self
358364
end
359365

360366
export1(taken, "taken", "take_n", "takeN")
@@ -381,7 +387,7 @@ end, "split", "span", "splitAt", "split_at")
381387
-- transforms
382388

383389
local function map_collect(func, key, ...)
384-
if key then return func(key, ...) end
390+
if key ~= nil then return func(key, ...) end
385391
end
386392

387393
local function map_next(self, state)
@@ -426,6 +432,9 @@ end
426432

427433
local function scan_collect(self, state, key, ...)
428434
if key == nil then return end
435+
if not self.current and not state.acc then
436+
return state.func(key, self[1]())
437+
end
429438
return state.func(self.current or state.acc, key, ...)
430439
end
431440

@@ -450,10 +459,8 @@ local function group_reset(self, other)
450459
collects[k] = v
451460
end
452461
self.collects = collects
453-
else
454-
self.collects = nil
455462
end
456-
self.remain = other.remain
463+
self.remain = other.remain
457464
else
458465
self.collects = { n = 0 }
459466
self.remain = self.state
@@ -664,6 +671,7 @@ end
664671

665672
function Iter:prefix(...)
666673
local super = zip(...)
674+
super.notskip = true
667675
super[#super+1] = self:clone()
668676
return super
669677
end
@@ -745,7 +753,7 @@ local function foldl(func, init, base)
745753
end
746754
local key
747755
while true do
748-
key, init = foldl_collect(func, init, base())
756+
key, init = foldl_collect(func, init, base())
749757
if key == nil then return init end
750758
end
751759
end
@@ -757,7 +765,7 @@ local function index_collect(func, base, i, key, ...)
757765
end
758766

759767
local function index(func, base)
760-
return index_collect(func or id, base(), 0, base())
768+
return index_collect(func or id, base, 0, base())
761769
end
762770

763771
local function tcollect_collect(t, key, ...)
@@ -891,7 +899,8 @@ function Selector:__call(...)
891899
rawset(self, 'max', 0)
892900
rawset(self, 'dots', false)
893901
local expr = self.gen(self)
894-
local code = "return function(_, _"..range(self.max):concat ", _"
902+
local code = "return function(_"
903+
if self.max > 0 then code = code..", _"..range(self.max):concat ", _" end
895904
if self.dots then code = code .. ", ..." end
896905
code = code .. ") return "..expr.."; end"
897906
eval = assert(load(code, expr))()
@@ -1037,8 +1046,7 @@ end
10371046

10381047
setmetatable(Operator, {
10391048
__call = function(self, op)
1040-
op = assert(self[op], "not such operator")
1041-
return iter._(op)
1049+
return iter._(assert(self[op], "not such operator"))
10421050
end
10431051
})
10441052

test.impl.lua

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ if ... == nil then return dofile './test.lua' end
44

55
--! array
66

7+
resolve(1,2,3):each(print)
8+
--[[OUTPUT
9+
1,2,3
10+
--]]
11+
712
for v in array { 1,2,3 } do print(v) end
813
--[[OUTPUT
914
1
@@ -511,9 +516,40 @@ drop(2, ipairs {'a', 'b', 'c', 'd', 'e'}):each(print)
511516
5,e
512517
--]]
513518

519+
take(_"_1 < 5", array { 1,2,3,4,5,6,7,8,9,10 }):each(print)
520+
--[[OUTPUT
521+
1
522+
2
523+
3
524+
4
525+
--]]
526+
527+
drop(_"_1 < 5", array { 1,2,3,4,5,6,7,8,9,10 }):each(print)
528+
--[[OUTPUT
529+
5
530+
6
531+
7
532+
8
533+
9
534+
10
535+
--]]
536+
514537

515538
--! transforms
516539

540+
scan(_1+_2, nil, range(10)):each(print)
541+
--[[OUTPUT
542+
3
543+
6
544+
10
545+
15
546+
21
547+
28
548+
36
549+
45
550+
55
551+
--]]
552+
517553
fun = function(...) return 'map', ... end
518554
map(fun, range(0)):each(print)
519555
--[[OUTPUT
@@ -623,6 +659,14 @@ array {1,2,2,3,3,4,5} :packgroupby():flatmap(array):map(_G.unpack or table.unpac
623659

624660
--! compositions
625661

662+
array{"a", "b", "c", "d"} :prefix(range()):each(print)
663+
--[[OUTPUT
664+
1,a
665+
2,b
666+
3,c
667+
4,d
668+
--]]
669+
626670
zip(array{"a", "b", "c", "d"}, array{"one", "two", "three"}):each(print)
627671
--[[OUTPUT
628672
a,one
@@ -836,11 +880,20 @@ Emma
836880

837881
--! reducing
838882

883+
eq(_.self(range(10)).reduce(op.add)(), 55)
884+
885+
eq(range(10):map(-_1):reduce(_1+_2), -55)
886+
eq(foldl(_1 + _2, nil, range(5)), 15)
839887
eq(foldl(_1 + _2, 0, range(5)), 15)
840888
eq(foldl(op.add, 0, range(5)), 15)
841889
eq(foldl(_1+_2*_3, 0, zip(range(1, 5), array{4, 3, 2, 1, 0})), 20)
842890
eq(foldl, reduce)
843891

892+
eq(range(10):count(), 10)
893+
894+
eq(index(_"_1 > 5", range(10)), 6)
895+
eq(range(5):scan(_1+_2):collect(), {3, 6, 10, 15})
896+
844897
eq(length{"a", "b", "c", "d", "e"}, 5)
845898
eq(length{}, 0)
846899
eq(length(range(0)), 0)

0 commit comments

Comments
 (0)