diff --git a/.mockery.yaml b/.mockery.yaml index 91e44a850e..16ee6356dd 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -38,13 +38,12 @@ packages: github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host: interfaces: ModuleV1: {} - ModuleV2: {} + github.com/smartcontractkit/chainlink-common/pkg/workflows/host: + interfaces: + Module: {} ExecutionHelper: config: - inpackage: true - filename: "mock_{{.InterfaceName | snakecase}}_test.go" - mockname: "Mock{{.InterfaceName}}" - dir: "{{.InterfaceDir}}" + mockname: "Mock{{.InterfaceName}}" github.com/smartcontractkit/chainlink-common/pkg/custmsg: interfaces: MessageEmitter: @@ -64,4 +63,4 @@ packages: dir: "{{.InterfaceDir}}/limits" outpkg: limits interfaces: - Getter: \ No newline at end of file + Getter: diff --git a/pkg/capabilities/v2/protoc/pkg/template_generator.go b/pkg/capabilities/v2/protoc/pkg/template_generator.go index eeb674c7f7..d64435c4e3 100644 --- a/pkg/capabilities/v2/protoc/pkg/template_generator.go +++ b/pkg/capabilities/v2/protoc/pkg/template_generator.go @@ -139,9 +139,9 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial if md == nil { return false, nil - } else { - return md.MapToUntypedApi, nil } + + return md.MapToUntypedApi, nil }, "addImport": func(importPath protogen.GoImportPath, ignore string) string { importName := importPath.String() @@ -259,6 +259,13 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial return line } }, + "TeeEnabled": func(s *protogen.Service) (bool, error) { + md, err := getCapabilityMetadata(s) + if err != nil { + return false, err + } + return slices.Contains(md.AdditionalEnvironments, generator.AdditionalEnvironments_ADDITIONAL_ENVIRONMENTS_TEE), nil + }, }).Funcs(t.ExtraFns) // Register partials diff --git a/pkg/workflows/artifacts/artifacts_test.go b/pkg/workflows/artifacts/artifacts_test.go index b9240ab274..6c14d6b71c 100644 --- a/pkg/workflows/artifacts/artifacts_test.go +++ b/pkg/workflows/artifacts/artifacts_test.go @@ -45,8 +45,9 @@ func (s *ArtifactsTestSuite) TestArtifacts() { s.lggr.Info("WorkflowCompiledBinary Size", "size", len(b64EncodedBinaryData)) - // Compare the keccak256 hash of the binary data with the keccak256 hash of the - // base64 encoded binary from CRE-CLI + // Compare the keccak256 hash of the binary data against a value produced by + // the pinned Go toolchain (see GetBuildCmd in utils.go). Because Compile sets + // GOTOOLCHAIN from the nearest go.mod, this hash is stable across machines. expKeccak256Hash, err := hex.DecodeString("a057a58ff8212122016515b2922b7c3893525f7f5afe95c8442e0cd629d68420") s.NoError(err, "failed to decode expected keccak256 hash") keccak256FromSha3Lib := sha3.NewLegacyKeccak256() @@ -58,8 +59,7 @@ func (s *ArtifactsTestSuite) TestArtifacts() { s.NoError(err, "failed to prepare artifacts") base64EncodedBinaryData := artifacts.GetBinaryData() - // Compare if the compiled WASM binary is the same as the CRE-CLI output - s.Len(base64EncodedBinaryData, 636684, "binary data size should be same as CRE-CLI output") + s.Len(base64EncodedBinaryData, 636684, "binary data size should match the pinned toolchain output") s.Equal("m1upG3s6AJQvOA8AAK295+EaARsHAADf/YcBgFURwPQAANDq5wFQVVVVVVVVVVVV3ZMQEI7ZtgMAAKqq", string(base64EncodedBinaryData[0:80])) s.Equal("gUEoFNoVRfyHGTsZmdg7wCJbGVibOhmYmsDAytgg92FTTmiddpI/x8SYzdANBkPGhtLoj/Hn7jvK26YE", diff --git a/pkg/workflows/artifacts/utils.go b/pkg/workflows/artifacts/utils.go index f022f61b93..cd7fe97f2a 100644 --- a/pkg/workflows/artifacts/utils.go +++ b/pkg/workflows/artifacts/utils.go @@ -1,6 +1,7 @@ package artifacts import ( + "bufio" "fmt" "os" "os/exec" @@ -78,10 +79,66 @@ func GetBuildCmd(inputFile string, outputFile string, rootFolder string) *exec.C "-buildvcs=false", inputFile, ) - buildCmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm", "CGO_ENABLED=0") + env := append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm", "CGO_ENABLED=0") + // Pin GOTOOLCHAIN so the compiled WASM is reproducible. Prefer the + // configured GOTOOLCHAIN; when it's unset, fall back to the version + // declared in the module's go.mod (the local Go version wouldn't pin). + if toolchain := goToolchain(rootFolder); toolchain != "" { + env = append(env, "GOTOOLCHAIN="+toolchain) + } + buildCmd.Env = env } buildCmd.Dir = rootFolder return buildCmd } + +// goToolchain returns a GOTOOLCHAIN value (e.g. "go1.26.2") to pin the build +// to. It prefers the configured `go env GOTOOLCHAIN`; when that is unset (e.g. +// "auto") it falls back to the go version declared in the module's go.mod. The +// local Go version is not used as a fallback because it would not pin a +// reproducible toolchain. Returns "" when nothing can be determined. +func goToolchain(dir string) string { + if v := goEnv(dir, "GOTOOLCHAIN"); v != "" && v != "auto" { + return v + } + return goToolchainFromMod(dir) +} + +// goEnv runs `go env ` in dir and returns the trimmed value, or "" on error. +func goEnv(dir, name string) string { + cmd := exec.Command("go", "env", name) + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +} + +// goToolchainFromMod returns a GOTOOLCHAIN value derived from the go directive +// in the module's go.mod, located via `go env GOMOD`. Returns "" when no go.mod +// or go version can be determined. +func goToolchainFromMod(dir string) string { + goModPath := goEnv(dir, "GOMOD") + // `go env GOMOD` returns "" outside a module and os.DevNull when modules + // are disabled (GO111MODULE=off); neither is a real go.mod file. + if goModPath == "" || goModPath == os.DevNull { + return "" + } + f, err := os.Open(goModPath) + if err != nil { + return "" + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) == 2 && fields[0] == "go" { + return "go" + fields[1] + } + } + return "" +} diff --git a/pkg/workflows/wasm/host/mock_execution_helper_test.go b/pkg/workflows/host/mocks/execution_helper.go similarity index 99% rename from pkg/workflows/wasm/host/mock_execution_helper_test.go rename to pkg/workflows/host/mocks/execution_helper.go index 9f2a84d8a8..455fde3041 100644 --- a/pkg/workflows/wasm/host/mock_execution_helper_test.go +++ b/pkg/workflows/host/mocks/execution_helper.go @@ -1,9 +1,10 @@ // Code generated by mockery v2.53.3. DO NOT EDIT. -package host +package mocks import ( context "context" + time "time" sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" diff --git a/pkg/workflows/host/mocks/module.go b/pkg/workflows/host/mocks/module.go new file mode 100644 index 0000000000..8576d1238e --- /dev/null +++ b/pkg/workflows/host/mocks/module.go @@ -0,0 +1,207 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + + host "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + mock "github.com/stretchr/testify/mock" +) + +// Module is an autogenerated mock type for the Module type +type Module struct { + mock.Mock +} + +type Module_Expecter struct { + mock *mock.Mock +} + +func (_m *Module) EXPECT() *Module_Expecter { + return &Module_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with no fields +func (_m *Module) Close() { + _m.Called() +} + +// Module_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Module_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Module_Expecter) Close() *Module_Close_Call { + return &Module_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Module_Close_Call) Run(run func()) *Module_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Close_Call) Return() *Module_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Close_Call) RunAndReturn(run func()) *Module_Close_Call { + _c.Run(run) + return _c +} + +// Execute provides a mock function with given fields: ctx, request, handler +func (_m *Module) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { + ret := _m.Called(ctx, request, handler) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 *sdk.ExecutionResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { + return rf(ctx, request, handler) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { + r0 = rf(ctx, request, handler) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.ExecutionResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { + r1 = rf(ctx, request, handler) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Module_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type Module_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.ExecuteRequest +// - handler host.ExecutionHelper +func (_e *Module_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *Module_Execute_Call { + return &Module_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} +} + +func (_c *Module_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *Module_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) + }) + return _c +} + +func (_c *Module_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *Module_Execute_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Module_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *Module_Execute_Call { + _c.Call.Return(run) + return _c +} + +// IsLegacyDAG provides a mock function with no fields +func (_m *Module) IsLegacyDAG() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsLegacyDAG") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Module_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' +type Module_IsLegacyDAG_Call struct { + *mock.Call +} + +// IsLegacyDAG is a helper method to define mock.On call +func (_e *Module_Expecter) IsLegacyDAG() *Module_IsLegacyDAG_Call { + return &Module_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} +} + +func (_c *Module_IsLegacyDAG_Call) Run(run func()) *Module_IsLegacyDAG_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) Return(_a0 bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) RunAndReturn(run func() bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with no fields +func (_m *Module) Start() { + _m.Called() +} + +// Module_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type Module_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +func (_e *Module_Expecter) Start() *Module_Start_Call { + return &Module_Start_Call{Call: _e.mock.On("Start")} +} + +func (_c *Module_Start_Call) Run(run func()) *Module_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Start_Call) Return() *Module_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Start_Call) RunAndReturn(run func()) *Module_Start_Call { + _c.Run(run) + return _c +} + +// NewModule creates a new instance of Module. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewModule(t interface { + mock.TestingT + Cleanup(func()) +}) *Module { + mock := &Module{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/module.go b/pkg/workflows/host/module.go new file mode 100644 index 0000000000..f4debbb922 --- /dev/null +++ b/pkg/workflows/host/module.go @@ -0,0 +1,48 @@ +//go:generate go run ./requirements_gen + +package host + +import ( + "context" + "time" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" +) + +type ModuleBase interface { + Start() + Close() + IsLegacyDAG() bool +} + +type Module interface { + ModuleBase + + // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution + Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) +} + +type RequirementEnforcingModule interface { + Module + + // SetRequirements must respect the requirements for the execution until it completes + SetRequirements(executionId string, requirements *sdkpb.Requirements) +} + +// ExecutionHelper Implemented by those running the host, for example the Workflow Engine +type ExecutionHelper interface { + // CallCapability blocking call to the Workflow Engine + CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) + GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) + + GetWorkflowExecutionID() string + + GetNodeTime() time.Time + + GetDONTime() (time.Time, error) + + EmitUserLog(log string) error + + EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error +} diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go new file mode 100644 index 0000000000..24d6d2b4a3 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -0,0 +1,131 @@ +package host + +import ( + "context" + "fmt" + "sync" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type ModuleAndHandler struct { + Module + RequirementsHandler +} + +// lazyModule wraps a ModuleAndHandler so that Start is called at most once +// and Close only fires for modules that were actually started. The mutex +// serializes start/close so a concurrent Close cannot race past an in-flight +// Start (leaving a started module unclosed) and vice versa. +type lazyModule struct { + ModuleAndHandler + mu sync.Mutex + started bool + closed bool +} + +func (l *lazyModule) ensureStarted() { + l.mu.Lock() + defer l.mu.Unlock() + if l.started || l.closed { + return + } + l.Module.Start() + l.started = true +} + +func (l *lazyModule) ensureClosed() { + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + return + } + l.closed = true + if l.started { + l.Module.Close() + } +} + +// NewRequirementSelectingModule creates a module that routes trigger executions +// based on subscription requirements. main is prepended as modules[0]; additional +// modules follow. Subscribe always runs on modules[0]. +func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAndHandler) Module { + modules := make([]*lazyModule, 1+len(additional)) + modules[0] = &lazyModule{ModuleAndHandler: main} + for i, a := range additional { + modules[1+i] = &lazyModule{ModuleAndHandler: a} + } + return &requirementSelectingModule{modules: modules} +} + +type triggerInfo struct { + moduleIdx int + requirements *sdk.Requirements +} + +type requirementSelectingModule struct { + modules []*lazyModule + // triggerID → triggerInfo + cache sync.Map +} + +func (r *requirementSelectingModule) Start() { + r.modules[0].ensureStarted() +} + +func (r *requirementSelectingModule) Close() { + for _, m := range r.modules { + m.ensureClosed() + } +} + +func (r *requirementSelectingModule) IsLegacyDAG() bool { + return r.modules[0].IsLegacyDAG() +} + +func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + if request.GetTrigger() == nil { + return r.subscribe(ctx, request, handler) + } + return r.trigger(ctx, request, handler) +} + +func (r *requirementSelectingModule) subscribe(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + result, err := r.modules[0].Execute(ctx, request, handler) + if err != nil { + return nil, err + } + + for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { + matched := false + for j, m := range r.modules { + if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) { + m.ensureStarted() + r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements}) + matched = true + break + } + } + if !matched { + return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements for trigger %d", i) + } + } + + return result, nil +} + +func (r *requirementSelectingModule) trigger(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + trigger := request.GetTrigger() + if val, cached := r.cache.Load(trigger.Id); cached { + info := val.(triggerInfo) + m := r.modules[info.moduleIdx] + if rem, ok := m.Module.(RequirementEnforcingModule); ok && info.requirements != nil { + rem.SetRequirements(handler.GetWorkflowExecutionID(), info.requirements) + } + + return m.Execute(ctx, request, handler) + } + return r.modules[0].Execute(ctx, request, handler) +} + +var _ Module = &requirementSelectingModule{} diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go new file mode 100644 index 0000000000..4124a05672 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -0,0 +1,578 @@ +package host_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type stubModule struct { + startFn func() + closeFn func() + legacyFn func() bool + executeFn func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) +} + +func (s *stubModule) Start() { s.startFn() } +func (s *stubModule) Close() { s.closeFn() } +func (s *stubModule) IsLegacyDAG() bool { return s.legacyFn() } +func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return s.executeFn(ctx, req, h) +} + +type requirementEnforcingStub struct { + *stubModule + setRequirementsFn func(string, *sdk.Requirements) +} + +func (s *requirementEnforcingStub) SetRequirements(executionID string, requirements *sdk.Requirements) { + s.setRequirementsFn(executionID, requirements) +} + +func noop() {} +func noopClose() {} + +func triggerRequest(id uint64) *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Trigger{ + Trigger: &sdk.Trigger{Id: id}, + }, + } +} + +func subscribeRequest() *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}, + } +} + +func subscribeResult(subs ...*sdk.TriggerSubscription) *sdk.ExecutionResult { + return &sdk.ExecutionResult{ + Result: &sdk.ExecutionResult_TriggerSubscriptions{ + TriggerSubscriptions: &sdk.TriggerSubscriptionRequest{ + Subscriptions: subs, + }, + }, + } +} + +func subWithReqs(reqs *sdk.Requirements) *sdk.TriggerSubscription { + return &sdk.TriggerSubscription{Requirements: reqs} +} + +func TestRequirementSelectingModule_Start(t *testing.T) { + t.Run("starts only main module", func(t *testing.T) { + var mainStarted, additionalStarted bool + main := host.ModuleAndHandler{Module: &stubModule{startFn: func() { mainStarted = true }}} + add := host.ModuleAndHandler{Module: &stubModule{startFn: func() { additionalStarted = true }}} + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + assert.True(t, mainStarted) + assert.False(t, additionalStarted) + }) +} + +func TestRequirementSelectingModule_Close(t *testing.T) { + t.Run("closes main and no additional when none started", func(t *testing.T) { + var mainClosed, addClosed bool + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { mainClosed = true }, + }} + add := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { addClosed = true }, + }} + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + m.Close() + + assert.True(t, mainClosed) + assert.False(t, addClosed) + }) + + t.Run("closes main and all started additional modules", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + var mainClosed, add0Closed, add1Closed bool + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + closeFn: func() { mainClosed = true }, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add0 := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add0Closed = true }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + add1 := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add1Closed = true }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add0, add1}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + m.Close() + + assert.True(t, mainClosed, "main should be closed") + assert.True(t, add0Closed, "started additional should be closed") + assert.False(t, add1Closed, "never-started additional should not be closed") + }) +} + +func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { + main := host.ModuleAndHandler{Module: &stubModule{legacyFn: func() bool { return true }}} + m := host.NewRequirementSelectingModule(main, nil) + assert.True(t, m.IsLegacyDAG()) +} + +func TestRequirementSelectingModule_Execute(t *testing.T) { + t.Run("trigger with no cached entry goes to main", func(t *testing.T) { + want := &sdk.ExecutionResult{} + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + return want, nil + } + return subscribeResult(), nil + }, + }} + + m := host.NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("main error on subscribe propagates", func(t *testing.T) { + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, assert.AnError + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called") + return nil, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + assert.ErrorIs(t, err, assert.AnError) + }) + + t.Run("subscribe with requirements routes trigger to additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("subscribe with unmatched requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") + }) + + t.Run("subscribe skips non-matching and selects later additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add0 := host.ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + add1 := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add0, add1}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("additional module started lazily during subscribe", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var addStartCount int32 + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: func() { atomic.AddInt32(&addStartCount, 1) }, + closeFn: noopClose, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + assert.Equal(t, int32(0), atomic.LoadInt32(&addStartCount)) + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + + // Second subscribe does not start additional again (sync.Once). + _, err = m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + }) + + t.Run("subscribe with no requirements returns main result", func(t *testing.T) { + want := subscribeResult() + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }} + + m := host.NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("main module satisfying requirements keeps trigger on main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + var mainTriggerCalls int32 + main := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return want, nil + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called when main satisfies requirements") + return nil, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls), "trigger should run on main") + }) + + t.Run("cached trigger sets requirements before execute", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + executionID := "wf-exec-1" + + main := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + var calls []string + var gotReqs *sdk.Requirements + var gotExecutionID string + enforcingAdd := &requirementEnforcingStub{ + stubModule: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + calls = append(calls, "execute") + return want, nil + }, + }, + setRequirementsFn: func(id string, requirements *sdk.Requirements) { + calls = append(calls, "set") + gotExecutionID = id + gotReqs = requirements + }, + } + add := host.ModuleAndHandler{ + Module: enforcingAdd, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + helper := &mocks.MockExecutionHelper{} + helper.On("GetWorkflowExecutionID").Return(executionID).Once() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), helper) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, []string{"set", "execute"}, calls) + assert.Equal(t, executionID, gotExecutionID) + assert.Same(t, teeReqs, gotReqs) + helper.AssertExpectations(t) + }) +} + +func TestRequirementSelectingModule_TriggerCache(t *testing.T) { + t.Run("cached trigger skips main on subsequent calls", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main") + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main on repeat") + }) + + t.Run("trigger not in cache goes to main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil + } + // subscription 0 has requirements; subscription 1 does not + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + // trigger 1 has no requirements → goes to main + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) + }) + + t.Run("different triggers route to different modules", func(t *testing.T) { + // subscription 0: TEE required → additional; subscription 1: no requirements → main + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{ + Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, + }}, + }} + var mainTriggerCalls int32 + wantAdditional := &sdk.ExecutionResult{} + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil + } + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil + }, + }} + add := host.ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantAdditional, nil + }, + }, + RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + // trigger 0 has TEE requirements → additional + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, wantAdditional, got) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls)) + + // trigger 1 has no requirements → main + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) + }) + + t.Run("no additional modules when subscribe has requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + main := host.ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + + m := host.NewRequirementSelectingModule(main, nil) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner") + }) +} diff --git a/pkg/workflows/host/requirements_gen/main.go b/pkg/workflows/host/requirements_gen/main.go new file mode 100644 index 0000000000..b4b4bd3e6f --- /dev/null +++ b/pkg/workflows/host/requirements_gen/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "bytes" + _ "embed" + "log" + "os" + "reflect" + "text/template" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/codegen" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +//go:embed requirements_helper.go.tmpl +var tmplSrc string + +type fieldInfo struct { + Name string + Type string +} + +type templateData struct { + Fields []fieldInfo +} + +func main() { + requirementsType := reflect.TypeOf(sdk.Requirements{}) + + var fields []fieldInfo + for i := 0; i < requirementsType.NumField(); i++ { + f := requirementsType.Field(i) + if !f.IsExported() { + continue + } + fields = append(fields, fieldInfo{ + Name: f.Name, + Type: f.Type.String(), + }) + } + + tmpl, err := template.New("requirements_helper").Parse(tmplSrc) + if err != nil { + log.Fatalf("failed to parse template: %v", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, templateData{Fields: fields}); err != nil { + log.Fatalf("failed to execute template: %v", err) + } + + const outFile = "requirements_helper_gen.go" + settings := codegen.PrettySettings{ + Tool: "requirements_gen", + GoPrettySettings: codegen.GoPrettySettings{ + LocalPrefix: "github.com/smartcontractkit/chainlink-common", + }, + } + + content, err := codegen.PrettyFile(outFile, buf.String(), settings) + if err != nil { + log.Fatalf("failed to format generated code: %v\n%s", err, buf.String()) + } + + if err := os.WriteFile(outFile, []byte(content), 0644); err != nil { + log.Fatalf("failed to write output: %v", err) + } +} diff --git a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl new file mode 100644 index 0000000000..da55611a42 --- /dev/null +++ b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl @@ -0,0 +1,35 @@ +package host + +import "ctx" + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a boolean indicating whether the requirement is satisfied. +type RequirementsHandler struct { +{{- range .Fields}} + {{.Name}} func(context.Context, {{.Type}}) bool +{{- end}} +} + +// CheckRequirements calls each non-nil callback in the handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + +{{range .Fields}} + if req.{{.Name}} != nil { + if handler.{{.Name}} == nil || !handler.{{.Name}}(ctx, req.{{.Name}}) { + return false + } + + } +{{end}} + + return true +} diff --git a/pkg/workflows/host/requirements_helper_gen.go b/pkg/workflows/host/requirements_helper_gen.go new file mode 100644 index 0000000000..d9e5140f34 --- /dev/null +++ b/pkg/workflows/host/requirements_helper_gen.go @@ -0,0 +1,37 @@ +// Code generated by requirements_gen, DO NOT EDIT. + +package host + +import ( + "context" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a boolean indicating whether the requirement is satisfied. +type RequirementsHandler struct { + Tee func(context.Context, *sdk.Tee) bool +} + +// CheckRequirements calls each non-nil callback in the handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + + if req.Tee != nil { + if handler.Tee == nil || !handler.Tee(ctx, req.Tee) { + return false + } + + } + + return true +} diff --git a/pkg/workflows/host/requirements_helper_gen_test.go b/pkg/workflows/host/requirements_helper_gen_test.go new file mode 100644 index 0000000000..68bb23b3d7 --- /dev/null +++ b/pkg/workflows/host/requirements_helper_gen_test.go @@ -0,0 +1,48 @@ +package host + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func Test_CheckRequirements(t *testing.T) { + t.Parallel() + t.Run("unknown proto fields", func(t *testing.T) { + // Encode a field number (99) unknown to Requirements so proto.Unmarshal + // preserves it as unknown bytes. + b := protowire.AppendTag(nil, 99, protowire.VarintType) + b = protowire.AppendVarint(b, 1) + req := &sdk.Requirements{} + require.NoError(t, proto.Unmarshal(b, req)) + + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) + }) + + t.Run("no fields always passes", func(t *testing.T) { + assert.True(t, CheckRequirements(context.Background(), RequirementsHandler{}, &sdk.Requirements{})) + }) + + t.Run("handler not set returns false", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) + }) + + t.Run("handler returns false causes false return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }} + assert.False(t, CheckRequirements(context.Background(), handler, req)) + }) + + t.Run("handler returns true causes true return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }} + assert.True(t, CheckRequirements(context.Background(), handler, req)) + }) +} diff --git a/pkg/workflows/host/tee_provider.go b/pkg/workflows/host/tee_provider.go new file mode 100644 index 0000000000..90f2e839cd --- /dev/null +++ b/pkg/workflows/host/tee_provider.go @@ -0,0 +1,59 @@ +package host + +import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + +type teeProvider struct { + sdkpb.TeeType + regions map[string]bool +} + +func NewTeeProvider(tpe sdkpb.TeeType, regions []string) func(tee *sdkpb.Tee) bool { + supportedRegions := map[string]bool{} + for _, region := range regions { + supportedRegions[region] = true + } + return (&teeProvider{TeeType: tpe, regions: supportedRegions}).Provides +} + +func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { + if tee == nil { + return true + } + + var regions []string + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + regions = teet.AnyRegions.Regions + case *sdkpb.Tee_TeeTypesAndRegions: + if teet.TeeTypesAndRegions == nil { + return false + } + + found := false + for _, tr := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + if tr.Type == t.TeeType { + found = true + regions = tr.Regions + break + } + } + + if !found { + return false + } + default: + return false + } + + if len(regions) == 0 { + return true + } + + for _, region := range regions { + if t.regions[region] { + return true + } + } + + return false +} diff --git a/pkg/workflows/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go new file mode 100644 index 0000000000..569bf138c9 --- /dev/null +++ b/pkg/workflows/host/tee_provider_test.go @@ -0,0 +1,171 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestNewTeeProvider(t *testing.T) { + t.Parallel() + t.Run("matches any", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type selection with matching region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(99)}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("does not match different types", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99)} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches type and region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type but not region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches one of multiple requested regions", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"eu-west-1": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2", "eu-west-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("provider has multiple regions and one matches", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-east-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("no matching region across multiple provider regions", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"ap-southeast-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("type mismatch ignores region match", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99), regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches any tee", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provides(tee)) + }) + + t.Run("returns a function that checks regions", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, provides(tee)) + }) + + t.Run("returns false when tee item is nil", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{} + assert.False(t, provides(tee)) + }) + + t.Run("AnyRegions with empty region list returns false", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with empty region list returns true", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, + }}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with nil regions returns true", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: nil}, + }, + }}} + assert.True(t, provides(tee)) + }) +} diff --git a/pkg/workflows/host/tee_selection_provider.go b/pkg/workflows/host/tee_selection_provider.go new file mode 100644 index 0000000000..80b675f425 --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider.go @@ -0,0 +1,53 @@ +package host + +import ( + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func NewProviderFromSelection(types []*sdkpb.TeeTypeAndRegions) func(tee *sdkpb.Tee) bool { + if len(types) == 1 { + return NewTeeProvider(types[0].Type, types[0].Regions) + } + + supplies := make(map[sdkpb.TeeType][]string) + for _, t := range types { + supplies[t.Type] = append(supplies[t.Type], t.Regions...) + } + + providers := make(map[sdkpb.TeeType]func(tee *sdkpb.Tee) bool) + for k, v := range supplies { + providers[k] = NewTeeProvider(k, v) + } + + return func(tee *sdkpb.Tee) bool { + if tee == nil { + return true + } + + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + for _, provider := range providers { + if provider(tee) { + return true + } + } + + return false + case *sdkpb.Tee_TeeTypesAndRegions: + for _, requestedType := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + provider, ok := providers[requestedType.Type] + if !ok { + continue + } + + if provider(tee) { + return true + } + } + + return false + default: + return false + } + } +} diff --git a/pkg/workflows/host/tee_selection_provider_test.go b/pkg/workflows/host/tee_selection_provider_test.go new file mode 100644 index 0000000000..bb68ea55c6 --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider_test.go @@ -0,0 +1,242 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestNewProviderFromSelection(t *testing.T) { + t.Parallel() + + t.Run("returns false for nil selection", func(t *testing.T) { + provider := NewProviderFromSelection(nil) + assert.False(t, provider(&sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}})) + }) + + t.Run("single type selection delegates to tee provider", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types support any tee", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types merges regions for same type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }) + + regions := []string{"eu-west-1"} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: regions, + }}}}} + assert.True(t, provider(tee)) + regions[0] = "us-west-2" + assert.True(t, provider(tee)) + }) + + t.Run("multiple types returns false when requested type is not supplied", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType(999), + Regions: []string{"us-west-2"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("returns false for unsupported tee shape", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + }}) + assert.False(t, provider(&sdkpb.Tee{})) + }) + + t.Run("multi-type AnyRegions returns false when no provider matches", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + // AnyRegions with a region that doesn't match any provider's regions + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"ap-southeast-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with nil TeeTypesAndRegions returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: nil}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions returns false when all requested types not in providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request types that don't exist in providers + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(888), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("single type TeeTypesAndRegions with non-matching region returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"eu-west-1"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions partial match skips non-providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request both AWS_NITRO and an unknown type; AWS_NITRO should match + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns directly without closure for AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns false for non-matching AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with empty TeeTypeAndRegions array returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{}, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with empty Regions and matching TEE type array returns true", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}}, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("multi-type AnyRegions with no region returns true", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types with no regions in first type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with first type not matching then match on second", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(888), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("unsupported item type in multi-type scenario", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions all types not in providers with continue path", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(555), Regions: []string{"us-west-2"}}, + }) + + // Request a type that is never in providers - forces continue on every iteration + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(777), Regions: []string{"us-west-2"}}, + }, + }}} + assert.False(t, provider(tee)) + }) +} diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index dfdad8114c..e18a5fcd59 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -58,6 +58,7 @@ func SendError(err error) { func SendSubscription(subscriptions *sdk.TriggerSubscriptionRequest) { execResult := &sdk.ExecutionResult{Result: &sdk.ExecutionResult_TriggerSubscriptions{TriggerSubscriptions: subscriptions}} sendResponse(BufferToPointerLen(Must(proto.Marshal(execResult)))) + os.Exit(0) } func Now() time.Time { diff --git a/pkg/workflows/wasm/host/mocks/module_v2.go b/pkg/workflows/wasm/host/mocks/module_v2.go index 4c84a3b4ae..dbcb7aa0c1 100644 --- a/pkg/workflows/wasm/host/mocks/module_v2.go +++ b/pkg/workflows/wasm/host/mocks/module_v2.go @@ -1,207 +1,20 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. - package mocks import ( - context "context" + "github.com/stretchr/testify/mock" - host "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" - sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" - mock "github.com/stretchr/testify/mock" + hostmocks "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" ) -// ModuleV2 is an autogenerated mock type for the ModuleV2 type -type ModuleV2 struct { - mock.Mock -} - -type ModuleV2_Expecter struct { - mock *mock.Mock -} - -func (_m *ModuleV2) EXPECT() *ModuleV2_Expecter { - return &ModuleV2_Expecter{mock: &_m.Mock} -} - -// Close provides a mock function with no fields -func (_m *ModuleV2) Close() { - _m.Called() -} - -// ModuleV2_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type ModuleV2_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Close() *ModuleV2_Close_Call { - return &ModuleV2_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *ModuleV2_Close_Call) Run(run func()) *ModuleV2_Close_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Close_Call) Return() *ModuleV2_Close_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Close_Call) RunAndReturn(run func()) *ModuleV2_Close_Call { - _c.Run(run) - return _c -} - -// Execute provides a mock function with given fields: ctx, request, handler -func (_m *ModuleV2) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { - ret := _m.Called(ctx, request, handler) - - if len(ret) == 0 { - panic("no return value specified for Execute") - } - - var r0 *sdk.ExecutionResult - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { - return rf(ctx, request, handler) - } - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { - r0 = rf(ctx, request, handler) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*sdk.ExecutionResult) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { - r1 = rf(ctx, request, handler) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ModuleV2_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' -type ModuleV2_Execute_Call struct { - *mock.Call -} - -// Execute is a helper method to define mock.On call -// - ctx context.Context -// - request *sdk.ExecuteRequest -// - handler host.ExecutionHelper -func (_e *ModuleV2_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *ModuleV2_Execute_Call { - return &ModuleV2_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} -} - -func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *ModuleV2_Execute_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) - }) - return _c -} - -func (_c *ModuleV2_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *ModuleV2_Execute_Call { - _c.Call.Return(_a0, _a1) - return _c -} +// ModuleV2 is a backward-compatible alias for hostmocks.Module. +// The ModuleV2 interface now lives in pkg/workflows/host as Module; +// this alias keeps existing consumers compiling without changes. +type ModuleV2 = hostmocks.Module -func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *ModuleV2_Execute_Call { - _c.Call.Return(run) - return _c -} - -// IsLegacyDAG provides a mock function with no fields -func (_m *ModuleV2) IsLegacyDAG() bool { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for IsLegacyDAG") - } - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -// ModuleV2_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' -type ModuleV2_IsLegacyDAG_Call struct { - *mock.Call -} - -// IsLegacyDAG is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) IsLegacyDAG() *ModuleV2_IsLegacyDAG_Call { - return &ModuleV2_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Run(run func()) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Return(_a0 bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) RunAndReturn(run func() bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(run) - return _c -} - -// Start provides a mock function with no fields -func (_m *ModuleV2) Start() { - _m.Called() -} - -// ModuleV2_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type ModuleV2_Start_Call struct { - *mock.Call -} - -// Start is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Start() *ModuleV2_Start_Call { - return &ModuleV2_Start_Call{Call: _e.mock.On("Start")} -} - -func (_c *ModuleV2_Start_Call) Run(run func()) *ModuleV2_Start_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Start_Call) Return() *ModuleV2_Start_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Start_Call) RunAndReturn(run func()) *ModuleV2_Start_Call { - _c.Run(run) - return _c -} - -// NewModuleV2 creates a new instance of ModuleV2. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. +// NewModuleV2 creates a new instance of ModuleV2 (alias for hostmocks.NewModule). func NewModuleV2(t interface { mock.TestingT Cleanup(func()) }) *ModuleV2 { - mock := &ModuleV2{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock + return hostmocks.NewModule(t) } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index cce932515a..a896aee0b0 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -21,6 +21,8 @@ import ( "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -32,7 +34,6 @@ import ( wasmdagpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" "github.com/smartcontractkit/chainlink-protos/cre/go/values" - wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) const v2ImportPrefix = "version_v2" @@ -109,11 +110,7 @@ type ModuleConfig struct { Determinism *DeterminismConfig } -type ModuleBase interface { - Start() - Close() - IsLegacyDAG() bool -} +type ModuleBase = host.ModuleBase type ModuleV1 interface { ModuleBase @@ -122,29 +119,9 @@ type ModuleV1 interface { Run(ctx context.Context, request *wasmdagpb.Request) (*wasmdagpb.Response, error) } -type ModuleV2 interface { - ModuleBase - - // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution - Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) -} - -// ExecutionHelper Implemented by those running the host, for example the Workflow Engine -type ExecutionHelper interface { - // CallCapability blocking call to the Workflow Engine - CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) - GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) +type ModuleV2 = host.Module - GetWorkflowExecutionID() string - - GetNodeTime() time.Time - - GetDONTime() (time.Time, error) - - EmitUserLog(log string) error - - EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error -} +type ExecutionHelper = host.ExecutionHelper type module struct { engine *wasmtime.Engine diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index 7aa4343621..bfa1818576 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" @@ -622,7 +624,7 @@ func Test_SdkLabeler(t *testing.T) { // CallAwaitRace validates that every call can be awaited. func Test_CallAwaitRace(t *testing.T) { ctx := t.Context() - mockExecHelper := NewMockExecutionHelper(t) + mockExecHelper := mocks.NewMockExecutionHelper(t) mockExecHelper.EXPECT(). CallCapability(matches.AnyContext, mock.Anything). Return(&sdkpb.CapabilityResponse{}, nil) diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index 426d343cdd..f67545821b 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -25,6 +25,8 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" @@ -47,7 +49,7 @@ func init() { func TestStandardConfig(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Some languages call time during initiation of the executable before the main is called. // This would be in unknown mode, which would call Node mode by default. @@ -63,7 +65,7 @@ func TestStandardConfig(t *testing.T) { func TestStandardErrors(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -77,7 +79,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { // To ensure the calls are actually async, the mock will block the first call until the second call is made. // The first call sets InputThing to true, the second to false. t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -120,7 +122,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() }).Maybe() @@ -152,7 +154,7 @@ func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { func TestStandardModeSwitch(t *testing.T) { t.Parallel() t.Run("successful mode switch", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Node calls may occur on initialization depending on the language. var donCall bool @@ -192,7 +194,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("node runtime in don mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -217,7 +219,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("don runtime in node mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -252,7 +254,7 @@ func TestStandardModeSwitch(t *testing.T) { func TestStandardLogging(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -272,7 +274,7 @@ func TestStandardMultipleTriggers(t *testing.T) { t.Parallel() m := makeTestModule(t) t.Run("test registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -324,7 +326,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("first callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -338,7 +340,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("same trigger as first one but different registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -351,7 +353,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("different capability callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -370,7 +372,7 @@ func TestStandardRandom(t *testing.T) { // Test binary executes node mode code conditionally based on the value >= 100 anyId := "Id" - gte100Exec := NewMockExecutionHelper(t) + gte100Exec := mocks.NewMockExecutionHelper(t) gte100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) gte100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -400,7 +402,7 @@ func TestStandardRandom(t *testing.T) { value1 := executeWithResult[any](t, m, anyRequest, gte100Exec) t.Run("Same execution id gives the same randoms even if random is called in node mode", func(t *testing.T) { - lt100Exec := NewMockExecutionHelper(t) + lt100Exec := mocks.NewMockExecutionHelper(t) lt100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) lt100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -424,7 +426,7 @@ func TestStandardRandom(t *testing.T) { t.Run("Different execution id give different randoms", func(t *testing.T) { require.NoError(t, err) - gte100Exec2 := NewMockExecutionHelper(t) + gte100Exec2 := mocks.NewMockExecutionHelper(t) gte100Exec2.EXPECT().GetWorkflowExecutionID().Return("differentId") gte100Exec2.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -470,7 +472,7 @@ func TestStandardSecrets(t *testing.T) { } func TestStandardSecretsFailInNodeMode(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -504,7 +506,7 @@ func TestStandardSecretsFailInNodeMode(t *testing.T) { func TestStandardTimeInterpretation(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Inject fixed timestamp: 1577934245000 milliseconds = 2020-01-02T03:04:05Z fixedTime := time.UnixMilli(1577934245000) @@ -523,6 +525,62 @@ func TestStandardTimeInterpretation(t *testing.T) { require.Equal(t, "2020-01-02T03:04:05Z", result) } +func TestStandardTeeRuntime(t *testing.T) { + t.Parallel() + + cfg := defaultNoDAGModCfg(t) + m := makeTestModuleWithConfig(t, cfg) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + + subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} + actual, err := m.Execute(t.Context(), subscribe, mockExecutionHelper) + require.NoError(t, err) + + payload0, err := anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + }) + require.NoError(t, err) + + payload1, err := anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + }) + require.NoError(t, err) + + expected := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: payload0, + Method: "Trigger", + Requirements: &sdk.Requirements{ + Tee: &sdk.Tee{ + Item: &sdk.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{ + {Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }, + }, + }, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: payload1, + Method: "Trigger", + }, + }, + } + + assertProto(t, expected, actual.GetTriggerSubscriptions()) +} + func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk.ExecuteRequest { wrappedTrigger, err := anypb.New(trigger) require.NoError(t, err) @@ -549,8 +607,12 @@ func runWithBasicTrigger(t *testing.T, executor ExecutionHelper) *sdk.ExecutionR // To re-use a binary, an outer test can create the module and use t.Run to run subtests using that module. // When subtests have their own binaries, those binaries are expected to be nested in a subfolder. func makeTestModule(t *testing.T) *module { + return makeTestModuleWithConfig(t, nil) +} + +func makeTestModuleWithConfig(t *testing.T, cfg *ModuleConfig) *module { testName := strcase.ToSnake(t.Name()[len("TestStandard"):]) - return makeTestModuleByName(t, testName, nil) + return makeTestModuleByName(t, testName, cfg) } func makeTestModuleByName(t *testing.T, testName string, cfg *ModuleConfig) *module { @@ -637,6 +699,7 @@ func wrapValue(t *testing.T, nodeResponse *nodeaction.NodeOutputs) *valuespb.Val func assertProto[T proto.Message](t *testing.T, expected, actual T) { t.Helper() + require.NotNil(t, actual) diff := cmp.Diff(expected, actual, protocmp.Transform()) var sb strings.Builder @@ -649,7 +712,7 @@ func assertProto[T proto.Message](t *testing.T, expected, actual T) { } func runSecretTest(t *testing.T, m *module, secretResponse *sdk.SecretResponse) *sdk.ExecutionResult { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go new file mode 100644 index 0000000000..a6fdfec748 --- /dev/null +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -0,0 +1,36 @@ +package main + +import ( + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func main() { + requirements := &sdk.Requirements{Tee: &sdk.Tee{Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} + subscription := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + })), + Method: "Trigger", + Requirements: requirements, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + })), + Method: "Trigger", + }, + }, + } + + rawsdk.SendSubscription(subscription) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go new file mode 100644 index 0000000000..bd31aa32c8 --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go @@ -0,0 +1,13 @@ +package main + +import ( + "unsafe" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + buf := make([]byte, 4) + rawsdk.Requirements(unsafe.Pointer(&buf[0]), 100) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go new file mode 100644 index 0000000000..1df7e1da8f --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + rawsdk.Requirements(rawsdk.BufferToPointerLen([]byte{0x3E, 0x80, 0xFF, 0x0A, 0xFF, 0x01, 0x0C, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01})) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/time_test.go b/pkg/workflows/wasm/host/time_test.go index b792588c99..7f7139b230 100644 --- a/pkg/workflows/wasm/host/time_test.go +++ b/pkg/workflows/wasm/host/time_test.go @@ -8,13 +8,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) func TestTimeFetcher_GetTime_NODE(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetNodeTime().Return(expected) @@ -29,7 +30,7 @@ func TestTimeFetcher_GetTime_NODE(t *testing.T) { func TestTimeFetcher_GetTime_DON(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetDONTime().Return(expected, nil) @@ -44,7 +45,7 @@ func TestTimeFetcher_GetTime_DON(t *testing.T) { func TestTimeFetcher_GetTime_DON_Error(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, errors.New("don error")) tf := newTimeFetcher(ctx, mockExec) @@ -58,7 +59,7 @@ func TestTimeFetcher_ContextCancelledBeforeRequest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, context.Canceled).Maybe() tf := newTimeFetcher(ctx, mockExec) @@ -81,7 +82,7 @@ func TestTimeFetcher_ContextCancelledDuringResponse(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Run(func() { time.Sleep(20 * time.Millisecond) // force timeout }).Return(time.Time{}, nil) diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index d7461b0e7a..5a2972671f 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -13,6 +13,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -44,7 +45,7 @@ func Test_Sleep_Timeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -77,7 +78,7 @@ func Test_Execute_CtxTimeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -109,7 +110,7 @@ func Test_Execute_CtxTimeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -139,7 +140,7 @@ func Test_Execute_CtxTimeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -198,7 +199,7 @@ func Test_NoDag_Run(t *testing.T) { func Test_NoDAG_LoggingWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -240,7 +241,7 @@ func Test_NoDAG_LoggingWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -286,7 +287,7 @@ func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricDisabled(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -321,7 +322,7 @@ func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { } func getTriggersSpec(t *testing.T, m ModuleV2, config []byte) (*sdk.TriggerSubscriptionRequest, error) { - helper := NewMockExecutionHelper(t) + helper := mocks.NewMockExecutionHelper(t) helper.EXPECT().GetWorkflowExecutionID().Return("Id") helper.EXPECT().GetNodeTime().Return(time.Now()).Maybe() execResult, err := m.Execute(t.Context(), &sdk.ExecuteRequest{