Skip to content

Commit aa82547

Browse files
authored
chore: Introduce WithContext() for engine invocations (#590)
1 parent 2c2e922 commit aa82547

3 files changed

Lines changed: 97 additions & 9 deletions

File tree

pkg/workflow/engine_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package workflow
22

33
import (
4+
"context"
45
"fmt"
56
"net/url"
67
"testing"
@@ -524,3 +525,62 @@ func Test_EngineInvocationConcurrent(t *testing.T) {
524525
}
525526
}
526527
}
528+
529+
func Test_EngineImpl_InvokeWithContext_CustomContext(t *testing.T) {
530+
config := configuration.NewInMemory()
531+
engine := NewWorkFlowEngine(config)
532+
533+
wfId := NewWorkflowIdentifier("ctxtest")
534+
flagset := pflag.NewFlagSet("ctx", pflag.ContinueOnError)
535+
536+
var receivedCtx context.Context
537+
_, err := engine.Register(wfId, ConfigurationOptionsFromFlagset(flagset), func(invocation InvocationContext, input []Data) ([]Data, error) {
538+
receivedCtx = invocation.Context()
539+
return nil, nil
540+
})
541+
assert.NoError(t, err)
542+
543+
err = engine.Init()
544+
assert.NoError(t, err)
545+
546+
// Create a context with a specific value and deadline to verify it's passed through
547+
type ctxKey string
548+
testKey := ctxKey("test-key")
549+
testValue := "test-value"
550+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
551+
defer cancel()
552+
ctx = context.WithValue(ctx, testKey, testValue)
553+
554+
_, err = engine.Invoke(wfId, WithContext(ctx))
555+
assert.NoError(t, err)
556+
assert.NotNil(t, receivedCtx)
557+
assert.Equal(t, testValue, receivedCtx.Value(testKey))
558+
559+
// Verify deadline is propagated
560+
deadline, hasDeadline := receivedCtx.Deadline()
561+
assert.True(t, hasDeadline, "context should have a deadline")
562+
assert.False(t, deadline.IsZero(), "deadline should not be zero")
563+
}
564+
565+
func Test_EngineImpl_InvokeWithContext_DefaultContext(t *testing.T) {
566+
config := configuration.NewInMemory()
567+
engine := NewWorkFlowEngine(config)
568+
569+
wfId := NewWorkflowIdentifier("ctxdefault")
570+
flagset := pflag.NewFlagSet("cd", pflag.ContinueOnError)
571+
572+
var receivedCtx context.Context
573+
_, err := engine.Register(wfId, ConfigurationOptionsFromFlagset(flagset), func(invocation InvocationContext, input []Data) ([]Data, error) {
574+
receivedCtx = invocation.Context()
575+
return nil, nil
576+
})
577+
assert.NoError(t, err)
578+
579+
err = engine.Init()
580+
assert.NoError(t, err)
581+
582+
// Invoke without WithContext - should get a non-nil default context
583+
_, err = engine.Invoke(wfId)
584+
assert.NoError(t, err)
585+
assert.NotNil(t, receivedCtx)
586+
}

pkg/workflow/engineimpl.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package workflow
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/url"
@@ -37,9 +38,10 @@ type EngineImpl struct {
3738
var _ Engine = (*EngineImpl)(nil)
3839

3940
type engineRuntimeConfig struct {
40-
config configuration.Configuration
41-
input []Data
42-
ic analytics.InstrumentationCollector
41+
config configuration.Configuration
42+
input []Data
43+
ic analytics.InstrumentationCollector
44+
ctxFunc func() context.Context
4345
}
4446

4547
type EngineInvokeOption func(*engineRuntimeConfig)
@@ -62,6 +64,14 @@ func WithInstrumentationCollector(ic analytics.InstrumentationCollector) EngineI
6264
}
6365
}
6466

67+
func WithContext(ctx context.Context) EngineInvokeOption {
68+
return func(e *engineRuntimeConfig) {
69+
if ctx != nil {
70+
e.ctxFunc = func() context.Context { return ctx }
71+
}
72+
}
73+
}
74+
6575
func (e *EngineImpl) GetLogger() *zerolog.Logger {
6676
return e.logger
6777
}
@@ -321,11 +331,11 @@ func (e *EngineImpl) Invoke(
321331
e.mu.Unlock()
322332

323333
// create a context object for the invocation
324-
context := NewInvocationContext(id, options.config, localEngine, localNetworkAccess, localLogger, localAnalytics, localUi)
334+
invocationCtx := newInvocationContext(options.ctxFunc, id, options.config, localEngine, localNetworkAccess, localLogger, localAnalytics, localUi)
325335

326336
// invoke workflow through its callback
327337
localLogger.Printf("Workflow Start")
328-
output, err = callback(context, options.input)
338+
output, err = callback(invocationCtx, options.input)
329339
localLogger.Printf("Workflow End")
330340
}
331341
} else {

pkg/workflow/invocationcontextimpl.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"log"
66

77
"github.com/rs/zerolog"
8+
89
"github.com/snyk/go-application-framework/pkg/analytics"
910
"github.com/snyk/go-application-framework/pkg/configuration"
1011
"github.com/snyk/go-application-framework/pkg/networking"
@@ -13,6 +14,7 @@ import (
1314
"github.com/snyk/go-application-framework/pkg/utils"
1415
)
1516

17+
// Deprecated: NewInvocationContext creates a new invocation context.
1618
func NewInvocationContext(
1719
id Identifier,
1820
config configuration.Configuration,
@@ -22,7 +24,24 @@ func NewInvocationContext(
2224
analyticsImpl analytics.Analytics,
2325
ui ui.UserInterface,
2426
) InvocationContext {
27+
return newInvocationContext(nil, id, config, engine, network, logger, analyticsImpl, ui)
28+
}
29+
30+
func newInvocationContext(
31+
ctxFunc func() context.Context,
32+
id Identifier,
33+
config configuration.Configuration,
34+
engine Engine,
35+
network networking.NetworkAccess,
36+
logger zerolog.Logger,
37+
analyticsImpl analytics.Analytics,
38+
ui ui.UserInterface,
39+
) InvocationContext {
40+
if ctxFunc == nil {
41+
ctxFunc = context.Background
42+
}
2543
return &invocationContextImpl{
44+
ctxFunc: ctxFunc,
2645
WorkflowID: id,
2746
Configuration: config,
2847
WorkflowEngine: engine,
@@ -35,6 +54,7 @@ func NewInvocationContext(
3554

3655
// invocationContextImpl is the default implementation of the InvocationContext interface.
3756
type invocationContextImpl struct {
57+
ctxFunc func() context.Context
3858
WorkflowID Identifier
3959
WorkflowEngine Engine
4060
Configuration configuration.Configuration
@@ -47,10 +67,8 @@ type invocationContextImpl struct {
4767
var _ InvocationContext = (*invocationContextImpl)(nil)
4868

4969
// Context returns the context of the workflow that is being invoked.
50-
func (*invocationContextImpl) Context() context.Context {
51-
// TODO: This is using context.Background() as a placeholder. Ideally this returns
52-
// the context representing the lifecycle of the workflow that is being invoked.
53-
return context.Background()
70+
func (ici *invocationContextImpl) Context() context.Context {
71+
return ici.ctxFunc()
5472
}
5573

5674
// GetWorkflowIdentifier returns the identifier of the workflow that is being invoked.

0 commit comments

Comments
 (0)