1、简介
一个socket连接复用的包https://github.com/xtaci/smux
如图所示,多个channel输入通过smux合并在一个连接中,后端服务将连接中的channel分离出来进行处理。简介就不多写了。
1.1 先放客户端和服务端代码, 了解使用方式
package main
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/rs/zerolog/log"
"math/rand"
"zero/smux"
"net"
"time"
)
func init() {
rand
.Seed(time
.Now().UnixNano())
}
func main() {
listener
, err
:= net
.Listen("tcp", ":9000")
if err
!= nil {
panic(err
)
}
log
.Info().Msg("随机数服务启动,监听9000端口")
defer listener
.Close()
for {
conn
, err
:= listener
.Accept()
if err
!= nil {
fmt
.Println(err
.Error())
continue
}
go SessionHandler(conn
)
}
}
func SessionHandler(conn net
.Conn
) {
session
, err
:= smux
.Server(conn
, nil)
if err
!= nil {
panic(err
)
}
log
.Info().Msgf("收到客户端连接,创建新会话,对端地址:%s", session
.RemoteAddr().String())
for !session
.IsClosed() {
stream
, err
:= session
.AcceptStream()
if err
!= nil {
fmt
.Println(err
.Error())
break
}
go StreamHandler(stream
)
}
log
.Info().Msgf("客户端连接断开,销毁会话,对端地址:%s", session
.RemoteAddr().String())
}
func StreamHandler(stream
*smux
.Stream
) {
buffer
:= make([]byte, 1024)
n
, err
:= stream
.Read(buffer
)
if err
!= nil {
log
.Error().Msgf("流id:%d,异常信息:%s", stream
.ID(), err
.Error())
stream
.Close()
return
}
cmd
:= buffer
[:n
]
if bytes
.Equal(cmd
, []byte{'R', 'A', 'N', 'D'}) {
rand
:= rand
.Uint64()
response
:= make([]byte, 8)
binary
.BigEndian
.PutUint64(response
, rand
)
stream
.Write(response
)
log
.Debug().Msgf("收到客户端数据,流id:%d,随机数:%d, 响应数据:%v", stream
.ID(), rand
, response
)
} else {
log
.Warn().Msgf("收到未知请求命令,流id:%d,请求命令:%v", stream
.ID(), cmd
)
}
}
package main
import (
"encoding/binary"
"fmt"
"github.com/rs/zerolog/log"
"net"
"net/http"
"zero/smux"
)
var randClient
*smux
.Session
func init() {
conn
, err
:= net
.Dial("tcp", ":9000")
if err
!= nil {
log
.Warn().Msg("随机数服务未启动")
panic(err
)
}
session
, err
:= smux
.Client(conn
, nil)
if err
!= nil {
log
.Error().Msg("打开会话失败")
panic(err
)
}
randClient
= session
}
func main() {
defer randClient
.Close()
http
.HandleFunc("/rand", RandHandler
)
http
.ListenAndServe(":8080", nil)
}
func RandHandler(w http
.ResponseWriter
, r
*http
.Request
) {
stream
, err
:= randClient
.OpenStream()
if err
!= nil {
w
.WriteHeader(500)
fmt
.Fprint(w
, err
.Error())
} else {
log
.Debug().Msgf("收到请求,打开流成功,流id:%d", stream
.ID())
defer stream
.Close()
stream
.Write([]byte{'R', 'A', 'N', 'D'})
buffer
:= make([]byte, 1024)
n
, err
:= stream
.Read(buffer
)
if err
!= nil {
w
.WriteHeader(500)
fmt
.Fprint(w
, err
.Error())
} else {
response
:= buffer
[:n
]
var rand
= binary
.BigEndian
.Uint64(response
)
log
.Debug().Msgf("收到服务端数据,流id:%d,随机数:%d, 响应数据:%v", stream
.ID(), rand
, response
)
fmt
.Fprintf(w
, "%d", rand
)
}
}
}
2、源码分析。
2.1 使用conn封装成session。
func newSession(config
*Config
, conn io
.ReadWriteCloser
, client
bool) *Session
{
...
if client
{
s
.nextStreamID
= 1
} else {
s
.nextStreamID
= 0
}
go s
.shaperLoop()
go s
.recvLoop()
go s
.sendLoop()
if !config
.KeepAliveDisabled
{
go s
.keepalive()
}
return s
}
func (s
*Session
) shaperLoop() {
var reqs shaperHeap
var next writeRequest
var chWrite
chan writeRequest
for {
if len(reqs
) > 0 {
chWrite
= s
.writes
next
= heap
.Pop(&reqs
).(writeRequest
)
} else {
chWrite
= nil
}
select {
case <-s
.die
:
return
case r
:= <-s
.shaper
:
if chWrite
!= nil {
heap
.Push(&reqs
, next
)
}
heap
.Push(&reqs
, r
)
case chWrite
<- next
:
}
}
}
func (s
*Session
) sendLoop() {
var buf
[]byte
var n
int
var err
error
var vec
[][]byte
bw
, ok
:= s
.conn
.(buffersWriter
)
if ok
{
buf
= make([]byte, headerSize
)
vec
= make([][]byte, 2)
} else {
buf
= make([]byte, (1<<16)+headerSize
)
}
for {
select {
case <-s
.die
:
return
case request
:= <-s
.writes
:
buf
[0] = request
.frame
.ver
buf
[1] = request
.frame
.cmd
binary
.LittleEndian
.PutUint16(buf
[2:], uint16(len(request
.frame
.data
)))
binary
.LittleEndian
.PutUint32(buf
[4:], request
.frame
.sid
)
if len(vec
) > 0 {
vec
[0] = buf
[:headerSize
]
vec
[1] = request
.frame
.data
n
, err
= bw
.WriteBuffers(vec
)
} else {
copy(buf
[headerSize
:], request
.frame
.data
)
n
, err
= s
.conn
.Write(buf
[:headerSize
+len(request
.frame
.data
)])
}
...
}
}
}
func (s
*Session
) recvLoop() {
var hdr rawHeader
var updHdr updHeader
for {
for atomic
.LoadInt32(&s
.bucket
) <= 0 && !s
.IsClosed() {
select {
case <-s
.bucketNotify
:
case <-s
.die
:
return
}
}
if _, err
:= io
.ReadFull(s
.conn
, hdr
[:]); err
== nil {
atomic
.StoreInt32(&s
.dataReady
, 1)
if hdr
.Version() != byte(s
.config
.Version
) {
s
.notifyProtoError(ErrInvalidProtocol
)
return
}
sid
:= hdr
.StreamID()
switch hdr
.Cmd() {
case cmdNOP
:
case cmdSYN
:
s
.streamLock
.Lock()
if _, ok
:= s
.streams
[sid
]; !ok
{
stream
:= newStream(sid
, s
.config
.MaxFrameSize
, s
)
s
.streams
[sid
] = stream
select {
case s
.chAccepts
<- stream
:
case <-s
.die
:
}
}
s
.streamLock
.Unlock()
case cmdFIN
:
s
.streamLock
.Lock()
if stream
, ok
:= s
.streams
[sid
]; ok
{
stream
.fin()
stream
.notifyReadEvent()
}
s
.streamLock
.Unlock()
case cmdPSH
:
if hdr
.Length() > 0 {
newbuf
:= defaultAllocator
.Get(int(hdr
.Length()))
if written
, err
:= io
.ReadFull(s
.conn
, newbuf
); err
== nil {
s
.streamLock
.Lock()
if stream
, ok
:= s
.streams
[sid
]; ok
{
stream
.pushBytes(newbuf
)
atomic
.AddInt32(&s
.bucket
, -int32(written
))
stream
.notifyReadEvent()
}
s
.streamLock
.Unlock()
} else {
s
.notifyReadError(err
)
return
}
}
...
}
}
func (s
*Stream
) pushBytes(buf
[]byte) (written
int, err
error) {
s
.bufferLock
.Lock()
s
.buffers
= append(s
.buffers
, buf
)
s
.heads
= append(s
.heads
, buf
)
s
.bufferLock
.Unlock()
return
}
2.2 客户端/服务短使用session生成stream。在stream中进行通信。
func (s
*Session
) OpenStream() (*Stream
, error) {
...
s
.nextStreamID
+= 2
sid
:= s
.nextStreamID
if sid
== sid
%2 {
s
.goAway
= 1
s
.nextStreamIDLock
.Unlock()
return nil, ErrGoAway
}
s
.nextStreamIDLock
.Unlock()
stream
:= newStream(sid
, s
.config
.MaxFrameSize
, s
)
if _, err
:= s
.writeFrame(newFrame(byte(s
.config
.Version
), cmdSYN
, sid
)); err
!= nil {
return nil, err
}
s
.streamLock
.Lock()
defer s
.streamLock
.Unlock()
select {
case <-s
.chSocketReadError
:
return nil, s
.socketReadError
.Load().(error)
case <-s
.chSocketWriteError
:
return nil, s
.socketWriteError
.Load().(error)
case <-s
.die
:
return nil, io
.ErrClosedPipe
default:
s
.streams
[sid
] = stream
return stream
, nil
}
}
func (s
*Session
) AcceptStream() (*Stream
, error) {
var deadline
<-chan time
.Time
if d
, ok
:= s
.deadline
.Load().(time
.Time
); ok
&& !d
.IsZero() {
timer
:= time
.NewTimer(time
.Until(d
))
defer timer
.Stop()
deadline
= timer
.C
}
select {
case stream
:= <-s
.chAccepts
:
return stream
, nil
case <-deadline
:
return nil, ErrTimeout
case <-s
.chSocketReadError
:
return nil, s
.socketReadError
.Load().(error)
case <-s
.chProtoError
:
return nil, s
.protoError
.Load().(error)
case <-s
.die
:
return nil, io
.ErrClosedPipe
}
}
2.3 客户端/服务端进行读取和获取帧(数据)
func (s
*Stream
) Write(b
[]byte) (n
int, err
error) {
...
sent
:= 0
frame
:= newFrame(byte(s
.sess
.config
.Version
), cmdPSH
, s
.id
)
bts
:= b
for len(bts
) > 0 {
sz
:= len(bts
)
if sz
> s
.frameSize
{
sz
= s
.frameSize
}
frame
.data
= bts
[:sz
]
bts
= bts
[sz
:]
n
, err
:= s
.sess
.writeFrameInternal(frame
, deadline
, uint64(s
.numWritten
))
s
.numWritten
++
sent
+= n
if err
!= nil {
return sent
, err
}
}
return sent
, nil
}
func (s
*Session
) writeFrameInternal(f Frame
, deadline
<-chan time
.Time
, prio
uint64) (int, error) {
req
:= writeRequest
{
prio
: prio
,
frame
: f
,
result
: make(chan writeResult
, 1),
}
select {
case s
.shaper
<- req
:
case <-s
.die
:
return 0, io
.ErrClosedPipe
case <-s
.chSocketWriteError
:
return 0, s
.socketWriteError
.Load().(error)
case <-deadline
:
return 0, ErrTimeout
}
select {
case result
:= <-req
.result
:
return result
.n
, result
.err
case <-s
.die
:
return 0, io
.ErrClosedPipe
case <-s
.chSocketWriteError
:
return 0, s
.socketWriteError
.Load().(error)
case <-deadline
:
return 0, ErrTimeout
}
}
func (s
*Stream
) Read(b
[]byte) (n
int, err
error) {
for {
n
, err
= s
.tryRead(b
)
if err
== ErrWouldBlock
{
if ew
:= s
.waitRead(); ew
!= nil {
return 0, ew
}
} else {
return n
, err
}
}
}
func (s
*Stream
) tryRead(b
[]byte) (n
int, err
error) {
if s
.sess
.config
.Version
== 2 {
return s
.tryReadv2(b
)
}
if len(b
) == 0 {
return 0, nil
}
s
.bufferLock
.Lock()
if len(s
.buffers
) > 0 {
n
= copy(b
, s
.buffers
[0])
s
.buffers
[0] = s
.buffers
[0][n
:]
if len(s
.buffers
[0]) == 0 {
s
.buffers
[0] = nil
s
.buffers
= s
.buffers
[1:]
defaultAllocator
.Put(s
.heads
[0])
s
.heads
= s
.heads
[1:]
}
}
...
}