// Copyright (c) 2019 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package dig

import (
	"fmt"
	"io"
	"reflect"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNewResultListErrors(t *testing.T) {
	tests := []struct {
		desc      string
		give      interface{}
		wantError string
	}{
		{
			desc:      "returns dig.In",
			give:      func() struct{ In } { panic("invalid") },
			wantError: "struct { dig.In } embeds a dig.In",
		},
		{
			desc: "returns dig.Out+dig.In",
			give: func() struct {
				Out
				In
			} {
				panic("invalid")
			},
			wantError: "struct { dig.Out; dig.In } embeds a dig.In",
		},
	}

	for _, tt := range tests {
		t.Run(tt.desc, func(t *testing.T) {
			_, err := newResultList(reflect.TypeOf(tt.give), resultOptions{})
			require.Error(t, err)
			AssertErrorMatches(t, err,
				"bad result 1:",
				`cannot provide parameter objects: `+tt.wantError)
		})
	}
}

func TestResultListExtractFails(t *testing.T) {
	rl, err := newResultList(reflect.TypeOf(func() (io.Writer, error) {
		panic("function should not be called")
	}), resultOptions{})
	require.NoError(t, err)
	assert.Panics(t, func() {
		rl.Extract(newStagingContainerWriter(), false, reflect.ValueOf("irrelevant"))
	})
}

func TestNewResultErrors(t *testing.T) {
	type outPtr struct{ *Out }
	type out struct{ Out }
	type in struct{ In }
	type inOut struct {
		In
		Out
	}

	tests := []struct {
		give interface{}
		err  string
	}{
		{
			give: outPtr{},
			err:  "cannot build a result object by embedding *dig.Out, embed dig.Out instead: dig.outPtr embeds *dig.Out",
		},
		{
			give: (*out)(nil),
			err:  "cannot return a pointer to a result object, use a value instead: *dig.out is a pointer to a struct that embeds dig.Out",
		},
		{
			give: in{},
			err:  "cannot provide parameter objects: dig.in embeds a dig.In",
		},
		{
			give: inOut{},
			err:  "cannot provide parameter objects: dig.inOut embeds a dig.In",
		},
	}

	for _, tt := range tests {
		give := reflect.TypeOf(tt.give)
		t.Run(fmt.Sprint(give), func(t *testing.T) {
			_, err := newResult(give, resultOptions{})
			require.Error(t, err)
			assert.Contains(t, err.Error(), tt.err)
		})
	}
}

func TestNewResultObject(t *testing.T) {
	typeOfReader := reflect.TypeOf((*io.Reader)(nil)).Elem()
	typeOfWriter := reflect.TypeOf((*io.Writer)(nil)).Elem()

	tests := []struct {
		desc string
		give interface{}
		opts resultOptions

		wantFields []resultObjectField
	}{
		{desc: "empty", give: struct{ Out }{}},
		{
			desc: "multiple values",
			give: struct {
				Out

				Reader io.Reader
				Writer io.Writer
			}{},
			wantFields: []resultObjectField{
				{
					FieldName:  "Reader",
					FieldIndex: 1,
					Result:     resultSingle{Type: typeOfReader},
				},
				{
					FieldName:  "Writer",
					FieldIndex: 2,
					Result:     resultSingle{Type: typeOfWriter},
				},
			},
		},
		{
			desc: "name tag",
			give: struct {
				Out

				A io.Writer `name:"stream-a"`
				B io.Writer `name:"stream-b" `
			}{},
			wantFields: []resultObjectField{
				{
					FieldName:  "A",
					FieldIndex: 1,
					Result:     resultSingle{Name: "stream-a", Type: typeOfWriter},
				},
				{
					FieldName:  "B",
					FieldIndex: 2,
					Result:     resultSingle{Name: "stream-b", Type: typeOfWriter},
				},
			},
		},
		{
			desc: "group tag",
			give: struct {
				Out

				Writer io.Writer `group:"writers"`
			}{},
			wantFields: []resultObjectField{
				{
					FieldName:  "Writer",
					FieldIndex: 1,
					Result:     resultGrouped{Group: "writers", Type: typeOfWriter},
				},
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.desc, func(t *testing.T) {
			got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts)
			require.NoError(t, err)
			assert.Equal(t, tt.wantFields, got.Fields)
		})
	}
}

func TestNewResultObjectErrors(t *testing.T) {
	tests := []struct {
		desc string
		give interface{}
		opts resultOptions
		err  string
	}{
		{
			desc: "unexported fields",
			give: struct {
				Out

				writer io.Writer
			}{},
			err: `unexported fields not allowed in dig.Out, did you mean to export "writer" (io.Writer)`,
		},
		{
			desc: "error field",
			give: struct {
				Out

				Error error
			}{},
			err: `bad field "Error" of struct { dig.Out; Error error }: cannot return an error here, return it from the constructor instead`,
		},
		{
			desc: "nested dig.In",
			give: struct {
				Out

				Nested struct{ In }
			}{},
			err: `bad field "Nested"`,
		},
		{
			desc: "group with name should fail",
			give: struct {
				Out

				Foo string `group:"foo" name:"bar"`
			}{},
			err: "cannot use named values with value groups: " +
				`name:"bar" provided with group:"foo"`,
		},
		{
			desc: "group marked as optional",
			give: struct {
				Out

				Foo string `group:"foo" optional:"true"`
			}{},
			err: "value groups cannot be optional",
		},
		{
			desc: "name option",
			give: struct {
				Out

				Reader io.Reader
			}{},
			opts: resultOptions{Name: "foo"},
			err:  `cannot specify a name for result objects`,
		},
		{
			desc: "name option with name tag",
			give: struct {
				Out

				A io.Writer `name:"stream-a"`
				B io.Writer
			}{},
			opts: resultOptions{Name: "stream"},
			err:  `cannot specify a name for result objects`,
		},
		{
			desc: "group tag with name option",
			give: struct {
				Out

				Reader io.Reader
				Writer io.Writer `group:"writers"`
			}{},
			opts: resultOptions{Name: "foo"},
			err:  `cannot specify a name for result objects`,
		},
		{
			desc: "flatten on non-slice",
			give: struct {
				Out

				Writer io.Writer `group:"writers,flatten"`
			}{},
			err: "flatten can be applied to slices only",
		},
		{
			desc: "soft on value group",
			give: struct {
				Out

				Fries []struct{} `group:"potato,flatten,soft"`
			}{},
			err: "cannot use soft with result value groups",
		},
	}

	for _, tt := range tests {
		t.Run(tt.desc, func(t *testing.T) {
			_, err := newResultObject(reflect.TypeOf(tt.give), tt.opts)
			require.Error(t, err)
			assert.Contains(t, err.Error(), tt.err)
		})
	}
}

type fakeResultVisit struct {
	Visit                result
	AnnotateWithField    *resultObjectField
	AnnotateWithPosition int
	Return               fakeResultVisits
}

func (fv fakeResultVisit) String() string {
	switch {
	case fv.Visit != nil:
		return fmt.Sprintf("Visit(%#v) -> %v", fv.Visit, fv.Return)
	case fv.AnnotateWithField != nil:
		return fmt.Sprintf("AnnotateWithField(%#v) -> %v", *fv.AnnotateWithField, fv.Return)
	default:
		return fmt.Sprintf("AnnotateWithPosition(%v) -> %v", fv.AnnotateWithPosition, fv.Return)
	}
}

type fakeResultVisits []fakeResultVisit

func (vs fakeResultVisits) Visitor(t *testing.T) resultVisitor {
	return &fakeResultVisitor{t: t, visits: vs}
}

type fakeResultVisitor struct {
	t      *testing.T
	visits fakeResultVisits
}

func (fv *fakeResultVisitor) popNext(call string) fakeResultVisit {
	if len(fv.visits) == 0 {
		fv.t.Fatalf("received unexpected call %v: no more calls were expected", call)
	}

	visit := fv.visits[0]
	fv.visits = fv.visits[1:]
	return visit
}

func (fv *fakeResultVisitor) Visit(r result) resultVisitor {
	v := fv.popNext(fmt.Sprintf("Visit(%#v)", r))
	if !reflect.DeepEqual(r, v.Visit) {
		fv.t.Fatalf("received unexpected call Visit(%#v)\nexpected %v", r, v)
	}
	return &fakeResultVisitor{t: fv.t, visits: v.Return}
}

func (fv *fakeResultVisitor) AnnotateWithField(f resultObjectField) resultVisitor {
	v := fv.popNext(fmt.Sprintf("AnnotateWithField(%#v)", f))
	if v.AnnotateWithField == nil || !reflect.DeepEqual(f, *v.AnnotateWithField) {
		fv.t.Fatalf("received unexpected call AnnotateWithField(%#v)\nexpected %v", f, v)
	}
	return &fakeResultVisitor{t: fv.t, visits: v.Return}
}

func (fv *fakeResultVisitor) AnnotateWithPosition(i int) resultVisitor {
	v := fv.popNext(fmt.Sprintf("AnnotateWithPosition(%v)", i))
	if i != v.AnnotateWithPosition {
		fv.t.Fatalf("received unexpected call AnnotateWithPosition(%v)\nexpected %v", i, v)
	}
	return &fakeResultVisitor{t: fv.t, visits: v.Return}
}

func TestWalkResult(t *testing.T) {
	t.Run("invalid result type", func(t *testing.T) {
		type badResult struct{ result }
		visitor := fakeResultVisits{
			{Visit: badResult{}, Return: fakeResultVisits{}},
		}.Visitor(t)
		assert.Panics(t,
			func() {
				walkResult(badResult{}, visitor)
			})
	})

	t.Run("resultObject ordering", func(t *testing.T) {
		type type1 struct{}
		type type2 struct{}
		type type3 struct{}
		type type4 struct{}

		typ := reflect.TypeOf(struct {
			Out

			T1 type1
			T2 type2

			Nested struct {
				Out

				T3 type3
				T4 type4
			}
		}{})

		ro, err := newResultObject(typ, resultOptions{})
		require.NoError(t, err)

		v := fakeResultVisits{
			{
				Visit: ro,
				Return: fakeResultVisits{
					{
						AnnotateWithField: &ro.Fields[0],
						Return: fakeResultVisits{
							{Visit: ro.Fields[0].Result},
						},
					},
					{
						AnnotateWithField: &ro.Fields[1],
						Return: fakeResultVisits{
							{Visit: ro.Fields[1].Result},
						},
					},
					{
						AnnotateWithField: &ro.Fields[2],
						Return: fakeResultVisits{
							{
								Visit: ro.Fields[2].Result,
								Return: fakeResultVisits{
									{
										AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[0],
										Return: fakeResultVisits{
											{Visit: ro.Fields[2].Result.(resultObject).Fields[0].Result},
										},
									},
									{
										AnnotateWithField: &ro.Fields[2].Result.(resultObject).Fields[1],
										Return: fakeResultVisits{
											{Visit: ro.Fields[2].Result.(resultObject).Fields[1].Result},
										},
									},
								},
							},
						},
					},
				},
			},
		}.Visitor(t)

		walkResult(ro, v)
	})
}