-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[INS-104] Support units in S3 source #4560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
bdc0d26
21d9334
a25afff
ea21d02
9915187
c32e12c
5161090
ef324d1
84e0cda
b5a66d5
966007f
50e5a90
0faa70e
474172c
10f91ff
6bfbc14
1863659
5ca4151
45a133b
549e6be
b5cb928
1cee9af
6f06776
85e681b
53a91c8
3e8e6b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,7 @@ type Source struct { | |
| var _ sources.Source = (*Source)(nil) | ||
| var _ sources.SourceUnitUnmarshaller = (*Source)(nil) | ||
| var _ sources.Validator = (*Source)(nil) | ||
| var _ sources.SourceUnitEnumChunker = (*Source)(nil) | ||
|
|
||
| // Type returns the type of source | ||
| func (s *Source) Type() sourcespb.SourceType { return SourceType } | ||
|
|
@@ -294,7 +295,7 @@ func (s *Source) scanBuckets( | |
| if role != "" { | ||
| ctx = context.WithValue(ctx, "role", role) | ||
| } | ||
| var objectCount uint64 | ||
| var totalObjectCount uint64 | ||
|
|
||
| pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan) | ||
| switch { | ||
|
|
@@ -316,16 +317,7 @@ func (s *Source) scanBuckets( | |
|
|
||
| bucketsToScanCount := len(bucketsToScan) | ||
| for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ { | ||
| s.metricsCollector.RecordBucketForRole(role) | ||
| bucket := bucketsToScan[bucketIdx] | ||
| ctx := context.WithValue(ctx, "bucket", bucket) | ||
|
|
||
| if common.IsDone(ctx) { | ||
| ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket") | ||
| return | ||
| } | ||
|
|
||
| ctx.Logger().V(3).Info("Scanning bucket") | ||
|
|
||
| s.SetProgressComplete( | ||
| bucketIdx, | ||
|
|
@@ -334,63 +326,94 @@ func (s *Source) scanBuckets( | |
| s.Progress.EncodedResumeInfo, | ||
| ) | ||
|
|
||
| regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket) | ||
| if err != nil { | ||
| ctx.Logger().Error(err, "could not get regional client for bucket") | ||
| continue | ||
| } | ||
|
|
||
| errorCount := sync.Map{} | ||
|
|
||
| input := &s3.ListObjectsV2Input{Bucket: &bucket} | ||
| var startAfter *string | ||
| if bucket == pos.bucket && pos.startAfter != "" { | ||
| input.StartAfter = &pos.startAfter | ||
| startAfter = &pos.startAfter | ||
| ctx.Logger().V(3).Info( | ||
| "Resuming bucket scan", | ||
| "start_after", pos.startAfter, | ||
| "bucket", bucket, | ||
| ) | ||
| } | ||
|
|
||
| pageNumber := 1 | ||
| paginator := s3.NewListObjectsV2Paginator(regionalClient, input) | ||
| for paginator.HasMorePages() { | ||
| output, err := paginator.NextPage(ctx) | ||
| if err != nil { | ||
| if role == "" { | ||
| ctx.Logger().Error(err, "could not list objects in bucket") | ||
| } else { | ||
| // Our documentation blesses specifying a role to assume without specifying buckets to scan, which will | ||
| // often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the | ||
| // account, but the role probably doesn't have access to all of them). This makes it expected behavior | ||
| // and therefore not an error. | ||
| ctx.Logger().V(3).Info("could not list objects in bucket", "err", err) | ||
| } | ||
| break | ||
| } | ||
| pageMetadata := pageMetadata{ | ||
| bucket: bucket, | ||
| pageNumber: pageNumber, | ||
| client: regionalClient, | ||
| page: output, | ||
| } | ||
| processingState := processingState{ | ||
| errorCount: &errorCount, | ||
| objectCount: &objectCount, | ||
| } | ||
| s.pageChunker(ctx, pageMetadata, processingState, chunksChan) | ||
|
|
||
| pageNumber++ | ||
| } | ||
| objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter) | ||
| totalObjectCount += objectCount | ||
| } | ||
|
|
||
| s.SetProgressComplete( | ||
| len(bucketsToScan), | ||
| len(bucketsToScan), | ||
| fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), | ||
| fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, totalObjectCount), | ||
| "", | ||
| ) | ||
| } | ||
|
|
||
| func (s *Source) scanBucket( | ||
| ctx context.Context, | ||
| client *s3.Client, | ||
| role string, | ||
| bucket string, | ||
| reporter sources.ChunkReporter, | ||
| startAfter *string, | ||
| ) uint64 { | ||
| s.metricsCollector.RecordBucketForRole(role) | ||
|
|
||
| ctx = context.WithValue(ctx, "bucket", bucket) | ||
|
|
||
| if common.IsDone(ctx) { | ||
| ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket") | ||
| return 0 | ||
| } | ||
|
|
||
| ctx.Logger().V(3).Info("Scanning bucket") | ||
|
|
||
| regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket) | ||
| if err != nil { | ||
| ctx.Logger().Error(err, "could not get regional client for bucket") | ||
| return 0 | ||
| } | ||
|
|
||
| errorCount := sync.Map{} | ||
|
|
||
| input := &s3.ListObjectsV2Input{Bucket: &bucket} | ||
| if startAfter != nil { | ||
| input.StartAfter = startAfter | ||
| } | ||
|
|
||
| pageNumber := 1 | ||
| paginator := s3.NewListObjectsV2Paginator(regionalClient, input) | ||
| var objectCount uint64 | ||
| for paginator.HasMorePages() { | ||
| output, err := paginator.NextPage(ctx) | ||
| if err != nil { | ||
| if role == "" { | ||
| ctx.Logger().Error(err, "could not list objects in bucket") | ||
| } else { | ||
| // Our documentation blesses specifying a role to assume without specifying buckets to scan, which will | ||
| // often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the | ||
| // account, but the role probably doesn't have access to all of them). This makes it expected behavior | ||
| // and therefore not an error. | ||
| ctx.Logger().V(3).Info("could not list objects in bucket", "err", err) | ||
| } | ||
| break | ||
| } | ||
| pageMetadata := pageMetadata{ | ||
| bucket: bucket, | ||
| pageNumber: pageNumber, | ||
| client: regionalClient, | ||
| page: output, | ||
| } | ||
| processingState := processingState{ | ||
| errorCount: &errorCount, | ||
| objectCount: &objectCount, | ||
| } | ||
| s.pageChunker(ctx, pageMetadata, processingState, reporter) | ||
|
|
||
| pageNumber++ | ||
| } | ||
| return objectCount | ||
| } | ||
|
|
||
| // Chunks emits chunks of bytes over a channel. | ||
| func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error { | ||
| visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error { | ||
|
|
@@ -429,14 +452,12 @@ func (s *Source) pageChunker( | |
| ctx context.Context, | ||
| metadata pageMetadata, | ||
| state processingState, | ||
| chunksChan chan *sources.Chunk, | ||
| reporter sources.ChunkReporter, | ||
| ) { | ||
| s.checkpointer.Reset() // Reset the checkpointer for each PAGE | ||
| ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber) | ||
|
|
||
| for objIdx, obj := range metadata.page.Contents { | ||
| ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size) | ||
|
|
||
| if common.IsDone(ctx) { | ||
| return | ||
| } | ||
|
|
@@ -572,12 +593,11 @@ func (s *Source) pageChunker( | |
| Verify: s.verify, | ||
| } | ||
|
|
||
| if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil { | ||
| if err := handlers.HandleFile(ctx, res.Body, chunkSkel, reporter); err != nil { | ||
| ctx.Logger().Error(err, "error handling file") | ||
| s.metricsCollector.RecordObjectError(metadata.bucket) | ||
| return nil | ||
| } | ||
|
|
||
| atomic.AddUint64(state.objectCount, 1) | ||
| ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount) | ||
| nErr, ok = state.errorCount.Load(prefix) | ||
|
|
@@ -587,17 +607,14 @@ func (s *Source) pageChunker( | |
| if nErr.(int) > 0 { | ||
| state.errorCount.Store(prefix, 0) | ||
| } | ||
|
|
||
| // Update progress after successful processing. | ||
| if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { | ||
| ctx.Logger().Error(err, "could not update progress for scanned object") | ||
| } | ||
| s.metricsCollector.RecordObjectScanned(metadata.bucket, float64(*obj.Size)) | ||
|
|
||
| return nil | ||
| }) | ||
| } | ||
|
|
||
| _ = s.jobPool.Wait() | ||
| } | ||
|
|
||
|
|
@@ -681,3 +698,76 @@ func (s *Source) visitRoles( | |
| func makeS3Link(bucket, region, key string) string { | ||
| return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key) | ||
| } | ||
|
|
||
| type S3SourceUnit struct { | ||
| Bucket string | ||
| Role string | ||
| } | ||
|
|
||
| func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) { | ||
| // The ID is the bucket name, and the kind is "s3_bucket". | ||
| return s.Bucket, "s3_bucket" | ||
| } | ||
|
||
|
|
||
| func (s S3SourceUnit) Display() string { | ||
| return s.Bucket | ||
| } | ||
|
|
||
| var _ sources.SourceUnit = S3SourceUnit{} | ||
|
|
||
| // Enumerate implements SourceUnitEnumerator interface. This implementation visits | ||
| // each configured role and passes each s3 bucket as a source unit | ||
| func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error { | ||
| visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error { | ||
| for _, bucket := range buckets { | ||
| if common.IsDone(ctx) { | ||
| return ctx.Err() | ||
| } | ||
|
|
||
| unit := S3SourceUnit{ | ||
| Bucket: bucket, | ||
| Role: roleArn, | ||
| } | ||
|
|
||
| if err := reporter.UnitOk(ctx, unit); err != nil { | ||
| return err | ||
| } | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| return s.visitRoles(ctx, visitor) | ||
| } | ||
|
|
||
| // ChunkUnit implements SourceUnitChunker interface. This implementation scans | ||
| // the given S3 bucket source unit and emits chunks for each object found. | ||
| // It supports sub-unit resumption by utilizing the checkpointer to track progress. | ||
| func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { | ||
|
|
||
| s3unit, ok := unit.(S3SourceUnit) | ||
| if !ok { | ||
| return fmt.Errorf("expected *S3SourceUnit, got %T", unit) | ||
| } | ||
| bucket := s3unit.Bucket | ||
|
|
||
| defaultClient, err := s.newClient(ctx, defaultAWSRegion, s3unit.Role) | ||
| if err != nil { | ||
| return fmt.Errorf("could not create s3 client for bucket %s and role %s: %w", bucket, s3unit.Role, err) | ||
| } | ||
|
|
||
| s.checkpointer.SetIsUnitScan(true) | ||
|
|
||
| var startAfterPtr *string | ||
| startAfter := s.Progress.GetEncodedResumeInfoFor(bucket) | ||
|
||
| if startAfter != "" { | ||
| ctx.Logger().V(3).Info( | ||
| "Resuming bucket scan", | ||
| "start_after", startAfter, | ||
| "bucket", bucket, | ||
| ) | ||
| startAfterPtr = &startAfter | ||
| } | ||
| defer s.Progress.ClearEncodedResumeInfoFor(bucket) | ||
| s.scanBucket(ctx, defaultClient, s3unit.Role, bucket, reporter, startAfterPtr) | ||
| return nil | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've defined this unit type, but you haven't modified the source to actually use it. The source type still embeds
CommonSourceUnitUnmarshaller, so it will still unmarshal source units toCommonSourceUnitinstead of your new type. You'll need to define custom unmarshalling logic. (Thegitsource has an example of custom unmarshalling logic you can look at.)Also, I recommend putting the unit struct and related code in a separate file, because we do that for several other sources, and I think it makes things more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, thanks for pointing this out. I wasn't aware of this. I'll make the changes.