Skip to content

Commit 5587466

Browse files
committed
Rust: Improve type inference for closures
1 parent e6385ac commit 5587466

File tree

5 files changed

+525
-141
lines changed

5 files changed

+525
-141
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 96 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
407407
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
408408
}
409409

410+
// Due to "binding modes" the type of the pattern is not necessarily the
411+
// same as the type of the initializer. The pattern being an identifier
412+
// pattern is sufficient to ensure that this is not the case.
413+
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
414+
let.getPat() = lhs and
415+
let.getInitializer() = rhs
416+
}
417+
410418
/** Module for inferring certain type information. */
411419
module CertainTypeInference {
412420
pragma[nomagic]
@@ -484,11 +492,7 @@ module CertainTypeInference {
484492
// is not a certain type equality.
485493
exists(LetStmt let |
486494
not let.hasTypeRepr() and
487-
// Due to "binding modes" the type of the pattern is not necessarily the
488-
// same as the type of the initializer. The pattern being an identifier
489-
// pattern is sufficient to ensure that this is not the case.
490-
let.getPat().(IdentPat) = n1 and
491-
let.getInitializer() = n2
495+
identLetStmt(let, n1, n2)
492496
)
493497
or
494498
exists(LetExpr let |
@@ -512,6 +516,25 @@ module CertainTypeInference {
512516
)
513517
else prefix2.isEmpty()
514518
)
519+
or
520+
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
521+
n1 = dce.getArgList() and
522+
tt.getArity() = dce.getNumberOfSyntacticArguments() and
523+
n2 = dce.getSyntacticPositionalArgument(i) and
524+
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
525+
prefix2.isEmpty()
526+
)
527+
or
528+
exists(ClosureExpr ce, int index |
529+
n1 = ce and
530+
n2 = ce.getParam(index).getPat() and
531+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
532+
prefix2.isEmpty()
533+
)
534+
or
535+
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
536+
prefix1 = closureReturnPath() and
537+
prefix2.isEmpty()
515538
}
516539

517540
pragma[nomagic]
@@ -781,17 +804,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
781804
prefix2.isEmpty() and
782805
s = getRangeType(n1)
783806
)
784-
or
785-
exists(ClosureExpr ce, int index |
786-
n1 = ce and
787-
n2 = ce.getParam(index).getPat() and
788-
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
789-
prefix2.isEmpty()
790-
)
791-
or
792-
n1.(ClosureExpr).getClosureBody() = n2 and
793-
prefix1 = closureReturnPath() and
794-
prefix2.isEmpty()
795807
}
796808

