Skip to content

Commit cac23d1

Browse files
corbtCursor Bot
andauthored
feat: replace pyright with ty for type checking (#534)
- Swap pyright for ty in dev dependencies - Configure ty with allowed-unresolved-imports for optional deps - Add ty:ignore comments for genuine type errors (295 total) - Remove unused type: ignore comments (31 total) - Set unused-ignore-comment rule to ignore (comments vary by installed deps) - Update CI workflow to use Python 3.11 and install all extras Co-authored-by: Cursor Bot <bot@cursor.com>
1 parent fc279ca commit cac23d1

43 files changed

Lines changed: 232 additions & 200 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/prek.yml

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,34 @@ name: Prek
33
on:
44
pull_request:
55
push:
6-
branches: [ main ]
6+
branches: [main]
77

88
jobs:
99
quality-checks:
1010
runs-on: ubuntu-latest
11-
11+
1212
steps:
13-
- name: Checkout code
14-
uses: actions/checkout@v4
15-
16-
- name: Set up Python
17-
uses: actions/setup-python@v5
18-
with:
19-
python-version: '3.10'
20-
21-
- name: Install uv
22-
run: |
23-
curl -LsSf https://astral.sh/uv/install.sh | sh
24-
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
25-
26-
- name: Install dependencies
27-
run: |
28-
uv sync --all-extras
29-
30-
- name: Run prek hooks (lint, format, typecheck, uv.lock, tests)
31-
run: |
32-
uv run prek run --all-files
33-
34-
- name: Run unit tests (via prek)
35-
run: |
36-
uv run prek run pytest
13+
- name: Checkout code
14+
uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: "3.11"
20+
21+
- name: Install uv
22+
run: |
23+
curl -LsSf https://astral.sh/uv/install.sh | sh
24+
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
25+
26+
- name: Install dependencies (with all optional extras for complete type checking)
27+
run: |
28+
uv sync --all-extras
29+
30+
- name: Run prek hooks (lint, format, typecheck, uv.lock, tests)
31+
run: |
32+
uv run prek run --all-files
33+
34+
- name: Run unit tests (via prek)
35+
run: |
36+
uv run prek run pytest

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ repos:
77

88
- repo: local
99
hooks:
10-
- id: pyright
11-
name: Pyright type checking
12-
entry: uv run pyright src tests
10+
- id: ty
11+
name: ty type checking
12+
entry: uv run ty check src tests
1313
language: system
1414
pass_filenames: false
1515

pyproject.toml

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,53 @@ asyncio_mode = "auto"
102102
[tool.uv]
103103
required-version = ">=0.6.15"
104104

105+
[tool.ty.environment]
106+
python-version = "3.11"
107+
108+
[tool.ty.rules]
109+
# Ignore unused-ignore-comment warnings because they vary depending on whether
110+
# optional deps are installed. The ty:ignore comments are needed in CI (with all deps)
111+
# but become unused locally (without all deps).
112+
unused-ignore-comment = "ignore"
113+
114+
[tool.ty.analysis]
115+
# Allow unresolved imports for optional dependencies that may not be installed locally.
116+
# In CI, we install all optional deps so these will be resolved and type-checked.
117+
allowed-unresolved-imports = [
118+
# backend deps
119+
"accelerate.**",
120+
"awscli.**",
121+
"bitsandbytes.**",
122+
"duckdb.**",
123+
"fastapi.**",
124+
"gql.**",
125+
"hf_xet.**",
126+
"nbclient.**",
127+
"nbmake.**",
128+
"peft.**",
129+
"pyarrow.**",
130+
"torch.**",
131+
"torchao.**",
132+
"transformers.**",
133+
"trl.**",
134+
"unsloth.**",
135+
"unsloth_zoo.**",
136+
"uvicorn.**",
137+
"vllm.**",
138+
"wandb.**",
139+
# skypilot deps
140+
"semver.**",
141+
"sky.**",
142+
"skypilot.**",
143+
# langgraph deps
144+
"langchain_core.**",
145+
"langchain_openai.**",
146+
"langgraph.**",
147+
# plotting deps
148+
"matplotlib.**",
149+
"seaborn.**",
150+
]
151+
105152
[dependency-groups]
106153
dev = [
107154
"black>=25.1.0",
@@ -112,7 +159,7 @@ dev = [
112159
"pytest>=8.4.1",
113160
"nbval>=0.11.0",
114161
"pytest-xdist>=3.8.0",
115-
"pyright[nodejs]>=1.1.403",
162+
"ty>=0.0.14",
116163
"pytest-asyncio>=1.1.0",
117164
"duckdb>=1.0.0",
118165
"pyarrow>=15.0.0",

src/art/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def __init__(self, **kwargs):
3434

3535
# Import unsloth before transformers, peft, and trl to maximize Unsloth optimizations
3636
if os.environ.get("IMPORT_UNSLOTH", "0") == "1":
37-
import unsloth # type: ignore # noqa: F401
37+
import unsloth # noqa: F401
3838

3939
try:
40-
import transformers # type: ignore
40+
import transformers
4141

4242
try:
4343
from .transformers.patches import patch_preprocess_mask_arguments

src/art/auto_trajectory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ async def patched_aclose(self: httpx._models.Response) -> None:
177177
if context := auto_trajectory_context_var.get(None):
178178
context.handle_httpx_response(self)
179179

180-
httpx._models.Response.iter_bytes = patched_iter_bytes
181-
httpx._models.Response.aiter_bytes = patched_aiter_bytes
182-
httpx._models.Response.close = patched_close
183-
httpx._models.Response.aclose = patched_aclose
180+
httpx._models.Response.iter_bytes = patched_iter_bytes # ty:ignore[invalid-assignment]
181+
httpx._models.Response.aiter_bytes = patched_aiter_bytes # ty:ignore[invalid-assignment]
182+
httpx._models.Response.close = patched_close # ty:ignore[invalid-assignment]
183+
httpx._models.Response.aclose = patched_aclose # ty:ignore[invalid-assignment]
184184

185185

186186
patch_httpx()

src/art/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
159159
return pydantic.BaseModel.__init__(self, *args, **kwargs)
160160

161161
TrajectoryGroup.__new__ = __new__ # type: ignore
162-
TrajectoryGroup.__init__ = __init__
162+
TrajectoryGroup.__init__ = __init__ # ty:ignore[invalid-assignment]
163163

164164
backend = LocalBackend()
165165
app = FastAPI()

src/art/gather.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def gather_trajectories(
134134
)
135135
if context.pbar is not None:
136136
context.pbar.close()
137-
return results # type: ignore
137+
return results
138138

139139

140140
async def wrap_group_awaitable(
@@ -193,7 +193,7 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None:
193193
len(l.content or l.refusal or [])
194194
for l in logprobs # noqa: E741
195195
) / len(logprobs)
196-
context.metric_sums["reward"] += trajectory.reward # type: ignore
196+
context.metric_sums["reward"] += trajectory.reward
197197
context.metric_divisors["reward"] += 1
198198
context.metric_sums.update(trajectory.metrics)
199199
context.metric_divisors.update(trajectory.metrics.keys())
@@ -229,7 +229,7 @@ def too_many_exceptions(self) -> bool:
229229
if (
230230
0 < self.max_exceptions < 1
231231
and self.pbar is not None
232-
and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions
232+
and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions # ty:ignore[unsupported-operator]
233233
) or self.metric_sums["exceptions"] <= self.max_exceptions:
234234
return False
235235
return True

src/art/guided_completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def freeze_tool_schema(tool: dict, fixed_args: dict) -> ChatCompletionToolParam:
1919
Each field is cast to typing.Literal[value] so Pydantic emits an
2020
enum-of-one in the JSON schema, which vLLM's `guided_json` accepts.
2121
"""
22-
fields = {k: (Literal[v], ...) for k, v in fixed_args.items()}
22+
fields = {k: (Literal[v], ...) for k, v in fixed_args.items()} # ty:ignore[invalid-type-form]
2323
FrozenModel = create_model(
2424
f"{tool['function']['name'].title()}FrozenArgs",
25-
**fields, # type: ignore
26-
)
25+
**fields,
26+
) # ty:ignore[no-matching-overload]
2727

2828
locked = deepcopy(tool)
2929
locked["function"]["parameters"] = FrozenModel.model_json_schema()
@@ -71,7 +71,7 @@ def get_guided_completion_params(
7171
}
7272
chosen_tool = next(t for t in base_tools if t["function"]["name"] == tool_name)
7373
tool_params = [
74-
freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments)) # type: ignore
74+
freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments))
7575
]
7676
else:
7777
content = completion.choices[0].message.content

src/art/langgraph/llm_wrapper.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def init_chat_model(
118118
config = CURRENT_CONFIG.get()
119119
return LoggingLLM(
120120
ChatOpenAI(
121-
base_url=config["base_url"],
122-
api_key=config["api_key"],
123-
model=config["model"],
121+
base_url=config["base_url"], # ty:ignore[unknown-argument]
122+
api_key=config["api_key"], # ty:ignore[unknown-argument]
123+
model=config["model"], # ty:ignore[unknown-argument]
124124
temperature=1.0,
125125
),
126126
config["logger"],
@@ -222,17 +222,17 @@ def with_config(
222222
self.llm,
223223
"bound",
224224
ChatOpenAI(
225-
base_url=art_config["base_url"],
226-
api_key=art_config["api_key"],
227-
model=art_config["model"],
225+
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
226+
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
227+
model=art_config["model"], # ty:ignore[unknown-argument]
228228
temperature=1.0,
229229
),
230230
)
231231
else:
232232
self.llm = ChatOpenAI(
233-
base_url=art_config["base_url"],
234-
api_key=art_config["api_key"],
235-
model=art_config["model"],
233+
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
234+
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
235+
model=art_config["model"], # ty:ignore[unknown-argument]
236236
temperature=1.0,
237237
)
238238

src/art/local/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _get_packed_tensors(
215215
packed_tensors = packed_tensors_from_tokenized_results(
216216
tokenized_results,
217217
sequence_length,
218-
pad_token_id=tokenizer.eos_token_id, # type: ignore
218+
pad_token_id=tokenizer.eos_token_id,
219219
advantage_balance=advantage_balance,
220220
)
221221
if (
@@ -360,7 +360,7 @@ def _trajectory_log(self, trajectory: Trajectory) -> str:
360360
if isinstance(message_or_choice, dict):
361361
message = message_or_choice
362362
else:
363-
message = cast(Message, message_or_choice.message.model_dump())
363+
message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute]
364364
formatted_messages.append(format_message(message))
365365
return header + "\n".join(formatted_messages)
366366

0 commit comments

Comments
 (0)