Skip to content

Commit 3aace41

Browse files
authored
Transfer Manager v2 downloader concurrency fix and version control (#3058)
* migrate parts GET to concurrent reader
1 parent a2e6c6f commit 3aace41

9 files changed

+726
-562
lines changed

feature/s3/transfermanager/api_op_DownloadObject.go

+19-12
Original file line numberDiff line numberDiff line change
@@ -541,19 +541,22 @@ type downloader struct {
541541
in *DownloadObjectInput
542542
out *DownloadObjectOutput
543543

544-
wg sync.WaitGroup
545-
m sync.Mutex
544+
wg sync.WaitGroup
545+
m sync.Mutex
546+
etagOnce sync.Once
547+
totalBytesOnce sync.Once
546548

547549
offset int64
548550
pos int64
549551
totalBytes int64
550552
written int64
553+
etag string
551554

552555
err error
553556
}
554557

555558
func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error) {
556-
if err := d.init(ctx); err != nil {
559+
if err := d.init(); err != nil {
557560
return nil, fmt.Errorf("unable to initialize download: %w", err)
558561
}
559562

@@ -600,12 +603,12 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
600603
d.wg.Wait()
601604
}
602605
} else {
603-
if d.in.Range == "" {
604-
output = d.getChunk(ctx, 0, d.byteRange(), clientOptions...)
605-
} else {
606+
if d.in.Range != "" {
606607
d.pos, d.totalBytes = d.getDownloadRange()
607608
d.offset = d.pos
608609
}
610+
611+
d.getChunk(ctx, 0, d.byteRange(), clientOptions...)
609612
total := d.totalBytes
610613

611614
ch := make(chan dlChunk, d.options.Concurrency)
@@ -639,7 +642,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
639642
return d.out, nil
640643
}
641644

642-
func (d *downloader) init(ctx context.Context) error {
645+
func (d *downloader) init() error {
643646
if d.options.PartSizeBytes < minPartSizeBytes {
644647
return fmt.Errorf("part size must be at least %d bytes", minPartSizeBytes)
645648
}
@@ -655,7 +658,6 @@ func (d *downloader) init(ctx context.Context) error {
655658

656659
func (d *downloader) singleDownload(ctx context.Context, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
657660
chunk := dlChunk{w: d.in.WriterAt}
658-
// d.in.PartNumber = 0
659661
output, err := d.downloadChunk(ctx, chunk, clientOptions...)
660662
if err != nil {
661663
return nil, err
@@ -708,6 +710,9 @@ func (d *downloader) downloadChunk(ctx context.Context, chunk dlChunk, clientOpt
708710
if chunk.withRange != "" {
709711
params.Range = aws.String(chunk.withRange)
710712
}
713+
if params.VersionId == nil && d.etag != "" {
714+
params.IfMatch = aws.String(d.etag)
715+
}
711716

712717
var out *s3.GetObjectOutput
713718
var n int64
@@ -737,6 +742,9 @@ func (d *downloader) downloadChunk(ctx context.Context, chunk dlChunk, clientOpt
737742
if out != nil {
738743
output = &DownloadObjectOutput{}
739744
output.mapFromGetObjectOutput(out, params.ChecksumMode)
745+
d.etagOnce.Do(func() {
746+
d.etag = aws.ToString(out.ETag)
747+
})
740748
}
741749
return output, err
742750
}
@@ -747,7 +755,9 @@ func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectI
747755
return nil, 0, err
748756
}
749757

750-
d.setTotalBytes(out) // Set total if not yet set.
758+
d.totalBytesOnce.Do(func() {
759+
d.setTotalBytes(out)
760+
}) // Set total in first GET
751761

752762
var n int64
753763
defer out.Body.Close()
@@ -780,9 +790,6 @@ func (d *downloader) getTotalBytes() int64 {
780790
// does not include a Content-Range. Meaning the object was not chunked. This
781791
// occurs when the full file fits within the PartSize directive.
782792
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
783-
d.m.Lock()
784-
defer d.m.Unlock()
785-
786793
if d.totalBytes >= 0 {
787794
return
788795
}

feature/s3/transfermanager/api_op_DownloadObject_integ_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func TestInteg_DownloadObject(t *testing.T) {
14-
cases := map[string]getObjectTestData{
14+
cases := map[string]downloadObjectTestData{
1515
"part get seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")},
1616
"part get empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")},
1717
"part get multipart body": {Body: bytes.NewReader(largeObjectBuf), ExpectBody: largeObjectBuf},

0 commit comments

Comments
 (0)