Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/hypervisor/firecracker/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ func toSnapshotCreateParams(snapshotDir string) snapshotCreateParams {
}
}

func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride) snapshotLoadParams {
func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride, resumeVM bool) snapshotLoadParams {
return snapshotLoadParams{
MemFilePath: snapshotMemoryPath(snapshotDir),
SnapshotPath: snapshotStatePath(snapshotDir),
EnableDiffSnapshots: true,
ResumeVM: false,
ResumeVM: resumeVM,
NetworkOverrides: networkOverrides,
}
}
Expand Down
4 changes: 2 additions & 2 deletions lib/hypervisor/firecracker/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ func TestSnapshotParamPaths(t *testing.T) {

load := toSnapshotLoadParams("/tmp/snapshot-latest", []networkOverride{
{IfaceID: "eth0", HostDevName: "hype-abc123"},
})
}, true)
assert.Equal(t, "/tmp/snapshot-latest/state", load.SnapshotPath)
assert.Equal(t, "/tmp/snapshot-latest/memory", load.MemFilePath)
assert.True(t, load.EnableDiffSnapshots)
assert.False(t, load.ResumeVM)
assert.True(t, load.ResumeVM)
require.Len(t, load.NetworkOverrides, 1)
}

Expand Down
13 changes: 9 additions & 4 deletions lib/hypervisor/firecracker/firecracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ type apiError struct {

// Firecracker implements hypervisor.Hypervisor for the Firecracker VMM.
type Firecracker struct {
socketPath string
client *http.Client
socketPath string
client *http.Client
restoredResumed bool
}

func New(socketPath string) (*Firecracker, error) {
Expand All @@ -50,6 +51,10 @@ func New(socketPath string) (*Firecracker, error) {

var _ hypervisor.Hypervisor = (*Firecracker)(nil)

func (f *Firecracker) RestoredResumed() bool {
return f != nil && f.restoredResumed
}

func (f *Firecracker) Capabilities() hypervisor.Capabilities {
return capabilities()
}
Expand Down Expand Up @@ -223,8 +228,8 @@ func (f *Firecracker) instanceStart(ctx context.Context) error {
return f.postAction(ctx, "InstanceStart")
}

func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride) error {
params := toSnapshotLoadParams(snapshotDir, networkOverrides)
func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride, resumeVM bool) error {
params := toSnapshotLoadParams(snapshotDir, networkOverrides, resumeVM)
if _, err := f.do(ctx, http.MethodPut, "/snapshot/load", params, http.StatusNoContent); err != nil {
return err
}
Expand Down
52 changes: 40 additions & 12 deletions lib/hypervisor/firecracker/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
)

const (
socketWaitTimeout = 10 * time.Second
socketPollEvery = 50 * time.Millisecond
socketDialTimeout = 100 * time.Millisecond
socketWaitTimeout = 10 * time.Second
socketReadyRetryEvery = 1 * time.Millisecond
socketDialTimeout = 100 * time.Millisecond
restoreResumeOnLoadEnv = "HYPEMAN_FIRECRACKER_RESTORE_RESUME_ON_LOAD"
restoreDeepTraceEnvForLoad = "HYPEMAN_RESTORE_DEEP_TRACE"
)

func init() {
Expand Down Expand Up @@ -115,16 +117,18 @@ func (s *Starter) RestoreVM(ctx context.Context, p *paths.Paths, version string,
if err != nil {
return 0, nil, fmt.Errorf("load firecracker restore metadata: %w", err)
}
resumeOnLoad := shouldResumeOnSnapshotLoad()
err = func() error {
snapshotSourceAliasMu.Lock()
defer snapshotSourceAliasMu.Unlock()
return withSnapshotSourceDirAlias(meta, filepath.Dir(socketPath), func() error {
return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides)
return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides, resumeOnLoad)
})
}()
if err != nil {
return 0, nil, fmt.Errorf("load firecracker snapshot: %w", err)
}
hv.restoredResumed = resumeOnLoad
if meta.SnapshotSourceDataDir != "" && !meta.RetainSnapshotSourceDataDirAlias {
meta.SnapshotSourceDataDir = ""
if err := saveRestoreMetadataState(filepath.Dir(socketPath), meta); err != nil {
Expand Down Expand Up @@ -244,23 +248,47 @@ func (s *Starter) startProcess(_ context.Context, p *paths.Paths, version string
}

func isSocketInUse(socketPath string) bool {
conn, err := net.DialTimeout("unix", socketPath, socketDialTimeout)
if err != nil {
return tryDialUnixSocket(socketPath) == nil
}

func shouldResumeOnSnapshotLoad() bool {
if envBoolDisabled(os.Getenv(restoreResumeOnLoadEnv)) {
return false
}
// Deep restore tracing is anchored after RestoreVM returns with a PID. Keep
// the explicit host Resume call in that mode so the trace still captures the
// first resumed guest execution window.
if strings.TrimSpace(os.Getenv(restoreDeepTraceEnvForLoad)) == "1" {
return false
}
_ = conn.Close()
return true
}

func waitForSocket(path string, timeout time.Duration) error {
func envBoolDisabled(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "0", "false", "no", "off":
return true
default:
return false
}
}

func tryDialUnixSocket(path string) error {
conn, err := net.DialTimeout("unix", path, socketDialTimeout)
if err != nil {
return err
}
_ = conn.Close()
return nil
}

func waitForSocketByPolling(path string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("unix", path, socketDialTimeout)
if err == nil {
_ = conn.Close()
if err := tryDialUnixSocket(path); err == nil {
return nil
}
time.Sleep(socketPollEvery)
time.Sleep(socketReadyRetryEvery)
}
return fmt.Errorf("timeout waiting for socket")
}
42 changes: 42 additions & 0 deletions lib/hypervisor/firecracker/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package firecracker

import (
"errors"
"net"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -78,3 +80,43 @@ func TestWithSnapshotSourceDirAlias_RejectsNestedPaths(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "must not be nested")
}

func TestShouldResumeOnSnapshotLoad(t *testing.T) {
t.Setenv(restoreResumeOnLoadEnv, "")
t.Setenv(restoreDeepTraceEnvForLoad, "")
assert.True(t, shouldResumeOnSnapshotLoad())

t.Setenv(restoreResumeOnLoadEnv, "0")
assert.False(t, shouldResumeOnSnapshotLoad())

t.Setenv(restoreResumeOnLoadEnv, "")
t.Setenv(restoreDeepTraceEnvForLoad, "1")
assert.False(t, shouldResumeOnSnapshotLoad())
}

func TestWaitForSocketReturnsWhenSocketAppears(t *testing.T) {
tmp, err := os.MkdirTemp("/tmp", "fcwait-")
require.NoError(t, err)
t.Cleanup(func() { _ = os.RemoveAll(tmp) })
socketPath := filepath.Join(tmp, "fc.sock")
done := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(done)
time.Sleep(10 * time.Millisecond)
listener, err := net.Listen("unix", socketPath)
if err != nil {
errCh <- err
return
}
errCh <- nil
defer listener.Close()
<-time.After(50 * time.Millisecond)
}()

start := time.Now()
require.NoError(t, waitForSocket(socketPath, time.Second))
assert.Less(t, time.Since(start), 250*time.Millisecond)
require.NoError(t, <-errCh)
<-done
}
96 changes: 96 additions & 0 deletions lib/hypervisor/firecracker/socket_wait_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//go:build linux

package firecracker

import (
"fmt"
"os"
"path/filepath"
"time"

"golang.org/x/sys/unix"
)

func waitForSocket(path string, timeout time.Duration) error {
if err := tryDialUnixSocket(path); err == nil {
return nil
}

parent := filepath.Dir(path)
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
if err != nil {
return waitForSocketByPolling(path, timeout)
}
defer unix.Close(fd)

wd, err := unix.InotifyAddWatch(fd, parent, unix.IN_CREATE|unix.IN_MOVED_TO|unix.IN_ATTRIB)
if err != nil {
return waitForSocketByPolling(path, timeout)
}
defer unix.InotifyRmWatch(fd, uint32(wd))

deadline := time.Now().Add(timeout)
buf := make([]byte, 4096)
for {
if err := tryDialUnixSocket(path); err == nil {
return nil
}
remaining := time.Until(deadline)
if remaining <= 0 {
return fmt.Errorf("timeout waiting for socket")
}

pollTimeout := remaining
if socketPathExists(path) {
pollTimeout = minDuration(pollTimeout, socketReadyRetryEvery)
}
events := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
n, err := unix.Poll(events, durationMillisCeil(pollTimeout))
if err != nil {
if err == unix.EINTR {
continue
}
return waitForSocketByPolling(path, remaining)
}
if n > 0 {
for {
n, err := unix.Read(fd, buf)
if err != nil {
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
break
}
return waitForSocketByPolling(path, time.Until(deadline))
}
if n == 0 {
break
}
}
}
}
}

func socketPathExists(path string) bool {
_, err := os.Lstat(path)
return err == nil
}

func durationMillisCeil(d time.Duration) int {
if d <= 0 {
return 0
}
ms := d / time.Millisecond
if d%time.Millisecond != 0 {
ms++
}
if int64(ms) > int64(^uint(0)>>1) {
return int(^uint(0) >> 1)
}
return int(ms)
}

func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
9 changes: 9 additions & 0 deletions lib/hypervisor/firecracker/socket_wait_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !linux

package firecracker

import "time"

func waitForSocket(path string, timeout time.Duration) error {
return waitForSocketByPolling(path, timeout)
}
13 changes: 12 additions & 1 deletion lib/hypervisor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ type VMStarter interface {
// Each hypervisor implements its own restore flow:
// - Cloud Hypervisor: starts process, calls Restore API
// - QEMU: would start with -incoming or -loadvm flags (not yet implemented)
// Returns the process ID and a Hypervisor client. The VM is in paused state after restore.
// Returns the process ID and a Hypervisor client. The VM is usually paused
// after restore, unless the returned client reports RestoredResumed.
RestoreVM(ctx context.Context, p *paths.Paths, version string, socketPath string, snapshotPath string) (pid int, hv Hypervisor, err error)

// PrepareFork allows hypervisors to prepare forked instance state.
Expand Down Expand Up @@ -202,6 +203,16 @@ type Hypervisor interface {
Capabilities() Capabilities
}

type restoredResumedHypervisor interface {
RestoredResumed() bool
}

// RestoredResumed reports whether RestoreVM already resumed guest execution.
func RestoredResumed(hv Hypervisor) bool {
resumed, ok := hv.(restoredResumedHypervisor)
return ok && resumed.RestoredResumed()
}

// Capabilities indicates which optional features a hypervisor supports.
// Callers should check these before calling optional methods.
type Capabilities struct {
Expand Down
4 changes: 4 additions & 0 deletions lib/hypervisor/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ func (h *tracingHypervisor) Capabilities() Capabilities {
return h.next.Capabilities()
}

func (h *tracingHypervisor) RestoredResumed() bool {
return RestoredResumed(h.next)
}

func (h *tracingHypervisor) spanAttrs(attrs ...attribute.KeyValue) []attribute.KeyValue {
out := make([]attribute.KeyValue, 0, len(h.attrs)+len(attrs))
out = append(out, h.attrs...)
Expand Down
9 changes: 9 additions & 0 deletions lib/hypervisor/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (

type fakeHypervisor struct{}
type fakeHypervisorGetVMInfoError struct{}
type fakeRestoredResumedHypervisor struct {
fakeHypervisor
}

func (fakeHypervisor) DeleteVM(context.Context) error { return nil }
func (fakeHypervisor) Shutdown(context.Context) error { return nil }
Expand All @@ -36,6 +39,7 @@ func (fakeHypervisor) GetTargetGuestMemoryBytes(context.Context) (int64, error)
return 0, nil
}
func (fakeHypervisor) Capabilities() Capabilities { return Capabilities{} }
func (fakeRestoredResumedHypervisor) RestoredResumed() bool { return true }
func (fakeHypervisorGetVMInfoError) DeleteVM(context.Context) error { return nil }
func (fakeHypervisorGetVMInfoError) Shutdown(context.Context) error { return nil }
func (fakeHypervisorGetVMInfoError) GetVMInfo(context.Context) (*VMInfo, error) {
Expand Down Expand Up @@ -123,6 +127,11 @@ func TestWrapVMStarterWrapsReturnedHypervisor(t *testing.T) {
assert.Equal(t, string(TypeCloudHypervisor), attrs["hypervisor"])
}

func TestWrapHypervisorPreservesRestoredResumed(t *testing.T) {
hv := WrapHypervisor(TypeFirecracker, fakeRestoredResumedHypervisor{})
require.True(t, RestoredResumed(hv))
}

func TestWrapHypervisorSkipsGetVMInfoTraceByDefault(t *testing.T) {
recorder, _ := newTestTracerProvider(t)

Expand Down
Loading
Loading