Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"name": "dev-release",
"displayName": "Default development optimized build config",
"cacheVariables": {
"STRIP_BINARIES": "OFF"
"STRIP_BINARIES": "OFF",
"WFAIL": "OFF"
},
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/release"
Expand Down
143 changes: 141 additions & 2 deletions src/Lean/Compiler/LCNF/Simp/ConstantFold.lean
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def getUSizeLit (fvarId : FVarId) : CompilerM (Option UInt64) := do
let some (.lit (.usize n)) ← findLetValue? (pu := .pure) fvarId | return none
return n

def mkUSizeLit (x : UInt64) : CompilerM (LetValue .pure) := do
return .lit <| .usize x

end Literals

/--
Expand Down Expand Up @@ -246,7 +249,7 @@ def Folder.mkBinaryUSizeDecisionProcedure {r64 : UInt64 → UInt64 → Prop} {r3
let some arg₂ ← getUSizeLit fvarId₂ | return none
let res64 := (f64 arg₁ arg₂).decide
let res32 := (f32 arg₁.toUInt32 arg₂.toUInt32).decide
if res64 != res32 then return none
unless res64 == res32 do return none
if (← getPhase) < .mono then
if res64 then
return some <| .const ``Decidable.isTrue [] #[.erased, .erased]
Expand All @@ -255,6 +258,45 @@ def Folder.mkBinaryUSizeDecisionProcedure {r64 : UInt64 → UInt64 → Prop} {r3
else
mkLit res64

def Folder.mkBinaryUSize (f64 : UInt64 → UInt64 → UInt64) (f32 : UInt32 → UInt32 → UInt32) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₁ ← getUSizeLit fvarId₁ | return none
let some arg₂ ← getUSizeLit fvarId₂ | return none
let res64 := f64 arg₁ arg₂
let res32 := f32 arg₁.toUInt32 arg₂.toUInt32
unless res32.toUInt64 == res64 do return none
mkUSizeLit res64

def Folder.leftNeutralUSize (n64 : UInt64) (n32 : UInt32) :
Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₁ ← getUSizeLit fvarId₁ | return none
unless arg₁ == n64 && arg₁.toUInt32 == n32 do return none
return some <| .fvar fvarId₂ #[]

def Folder.rightNeutralUSize (n64 : UInt64) (n32 : UInt32) :
Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₂ ← getUSizeLit fvarId₂ | return none
unless arg₂ == n64 && arg₂.toUInt32 == n32 do return none
return some <| .fvar fvarId₁ #[]

def Folder.leftAnnihilatorUSize (a64 : UInt64) (a32 : UInt32)
(zero64 : UInt64) (zero32 : UInt32) : Folder := fun args => do
let #[.fvar fvarId, _] := args | return none
let some arg ← getUSizeLit fvarId | return none
unless arg == a64 && arg.toUInt32 == a32 do return none
assert! zero64.toUInt32 == zero32
mkUSizeLit zero64

def Folder.rightAnnihilatorUSize (a64 : UInt64) (a32 : UInt32)
(zero64 : UInt64) (zero32 : UInt32) : Folder := fun args => do
let #[_, .fvar fvarId] := args | return none
let some arg ← getUSizeLit fvarId | return none
unless arg == a64 && arg.toUInt32 == a32 do return none
assert! zero64.toUInt32 == zero32
mkUSizeLit zero64

/--
Provide a folder for an operation with a left neutral element.
-/
Expand Down Expand Up @@ -322,6 +364,43 @@ def Folder.mulLhsShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α →
let shiftLit ← mkAuxLit exponent
return some <| .const shiftLeft [] #[rhs, .fvar shiftLit]

/--
If `x` is a power of two with the same exponent in both its 64 and 32 bit
interpretation, return that exponent.
-/
private def getUSizePow2Exponent? (x : UInt64) : Option UInt64 :=
let exponent64 := x.log2
let exponent32 := x.toUInt32.log2
if UInt64.shiftLeft 1 exponent64 == x && UInt32.shiftLeft 1 exponent32 == x.toUInt32 &&
exponent64 == exponent32.toUInt64 then
some exponent64
else
none

def Folder.divShiftUSize : Folder := fun args => do
unless (← getDecl? ``USize.shiftRight).isSome do return none
let #[lhs, .fvar fvarId] := args | return none
let some rhs ← getUSizeLit fvarId | return none
let some exponent := getUSizePow2Exponent? rhs | return none
let shiftLit ← mkAuxLetDecl (← mkUSizeLit exponent)
return some <| .const ``USize.shiftRight [] #[lhs, .fvar shiftLit]

def Folder.mulRhsShiftUSize : Folder := fun args => do
unless (← getDecl? ``USize.shiftLeft).isSome do return none
let #[lhs, .fvar fvarId] := args | return none
let some rhs ← getUSizeLit fvarId | return none
let some exponent := getUSizePow2Exponent? rhs | return none
let shiftLit ← mkAuxLetDecl (← mkUSizeLit exponent)
return some <| .const ``USize.shiftLeft [] #[lhs, .fvar shiftLit]

def Folder.mulLhsShiftUSize : Folder := fun args => do
unless (← getDecl? ``USize.shiftLeft).isSome do return none
let #[.fvar fvarId, rhs] := args | return none
let some lhs ← getUSizeLit fvarId | return none
let some exponent := getUSizePow2Exponent? lhs | return none
let shiftLit ← mkAuxLetDecl (← mkUSizeLit exponent)
return some <| .const ``USize.shiftLeft [] #[rhs, .fvar shiftLit]

/--
Pick the first folder out of `folders` that succeeds.
-/
Expand All @@ -341,6 +420,12 @@ def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) (op : α → α
(_h1 : ∀ x, op neutral x = x := by simp) (_h2 : ∀ x, op x neutral = x := by simp) : Folder :=
Folder.first #[Folder.leftNeutral neutral op _h1, Folder.rightNeutral neutral op _h2]

/--
Provide a folder for a `USize` operation that has the same left and right neutral element.
-/
def Folder.leftRightNeutralUSize (n64 : UInt64) (n32 : UInt32) : Folder :=
Folder.first #[Folder.leftNeutralUSize n64 n32, Folder.rightNeutralUSize n64 n32]

/--
Provide a folder for an operation that has the same left and right annihilator.
-/
Expand All @@ -352,6 +437,16 @@ def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero :
Folder.rightAnnihilator annihilator zero op _h2
]

/--
Provide a folder for a `USize` operation that has the same left and right annihilator.
-/
def Folder.leftRightAnnihilatorUSize (annihilator64 : UInt64) (annihilator32 : UInt32)
(zero64 : UInt64) (zero32 : UInt32) : Folder :=
Folder.first #[
Folder.leftAnnihilatorUSize annihilator64 annihilator32 zero64 zero32,
Folder.rightAnnihilatorUSize annihilator64 annihilator32 zero64 zero32
]