797809
/**
@@ -828,6 +840,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
828840
prefix.isEmpty()
829841
}
830842

843+
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
844+
inferType(n, path) = TUnknownType() and
845+
// Normally, these are coercion sites, but in case a type is unknown we
846+
// allow for type information to flow from the type annotation.
847+
exists(TypeMention tm | result = tm.resolveTypeAt(path) |
848+
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
849+
or
850+
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
851+
or
852+
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
853+
)
854+
}
855+
831856
/**
832857
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
833858
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
@@ -1509,6 +1534,8 @@ private module MethodResolution {
15091534
* or
15101535
* 4. `MethodCallOperation`: an operation expression, `x + y`, which is syntactic sugar
15111536
* for `Add::add(x, y)`.
1537+
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
1538+
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
15121539
*
15131540
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
15141541
*
@@ -1520,7 +1547,7 @@ private module MethodResolution {
15201547
abstract class MethodCall extends Expr {
15211548
abstract predicate hasNameAndArity(string name, int arity);
15221549

1523-
abstract Expr getArg(ArgumentPosition pos);
1550+
abstract AstNode getArg(ArgumentPosition pos);
15241551

15251552
abstract predicate supportsAutoDerefAndBorrow();
15261553

@@ -2050,6 +2077,26 @@ private module MethodResolution {
20502077
override Trait getTrait() { super.isOverloaded(result, _, _) }
20512078
}
20522079

2080+
private class ClosureMethodCall extends MethodCall instanceof CallExprImpl::DynamicCallExpr {
2081+
pragma[nomagic]
2082+
override predicate hasNameAndArity(string name, int arity) {
2083+
name = "call_once" and // todo: handle call_mut and call
2084+
arity = 1 // args are passed in a tuple
2085+
}
2086+
2087+
override AstNode getArg(ArgumentPosition pos) {
2088+
pos.isSelf() and
2089+
result = super.getFunction()
2090+
or
2091+
pos.asPosition() = 0 and
2092+
result = super.getArgList()
2093+
}
2094+
2095+
override predicate supportsAutoDerefAndBorrow() { any() }
2096+
2097+
override Trait getTrait() { result instanceof AnyFnTrait }
2098+
}
2099+
20532100
pragma[nomagic]
20542101
private Method getMethodSuccessor(ImplOrTraitItemNode i, string name, int arity) {
20552102
result = i.getASuccessor(name) and
@@ -3860,12 +3907,6 @@ private module InvokedClosureSatisfiesConstraintInput implements
38603907
}
38613908
}
38623909

3863-
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
3864-
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3865-
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
3866-
_, path, result)
3867-
}
3868-
38693910
/**
38703911
* Gets the root type of a closure.
38713912
*
@@ -3892,73 +3933,39 @@ private TypePath closureParameterPath(int arity, int index) {
38923933
TypePath::singleton(getTupleTypeParameter(arity, index)))
38933934
}
38943935

3895-
/** Gets the path to the return type of the `FnOnce` trait. */
3896-
private TypePath fnReturnPath() {
3897-
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3898-
}
3899-
3900-
/**
3901-
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
3902-
* and index `index`.
3903-
*/
39043936
pragma[nomagic]
3905-
private TypePath fnParameterPath(int arity, int index) {
3906-
result =
3907-
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
3908-
TypePath::singleton(getTupleTypeParameter(arity, index)))
3909-
}
3910-
3911-
pragma[nomagic]
3912-
private Type inferDynamicCallExprType(Expr n, TypePath path) {
3913-
exists(InvokedClosureExpr ce |
3914-
// Propagate the function's return type to the call expression
3915-
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
3916-
n = ce.getCall() and
3917-
path = path0.stripPrefix(fnReturnPath())
3937+
private Type inferClosureExprType(AstNode n, TypePath path) {
3938+
exists(ClosureExpr ce |
3939+
n = ce and
3940+
(
3941+
path.isEmpty() and
3942+
result = closureRootType()
39183943
or
3919-
// Propagate the function's parameter type to the arguments
3920-
exists(int index |
3921-
n = ce.getCall().getSyntacticPositionalArgument(index) and
3922-
path =
3923-
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
3944+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3945+
result.(TupleType).getArity() = ce.getNumberOfParams()
3946+
or
3947+
exists(TypePath path0 |
3948+
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path0) and
3949+
path = closureReturnPath().append(path0)
39243950
)
39253951
)
39263952
or
3927-
// _If_ the invoked expression has the type of a closure, then we propagate
3928-
// the surrounding types into the closure.
3929-
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
3930-
// Propagate the type of arguments to the parameter types of closure
3931-
exists(int index, ArgList args |
3932-
n = ce and
3933-
args = ce.getCall().getArgList() and
3934-
arity = args.getNumberOfArgs() and
3935-
result = inferType(args.getArg(index), path0) and
3936-
path = closureParameterPath(arity, index).append(path0)
3937-
)
3938-
or
3939-
// Propagate the type of the call expression to the return type of the closure
3940-
n = ce and
3941-
arity = ce.getCall().getArgList().getNumberOfArgs() and
3942-
result = inferType(ce.getCall(), path0) and
3943-
path = closureReturnPath().append(path0)
3953+
exists(Param p |
3954+
p = ce.getAParam() and
3955+
not p.hasTypeRepr() and
3956+
n = p.getPat() and
3957+
result = TUnknownType() and
3958+
path.isEmpty()
39443959
)
39453960
)
39463961
}
39473962

39483963
pragma[nomagic]
3949-
private Type inferClosureExprType(AstNode n, TypePath path) {
3950-
exists(ClosureExpr ce |
3951-
n = ce and
3952-
path.isEmpty() and
3953-
result = closureRootType()
3954-
or
3955-
n = ce and
3956-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3957-
result.(TupleType).getArity() = ce.getNumberOfParams()
3958-
or
3959-
// Propagate return type annotation to body
3960-
n = ce.getClosureBody() and
3961-
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
3964+
private TupleType inferArgList(ArgList args, TypePath path) {
3965+
exists(CallExprImpl::DynamicCallExpr dce |
3966+
args = dce.getArgList() and
3967+
result.getArity() = dce.getNumberOfSyntacticArguments() and
3968+
path.isEmpty()
39623969
)
39633970
}
39643971

@@ -4005,7 +4012,9 @@ private module Cached {
40054012
or
40064013
i instanceof ImplItemNode and dispatch = false
40074014
|
4008-
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) or
4015+
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) and
4016+
not call instanceof CallExprImpl::DynamicCallExpr
4017+
or
40094018
result = call.(NonMethodResolution::NonMethodCall).resolveCallTargetViaTypeInference(i)
40104019
)
40114020
}
@@ -4115,13 +4124,15 @@ private module Cached {
41154124
or
41164125
result = inferForLoopExprType(n, path)
41174126
or
4118-
result = inferDynamicCallExprType(n, path)
4119-
or
41204127
result = inferClosureExprType(n, path)
41214128
or
4129+
result = inferArgList(n, path)
4130+
or
41224131
result = inferStructPatType(n, path)
41234132
or
41244133
result = inferTupleStructPatType(n, path)
4134+
or
4135+
result = inferUnknownTypeFromAnnotation(n, path)
41254136
)
41264137
}
41274138
}
@@ -4138,8 +4149,8 @@ private module Debug {
41384149
Locatable getRelevantLocatable() {
41394150
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
41404151
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
4141-
filepath.matches("%/sqlx.rs") and
4142-
startline = [56 .. 60]
4152+
filepath.matches("%/closure.rs") and
4153+
startline = [10]
41434154
)
41444155
}
41454156

