diff --git a/lib/hypervisor/firecracker/config.go b/lib/hypervisor/firecracker/config.go index e4ccefd8..5b6b48cc 100644 --- a/lib/hypervisor/firecracker/config.go +++ b/lib/hypervisor/firecracker/config.go @@ -213,12 +213,12 @@ func toSnapshotCreateParams(snapshotDir string) snapshotCreateParams { } } -func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride) snapshotLoadParams { +func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride, resumeVM bool) snapshotLoadParams { return snapshotLoadParams{ MemFilePath: snapshotMemoryPath(snapshotDir), SnapshotPath: snapshotStatePath(snapshotDir), EnableDiffSnapshots: true, - ResumeVM: false, + ResumeVM: resumeVM, NetworkOverrides: networkOverrides, } } diff --git a/lib/hypervisor/firecracker/config_test.go b/lib/hypervisor/firecracker/config_test.go index 6e912ee7..66adf83e 100644 --- a/lib/hypervisor/firecracker/config_test.go +++ b/lib/hypervisor/firecracker/config_test.go @@ -82,11 +82,11 @@ func TestSnapshotParamPaths(t *testing.T) { load := toSnapshotLoadParams("/tmp/snapshot-latest", []networkOverride{ {IfaceID: "eth0", HostDevName: "hype-abc123"}, - }) + }, true) assert.Equal(t, "/tmp/snapshot-latest/state", load.SnapshotPath) assert.Equal(t, "/tmp/snapshot-latest/memory", load.MemFilePath) assert.True(t, load.EnableDiffSnapshots) - assert.False(t, load.ResumeVM) + assert.True(t, load.ResumeVM) require.Len(t, load.NetworkOverrides, 1) } diff --git a/lib/hypervisor/firecracker/firecracker.go b/lib/hypervisor/firecracker/firecracker.go index e22e42f3..ff5c44d6 100644 --- a/lib/hypervisor/firecracker/firecracker.go +++ b/lib/hypervisor/firecracker/firecracker.go @@ -27,8 +27,9 @@ type apiError struct { // Firecracker implements hypervisor.Hypervisor for the Firecracker VMM. type Firecracker struct { - socketPath string - client *http.Client + socketPath string + client *http.Client + restoredResumed bool } func New(socketPath string) (*Firecracker, error) { @@ -50,6 +51,10 @@ func New(socketPath string) (*Firecracker, error) { var _ hypervisor.Hypervisor = (*Firecracker)(nil) +func (f *Firecracker) RestoredResumed() bool { + return f != nil && f.restoredResumed +} + func (f *Firecracker) Capabilities() hypervisor.Capabilities { return capabilities() } @@ -223,8 +228,8 @@ func (f *Firecracker) instanceStart(ctx context.Context) error { return f.postAction(ctx, "InstanceStart") } -func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride) error { - params := toSnapshotLoadParams(snapshotDir, networkOverrides) +func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride, resumeVM bool) error { + params := toSnapshotLoadParams(snapshotDir, networkOverrides, resumeVM) if _, err := f.do(ctx, http.MethodPut, "/snapshot/load", params, http.StatusNoContent); err != nil { return err } diff --git a/lib/hypervisor/firecracker/process.go b/lib/hypervisor/firecracker/process.go index b29539e9..5044fa2b 100644 --- a/lib/hypervisor/firecracker/process.go +++ b/lib/hypervisor/firecracker/process.go @@ -18,9 +18,11 @@ import ( ) const ( - socketWaitTimeout = 10 * time.Second - socketPollEvery = 50 * time.Millisecond - socketDialTimeout = 100 * time.Millisecond + socketWaitTimeout = 10 * time.Second + socketReadyRetryEvery = 1 * time.Millisecond + socketDialTimeout = 100 * time.Millisecond + restoreResumeOnLoadEnv = "HYPEMAN_FIRECRACKER_RESTORE_RESUME_ON_LOAD" + restoreDeepTraceEnvForLoad = "HYPEMAN_RESTORE_DEEP_TRACE" ) func init() { @@ -115,16 +117,18 @@ func (s *Starter) RestoreVM(ctx context.Context, p *paths.Paths, version string, if err != nil { return 0, nil, fmt.Errorf("load firecracker restore metadata: %w", err) } + resumeOnLoad := shouldResumeOnSnapshotLoad() err = func() error { snapshotSourceAliasMu.Lock() defer snapshotSourceAliasMu.Unlock() return withSnapshotSourceDirAlias(meta, filepath.Dir(socketPath), func() error { - return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides) + return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides, resumeOnLoad) }) }() if err != nil { return 0, nil, fmt.Errorf("load firecracker snapshot: %w", err) } + hv.restoredResumed = resumeOnLoad if meta.SnapshotSourceDataDir != "" && !meta.RetainSnapshotSourceDataDirAlias { meta.SnapshotSourceDataDir = "" if err := saveRestoreMetadataState(filepath.Dir(socketPath), meta); err != nil { @@ -244,23 +248,47 @@ func (s *Starter) startProcess(_ context.Context, p *paths.Paths, version string } func isSocketInUse(socketPath string) bool { - conn, err := net.DialTimeout("unix", socketPath, socketDialTimeout) - if err != nil { + return tryDialUnixSocket(socketPath) == nil +} + +func shouldResumeOnSnapshotLoad() bool { + if envBoolDisabled(os.Getenv(restoreResumeOnLoadEnv)) { + return false + } + // Deep restore tracing is anchored after RestoreVM returns with a PID. Keep + // the explicit host Resume call in that mode so the trace still captures the + // first resumed guest execution window. + if strings.TrimSpace(os.Getenv(restoreDeepTraceEnvForLoad)) == "1" { return false } - _ = conn.Close() return true } -func waitForSocket(path string, timeout time.Duration) error { +func envBoolDisabled(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "0", "false", "no", "off": + return true + default: + return false + } +} + +func tryDialUnixSocket(path string) error { + conn, err := net.DialTimeout("unix", path, socketDialTimeout) + if err != nil { + return err + } + _ = conn.Close() + return nil +} + +func waitForSocketByPolling(path string, timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { - conn, err := net.DialTimeout("unix", path, socketDialTimeout) - if err == nil { - _ = conn.Close() + if err := tryDialUnixSocket(path); err == nil { return nil } - time.Sleep(socketPollEvery) + time.Sleep(socketReadyRetryEvery) } return fmt.Errorf("timeout waiting for socket") } diff --git a/lib/hypervisor/firecracker/process_test.go b/lib/hypervisor/firecracker/process_test.go index df99180e..44a6b4d0 100644 --- a/lib/hypervisor/firecracker/process_test.go +++ b/lib/hypervisor/firecracker/process_test.go @@ -2,9 +2,11 @@ package firecracker import ( "errors" + "net" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,3 +80,43 @@ func TestWithSnapshotSourceDirAlias_RejectsNestedPaths(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "must not be nested") } + +func TestShouldResumeOnSnapshotLoad(t *testing.T) { + t.Setenv(restoreResumeOnLoadEnv, "") + t.Setenv(restoreDeepTraceEnvForLoad, "") + assert.True(t, shouldResumeOnSnapshotLoad()) + + t.Setenv(restoreResumeOnLoadEnv, "0") + assert.False(t, shouldResumeOnSnapshotLoad()) + + t.Setenv(restoreResumeOnLoadEnv, "") + t.Setenv(restoreDeepTraceEnvForLoad, "1") + assert.False(t, shouldResumeOnSnapshotLoad()) +} + +func TestWaitForSocketReturnsWhenSocketAppears(t *testing.T) { + tmp, err := os.MkdirTemp("/tmp", "fcwait-") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(tmp) }) + socketPath := filepath.Join(tmp, "fc.sock") + done := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + defer close(done) + time.Sleep(10 * time.Millisecond) + listener, err := net.Listen("unix", socketPath) + if err != nil { + errCh <- err + return + } + errCh <- nil + defer listener.Close() + <-time.After(50 * time.Millisecond) + }() + + start := time.Now() + require.NoError(t, waitForSocket(socketPath, time.Second)) + assert.Less(t, time.Since(start), 250*time.Millisecond) + require.NoError(t, <-errCh) + <-done +} diff --git a/lib/hypervisor/firecracker/socket_wait_linux.go b/lib/hypervisor/firecracker/socket_wait_linux.go new file mode 100644 index 00000000..8f5dcb47 --- /dev/null +++ b/lib/hypervisor/firecracker/socket_wait_linux.go @@ -0,0 +1,96 @@ +//go:build linux + +package firecracker + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "golang.org/x/sys/unix" +) + +func waitForSocket(path string, timeout time.Duration) error { + if err := tryDialUnixSocket(path); err == nil { + return nil + } + + parent := filepath.Dir(path) + fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK) + if err != nil { + return waitForSocketByPolling(path, timeout) + } + defer unix.Close(fd) + + wd, err := unix.InotifyAddWatch(fd, parent, unix.IN_CREATE|unix.IN_MOVED_TO|unix.IN_ATTRIB) + if err != nil { + return waitForSocketByPolling(path, timeout) + } + defer unix.InotifyRmWatch(fd, uint32(wd)) + + deadline := time.Now().Add(timeout) + buf := make([]byte, 4096) + for { + if err := tryDialUnixSocket(path); err == nil { + return nil + } + remaining := time.Until(deadline) + if remaining <= 0 { + return fmt.Errorf("timeout waiting for socket") + } + + pollTimeout := remaining + if socketPathExists(path) { + pollTimeout = minDuration(pollTimeout, socketReadyRetryEvery) + } + events := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + n, err := unix.Poll(events, durationMillisCeil(pollTimeout)) + if err != nil { + if err == unix.EINTR { + continue + } + return waitForSocketByPolling(path, remaining) + } + if n > 0 { + for { + n, err := unix.Read(fd, buf) + if err != nil { + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + return waitForSocketByPolling(path, time.Until(deadline)) + } + if n == 0 { + break + } + } + } + } +} + +func socketPathExists(path string) bool { + _, err := os.Lstat(path) + return err == nil +} + +func durationMillisCeil(d time.Duration) int { + if d <= 0 { + return 0 + } + ms := d / time.Millisecond + if d%time.Millisecond != 0 { + ms++ + } + if int64(ms) > int64(^uint(0)>>1) { + return int(^uint(0) >> 1) + } + return int(ms) +} + +func minDuration(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b +} diff --git a/lib/hypervisor/firecracker/socket_wait_other.go b/lib/hypervisor/firecracker/socket_wait_other.go new file mode 100644 index 00000000..1de4be29 --- /dev/null +++ b/lib/hypervisor/firecracker/socket_wait_other.go @@ -0,0 +1,9 @@ +//go:build !linux + +package firecracker + +import "time" + +func waitForSocket(path string, timeout time.Duration) error { + return waitForSocketByPolling(path, timeout) +} diff --git a/lib/hypervisor/hypervisor.go b/lib/hypervisor/hypervisor.go index 7331f471..0e44aaf9 100644 --- a/lib/hypervisor/hypervisor.go +++ b/lib/hypervisor/hypervisor.go @@ -114,7 +114,8 @@ type VMStarter interface { // Each hypervisor implements its own restore flow: // - Cloud Hypervisor: starts process, calls Restore API // - QEMU: would start with -incoming or -loadvm flags (not yet implemented) - // Returns the process ID and a Hypervisor client. The VM is in paused state after restore. + // Returns the process ID and a Hypervisor client. The VM is usually paused + // after restore, unless the returned client reports RestoredResumed. RestoreVM(ctx context.Context, p *paths.Paths, version string, socketPath string, snapshotPath string) (pid int, hv Hypervisor, err error) // PrepareFork allows hypervisors to prepare forked instance state. @@ -202,6 +203,16 @@ type Hypervisor interface { Capabilities() Capabilities } +type restoredResumedHypervisor interface { + RestoredResumed() bool +} + +// RestoredResumed reports whether RestoreVM already resumed guest execution. +func RestoredResumed(hv Hypervisor) bool { + resumed, ok := hv.(restoredResumedHypervisor) + return ok && resumed.RestoredResumed() +} + // Capabilities indicates which optional features a hypervisor supports. // Callers should check these before calling optional methods. type Capabilities struct { diff --git a/lib/hypervisor/tracing.go b/lib/hypervisor/tracing.go index 79888302..96805240 100644 --- a/lib/hypervisor/tracing.go +++ b/lib/hypervisor/tracing.go @@ -172,6 +172,10 @@ func (h *tracingHypervisor) Capabilities() Capabilities { return h.next.Capabilities() } +func (h *tracingHypervisor) RestoredResumed() bool { + return RestoredResumed(h.next) +} + func (h *tracingHypervisor) spanAttrs(attrs ...attribute.KeyValue) []attribute.KeyValue { out := make([]attribute.KeyValue, 0, len(h.attrs)+len(attrs)) out = append(out, h.attrs...) diff --git a/lib/hypervisor/tracing_test.go b/lib/hypervisor/tracing_test.go index 8d5147c6..55a96892 100644 --- a/lib/hypervisor/tracing_test.go +++ b/lib/hypervisor/tracing_test.go @@ -18,6 +18,9 @@ import ( type fakeHypervisor struct{} type fakeHypervisorGetVMInfoError struct{} +type fakeRestoredResumedHypervisor struct { + fakeHypervisor +} func (fakeHypervisor) DeleteVM(context.Context) error { return nil } func (fakeHypervisor) Shutdown(context.Context) error { return nil } @@ -36,6 +39,7 @@ func (fakeHypervisor) GetTargetGuestMemoryBytes(context.Context) (int64, error) return 0, nil } func (fakeHypervisor) Capabilities() Capabilities { return Capabilities{} } +func (fakeRestoredResumedHypervisor) RestoredResumed() bool { return true } func (fakeHypervisorGetVMInfoError) DeleteVM(context.Context) error { return nil } func (fakeHypervisorGetVMInfoError) Shutdown(context.Context) error { return nil } func (fakeHypervisorGetVMInfoError) GetVMInfo(context.Context) (*VMInfo, error) { @@ -123,6 +127,11 @@ func TestWrapVMStarterWrapsReturnedHypervisor(t *testing.T) { assert.Equal(t, string(TypeCloudHypervisor), attrs["hypervisor"]) } +func TestWrapHypervisorPreservesRestoredResumed(t *testing.T) { + hv := WrapHypervisor(TypeFirecracker, fakeRestoredResumedHypervisor{}) + require.True(t, RestoredResumed(hv)) +} + func TestWrapHypervisorSkipsGetVMInfoTraceByDefault(t *testing.T) { recorder, _ := newTestTracerProvider(t) diff --git a/lib/instances/restore.go b/lib/instances/restore.go index 17e4726d..d169242a 100644 --- a/lib/instances/restore.go +++ b/lib/instances/restore.go @@ -309,10 +309,16 @@ func (m *manager) restoreInstance( ) log.InfoContext(ctx, "resuming VM", "instance_id", id) if deepTrace != nil { - deepTrace.Mark("resume_call_start", "") - deepTrace.Sample("resume_call_start") + stage := "resume_call_start" + if hypervisor.RestoredResumed(hv) { + stage = "resume_already_done" + } + deepTrace.Mark(stage, "") + deepTrace.Sample(stage) } - if err := hv.Resume(resumeCtx); err != nil { + if hypervisor.RestoredResumed(hv) { + log.InfoContext(ctx, "VM was resumed during snapshot load", "instance_id", id) + } else if err := hv.Resume(resumeCtx); err != nil { if deepTrace != nil { deepTrace.Mark("resume_error", err.Error()) deepTrace.Sample("resume_error")