/--
Literal folders for higher order datastructures.
-/
Expand All @@ -362,6 +457,9 @@ def higherOrderLiteralFolders : List (Name × Folder) := [
def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α) (log2 : α → α) : Folder :=
Folder.first #[Folder.mulLhsShift shiftLeft pow2 log2, Folder.mulRhsShift shiftLeft pow2 log2]

def Folder.mulShiftUSize : Folder :=
Folder.first #[Folder.mulLhsShiftUSize, Folder.mulRhsShiftUSize]

-- TODO: add option for controlling the limit
def natPowThreshold := 256

Expand All @@ -386,7 +484,8 @@ def Folder.toNat (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure))
let #[.fvar fvarId] := args | return none
let some (.lit lit) ← findLetValue? (pu := .pure) fvarId | return none
match lit with
| .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return some (.lit (.nat v.toNat))
| .uint8 v | .uint16 v | .uint32 v | .uint64 v => return some (.lit (.nat v.toNat))
| .usize v => if v.toUInt32.toUInt64 == v then return some (.lit (.nat v.toNat)) else return none
| .nat _ | .str _ => return none

/--
Expand All @@ -399,23 +498,63 @@ def arithmeticFolders : List (Name × Folder) := [
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16) (· + ·)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32) (· + ·)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64) (· + ·)]),
(``USize.add, Folder.first #[Folder.mkBinaryUSize UInt64.add UInt32.add, Folder.leftRightNeutralUSize 0 0]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftAnnihilator 0 0 (· - ·), Folder.rightNeutral 0 (· - ·)]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.rightNeutral (0 : UInt8) (· - ·)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.rightNeutral (0 : UInt16) (· - ·)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.rightNeutral (0 : UInt32) (· - ·)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.rightNeutral (0 : UInt64) (· - ·)]),
(``USize.sub, Folder.first #[Folder.mkBinaryUSize UInt64.sub UInt32.sub, Folder.rightNeutralUSize 0 0]),
-- We don't convert Nat multiplication by a power of 2 into a left shift, because the fast path
-- for multiplication isn't any slower than a fast path for left shift that checks for overflow.
(``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral (1 : Nat) (· * ·), Folder.leftRightAnnihilator (0 : Nat) 0 (· * ·)]),
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8) (· * ·), Folder.leftRightAnnihilator (0 : UInt8) 0 (· * ·), Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16) (· * ·), Folder.leftRightAnnihilator (0 : UInt16) 0 (· * ·), Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32) (· * ·), Folder.leftRightAnnihilator (0 : UInt32) 0 (· * ·), Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64) (· * ·), Folder.leftRightAnnihilator (0 : UInt64) 0 (· * ·), Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``USize.mul, Folder.first #[Folder.mkBinaryUSize UInt64.mul UInt32.mul, Folder.leftRightNeutralUSize 1 1, Folder.leftRightAnnihilatorUSize 0 0 0 0, Folder.mulShiftUSize]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1 (· / ·), Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8) (· / ·), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16) (· / ·), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32) (· / ·), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64) (· / ·), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``USize.div, Folder.first #[Folder.mkBinaryUSize UInt64.div UInt32.div, Folder.rightNeutralUSize 1 1, Folder.divShiftUSize]),

