Skip to content

Commit f73caf3

Browse files
authored
Abort multi part download if the object is modified during download
1 parent 4a7aaf4 commit f73caf3

File tree

3 files changed

+177
-10
lines changed

3 files changed

+177
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "3ca07475-6a71-46be-83f9-123510baff00",
3+
"type": "bugfix",
4+
"description": "Abort multi part download if the object is modified during download",
5+
"modules": [
6+
"feature/s3/manager"
7+
]
8+
}

feature/s3/manager/download.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,15 @@ type downloader struct {
220220
in *s3.GetObjectInput
221221
w io.WriterAt
222222

223-
wg sync.WaitGroup
224-
m sync.Mutex
225-
226-
pos int64
227-
totalBytes int64
228-
written int64
229-
err error
230-
223+
wg sync.WaitGroup
224+
m sync.Mutex
225+
once sync.Once
226+
227+
pos int64
228+
totalBytes int64
229+
written int64
230+
err error
231+
etag string
231232
partBodyMaxRetries int
232233
}
233234

@@ -358,6 +359,9 @@ func (d *downloader) downloadChunk(chunk dlchunk) error {
358359

359360
// Get the next byte range of data
360361
params.Range = aws.String(chunk.ByteRange())
362+
if params.VersionId == nil && d.etag != "" {
363+
params.IfMatch = aws.String(d.etag)
364+
}
361365

362366
var n int64
363367
var err error
@@ -401,6 +405,9 @@ func (d *downloader) tryDownloadChunk(params *s3.GetObjectInput, w io.Writer) (i
401405
return 0, err
402406
}
403407
d.setTotalBytes(resp) // Set total if not yet set.
408+
d.once.Do(func() {
409+
d.etag = aws.ToString(resp.ETag)
410+
})
404411

405412
var src io.Reader = resp.Body
406413
if d.cfg.BufferProvider != nil {

feature/s3/manager/download_test.go

+154-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ type downloadCaptureClient struct {
2727
GetObjectFn func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
2828
GetObjectInvocations int
2929

30-
RetrievedRanges []string
30+
RetrievedRanges []string
31+
RetrievedETags []string
32+
RetrievedVersions []string
3133

3234
lock sync.Mutex
3335
}
@@ -41,11 +43,13 @@ func (c *downloadCaptureClient) GetObject(ctx context.Context, params *s3.GetObj
4143
if params.Range != nil {
4244
c.RetrievedRanges = append(c.RetrievedRanges, aws.ToString(params.Range))
4345
}
44-
46+
c.RetrievedETags = append(c.RetrievedETags, aws.ToString(params.IfMatch))
47+
c.RetrievedVersions = append(c.RetrievedVersions, aws.ToString(params.VersionId))
4548
return c.GetObjectFn(ctx, params, optFns...)
4649
}
4750

4851
var rangeValueRegex = regexp.MustCompile(`bytes=(\d+)-(\d+)`)
52+
var etag = "my-etag"
4953

5054
func parseRange(rangeValue string) (start, fin int64) {
5155
rng := rangeValueRegex.FindStringSubmatch(rangeValue)
@@ -90,6 +94,30 @@ func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) {
9094
return capture, &capture.GetObjectInvocations
9195
}
9296

97+
func newDownloadVersionClient(data []byte) (*downloadCaptureClient, *int, *[]string, *[]string) {
98+
capture := &downloadCaptureClient{}
99+
100+
capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
101+
start, fin := parseRange(aws.ToString(params.Range))
102+
fin++
103+
104+
if fin >= int64(len(data)) {
105+
fin = int64(len(data))
106+
}
107+
108+
bodyBytes := data[start:fin]
109+
110+
return &s3.GetObjectOutput{
111+
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
112+
ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(data))),
113+
ContentLength: aws.Int64(int64(len(bodyBytes))),
114+
ETag: aws.String(etag),
115+
}, nil
116+
}
117+
118+
return capture, &capture.GetObjectInvocations, &capture.RetrievedETags, &capture.RetrievedVersions
119+
}
120+
93121
type mockHTTPStatusError struct {
94122
StatusCode int
95123
}
@@ -522,6 +550,73 @@ func TestDownload_WithRange(t *testing.T) {
522550
}
523551
}
524552

