host := NewHost(s)
host.auth = &auth
+ host.theme = &chat.Themes[0]
for _, fingerprint := range options.Admin {
auth.Op(fingerprint)
channel *chat.Channel
commands *chat.Commands
- motd string
- auth *Auth
+ motd string
+ auth *Auth
+ count int
+
+ // Default theme
+ theme *chat.Theme
}
// NewHost creates a Host on top of an existing listener.
term.AutoCompleteCallback = h.AutoCompleteFunction
user := chat.NewUserScreen(name, term)
- user.Config.Theme = &chat.Themes[0]
+ user.Config.Theme = h.theme
go func() {
// Close term once user is closed.
user.Wait()
}()
defer user.Close()
- term.SetPrompt(GetPrompt(user))
-
err := h.channel.Join(user)
+ if err == chat.ErrIdTaken {
+ // Try again...
+ user.SetName(fmt.Sprintf("Guest%d", h.count))
+ err = h.channel.Join(user)
+ }
if err != nil {
logger.Errorf("Failed to join: %s", err)
return
}
+ // Successfully joined.
+ term.SetPrompt(GetPrompt(user))
+ h.count++
+
for {
line, err := term.ReadLine()
if err == io.EOF {
package main
import (
+ "bufio"
+ "io"
+ "strings"
"testing"
"github.com/shazow/ssh-chat/chat"
+ "github.com/shazow/ssh-chat/sshd"
)
+func stripPrompt(s string) string {
+ pos := strings.LastIndex(s, "\033[K")
+ if pos < 0 {
+ return s
+ }
+ return s[pos+3:]
+}
+
func TestHostGetPrompt(t *testing.T) {
var expected, actual string
actual = GetPrompt(u)
expected = "[foo] "
if actual != expected {
- t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+ t.Errorf("Got: %q; Expected: %q", actual, expected)
}
u.Config.Theme = &chat.Themes[0]
actual = GetPrompt(u)
expected = "[\033[38;05;2mfoo\033[0m] "
if actual != expected {
- t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+ t.Errorf("Got: %q; Expected: %q", actual, expected)
+ }
+}
+
+func TestHostNameCollision(t *testing.T) {
+ key, err := sshd.NewRandomKey(512)
+ if err != nil {
+ t.Fatal(err)
+ }
+ config := sshd.MakeNoAuth()
+ config.AddHostKey(key)
+
+ s, err := sshd.ListenSSH(":0", config)
+ if err != nil {
+ t.Fatal(err)
+ }
+ host := NewHost(s)
+ go host.Serve()
+
+ done := make(chan struct{}, 1)
+
+ // First client
+ go func() {
+ err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+ scanner := bufio.NewScanner(r)
+
+ // Consume the initial buffer
+ scanner.Scan()
+ actual := scanner.Text()
+ if !strings.HasPrefix(actual, "[foo] ") {
+ t.Errorf("First client failed to get 'foo' name.")
+ }
+
+ actual = stripPrompt(actual)
+ expected := " * foo joined. (Connected: 1)"
+ if actual != expected {
+ t.Errorf("Got %q; expected %q", actual, expected)
+ }
+
+ // Ready for second client
+ done <- struct{}{}
+
+ scanner.Scan()
+ actual = stripPrompt(scanner.Text())
+ expected = " * Guest1 joined. (Connected: 2)"
+ if actual != expected {
+ t.Errorf("Got %q; expected %q", actual, expected)
+ }
+
+ // Wrap it up.
+ close(done)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ // Wait for first client
+ <-done
+
+ // Second client
+ err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+ scanner := bufio.NewScanner(r)
+
+ // Consume the initial buffer
+ scanner.Scan()
+ actual := scanner.Text()
+ if !strings.HasPrefix(actual, "[Guest1] ") {
+ t.Errorf("Second client did not get Guest1 name.")
+ }
+ })
+ if err != nil {
+ t.Fatal(err)
}
+
+ <-done
+ s.Close()
}
--- /dev/null
+package sshd
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "io"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// NewRandomKey generates a random key of a desired bit length.
+func NewRandomKey(bits int) (ssh.Signer, error) {
+ key, err := rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, err
+ }
+ return ssh.NewSignerFromKey(key)
+}
+
+// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
+func NewClientConfig(name string) *ssh.ClientConfig {
+ return &ssh.ClientConfig{
+ User: name,
+ Auth: []ssh.AuthMethod{
+ ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
+ return
+ }),
+ },
+ }
+}
+
+// NewClientSession makes a barebones SSH client session, used for testing.
+func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
+ config := NewClientConfig(name)
+ conn, err := ssh.Dial("tcp", host, config)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ session, err := conn.NewSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+
+ in, err := session.StdinPipe()
+ if err != nil {
+ return err
+ }
+
+ out, err := session.StdoutPipe()
+ if err != nil {
+ return err
+ }
+
+ err = session.Shell()
+ if err != nil {
+ return err
+ }
+
+ handler(out, in)
+
+ return nil
+}
import (
"bytes"
- "crypto/rand"
- "crypto/rsa"
"io"
"testing"
-
- "golang.org/x/crypto/ssh"
)
// TODO: Move some of these into their own package?
-func MakeKey(bits int) (ssh.Signer, error) {
- key, err := rsa.GenerateKey(rand.Reader, bits)
- if err != nil {
- return nil, err
- }
- return ssh.NewSignerFromKey(key)
-}
-
-func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
- config := &ssh.ClientConfig{
- User: name,
- Auth: []ssh.AuthMethod{
- ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
- return
- }),
- },
- }
-
- conn, err := ssh.Dial("tcp", host, config)
- if err != nil {
- return err
- }
- defer conn.Close()
-
- session, err := conn.NewSession()
- if err != nil {
- return err
- }
- defer session.Close()
-
- in, err := session.StdinPipe()
- if err != nil {
- return err
- }
-
- out, err := session.StdoutPipe()
- if err != nil {
- return err
- }
-
- err = session.Shell()
- if err != nil {
- return err
- }
-
- handler(out, in)
-
- return nil
-}
-
func TestServerInit(t *testing.T) {
config := MakeNoAuth()
s, err := ListenSSH(":badport", config)
}
func TestServeTerminals(t *testing.T) {
- signer, err := MakeKey(512)
+ signer, err := NewRandomKey(512)
config := MakeNoAuth()
config.AddHostKey(signer)