(``Nat.shiftLeft, Folder.first #[Folder.mkBinary Nat.shiftLeft, Folder.rightNeutral 0 Nat.shiftLeft (by intros; rfl)]),
(``UInt8.shiftLeft, Folder.first #[Folder.mkBinary UInt8.shiftLeft, Folder.rightNeutral 0 UInt8.shiftLeft @UInt8.shiftLeft_zero]),
(``UInt16.shiftLeft, Folder.first #[Folder.mkBinary UInt16.shiftLeft, Folder.rightNeutral 0 UInt16.shiftLeft @UInt16.shiftLeft_zero]),
(``UInt32.shiftLeft, Folder.first #[Folder.mkBinary UInt32.shiftLeft, Folder.rightNeutral 0 UInt32.shiftLeft @UInt32.shiftLeft_zero]),
(``UInt64.shiftLeft, Folder.first #[Folder.mkBinary UInt64.shiftLeft, Folder.rightNeutral 0 UInt64.shiftLeft @UInt64.shiftLeft_zero]),
(``USize.shiftLeft, Folder.first #[Folder.mkBinaryUSize UInt64.shiftLeft UInt32.shiftLeft, Folder.rightNeutralUSize 0 0]),

(``Nat.shiftRight, Folder.first #[Folder.mkBinary Nat.shiftRight, Folder.rightNeutral 0 Nat.shiftRight (by intros; rfl)]),
(``UInt8.shiftRight, Folder.first #[Folder.mkBinary UInt8.shiftRight, Folder.rightNeutral 0 UInt8.shiftRight @UInt8.shiftRight_zero]),
(``UInt16.shiftRight, Folder.first #[Folder.mkBinary UInt16.shiftRight, Folder.rightNeutral 0 UInt16.shiftRight @UInt16.shiftRight_zero]),
(``UInt32.shiftRight, Folder.first #[Folder.mkBinary UInt32.shiftRight, Folder.rightNeutral 0 UInt32.shiftRight @UInt32.shiftRight_zero]),
(``UInt64.shiftRight, Folder.first #[Folder.mkBinary UInt64.shiftRight, Folder.rightNeutral 0 UInt64.shiftRight @UInt64.shiftRight_zero]),
(``USize.shiftRight, Folder.first #[Folder.mkBinaryUSize UInt64.shiftRight UInt32.shiftRight, Folder.rightNeutralUSize 0 0]),

(``Nat.land, Folder.first #[Folder.mkBinary Nat.land, Folder.leftRightAnnihilator 0 0 Nat.land]),
(``UInt8.land, Folder.first #[Folder.mkBinary UInt8.land, Folder.leftRightAnnihilator 0 0 UInt8.land sorry sorry]),
(``UInt16.land, Folder.first #[Folder.mkBinary UInt16.land, Folder.leftRightAnnihilator 0 0 UInt16.land sorry sorry]),
(``UInt32.land, Folder.first #[Folder.mkBinary UInt32.land, Folder.leftRightAnnihilator 0 0 UInt32.land sorry sorry]),
(``UInt64.land, Folder.first #[Folder.mkBinary UInt64.land, Folder.leftRightAnnihilator 0 0 UInt64.land sorry sorry]),
(``USize.land, Folder.first #[Folder.mkBinaryUSize UInt64.land UInt32.land, Folder.leftRightAnnihilatorUSize 0 0 0 0]),

(``Nat.lor, Folder.first #[Folder.mkBinary Nat.lor, Folder.leftRightNeutral 0 Nat.lor]),
(``UInt8.lor, Folder.first #[Folder.mkBinary UInt8.lor, Folder.leftRightNeutral 0 UInt8.lor sorry sorry]),
(``UInt16.lor, Folder.first #[Folder.mkBinary UInt16.lor, Folder.leftRightNeutral 0 UInt16.lor sorry sorry]),
(``UInt32.lor, Folder.first #[Folder.mkBinary UInt32.lor, Folder.leftRightNeutral 0 UInt32.lor sorry sorry]),
(``UInt64.lor, Folder.first #[Folder.mkBinary UInt64.lor, Folder.leftRightNeutral 0 UInt64.lor sorry sorry]),
(``USize.lor, Folder.first #[Folder.mkBinaryUSize UInt64.lor UInt32.lor, Folder.leftNeutralUSize 0 0]),

(``Nat.xor, Folder.first #[Folder.mkBinary Nat.xor, Folder.leftRightNeutral 0 Nat.xor]),
(``UInt8.xor, Folder.first #[Folder.mkBinary UInt8.xor, Folder.leftRightNeutral 0 UInt8.xor sorry sorry]),
(``UInt16.xor, Folder.first #[Folder.mkBinary UInt16.xor, Folder.leftRightNeutral 0 UInt16.xor sorry sorry]),
(``UInt32.xor, Folder.first #[Folder.mkBinary UInt32.xor, Folder.leftRightNeutral 0 UInt32.xor sorry sorry]),
(``UInt64.xor, Folder.first #[Folder.mkBinary UInt64.xor, Folder.leftRightNeutral 0 UInt64.xor sorry sorry]),
(``USize.xor, Folder.first #[Folder.mkBinaryUSize UInt64.xor UInt32.xor, Folder.leftNeutralUSize 0 0]),

(``Nat.pow, foldNatPow),
(``Nat.nextPowerOfTwo, Folder.mkUnary Nat.nextPowerOfTwo),
]
Expand Down
23 changes: 11 additions & 12 deletions src/Lean/Data/PersistentHashMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,17 @@ partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize
def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → β → PersistentHashMap α β
| { root }, k, v => { root := insertAux root (hash k |>.toUSize) 1 k v }

partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
if h : i < keys.size then
let k' := keys[i]
have : i < vals.size := by rw [←heq]; assumption
if k == k' then some vals[i]
partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : USize) (k : α) : Option β :=
if h : i < keys.usize then
let k' := keys[i]'sorry
if k == k' then some (vals[i]'sorry)
else findAtAux keys vals heq (i+1) k
else none

partial def findAux [BEq α] : @&Node α β → USize → α → Option β
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries[j]! with
let j := mod2Shift h shift
match entries[j]'sorry with
| Entry.null => none
| Entry.ref node => findAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some v else none
Expand Down Expand Up @@ -219,17 +218,17 @@ A more efficient `m.findEntry? a |>.map (·.1) |>.getD a₀`
@[inline] def findKeyD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (a₀ : α) : α :=
findKeyDAux m.root (hash a |>.toUSize) a a₀

partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool :=
if h : i < keys.size then
let k' := keys[i]
partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : USize) (k : α) : Bool :=
if h : i < keys.usize then
let k' := keys[i]'sorry
if k == k' then true
else containsAtAux keys vals heq (i+1) k
else false

partial def containsAux [BEq α] : Node α β → USize → α → Bool
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries[j]! with
let j := mod2Shift h shift
match entries[j]'sorry with
| Entry.null => false
| Entry.ref node => containsAux node (div2Shift h shift) k
| Entry.entry k' _ => k == k'
Expand Down
Loading
Loading