Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
path1 = path2
)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1 = TypePath::cons(TRefTypeParameter(), path2)
)
or
n1 = n2.(ParenExpr).getExpr() and
path1 = path2
or
Expand All @@ -239,12 +232,36 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
)
}

bindingset[path1]
private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1 = TypePath::consInverse(TRefTypeParameter(), path2)
)
}

bindingset[path2]
private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1 = TypePath::cons(TRefTypeParameter(), path2)
)
}

pragma[nomagic]
private Type inferTypeEquality(AstNode n, TypePath path) {
exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) |
typeEquality(n, path, n2, path2)
typeEqualityRight(n, path, n2, path2)
or
typeEquality(n2, path2, n, path)
typeEqualityLeft(n2, path2, n, path)
)
}

Expand Down Expand Up @@ -909,7 +926,7 @@ private Type inferRefExprType(Expr e, TypePath path) {
e = re.getExpr() and
exists(TypePath exprPath, TypePath refPath, Type exprType |
result = inferType(re, exprPath) and
exprPath = TypePath::cons(TRefTypeParameter(), refPath) and
exprPath = TypePath::consInverse(TRefTypeParameter(), refPath) and
exprType = inferType(e)
|
if exprType = TRefType()
Expand All @@ -924,7 +941,7 @@ private Type inferRefExprType(Expr e, TypePath path) {
pragma[nomagic]
private Type inferTryExprType(TryExpr te, TypePath path) {
exists(TypeParam tp |
result = inferType(te.getExpr(), TypePath::cons(TTypeParamTypeParameter(tp), path))
result = inferType(te.getExpr(), TypePath::consInverse(TTypeParamTypeParameter(tp), path))
|
tp = any(ResultEnum r).getGenericParamList().getGenericParam(0)
or
Expand Down Expand Up @@ -1000,7 +1017,7 @@ private module Cached {
pragma[nomagic]
Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(this, path0) |
path0 = TypePath::cons(TRefTypeParameter(), path)
path0 = TypePath::consInverse(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,29 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
/** Holds if this type path is empty. */
predicate isEmpty() { this = "" }

/** Gets the length of this path, assuming the length is at least 2. */
bindingset[this]
pragma[inline_late]
private int length2() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe depthFrom2?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or lengthFrom2.

// Same as
// `result = strictcount(this.indexOf(".")) + 1`
// but performs better because it doesn't use an aggregate
result = this.regexpReplaceAll("[0-9]+", "").length() + 1
}

/** Gets the length of this path. */
bindingset[this]
pragma[inline_late]
int length() {
this.isEmpty() and result = 0
or
result = strictcount(this.indexOf(".")) + 1
if this.isEmpty()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps depth would actually be a better name than length?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't depth lead to the misconception that it is a tree instead of a list?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? To me it feel natural to say that the depth of foo/bar/baz is 3, but let's just keep it as-if if you feel it's not as clear.

then result = 0
else
if exists(TypeParameter::decode(this))
then result = 1
else result = this.length2()
}

/** Gets the path obtained by appending `suffix` onto this path. */
bindingset[suffix, result]
bindingset[this, result]
bindingset[this, suffix]
TypePath append(TypePath suffix) {
if this.isEmpty()
Expand All @@ -202,22 +213,37 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
then result = this
else (
result = this + "." + suffix and
not result.length() > getTypePathLimit()
(
not exists(getTypePathLimit())
or
result.length2() <= getTypePathLimit()
)
)
}

/**
* Gets the path obtained by appending `suffix` onto this path.
*
* Unlike `append`, this predicate has `result` in the binding set,
* so there is no need to check the length of `result`.
Comment on lines +225 to +228
Copy link

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for appendInverse is copied from append but this predicate actually deconstructs a full path into this and suffix. Please update the comment to describe the inverse operation.

Suggested change
* Gets the path obtained by appending `suffix` onto this path.
*
* Unlike `append`, this predicate has `result` in the binding set,
* so there is no need to check the length of `result`.
* Deconstructs a full path `result` into `this` and `suffix`.
*
* This predicate performs the inverse operation of `append`. It holds if
* `result` is a path that can be split into `this` as the prefix and
* `suffix` as the remainder. For example, if `result` is "a.b.c" and
* `this` is "a.b", then `suffix` would be "c".

Copilot uses AI. Check for mistakes.
*/
bindingset[this, result]
TypePath appendInverse(TypePath suffix) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A variant of this where the result is the output could be:

    bindingset[this, prefix]
    TypePath stripPrefix(TypePath prefix) { this = prefix.appendInverse(result) }

Similar to my comment for consInverse I find that more natural. However, unlike for isCons this variant is less convenient in all the places where appendInverse is used. But perhaps we could introduce stripPrefix, define appendInverse in terms of stripPrefix, and use stripPrefix in the (few if any) cases where it's not inconvenient?

if result.isEmpty()
then this.isEmpty() and suffix.isEmpty()
else
if this.isEmpty()
then suffix = result
else (
result = this and suffix.isEmpty()
or
result = this + "." + suffix
)
}

/** Holds if this path starts with `tp`, followed by `suffix`. */
bindingset[this]
predicate isCons(TypeParameter tp, TypePath suffix) {
tp = TypeParameter::decode(this) and
suffix.isEmpty()
or
exists(int first |
first = min(this.indexOf(".")) and
suffix = this.suffix(first + 1) and
tp = TypeParameter::decode(this.prefix(first))
)
}
predicate isCons(TypeParameter tp, TypePath suffix) { this = TypePath::consInverse(tp, suffix) }
}

/** Provides predicates for constructing `TypePath`s. */
Expand All @@ -232,9 +258,17 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* Gets the type path obtained by appending the singleton type path `tp`
* onto `suffix`.
*/
bindingset[result]
bindingset[suffix]
TypePath cons(TypeParameter tp, TypePath suffix) { result = singleton(tp).append(suffix) }

/**
* Gets the type path obtained by appending the singleton type path `tp`
* onto `suffix`.
*/
bindingset[result]
TypePath consInverse(TypeParameter tp, TypePath suffix) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consInverse is equivalent to isCons but with a different parameter order.

I find isCons easer to read as it's the result that is the output and since the name is easier to understand. Looking at the places where consInverse is used there's only one spot where it's not trivial to use isCons instead. Given that, I think it would make sense to remove consInverse and use isCons instead.

result = singleton(tp).appendInverse(suffix)
}
}

/**
Expand Down Expand Up @@ -556,7 +590,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
TypeMention tm1, TypeMention tm2, TypeParameter tp, TypePath path, Type t
) {
exists(TypePath prefix |
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.append(path))
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.appendInverse(path))
)
}

Expand Down Expand Up @@ -899,7 +933,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
tp = target.getDeclaredType(dpos, pathToTypeParam) and
accessDeclarationPositionMatch(apos, dpos) and
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t)
adjustedAccessType(a, apos, target, pathToTypeParam.appendInverse(path), t)
)
}

Expand Down Expand Up @@ -998,7 +1032,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {

RelevantAccess() { this = MkRelevantAccess(a, apos, path) }

Type getTypeAt(TypePath suffix) { a.getInferredType(apos, path.append(suffix)) = result }
Type getTypeAt(TypePath suffix) {
a.getInferredType(apos, path.appendInverse(suffix)) = result
}

/** Holds if this relevant access has the type `type` and should satisfy `constraint`. */
predicate hasTypeConstraint(Type type, Type constraint) {
Expand Down Expand Up @@ -1077,7 +1113,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
t0 = abs.getATypeParameter() and
exists(TypePath path3, TypePath suffix |
sub.resolveTypeAt(path3) = t0 and
at.getTypeAt(path3.append(suffix)) = t and
at.getTypeAt(path3.appendInverse(suffix)) = t and
path = prefix0.append(suffix)
)
)
Expand Down Expand Up @@ -1149,7 +1185,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
not exists(getTypeArgument(a, target, tp, _)) and
target = a.getTarget() and
exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam |
accessBaseType(a, apos, base, pathToTypeParam.append(path), t) and
accessBaseType(a, apos, base, pathToTypeParam.appendInverse(path), t) and
declarationBaseType(target, dpos, base, pathToTypeParam, tp) and
accessDeclarationPositionMatch(apos, dpos)
)
Expand Down Expand Up @@ -1217,7 +1253,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _, constraint, pathToTp,
tp) and
AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, constraint,
pathToTp.append(path), t)
pathToTp.appendInverse(path), t)
)
}

Expand Down
Loading