Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bdc0d26
implemented Source unit for S3.
mustansir14 Jun 26, 2025
21d9334
use bucket as source unit
mustansir14 Jul 7, 2025
a25afff
remove code duplication, reuse from Chunks
mustansir14 Nov 19, 2025
ea21d02
remove unnecessary change
mustansir14 Nov 19, 2025
9915187
remove unused functions
mustansir14 Nov 19, 2025
c32e12c
revisit tests
mustansir14 Nov 19, 2025
5161090
revert unnecessary change
mustansir14 Nov 19, 2025
ef324d1
change SourceUnitKind to s3_bucket
mustansir14 Nov 20, 2025
84e0cda
handle nil objectCount inside scanBucket
mustansir14 Nov 20, 2025
b5a66d5
handle nil objectCount outside loop
mustansir14 Nov 20, 2025
966007f
add bucket to resume log
mustansir14 Nov 20, 2025
50e5a90
Merge branch 'main' into INS-104-Support-units-in-S3-source
amanfcp Nov 20, 2025
0faa70e
add bucket and role to error log, remove enumerating log
mustansir14 Nov 21, 2025
474172c
Merge branch 'INS-104-Support-units-in-S3-source' of mustansir:mustan…
mustansir14 Nov 21, 2025
10f91ff
implement sub unit resumption
mustansir14 Nov 24, 2025
6bfbc14
add comment to checkpointer for unit scans
mustansir14 Nov 24, 2025
1863659
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Nov 25, 2025
5ca4151
implement SourceUnitUnmarshaller on source with the new S3SourceUnit,…
mustansir14 Nov 26, 2025
45a133b
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 1, 2025
549e6be
add role to SourceUnitID
mustansir14 Dec 2, 2025
b5cb928
Merge branch 'INS-104-Support-units-in-S3-source' of mustansir:mustan…
mustansir14 Dec 2, 2025
1cee9af
Revert "add role to SourceUnitID"
mustansir14 Dec 2, 2025
6f06776
add role to source unit ID, keep track of resumption using source uni…
mustansir14 Dec 3, 2025
85e681b
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 3, 2025
53a91c8
rename bucket -> unitID in UnmarshalSourceUnit
mustansir14 Dec 4, 2025
3e8e6b9
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pkg/sources/s3/checkpointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ import (
// resuming from the correct bucket. The scan will continue from the last checkpointed object
// in that bucket.
//
// Unit scans are also supported. The encoded resume info in this case tracks the last processed object
// for each bucket (unit) separately by using the SetEncodedResumeInfoFor method on Progress. To use the
// checkpointer for unit scans, call SetIsUnitScan(true) before starting the scan.
//
// For example, if scanning is interrupted after processing 1500 objects across 2 pages:
// Page 1 (objects 0-999): Fully processed, checkpoint saved at object 999
// Page 2 (objects 1000-1999): Partially processed through 1600, but only consecutive through 1499
Expand All @@ -56,6 +60,8 @@ type Checkpointer struct {
// progress holds the scan's overall progress state and enables persistence.
// The EncodedResumeInfo field stores the JSON-encoded ResumeInfo checkpoint.
progress *sources.Progress // Reference to source's Progress

isUnitScan bool // Indicates if scanning is done in unit scan mode
}

const defaultMaxObjectsPerPage = 1000
Expand Down Expand Up @@ -199,6 +205,12 @@ func (p *Checkpointer) advanceLowestIncompleteIdx() {
// updateCheckpoint persists the current resumption state.
// Must be called with lock held.
func (p *Checkpointer) updateCheckpoint(bucket string, lastKey string) error {
if p.isUnitScan {
// track sub-unit resumption state
p.progress.SetEncodedResumeInfoFor(bucket, lastKey)
return nil
}

encoded, err := json.Marshal(&ResumeInfo{CurrentBucket: bucket, StartAfter: lastKey})
if err != nil {
return fmt.Errorf("failed to encode resume info: %w", err)
Expand All @@ -212,3 +224,11 @@ func (p *Checkpointer) updateCheckpoint(bucket string, lastKey string) error {
)
return nil
}

// SetIsUnitScan sets whether the checkpointer is operating in unit scan mode.
func (p *Checkpointer) SetIsUnitScan(isUnitScan bool) {
p.mu.Lock()
defer p.mu.Unlock()

p.isUnitScan = isUnitScan
}
30 changes: 30 additions & 0 deletions pkg/sources/s3/checkpointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,36 @@ func TestCheckpointerUpdate(t *testing.T) {
}
}

func TestCheckpointerUpdateUnitScan(t *testing.T) {
ctx := context.Background()
progress := new(sources.Progress)
tracker := NewCheckpointer(ctx, progress)
tracker.SetIsUnitScan(true)

page := &s3.ListObjectsV2Output{
Contents: make([]s3types.Object, 3),
}
for i := range 3 {
key := fmt.Sprintf("key-%d", i)
page.Contents[i] = s3types.Object{Key: &key}
}

// Complete first object.
err := tracker.UpdateObjectCompletion(ctx, 0, "test-bucket", page.Contents)
assert.NoError(t, err, "Unexpected error updating progress")

var info map[string]string
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
var gotBucket, gotStartAfter string
for k, v := range info {
gotBucket = k
gotStartAfter = v
}
assert.NoError(t, err, "Failed to decode resume info")
assert.Equal(t, "test-bucket", gotBucket, "Incorrect bucket")
assert.Equal(t, "key-0", gotStartAfter, "Incorrect resume point")
}

func TestComplete(t *testing.T) {
tests := []struct {
name string
Expand Down
208 changes: 149 additions & 59 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Source struct {
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)
var _ sources.SourceUnitEnumChunker = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType { return SourceType }
Expand Down Expand Up @@ -294,7 +295,7 @@ func (s *Source) scanBuckets(
if role != "" {
ctx = context.WithValue(ctx, "role", role)
}
var objectCount uint64
var totalObjectCount uint64

pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan)
switch {
Expand All @@ -316,16 +317,7 @@ func (s *Source) scanBuckets(

bucketsToScanCount := len(bucketsToScan)
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
s.metricsCollector.RecordBucketForRole(role)
bucket := bucketsToScan[bucketIdx]
ctx := context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
return
}

ctx.Logger().V(3).Info("Scanning bucket")

s.SetProgressComplete(
bucketIdx,
Expand All @@ -334,63 +326,94 @@ func (s *Source) scanBuckets(
s.Progress.EncodedResumeInfo,
)

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
continue
}

errorCount := sync.Map{}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
var startAfter *string
if bucket == pos.bucket && pos.startAfter != "" {
input.StartAfter = &pos.startAfter
startAfter = &pos.startAfter
ctx.Logger().V(3).Info(
"Resuming bucket scan",
"start_after", pos.startAfter,
"bucket", bucket,
)
}

pageNumber := 1
paginator := s3.NewListObjectsV2Paginator(regionalClient, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
break
}
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: output,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: &objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)

pageNumber++
}
objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter)
totalObjectCount += objectCount
}

s.SetProgressComplete(
len(bucketsToScan),
len(bucketsToScan),
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, totalObjectCount),
"",
)
}

