Skip to content

Commit 5ca4151

Browse files
committed
implement SourceUnitUnmarshaller on source with the new S3SourceUnit, add test to test resumption on multiple buckets with concurrent ChunkUnit processing
1 parent 1863659 commit 5ca4151

File tree

5 files changed

+153
-19
lines changed

5 files changed

+153
-19
lines changed

pkg/sources/s3/s3.go

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package s3
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"slices"
67
"strings"
@@ -54,8 +55,6 @@ type Source struct {
5455
errorCount *sync.Map
5556
jobPool *errgroup.Group
5657
maxObjectSize int64
57-
58-
sources.CommonSourceUnitUnmarshaller
5958
}
6059

6160
// Ensure the Source satisfies the interfaces at compile time
@@ -699,22 +698,6 @@ func makeS3Link(bucket, region, key string) string {
699698
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key)
700699
}
701700

702-
type S3SourceUnit struct {
703-
Bucket string
704-
Role string
705-
}
706-
707-
func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) {
708-
// The ID is the bucket name, and the kind is "s3_bucket".
709-
return s.Bucket, "s3_bucket"
710-
}
711-
712-
func (s S3SourceUnit) Display() string {
713-
return s.Bucket
714-
}
715-
716-
var _ sources.SourceUnit = S3SourceUnit{}
717-
718701
// Enumerate implements SourceUnitEnumerator interface. This implementation visits
719702
// each configured role and passes each s3 bucket as a source unit
720703
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
@@ -771,3 +754,15 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
771754
s.scanBucket(ctx, defaultClient, s3unit.Role, bucket, reporter, startAfterPtr)
772755
return nil
773756
}
757+
758+
func (s *Source) UnmarshalSourceUnit(data []byte) (sources.SourceUnit, error) {
759+
var unit S3SourceUnit
760+
if err := json.Unmarshal(data, &unit); err != nil {
761+
return nil, err
762+
}
763+
bucket, kind := unit.SourceUnitID()
764+
if bucket == "" || kind != SourceUnitKindBucket {
765+
return nil, fmt.Errorf("not an S3SourceUnit")
766+
}
767+
return unit, nil
768+
}

pkg/sources/s3/s3_integration_test.go

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
1717
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
1818
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
19+
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
1920
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
2021
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
2122
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
@@ -394,6 +395,8 @@ func TestSourceChunksResumptionMultipleBucketsIgnoredBucket(t *testing.T) {
394395
}
395396

