Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/ql/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion rust/ql/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// generated by codegen, remove this comment if you wish to edit this file
/**
* This module provides a hand-modifiable wrapper around the generated class `DynTraitTypeRepr`.
*
Expand All @@ -12,6 +11,10 @@ private import codeql.rust.elements.internal.generated.DynTraitTypeRepr
* be referenced directly.
*/
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A dynamic trait object type.
*
Expand All @@ -21,5 +24,16 @@ module Impl {
* // ^^^^^^^^^
* ```
*/
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr { }
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr {
/** Gets the trait that this trait object refers to. */
pragma[nomagic]
Trait getTrait() {
result =
PathResolution::resolvePath(this.getTypeBoundList()
.getBound(0)
.getTypeRepr()
.(PathTypeRepr)
.getPath())
}
}
}
43 changes: 43 additions & 0 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ newtype TType =
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(ImplTraitTypeRepr impl) or
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
TSliceType() or
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
TDynTraitTypeParameter(TypeParam tp) {
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getAGenericParam()
} or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()
Expand Down Expand Up @@ -247,6 +251,26 @@ class ImplTraitType extends Type, TImplTraitType {
override Location getLocation() { result = impl.getLocation() }
}

class DynTraitType extends Type, TDynTraitType {
Trait trait;

DynTraitType() { this = TDynTraitType(trait) }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override DynTraitTypeParameter getTypeParameter(int i) {
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
Comment thread
geoffw0 marked this conversation as resolved.
}

Trait getTrait() { result = trait }

override string toString() { result = "dyn " + trait.getName().toString() }

override Location getLocation() { result = trait.getLocation() }
}

/**
* An [impl Trait in return position][1] type, for example:
*
Expand Down Expand Up @@ -381,6 +405,18 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
override Location getLocation() { result instanceof EmptyLocation }
}

class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
private TypeParam typeParam;

DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }

TypeParam getTypeParam() { result = typeParam }

override string toString() { result = "dyn(" + typeParam.toString() + ")" }

override Location getLocation() { result = typeParam.getLocation() }
}

/** An implicit reference type parameter. */
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
override string toString() { result = "&T" }
Expand Down Expand Up @@ -465,6 +501,13 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
}
}

final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() =
this.getTrait().getGenericParamList().getATypeParam()
}
}

final class TraitTypeAbstraction extends TypeAbstraction, Trait {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
Expand Down
33 changes: 30 additions & 3 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ private module Input1 implements InputSig1<Location> {
id = 2
or
kind = 1 and
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
Comment thread
geoffw0 marked this conversation as resolved.
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
Expand All @@ -107,7 +110,7 @@ private module Input1 implements InputSig1<Location> {
exists(TupleTypeParameter ttp, int maxArity |
maxArity = max(int i | i = any(TupleType tt).getArity()) and
tp0 = ttp and
kind = 2 and
kind = 3 and
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
)
|
Expand Down Expand Up @@ -189,6 +192,14 @@ private module Input2 implements InputSig2 {
condition = impl and
constraint = impl.getTypeBoundList().getABound().getTypeRepr()
)
or
// a `dyn Trait` type implements `Trait`. See the comment on
// `DynTypeBoundListMention` for further details.
exists(DynTraitTypeRepr object |
abs = object and
condition = object.getTypeBoundList() and
constraint = object.getTrait()
)
}
}

Expand Down Expand Up @@ -1715,10 +1726,16 @@ private Function getMethodFromImpl(MethodCall mc) {

bindingset[trait, name]
pragma[inline_late]
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
private Function getImplTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}

bindingset[traitObject, name]
pragma[inline_late]
private Function getDynTraitMethod(DynTraitType traitObject, string name) {
result = getMethodSuccessor(traitObject.getTrait(), name)
}

pragma[nomagic]
private Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
Expand All @@ -1729,7 +1746,10 @@ private Function resolveMethodCallTarget(MethodCall mc) {
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
result = getImplTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is a trait object `dyn Trait` type.
result = getDynTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

pragma[nomagic]
Expand Down Expand Up @@ -2073,6 +2093,13 @@ private module Debug {
result = resolveCallTarget(c)
}

predicate debugConditionSatisfiesConstraint(
TypeAbstraction abs, TypeMention condition, TypeMention constraint
) {
abs = getRelevantLocatable() and
Input2::conditionSatisfiesConstraint(abs, condition, constraint)
}

predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
self = getRelevantLocatable() and
t = inferImplicitSelfType(self, path)
Expand Down
61 changes: 61 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,64 @@ class SelfTypeParameterMention extends TypeMention instanceof Name {
result = TSelfTypeParameter(trait)
}
}

class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
private DynTraitType dynType;

DynTraitTypeReprMention() {
// This excludes `DynTraitTypeRepr` elements where `getTrait` is not
// defined, i.e., where path resolution can't find a trait.
dynType.getTrait() = super.getTrait()
}

override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result = dynType
or
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
tp = dynType.getTypeParameter(_) and
path = TypePath::cons(tp, suffix) and
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
)
}
}

// We want a type of the form `dyn Trait` to implement `Trait`. If `Trait` has
// type parameters then `dyn Trait` has equivalent type parameters and the
// implementation should be abstracted over them.
//
// Intuitively we want something to the effect of:
// ```
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
// ```
// To achieve this:
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
// Trait` at the root and which for every type parameter of `dyn Trait` has the
// corresponding type parameter of the trait.
// - `TraitMention` (which is used for other things as well) is a type mention
// for the trait applied to its own type parameters.
//
// We arbitrarily use the `TypeBoundList` inside `DynTraitTypeRepr` to encode
// this type mention, since it doesn't syntactically appear in the AST. This
// works because there is a one-to-one correspondence between a trait object and
// its list of type bounds.
class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
private Trait trait;

DynTypeBoundListMention() {
exists(DynTraitTypeRepr dyn | this = dyn.getTypeBoundList() and trait = dyn.getTrait())
}

override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result.(DynTraitType).getTrait() = trait
or
exists(TypeParam param |
param = trait.getGenericParamList().getATypeParam() and
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
result = TTypeParamTypeParameter(param)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
category: minorAnalysis
---
* Type inference now supports trait objects, i.e., `dyn Trait` types.
67 changes: 67 additions & 0 deletions rust/ql/test/library-tests/type-inference/dyn_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Test cases for type inference and method resolution with `dyn` types

use std::fmt::Debug;

trait MyTrait1 {
// MyTrait1::m
fn m(&self) -> String;
}

trait GenericGet<A> {
// GenericGet::get
fn get(&self) -> A;
}

#[derive(Clone, Debug)]
struct MyStruct {
value: i32,
}

impl MyTrait1 for MyStruct {
// MyStruct1::m
fn m(&self) -> String {
format!("MyTrait1: {}", self.value) // $ fieldof=MyStruct
}
}

#[derive(Clone, Debug)]
struct GenStruct<A: Clone + Debug> {
value: A,
}

impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
// GenStruct<A>::get
fn get(&self) -> A {
self.value.clone() // $ fieldof=GenStruct target=clone
}
}

fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
a.get() // $ target=GenericGet::get
}

