|
1 | 1 | import { describe, expect, test, beforeEach, afterEach } from "bun:test" |
| 2 | +import fs from "fs/promises" |
2 | 3 | import path from "path" |
3 | 4 | import { Session } from "../../src/session" |
4 | 5 | import { ModelID, ProviderID } from "../../src/provider/schema" |
5 | 6 | import { SessionRevert } from "../../src/session/revert" |
6 | 7 | import { SessionCompaction } from "../../src/session/compaction" |
7 | 8 | import { MessageV2 } from "../../src/session/message-v2" |
| 9 | +import { Snapshot } from "../../src/snapshot" |
8 | 10 | import { Log } from "../../src/util/log" |
9 | 11 | import { Instance } from "../../src/project/instance" |
10 | 12 | import { MessageID, PartID } from "../../src/session/schema" |
@@ -70,6 +72,13 @@ function tool(sessionID: string, messageID: string) { |
70 | 72 | }) |
71 | 73 | } |
72 | 74 |
|
| 75 | +const tokens = { |
| 76 | + input: 0, |
| 77 | + output: 0, |
| 78 | + reasoning: 0, |
| 79 | + cache: { read: 0, write: 0 }, |
| 80 | +} |
| 81 | + |
73 | 82 | describe("revert + compact workflow", () => { |
74 | 83 | test("should properly handle compact command after revert", async () => { |
75 | 84 | await using tmp = await tmpdir({ git: true }) |
@@ -434,4 +443,179 @@ describe("revert + compact workflow", () => { |
434 | 443 | }, |
435 | 444 | }) |
436 | 445 | }) |
| 446 | + |
| 447 | + test("restore messages in sequential order", async () => { |
| 448 | + await using tmp = await tmpdir({ git: true }) |
| 449 | + await Instance.provide({ |
| 450 | + directory: tmp.path, |
| 451 | + fn: async () => { |
| 452 | + await fs.writeFile(path.join(tmp.path, "a.txt"), "a0") |
| 453 | + await fs.writeFile(path.join(tmp.path, "b.txt"), "b0") |
| 454 | + await fs.writeFile(path.join(tmp.path, "c.txt"), "c0") |
| 455 | + |
| 456 | + const session = await Session.create({}) |
| 457 | + const sid = session.id |
| 458 | + |
| 459 | + const turn = async (file: string, next: string) => { |
| 460 | + const u = await user(sid) |
| 461 | + await text(sid, u.id, `${file}:${next}`) |
| 462 | + const a = await assistant(sid, u.id, tmp.path) |
| 463 | + const before = await Snapshot.track() |
| 464 | + if (!before) throw new Error("expected snapshot") |
| 465 | + await fs.writeFile(path.join(tmp.path, file), next) |
| 466 | + const after = await Snapshot.track() |
| 467 | + if (!after) throw new Error("expected snapshot") |
| 468 | + const patch = await Snapshot.patch(before) |
| 469 | + await Session.updatePart({ |
| 470 | + id: PartID.ascending(), |
| 471 | + messageID: a.id, |
| 472 | + sessionID: sid, |
| 473 | + type: "step-start", |
| 474 | + snapshot: before, |
| 475 | + }) |
| 476 | + await Session.updatePart({ |
| 477 | + id: PartID.ascending(), |
| 478 | + messageID: a.id, |
| 479 | + sessionID: sid, |
| 480 | + type: "step-finish", |
| 481 | + reason: "stop", |
| 482 | + snapshot: after, |
| 483 | + cost: 0, |
| 484 | + tokens, |
| 485 | + }) |
| 486 | + await Session.updatePart({ |
| 487 | + id: PartID.ascending(), |
| 488 | + messageID: a.id, |
| 489 | + sessionID: sid, |
| 490 | + type: "patch", |
| 491 | + hash: patch.hash, |
| 492 | + files: patch.files, |
| 493 | + }) |
| 494 | + return u.id |
| 495 | + } |
| 496 | + |
| 497 | + const first = await turn("a.txt", "a1") |
| 498 | + const second = await turn("b.txt", "b2") |
| 499 | + const third = await turn("c.txt", "c3") |
| 500 | + |
| 501 | + await SessionRevert.revert({ |
| 502 | + sessionID: sid, |
| 503 | + messageID: first, |
| 504 | + }) |
| 505 | + expect((await Session.get(sid)).revert?.messageID).toBe(first) |
| 506 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0") |
| 507 | + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0") |
| 508 | + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") |
| 509 | + |
| 510 | + await SessionRevert.revert({ |
| 511 | + sessionID: sid, |
| 512 | + messageID: second, |
| 513 | + }) |
| 514 | + expect((await Session.get(sid)).revert?.messageID).toBe(second) |
| 515 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") |
| 516 | + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0") |
| 517 | + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") |
| 518 | + |
| 519 | + await SessionRevert.revert({ |
| 520 | + sessionID: sid, |
| 521 | + messageID: third, |
| 522 | + }) |
| 523 | + expect((await Session.get(sid)).revert?.messageID).toBe(third) |
| 524 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") |
| 525 | + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2") |
| 526 | + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") |
| 527 | + |
| 528 | + await SessionRevert.unrevert({ |
| 529 | + sessionID: sid, |
| 530 | + }) |
| 531 | + expect((await Session.get(sid)).revert).toBeUndefined() |
| 532 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") |
| 533 | + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2") |
| 534 | + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3") |
| 535 | + }, |
| 536 | + }) |
| 537 | + }) |
| 538 | + |
| 539 | + test("restore same file in sequential order", async () => { |
| 540 | + await using tmp = await tmpdir({ git: true }) |
| 541 | + await Instance.provide({ |
| 542 | + directory: tmp.path, |
| 543 | + fn: async () => { |
| 544 | + await fs.writeFile(path.join(tmp.path, "a.txt"), "a0") |
| 545 | + |
| 546 | + const session = await Session.create({}) |
| 547 | + const sid = session.id |
| 548 | + |
| 549 | + const turn = async (next: string) => { |
| 550 | + const u = await user(sid) |
| 551 | + await text(sid, u.id, `a.txt:${next}`) |
| 552 | + const a = await assistant(sid, u.id, tmp.path) |
| 553 | + const before = await Snapshot.track() |
| 554 | + if (!before) throw new Error("expected snapshot") |
| 555 | + await fs.writeFile(path.join(tmp.path, "a.txt"), next) |
| 556 | + const after = await Snapshot.track() |
| 557 | + if (!after) throw new Error("expected snapshot") |
| 558 | + const patch = await Snapshot.patch(before) |
| 559 | + await Session.updatePart({ |
| 560 | + id: PartID.ascending(), |
| 561 | + messageID: a.id, |
| 562 | + sessionID: sid, |
| 563 | + type: "step-start", |
| 564 | + snapshot: before, |
| 565 | + }) |
| 566 | + await Session.updatePart({ |
| 567 | + id: PartID.ascending(), |
| 568 | + messageID: a.id, |
| 569 | + sessionID: sid, |
| 570 | + type: "step-finish", |
| 571 | + reason: "stop", |
| 572 | + snapshot: after, |
| 573 | + cost: 0, |
| 574 | + tokens, |
| 575 | + }) |
| 576 | + await Session.updatePart({ |
| 577 | + id: PartID.ascending(), |
| 578 | + messageID: a.id, |
| 579 | + sessionID: sid, |
| 580 | + type: "patch", |
| 581 | + hash: patch.hash, |
| 582 | + files: patch.files, |
| 583 | + }) |
| 584 | + return u.id |
| 585 | + } |
| 586 | + |
| 587 | + const first = await turn("a1") |
| 588 | + const second = await turn("a2") |
| 589 | + const third = await turn("a3") |
| 590 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3") |
| 591 | + |
| 592 | + await SessionRevert.revert({ |
| 593 | + sessionID: sid, |
| 594 | + messageID: first, |
| 595 | + }) |
| 596 | + expect((await Session.get(sid)).revert?.messageID).toBe(first) |
| 597 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0") |
| 598 | + |
| 599 | + await SessionRevert.revert({ |
| 600 | + sessionID: sid, |
| 601 | + messageID: second, |
| 602 | + }) |
| 603 | + expect((await Session.get(sid)).revert?.messageID).toBe(second) |
| 604 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") |
| 605 | + |
| 606 | + await SessionRevert.revert({ |
| 607 | + sessionID: sid, |
| 608 | + messageID: third, |
| 609 | + }) |
| 610 | + expect((await Session.get(sid)).revert?.messageID).toBe(third) |
| 611 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2") |
| 612 | + |
| 613 | + await SessionRevert.unrevert({ |
| 614 | + sessionID: sid, |
| 615 | + }) |
| 616 | + expect((await Session.get(sid)).revert).toBeUndefined() |
| 617 | + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3") |
| 618 | + }, |
| 619 | + }) |
| 620 | + }) |
437 | 621 | }) |
0 commit comments