396397
func TestSource_Enumerate(t *testing.T) {
398+
t.Parallel()
399+
397400
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
398401
defer cancel()
399402

@@ -438,6 +441,8 @@ func TestSource_Enumerate(t *testing.T) {
438441
}
439442

440443
func TestSource_ChunkUnit(t *testing.T) {
444+
t.Parallel()
445+
441446
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
442447
defer cancel()
443448

@@ -484,7 +489,7 @@ func TestSource_ChunkUnit(t *testing.T) {
484489
func TestSource_ChunkUnit_Resumption(t *testing.T) {
485490
t.Parallel()
486491

487-
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second)
492+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
488493
defer cancel()
489494

490495
s := new(Source)
@@ -517,3 +522,65 @@ func TestSource_ChunkUnit_Resumption(t *testing.T) {
517522
// Verify that we processed all remaining data on resume.
518523
assert.Equal(t, 9638, len(reporter.Chunks), "Should have processed all remaining data on resume")
519524
}
525+
526+
// TestSource_ChunkUnit_Resumption_MultipleBucketsConcurrent tests resumption across multiple buckets
527+
// with concurrent ChunkUnit processing.
528+
func TestSource_ChunkUnit_Resumption_MultipleBucketsConcurrent(t *testing.T) {
529+
t.Parallel()
530+
531+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
532+
defer cancel()
533+
534+
src := new(Source)
535+
connection := &sourcespb.S3{
536+
Credential: &sourcespb.S3_Unauthenticated{},
537+
Buckets: []string{"integration-resumption-tests", "trufflesec-ahrav-test-2"},
538+
EnableResumption: true,
539+
}
540+
conn, err := anypb.New(connection)
541+
require.NoError(t, err)
542+
err = src.Init(ctx, "test name", 0, 0, false, conn, 2)
543+
require.NoError(t, err)
544+
545+
src.Progress = sources.Progress{
546+
Message: "Buckets: [integration-resumption-tests trufflesec-ahrav-test-2]",
547+
EncodedResumeInfo: "{\"integration-resumption-tests\":\"test-dir\", \"trufflesec-ahrav-test-2\":\"test-dir/smmed_random_data.json.zip\"}",
548+
SectionsCompleted: 0,
549+
SectionsRemaining: 2,
550+
}
551+
552+
reporter := sourcestest.SafeTestReporter{}
553+
err = src.Enumerate(ctx, &reporter)
554+
require.NoError(t, err)
555+
require.Equal(t, 2, len(reporter.Units), "Expected two source units from enumeration")
556+
557+
var wg sync.WaitGroup
558+
559+
for _, unit := range reporter.Units {
560+
wg.Add(1)
561+
go func() {
562+
defer wg.Done()
563+
err = src.ChunkUnit(ctx, unit, &reporter)
564+
assert.NoError(t, err, "Expected no error during ChunkUnit")
565+
}()
566+
}
567+
568+
wg.Wait()
569+
570+
bucketChunkCounts := map[string]int{
571+
"integration-resumption-tests": 9638,
572+
"trufflesec-ahrav-test-2": 2,
573+
}
574+
actualBucketChunkCounts := make(map[string]int)
575+
576+
for _, chunk := range reporter.Chunks {
577+
metadata, _ := chunk.SourceMetadata.Data.(*source_metadatapb.MetaData_S3)
578+
actualBucketChunkCounts[metadata.S3.Bucket]++
579+
}
580+
581+
for bucket, wantCount := range bucketChunkCounts {
582+
gotCount, ok := actualBucketChunkCounts[bucket]
583+
require.True(t, ok, "Expected chunks for bucket %s", bucket)
584+
assert.Equal(t, wantCount, gotCount, "Chunk count mismatch for bucket %s", bucket)
585+
}
586+
}

pkg/sources/s3/s3_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/kylelemons/godebug/pretty"
1111
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
1213
"google.golang.org/protobuf/types/known/anypb"
1314

1415
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
@@ -163,3 +164,21 @@ func TestSource_Chunks(t *testing.T) {
163164
})
164165
}
165166
}
167+
168+
func TestSource_UnmarshalSourceUnit(t *testing.T) {
169+
s := Source{}
170+
171+
unitJSON := `{
172+
"Bucket": "my-test-bucket",
173+
"Role": "my-test-role"
174+
}`
175+
176+
unit, err := s.UnmarshalSourceUnit([]byte(unitJSON))
177+
require.NoError(t, err, "UnmarshalSourceUnit should not return an error")
178+
179+
s3Unit, ok := unit.(S3SourceUnit)
180+
require.True(t, ok, "Unmarshaled unit should be of type S3SourceUnit")
181+
182+
assert.Equal(t, "my-test-bucket", s3Unit.Bucket, "Bucket field should match")
183+
assert.Equal(t, "my-test-role", s3Unit.Role, "Role field should match")
184+
}

pkg/sources/s3/unit.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package s3
2+
3+
import (
4+
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
5+
)
6+
7+
const SourceUnitKindBucket sources.SourceUnitKind = "bucket"
8+
9+
type S3SourceUnit struct {
10+
Bucket string
11+
Role string
12+
}
13+
14+
var _ sources.SourceUnit = S3SourceUnit{}
15+
16+
func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) {
17+
// The ID is the bucket name, and the kind is "bucket".
18+
return s.Bucket, SourceUnitKindBucket
19+
}
20+
21+
func (s S3SourceUnit) Display() string {
22+
return s.Bucket
23+
}

pkg/sourcestest/sourcestest.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sourcestest
22

33
import (
44
"fmt"
5+
"sync"
56

67
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
78
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
@@ -59,3 +60,32 @@ func (ErrReporter) ChunkOk(context.Context, sources.Chunk) error {
5960
func (ErrReporter) ChunkErr(context.Context, error) error {
6061
return fmt.Errorf("ErrReporter: ChunkErr error")
6162
}
63+
64+
// SafeTestReporter is a helper struct that implements both UnitReporter and
65+
// ChunkReporter by recording the values passed in the methods with thread safety.
66+
type SafeTestReporter struct {
67+
TestReporter
68+
69+
mu sync.Mutex
70+
}
71+
72+
func (t *SafeTestReporter) UnitOk(_ context.Context, unit sources.SourceUnit) error {
73+
t.mu.Lock()
74+
defer t.mu.Unlock()
75+
return t.TestReporter.UnitOk(nil, unit)
76+
}
77+
func (t *SafeTestReporter) UnitErr(_ context.Context, err error) error {
78+
t.mu.Lock()
79+
defer t.mu.Unlock()
80+
return t.TestReporter.UnitErr(nil, err)
81+
}
82+
func (t *SafeTestReporter) ChunkOk(_ context.Context, chunk sources.Chunk) error {
83+
t.mu.Lock()
84+
defer t.mu.Unlock()
85+
return t.TestReporter.ChunkOk(nil, chunk)
86+
}
87+
func (t *SafeTestReporter) ChunkErr(_ context.Context, err error) error {
88+
t.mu.Lock()
89+
defer t.mu.Unlock()
90+
return t.TestReporter.ChunkErr(nil, err)
91+
}

0 commit comments

Comments
 (0)