From 9259fc3884d2ba7d633dcb0abf29a3de0811bd3c Mon Sep 17 00:00:00 2001 From: Stefan McShane Date: Fri, 2 Apr 2021 18:11:41 -0400 Subject: [PATCH] Added search with channels first draft --- v3/error.go | 1 + v3/examples_test.go | 39 ++++++++++++++++++++ v3/go.mod | 2 + v3/go.sum | 4 ++ v3/search.go | 89 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+) diff --git a/v3/error.go b/v3/error.go index 3cdb7b31..3d76f0e0 100644 --- a/v3/error.go +++ b/v3/error.go @@ -89,6 +89,7 @@ const ( ErrorUnexpectedMessage = 204 ErrorUnexpectedResponse = 205 ErrorEmptyPassword = 206 + ErrorUsage = 207 ) // LDAPResultCodeMap contains string descriptions for LDAP error codes diff --git a/v3/examples_test.go b/v3/examples_test.go index da7531fc..a30075fd 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "log" + "sync" ) // This example demonstrates how to bind a connection to an ldap user @@ -49,6 +50,44 @@ func ExampleConn_Search() { } } +func ExampleConn_SearchWithChannel() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + // this is basically how Search() does it: + ch := make(chan *SearchResult) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + for res := range ch { + if len(res.Entries) != 0 { + fmt.Printf("%s has DN %s\n", res.Entries[0].GetAttributeValue("cn"), res.Entries[0].DN) + } + } + wg.Done() + }() + + err = l.SearchWithChannel(searchRequest, ch) + + wg.Wait() + + if err != nil { + log.Fatalf("Error while searching: %s", err) + } +} + // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/v3/go.mod b/v3/go.mod index 931e5967..df5a1c18 100644 --- a/v3/go.mod +++ b/v3/go.mod @@ -5,5 +5,7 @@ go 1.13 require ( github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c github.com/go-asn1-ber/asn1-ber v1.5.1 + github.com/go-ldap/ldap v3.0.3+incompatible golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect + gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect ) diff --git a/v3/go.sum b/v3/go.sum index 0d8a4f68..af372d9d 100644 --- a/v3/go.sum +++ b/v3/go.sum @@ -2,6 +2,8 @@ github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzU github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/go-asn1-ber/asn1-ber v1.5.1 h1:pDbRAunXzIUXfx4CB2QJFv5IuPiuoW+sWvr/Us009o8= github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap v3.0.3+incompatible h1:HTeSZO8hWMS1Rgb2Ziku6b8a7qRIZZMHjsvuZyatzwk= +github.com/go-ldap/ldap v3.0.3+incompatible/go.mod h1:qfd9rJvER9Q0/D/Sqn1DfHRoBp40uXYvFoEVrNEPqRc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 h1:vEg9joUBmeBcK9iSJftGNf3coIG4HqZElCPehJsfAYM= golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -9,3 +11,5 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d h1:TxyelI5cVkbREznMhfzycHdkp5cLA7DpE+GKjSslYhM= +gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw= diff --git a/v3/search.go b/v3/search.go index 4fcc794a..36052ba5 100644 --- a/v3/search.go +++ b/v3/search.go @@ -408,3 +408,92 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } } + +// SearchWithChannel performs a search request and returns all search results via the given +// channel as soon as they are received. This means you get all results until an error +// happens (or the search successfully finished), e.g. for size / time limited requests all +// are recieved via the channel until the limit is reached. +func (l *Conn) SearchWithChannel(searchRequest *SearchRequest, ch chan *SearchResult) error { + if ch == nil { + return NewError(ErrorUsage, errors.New("ldap: SearchWithChannel got nil channel")) + } + defer close(ch) + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + return err + } + + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + foundSearchResultDone := false + for !foundSearchResultDone { + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return err + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return err + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + attr.ByteValues = append(attr.ByteValues, value.ByteValue) + } + entry.Attributes = append(entry.Attributes, attr) + } + ch <- &SearchResult{Entries: []*Entry{entry}} + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + return err + } + if len(packet.Children) == 3 { + result := &SearchResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + return fmt.Errorf("failed to decode child control: %s", err) + } + result.Controls = append(result.Controls, decodedChild) + } + ch <- result + } + foundSearchResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + ch <- &SearchResult{Referrals: []string{ref}} + } + } + + l.Debug.Printf("%d: returning", msgCtx.id) + return nil + +}