Платформа ЦРНП "Мирокод" для разработки проектов
https://git.mirocod.ru
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
570 lines
15 KiB
570 lines
15 KiB
package ldap |
|
|
|
import ( |
|
"crypto/tls" |
|
"errors" |
|
"fmt" |
|
"log" |
|
"net" |
|
"net/url" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
|
|
ber "github.com/go-asn1-ber/asn1-ber" |
|
) |
|
|
|
const ( |
|
// MessageQuit causes the processMessages loop to exit |
|
MessageQuit = 0 |
|
// MessageRequest sends a request to the server |
|
MessageRequest = 1 |
|
// MessageResponse receives a response from the server |
|
MessageResponse = 2 |
|
// MessageFinish indicates the client considers a particular message ID to be finished |
|
MessageFinish = 3 |
|
// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached |
|
MessageTimeout = 4 |
|
) |
|
|
|
const ( |
|
// DefaultLdapPort default ldap port for pure TCP connection |
|
DefaultLdapPort = "389" |
|
// DefaultLdapsPort default ldap port for SSL connection |
|
DefaultLdapsPort = "636" |
|
) |
|
|
|
// PacketResponse contains the packet or error encountered reading a response |
|
type PacketResponse struct { |
|
// Packet is the packet read from the server |
|
Packet *ber.Packet |
|
// Error is an error encountered while reading |
|
Error error |
|
} |
|
|
|
// ReadPacket returns the packet or an error |
|
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { |
|
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) |
|
} |
|
return pr.Packet, pr.Error |
|
} |
|
|
|
type messageContext struct { |
|
id int64 |
|
// close(done) should only be called from finishMessage() |
|
done chan struct{} |
|
// close(responses) should only be called from processMessages(), and only sent to from sendResponse() |
|
responses chan *PacketResponse |
|
} |
|
|
|
// sendResponse should only be called within the processMessages() loop which |
|
// is also responsible for closing the responses channel. |
|
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { |
|
select { |
|
case msgCtx.responses <- packet: |
|
// Successfully sent packet to message handler. |
|
case <-msgCtx.done: |
|
// The request handler is done and will not receive more |
|
// packets. |
|
} |
|
} |
|
|
|
type messagePacket struct { |
|
Op int |
|
MessageID int64 |
|
Packet *ber.Packet |
|
Context *messageContext |
|
} |
|
|
|
type sendMessageFlags uint |
|
|
|
const ( |
|
startTLS sendMessageFlags = 1 << iota |
|
) |
|
|
|
// Conn represents an LDAP Connection |
|
type Conn struct { |
|
// requestTimeout is loaded atomically |
|
// so we need to ensure 64-bit alignment on 32-bit platforms. |
|
requestTimeout int64 |
|
conn net.Conn |
|
isTLS bool |
|
closing uint32 |
|
closeErr atomic.Value |
|
isStartingTLS bool |
|
Debug debugging |
|
chanConfirm chan struct{} |
|
messageContexts map[int64]*messageContext |
|
chanMessage chan *messagePacket |
|
chanMessageID chan int64 |
|
wgClose sync.WaitGroup |
|
outstandingRequests uint |
|
messageMutex sync.Mutex |
|
} |
|
|
|
var _ Client = &Conn{} |
|
|
|
// DefaultTimeout is a package-level variable that sets the timeout value |
|
// used for the Dial and DialTLS methods. |
|
// |
|
// WARNING: since this is a package-level variable, setting this value from |
|
// multiple places will probably result in undesired behaviour. |
|
var DefaultTimeout = 60 * time.Second |
|
|
|
// DialOpt configures DialContext. |
|
type DialOpt func(*DialContext) |
|
|
|
// DialWithDialer updates net.Dialer in DialContext. |
|
func DialWithDialer(d *net.Dialer) DialOpt { |
|
return func(dc *DialContext) { |
|
dc.d = d |
|
} |
|
} |
|
|
|
// DialWithTLSConfig updates tls.Config in DialContext. |
|
func DialWithTLSConfig(tc *tls.Config) DialOpt { |
|
return func(dc *DialContext) { |
|
dc.tc = tc |
|
} |
|
} |
|
|
|
// DialContext contains necessary parameters to dial the given ldap URL. |
|
type DialContext struct { |
|
d *net.Dialer |
|
tc *tls.Config |
|
} |
|
|
|
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { |
|
if u.Scheme == "ldapi" { |
|
if u.Path == "" || u.Path == "/" { |
|
u.Path = "/var/run/slapd/ldapi" |
|
} |
|
return dc.d.Dial("unix", u.Path) |
|
} |
|
|
|
host, port, err := net.SplitHostPort(u.Host) |
|
if err != nil { |
|
// we assume that error is due to missing port |
|
host = u.Host |
|
port = "" |
|
} |
|
|
|
switch u.Scheme { |
|
case "ldap": |
|
if port == "" { |
|
port = DefaultLdapPort |
|
} |
|
return dc.d.Dial("tcp", net.JoinHostPort(host, port)) |
|
case "ldaps": |
|
if port == "" { |
|
port = DefaultLdapsPort |
|
} |
|
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) |
|
} |
|
|
|
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) |
|
} |
|
|
|
// Dial connects to the given address on the given network using net.Dial |
|
// and then returns a new Conn for the connection. |
|
// @deprecated Use DialURL instead. |
|
func Dial(network, addr string) (*Conn, error) { |
|
c, err := net.DialTimeout(network, addr, DefaultTimeout) |
|
if err != nil { |
|
return nil, NewError(ErrorNetwork, err) |
|
} |
|
conn := NewConn(c, false) |
|
conn.Start() |
|
return conn, nil |
|
} |
|
|
|
// DialTLS connects to the given address on the given network using tls.Dial |
|
// and then returns a new Conn for the connection. |
|
// @deprecated Use DialURL instead. |
|
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { |
|
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config) |
|
if err != nil { |
|
return nil, NewError(ErrorNetwork, err) |
|
} |
|
conn := NewConn(c, true) |
|
conn.Start() |
|
return conn, nil |
|
} |
|
|
|
// DialURL connects to the given ldap URL. |
|
// The following schemas are supported: ldap://, ldaps://, ldapi://. |
|
// On success a new Conn for the connection is returned. |
|
func DialURL(addr string, opts ...DialOpt) (*Conn, error) { |
|
u, err := url.Parse(addr) |
|
if err != nil { |
|
return nil, NewError(ErrorNetwork, err) |
|
} |
|
|
|
var dc DialContext |
|
for _, opt := range opts { |
|
opt(&dc) |
|
} |
|
if dc.d == nil { |
|
dc.d = &net.Dialer{Timeout: DefaultTimeout} |
|
} |
|
|
|
c, err := dc.dial(u) |
|
if err != nil { |
|
return nil, NewError(ErrorNetwork, err) |
|
} |
|
|
|
conn := NewConn(c, u.Scheme == "ldaps") |
|
conn.Start() |
|
return conn, nil |
|
} |
|
|
|
// NewConn returns a new Conn using conn for network I/O. |
|
func NewConn(conn net.Conn, isTLS bool) *Conn { |
|
return &Conn{ |
|
conn: conn, |
|
chanConfirm: make(chan struct{}), |
|
chanMessageID: make(chan int64), |
|
chanMessage: make(chan *messagePacket, 10), |
|
messageContexts: map[int64]*messageContext{}, |
|
requestTimeout: 0, |
|
isTLS: isTLS, |
|
} |
|
} |
|
|
|
// Start initializes goroutines to read responses and process messages |
|
func (l *Conn) Start() { |
|
l.wgClose.Add(1) |
|
go l.reader() |
|
go l.processMessages() |
|
} |
|
|
|
// IsClosing returns whether or not we're currently closing. |
|
func (l *Conn) IsClosing() bool { |
|
return atomic.LoadUint32(&l.closing) == 1 |
|
} |
|
|
|
// setClosing sets the closing value to true |
|
func (l *Conn) setClosing() bool { |
|
return atomic.CompareAndSwapUint32(&l.closing, 0, 1) |
|
} |
|
|
|
// Close closes the connection. |
|
func (l *Conn) Close() { |
|
l.messageMutex.Lock() |
|
defer l.messageMutex.Unlock() |
|
|
|
if l.setClosing() { |
|
l.Debug.Printf("Sending quit message and waiting for confirmation") |
|
l.chanMessage <- &messagePacket{Op: MessageQuit} |
|
<-l.chanConfirm |
|
close(l.chanMessage) |
|
|
|
l.Debug.Printf("Closing network connection") |
|
if err := l.conn.Close(); err != nil { |
|
log.Println(err) |
|
} |
|
|
|
l.wgClose.Done() |
|
} |
|
l.wgClose.Wait() |
|
} |
|
|
|
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers |
|
func (l *Conn) SetTimeout(timeout time.Duration) { |
|
if timeout > 0 { |
|
atomic.StoreInt64(&l.requestTimeout, int64(timeout)) |
|
} |
|
} |
|
|
|
// Returns the next available messageID |
|
func (l *Conn) nextMessageID() int64 { |
|
if messageID, ok := <-l.chanMessageID; ok { |
|
return messageID |
|
} |
|
return 0 |
|
} |
|
|
|
// StartTLS sends the command to start a TLS session and then creates a new TLS Client |
|
func (l *Conn) StartTLS(config *tls.Config) error { |
|
if l.isTLS { |
|
return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) |
|
} |
|
|
|
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") |
|
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) |
|
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") |
|
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) |
|
packet.AppendChild(request) |
|
l.Debug.PrintPacket(packet) |
|
|
|
msgCtx, err := l.sendMessageWithFlags(packet, startTLS) |
|
if err != nil { |
|
return err |
|
} |
|
defer l.finishMessage(msgCtx) |
|
|
|
l.Debug.Printf("%d: waiting for response", msgCtx.id) |
|
|
|
packetResponse, ok := <-msgCtx.responses |
|
if !ok { |
|
return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) |
|
} |
|
packet, err = packetResponse.ReadPacket() |
|
l.Debug.Printf("%d: got response %p", msgCtx.id, packet) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if l.Debug { |
|
if err := addLDAPDescriptions(packet); err != nil { |
|
l.Close() |
|
return err |
|
} |
|
l.Debug.PrintPacket(packet) |
|
} |
|
|
|
if err := GetLDAPError(packet); err == nil { |
|
conn := tls.Client(l.conn, config) |
|
|
|
if connErr := conn.Handshake(); connErr != nil { |
|
l.Close() |
|
return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr)) |
|
} |
|
|
|
l.isTLS = true |
|
l.conn = conn |
|
} else { |
|
return err |
|
} |
|
go l.reader() |
|
|
|
return nil |
|
} |
|
|
|
// TLSConnectionState returns the client's TLS connection state. |
|
// The return values are their zero values if StartTLS did |
|
// not succeed. |
|
func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) { |
|
tc, ok := l.conn.(*tls.Conn) |
|
if !ok { |
|
return |
|
} |
|
return tc.ConnectionState(), true |
|
} |
|
|
|
func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { |
|
return l.sendMessageWithFlags(packet, 0) |
|
} |
|
|
|
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { |
|
if l.IsClosing() { |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) |
|
} |
|
l.messageMutex.Lock() |
|
l.Debug.Printf("flags&startTLS = %d", flags&startTLS) |
|
if l.isStartingTLS { |
|
l.messageMutex.Unlock() |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) |
|
} |
|
if flags&startTLS != 0 { |
|
if l.outstandingRequests != 0 { |
|
l.messageMutex.Unlock() |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) |
|
} |
|
l.isStartingTLS = true |
|
} |
|
l.outstandingRequests++ |
|
|
|
l.messageMutex.Unlock() |
|
|
|
responses := make(chan *PacketResponse) |
|
messageID := packet.Children[0].Value.(int64) |
|
message := &messagePacket{ |
|
Op: MessageRequest, |
|
MessageID: messageID, |
|
Packet: packet, |
|
Context: &messageContext{ |
|
id: messageID, |
|
done: make(chan struct{}), |
|
responses: responses, |
|
}, |
|
} |
|
if !l.sendProcessMessage(message) { |
|
if l.IsClosing() { |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) |
|
} |
|
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason")) |
|
} |
|
return message.Context, nil |
|
} |
|
|
|
func (l *Conn) finishMessage(msgCtx *messageContext) { |
|
close(msgCtx.done) |
|
|
|
if l.IsClosing() { |
|
return |
|
} |
|
|
|
l.messageMutex.Lock() |
|
l.outstandingRequests-- |
|
if l.isStartingTLS { |
|
l.isStartingTLS = false |
|
} |
|
l.messageMutex.Unlock() |
|
|
|
message := &messagePacket{ |
|
Op: MessageFinish, |
|
MessageID: msgCtx.id, |
|
} |
|
l.sendProcessMessage(message) |
|
} |
|
|
|
func (l *Conn) sendProcessMessage(message *messagePacket) bool { |
|
l.messageMutex.Lock() |
|
defer l.messageMutex.Unlock() |
|
if l.IsClosing() { |
|
return false |
|
} |
|
l.chanMessage <- message |
|
return true |
|
} |
|
|
|
func (l *Conn) processMessages() { |
|
defer func() { |
|
if err := recover(); err != nil { |
|
log.Printf("ldap: recovered panic in processMessages: %v", err) |
|
} |
|
for messageID, msgCtx := range l.messageContexts { |
|
// If we are closing due to an error, inform anyone who |
|
// is waiting about the error. |
|
if l.IsClosing() && l.closeErr.Load() != nil { |
|
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) |
|
} |
|
l.Debug.Printf("Closing channel for MessageID %d", messageID) |
|
close(msgCtx.responses) |
|
delete(l.messageContexts, messageID) |
|
} |
|
close(l.chanMessageID) |
|
close(l.chanConfirm) |
|
}() |
|
|
|
var messageID int64 = 1 |
|
for { |
|
select { |
|
case l.chanMessageID <- messageID: |
|
messageID++ |
|
case message := <-l.chanMessage: |
|
switch message.Op { |
|
case MessageQuit: |
|
l.Debug.Printf("Shutting down - quit message received") |
|
return |
|
case MessageRequest: |
|
// Add to message list and write to network |
|
l.Debug.Printf("Sending message %d", message.MessageID) |
|
|
|
buf := message.Packet.Bytes() |
|
_, err := l.conn.Write(buf) |
|
if err != nil { |
|
l.Debug.Printf("Error Sending Message: %s", err.Error()) |
|
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) |
|
close(message.Context.responses) |
|
break |
|
} |
|
|
|
// Only add to messageContexts if we were able to |
|
// successfully write the message. |
|
l.messageContexts[message.MessageID] = message.Context |
|
|
|
// Add timeout if defined |
|
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) |
|
if requestTimeout > 0 { |
|
go func() { |
|
defer func() { |
|
if err := recover(); err != nil { |
|
log.Printf("ldap: recovered panic in RequestTimeout: %v", err) |
|
} |
|
}() |
|
time.Sleep(requestTimeout) |
|
timeoutMessage := &messagePacket{ |
|
Op: MessageTimeout, |
|
MessageID: message.MessageID, |
|
} |
|
l.sendProcessMessage(timeoutMessage) |
|
}() |
|
} |
|
case MessageResponse: |
|
l.Debug.Printf("Receiving message %d", message.MessageID) |
|
if msgCtx, ok := l.messageContexts[message.MessageID]; ok { |
|
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) |
|
} else { |
|
log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing()) |
|
l.Debug.PrintPacket(message.Packet) |
|
} |
|
case MessageTimeout: |
|
// Handle the timeout by closing the channel |
|
// All reads will return immediately |
|
if msgCtx, ok := l.messageContexts[message.MessageID]; ok { |
|
l.Debug.Printf("Receiving message timeout for %d", message.MessageID) |
|
msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) |
|
delete(l.messageContexts, message.MessageID) |
|
close(msgCtx.responses) |
|
} |
|
case MessageFinish: |
|
l.Debug.Printf("Finished message %d", message.MessageID) |
|
if msgCtx, ok := l.messageContexts[message.MessageID]; ok { |
|
delete(l.messageContexts, message.MessageID) |
|
close(msgCtx.responses) |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
func (l *Conn) reader() { |
|
cleanstop := false |
|
defer func() { |
|
if err := recover(); err != nil { |
|
log.Printf("ldap: recovered panic in reader: %v", err) |
|
} |
|
if !cleanstop { |
|
l.Close() |
|
} |
|
}() |
|
|
|
for { |
|
if cleanstop { |
|
l.Debug.Printf("reader clean stopping (without closing the connection)") |
|
return |
|
} |
|
packet, err := ber.ReadPacket(l.conn) |
|
if err != nil { |
|
// A read error is expected here if we are closing the connection... |
|
if !l.IsClosing() { |
|
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) |
|
l.Debug.Printf("reader error: %s", err) |
|
} |
|
return |
|
} |
|
if err := addLDAPDescriptions(packet); err != nil { |
|
l.Debug.Printf("descriptions error: %s", err) |
|
} |
|
if len(packet.Children) == 0 { |
|
l.Debug.Printf("Received bad ldap packet") |
|
continue |
|
} |
|
l.messageMutex.Lock() |
|
if l.isStartingTLS { |
|
cleanstop = true |
|
} |
|
l.messageMutex.Unlock() |
|
message := &messagePacket{ |
|
Op: MessageResponse, |
|
MessageID: packet.Children[0].Value.(int64), |
|
Packet: packet, |
|
} |
|
if !l.sendProcessMessage(message) { |
|
return |
|
} |
|
} |
|
}
|
|
|