11package com .wavesplatform .lang .v1 .estimator .v3
22
3- import cats .implicits .*
3+ import cats .implicits .{ toBifunctorOps , toFoldableOps , toTraverseOps }
44import cats .{Id , Monad }
55import com .wavesplatform .lang .v1 .FunctionHeader
66import com .wavesplatform .lang .v1 .FunctionHeader .User
@@ -13,7 +13,7 @@ import monix.eval.Coeval
1313
1414import scala .util .Try
1515
16- case class ScriptEstimatorV3 (fixOverflow : Boolean , overhead : Boolean ) extends ScriptEstimator {
16+ case class ScriptEstimatorV3 (fixOverflow : Boolean , overhead : Boolean , letFixes : Boolean ) extends ScriptEstimator {
1717 private val overheadCost : Long = if (overhead) 1 else 0
1818
1919 override val version : Int = 3
@@ -39,55 +39,45 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
3939 globalDeclarationsMode : Boolean
4040 ): (EstimatorContext , Either [EstimationError , Long ]) = {
4141 val ctxFuncs = funcs.view.mapValues((_, Set [String ]())).toMap
42- evalExpr(expr, globalDeclarationsMode).run(EstimatorContext (ctxFuncs)).value
42+ evalExpr(expr, Set (), globalDeclarationsMode).run(EstimatorContext (ctxFuncs)).value
4343 }
4444
45- private def evalExpr (t : EXPR , globalDeclarationsMode : Boolean = false ): EvalM [Long ] =
45+ private def evalExpr (t : EXPR , activeFuncArgs : Set [ String ], globalDeclarationsMode : Boolean = false ): EvalM [Long ] =
4646 if (Thread .currentThread().isInterrupted)
4747 raiseError(" Script estimation was interrupted" )
4848 else
4949 t match {
50- case LET_BLOCK (let, inner) => evalLetBlock(let, inner, globalDeclarationsMode)
51- case BLOCK (let : LET , inner) => evalLetBlock(let, inner, globalDeclarationsMode)
52- case BLOCK (f : FUNC , inner) => evalFuncBlock(f, inner, globalDeclarationsMode)
50+ case LET_BLOCK (let, inner) => evalLetBlock(let, inner, activeFuncArgs, globalDeclarationsMode)
51+ case BLOCK (let : LET , inner) => evalLetBlock(let, inner, activeFuncArgs, globalDeclarationsMode)
52+ case BLOCK (f : FUNC , inner) => evalFuncBlock(f, inner, activeFuncArgs, globalDeclarationsMode)
5353 case BLOCK (_ : FAILED_DEC , _) => zero
54- case REF (str) => markRef (str)
54+ case REF (str) => evalRef (str, activeFuncArgs )
5555 case _ : EVALUATED => const(overheadCost)
56- case IF (cond, t1, t2) => evalIF(cond, t1, t2)
57- case GETTER (expr, _) => evalGetter(expr)
58- case FUNCTION_CALL (header, args) => evalFuncCall(header, args)
56+ case IF (cond, t1, t2) => evalIF(cond, t1, t2, activeFuncArgs )
57+ case GETTER (expr, _) => evalGetter(expr, activeFuncArgs )
58+ case FUNCTION_CALL (header, args) => evalFuncCall(header, args, activeFuncArgs )
5959 case _ : FAILED_EXPR => zero
6060 }
6161
62- private def evalHoldingFuncs ( expr : EXPR ): EvalM [Long ] =
62+ private def evalLetBlock ( let : LET , nextExpr : EXPR , activeFuncArgs : Set [ String ], globalDeclarationsMode : Boolean ): EvalM [Long ] =
6363 for {
64+ _ <- if (globalDeclarationsMode) saveGlobalLetCost(let, activeFuncArgs) else doNothing
6465 startCtx <- get[Id , EstimatorContext , EstimationError ]
65- cost <- evalExpr(expr)
66- _ <- update(funcs.set(_)(startCtx.funcs))
67- } yield cost
68-
69- private def evalLetBlock (let : LET , inner : EXPR , globalDeclarationsMode : Boolean ): EvalM [Long ] =
70- for {
71- startCtx <- get[Id , EstimatorContext , EstimationError ]
72- overlap = startCtx.usedRefs.contains(let.name)
73- _ <- update(usedRefs.modify(_)(_ - let.name))
74- letEval = evalHoldingFuncs(let.value)
75- _ <- if (globalDeclarationsMode) saveGlobalLetCost(let) else doNothing
76- nextCost <- evalExpr(inner, globalDeclarationsMode)
77- ctx <- get[Id , EstimatorContext , EstimationError ]
78- letCost <- if (ctx.usedRefs.contains(let.name)) letEval else zero
79- _ <- update(usedRefs.modify(_)(r => if (overlap) r + let.name else r - let.name))
80- result <- sum(nextCost, letCost)
66+ letEval = evalHoldingFuncs(let.value, activeFuncArgs)
67+ _ <- beforeNextExprEval(let, letEval)
68+ nextExprCost <- evalExpr(nextExpr, activeFuncArgs, globalDeclarationsMode)
69+ nextExprCtx <- get[Id , EstimatorContext , EstimationError ]
70+ _ <- afterNextExprEval(let, startCtx)
71+ letCost <- if (nextExprCtx.usedRefs.contains(let.name)) letEval else const(0L )
72+ result <- sum(nextExprCost, letCost)
8173 } yield result
8274
83- private def saveGlobalLetCost (let : LET ): EvalM [Unit ] = {
75+ private def saveGlobalLetCost (let : LET , activeFuncArgs : Set [ String ] ): EvalM [Unit ] = {
8476 val costEvaluation =
8577 for {
86- startCtx <- get[Id , EstimatorContext , EstimationError ]
87- bodyCost <- evalExpr(let.value)
88- bodyEvalCtx <- get[Id , EstimatorContext , EstimationError ]
89- usedRefs = bodyEvalCtx.usedRefs diff startCtx.usedRefs
90- letCosts <- usedRefs.toSeq.traverse(bodyEvalCtx.globalLetEvals.getOrElse(_, zero))
78+ (bodyCost, usedRefs) <- withUsedRefs(evalExpr(let.value, activeFuncArgs))
79+ ctx <- get[Id , EstimatorContext , EstimationError ]
80+ letCosts <- usedRefs.toSeq.traverse(ctx.globalLetEvals.getOrElse(_, zero))
9181 } yield bodyCost + letCosts.sum
9282 for {
9383 cost <- local(costEvaluation)
@@ -100,26 +90,47 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
10090 } yield ()
10191 }
10292
103- private def evalFuncBlock (func : FUNC , inner : EXPR , globalDeclarationsMode : Boolean ): EvalM [Long ] =
93+ private def beforeNextExprEval (let : LET , eval : EvalM [Long ]): EvalM [Unit ] =
94+ for {
95+ cost <- local(eval)
96+ _ <- update(ctx =>
97+ usedRefs
98+ .modify(ctx)(_ - let.name)
99+ .copy(refsCosts = ctx.refsCosts + (let.name -> cost))
100+ )
101+ } yield ()
102+
103+ private def afterNextExprEval (let : LET , startCtx : EstimatorContext ): EvalM [Unit ] =
104+ update(ctx =>
105+ usedRefs
106+ .modify(ctx)(r => if (startCtx.usedRefs.contains(let.name)) r + let.name else r - let.name)
107+ .copy(refsCosts =
108+ if (startCtx.refsCosts.contains(let.name))
109+ ctx.refsCosts + (let.name -> startCtx.refsCosts(let.name))
110+ else
111+ ctx.refsCosts - let.name
112+ )
113+ )
114+
115+ private def evalFuncBlock (func : FUNC , nextExpr : EXPR , activeFuncArgs : Set [String ], globalDeclarationsMode : Boolean ): EvalM [Long ] =
104116 for {
105- startCtx <- get[Id , EstimatorContext , EstimationError ]
106- _ <- checkShadowing(func, startCtx)
107- funcCost <- evalHoldingFuncs(func.body)
108- bodyEvalCtx <- get[Id , EstimatorContext , EstimationError ]
109- refsUsedInBody = bodyEvalCtx.usedRefs diff startCtx.usedRefs
110- _ <- if (globalDeclarationsMode) saveGlobalFuncCost(func.name, funcCost, bodyEvalCtx, refsUsedInBody) else doNothing
111- _ <- handleUsedRefs(func.name, funcCost, startCtx, refsUsedInBody)
112- nextCost <- evalExpr(inner, globalDeclarationsMode)
113- } yield nextCost
117+ startCtx <- get[Id , EstimatorContext , EstimationError ]
118+ _ <- checkShadowing(func, startCtx)
119+ (funcCost, refsUsedInBody) <- withUsedRefs(evalHoldingFuncs(func.body, activeFuncArgs ++ func.args))
120+ _ <- if (globalDeclarationsMode) saveGlobalFuncCost(func.name, funcCost, refsUsedInBody) else doNothing
121+ _ <- handleUsedRefs(func.name, funcCost, startCtx, refsUsedInBody)
122+ nextExprCost <- evalExpr(nextExpr, activeFuncArgs, globalDeclarationsMode)
123+ } yield nextExprCost
114124
115125 private def checkShadowing (func : FUNC , startCtx : EstimatorContext ): EvalM [Any ] =
116126 if (fixOverflow && startCtx.funcs.contains(FunctionHeader .User (func.name)))
117127 raiseError(s " Function ' ${func.name}${func.args.mkString(" (" , " , " , " )" )}' shadows preceding declaration " )
118128 else
119129 doNothing
120130
121- private def saveGlobalFuncCost (name : String , funcCost : Long , ctx : EstimatorContext , refsUsedInBody : Set [String ]): EvalM [Unit ] =
131+ private def saveGlobalFuncCost (name : String , funcCost : Long , refsUsedInBody : Set [String ]): EvalM [Unit ] =
122132 for {
133+ ctx <- get[Id , EstimatorContext , EstimationError ]
123134 letCosts <- local(refsUsedInBody.toSeq.traverse(ctx.globalLetEvals.getOrElse(_, zero)))
124135 totalCost = math.max(1 , funcCost + letCosts.sum)
125136 _ <- set[Id , EstimatorContext , EstimationError ](ctx.copy(globalFunctionsCosts = ctx.globalFunctionsCosts + (name -> totalCost)))
@@ -135,46 +146,75 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
135146 }
136147 )
137148
138- private def evalIF (cond : EXPR , ifTrue : EXPR , ifFalse : EXPR ): EvalM [Long ] =
149+ private def evalIF (cond : EXPR , ifTrue : EXPR , ifFalse : EXPR , activeFuncArgs : Set [ String ] ): EvalM [Long ] =
139150 for {
140- cond <- evalHoldingFuncs(cond)
141- right <- evalHoldingFuncs(ifTrue)
142- left <- evalHoldingFuncs(ifFalse)
151+ cond <- evalHoldingFuncs(cond, activeFuncArgs )
152+ right <- evalHoldingFuncs(ifTrue, activeFuncArgs )
153+ left <- evalHoldingFuncs(ifFalse, activeFuncArgs )
143154 r1 <- sum(cond, Math .max(right, left))
144155 r2 <- sum(r1, overheadCost)
145156 } yield r2
146157
147- private def markRef (key : String ): EvalM [Long ] =
148- update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
158+ private def evalRef (key : String , activeFuncArgs : Set [String ]): EvalM [Long ] =
159+ if (activeFuncArgs.contains(key) && letFixes)
160+ const(overheadCost)
161+ else
162+ update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
149163
150- private def evalGetter (expr : EXPR ): EvalM [Long ] =
151- evalExpr(expr).flatMap(sum(_, overheadCost))
164+ private def evalGetter (expr : EXPR , activeFuncArgs : Set [ String ] ): EvalM [Long ] =
165+ evalExpr(expr, activeFuncArgs ).flatMap(sum(_, overheadCost))
152166
153- private def evalFuncCall (header : FunctionHeader , args : List [EXPR ]): EvalM [Long ] =
167+ private def evalFuncCall (header : FunctionHeader , args : List [EXPR ], activeFuncArgs : Set [ String ] ): EvalM [Long ] =
154168 for {
155- ctx <- get[Id , EstimatorContext , EstimationError ]
156- (bodyCost, bodyUsedRefs) <- funcs
157- .get(ctx)
158- .get(header)
159- .map(const)
160- .getOrElse(
161- raiseError[Id , EstimatorContext , EstimationError , (Coeval [Long ], Set [String ])](s " function ' $header' not found " )
162- )
163- _ <- update(
164- (funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
165- (
166- funcs + ((header, (bodyCost, Set [String ]()))),
167- usedRefs ++ bodyUsedRefs
168- )
169- }
170- )
171- argsCosts <- args.traverse(evalHoldingFuncs)
172- argsCostsSum <- argsCosts.foldM(0L )(sum)
173- bodyCostV = bodyCost.value()
174- correctedBodyCost = if (! overhead && bodyCostV == 0 ) 1 else bodyCostV
169+ ctx <- get[Id , EstimatorContext , EstimationError ]
170+ (bodyCost, bodyUsedRefs) <- getFuncCost(header, ctx)
171+ _ <- setFuncToCtx(header, bodyCost, bodyUsedRefs)
172+ (argsCosts, argsUsedRefs) <- withUsedRefs(args.traverse(evalHoldingFuncs(_, activeFuncArgs)))
173+ argsCostsSum <- argsCosts.foldM(0L )(sum)
174+ bodyCostV = bodyCost.value()
175+ correctedBodyCost =
176+ if (! overhead && ! letFixes && bodyCostV == 0 ) 1
177+ else if (letFixes && bodyCostV == 0 && isBlankFunc(bodyUsedRefs ++ argsUsedRefs, ctx.refsCosts)) 1
178+ else bodyCostV
175179 result <- sum(argsCostsSum, correctedBodyCost)
176180 } yield result
177181
182+ private def setFuncToCtx (header : FunctionHeader , bodyCost : Coeval [Long ], bodyUsedRefs : Set [EstimationError ]): EvalM [Unit ] =
183+ update(
184+ (funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
185+ (
186+ funcs + (header -> (bodyCost, Set ())),
187+ usedRefs ++ bodyUsedRefs
188+ )
189+ }
190+ )
191+
192+ private def getFuncCost (header : FunctionHeader , ctx : EstimatorContext ): EvalM [(Coeval [Long ], Set [EstimationError ])] =
193+ funcs
194+ .get(ctx)
195+ .get(header)
196+ .map(const)
197+ .getOrElse(
198+ raiseError[Id , EstimatorContext , EstimationError , (Coeval [Long ], Set [EstimationError ])](s " function ' $header' not found " )
199+ )
200+
201+ private def isBlankFunc (usedRefs : Set [String ], refsCosts : Map [String , Long ]): Boolean =
202+ ! usedRefs.exists(refsCosts.get(_).exists(_ > 0 ))
203+
204+ private def evalHoldingFuncs (expr : EXPR , activeFuncArgs : Set [String ]): EvalM [Long ] =
205+ for {
206+ startCtx <- get[Id , EstimatorContext , EstimationError ]
207+ cost <- evalExpr(expr, activeFuncArgs)
208+ _ <- update(funcs.set(_)(startCtx.funcs))
209+ } yield cost
210+
211+ private def withUsedRefs [A ](eval : EvalM [A ]): EvalM [(A , Set [String ])] =
212+ for {
213+ ctxBefore <- get[Id , EstimatorContext , EstimationError ]
214+ result <- eval
215+ ctxAfter <- get[Id , EstimatorContext , EstimationError ]
216+ } yield (result, ctxAfter.usedRefs diff ctxBefore.usedRefs)
217+
178218 private def update (f : EstimatorContext => EstimatorContext ): EvalM [Unit ] =
179219 modify[Id , EstimatorContext , EstimationError ](f)
180220
@@ -192,3 +232,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
192232 liftEither(Try (r).toEither.leftMap(_ => " Illegal script" ))
193233 }
194234}
235+
236+ object ScriptEstimatorV3 {
237+ val latest = ScriptEstimatorV3 (fixOverflow = true , overhead = false , letFixes = true )
238+ }
0 commit comments