Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
80dd71c
Add param/ret/sig refinement-type annotations via thrust_macros
claude May 26, 2026
8c3343d
Model refinement type positions with a structured TypePosition
claude May 26, 2026
1275205
Redesign TypePosition as a flat sequence of TypePositionStep
claude May 27, 2026
0e9f0e3
Rename thrust::refine attribute to thrust::refinement_path
claude May 27, 2026
bd08174
Extract refinement-type annotation macros into refine module
claude May 27, 2026
e11a748
Use $i / bare integer syntax for refinement_path type positions
claude May 27, 2026
8123bd8
Clarify refinement type-position naming and align Display with syntax
claude May 27, 2026
70a8070
Refine type-position terminology and drop non-empty invariant
claude May 27, 2026
2322d32
Fix mis-named conversion at refinement_path extraction site
claude May 30, 2026
076831d
Test refinement_path on a generic ADT's element type
claude May 30, 2026
19b8514
Dedupe TypePositionStep docs, fix leftover [i] syntax, separate annot…
claude May 30, 2026
37262c9
Loosen to_refinement doc to cover non-direct positions
claude Jun 2, 2026
415bbdd
Split refinement-type macro into three concrete expanders
claude Jun 2, 2026
71b65ba
Split Refinement into RefinedType and RefinedTypeAnnotation
claude Jun 2, 2026
03336f2
Replace scan_type with parse_refined_type_annotations
claude Jun 2, 2026
d115126
Inline emit_error helper, trim verbose comments
claude Jun 2, 2026
d65014e
Handle paths/refs in refinement-type parser; strip nested refinements…
claude Jun 2, 2026
1f35265
Simplify generic-prefix scan and fold strip into the single parser pass
claude Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,85 @@ impl<'tcx> Analyzer<'tcx> {
ensure_annot
}

/// Collects every `#[thrust::refinement_path(..)]` path statement in the
/// function body, returning each `(type position, formula_fn DefId)`.
fn extract_refinement_paths(
&self,
local_def_id: LocalDefId,
) -> Vec<(rty::TypePosition, DefId)> {
let mut out = Vec::new();
let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else {
return out;
};
let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
return out;
};
let attr_path = analyze::annot::refinement_path_path();
let typeck = self.tcx.typeck(local_def_id);
for stmt in block.stmts {
let Some(attr) = self
.tcx
.hir_attrs(stmt.hir_id)
.iter()
.find(|attr| attr.path_matches(&attr_path))
else {
continue;
};
let ts = analyze::annot::extract_annot_tokens(attr.clone());
let position = analyze::annot::parse_type_position(&ts);

let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
self.tcx.dcx().span_err(
stmt.span,
"annotated path is expected to be a semi statement",
);
continue;
};
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to be a path expression",
);
continue;
};
let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to refer to a definition",
);
continue;
};
out.push((position, def_id));
}
out
}

/// Resolves every `#[thrust::refinement_path(..)]` annotation into a
/// positioned refinement, by translating the referenced formula function.
pub fn extract_refinement_annots(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Vec<(rty::TypePosition, rty::Refinement<rty::FunctionParamIdx>)> {
let mut out = Vec::new();
for (position, def_id) in self.extract_refinement_paths(local_def_id) {
let Some(formula_def_id) = def_id.as_local() else {
panic!(
"refinement_path annotation is expected to refer to a local def, but found: {:?}",
def_id
);
};
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
panic!(
"refinement_path annotation {:?} is not a formula function",
formula_def_id
);
};
out.push((position, formula_fn.to_refinement()));
}
out
}

/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
self.tcx
Expand Down
57 changes: 57 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("ensures_path")]
}

pub fn refinement_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("refinement_path")]
}

pub fn model_ty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -215,6 +219,59 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream {
d.tokens
}

