Skip to content

Commit f7929a8

Browse files
committed
feat: add search with channels inspired by go-ldap#319
1 parent cdb0754 commit f7929a8

6 files changed

+436
-0
lines changed

examples_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
67
"fmt"
@@ -50,6 +51,34 @@ func ExampleConn_Search() {
5051
}
5152
}
5253

54+
// This example demonstrates how to search with channel
55+
func ExampleConn_SearchWithChannel() {
56+
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
defer l.Close()
61+
62+
searchRequest := NewSearchRequest(
63+
"dc=example,dc=com", // The base dn to search
64+
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
65+
"(&(objectClass=organizationalPerson))", // The filter to apply
66+
[]string{"dn", "cn"}, // A list attributes to retrieve
67+
nil,
68+
)
69+
70+
ctx, cancel := context.WithCancel(context.Background())
71+
defer cancel()
72+
73+
ch := l.SearchWithChannel(ctx, searchRequest)
74+
for res := range ch {
75+
if res.Error != nil {
76+
log.Fatalf("Error searching: %s", res.Error)
77+
}
78+
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
79+
}
80+
}
81+
5382
// This example demonstrates how to start a TLS connection
5483
func ExampleConn_StartTLS() {
5584
l, err := DialURL("ldap://ldap.example.com:389")

ldap_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"testing"
67

@@ -344,3 +345,63 @@ func TestEscapeDN(t *testing.T) {
344345
})
345346
}
346347
}
348+
349+
func TestSearchWithChannel(t *testing.T) {
350+
l, err := DialURL(ldapServer)
351+
if err != nil {
352+
t.Fatal(err)
353+
}
354+
defer l.Close()
355+
356+
searchRequest := NewSearchRequest(
357+
baseDN,
358+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
359+
filter[2],
360+
attributes,
361+
nil)
362+
363+
srs := make([]*Entry, 0)
364+
ctx := context.Background()
365+
for sr := range l.SearchWithChannel(ctx, searchRequest) {
366+
if sr.Error != nil {
367+
t.Fatal(err)
368+
}
369+
srs = append(srs, sr.Entry)
370+
}
371+
372+
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
373+
}
374+
375+
func TestSearchWithChannelAndCancel(t *testing.T) {
376+
l, err := DialURL(ldapServer)
377+
if err != nil {
378+
t.Fatal(err)
379+
}
380+
defer l.Close()
381+
382+
searchRequest := NewSearchRequest(
383+
baseDN,
384+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
385+
filter[2],
386+
attributes,
387+
nil)
388+
389+
cancelNum := 10
390+
srs := make([]*Entry, 0)
391+
ctx, cancel := context.WithCancel(context.Background())
392+
for sr := range l.SearchWithChannel(ctx, searchRequest) {
393+
if sr.Error != nil {
394+
t.Fatal(err)
395+
}
396+
srs = append(srs, sr.Entry)
397+
if len(srs) == cancelNum {
398+
cancel()
399+
}
400+
}
401+
if len(srs) > cancelNum+2 {
402+
// The cancel process is asynchronous,
403+
// so a few entries after it canceled might be received
404+
t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2)
405+
}
406+
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
407+
}

search.go

+128
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"reflect"
@@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) {
375376
r.Controls = append(r.Controls, s.Controls...)
376377
}
377378

379+
// SearchSingleResult holds the server's single response to a search request
380+
type SearchSingleResult struct {
381+
// Entry is the returned entry
382+
Entry *Entry
383+
// Referral is the returned referral
384+
Referral string
385+
// Controls are the returned controls
386+
Controls []Control
387+
// Error is set when the search request was failed
388+
Error error
389+
}
390+
391+
// Print outputs a human-readable description
392+
func (s *SearchSingleResult) Print() {
393+
s.Entry.Print()
394+
}
395+
396+
// PrettyPrint outputs a human-readable description with indenting
397+
func (s *SearchSingleResult) PrettyPrint(indent int) {
398+
s.Entry.PrettyPrint(indent)
399+
}
400+
378401
// SearchRequest represents a search request to send to the server
379402
type SearchRequest struct {
380403
BaseDN string
@@ -559,6 +582,111 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
559582
}
560583
}
561584

