|
10 | 10 | import java.nio.channels.AsynchronousSocketChannel;
|
11 | 11 | import java.nio.channels.CompletionHandler;
|
12 | 12 | import java.util.Map;
|
| 13 | +import java.util.Objects; |
| 14 | +import java.util.logging.Level; |
13 | 15 |
|
14 | 16 | import org.apache.sshd.common.FactoryManager;
|
15 | 17 | import org.apache.sshd.common.io.IoHandler;
|
16 | 18 | import org.apache.sshd.common.io.nio2.Nio2Acceptor;
|
17 | 19 | import org.apache.sshd.common.io.nio2.Nio2CompletionHandler;
|
18 | 20 | import org.apache.sshd.common.io.nio2.Nio2Session;
|
19 |
| -import org.apache.sshd.common.util.ValidateUtils; |
20 | 21 |
|
21 | 22 | /**
|
22 | 23 | * @author dariusz.bywalec
|
23 | 24 | *
|
24 | 25 | */
|
25 | 26 | public class SshAcceptor extends Nio2Acceptor {
|
26 | 27 |
|
27 |
| - public SshAcceptor(FactoryManager manager, IoHandler handler, AsynchronousChannelGroup group) { |
28 |
| - super(manager, handler, group); |
29 |
| - } |
30 |
| - |
31 |
| - protected CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> createSocketCompletionHandler( |
32 |
| - Map<SocketAddress, AsynchronousServerSocketChannel> channelsMap, AsynchronousServerSocketChannel socket) throws IOException { |
33 |
| - return new SshAcceptCompletionHandler(socket); |
34 |
| - } |
35 |
| - |
36 |
| - protected class SshAcceptCompletionHandler extends Nio2CompletionHandler<AsynchronousSocketChannel, SocketAddress> { |
37 |
| - protected final AsynchronousServerSocketChannel socket; |
38 |
| - |
39 |
| - SshAcceptCompletionHandler(AsynchronousServerSocketChannel socket) { |
40 |
| - this.socket = socket; |
41 |
| - } |
42 |
| - |
43 |
| - @Override |
44 |
| - @SuppressWarnings("synthetic-access") |
45 |
| - protected void onCompleted(AsynchronousSocketChannel result, SocketAddress address) { |
46 |
| - // Verify that the address has not been unbound |
47 |
| - if (!channels.containsKey(address)) { |
48 |
| - return; |
49 |
| - } |
50 |
| - |
51 |
| - Nio2Session session = null; |
52 |
| - try { |
53 |
| - // Create a session |
54 |
| - IoHandler handler = getIoHandler(); |
55 |
| - setSocketOptions(result); |
56 |
| - session = Objects.requireNonNull(createSession(SshAcceptor.this, address, result, handler), "No SSH session created"); |
57 |
| - handler.sessionCreated(session); |
58 |
| - sessions.put(session.getId(), session); |
59 |
| - session.startReading(); |
60 |
| - } catch (Throwable exc) { |
61 |
| - failed(exc, address); |
62 |
| - |
63 |
| - // fail fast the accepted connection |
64 |
| - if (session != null) { |
65 |
| - try { |
66 |
| - session.close(); |
67 |
| - } catch (Throwable t) { |
68 |
| - log.warn("Failed (" + t.getClass().getSimpleName() + ")" |
69 |
| - + " to close accepted connection from " + address |
70 |
| - + ": " + t.getMessage(), |
71 |
| - t); |
72 |
| - } |
73 |
| - } |
74 |
| - } |
75 |
| - |
76 |
| - try { |
77 |
| - // Accept new connections |
78 |
| - socket.accept(address, this); |
79 |
| - } catch (Throwable exc) { |
80 |
| - failed(exc, address); |
81 |
| - } |
82 |
| - } |
83 |
| - |
84 |
| - @SuppressWarnings("synthetic-access") |
85 |
| - protected Nio2Session createSession(Nio2Acceptor acceptor, SocketAddress address, AsynchronousSocketChannel channel, IoHandler handler) throws Throwable { |
86 |
| - if (log.isTraceEnabled()) { |
87 |
| - log.trace("createSshSession({}) address={}", acceptor, address); |
88 |
| - } |
89 |
| - return new Nio2Session(acceptor, getFactoryManager(), handler, channel); |
90 |
| - } |
91 |
| - |
92 |
| - @Override |
93 |
| - @SuppressWarnings("synthetic-access") |
94 |
| - protected void onFailed(final Throwable exc, final SocketAddress address) { |
95 |
| - if (channels.containsKey(address) && !disposing.get()) { |
96 |
| - log.warn("Caught " + exc.getClass().getSimpleName() |
97 |
| - + " while accepting incoming connection from " + address |
98 |
| - + ": " + exc.getMessage(), |
99 |
| - exc); |
100 |
| - } |
101 |
| - } |
102 |
| - } |
| 28 | + public SshAcceptor(FactoryManager manager, IoHandler handler, AsynchronousChannelGroup group) { |
| 29 | + super(manager, handler, group); |
| 30 | + } |
| 31 | + |
| 32 | + protected CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> createSocketCompletionHandler( |
| 33 | + Map<SocketAddress, AsynchronousServerSocketChannel> channelsMap, AsynchronousServerSocketChannel socket) |
| 34 | + throws IOException { |
| 35 | + return new SshAcceptCompletionHandler(socket); |
| 36 | + } |
| 37 | + |
| 38 | + @SuppressWarnings("synthetic-access") |
| 39 | + protected class SshAcceptCompletionHandler extends Nio2CompletionHandler<AsynchronousSocketChannel, SocketAddress> { |
| 40 | + protected final AsynchronousServerSocketChannel socket; |
| 41 | + |
| 42 | + SshAcceptCompletionHandler(AsynchronousServerSocketChannel socket) { |
| 43 | + this.socket = socket; |
| 44 | + } |
| 45 | + |
| 46 | + @Override |
| 47 | + protected void onCompleted(AsynchronousSocketChannel result, SocketAddress address) { |
| 48 | + // Verify that the address has not been unbound |
| 49 | + if (!channels.containsKey(address)) { |
| 50 | + if (log.isDebugEnabled()) { |
| 51 | + log.debug("onCompleted({}) unbound address", address); |
| 52 | + } |
| 53 | + return; |
| 54 | + } |
| 55 | + |
| 56 | + Nio2Session session = null; |
| 57 | + Long sessionId = null; |
| 58 | + boolean keepAccepting; |
| 59 | + try { |
| 60 | + // Create a session |
| 61 | + IoHandler handler = getIoHandler(); |
| 62 | + setSocketOptions(result); |
| 63 | + session = Objects.requireNonNull(createSession(SshAcceptor.this, address, result, handler), |
| 64 | + "No SSH session created"); |
| 65 | + sessionId = session.getId(); |
| 66 | + handler.sessionCreated(session); |
| 67 | + sessions.put(sessionId, session); |
| 68 | + if (session.isClosing()) { |
| 69 | + try { |
| 70 | + handler.sessionClosed(session); |
| 71 | + } finally { |
| 72 | + unmapSession(sessionId); |
| 73 | + } |
| 74 | + } else { |
| 75 | + session.startReading(); |
| 76 | + } |
| 77 | + |
| 78 | + keepAccepting = true; |
| 79 | + } catch (Throwable exc) { |
| 80 | + keepAccepting = okToReaccept(exc, address); |
| 81 | + |
| 82 | + // fail fast the accepted connection |
| 83 | + if (session != null) { |
| 84 | + try { |
| 85 | + session.close(); |
| 86 | + } catch (Throwable t) { |
| 87 | + log.warn("onCompleted(" + address + ") Failed (" + t.getClass().getSimpleName() + ")" |
| 88 | + + " to close accepted connection from " + address + ": " + t.getMessage(), t); |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + unmapSession(sessionId); |
| 93 | + } |
| 94 | + |
| 95 | + if (keepAccepting) { |
| 96 | + try { |
| 97 | + // Accept new connections |
| 98 | + socket.accept(address, this); |
| 99 | + } catch (Throwable exc) { |
| 100 | + failed(exc, address); |
| 101 | + } |
| 102 | + } else { |
| 103 | + log.error("=====> onCompleted({}) no longer accepting incoming connections <====", address); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + protected Nio2Session createSession(Nio2Acceptor acceptor, SocketAddress address, |
| 108 | + AsynchronousSocketChannel channel, IoHandler handler) throws Throwable { |
| 109 | + if (log.isTraceEnabled()) { |
| 110 | + log.trace("createSshSession({}) address={}", acceptor, address); |
| 111 | + } |
| 112 | + return new Nio2Session(acceptor, getFactoryManager(), handler, channel); |
| 113 | + } |
| 114 | + |
| 115 | + @Override |
| 116 | + protected void onFailed(Throwable exc, SocketAddress address) { |
| 117 | + if (okToReaccept(exc, address)) { |
| 118 | + try { |
| 119 | + // Accept new connections |
| 120 | + socket.accept(address, this); |
| 121 | + } catch (Throwable t) { |
| 122 | + // Do not call failed(t, address) to avoid infinite |
| 123 | + // recursion |
| 124 | + log.error("Failed (" + t.getClass().getSimpleName() + " to re-accept new connections on " + address |
| 125 | + + ": " + t.getMessage(), t); |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + protected boolean okToReaccept(Throwable exc, SocketAddress address) { |
| 131 | + AsynchronousServerSocketChannel channel = channels.get(address); |
| 132 | + if (channel == null) { |
| 133 | + if (log.isDebugEnabled()) { |
| 134 | + log.debug("Caught {} for untracked channel of {}: {}", exc.getClass().getSimpleName(), address, |
| 135 | + exc.getMessage()); |
| 136 | + } |
| 137 | + return false; |
| 138 | + } |
| 139 | + |
| 140 | + if (disposing.get()) { |
| 141 | + if (log.isDebugEnabled()) { |
| 142 | + log.debug("Caught {} for tracked channel of {} while disposing: {}", exc.getClass().getSimpleName(), |
| 143 | + address, exc.getMessage()); |
| 144 | + } |
| 145 | + return false; |
| 146 | + } |
| 147 | + |
| 148 | + log.warn("Caught {} while accepting incoming connection from {}: {}", exc.getClass().getSimpleName(), |
| 149 | + address, exc.getMessage()); |
| 150 | + SshLoggingUtils.logExceptionStackTrace(log, Level.WARNING, exc); |
| 151 | + return true; |
| 152 | + } |
| 153 | + } |
103 | 154 | }
|
0 commit comments