553+
func TestDownload_WithVersionID(t *testing.T) {
554+
c, invocations, etags, versions := newDownloadVersionClient(buf12MB)
555+
556+
d := manager.NewDownloader(c)
557+
558+
w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB)))
559+
n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
560+
Bucket: aws.String("bucket"),
561+
Key: aws.String("key"),
562+
VersionId: aws.String("vid"),
563+
})
564+
565+
if err != nil {
566+
t.Fatalf("expect no error, got %v", err)
567+
}
568+
if e, a := int64(len(buf12MB)), n; e != a {
569+
t.Errorf("expect %d buffer length, got %d", e, a)
570+
}
571+
572+
if e, a := 3, *invocations; e != a {
573+
t.Errorf("expect %v API calls, got %v", e, a)
574+
}
575+
576+
expectVersions := []string{"vid", "vid", "vid"}
577+
if e, a := expectVersions, *versions; !reflect.DeepEqual(e, a) {
578+
t.Errorf("expect %v version ids, got %v", e, a)
579+
}
580+
581+
expectETags := []string{"", "", ""}
582+
if e, a := expectETags, *etags; !reflect.DeepEqual(e, a) {
583+
t.Errorf("expect %v ETags, got %v", e, a)
584+
}
585+
}
586+
587+
func TestDownload_WithETag(t *testing.T) {
588+
c, invocations, etags, versions := newDownloadVersionClient(buf12MB)
589+
590+
d := manager.NewDownloader(c)
591+
592+
w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB)))
593+
n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
594+
Bucket: aws.String("bucket"),
595+
Key: aws.String("key"),
596+
})
597+
598+
if err != nil {
599+
t.Fatalf("expect no error, got %v", err)
600+
}
601+
if e, a := int64(len(buf12MB)), n; e != a {
602+
t.Errorf("expect %d buffer length, got %d", e, a)
603+
}
604+
605+
if e, a := 3, *invocations; e != a {
606+
t.Errorf("expect %v API calls, got %v", e, a)
607+
}
608+
609+
expectVersions := []string{"", "", ""}
610+
if e, a := expectVersions, *versions; !reflect.DeepEqual(e, a) {
611+
t.Errorf("expect %v version ids, got %v", e, a)
612+
}
613+
614+
expectETags := []string{"", etag, etag}
615+
if e, a := expectETags, *etags; !reflect.DeepEqual(e, a) {
616+
t.Errorf("expect %v ETags, got %v", e, a)
617+
}
618+
}
619+
525620
type mockDownloadCLient func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
526621

527622
func (m mockDownloadCLient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
@@ -577,6 +672,63 @@ func TestDownload_WithFailure(t *testing.T) {
577672
}
578673
}
579674

675+
func TestDownload_WithMismatch(t *testing.T) {
676+
reqCount := int64(0)
677+
body := bytes.NewReader(make([]byte, manager.DefaultDownloadPartSize))
678+
679+
client := mockDownloadCLient(func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (out *s3.GetObjectOutput, err error) {
680+
switch atomic.LoadInt64(&reqCount) {
681+
case 0:
682+
if params.IfMatch != nil {
683+
t.Errorf("expect no Etag in first request, got %s", aws.ToString(params.IfMatch))
684+
err = fmt.Errorf("invalid input error")
685+
} else {
686+
out = &s3.GetObjectOutput{
687+
Body: ioutil.NopCloser(body),
688+
ContentLength: aws.Int64(int64(body.Len())),
689+
ContentRange: aws.String(fmt.Sprintf("bytes 0-%d/%d", body.Len()-1, body.Len()*10)),
690+
ETag: aws.String(etag),
691+
}
692+
}
693+
case 1:
694+
// Give a chance for the multipart chunks to be queued up
695+
time.Sleep(1 * time.Second)
696+
// mock the precondition error when object is synchronously updated
697+
err = fmt.Errorf("api error PreconditionFailed")
698+
default:
699+
if a := aws.ToString(params.IfMatch); a != etag {
700+
t.Errorf("expect subrequests' IfMatch header to be %s, got %s", etag, a)
701+
err = fmt.Errorf("invalid input error")
702+
} else {
703+
out = &s3.GetObjectOutput{
704+
Body: ioutil.NopCloser(body),
705+
ContentLength: aws.Int64(int64(body.Len())),
706+
ContentRange: aws.String(fmt.Sprintf("bytes 0-%d/%d", body.Len()-1, body.Len()*10)),
707+
}
708+
}
709+
}
710+
atomic.AddInt64(&reqCount, 1)
711+
return out, err
712+
})
713+
714+
d := manager.NewDownloader(client, func(d *manager.Downloader) {
715+
d.Concurrency = 2
716+
})
717+
718+
w := &manager.WriteAtBuffer{}
719+
params := s3.GetObjectInput{
720+
Bucket: aws.String("Bucket"),
721+
Key: aws.String("Key"),
722+
}
723+
724+
_, err := d.Download(context.Background(), w, &params)
725+
if err == nil {
726+
t.Fatalf("expect error, got none")
727+
} else if e, a := "PreconditionFailed", err.Error(); !strings.Contains(a, e) {
728+
t.Fatalf("expect error message to contain %s, but did not %s", e, a)
729+
}
730+
}
731+
580732
func TestDownloadBufferStrategy(t *testing.T) {
581733
cases := map[string]struct {
582734
partSize int64

0 commit comments

Comments
 (0)