golang smux库源码分析

    科技2022-07-17  139

    1、简介

    一个socket连接复用的包https://github.com/xtaci/smux

    如图所示,多个channel输入通过smux合并在一个连接中,后端服务将连接中的channel分离出来进行处理。简介就不多写了。

    1.1 先放客户端和服务端代码, 了解使用方式

    package main import ( "bytes" "encoding/binary" "fmt" "github.com/rs/zerolog/log" "math/rand" "zero/smux" "net" "time" ) func init() { rand.Seed(time.Now().UnixNano()) } /** 一个生成随机数的tcp服务 客户端发送'R', 'A', 'N', 'D',服务返回一个随机数 */ func main() { listener, err := net.Listen("tcp", ":9000") if err != nil { panic(err) } log.Info().Msg("随机数服务启动,监听9000端口") defer listener.Close() for { conn, err := listener.Accept() if err != nil { fmt.Println(err.Error()) continue } go SessionHandler(conn) } } /** 处理会话 每个tcp连接生成一个会话session */ func SessionHandler(conn net.Conn) { session, err := smux.Server(conn, nil) if err != nil { panic(err) } log.Info().Msgf("收到客户端连接,创建新会话,对端地址:%s", session.RemoteAddr().String()) for !session.IsClosed() { stream, err := session.AcceptStream() if err != nil { fmt.Println(err.Error()) break } go StreamHandler(stream) } log.Info().Msgf("客户端连接断开,销毁会话,对端地址:%s", session.RemoteAddr().String()) } /** 流数据处理 */ func StreamHandler(stream *smux.Stream) { buffer := make([]byte, 1024) n, err := stream.Read(buffer) if err != nil { log.Error().Msgf("流id:%d,异常信息:%s", stream.ID(), err.Error()) stream.Close() return } cmd := buffer[:n] if bytes.Equal(cmd, []byte{'R', 'A', 'N', 'D'}) { rand := rand.Uint64() response := make([]byte, 8) binary.BigEndian.PutUint64(response, rand) stream.Write(response) log.Debug().Msgf("收到客户端数据,流id:%d,随机数:%d, 响应数据:%v", stream.ID(), rand, response) } else { log.Warn().Msgf("收到未知请求命令,流id:%d,请求命令:%v", stream.ID(), cmd) } } package main import ( "encoding/binary" "fmt" "github.com/rs/zerolog/log" "net" "net/http" "zero/smux" ) /** 随机数服务客户端连接 */ var randClient *smux.Session func init() { //连接后端随机数服务 conn, err := net.Dial("tcp", ":9000") if err != nil { log.Warn().Msg("随机数服务未启动") panic(err) } session, err := smux.Client(conn, nil) if err != nil { log.Error().Msg("打开会话失败") panic(err) } randClient = session } /** 一个api网关,对外提供api接口 调用随机数服务来获取随机数 */ func main() { defer randClient.Close() http.HandleFunc("/rand", RandHandler) http.ListenAndServe(":8080", nil) } /** 随机数接口 */ func RandHandler(w http.ResponseWriter, r *http.Request) { stream, err := randClient.OpenStream() if err != nil { w.WriteHeader(500) fmt.Fprint(w, err.Error()) } else { log.Debug().Msgf("收到请求,打开流成功,流id:%d", stream.ID()) defer stream.Close() stream.Write([]byte{'R', 'A', 'N', 'D'}) buffer := make([]byte, 1024) n, err := stream.Read(buffer) if err != nil { w.WriteHeader(500) fmt.Fprint(w, err.Error()) } else { response := buffer[:n] var rand = binary.BigEndian.Uint64(response) log.Debug().Msgf("收到服务端数据,流id:%d,随机数:%d, 响应数据:%v", stream.ID(), rand, response) fmt.Fprintf(w, "%d", rand) } } }

    2、源码分析。

    2.1 使用conn封装成session。

    // newSession 的时候开启三个协程进行数据交互 func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { ... // 客户端和服务端的区别就是这里, 为了防止冲突。 if client { s.nextStreamID = 1 } else { s.nextStreamID = 0 } go s.shaperLoop() go s.recvLoop() go s.sendLoop() if !config.KeepAliveDisabled { go s.keepalive() } return s } // 进行request的弹出与写入(使用堆),写入的话是在writeFrameInternal中写入,通过s.shaper通道传入这里, 再弹出到s.writes。 func (s *Session) shaperLoop() { var reqs shaperHeap var next writeRequest var chWrite chan writeRequest for { if len(reqs) > 0 { chWrite = s.writes next = heap.Pop(&reqs).(writeRequest) } else { chWrite = nil } select { case <-s.die: return case r := <-s.shaper: if chWrite != nil { // next is valid, reshape heap.Push(&reqs, next) } heap.Push(&reqs, r) case chWrite <- next: } } } //在s.write中拿到request(shaperLoop函数中弹入), 然后通过s.conn.write发送出去。 // 相当于前面定义头, 然后后面定义数据。再传输。 func (s *Session) sendLoop() { var buf []byte var n int var err error var vec [][]byte // vector for writeBuffers bw, ok := s.conn.(buffersWriter) if ok { buf = make([]byte, headerSize) vec = make([][]byte, 2) } else { buf = make([]byte, (1<<16)+headerSize) } for { select { case <-s.die: return case request := <-s.writes: // 这里是自定义规则。VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH) buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) if len(vec) > 0 { vec[0] = buf[:headerSize] vec[1] = request.frame.data n, err = bw.WriteBuffers(vec) } else { copy(buf[headerSize:], request.frame.data) // 最终使用tcp的conn进行传输。 n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) } ... } } } // recvLoop keeps on reading from underlying connection if tokens are available // 从conn 中读取头信息,然后使用sid 查看是否需要新创建流。 // cmdSYN命令创建打开流, s.chAccepts <- stream, 将stream流传入, 然后进行接收。 // cmdPSH命令是进行push流。将数据推入缓冲池进行读取。 func (s *Session) recvLoop() { var hdr rawHeader var updHdr updHeader for { for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { select { case <-s.bucketNotify: case <-s.die: return } } // 首先进行消息头读取。获取流id、消息长度之类的消息(sendLoop)中定义 if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { atomic.StoreInt32(&s.dataReady, 1) if hdr.Version() != byte(s.config.Version) { s.notifyProtoError(ErrInvalidProtocol) return } sid := hdr.StreamID() switch hdr.Cmd() { case cmdNOP: // 创建打开流, (OpenStream函数中进行创建流) case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[sid]; !ok { // 如果是创建新的流, 则将其传入s.chAccepts, 用于服务端获取流链接(AcceptStream函数) stream := newStream(sid, s.config.MaxFrameSize, s) s.streams[sid] = stream select { case s.chAccepts <- stream: case <-s.die: } } s.streamLock.Unlock() case cmdFIN: s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { stream.fin() stream.notifyReadEvent() } s.streamLock.Unlock() // 进行推流、将获取的数据丢入缓冲池 case cmdPSH: if hdr.Length() > 0 { newbuf := defaultAllocator.Get(int(hdr.Length())) if written, err := io.ReadFull(s.conn, newbuf); err == nil { s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { // 将获取的数据推入缓冲池 stream.pushBytes(newbuf) atomic.AddInt32(&s.bucket, -int32(written)) stream.notifyReadEvent() } s.streamLock.Unlock() } else { s.notifyReadError(err) return } } ... } } // pushBytes append buf to buffers // 推入缓冲池, 缓冲池就是一个切片。 func (s *Stream) pushBytes(buf []byte) (written int, err error) { s.bufferLock.Lock() s.buffers = append(s.buffers, buf) s.heads = append(s.heads, buf) s.bufferLock.Unlock() return }

    2.2 客户端/服务短使用session生成stream。在stream中进行通信。

    // 客户端 // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { ... s.nextStreamID += 2 sid := s.nextStreamID if sid == sid%2 { // stream-id overflows s.goAway = 1 s.nextStreamIDLock.Unlock() return nil, ErrGoAway } s.nextStreamIDLock.Unlock() // 新建流进行交互。 stream := newStream(sid, s.config.MaxFrameSize, s) // 这里发个空信息是为了什么? if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { return nil, err } s.streamLock.Lock() defer s.streamLock.Unlock() select { case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chSocketWriteError: return nil, s.socketWriteError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe default: s.streams[sid] = stream return stream, nil } } // 服务端 func (s *Session) AcceptStream() (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } select { // recvLoop函数中,如果新生成的流,则让服务端获取。 case stream := <-s.chAccepts: return stream, nil case <-deadline: return nil, ErrTimeout case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chProtoError: return nil, s.protoError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe } }

    2.3 客户端/服务端进行读取和获取帧(数据)

    // 写入帧 // 生成帧(帧携带数据), 然后传出。 func (s *Stream) Write(b []byte) (n int, err error) { ... // frame split and transmit sent := 0 // new帧, frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) bts := b for len(bts) > 0 { sz := len(bts) if sz > s.frameSize { sz = s.frameSize } frame.data = bts[:sz] bts = bts[sz:] // 写入帧。 n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten)) s.numWritten++ sent += n if err != nil { return sent, err } } return sent, nil } // internal writeFrame version to support deadline used in keepalive func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { req := writeRequest{ prio: prio, frame: f, result: make(chan writeResult, 1), } // 将frame 用req进行包装传入s.shaper(shaperLoop函数), 再弹入s.writer进行conn.write select { case s.shaper <- req: case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } select { case result := <-req.result: return result.n, result.err case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } } // 获取帧(数据) // Read implements net.Conn func (s *Stream) Read(b []byte) (n int, err error) { for { // 循环进行数据读取。直至缓冲池为空。 n, err = s.tryRead(b) if err == ErrWouldBlock { if ew := s.waitRead(); ew != nil { return 0, ew } } else { return n, err } } } // tryRead is the nonblocking version of Read // 读取数据 func (s *Stream) tryRead(b []byte) (n int, err error) { if s.sess.config.Version == 2 { return s.tryReadv2(b) } if len(b) == 0 { return 0, nil } s.bufferLock.Lock() // 从缓冲池中读取数据,每读一帧则去掉。 if len(s.buffers) > 0 { n = copy(b, s.buffers[0]) s.buffers[0] = s.buffers[0][n:] if len(s.buffers[0]) == 0 { s.buffers[0] = nil s.buffers = s.buffers[1:] // full recycle defaultAllocator.Put(s.heads[0]) s.heads = s.heads[1:] } } ... }
    Processed: 0.010, SQL: 8