Unflake tests, remove lock from chat/message.User
authorAndrey Petrov <andrey.petrov@shazow.net>
Sun, 24 Jul 2016 19:34:56 +0000 (15:34 -0400)
committerAndrey Petrov <andrey.petrov@shazow.net>
Sun, 24 Jul 2016 20:17:02 +0000 (16:17 -0400)
chat/message/user.go
chat/room.go
host_test.go
sshd/client.go
sshd/net_test.go

index d107a0d080800d3f0109f6c450b76f98a1285b52..a4f1adc498b0c518a74e880bf961a072ddf500b1 100644 (file)
@@ -24,7 +24,6 @@ type User struct {
        msg      chan Message
        done     chan struct{}
 
-       mu        sync.RWMutex
        replyTo   *User // Set when user gets a /msg, for replying.
        screen    io.WriteCloser
        closeOnce sync.Once
@@ -33,10 +32,10 @@ type User struct {
 func NewUser(identity Identifier) *User {
        u := User{
                Identifier: identity,
-               Config:     *DefaultUserConfig,
+               Config:     DefaultUserConfig,
                joined:     time.Now(),
                msg:        make(chan Message, messageBuffer),
-               done:       make(chan struct{}, 1),
+               done:       make(chan struct{}),
        }
        u.SetColorIdx(rand.Int())
 
@@ -85,23 +84,27 @@ func (u *User) Wait() {
 // Disconnect user, stop accepting messages
 func (u *User) Close() {
        u.closeOnce.Do(func() {
-               u.mu.Lock()
                if u.screen != nil {
                        u.screen.Close()
                }
-               close(u.msg)
+               // close(u.msg) TODO: Close?
                close(u.done)
-               u.msg = nil
-               u.mu.Unlock()
        })
 }
 
-// Consume message buffer into an io.Writer. Will block, should be called in a
+// Consume message buffer into the handler. Will block, should be called in a
 // goroutine.
-// TODO: Not sure if this is a great API.
 func (u *User) Consume() {
-       for m := range u.msg {
-               u.HandleMsg(m)
+       for {
+               select {
+               case <-u.done:
+                       return
+               case m, ok := <-u.msg:
+                       if !ok {
+                               return
+                       }
+                       u.HandleMsg(m)
+               }
        }
 }
 
@@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
 
 // Add message to consume by user
 func (u *User) Send(m Message) error {
-       u.mu.RLock()
-       defer u.mu.RUnlock()
        select {
        case u.msg <- m:
+       case <-u.done:
+               return ErrUserClosed
        default:
                logger.Printf("Msg buffer full, closing: %s", u.Name())
                u.Close()
@@ -166,10 +169,10 @@ type UserConfig struct {
 }
 
 // Default user configuration to use
-var DefaultUserConfig *UserConfig
+var DefaultUserConfig UserConfig
 
 func init() {
-       DefaultUserConfig = &UserConfig{
+       DefaultUserConfig = UserConfig{
                Bell:  true,
                Quiet: false,
        }
index bf2128c07bd4e75ee7dc75f4b211e7d196283ee5..7e73da6554e262bda057e451f20ffa1c0f7e3ec2 100644 (file)
@@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
 
 // Join the room as a user, will announce.
 func (r *Room) Join(u *message.User) (*Member, error) {
-       if r.closed {
-               return nil, ErrRoomClosed
-       }
+       // TODO: Check if closed
        if u.Id() == "" {
                return nil, ErrInvalidName
        }
index f2402ca311726ce866b1c1f65b09ea03bf94cee0..e30dd556d01e4ed2af726c057160e80b11486803 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bufio"
        "crypto/rand"
        "crypto/rsa"
+       "errors"
        "io"
        "io/ioutil"
        "strings"
@@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
 
        // First client
        go func() {
-               err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+               err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
                        scanner := bufio.NewScanner(r)
 
                        // Consume the initial buffer
@@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
 
                        // Wrap it up.
                        close(done)
+                       return nil
                })
                if err != nil {
                        t.Fatal(err)
@@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
        <-done
 
        // Second client
-       err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+       err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
                scanner := bufio.NewScanner(r)
 
                // Consume the initial buffer
@@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
                if !strings.HasPrefix(actual, "[Guest1] ") {
                        t.Errorf("Second client did not get Guest1 name: %q", actual)
                }
+               return nil
        })
        if err != nil {
                t.Fatal(err)
@@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
 
        target := s.Addr().String()
 
-       err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+       err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
        if err != nil {
                t.Error(err)
        }
@@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
        clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
        auth.Whitelist(clientpubkey, 0)
 
-       err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+       err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
        if err == nil {
                t.Error("Failed to block unwhitelisted connection.")
        }
@@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
 
        go func() {
                // First client
-               err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
+               err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
                        // Make op
                        member, _ := host.Room.MemberById("foo")
                        if member == nil {
-                               t.Fatal("failed to load MemberById")
+                               return errors.New("failed to load MemberById")
                        }
                        host.Room.Ops.Add(member)
 
                        // Block until second client is here
                        connected <- struct{}{}
                        w.Write([]byte("/kick bar\r\n"))
+                       return nil
                })
                if err != nil {
+                       close(connected)
                        t.Fatal(err)
                }
        }()
 
        go func() {
                // Second client
-               err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
+               err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
                        <-connected
 
                        // Consume while we're connected. Should break when kicked.
-                       ioutil.ReadAll(r) // XXX?
+                       ioutil.ReadAll(r)
+                       return nil
                })
                if err != nil {
                        t.Fatal(err)
index 13d5dea97d55d17a85881257d43ce0ba23149059..47cbc5a4d8c1fa3b4e07de4c461064af95824053 100644 (file)
@@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
 }
 
 // ConnectShell makes a barebones SSH client session, used for testing.
-func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
+func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
        config := NewClientConfig(name)
        conn, err := ssh.Dial("tcp", host, config)
        if err != nil {
@@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
                return err
        }
 
-       /* FIXME: Do we want to request a PTY?
-       err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
-       if err != nil {
-               return err
-       }
+       /*
+               err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
+               if err != nil {
+                       return err
+               }
        */
 
        err = session.Shell()
@@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
                return err
        }
 
-       handler(out, in)
+       _, err = session.SendRequest("ping", true, nil)
+       if err != nil {
+               return err
+       }
 
-       return nil
+       return handler(out, in)
 }
index abbde70081a2a4ca9fe9875d5bb85e73b292c190..7c6f04fc9fb55215f377dc79e2cadb76ce4ea61f 100644 (file)
@@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
        host := s.Addr().String()
        name := "foo"
 
-       err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
+       err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
                // Consume if there is anything
                buf := new(bytes.Buffer)
                w.Write([]byte("hello\r\n"))
 
                buf.Reset()
                _, err := io.Copy(buf, r)
-               if err != nil {
-                       t.Error(err)
-               }
 
                expected := "> hello\r\necho: hello\r\n"
                actual := buf.String()
                if actual != expected {
+                       if err != nil {
+                               t.Error(err)
+                       }
                        t.Errorf("Got %q; expected %q", actual, expected)
                }
                s.Close()
+               return nil
        })
 
        if err != nil {