Skip to content

Commit 4f7f2f2

Browse files
Cost library for ext/network
1 parent 258e7c8 commit 4f7f2f2

3 files changed

Lines changed: 835 additions & 19 deletions

File tree

ext/network.go

Lines changed: 197 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ package ext
1616

1717
import (
1818
"fmt"
19+
"math"
1920
"net/netip"
2021
"reflect"
2122

2223
"github.com/google/cel-go/cel"
24+
"github.com/google/cel-go/checker"
2325
"github.com/google/cel-go/common/ast"
2426
"github.com/google/cel-go/common/types"
2527
"github.com/google/cel-go/common/types/ref"
28+
"github.com/google/cel-go/interpreter"
2629
)
2730

2831
const (
@@ -182,7 +185,11 @@ const (
182185

183186
var (
184187
// Definitions for the Opaque Types
185-
IPType = types.NewOpaqueType("net.IP")
188+
189+
// IPType represents a network IP address.
190+
IPType = types.NewOpaqueType("net.IP")
191+
192+
// CIDRType represents a CIDR-format network range.
186193
CIDRType = types.NewOpaqueType("net.CIDR")
187194
)
188195

@@ -196,13 +203,11 @@ func (*networkLib) LibraryName() string {
196203

197204
func (*networkLib) CompileOptions() []cel.EnvOption {
198205
return []cel.EnvOption{
199-
// 1. Register Types
200206
cel.Types(
201207
IPType,
202208
CIDRType,
203209
),
204210

205-
// 2. Register Functions
206211
cel.Function(cidrFunc,
207212
// K8s Parity: Following the pattern, this is "string_to_cidr"
208213
cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, CIDRType,
@@ -288,11 +293,58 @@ func (*networkLib) CompileOptions() []cel.EnvOption {
288293
networkFormatValidator{funcName: ipFunc, argNum: 0, check: checkIP},
289294
networkFormatValidator{funcName: cidrFunc, argNum: 0, check: checkCIDR},
290295
),
296+
cel.CostEstimatorOptions(
297+
checker.OverloadCostEstimate("string_to_cidr", estimateNetworkParseCost),
298+
checker.OverloadCostEstimate("cidr_to_string", estimateNetworkNominalStringCost),
299+
checker.OverloadCostEstimate("cidr_contains_cidr", estimateNetworkContainsCIDRCIDRCost),
300+
checker.OverloadCostEstimate("cidr_contains_cidr_string", estimateNetworkContainsCIDRStringCost),
301+
checker.OverloadCostEstimate("cidr_contains_ip_ip", estimateNetworkContainsIPIPCost),
302+
checker.OverloadCostEstimate("cidr_contains_ip_string", estimateNetworkContainsIPStringCost),
303+
checker.OverloadCostEstimate("ip_family", estimateNetworkNominalCost),
304+
checker.OverloadCostEstimate("string_to_ip", estimateNetworkParseCost),
305+
checker.OverloadCostEstimate("cidr_ip", estimateNetworkNominalOpaqueCost),
306+
checker.OverloadCostEstimate("ip_to_string", estimateNetworkNominalStringCost),
307+
checker.OverloadCostEstimate("ip_is_canonical", estimateIPIsCanonicalCost),
308+
checker.OverloadCostEstimate("is_cidr", estimateNetworkParseBoolCost),
309+
checker.OverloadCostEstimate("ip_is_global_unicast", estimateNetworkNominalCost),
310+
checker.OverloadCostEstimate("is_ip", estimateNetworkParseBoolCost),
311+
checker.OverloadCostEstimate("ip_is_link_local_multicast", estimateNetworkNominalCost),
312+
checker.OverloadCostEstimate("ip_is_link_local_unicast", estimateNetworkNominalCost),
313+
checker.OverloadCostEstimate("ip_is_loopback", estimateNetworkNominalCost),
314+
checker.OverloadCostEstimate("cidr_is_mask", estimateNetworkNominalCost),
315+
checker.OverloadCostEstimate("ip_is_unspecified", estimateNetworkNominalCost),
316+
checker.OverloadCostEstimate("cidr_masked", estimateNetworkNominalOpaqueCost),
317+
checker.OverloadCostEstimate("cidr_prefix_length", estimateNetworkNominalCost),
318+
),
291319
}
292320
}
293321

294322
func (*networkLib) ProgramOptions() []cel.ProgramOption {
295-
return []cel.ProgramOption{}
323+
return []cel.ProgramOption{
324+
cel.CostTrackerOptions(
325+
interpreter.OverloadCostTracker("string_to_cidr", trackNetworkParseCost),
326+
interpreter.OverloadCostTracker("cidr_to_string", trackNetworkNominalCost),
327+
interpreter.OverloadCostTracker("cidr_contains_cidr", trackNetworkContainsCIDRCIDRCost),
328+
interpreter.OverloadCostTracker("cidr_contains_cidr_string", trackNetworkContainsCIDRStringCost),
329+
interpreter.OverloadCostTracker("cidr_contains_ip_ip", trackNetworkContainsIPIPCost),
330+
interpreter.OverloadCostTracker("cidr_contains_ip_string", trackNetworkContainsIPStringCost),
331+
interpreter.OverloadCostTracker("ip_family", trackNetworkNominalCost),
332+
interpreter.OverloadCostTracker("string_to_ip", trackNetworkParseCost),
333+
interpreter.OverloadCostTracker("cidr_ip", trackNetworkNominalCost),
334+
interpreter.OverloadCostTracker("ip_to_string", trackNetworkNominalCost),
335+
interpreter.OverloadCostTracker("ip_is_canonical", trackIPIsCanonicalCost),
336+
interpreter.OverloadCostTracker("is_cidr", trackNetworkParseCost),
337+
interpreter.OverloadCostTracker("ip_is_global_unicast", trackNetworkNominalCost),
338+
interpreter.OverloadCostTracker("is_ip", trackNetworkParseCost),
339+
interpreter.OverloadCostTracker("ip_is_link_local_multicast", trackNetworkNominalCost),
340+
interpreter.OverloadCostTracker("ip_is_link_local_unicast", trackNetworkNominalCost),
341+
interpreter.OverloadCostTracker("ip_is_loopback", trackNetworkNominalCost),
342+
interpreter.OverloadCostTracker("cidr_is_mask", trackNetworkNominalCost),
343+
interpreter.OverloadCostTracker("ip_is_unspecified", trackNetworkNominalCost),
344+
interpreter.OverloadCostTracker("cidr_masked", trackNetworkNominalCost),
345+
interpreter.OverloadCostTracker("cidr_prefix_length", trackNetworkNominalCost),
346+
),
347+
}
296348
}
297349

298350
// networkAdapter adapts netip types while preserving existing adapters.
@@ -478,8 +530,7 @@ func parseIPAddr(raw string) (netip.Addr, error) {
478530
return addr, nil
479531
}
480532

481-
// --- Opaque Type Wrappers ---
482-
533+
// IP represents an IP address type.
483534
type IP struct {
484535
netip.Addr
485536
}
@@ -527,6 +578,13 @@ func (i IP) Value() any {
527578
return i.Addr
528579
}
529580

581+
// Size returns the size of the IP address in bytes.
582+
// /Used in the size estimation of the runtime cost.
583+
func (i IP) Size() ref.Val {
584+
return types.Int(int64(math.Ceil(float64(i.Addr.BitLen()) / 8)))
585+
}
586+
587+
// CIDR represents the CIDR network mask format.
530588
type CIDR struct {
531589
netip.Prefix
532590
}
@@ -574,6 +632,12 @@ func (c CIDR) Value() any {
574632
return c.Prefix
575633
}
576634

635+
// Size returns the size of the CIDR prefix address in bytes.
636+
// Used in the size estimation of the runtime cost.
637+
func (c CIDR) Size() ref.Val {
638+
return types.Int(int64(math.Ceil(float64(c.Prefix.Bits()) / 8)))
639+
}
640+
577641
// --- Static Validators ---
578642

579643
type argChecker func(e *cel.Env, call, arg ast.Expr) error
@@ -617,3 +681,130 @@ func checkCIDR(e *cel.Env, call, arg ast.Expr) error {
617681
_, err := parseCIDR(pattern)
618682
return err
619683
}
684+
685+
// Cost estimation functions for network extensions.
686+
687+
func estimateNetworkParseCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
688+
if len(args) < 1 {
689+
return nil
690+
}
691+
sz := estimateSize(estimator, args[0])
692+
resultSize := rangedSizeEstimate(4, 16)
693+
return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), &resultSize)
694+
}
695+
696+
func estimateNetworkParseBoolCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
697+
if len(args) < 1 {
698+
return nil
699+
}
700+
sz := estimateSize(estimator, args[0])
701+
return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), nil)
702+
}
703+
704+
func estimateIPIsCanonicalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
705+
if len(args) < 1 {
706+
return nil
707+
}
708+
sz := estimateSize(estimator, args[0])
709+
return callEstimate(sz.MultiplyByCostFactor(2*stringCostFactor), nil)
710+
}
711+
712+
func estimateNetworkNominalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
713+
return callEstimate(callCostEstimate, nil)
714+
}
715+
716+
func estimateNetworkNominalOpaqueCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
717+
resultSize := rangedSizeEstimate(4, 16)
718+
return callEstimate(callCostEstimate, &resultSize)
719+
}
720+
721+
func estimateNetworkNominalStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
722+
resultSize := rangedSizeEstimate(3, 45)
723+
return callEstimate(callCostEstimate, &resultSize)
724+
}
725+
726+
func estimateNetworkContainsIPIPCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
727+
sz := rangedSizeEstimate(4, 16)
728+
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
729+
return callEstimate(ipCompCost, nil)
730+
}
731+
732+
func estimateNetworkContainsIPStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
733+
if len(args) < 1 {
734+
return nil
735+
}
736+
sz := rangedSizeEstimate(4, 16)
737+
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
738+
argSz := estimateSize(estimator, args[0])
739+
ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor))
740+
return callEstimate(ipCompCost, nil)
741+
}
742+
743+
func estimateNetworkContainsCIDRCIDRCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
744+
sz := rangedSizeEstimate(4, 16)
745+
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
746+
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor))
747+
// K8s adds one for the extra IP traversal
748+
ipCompCost = ipCompCost.Add(callCostEstimate)
749+
return callEstimate(ipCompCost, nil)
750+
}
751+
752+
func estimateNetworkContainsCIDRStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
753+
if len(args) < 1 {
754+
return nil
755+
}
756+
sz := rangedSizeEstimate(4, 16)
757+
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
758+
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor))
759+
argSz := estimateSize(estimator, args[0])
760+
ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor))
761+
// K8s adds one for the extra IP traversal
762+
ipCompCost = ipCompCost.Add(callCostEstimate)
763+
return callEstimate(ipCompCost, nil)
764+
}
765+
766+
// Runtime cost tracking functions for network extensions.
767+
768+
func trackNetworkParseCost(args []ref.Val, result ref.Val) *uint64 {
769+
cost := uint64(math.Ceil(float64(actualSize(args[0])) * stringCostFactor))
770+
return &cost
771+
}
772+
773+
func trackIPIsCanonicalCost(args []ref.Val, result ref.Val) *uint64 {
774+
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * stringCostFactor))
775+
return &cost
776+
}
777+
778+
func trackNetworkNominalCost(args []ref.Val, result ref.Val) *uint64 {
779+
return &callCost
780+
}
781+
782+
func trackNetworkContainsIPIPCost(args []ref.Val, result ref.Val) *uint64 {
783+
cidrSize := actualSize(args[0])
784+
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
785+
return &cost
786+
}
787+
788+
func trackNetworkContainsIPStringCost(args []ref.Val, result ref.Val) *uint64 {
789+
cidrSize := actualSize(args[0])
790+
otherSize := actualSize(args[1])
791+
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
792+
cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor)))
793+
return &cost
794+
}
795+
796+
func trackNetworkContainsCIDRCIDRCost(args []ref.Val, result ref.Val) *uint64 {
797+
cidrSize := actualSize(args[0])
798+
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
799+
cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1)
800+
return &cost
801+
}
802+
803+
func trackNetworkContainsCIDRStringCost(args []ref.Val, result ref.Val) *uint64 {
804+
cidrSize := actualSize(args[0])
805+
otherSize := actualSize(args[1])
806+
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
807+
cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1)
808+
cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor)))
809+
return &cost
810+
}

0 commit comments

Comments
 (0)