/// Parses a [`rty::TypePosition`] from the tokens of a
/// `#[thrust::refinement_path(..)]` attribute.
///
/// Tokens are comma-separated [`rty::TypePositionStep`]s, each encoded as
/// `result` (→ `Return`), `$i` (→ `Param(i)`), or a bare integer `i` (→
/// `TypeArg(i)`).
pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition {
use rustc_ast::token::{LitKind, TokenKind};
use rustc_ast::tokenstream::TokenTree;

let parse_int = |lit: &rustc_ast::token::Lit| -> usize {
assert_eq!(
lit.kind,
LitKind::Integer,
"expected an integer in type position"
);
lit.symbol
.as_str()
.parse()
.expect("invalid integer in type position")
};

let mut steps = Vec::new();
let mut iter = ts.iter();
while let Some(tt) = iter.next() {
let TokenTree::Token(t, _) = tt else {
panic!("unexpected token tree in type position");
};
match &t.kind {
TokenKind::Comma => {}
TokenKind::Ident(sym, _) if sym.as_str() == "result" => {
steps.push(rty::TypePositionStep::Return);
}
TokenKind::Dollar => {
let i = match iter.next() {
Some(TokenTree::Token(t, _)) => match &t.kind {
TokenKind::Literal(lit) => parse_int(lit),
_ => panic!("expected integer after `$` in type position: {:?}", t),
},
_ => panic!("expected integer after `$` in type position"),
};
steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i)));
}
TokenKind::Literal(lit) => {
steps.push(rty::TypePositionStep::TypeArg(parse_int(lit)));
}
_ => panic!("unexpected token in type position: {:?}", t),
}
}

rty::TypePosition::new(steps)
}

pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) {
use rustc_ast::token::TokenKind;
use rustc_ast::tokenstream::TokenTree;
Expand Down
26 changes: 26 additions & 0 deletions src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ impl<'tcx> FormulaFn<'tcx> {
AnnotFormula::Formula(self.formula.clone())
}

/// Lowers an `ensures` formula function into a postcondition annotation.
///
/// Relies on the layout produced by `thrust_macros::ensures`: parameter `0`
/// is the function's return value (bound to [`rty::RefinedTypeVar::Value`])
/// and parameters `1..n` are the enclosing function's parameters in order
/// (mapped to [`rty::RefinedTypeVar::Free`]).
pub fn to_ensure_annot(&self) -> AnnotFormula<rty::RefinedTypeVar<rty::FunctionParamIdx>> {
AnnotFormula::Formula(self.formula.clone().map_var(|v| {
if v.as_usize() == 0 {
Expand All @@ -57,6 +63,26 @@ impl<'tcx> FormulaFn<'tcx> {
}
}))
}

/// Lowers a refinement-type formula function (generated by the `param` /
/// `ret` / `sig` macros) into a [`rty::Refinement`].
///
/// Relies on the layout produced by those macros: parameter `0` is the
/// refinement's value binder (mapped to [`rty::RefinedTypeVar::Value`])
/// and the remaining parameters carry the formula's free variables,
/// mapped to [`rty::RefinedTypeVar::Free`].
pub fn to_refinement(&self) -> rty::Refinement<rty::FunctionParamIdx> {
self.formula
.clone()
.map_var(|v| {
if v.as_usize() == 0 {
rty::RefinedTypeVar::Value
} else {
rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1))
}
})
.into()
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
7 changes: 7 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
assert!(require_annot.is_none() || param_annots.is_empty());
assert!(ensure_annot.is_none() || ret_annot.is_none());

let refinement_annots = self
.ctx
.extract_refinement_annots(self.local_def_id, self.generic_args);

let trait_item_ty = self.trait_item_ty();
let is_fully_annotated = self.is_fully_annotated();

Expand All @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
if let Some(ret_rty) = ret_annot {
builder.ret_rty(ret_rty);
}
for (position, refinement) in refinement_annots {
builder.refinement_at(&position, refinement);
}

if is_fully_annotated {
rty::RefinedType::unrefined(builder.build().into())
Expand Down
43 changes: 43 additions & 0 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,49 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
self.ret_rty = Some(rty);
self
}

/// Records a refinement to install at a [`rty::TypePosition`].
///
/// The first step must be [`rty::TypePositionStep::Param`] or
/// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to
/// [`rty::RefinedType::install_refinement_at`].
pub fn refinement_at(
&mut self,
position: &rty::TypePosition,
refinement: rty::Refinement<rty::FunctionParamIdx>,
) -> &mut Self {
let (first, rest) = match position.steps().split_first() {
Some(pair) => pair,
None => panic!("type position applied to a function type must not be empty"),
};
match first {
rty::TypePositionStep::Param(idx) => {
if !self.param_rtys.contains_key(idx) {
let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous();
self.param_rtys
.insert(*idx, rty::RefinedType::unrefined(ty));
}
self.param_rtys
.get_mut(idx)
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::Return => {
if self.ret_rty.is_none() {
let ty = self.inner.build(self.ret_ty).vacuous();
self.ret_rty = Some(rty::RefinedType::unrefined(ty));
}
self.ret_rty
.as_mut()
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::TypeArg(_) => {
panic!("type position applied to a function type must start with a param or result step, not a type argument");
}
}
self
}
}

impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>
Expand Down
118 changes: 118 additions & 0 deletions src/rty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,80 @@ where
}
}

