"errors"
"fmt"
"net"
+ "sync"
"time"
"golang.org/x/crypto/ssh"
var keepaliveInterval = time.Second * 30
var keepaliveRequest = "keepalive@ssh-chat"
+var ErrNoSessionChannel = errors.New("no session channel")
+var ErrNotSessionChannel = errors.New("terminal requires session channel")
+
// Connection is an interface with fields necessary to operate an sshd host.
type Connection interface {
PublicKey() ssh.PublicKey
terminal.Terminal
Conn Connection
Channel ssh.Channel
+
+ done chan struct{}
+ closeOnce sync.Once
}
// Make new terminal from a session channel
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" {
- return nil, errors.New("terminal requires session channel")
+ return nil, ErrNotSessionChannel
}
channel, requests, err := ch.Accept()
if err != nil {
return nil, err
}
term := Terminal{
- *terminal.NewTerminal(channel, "Connecting..."),
- sshConn{conn},
- channel,
+ Terminal: *terminal.NewTerminal(channel, "Connecting..."),
+ Conn: sshConn{conn},
+ Channel: channel,
+
+ done: make(chan struct{}),
}
go term.listen(requests)
- go func() {
- // FIXME: Is this necessary?
- conn.Wait()
- channel.Close()
- }()
go func() {
- for range time.Tick(keepaliveInterval) {
- // TODO: Could break out earlier with a select if we want, rather than waiting for an error.
- _, err := channel.SendRequest(keepaliveRequest, true, nil)
- if err != nil {
- // Connection is gone
- conn.Close()
+ // Keep-Alive Ticker
+ ticker := time.Tick(keepaliveInterval)
+ for {
+ select {
+ case <-ticker:
+ _, err := channel.SendRequest(keepaliveRequest, true, nil)
+ if err != nil {
+ // Connection is gone
+ term.Close()
+ return
+ }
+ case <-term.done:
return
}
}
}
// Find session channel and make a Terminal from it
-func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
+func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) {
+ // Make a terminal from the first session found
for ch := range channels {
if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
- term, err = NewTerminal(conn, ch)
- if err == nil {
- break
- }
- }
-
- if term != nil {
- // Reject the rest.
- // FIXME: Do we need this?
- go func() {
- for ch := range channels {
- ch.Reject(ssh.Prohibited, "only one session allowed")
- }
- }()
+ return NewTerminal(conn, ch)
}
- return term, err
+ return nil, ErrNoSessionChannel
}
// Close terminal and ssh connection
func (t *Terminal) Close() error {
- return t.Conn.Close()
+ var err error
+ t.closeOnce.Do(func() {
+ close(t.done)
+ t.Channel.Close()
+ err = t.Conn.Close()
+ })
+ return err
}
// Negotiate terminal type and settings