Skip to content
Closed
7 changes: 4 additions & 3 deletions pkg/cli/mcp_server_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cli

import (
"fmt"
"log"
"net/http"
"os"
"strings"
Expand All @@ -13,6 +12,8 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)

var mcpServerHTTPLog = logger.New("cli:mcp_server_http")

// sanitizeForLog removes newline and carriage return characters from user input
// to prevent log injection attacks where malicious users could forge log entries.
func sanitizeForLog(input string) string {
Expand All @@ -39,7 +40,7 @@ func loggingHandler(handler http.Handler) http.Handler {
sanitizedPath := sanitizeForLog(r.URL.Path)

// Log request details.
log.Printf("[REQUEST] %s | %s | %s %s",
mcpServerHTTPLog.Printf("[REQUEST] %s | %s | %s %s",
start.Format(time.RFC3339),
r.RemoteAddr,
r.Method,
Expand All @@ -50,7 +51,7 @@ func loggingHandler(handler http.Handler) http.Handler {

// Log response details.
duration := time.Since(start)
log.Printf("[RESPONSE] %s | %s | %s %s | Status: %d | Duration: %v",
mcpServerHTTPLog.Printf("[RESPONSE] %s | %s | %s %s | Status: %d | Duration: %v",
time.Now().Format(time.RFC3339),
r.RemoteAddr,
r.Method,
Expand Down
10 changes: 5 additions & 5 deletions pkg/envutil/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This package centralizes the pattern of reading integer-valued environment varia

## Public API

### `GetIntFromEnv(envVar string, defaultValue, minValue, maxValue int, log *logger.Logger) int`
### `GetIntFromEnv(envVar string, defaultValue, minValue, maxValue int, debugLog *logger.Logger) int`

Reads an integer-valued environment variable, validates it against `[minValue, maxValue]`, and returns `defaultValue` when the variable is absent, unparseable, or out of range. A warning is emitted to `os.Stderr` when the value is invalid.

Expand All @@ -18,7 +18,7 @@ Reads an integer-valued environment variable, validates it against `[minValue, m
| `defaultValue` | `int` | Value returned when env var is absent or invalid |
| `minValue` | `int` | Minimum allowed value (inclusive) |
| `maxValue` | `int` | Maximum allowed value (inclusive) |
| `log` | `*logger.Logger` | Optional logger for debug output; pass `nil` to disable |
| `debugLog` | `*logger.Logger` | Optional logger for debug output; pass `nil` to disable |

## Usage Examples

Expand All @@ -28,10 +28,10 @@ import (
"github.com/github/gh-aw/pkg/logger"
)

var log = logger.New("mypackage:config")
var debugLog = logger.New("mypackage:config")

// Read GH_AW_MAX_CONCURRENT_DOWNLOADS, constrained to [1, 20], default 5
concurrency := envutil.GetIntFromEnv("GH_AW_MAX_CONCURRENT_DOWNLOADS", 5, 1, 20, log)
concurrency := envutil.GetIntFromEnv("GH_AW_MAX_CONCURRENT_DOWNLOADS", 5, 1, 20, debugLog)

// Suppress debug output by passing nil logger
timeout := envutil.GetIntFromEnv("GH_AW_TIMEOUT", 60, 1, 3600, nil)
Expand All @@ -41,7 +41,7 @@ timeout := envutil.GetIntFromEnv("GH_AW_TIMEOUT", 60, 1, 3600, nil)
- Returns `defaultValue` when the environment variable is not set.
- Returns `defaultValue` and emits a warning when the value cannot be parsed as an integer.
- Returns `defaultValue` and emits a warning when the value is outside `[minValue, maxValue]`.
- Logs the accepted value at debug level when `log` is non-nil.
- Logs the accepted value at debug level when `debugLog` is non-nil.

## Dependencies

Expand Down
12 changes: 6 additions & 6 deletions pkg/envutil/envutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ import (
// - defaultValue: The default value to return if env var is not set or invalid
// - minValue: Minimum allowed value (inclusive)
// - maxValue: Maximum allowed value (inclusive)
// - log: Optional logger for debug output
// - debugLog: Optional logger for debug output
//
// Returns the parsed integer value, or defaultValue if:
// - Environment variable is not set
// - Value cannot be parsed as an integer
// - Value is outside the [minValue, maxValue] range
//
// Invalid values trigger warning messages to stderr, or through the logger if provided.
func GetIntFromEnv(envVar string, defaultValue, minValue, maxValue int, log *logger.Logger) int {
func GetIntFromEnv(envVar string, defaultValue, minValue, maxValue int, debugLog *logger.Logger) int {
warn := func(msg string) {
if log != nil {
log.Printf("WARNING: %s", msg)
if debugLog != nil {
debugLog.Printf("WARNING: %s", msg)
} else {
fmt.Fprintln(os.Stderr, console.FormatWarningMessage(msg))
}
Expand All @@ -52,8 +52,8 @@ func GetIntFromEnv(envVar string, defaultValue, minValue, maxValue int, log *log
return defaultValue
}

if log != nil {
log.Printf("Using %s=%d", envVar, val)
if debugLog != nil {
debugLog.Printf("Using %s=%d", envVar, val)
}
return val
}
24 changes: 12 additions & 12 deletions pkg/fileutil/fileutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/github/gh-aw/pkg/logger"
)

var log = logger.New("fileutil:fileutil")
var fileutilLog = logger.New("fileutil:fileutil")

// ValidateAbsolutePath validates that a file path is absolute and safe to use.
// It performs the following security checks:
Expand All @@ -37,7 +37,7 @@ var log = logger.New("fileutil:fileutil")
func ValidateAbsolutePath(path string) (string, error) {
// Check for empty path
if path == "" {
log.Print("ValidateAbsolutePath: rejected empty path")
fileutilLog.Print("ValidateAbsolutePath: rejected empty path")
return "", errors.New("path cannot be empty")
}

Expand All @@ -46,11 +46,11 @@ func ValidateAbsolutePath(path string) (string, error) {

// Verify the path is absolute to prevent relative path traversal
if !filepath.IsAbs(cleanPath) {
log.Printf("ValidateAbsolutePath: rejected relative path: %s", path)
fileutilLog.Printf("ValidateAbsolutePath: rejected relative path: %s", path)
return "", fmt.Errorf("path must be absolute, got: %s", path)
}

log.Printf("ValidateAbsolutePath: validated path: %s", cleanPath)
fileutilLog.Printf("ValidateAbsolutePath: validated path: %s", cleanPath)
return cleanPath, nil
}

Expand All @@ -63,7 +63,7 @@ func ValidateAbsolutePath(path string) (string, error) {
// - Either path cannot be resolved to an absolute form.
// - The resolved candidate path starts outside the resolved base directory.
func ValidatePathWithinBase(base, candidate string) error {
log.Printf("ValidatePathWithinBase: checking candidate=%q within base=%q", candidate, base)
fileutilLog.Printf("ValidatePathWithinBase: checking candidate=%q within base=%q", candidate, base)
// EvalSymlinks resolves both symlinks and ".." components.
// Fall back to Abs when a path does not exist on disk yet.
absBase, err := filepath.EvalSymlinks(base)
Expand All @@ -82,10 +82,10 @@ func ValidatePathWithinBase(base, candidate string) error {
}
rel, err := filepath.Rel(absBase, absCand)
if err != nil || !filepath.IsLocal(rel) {
log.Printf("ValidatePathWithinBase: path escape detected: candidate=%q base=%q", candidate, base)
fileutilLog.Printf("ValidatePathWithinBase: path escape detected: candidate=%q base=%q", candidate, base)
return fmt.Errorf("path %q escapes base directory %q", candidate, base)
}
log.Printf("ValidatePathWithinBase: path is safe: candidate=%q (rel=%s) within base=%q", candidate, rel, base)
fileutilLog.Printf("ValidatePathWithinBase: path is safe: candidate=%q (rel=%s) within base=%q", candidate, rel, base)
return nil
}

Expand Down Expand Up @@ -131,7 +131,7 @@ func copyFileContents(in io.Reader, out syncWriteCloser, dst string) (err error)
}
if removePartial {
if removeErr := os.Remove(dst); removeErr != nil {
log.Printf("Failed to remove partial destination file during cleanup: %s", removeErr)
fileutilLog.Printf("Failed to remove partial destination file during cleanup: %s", removeErr)
}
}
}()
Expand All @@ -146,24 +146,24 @@ func copyFileContents(in io.Reader, out syncWriteCloser, dst string) (err error)

// CopyFile copies a file from src to dst using buffered IO.
func CopyFile(src, dst string) error {
log.Printf("Copying file: src=%s, dst=%s", src, dst)
fileutilLog.Printf("Copying file: src=%s, dst=%s", src, dst)
in, err := os.Open(src)
if err != nil {
log.Printf("Failed to open source file: %s", err)
fileutilLog.Printf("Failed to open source file: %s", err)
return err
}
defer in.Close()

out, err := os.Create(dst)
if err != nil {
log.Printf("Failed to create destination file: %s", err)
fileutilLog.Printf("Failed to create destination file: %s", err)
return err
}
err = copyFileContents(in, out, dst)
if err != nil {
return err
}

log.Printf("File copied successfully: src=%s, dst=%s", src, dst)
fileutilLog.Printf("File copied successfully: src=%s, dst=%s", src, dst)
return nil
}
18 changes: 9 additions & 9 deletions pkg/gitutil/gitutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/github/gh-aw/pkg/logger"
)

var log = logger.New("gitutil:gitutil")
var gitutilLog = logger.New("gitutil:gitutil")
var ErrNotGitRepository = errors.New("not in a git repository")

var fullSHARegex = regexp.MustCompile(`^[0-9a-f]{40}$`)
Expand All @@ -30,7 +30,7 @@ func IsRateLimitError(errMsg string) bool {
// IsAuthError checks if an error message indicates an authentication issue.
// This is used to detect when GitHub API calls fail due to missing or invalid credentials.
func IsAuthError(errMsg string) bool {
log.Printf("Checking if error is auth-related: %s", errMsg)
gitutilLog.Printf("Checking if error is auth-related: %s", errMsg)
lowerMsg := strings.ToLower(errMsg)
isAuth := strings.Contains(lowerMsg, "gh_token") ||
strings.Contains(lowerMsg, "github_token") ||
Expand All @@ -41,7 +41,7 @@ func IsAuthError(errMsg string) bool {
strings.Contains(lowerMsg, "permission denied") ||
strings.Contains(lowerMsg, "saml enforcement")
if isAuth {
log.Print("Detected authentication error")
gitutilLog.Print("Detected authentication error")
}
return isAuth
}
Expand Down Expand Up @@ -83,21 +83,21 @@ func ExtractBaseRepo(repoPath string) string {
// environments where git is not on PATH.
// Returns an error if not in a git repository.
func FindGitRoot() (string, error) {
log.Print("Finding git root directory")
gitutilLog.Print("Finding git root directory")

dir, err := os.Getwd()
if err != nil {
log.Printf("Failed to get current directory: %v", err)
gitutilLog.Printf("Failed to get current directory: %v", err)
return "", fmt.Errorf("failed to get current directory: %w", err)
}

root, err := FindGitRootFrom(dir)
if err != nil {
log.Printf("Failed to find git root: %v", err)
gitutilLog.Printf("Failed to find git root: %v", err)
return "", err
}

log.Printf("Found git root: %s", root)
gitutilLog.Printf("Found git root: %s", root)
return root, nil
}

Expand Down Expand Up @@ -171,12 +171,12 @@ func ReadFileFromHEAD(filePath, gitRoot string) (string, error) {

relPath = filepath.ToSlash(relPath)

log.Printf("Reading %q from git HEAD (relative path: %s)", filePath, relPath)
gitutilLog.Printf("Reading %q from git HEAD (relative path: %s)", filePath, relPath)

cmd := exec.Command("git", "-C", gitRoot, "show", "HEAD:"+relPath)
output, err := cmd.Output()
if err != nil {
log.Printf("File %q not found in HEAD commit: %v", filePath, err)
gitutilLog.Printf("File %q not found in HEAD commit: %v", filePath, err)
return "", fmt.Errorf("file %q not found in HEAD commit: %w", filePath, err)
}
return string(output), nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/repoutil/repoutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import (
"github.com/github/gh-aw/pkg/logger"
)

var log = logger.New("repoutil:repoutil")
var repoutilLog = logger.New("repoutil:repoutil")

// SplitRepoSlug splits a repository slug (owner/repo) into owner and repo parts.
// Returns an error if the slug format is invalid.
func SplitRepoSlug(slug string) (owner, repo string, err error) {
log.Printf("Splitting repo slug: %s", slug)
repoutilLog.Printf("Splitting repo slug: %s", slug)
parts := strings.Split(slug, "/")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
log.Printf("Invalid repo slug format: %s", slug)
repoutilLog.Printf("Invalid repo slug format: %s", slug)
return "", "", fmt.Errorf("invalid repo format: %s", slug)
}
log.Printf("Split result: owner=%s, repo=%s", parts[0], parts[1])
repoutilLog.Printf("Split result: owner=%s, repo=%s", parts[0], parts[1])
return parts[0], parts[1], nil
}
14 changes: 7 additions & 7 deletions pkg/semverutil/semverutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"golang.org/x/mod/semver"
)

var log = logger.New("semverutil:semverutil")
var semverutilLog = logger.New("semverutil:semverutil")

// actionVersionTagRegex matches version tags: vmajor, vmajor.minor, or vmajor.minor.patch.
// It intentionally excludes prerelease and build-metadata suffixes because GitHub Actions
Expand Down Expand Up @@ -60,11 +60,11 @@ func IsValid(ref string) bool {
// ParseVersion parses v into a SemanticVersion.
// It returns nil if v is not a valid semantic version string.
func ParseVersion(v string) *SemanticVersion {
log.Printf("Parsing semantic version: %s", v)
semverutilLog.Printf("Parsing semantic version: %s", v)
v = EnsureVPrefix(v)

if !semver.IsValid(v) {
log.Printf("Invalid semantic version: %s", v)
semverutilLog.Printf("Invalid semantic version: %s", v)
return nil
}

Expand Down Expand Up @@ -107,11 +107,11 @@ func Compare(v1, v2 string) int {
result := semver.Compare(v1, v2)

if result > 0 {
log.Printf("Version comparison result: %s > %s", v1, v2)
semverutilLog.Printf("Version comparison result: %s > %s", v1, v2)
} else if result < 0 {
log.Printf("Version comparison result: %s < %s", v1, v2)
semverutilLog.Printf("Version comparison result: %s < %s", v1, v2)
} else {
log.Printf("Version comparison result: %s == %s", v1, v2)
semverutilLog.Printf("Version comparison result: %s == %s", v1, v2)
}

return result
Expand Down Expand Up @@ -146,7 +146,7 @@ func IsCompatible(pinVersion, requestedVersion string) bool {
requestedMajor := semver.Major(requestedVersion)

compatible := pinMajor == requestedMajor
log.Printf("Checking semver compatibility: pin=%s (major=%s), requested=%s (major=%s) -> %v",
semverutilLog.Printf("Checking semver compatibility: pin=%s (major=%s), requested=%s (major=%s) -> %v",
pinVersion, pinMajor, requestedVersion, requestedMajor, compatible)

return compatible
Expand Down
8 changes: 4 additions & 4 deletions pkg/typeutil/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
"github.com/github/gh-aw/pkg/logger"
)

var log = logger.New("typeutil:convert")
var typeutilConvertLog = logger.New("typeutil:convert")

// ParseIntValue strictly parses numeric types (int, int64, uint64, float64) to int,
// returning (value, true) on success and (0, false) for any unrecognized or
Expand All @@ -55,15 +55,15 @@ func ParseIntValue(value any) (int, bool) {
// Check for overflow before converting uint64 to int
const maxInt = int(^uint(0) >> 1)
if v > uint64(maxInt) {
log.Printf("uint64 value %d exceeds max int value, returning 0", v)
typeutilConvertLog.Printf("uint64 value %d exceeds max int value, returning 0", v)
return 0, false
}
return int(v), true
case float64:
intVal := int(v)
// Warn if truncation occurs (value has fractional part)
if v != float64(intVal) {
log.Printf("Float value %.2f truncated to integer %d", v, intVal)
typeutilConvertLog.Printf("Float value %.2f truncated to integer %d", v, intVal)
}
return intVal, true
default:
Expand Down Expand Up @@ -102,7 +102,7 @@ func ConvertToInt(val any) int {
intVal := int(v)
// Warn if truncation occurs (value has fractional part)
if v != float64(intVal) {
log.Printf("Float value %.2f truncated to integer %d", v, intVal)
typeutilConvertLog.Printf("Float value %.2f truncated to integer %d", v, intVal)
}
return intVal
case string:
Expand Down
Loading
Loading