rust/ql/test/library-tests/type-inference/closure.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ mod fn_once_trait {
6363
};
6464
let _r = apply(f, true); // $ target=apply type=_r:i64
6565

66-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
66+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
6767
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
6868
}
6969
}
@@ -100,7 +100,7 @@ mod fn_mut_trait {
100100
};
101101
let _r = apply(f, true); // $ target=apply type=_r:i64
102102

103-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
103+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
104104
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
105105
}
106106
}
@@ -137,7 +137,7 @@ mod fn_trait {
137137
};
138138
let _r = apply(f, true); // $ target=apply type=_r:i64
139139

140-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
140+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
141141
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
142142
}
143143
}
@@ -183,25 +183,25 @@ mod closure_infer_param {
183183
}
184184

185185
fn test() {
186-
let f = |x| x; // $ MISSING: type=x:i64
186+
let f = |x| x; // $ type=x:i64
187187
let _r = apply1(f, 1i64); // $ target=apply1
188188

189-
let f = |x| x; // $ MISSING: type=x:i64
189+
let f = |x| x; // $ type=x:i64
190190
let _r = apply2(f, 2i64); // $ target=apply2
191191

192-
let f = |x| x; // $ MISSING: type=x:i64
192+
let f = |x| x; // $ type=x:i64
193193
let _r = apply3(&f, 3i64); // $ target=apply3
194194

195-
let f = |x| x; // $ MISSING: type=x:i64
195+
let f = |x| x; // $ type=x:i64
196196
let _r = apply4(f, 4i64); // $ target=apply4
197197

198198
let mut f = |x| x; // $ MISSING: type=x:i64
199199
let _r = apply5(&mut f, 5i64); // $ target=apply5
200200

201-
let f = |x| x; // $ MISSING: type=x:i64
201+
let f = |x| x; // $ type=x:i64
202202
let _r = apply6(f, 6i64); // $ target=apply6
203203

204-
let f = |x| x; // $ MISSING: type=x:i64
204+
let f = |x| x; // $ type=x:i64
205205
let _r = apply7(f, 7i64); // $ target=apply7
206206
}
207207
}
@@ -221,15 +221,15 @@ mod implicit_deref {
221221

222222
pub fn test() {
223223
let x = 0i64;
224-
let v = Default::default(); // $ MISSING: type=v:i64 target=default
224+
let v = Default::default(); // $ type=v:i64 target=default
225225
let s = S(v);
226-
let _ret = s(x); // $ MISSING: type=_ret:bool
226+
let _ret = s(x); // $ type=_ret:bool
227227

228228
let x = 0i32;
229-
let v = Default::default(); // $ MISSING: type=v:i32 target=default
229+
let v = Default::default(); // $ type=v:i32 target=default
230230
let s = S(v);
231231
let s_ref = &s;
232-
let _ret = s_ref(x); // $ MISSING: type=_ret:bool
232+
let _ret = s_ref(x); // $ type=_ret:bool
233233

234234
// The call below is not an implicit deref, instead it will target
235235
// `impl<A, F> FnOnce<A> for &F` from

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,7 +2396,7 @@ mod loops {
23962396
// for loops with arrays
23972397

23982398
for i in [1, 2, 3] {} // $ type=i:i32
2399-
for i in [1, 2, 3].map(|x| x + 1) {} // $ target=map MISSING: type=i:i32
2399+
for i in [1, 2, 3].map(|x| x + 1) {} // $ target=map target=add type=i:i32
24002400
for i in [1, 2, 3].into_iter() {} // $ target=into_iter type=i:i32
24012401

24022402
let vals1 = [1u8, 2, 3]; // $ type=vals1:TArray.u8
@@ -2896,7 +2896,7 @@ mod arg_trait_bounds {
28962896
}
28972897

28982898
fn test() {
2899-
let v = Default::default(); // $ MISSING: type=v:i64 target=default
2899+
let v = Default::default(); // $ type=v:i64 target=default
29002900
let g = Gen(v);
29012901
let _ = my_get(&g); // $ target=my_get
29022902
}

0 commit comments

Comments
 (0)