Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 98 additions & 87 deletions rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
}

// Due to "binding modes" the type of the pattern is not necessarily the
// same as the type of the initializer. The pattern being an identifier
// pattern is sufficient to ensure that this is not the case.
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
let.getPat() = lhs and
let.getInitializer() = rhs
}

/** Module for inferring certain type information. */
module CertainTypeInference {
pragma[nomagic]
Expand Down Expand Up @@ -485,11 +493,7 @@ module CertainTypeInference {
// is not a certain type equality.
exists(LetStmt let |
not let.hasTypeRepr() and
// Due to "binding modes" the type of the pattern is not necessarily the
// same as the type of the initializer. The pattern being an identifier
// pattern is sufficient to ensure that this is not the case.
let.getPat().(IdentPat) = n1 and
let.getInitializer() = n2
identLetStmt(let, n1, n2)
)
or
exists(LetExpr let |
Expand All @@ -513,6 +517,25 @@ module CertainTypeInference {
)
else prefix2.isEmpty()
)
or
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
n1 = dce.getArgList() and
tt.getArity() = dce.getNumberOfSyntacticArguments() and
n2 = dce.getSyntacticPositionalArgument(i) and
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
prefix2.isEmpty()
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -775,17 +798,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
prefix2.isEmpty()
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1.(ClosureExpr).getClosureBody() = n2 and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

/**
Expand Down Expand Up @@ -828,6 +840,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
)
}

private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
inferType(n, path) = TUnknownType() and
// Normally, these are coercion sites, but in case a type is unknown we
// allow for type information to flow from the type annotation.
exists(TypeMention tm | result = tm.getTypeAt(path) |
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
or
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
or
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
)
}

/**
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
Expand Down Expand Up @@ -1533,6 +1558,8 @@ private module MethodResolution {
* or
* 4. `MethodCallOperation`: an operation expression, `x + y`, which is syntactic sugar
* for `Add::add(x, y)`.
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
*
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
*
Expand All @@ -1544,7 +1571,7 @@ private module MethodResolution {
abstract class MethodCall extends Expr {
abstract predicate hasNameAndArity(string name, int arity);

abstract Expr getArg(ArgumentPosition pos);
abstract AstNode getArg(ArgumentPosition pos);

abstract predicate supportsAutoDerefAndBorrow();

Expand Down Expand Up @@ -2093,6 +2120,26 @@ private module MethodResolution {
override Trait getTrait() { super.isOverloaded(result, _, _) }
}

private class ClosureMethodCall extends MethodCall instanceof CallExprImpl::DynamicCallExpr {
pragma[nomagic]
override predicate hasNameAndArity(string name, int arity) {
name = "call_once" and // todo: handle call_mut and call
arity = 1 // args are passed in a tuple
}

override AstNode getArg(ArgumentPosition pos) {
pos.isSelf() and
result = super.getFunction()
or
pos.asPosition() = 0 and
result = super.getArgList()
}

override predicate supportsAutoDerefAndBorrow() { any() }

override Trait getTrait() { result instanceof AnyFnTrait }
}

pragma[nomagic]
private Method getMethodSuccessor(ImplOrTraitItemNode i, string name, int arity) {
result = i.getASuccessor(name) and
Expand Down Expand Up @@ -2600,7 +2647,9 @@ private Type inferMethodCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePat
* empty, at which point the inferred type can be applied back to `n`.
*/
pragma[nomagic]
private Type inferMethodCallTypeSelf(MethodCall mc, AstNode n, DerefChain derefChain, TypePath path) {
private Type inferMethodCallTypeSelf(
MethodCallMatchingInput::Access mc, AstNode n, DerefChain derefChain, TypePath path
) {
exists(
MethodCallMatchingInput::AccessPosition apos, string derefChainBorrow, BorrowKind borrow,
TypePath path0
Expand Down Expand Up @@ -2644,7 +2693,7 @@ private Type inferMethodCallTypeSelf(MethodCall mc, AstNode n, DerefChain derefC
private Type inferMethodCallTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
result = inferMethodCallTypeNonSelf(n, pos, path)
or
exists(MethodCall mc |
exists(MethodCallMatchingInput::Access mc |
result = inferMethodCallTypeSelf(mc, n, DerefChain::nil(), path) and
if mc instanceof CallExpr then pos.asPosition() = 0 else pos.isSelf()
)
Expand Down Expand Up @@ -3942,14 +3991,6 @@ private module InvokedClosureSatisfiesConstraintInput implements
}
}

private module InvokedClosureSatisfiesConstraint =
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>;

/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
InvokedClosureSatisfiesConstraint::satisfiesConstraintType(ce, _, path, result)
}

/**
* Gets the root type of a closure.
*
Expand All @@ -3976,73 +4017,39 @@ private TypePath closureParameterPath(int arity, int index) {
TypePath::singleton(getTupleTypeParameter(arity, index)))
}

/** Gets the path to the return type of the `FnOnce` trait. */
private TypePath fnReturnPath() {
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/**
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
* and index `index`.
*/
pragma[nomagic]
private TypePath fnParameterPath(int arity, int index) {
result =
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(getTupleTypeParameter(arity, index)))
}

pragma[nomagic]
private Type inferDynamicCallExprType(Expr n, TypePath path) {
exists(InvokedClosureExpr ce |
// Propagate the function's return type to the call expression
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
n = ce.getCall() and
path = path0.stripPrefix(fnReturnPath())
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
(
path.isEmpty() and
result = closureRootType()
or
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate the function's parameter type to the arguments
exists(int index |
n = ce.getCall().getSyntacticPositionalArgument(index) and
path =
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
exists(TypePath path0 |
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and
path = closureReturnPath().append(path0)
)
)
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
// Propagate the type of arguments to the parameter types of closure
exists(int index, ArgList args |
n = ce and
args = ce.getCall().getArgList() and
arity = args.getNumberOfArgs() and
result = inferType(args.getArg(index), path0) and
path = closureParameterPath(arity, index).append(path0)
)
or
// Propagate the type of the call expression to the return type of the closure
n = ce and
arity = ce.getCall().getArgList().getNumberOfArgs() and
result = inferType(ce.getCall(), path0) and
path = closureReturnPath().append(path0)
exists(Param p |
p = ce.getAParam() and
not p.hasTypeRepr() and
n = p.getPat() and
result = TUnknownType() and
path.isEmpty()
)
)
}

pragma[nomagic]
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = closureRootType()
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate return type annotation to body
n = ce.getClosureBody() and
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path)
private TupleType inferArgList(ArgList args, TypePath path) {
exists(CallExprImpl::DynamicCallExpr dce |
args = dce.getArgList() and
result.getArity() = dce.getNumberOfSyntacticArguments() and
path.isEmpty()
)
}

Expand Down Expand Up @@ -4089,7 +4096,9 @@ private module Cached {
or
i instanceof ImplItemNode and dispatch = false
|
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) or
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) and
not call instanceof CallExprImpl::DynamicCallExpr
or
result = call.(NonMethodResolution::NonMethodCall).resolveCallTargetViaTypeInference(i)
)
}
Expand Down Expand Up @@ -4199,13 +4208,15 @@ private module Cached {
or
result = inferForLoopExprType(n, path)
or
result = inferDynamicCallExprType(n, path)
or
result = inferClosureExprType(n, path)
or
result = inferArgList(n, path)
or
result = inferStructPatType(n, path)
or
result = inferTupleStructPatType(n, path)
or
result = inferUnknownTypeFromAnnotation(n, path)
)
}
}
Expand Down
Loading
Loading