func (s *Source) scanBucket(
ctx context.Context,
client *s3.Client,
role string,
bucket string,
reporter sources.ChunkReporter,
startAfter *string,
) uint64 {
s.metricsCollector.RecordBucketForRole(role)

ctx = context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
return 0
}

ctx.Logger().V(3).Info("Scanning bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
return 0
}

errorCount := sync.Map{}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
if startAfter != nil {
input.StartAfter = startAfter
}

pageNumber := 1
paginator := s3.NewListObjectsV2Paginator(regionalClient, input)
var objectCount uint64
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
break
}
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: output,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: &objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, reporter)

pageNumber++
}
return objectCount
}

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error {
Expand Down Expand Up @@ -429,14 +452,12 @@ func (s *Source) pageChunker(
ctx context.Context,
metadata pageMetadata,
state processingState,
chunksChan chan *sources.Chunk,
reporter sources.ChunkReporter,
) {
s.checkpointer.Reset() // Reset the checkpointer for each PAGE
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)

for objIdx, obj := range metadata.page.Contents {
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)

if common.IsDone(ctx) {
return
}
Expand Down Expand Up @@ -572,12 +593,11 @@ func (s *Source) pageChunker(
Verify: s.verify,
}

if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
if err := handlers.HandleFile(ctx, res.Body, chunkSkel, reporter); err != nil {
ctx.Logger().Error(err, "error handling file")
s.metricsCollector.RecordObjectError(metadata.bucket)
return nil
}

atomic.AddUint64(state.objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
nErr, ok = state.errorCount.Load(prefix)
Expand All @@ -587,17 +607,14 @@ func (s *Source) pageChunker(
if nErr.(int) > 0 {
state.errorCount.Store(prefix, 0)
}

// Update progress after successful processing.
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}
s.metricsCollector.RecordObjectScanned(metadata.bucket, float64(*obj.Size))

return nil
})
}

_ = s.jobPool.Wait()
}

Expand Down Expand Up @@ -681,3 +698,76 @@ func (s *Source) visitRoles(
func makeS3Link(bucket, region, key string) string {
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key)
}

