Skip to content

Commit 5b664d5

Browse files
leodidoona-agent
andcommitted
refactor: reuse extractTarToDir in extractTar
- Add filepath.Clean() to extractTarToDir for stricter path traversal check - Reimplement extractTar as a thin wrapper around extractTarToDir - Removes code duplication between the two functions Co-authored-by: Ona <[email protected]>
1 parent cba003b commit 5b664d5

File tree

1 file changed

+13
-56
lines changed

1 file changed

+13
-56
lines changed

pkg/leeway/container_image.go

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -49,58 +49,58 @@ func extractImageWithOCILibsImpl(destDir, imgTag string) error {
4949
// The OCI tar is in the parent directory of destDir (the build directory)
5050
buildDir := filepath.Dir(filepath.Dir(destDir)) // destDir is buildDir/container/content
5151
ociTarPath := filepath.Join(buildDir, "image.tar")
52-
52+
5353
var img v1.Image
54-
54+
5555
if _, statErr := os.Stat(ociTarPath); statErr == nil {
5656
// OCI tar exists - extract and load from it
5757
log.WithField("ociTar", ociTarPath).Debug("Loading image from OCI tar file")
58-
58+
5959
// Create a temporary directory to extract the OCI layout
6060
ociLayoutDir, err := os.MkdirTemp(buildDir, "oci-layout-")
6161
if err != nil {
6262
return fmt.Errorf("creating temp dir for OCI layout: %w", err)
6363
}
6464
defer os.RemoveAll(ociLayoutDir)
65-
65+
6666
// Extract the OCI tar to the temporary directory
6767
if err := extractTar(ociTarPath, ociLayoutDir); err != nil {
6868
return fmt.Errorf("extracting OCI tar: %w", err)
6969
}
70-
70+
7171
// Load the image from the OCI layout directory
7272
layoutPath, err := layout.FromPath(ociLayoutDir)
7373
if err != nil {
7474
return fmt.Errorf("loading OCI layout from %s: %w", ociLayoutDir, err)
7575
}
76-
76+
7777
// Get the image index
7878
imageIndex, err := layoutPath.ImageIndex()
7979
if err != nil {
8080
return fmt.Errorf("getting image index from OCI layout: %w", err)
8181
}
82-
82+
8383
// Get the manifest
8484
indexManifest, err := imageIndex.IndexManifest()
8585
if err != nil {
8686
return fmt.Errorf("getting index manifest: %w", err)
8787
}
88-
88+
8989
if len(indexManifest.Manifests) == 0 {
9090
return fmt.Errorf("no manifests found in OCI layout")
9191
}
92-
92+
9393
// Get the first image (there should only be one for single-platform builds)
9494
img, err = layoutPath.Image(indexManifest.Manifests[0].Digest)
9595
if err != nil {
9696
return fmt.Errorf("getting image from OCI layout: %w", err)
9797
}
98-
98+
9999
log.Debug("Successfully loaded image from OCI tar")
100100
} else {
101101
// OCI tar doesn't exist - fall back to Docker daemon
102102
log.Debug("OCI tar not found, loading image from Docker daemon")
103-
103+
104104
ref, err := name.ParseReference(imgTag)
105105
if err != nil {
106106
return fmt.Errorf("parsing image reference: %w", err)
@@ -276,7 +276,7 @@ func extractTarToDir(r io.Reader, destDir string) error {
276276
target := filepath.Join(destDir, header.Name)
277277

278278
// Prevent directory traversal attacks
279-
if !strings.HasPrefix(target, destDir) {
279+
if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)) {
280280
continue
281281
}
282282

@@ -478,48 +478,5 @@ func extractTar(tarPath, destDir string) error {
478478
}
479479
defer file.Close()
480480

481-
tarReader := tar.NewReader(file)
482-
483-
for {
484-
header, err := tarReader.Next()
485-
if err == io.EOF {
486-
break
487-
}
488-
if err != nil {
489-
return fmt.Errorf("reading tar: %w", err)
490-
}
491-
492-
target := filepath.Join(destDir, header.Name)
493-
494-
// Ensure the target is within destDir (security check)
495-
if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)) {
496-
return fmt.Errorf("illegal file path in tar: %s", header.Name)
497-
}
498-
499-
switch header.Typeflag {
500-
case tar.TypeDir:
501-
if err := os.MkdirAll(target, 0755); err != nil {
502-
return fmt.Errorf("creating directory: %w", err)
503-
}
504-
case tar.TypeReg:
505-
// Create parent directories
506-
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
507-
return fmt.Errorf("creating parent directory: %w", err)
508-
}
509-
510-
// Create file
511-
outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
512-
if err != nil {
513-
return fmt.Errorf("creating file: %w", err)
514-
}
515-
516-
if _, err := io.Copy(outFile, tarReader); err != nil {
517-
outFile.Close()
518-
return fmt.Errorf("writing file: %w", err)
519-
}
520-
outFile.Close()
521-
}
522-
}
523-
524-
return nil
481+
return extractTarToDir(file, destDir)
525482
}

0 commit comments

Comments
 (0)