@@ -27,7 +27,9 @@ type downloadCaptureClient struct {
27
27
GetObjectFn func (context.Context , * s3.GetObjectInput , ... func (* s3.Options )) (* s3.GetObjectOutput , error )
28
28
GetObjectInvocations int
29
29
30
- RetrievedRanges []string
30
+ RetrievedRanges []string
31
+ RetrievedETags []string
32
+ RetrievedVersions []string
31
33
32
34
lock sync.Mutex
33
35
}
@@ -41,11 +43,13 @@ func (c *downloadCaptureClient) GetObject(ctx context.Context, params *s3.GetObj
41
43
if params .Range != nil {
42
44
c .RetrievedRanges = append (c .RetrievedRanges , aws .ToString (params .Range ))
43
45
}
44
-
46
+ c .RetrievedETags = append (c .RetrievedETags , aws .ToString (params .IfMatch ))
47
+ c .RetrievedVersions = append (c .RetrievedVersions , aws .ToString (params .VersionId ))
45
48
return c .GetObjectFn (ctx , params , optFns ... )
46
49
}
47
50
48
51
var rangeValueRegex = regexp .MustCompile (`bytes=(\d+)-(\d+)` )
52
+ var etag = "my-etag"
49
53
50
54
func parseRange (rangeValue string ) (start , fin int64 ) {
51
55
rng := rangeValueRegex .FindStringSubmatch (rangeValue )
@@ -90,6 +94,30 @@ func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) {
90
94
return capture , & capture .GetObjectInvocations
91
95
}
92
96
97
+ func newDownloadVersionClient (data []byte ) (* downloadCaptureClient , * int , * []string , * []string ) {
98
+ capture := & downloadCaptureClient {}
99
+
100
+ capture .GetObjectFn = func (_ context.Context , params * s3.GetObjectInput , _ ... func (* s3.Options )) (* s3.GetObjectOutput , error ) {
101
+ start , fin := parseRange (aws .ToString (params .Range ))
102
+ fin ++
103
+
104
+ if fin >= int64 (len (data )) {
105
+ fin = int64 (len (data ))
106
+ }
107
+
108
+ bodyBytes := data [start :fin ]
109
+
110
+ return & s3.GetObjectOutput {
111
+ Body : ioutil .NopCloser (bytes .NewReader (bodyBytes )),
112
+ ContentRange : aws .String (fmt .Sprintf ("bytes %d-%d/%d" , start , fin - 1 , len (data ))),
113
+ ContentLength : aws .Int64 (int64 (len (bodyBytes ))),
114
+ ETag : aws .String (etag ),
115
+ }, nil
116
+ }
117
+
118
+ return capture , & capture .GetObjectInvocations , & capture .RetrievedETags , & capture .RetrievedVersions
119
+ }
120
+
93
121
type mockHTTPStatusError struct {
94
122
StatusCode int
95
123
}
@@ -522,6 +550,73 @@ func TestDownload_WithRange(t *testing.T) {
522
550
}
523
551
}
524
552
553
+ func TestDownload_WithVersionID (t * testing.T ) {
554
+ c , invocations , etags , versions := newDownloadVersionClient (buf12MB )
555
+
556
+ d := manager .NewDownloader (c )
557
+
558
+ w := manager .NewWriteAtBuffer (make ([]byte , len (buf12MB )))
559
+ n , err := d .Download (context .Background (), w , & s3.GetObjectInput {
560
+ Bucket : aws .String ("bucket" ),
561
+ Key : aws .String ("key" ),
562
+ VersionId : aws .String ("vid" ),
563
+ })
564
+
565
+ if err != nil {
566
+ t .Fatalf ("expect no error, got %v" , err )
567
+ }
568
+ if e , a := int64 (len (buf12MB )), n ; e != a {
569
+ t .Errorf ("expect %d buffer length, got %d" , e , a )
570
+ }
571
+
572
+ if e , a := 3 , * invocations ; e != a {
573
+ t .Errorf ("expect %v API calls, got %v" , e , a )
574
+ }
575
+
576
+ expectVersions := []string {"vid" , "vid" , "vid" }
577
+ if e , a := expectVersions , * versions ; ! reflect .DeepEqual (e , a ) {
578
+ t .Errorf ("expect %v version ids, got %v" , e , a )
579
+ }
580
+
581
+ expectETags := []string {"" , "" , "" }
582
+ if e , a := expectETags , * etags ; ! reflect .DeepEqual (e , a ) {
583
+ t .Errorf ("expect %v ETags, got %v" , e , a )
584
+ }
585
+ }
586
+
587
+ func TestDownload_WithETag (t * testing.T ) {
588
+ c , invocations , etags , versions := newDownloadVersionClient (buf12MB )
589
+
590
+ d := manager .NewDownloader (c )
591
+
592
+ w := manager .NewWriteAtBuffer (make ([]byte , len (buf12MB )))
593
+ n , err := d .Download (context .Background (), w , & s3.GetObjectInput {
594
+ Bucket : aws .String ("bucket" ),
595
+ Key : aws .String ("key" ),
596
+ })
597
+
598
+ if err != nil {
599
+ t .Fatalf ("expect no error, got %v" , err )
600
+ }
601
+ if e , a := int64 (len (buf12MB )), n ; e != a {
602
+ t .Errorf ("expect %d buffer length, got %d" , e , a )
603
+ }
604
+
605
+ if e , a := 3 , * invocations ; e != a {
606
+ t .Errorf ("expect %v API calls, got %v" , e , a )
607
+ }
608
+
609
+ expectVersions := []string {"" , "" , "" }
610
+ if e , a := expectVersions , * versions ; ! reflect .DeepEqual (e , a ) {
611
+ t .Errorf ("expect %v version ids, got %v" , e , a )
612
+ }
613
+
614
+ expectETags := []string {"" , etag , etag }
615
+ if e , a := expectETags , * etags ; ! reflect .DeepEqual (e , a ) {
616
+ t .Errorf ("expect %v ETags, got %v" , e , a )
617
+ }
618
+ }
619
+
525
620
type mockDownloadCLient func (ctx context.Context , params * s3.GetObjectInput , optFns ... func (* s3.Options )) (* s3.GetObjectOutput , error )
526
621
527
622
func (m mockDownloadCLient ) GetObject (ctx context.Context , params * s3.GetObjectInput , optFns ... func (* s3.Options )) (* s3.GetObjectOutput , error ) {
@@ -577,6 +672,63 @@ func TestDownload_WithFailure(t *testing.T) {
577
672
}
578
673
}
579
674
675
+ func TestDownload_WithMismatch (t * testing.T ) {
676
+ reqCount := int64 (0 )
677
+ body := bytes .NewReader (make ([]byte , manager .DefaultDownloadPartSize ))
678
+
679
+ client := mockDownloadCLient (func (ctx context.Context , params * s3.GetObjectInput , optFns ... func (* s3.Options )) (out * s3.GetObjectOutput , err error ) {
680
+ switch atomic .LoadInt64 (& reqCount ) {
681
+ case 0 :
682
+ if params .IfMatch != nil {
683
+ t .Errorf ("expect no Etag in first request, got %s" , aws .ToString (params .IfMatch ))
684
+ err = fmt .Errorf ("invalid input error" )
685
+ } else {
686
+ out = & s3.GetObjectOutput {
687
+ Body : ioutil .NopCloser (body ),
688
+ ContentLength : aws .Int64 (int64 (body .Len ())),
689
+ ContentRange : aws .String (fmt .Sprintf ("bytes 0-%d/%d" , body .Len ()- 1 , body .Len ()* 10 )),
690
+ ETag : aws .String (etag ),
691
+ }
692
+ }
693
+ case 1 :
694
+ // Give a chance for the multipart chunks to be queued up
695
+ time .Sleep (1 * time .Second )
696
+ // mock the precondition error when object is synchronously updated
697
+ err = fmt .Errorf ("api error PreconditionFailed" )
698
+ default :
699
+ if a := aws .ToString (params .IfMatch ); a != etag {
700
+ t .Errorf ("expect subrequests' IfMatch header to be %s, got %s" , etag , a )
701
+ err = fmt .Errorf ("invalid input error" )
702
+ } else {
703
+ out = & s3.GetObjectOutput {
704
+ Body : ioutil .NopCloser (body ),
705
+ ContentLength : aws .Int64 (int64 (body .Len ())),
706
+ ContentRange : aws .String (fmt .Sprintf ("bytes 0-%d/%d" , body .Len ()- 1 , body .Len ()* 10 )),
707
+ }
708
+ }
709
+ }
710
+ atomic .AddInt64 (& reqCount , 1 )
711
+ return out , err
712
+ })
713
+
714
+ d := manager .NewDownloader (client , func (d * manager.Downloader ) {
715
+ d .Concurrency = 2
716
+ })
717
+
718
+ w := & manager.WriteAtBuffer {}
719
+ params := s3.GetObjectInput {
720
+ Bucket : aws .String ("Bucket" ),
721
+ Key : aws .String ("Key" ),
722
+ }
723
+
724
+ _ , err := d .Download (context .Background (), w , & params )
725
+ if err == nil {
726
+ t .Fatalf ("expect error, got none" )
727
+ } else if e , a := "PreconditionFailed" , err .Error (); ! strings .Contains (a , e ) {
728
+ t .Fatalf ("expect error message to contain %s, but did not %s" , e , a )
729
+ }
730
+ }
731
+
580
732
func TestDownloadBufferStrategy (t * testing.T ) {
581
733
cases := map [string ]struct {
582
734
partSize int64
0 commit comments