585+
// SearchWithChannel performs a search request and returns all search results
586+
// via the returned channel as soon as they are received. This means you get
587+
// all results until an error happens (or the search successfully finished),
588+
// e.g. for size / time limited requests all are recieved via the channel
589+
// until the limit is reached.
590+
func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult {
591+
ch := make(chan *SearchSingleResult)
592+
go func() {
593+
defer close(ch)
594+
if l.IsClosing() {
595+
return
596+
}
597+
598+
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
599+
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
600+
// encode search request
601+
err := searchRequest.appendTo(packet)
602+
if err != nil {
603+
ch <- &SearchSingleResult{Error: err}
604+
return
605+
}
606+
l.Debug.PrintPacket(packet)
607+
608+
msgCtx, err := l.sendMessage(packet)
609+
if err != nil {
610+
ch <- &SearchSingleResult{Error: err}
611+
return
612+
}
613+
defer l.finishMessage(msgCtx)
614+
615+
foundSearchSingleResultDone := false
616+
for !foundSearchSingleResultDone {
617+
select {
618+
case <-ctx.Done():
619+
l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
620+
return
621+
default:
622+
l.Debug.Printf("%d: waiting for response", msgCtx.id)
623+
packetResponse, ok := <-msgCtx.responses
624+
if !ok {
625+
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
626+
ch <- &SearchSingleResult{Error: err}
627+
return
628+
}
629+
packet, err = packetResponse.ReadPacket()
630+
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
631+
if err != nil {
632+
ch <- &SearchSingleResult{Error: err}
633+
return
634+
}
635+
636+
if l.Debug {
637+
if err := addLDAPDescriptions(packet); err != nil {
638+
ch <- &SearchSingleResult{Error: err}
639+
return
640+
}
641+
ber.PrintPacket(packet)
642+
}
643+
644+
switch packet.Children[1].Tag {
645+
case ApplicationSearchResultEntry:
646+
entry := new(Entry)
647+
entry.DN = packet.Children[1].Children[0].Value.(string)
648+
for _, child := range packet.Children[1].Children[1].Children {
649+
attr := new(EntryAttribute)
650+
attr.Name = child.Children[0].Value.(string)
651+
for _, value := range child.Children[1].Children {
652+
attr.Values = append(attr.Values, value.Value.(string))
653+
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
654+
}
655+
entry.Attributes = append(entry.Attributes, attr)
656+
}
657+
ch <- &SearchSingleResult{Entry: entry}
658+
659+
case ApplicationSearchResultDone:
660+
if err := GetLDAPError(packet); err != nil {
661+
ch <- &SearchSingleResult{Error: err}
662+
return
663+
}
664+
if len(packet.Children) == 3 {
665+
result := &SearchSingleResult{}
666+
for _, child := range packet.Children[2].Children {
667+
decodedChild, err := DecodeControl(child)
668+
if err != nil {
669+
werr := fmt.Errorf("failed to decode child control: %w", err)
670+
ch <- &SearchSingleResult{Error: werr}
671+
return
672+
}
673+
result.Controls = append(result.Controls, decodedChild)
674+
}
675+
ch <- result
676+
}
677+
foundSearchSingleResultDone = true
678+
679+
case ApplicationSearchResultReference:
680+
ref := packet.Children[1].Children[0].Value.(string)
681+
ch <- &SearchSingleResult{Referral: ref}
682+
}
683+
}
684+
}
685+
l.Debug.Printf("%d: returning", msgCtx.id)
686+
}()
687+
return ch
688+
}
689+
562690
// unpackAttributes will extract all given LDAP attributes and it's values
563691
// from the ber.Packet
564692
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {

v3/examples_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
67
"fmt"
@@ -50,6 +51,34 @@ func ExampleConn_Search() {
5051
}
5152
}
5253

54+
// This example demonstrates how to search with channel
55+
func ExampleConn_SearchWithChannel() {
56+
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
defer l.Close()
61+
62+
searchRequest := NewSearchRequest(
63+
"dc=example,dc=com", // The base dn to search
64+
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
65+
"(&(objectClass=organizationalPerson))", // The filter to apply
66+
[]string{"dn", "cn"}, // A list attributes to retrieve
67+
nil,
68+
)
69+
70+
ctx, cancel := context.WithCancel(context.Background())
71+
defer cancel()
72+
73+
ch := l.SearchWithChannel(ctx, searchRequest)
74+
for res := range ch {
75+
if res.Error != nil {
76+
log.Fatalf("Error searching: %s", res.Error)
77+
}
78+
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
79+
}
80+
}
81+
5382
// This example demonstrates how to start a TLS connection
5483
func ExampleConn_StartTLS() {
5584
l, err := DialURL("ldap://ldap.example.com:389")

v3/ldap_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"testing"
67

@@ -344,3 +345,63 @@ func TestEscapeDN(t *testing.T) {
344345
})
345346
}
346347
}
348+
349+
func TestSearchWithChannel(t *testing.T) {
350+
l, err := DialURL(ldapServer)
351+
if err != nil {
352+
t.Fatal(err)
353+
}
354+
defer l.Close()
355+
356+
searchRequest := NewSearchRequest(
357+
baseDN,
358+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
359+
filter[2],
360+
attributes,
361+
nil)
362+
363+
srs := make([]*Entry, 0)
364+
ctx := context.Background()
365+
for sr := range l.SearchWithChannel(ctx, searchRequest) {
366+
if sr.Error != nil {
367+
t.Fatal(err)
368+
}
369+
srs = append(srs, sr.Entry)
370+
}
371+
372+
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
373+
}
374+
375+
func TestSearchWithChannelAndCancel(t *testing.T) {
376+
l, err := DialURL(ldapServer)
377+
if err != nil {
378+
t.Fatal(err)
379+
}
380+
defer l.Close()
381+
382+
searchRequest := NewSearchRequest(
383+
baseDN,
384+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
385+
filter[2],
386+
attributes,
387+
nil)
388+
389+
cancelNum := 10
390+
srs := make([]*Entry, 0)
391+
ctx, cancel := context.WithCancel(context.Background())
392+
for sr := range l.SearchWithChannel(ctx, searchRequest) {
393+
if sr.Error != nil {
394+
t.Fatal(err)
395+
}
396+
srs = append(srs, sr.Entry)
397+
if len(srs) == cancelNum {
398+
cancel()
399+
}
400+
}
401+
if len(srs) > cancelNum+2 {
402+
// The cancel process is asynchronous,
403+
// so a few entries after it canceled might be received
404+
t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2)
405+
}
406+
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
407+
}

0 commit comments

Comments
 (0)