type S3SourceUnit struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've defined this unit type, but you haven't modified the source to actually use it. The source type still embeds CommonSourceUnitUnmarshaller, so it will still unmarshal source units to CommonSourceUnit instead of your new type. You'll need to define custom unmarshalling logic. (The git source has an example of custom unmarshalling logic you can look at.)

Also, I recommend putting the unit struct and related code in a separate file, because we do that for several other sources, and I think it makes things more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks for pointing this out. I wasn't aware of this. I'll make the changes.

Bucket string
Role string
}

func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) {
// The ID is the bucket name, and the kind is "s3_bucket".
return s.Bucket, "s3_bucket"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcastorina I forget - is it a problem if SourceUnitID can't be used to round-trip a unit? (In this case, we lose the Role field.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wait for @mcastorina's answer before making changes here, but here's what the description comment says for SourceUnitID():

// SourceUnitID uniquely identifies a source unit. It does not need to
// be human readable or two-way, however, it should be canonical and
// stable across runs.

The bucket name is a globally unique value, so with that aspect we should be good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good catch. I guess the round-trip-abillity happens in the source manager somewhere? (@mcastorina?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We take the full unit object and JSON marshal it, so the fields need to be public. Idk if I documented that anywhere though, but that's why a source needs to implement unmarshalling but not marshalling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the fields are public, so we're good there. But based on our discussion in the thread below regarding having the role in resumption info, it seems like a good idea to have the role in the SourceUnitID as well. I'll add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it and realized it might not be best to add this yet as this also affects sub-unit resumption because the resumption info is supposed to be saved against the SourceUnitID, and our current checkpointer only works with buckets, not roles. I have reverted the changes and will wait for your responses to decide if we want to go with roles being part of resumption info or not.


func (s S3SourceUnit) Display() string {
return s.Bucket
}

var _ sources.SourceUnit = S3SourceUnit{}

// Enumerate implements SourceUnitEnumerator interface. This implementation visits
// each configured role and passes each s3 bucket as a source unit
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error {
for _, bucket := range buckets {
if common.IsDone(ctx) {
return ctx.Err()
}

unit := S3SourceUnit{
Bucket: bucket,
Role: roleArn,
}

if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
}
return nil
}

return s.visitRoles(ctx, visitor)
}

// ChunkUnit implements SourceUnitChunker interface. This implementation scans
// the given S3 bucket source unit and emits chunks for each object found.
// It supports sub-unit resumption by utilizing the checkpointer to track progress.
func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {

s3unit, ok := unit.(S3SourceUnit)
if !ok {
return fmt.Errorf("expected *S3SourceUnit, got %T", unit)
}
bucket := s3unit.Bucket

defaultClient, err := s.newClient(ctx, defaultAWSRegion, s3unit.Role)
if err != nil {
return fmt.Errorf("could not create s3 client for bucket %s and role %s: %w", bucket, s3unit.Role, err)
}

s.checkpointer.SetIsUnitScan(true)

var startAfterPtr *string
startAfter := s.Progress.GetEncodedResumeInfoFor(bucket)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about the way this gets resume info using only a bucket name, but then scans that bucket also using the role stored in the unit. It seems like the information used to retrieve resumption information should be the same information that's used to scan using the retrieved information, but that's not how you've implemented this.

I can't think of any concrete, immediate problems this would cause, but that doesn't mean there aren't any - and my bigger concern is that this will impede maintainability. What do you think?

Copy link
Contributor Author

@mustansir14 mustansir14 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this. I had an internal discussion with @amanfcp on this, and he raised a great point that there could be a case where two roles have access to the same bucket, but have different object-level access. I researched on this and turns out this is true. So it seems like a good idea to store resume info for a particular bucket AND a particular role.

My only concern here is that Legacy scans resumption does not have this. Resume Info is only stored using bucket there, and this particular case also seems to be applicable there.

I might be totally off here and there might be some place where we are already handling this particular case, so please correct me if I am wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this looks like an oversight in the existing implementation. It's been around for a while, so I don't think we need to urgently fix it (and I wouldn't fix it in this PR), but please add a TODO somewhere flagging the problem. We can clean it up as a later step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've included the role in the SourceUnitID now, and resumption info is being tracked using the SourceUnitID, so this means resumption info is being stored against both role and bucket

if startAfter != "" {
ctx.Logger().V(3).Info(
"Resuming bucket scan",
"start_after", startAfter,
"bucket", bucket,
)
startAfterPtr = &startAfter
}
defer s.Progress.ClearEncodedResumeInfoFor(bucket)
s.scanBucket(ctx, defaultClient, s3unit.Role, bucket, reporter, startAfterPtr)
return nil
}
Loading
Loading