fn get_box_trait<A: Clone + Debug + 'static>(a: A) -> Box<dyn GenericGet<A>> {
Box::new(GenStruct { value: a }) // $ target=new
}

fn test_basic_dyn_trait(obj: &dyn MyTrait1) {
let _result = (*obj).m(); // $ target=deref target=MyTrait1::m type=_result:String
}

fn test_generic_dyn_trait(obj: &dyn GenericGet<String>) {
let _result1 = (*obj).get(); // $ target=deref target=GenericGet::get type=_result1:String
let _result2 = get_a(obj); // $ target=get_a type=_result2:String
}

fn test_poly_dyn_trait() {
let obj = get_box_trait(true); // $ target=get_box_trait
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
}

pub fn test() {
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
test_generic_dyn_trait(&GenStruct {
value: "".to_string(),
}); // $ target=test_generic_dyn_trait
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
}
8 changes: 5 additions & 3 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2292,8 +2292,6 @@ mod loops {
}
}

mod dereference;

mod explicit_type_args {
struct S1<T>(T);

Expand Down Expand Up @@ -2461,6 +2459,9 @@ mod closures {
}
}

mod dereference;
mod dyn_type;

fn main() {
field_access::f(); // $ target=f
method_impl::f(); // $ target=f
Expand Down Expand Up @@ -2491,5 +2492,6 @@ fn main() {
dereference::test(); // $ target=test
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
closures::f() // $ target=f
closures::f(); // $ target=f
dyn_type::test(); // $ target=test
}
Loading