/// One step in a [`TypePosition`] path.
///
/// A path is a sequence of steps that addresses a sub-type within a
/// (potentially nested) function signature:
/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a
/// function type's parameter or return slot respectively.
/// - [`TypeArg`](Self::TypeArg) navigates into the `i`-th type argument of a
/// generic type (enum, `Box`, etc.).
///
/// Using distinct variants for function navigation ([`Param`](Self::Param),
/// [`Return`](Self::Return)) and generic-arg navigation
/// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address
/// positions inside higher-order function types. For example, `$0, result`
/// addresses the return type of a function-typed first parameter.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypePositionStep {
/// Navigate to the `i`-th parameter of a function type.
Param(FunctionParamIdx),
/// Navigate to the return type of a function type.
Return,
/// Navigate to the `i`-th type argument of a generic type (enum, `Box`, …).
TypeArg(usize),
}

impl std::fmt::Display for TypePositionStep {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
TypePositionStep::Param(idx) => write!(f, "{}", idx),
TypePositionStep::Return => f.write_str("result"),
TypePositionStep::TypeArg(i) => write!(f, "{}", i),
}
}
}

/// A path of [`TypePositionStep`]s addressing a sub-type within a type, used
/// to attach a refinement.
///
/// An empty path addresses the type itself; a path applied to a function type
/// is therefore non-empty, beginning with a `Param` or `Return` step.
///
/// Examples (function `fn f(x: List<T>) -> Box<T>`):
/// - `$0` — parameter `x`.
/// - `result` — the return type.
/// - `$0, 0` — the first type arg of `x`.
/// - `result, 0` — the pointee of the `Box` return.
/// - `$0, result` — the return of a function-typed param `x`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypePosition {
steps: Vec<TypePositionStep>,
}

impl TypePosition {
pub fn new(steps: Vec<TypePositionStep>) -> Self {
TypePosition { steps }
}

pub fn steps(&self) -> &[TypePositionStep] {
&self.steps
}
}

impl std::fmt::Display for TypePosition {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut iter = self.steps.iter();
if let Some(first) = iter.next() {
write!(f, "{}", first)?;
}
for s in iter {
write!(f, ", {}", s)?;
}
Ok(())
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum FunctionAbi {
#[default]
Expand Down Expand Up @@ -1475,6 +1549,50 @@ where
}
}

impl RefinedType<FunctionParamIdx> {
/// Installs `refinement` at the sub-type addressed by `steps`.
///
/// An empty `steps` slice replaces the refinement at this node; otherwise
/// each step navigates one level deeper per [`TypePositionStep`].
pub fn install_refinement_at(
&mut self,
steps: &[TypePositionStep],
refinement: Refinement<FunctionParamIdx>,
) {
let Some((step, rest)) = steps.split_first() else {
self.refinement = refinement;
return;
};
match step {
TypePositionStep::TypeArg(i) => match &mut self.ty {
Type::Enum(e) => {
let arg = e.args.get_mut(TypeParamIdx::from(*i)).unwrap_or_else(|| {
panic!("refine step [{}] out of range for enum type", i)
});
arg.install_refinement_at(rest, refinement);
}
Type::Pointer(p) => {
assert_eq!(*i, 0, "Box type position must be 0");
p.elem.install_refinement_at(rest, refinement);
}
ty => panic!("TypeArg step on unsupported type: {:?}", ty),
},
TypePositionStep::Param(idx) => match &mut self.ty {
Type::Function(func) => {
func.params[*idx].install_refinement_at(rest, refinement);
}
ty => panic!("Param step on non-function type: {:?}", ty),
},
TypePositionStep::Return => match &mut self.ty {
Type::Function(func) => {
func.ret.install_refinement_at(rest, refinement);
}
ty => panic!("Return step on non-function type: {:?}", ty),
},
}
}
}

impl<FV> RefinedType<FV> {
fn pretty_atom<'a, 'b, D>(
&'b self,
Expand Down
Loading