From 1ed02f2f0acb62d0c41fc8f8a25a59213fa3e076 Mon Sep 17 00:00:00 2001 From: Swarit Pandey Date: Thu, 14 May 2026 11:01:56 +0530 Subject: [PATCH] feat(aiagents): poll backend for hook enable/disable state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new internal/aiagents/state package and wires it into the scheduled telemetry tick so a UI toggle on the agent-api side converges to local install/uninstall on the next run. - state package owns the cache file ~/.stepsecurity/hooks-state.json, the HTTP fetcher against /developer-mdm-agent/features, and the Reconciler that ties fetch → cache write → idempotent install or uninstall together. - _hook hot path reads the cache before any work and short-circuits to the allow response when disabled. Missing or unparseable cache reads as enabled, so first-run after install keeps working. - main.go runs the reconciler after telemetry.Run in send-telemetry and install paths; community mode (no enterprise config) is a silent no-op. No agent-api changes needed: the existing feature key ai_agents_hooks_install and the GET /developer-mdm-agent/features endpoint already serve the resolved state. --- cmd/stepsecurity-dev-machine-guard/main.go | 52 ++++++++ internal/aiagents/hook/runtime.go | 11 ++ internal/aiagents/hook/runtime_test.go | 44 +++++++ internal/aiagents/state/cache.go | 126 ++++++++++++++++++++ internal/aiagents/state/cache_test.go | 122 +++++++++++++++++++ internal/aiagents/state/doc.go | 20 ++++ internal/aiagents/state/fetcher.go | 132 +++++++++++++++++++++ internal/aiagents/state/fetcher_test.go | 116 ++++++++++++++++++ internal/aiagents/state/reconciler.go | 88 ++++++++++++++ internal/aiagents/state/reconciler_test.go | 125 +++++++++++++++++++ internal/aiagents/state/state.go | 39 ++++++ internal/aiagents/state/state_test.go | 12 ++ 12 files changed, 887 insertions(+) create mode 100644 internal/aiagents/state/cache.go create mode 100644 internal/aiagents/state/cache_test.go create mode 100644 internal/aiagents/state/doc.go create mode 100644 internal/aiagents/state/fetcher.go create mode 100644 internal/aiagents/state/fetcher_test.go create mode 100644 internal/aiagents/state/reconciler.go create mode 100644 internal/aiagents/state/reconciler_test.go create mode 100644 internal/aiagents/state/state.go create mode 100644 internal/aiagents/state/state_test.go diff --git a/cmd/stepsecurity-dev-machine-guard/main.go b/cmd/stepsecurity-dev-machine-guard/main.go index 47f9848..91fd067 100644 --- a/cmd/stepsecurity-dev-machine-guard/main.go +++ b/cmd/stepsecurity-dev-machine-guard/main.go @@ -7,8 +7,11 @@ import ( "io" "os" "runtime" + "time" aiagentscli "github.com/step-security/dev-machine-guard/internal/aiagents/cli" + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" + "github.com/step-security/dev-machine-guard/internal/aiagents/state" "github.com/step-security/dev-machine-guard/internal/buildinfo" "github.com/step-security/dev-machine-guard/internal/cli" "github.com/step-security/dev-machine-guard/internal/config" @@ -24,6 +27,12 @@ import ( "github.com/step-security/dev-machine-guard/internal/telemetry" ) +// hookReconcileTimeout caps the entire reconcile step (fetch + cache +// write + install/uninstall). Generous because install can chown a +// handful of files under root; the actual GET cost is bounded by +// state.DefaultFetchTimeout. +const hookReconcileTimeout = 30 * time.Second + func main() { // Hook hot path. Agents invoke `_hook` on every event and any non-zero // exit is treated as a hook failure / block — so we MUST exit 0 even on @@ -121,6 +130,7 @@ func main() { log.Error("%v", err) os.Exit(1) } + runHookStateReconcile(exec, log) case "install": _, _ = fmt.Fprintf(os.Stdout, "StepSecurity Dev Machine Guard v%s\n\n", buildinfo.Version) @@ -168,6 +178,7 @@ func main() { log.Error("%v", telemetryErr) os.Exit(1) } + runHookStateReconcile(exec, log) case "uninstall": _, _ = fmt.Fprintf(os.Stdout, "StepSecurity Dev Machine Guard v%s\n\n", buildinfo.Version) @@ -299,3 +310,44 @@ func scanJSONEncoder(w io.Writer) *json.Encoder { enc.SetEscapeHTML(false) return enc } + +// runHookStateReconcile polls agent-api for the desired AI-agent hook +// state and reconciles local hook installation to match. Silent no-op +// in community mode (enterprise config missing) — the existing scan +// path stays unaffected. Failures are logged but never crash main. +func runHookStateReconcile(exec executor.Executor, log *progress.Logger) { + cfg, ok := ingest.Snapshot() + if !ok { + log.Debug("hook-state reconcile: skipped (no enterprise config)") + return + } + fetcher, ok := state.NewHTTPFetcher(cfg, nil) + if !ok { + log.Debug("hook-state reconcile: skipped (fetcher init refused config)") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), hookReconcileTimeout) + defer cancel() + + dev := device.Gather(ctx, exec) + if dev.SerialNumber == "" || dev.SerialNumber == "unknown" { + log.Warn("hook-state reconcile: device serial unresolved; skipping") + return + } + + r := &state.Reconciler{ + Exec: exec, + Fetcher: fetcher, + CustomerID: cfg.CustomerID, + DeviceID: dev.SerialNumber, + Stdout: os.Stdout, + Stderr: os.Stderr, + InstallFn: aiagentscli.RunInstall, + UninstallFn: aiagentscli.RunUninstall, + } + if err := r.Reconcile(ctx); err != nil { + log.Warn("hook-state reconcile: %v", err) + aiagentscli.AppendError("reconcile", "reconcile_failed", err.Error(), "") + } +} diff --git a/internal/aiagents/hook/runtime.go b/internal/aiagents/hook/runtime.go index 3c07995..57ce2b0 100644 --- a/internal/aiagents/hook/runtime.go +++ b/internal/aiagents/hook/runtime.go @@ -27,6 +27,7 @@ import ( "github.com/step-security/dev-machine-guard/internal/aiagents/identity" "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" "github.com/step-security/dev-machine-guard/internal/aiagents/redact" + "github.com/step-security/dev-machine-guard/internal/aiagents/state" "github.com/step-security/dev-machine-guard/internal/executor" ) @@ -115,6 +116,16 @@ func (rt *Runtime) Run(parent context.Context, hookType event.HookEvent) error { var ev *event.Event defer func() { rt.emitDecidedResponse(ev, decision) }() + // Server-driven kill switch. The reconciler writes ~/.stepsecurity/ + // hooks-state.json on every telemetry tick; a UI flip propagates + // to disabled here in O(1) microseconds and short-circuits the + // hot path before any enrichment, identity probe, or upload runs. + // Missing/corrupt cache returns Default()=enabled so first-run + // after install continues to work. + if cur, _ := state.Read(); !cur.Hooks.Enabled { + return nil + } + cfg, _ := ingest.Snapshot() id := identity.Resolve(ctx, rt.Exec, cfg.CustomerID) upload := rt.resolveUpload() diff --git a/internal/aiagents/hook/runtime_test.go b/internal/aiagents/hook/runtime_test.go index 0e240bf..449e6f6 100644 --- a/internal/aiagents/hook/runtime_test.go +++ b/internal/aiagents/hook/runtime_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "io" + "path/filepath" "strings" "sync" "testing" @@ -13,6 +14,7 @@ import ( cc "github.com/step-security/dev-machine-guard/internal/aiagents/adapter/claudecode" "github.com/step-security/dev-machine-guard/internal/aiagents/event" + "github.com/step-security/dev-machine-guard/internal/aiagents/state" "github.com/step-security/dev-machine-guard/internal/executor" ) @@ -420,6 +422,48 @@ func TestRunUploadFailureFailsOpen(t *testing.T) { // When no UploadEvent is wired (any runtime without enterprise config), // the runtime must still complete — just with no upload attempt. +// withDisabledStateCache writes a disabled state file and restores +// the override on cleanup. Returns the path for assertions. +func withDisabledStateCache(t *testing.T) { + t.Helper() + dir := t.TempDir() + path := dir + string(filepath.Separator) + state.CacheFilename + restore := state.SetCachePathForTest(path) + t.Cleanup(restore) + s := state.Default() + s.Hooks.Enabled = false + s.Source = state.SourcePoll + if err := state.Write(s); err != nil { + t.Fatalf("seed disabled cache: %v", err) + } +} + +func TestRunHonorsDisabledStateCache(t *testing.T) { + withDisabledStateCache(t) + + stdin := strings.NewReader(`{ + "session_id":"abc","cwd":"/tmp","tool_name":"Bash", + "tool_input":{"command":"npm install lodash","cwd":"/tmp"} + }`) + var stdout, stderr bytes.Buffer + rt, cap := newRuntime(t, stdin, &stdout, &stderr) + + if err := rt.Run(context.Background(), event.HookPreToolUse); err != nil { + t.Fatalf("Run: %v", err) + } + // Allow response still emitted (fail-open contract). + if !strings.HasPrefix(strings.TrimSpace(stdout.String()), "{") { + t.Errorf("stdout should still be the allow JSON: %q", stdout.String()) + } + // No upload, no enrichment, no error log lines. + if len(cap.events) != 0 { + t.Fatalf("disabled cache must short-circuit upload; got %d events", len(cap.events)) + } + if len(cap.errs) != 0 { + t.Fatalf("disabled cache must not log errors; got %v", cap.errs) + } +} + func TestRunSkipsUploadWithoutSeam(t *testing.T) { stdin := strings.NewReader(`{ "session_id":"s","cwd":"/tmp","tool_name":"Bash","tool_input":{"command":"ls"} diff --git a/internal/aiagents/state/cache.go b/internal/aiagents/state/cache.go new file mode 100644 index 0000000..8e6e886 --- /dev/null +++ b/internal/aiagents/state/cache.go @@ -0,0 +1,126 @@ +package state + +import ( + "encoding/json" + "os" + "path/filepath" +) + +// CacheFilename is the basename of the cache file. Lives under +// ~/.stepsecurity/ alongside config.json and ai-agent-hook-errors.jsonl. +const CacheFilename = "hooks-state.json" + +const ( + cacheFileMode os.FileMode = 0o600 + cacheParentDirMode os.FileMode = 0o700 +) + +// cachePathOverride lets tests redirect reads/writes to a tempdir. +// Production leaves it empty. Mutating from outside this package is +// a test-only concern; same pattern as cli.errorLogPathOverride. +var cachePathOverride string + +// SetCachePathForTest redirects CachePath() to the given absolute path +// and returns a restore function. Test-only; production code never +// calls this. Living on the package surface (rather than as a +// build-tagged file) keeps cross-package tests in hook/* and main_test +// able to drive the override without an internal-import trick. +func SetCachePathForTest(p string) (restore func()) { + prev := cachePathOverride + cachePathOverride = p + return func() { cachePathOverride = prev } +} + +// CachePath returns the absolute cache path, honoring the test +// override when set. +func CachePath() string { + if cachePathOverride != "" { + return cachePathOverride + } + home, err := os.UserHomeDir() + if err != nil || home == "" { + return "" + } + return filepath.Join(home, ".stepsecurity", CacheFilename) +} + +// Read returns (state, true) on a successful parse. Any I/O or parse +// error returns (Default(), false) — Read never surfaces an error +// because the hot path must remain fail-open. +func Read() (State, bool) { + path := CachePath() + if path == "" { + return Default(), false + } + // #nosec G304 -- path is CachePath(): either a test override set by + // SetCachePathForTest, or os.UserHomeDir() joined with the package + // constant CacheFilename. Never derived from external input. + b, err := os.ReadFile(path) + if err != nil { + return Default(), false + } + var s State + if err := json.Unmarshal(b, &s); err != nil { + return Default(), false + } + if s.SchemaVersion == 0 { + // Forward-compat tolerance: missing schema_version reads as the + // current version. A future breaking change would gate on a + // specific value here. + s.SchemaVersion = SchemaVersion + } + return s, true +} + +// Write atomically replaces the cache file. No backups are kept — the +// cache is rewritten on every reconcile tick, and orphaned backups +// would accumulate trash. Parent dir is created with 0o700. +func Write(s State) error { + path := CachePath() + if path == "" { + return errNoHomeDir + } + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return err + } + data = append(data, '\n') + + parent := filepath.Dir(path) + if err := os.MkdirAll(parent, cacheParentDirMode); err != nil { + return err + } + + tmp, err := os.CreateTemp(parent, "."+CacheFilename+".tmp-*") + if err != nil { + return err + } + tmpPath := tmp.Name() + defer func() { + if _, statErr := os.Stat(tmpPath); statErr == nil { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Chmod(tmpPath, cacheFileMode); err != nil { + return err + } + return os.Rename(tmpPath, path) +} + +type cacheError string + +func (e cacheError) Error() string { return string(e) } + +const errNoHomeDir = cacheError("state: cannot resolve home directory") diff --git a/internal/aiagents/state/cache_test.go b/internal/aiagents/state/cache_test.go new file mode 100644 index 0000000..59d8760 --- /dev/null +++ b/internal/aiagents/state/cache_test.go @@ -0,0 +1,122 @@ +package state + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + "time" +) + +// withTempCache redirects CachePath to a tempdir for the duration of +// the test. Returns the absolute path the cache will be written to. +func withTempCache(t *testing.T) string { + t.Helper() + dir := t.TempDir() + p := filepath.Join(dir, CacheFilename) + prev := cachePathOverride + cachePathOverride = p + t.Cleanup(func() { cachePathOverride = prev }) + return p +} + +func TestReadMissingFileReturnsDefault(t *testing.T) { + withTempCache(t) + s, ok := Read() + if ok { + t.Fatal("Read of missing file should report ok=false") + } + if !s.Hooks.Enabled { + t.Fatal("missing-file Read must yield Default (enabled)") + } +} + +func TestWriteThenReadRoundTrip(t *testing.T) { + withTempCache(t) + in := State{ + SchemaVersion: SchemaVersion, + FetchedAt: time.Date(2026, 5, 14, 8, 0, 0, 0, time.UTC), + Source: SourcePoll, + Hooks: Hooks{Enabled: false}, + } + if err := Write(in); err != nil { + t.Fatalf("Write: %v", err) + } + out, ok := Read() + if !ok { + t.Fatal("Read after Write should report ok=true") + } + if out.Hooks.Enabled != false || out.Source != SourcePoll { + t.Fatalf("round-trip mismatch: %+v", out) + } + if !out.FetchedAt.Equal(in.FetchedAt) { + t.Fatalf("FetchedAt drift: got %v want %v", out.FetchedAt, in.FetchedAt) + } +} + +func TestReadMalformedReturnsDefault(t *testing.T) { + path := withTempCache(t) + if err := os.WriteFile(path, []byte("not json"), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + s, ok := Read() + if ok { + t.Fatal("malformed file should report ok=false") + } + if !s.Hooks.Enabled { + t.Fatal("malformed Read must yield Default (enabled)") + } +} + +func TestWriteFileMode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file mode bits not meaningful on Windows") + } + path := withTempCache(t) + if err := Write(Default()); err != nil { + t.Fatalf("Write: %v", err) + } + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if got := info.Mode().Perm(); got != cacheFileMode { + t.Fatalf("mode = %o, want %o", got, cacheFileMode) + } +} + +func TestWriteReplacesExistingFile(t *testing.T) { + path := withTempCache(t) + if err := os.WriteFile(path, []byte(`{"hooks":{"enabled":true}}`), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + next := Default() + next.Hooks.Enabled = false + if err := Write(next); err != nil { + t.Fatalf("Write: %v", err) + } + out, ok := Read() + if !ok || out.Hooks.Enabled { + t.Fatalf("expected disabled after rewrite, got %+v (ok=%v)", out, ok) + } +} + +func TestReadForwardCompatMissingSchemaVersion(t *testing.T) { + path := withTempCache(t) + raw := map[string]any{"hooks": map[string]any{"enabled": false}} + b, _ := json.Marshal(raw) + if err := os.WriteFile(path, b, 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + out, ok := Read() + if !ok { + t.Fatal("legacy-shape file should still parse") + } + if out.SchemaVersion != SchemaVersion { + t.Fatalf("schema_version should default to %d, got %d", SchemaVersion, out.SchemaVersion) + } + if out.Hooks.Enabled { + t.Fatal("legacy disabled value should round-trip") + } +} diff --git a/internal/aiagents/state/doc.go b/internal/aiagents/state/doc.go new file mode 100644 index 0000000..ae3fde6 --- /dev/null +++ b/internal/aiagents/state/doc.go @@ -0,0 +1,20 @@ +// Package state owns the server-driven hook enable/disable cache. +// +// Flow: +// +// scheduled tick / install ──▶ Reconciler.Reconcile +// │ +// ├─ Fetcher.Fetch (GET /developer-mdm-agent/features) +// ├─ cache.Write (~/.stepsecurity/hooks-state.json) +// └─ InstallFn / UninstallFn (idempotent) +// +// _hook hot path ──▶ cache.Read ──▶ short-circuit to allow if disabled +// +// The cache file is the single source of truth for the hot path. Both +// the polling reconciler (this package) and any future WebSocket +// transport are expected to converge on the same file, so the hot path +// never has to know which transport is active. +// +// Defaults: cache missing or unparseable ⇒ Default() (enabled). Hot path +// is fail-open by contract; a corrupt cache must not break the agent. +package state diff --git a/internal/aiagents/state/fetcher.go b/internal/aiagents/state/fetcher.go new file mode 100644 index 0000000..532f5e4 --- /dev/null +++ b/internal/aiagents/state/fetcher.go @@ -0,0 +1,132 @@ +package state + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" + "github.com/step-security/dev-machine-guard/internal/aiagents/redact" + "github.com/step-security/dev-machine-guard/internal/buildinfo" +) + +// FeatureKeyHooks mirrors agent-api's DeveloperMDMFeatureAIAgentsHooksInstall. +// Constant in both repos must match for the toggle to plumb through. +const FeatureKeyHooks = "ai_agents_hooks_install" + +// DefaultFetchTimeout caps a single Fetch round-trip. Matches +// ingest.DefaultHookUploadTimeout for consistency with the existing +// HTTP timeout budget; the reconciler runs off the scheduled tick, +// not the hot path, so a 5s ceiling is comfortable. +const DefaultFetchTimeout = 5 * time.Second + +// maxBodyBytes bounds the response read to avoid memory bloat from a +// pathological backend. The real payload is < 1 KiB; 64 KiB is plenty +// of slack. +const maxBodyBytes = 64 * 1024 + +// FetchResult is what one successful Fetch resolves to. Today there is +// only a single toggle; future fields land here as the contract grows. +type FetchResult struct { + Enabled bool +} + +// Fetcher returns the desired feature state for one device. +type Fetcher interface { + Fetch(ctx context.Context, customerID, deviceID string) (FetchResult, error) +} + +// HTTPFetcher is the production Fetcher. Safe for concurrent use; the +// underlying *http.Client owns connection state. +type HTTPFetcher struct { + endpoint string + apiKey string + http *http.Client +} + +// NewHTTPFetcher constructs a Fetcher from a strict enterprise config +// (the same gate the upload path uses). Returns ok=false when the +// config is incomplete — the caller treats nil as "skip reconcile", +// matching how upload disables itself in community mode. +func NewHTTPFetcher(cfg ingest.Config, h *http.Client) (*HTTPFetcher, bool) { + endpoint := strings.TrimSpace(cfg.APIEndpoint) + apiKey := strings.TrimSpace(cfg.APIKey) + if endpoint == "" || apiKey == "" { + return nil, false + } + if h == nil { + h = &http.Client{Timeout: DefaultFetchTimeout} + } + return &HTTPFetcher{ + endpoint: strings.TrimRight(endpoint, "/"), + apiKey: apiKey, + http: h, + }, true +} + +// featuresEnvelope is the agent-api response shape: +// +// {"features": {"ai_agents_hooks_install": {"enabled": bool}, ...}} +type featuresEnvelope struct { + Features map[string]featureState `json:"features"` +} + +type featureState struct { + Enabled bool `json:"enabled"` +} + +// Fetch issues the GET against /developer-mdm-agent/features. Missing +// feature key in the response ⇒ Enabled=false (matches server-side +// baseline-disabled default; backend keeps unset features out of the +// map rather than emitting a zero entry). +func (c *HTTPFetcher) Fetch(ctx context.Context, customerID, deviceID string) (FetchResult, error) { + if c == nil { + return FetchResult{}, errors.New("state: nil fetcher") + } + if strings.TrimSpace(customerID) == "" { + return FetchResult{}, errors.New("state: empty customer_id") + } + if strings.TrimSpace(deviceID) == "" { + return FetchResult{}, errors.New("state: empty device_id") + } + + endpoint := c.endpoint + + "/v1/" + url.PathEscape(customerID) + + "/developer-mdm-agent/features?device_id=" + url.QueryEscape(deviceID) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return FetchResult{}, fmt.Errorf("state: build request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "dmg/"+buildinfo.Version) + + resp, err := c.http.Do(req) + if err != nil { + return FetchResult{}, fmt.Errorf("state: transport: %s", redact.String(err.Error())) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, maxBodyBytes)) + return FetchResult{}, fmt.Errorf("state: unexpected status %d: %s", + resp.StatusCode, redact.String(strings.TrimSpace(string(snippet)))) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodyBytes)) + if err != nil { + return FetchResult{}, fmt.Errorf("state: read body: %w", err) + } + var env featuresEnvelope + if err := json.Unmarshal(body, &env); err != nil { + return FetchResult{}, fmt.Errorf("state: decode body: %w", err) + } + return FetchResult{Enabled: env.Features[FeatureKeyHooks].Enabled}, nil +} diff --git a/internal/aiagents/state/fetcher_test.go b/internal/aiagents/state/fetcher_test.go new file mode 100644 index 0000000..2e837b0 --- /dev/null +++ b/internal/aiagents/state/fetcher_test.go @@ -0,0 +1,116 @@ +package state + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/step-security/dev-machine-guard/internal/aiagents/ingest" +) + +func newTestServer(t *testing.T, status int, body string) (*httptest.Server, *HTTPFetcher) { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization = %q, want Bearer test-key", got) + } + if got := r.URL.Query().Get("device_id"); got != "dev-1" { + t.Errorf("device_id = %q, want dev-1", got) + } + if !strings.Contains(r.URL.Path, "/developer-mdm-agent/features") { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, _ = w.Write([]byte(body)) + })) + t.Cleanup(srv.Close) + + f, ok := NewHTTPFetcher(ingest.Config{APIEndpoint: srv.URL, APIKey: "test-key"}, srv.Client()) + if !ok { + t.Fatal("NewHTTPFetcher returned ok=false on valid config") + } + return srv, f +} + +func TestFetcherEnabled(t *testing.T) { + _, f := newTestServer(t, 200, `{"features":{"ai_agents_hooks_install":{"enabled":true}}}`) + res, err := f.Fetch(context.Background(), "cust", "dev-1") + if err != nil { + t.Fatalf("Fetch: %v", err) + } + if !res.Enabled { + t.Fatal("Enabled should be true") + } +} + +func TestFetcherDisabled(t *testing.T) { + _, f := newTestServer(t, 200, `{"features":{"ai_agents_hooks_install":{"enabled":false}}}`) + res, err := f.Fetch(context.Background(), "cust", "dev-1") + if err != nil { + t.Fatalf("Fetch: %v", err) + } + if res.Enabled { + t.Fatal("Enabled should be false") + } +} + +func TestFetcherMissingKeyMeansDisabled(t *testing.T) { + // Server omits the key entirely: the agent-api baseline-disabled + // default keeps features out of the map until a customer-level + // override exists. + _, f := newTestServer(t, 200, `{"features":{}}`) + res, err := f.Fetch(context.Background(), "cust", "dev-1") + if err != nil { + t.Fatalf("Fetch: %v", err) + } + if res.Enabled { + t.Fatal("missing feature key must read as disabled") + } +} + +func TestFetcherNon200IsError(t *testing.T) { + _, f := newTestServer(t, 500, `{"error":"boom"}`) + if _, err := f.Fetch(context.Background(), "cust", "dev-1"); err == nil { + t.Fatal("5xx should propagate as error") + } +} + +func TestFetcherUnauthorizedIsError(t *testing.T) { + _, f := newTestServer(t, 401, `{"error":"unauth"}`) + if _, err := f.Fetch(context.Background(), "cust", "dev-1"); err == nil { + t.Fatal("401 should propagate as error") + } +} + +func TestFetcherMalformedBodyIsError(t *testing.T) { + _, f := newTestServer(t, 200, `not json`) + if _, err := f.Fetch(context.Background(), "cust", "dev-1"); err == nil { + t.Fatal("malformed body should propagate as error") + } +} + +func TestFetcherEmptyCustomerIDIsError(t *testing.T) { + _, f := newTestServer(t, 200, `{"features":{}}`) + if _, err := f.Fetch(context.Background(), "", "dev-1"); err == nil { + t.Fatal("empty customer should error") + } +} + +func TestFetcherEmptyDeviceIDIsError(t *testing.T) { + _, f := newTestServer(t, 200, `{"features":{}}`) + if _, err := f.Fetch(context.Background(), "cust", ""); err == nil { + t.Fatal("empty device should error") + } +} + +func TestNewHTTPFetcherRejectsIncompleteConfig(t *testing.T) { + if _, ok := NewHTTPFetcher(ingest.Config{APIEndpoint: "", APIKey: "k"}, nil); ok { + t.Fatal("missing endpoint should yield ok=false") + } + if _, ok := NewHTTPFetcher(ingest.Config{APIEndpoint: "https://x", APIKey: ""}, nil); ok { + t.Fatal("missing api key should yield ok=false") + } +} diff --git a/internal/aiagents/state/reconciler.go b/internal/aiagents/state/reconciler.go new file mode 100644 index 0000000..03cfaf5 --- /dev/null +++ b/internal/aiagents/state/reconciler.go @@ -0,0 +1,88 @@ +package state + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +// HookCommandFn is the install/uninstall seam shape. Production wires +// these to internal/aiagents/cli.RunInstall and .RunUninstall in +// main.go (state can't import cli without a cycle, so the seam stays +// a plain function type). +type HookCommandFn func(ctx context.Context, exec executor.Executor, agent string, stdout, stderr io.Writer) int + +// Reconciler turns a desired enable/disable into local actions. One +// instance per main.go call site; the struct holds the wiring and no +// per-call state. +type Reconciler struct { + Exec executor.Executor + Fetcher Fetcher + CustomerID string + DeviceID string + Agent string // "" = every detected agent + Stdout io.Writer + Stderr io.Writer + InstallFn HookCommandFn + UninstallFn HookCommandFn + Now func() time.Time +} + +// Reconcile fetches desired state, writes the cache, and converges +// settings to match by calling InstallFn / UninstallFn. Both are +// idempotent so we don't need to detect the current state — install +// is a no-op when entries are already in place, uninstall is a no-op +// when no DMG-owned entries exist. +// +// Order: cache write first, then settings reconciliation. If the +// settings reconciliation fails, the cache still reflects the desired +// state — the hot path honors the new value immediately, and the next +// tick retries the settings change. +// +// Errors are returned to the caller for logging via cli.AppendError; +// Reconcile itself never panics into the caller's hot path. +func (r *Reconciler) Reconcile(ctx context.Context) error { + if r.Fetcher == nil { + return errors.New("state: nil fetcher") + } + + res, err := r.Fetcher.Fetch(ctx, r.CustomerID, r.DeviceID) + if err != nil { + return fmt.Errorf("state: fetch: %w", err) + } + + now := time.Now().UTC + if r.Now != nil { + now = r.Now + } + next := Default() + next.FetchedAt = now() + next.Source = SourcePoll + next.Hooks.Enabled = res.Enabled + + if err := Write(next); err != nil { + return fmt.Errorf("state: write cache: %w", err) + } + + switch { + case res.Enabled: + if r.InstallFn == nil { + return errors.New("state: nil InstallFn") + } + if code := r.InstallFn(ctx, r.Exec, r.Agent, r.Stdout, r.Stderr); code != 0 { + return fmt.Errorf("state: install exited %d", code) + } + default: + if r.UninstallFn == nil { + return errors.New("state: nil UninstallFn") + } + if code := r.UninstallFn(ctx, r.Exec, r.Agent, r.Stdout, r.Stderr); code != 0 { + return fmt.Errorf("state: uninstall exited %d", code) + } + } + return nil +} diff --git a/internal/aiagents/state/reconciler_test.go b/internal/aiagents/state/reconciler_test.go new file mode 100644 index 0000000..7439c6e --- /dev/null +++ b/internal/aiagents/state/reconciler_test.go @@ -0,0 +1,125 @@ +package state + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/step-security/dev-machine-guard/internal/executor" +) + +type fakeFetcher struct { + res FetchResult + err error +} + +func (f *fakeFetcher) Fetch(_ context.Context, _, _ string) (FetchResult, error) { + return f.res, f.err +} + +type callRec struct { + calls []string + codes []int + exit int +} + +func (r *callRec) fn(name string) HookCommandFn { + return func(_ context.Context, _ executor.Executor, _ string, _, _ io.Writer) int { + r.calls = append(r.calls, name) + r.codes = append(r.codes, r.exit) + return r.exit + } +} + +func newReconciler(t *testing.T, fetch FetchResult, fetchErr error, exitCode int) (*Reconciler, *callRec) { + t.Helper() + withTempCache(t) + rec := &callRec{exit: exitCode} + return &Reconciler{ + Exec: executor.NewMock(), + Fetcher: &fakeFetcher{res: fetch, err: fetchErr}, + CustomerID: "cust", + DeviceID: "dev-1", + Stdout: io.Discard, + Stderr: io.Discard, + InstallFn: rec.fn("install"), + UninstallFn: rec.fn("uninstall"), + Now: func() time.Time { return time.Date(2026, 5, 14, 8, 0, 0, 0, time.UTC) }, + }, rec +} + +func TestReconcileEnabledCallsInstallAndWritesCache(t *testing.T) { + r, rec := newReconciler(t, FetchResult{Enabled: true}, nil, 0) + if err := r.Reconcile(context.Background()); err != nil { + t.Fatalf("Reconcile: %v", err) + } + if len(rec.calls) != 1 || rec.calls[0] != "install" { + t.Fatalf("calls = %v, want [install]", rec.calls) + } + s, ok := Read() + if !ok { + t.Fatal("cache should be written") + } + if !s.Hooks.Enabled || s.Source != SourcePoll { + t.Fatalf("cache = %+v", s) + } +} + +func TestReconcileDisabledCallsUninstall(t *testing.T) { + r, rec := newReconciler(t, FetchResult{Enabled: false}, nil, 0) + if err := r.Reconcile(context.Background()); err != nil { + t.Fatalf("Reconcile: %v", err) + } + if len(rec.calls) != 1 || rec.calls[0] != "uninstall" { + t.Fatalf("calls = %v, want [uninstall]", rec.calls) + } + s, _ := Read() + if s.Hooks.Enabled { + t.Fatal("cache should record disabled") + } +} + +func TestReconcileFetchErrorPreservesCache(t *testing.T) { + r, rec := newReconciler(t, FetchResult{}, errors.New("network down"), 0) + // Seed prior cache so we can verify it's untouched. + prior := Default() + prior.Hooks.Enabled = false + prior.Source = SourcePoll + if err := Write(prior); err != nil { + t.Fatalf("seed: %v", err) + } + + if err := r.Reconcile(context.Background()); err == nil { + t.Fatal("Reconcile should surface fetch error") + } + if len(rec.calls) != 0 { + t.Fatalf("no install/uninstall on fetch error; got %v", rec.calls) + } + s, ok := Read() + if !ok || s.Hooks.Enabled || s.Source != SourcePoll { + t.Fatalf("cache should be untouched, got %+v ok=%v", s, ok) + } +} + +func TestReconcileInstallFailureSurfacesError(t *testing.T) { + r, _ := newReconciler(t, FetchResult{Enabled: true}, nil, 1) + err := r.Reconcile(context.Background()) + if err == nil { + t.Fatal("non-zero install exit should surface as error") + } + // Cache should still be written — settings retry is the next tick's job. + s, ok := Read() + if !ok || !s.Hooks.Enabled { + t.Fatalf("cache should still reflect desired state, got %+v ok=%v", s, ok) + } +} + +func TestReconcileNilFetcherIsError(t *testing.T) { + withTempCache(t) + r := &Reconciler{} + if err := r.Reconcile(context.Background()); err == nil { + t.Fatal("nil fetcher should error") + } +} diff --git a/internal/aiagents/state/state.go b/internal/aiagents/state/state.go new file mode 100644 index 0000000..4781c5c --- /dev/null +++ b/internal/aiagents/state/state.go @@ -0,0 +1,39 @@ +package state + +import "time" + +// SchemaVersion is the wire/disk version of the cache file. Bump only +// on a breaking shape change; older daemons keep parsing v1. +const SchemaVersion = 1 + +// Source values record where the cache write came from. Diagnostic +// only — the hot path never branches on Source. +const ( + SourcePoll = "poll" + SourceManual = "manual" + SourceInstall = "install" + SourceWebsocket = "websocket" // reserved +) + +// State is the on-disk cache shape. JSON keys are the wire format. +type State struct { + SchemaVersion int `json:"schema_version"` + FetchedAt time.Time `json:"fetched_at"` + Source string `json:"source,omitempty"` + Hooks Hooks `json:"hooks"` +} + +// Hooks carries the feature toggles the hot path reads. Today there's +// only Enabled; per-agent granularity goes here when the contract +// grows. +type Hooks struct { + Enabled bool `json:"enabled"` +} + +// Default is the in-memory fallback for any read failure. Enabled is +// true: the hot path only runs because settings carry a DMG entry, so +// defaulting to disabled would silently turn off first-run after +// install (before the reconciler has had a chance to write the cache). +func Default() State { + return State{SchemaVersion: SchemaVersion, Hooks: Hooks{Enabled: true}} +} diff --git a/internal/aiagents/state/state_test.go b/internal/aiagents/state/state_test.go new file mode 100644 index 0000000..cacc30b --- /dev/null +++ b/internal/aiagents/state/state_test.go @@ -0,0 +1,12 @@ +package state + +import "testing" + +func TestDefaultIsEnabled(t *testing.T) { + if !Default().Hooks.Enabled { + t.Fatal("Default() must be enabled; otherwise first-run after install breaks") + } + if Default().SchemaVersion != SchemaVersion { + t.Fatalf("Default schema_version = %d, want %d", Default().SchemaVersion, SchemaVersion) + } +}