This post describes a module to perform common socket operations using asynchronous operations and reactive programming.

We first define an extension of the F# Event module.

module Event =
  let listenOnce f evt =
    async {
      let! res = Async.AwaitEvent evt
      f res
    } |> Async.Start

We then define the signature of the Socket module.

module Socket =
  type t = Socket

  type Error =
    | Connect
    | Accept
    | Send
    | Receive

  //
  // Events
  //
  val Connected : IEvent<t>

  //Generated when the disconnection is done purposely by the programmer.
  val Disconnected : IEvent<t * string>

  //Generated when an incoming connection is detected.
  val Incoming : IEvent<t>

  val SentData : IEvent<t * int>

  val ReceivedData : IEvent<t * int>

  //For connection and incoming connection errors, the SocketError is
  //set to SocketError.SocketError, which is the default undefined error.
  val Failure : IEvent<t * Error * SocketError * Exception>

  //
  // Constructor shorthand
  //
  val create : unit -> t

  val createListeningExtended : ip:option<string> -> port:int -> nToListen:int -> t

  val createListening : port:int -> t

  //
  // "Properties"
  //
  val ip : socket:t -> string

  val port : socket:t -> int

  val isAlive : socket:t -> bool

  val isConnected : socket:t -> bool

  //
  // Operations running asynchronously (and trigger events upon completion)
  //
  val connect : ip:string -> port:int -> Socket -> unit

  val disconnect : socket:t -> reason:string -> keepForLater:bool -> unit

  val send : data:byte[] -> offset:int -> count:int -> socket:t -> unit

  val receive : buffer:byte[] -> offset:int -> count:int -> socket:t -> unit

  //If we want to receive more byes than the buffer size, or if the data
  //is sent progressively, we shall use this function which waits until
  //all the data is received, contrary to the [receive] function which is
  //a one-shot attempt.
  val receiveUntil : buffer:byte[] -> offset:int -> count:int -> socket:t -> unit

  val accept : socket:t -> unit

…and the implementation.

module Socket =

  type t = Socket  

  type Error =
    | Connect
    | Accept
    | Send
    | Receive

  //
  //  Events
  //

  let connected = new Event<t>()
  let Connected = connected.Publish

  let disconnected = new Event<t * string>()
  let Disconnected = disconnected.Publish

  let incoming = new Event<t>()
  let Incoming = incoming.Publish

  let sentData = new Event<t * int>()
  let SentData = sentData.Publish

  let receivedData = new Event<t * int>()
  let ReceivedData = receivedData.Publish

  let failure = new Event<t * Error * SocketError * Exception>()
  let Failure = failure.Publish 

  //
  // Constructor shorthand
  //

  let create() =
    new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)

  let createListeningExtended ip port n =
    let s = create()
    let ip =
      match ip with
      | Some s -> IPAddress.Parse s
      | None -> IPAddress.Any
    s.Bind(new IPEndPoint(ip, port))
    s.Listen(n)
    s

  let createListening port = createListeningExtended None port 5

  //
  // "Properties"
  //

  let ip (socket:t) =
    let ipEndPoint = socket.RemoteEndPoint :?> IPEndPoint
    ipEndPoint.Address.ToString()

  let port (socket:t) =
    let ipEndPoint = socket.RemoteEndPoint :?> IPEndPoint
    ipEndPoint.Port

  let isAlive (socket:t) = box socket <> null

  let isConnected socket =
    if isAlive socket then
      try
        not <| (socket.Poll(1, SelectMode.SelectRead) && socket.Available = 0)
      with
        | :? SocketException -> false
        | e -> failwith e.Message
    else false         

  //
  // Connection
  //
  let asyncConnect (ip:string) port (socket:Socket) =
    Async.FromBeginEnd((fun (cb,o) -> socket.BeginConnect(ip,port,cb,o)), socket.EndConnect)

  let onConnectCompleted socket () =
    connected.Trigger socket

  let onConnectFailed socket (exn:Exception) =
    failure.Trigger(socket, Error.Connect, SocketError.SocketError, exn)

  let connect ip port socket =
    Async.StartWithContinuations(
      asyncConnect ip port socket,
      onConnectCompleted socket,
      onConnectFailed socket,
      ignore
    )

  //
  // Disconnection
  //

  let disconnect (socket:t) (reason:string) keepForLater =
    if isConnected socket then
      try
        socket.Shutdown(SocketShutdown.Both)
        socket.Disconnect(true) //doesn't work on all platforms
      with _ -> ()
      if not keepForLater then socket.Close()
      disconnected.Trigger(socket, reason)      

  //
  // Sending data
  //

  let asyncSend data offset count err (socket:Socket) =
    Async.FromBeginEnd(
      (fun (cb,o) -> socket.BeginSend(data, offset, count, SocketFlags.None,cb,o)),
      (fun iar -> socket.EndSend(iar, err))
    )

  let onSendCompleted socket n =
    sentData.Trigger(socket, n)

  let onSendFailed socket err (exn:Exception) =
    failure.Trigger(socket, Error.Send, err, exn)

  let send data offset count socket =
    let err = ref <| Unchecked.defaultof<SocketError>
    Async.StartWithContinuations(
      asyncSend data offset count err socket,
      onSendCompleted socket,
      onSendFailed socket !err,
      ignore
    )

  //
  // Receiving data
  //
  let asyncReceive buffer offset count err (socket:Socket) =
    Async.FromBeginEnd(
      (fun (cb,o) -> socket.BeginReceive(buffer, offset, count, SocketFlags.None,cb,o)),
      (fun iar -> socket.EndReceive(iar, err))
    )

  let onReceiveCompleted socket n =
    receivedData.Trigger(socket, n)

  let onReceiveFailed socket err (exn:Exception) =
    failure.Trigger(socket, Error.Receive, err, exn)

  let receive buffer offset count socket =
    let err = ref <| Unchecked.defaultof<SocketError>
    Async.StartWithContinuations(
      asyncReceive buffer offset count err socket,
      onReceiveCompleted socket,
      onReceiveFailed socket !err,
      ignore
    )

  let rec onReceiveUntilCompleted buffer offset countRemaining countTotal socket n =
    let nToReceive = countRemaining - n
    if nToReceive > 0 then
      _receiveUntil buffer (offset+n) nToReceive countTotal socket
    else
      receivedData.Trigger(socket, countTotal)

  and _receiveUntil buffer offset countRemaining countTotal socket =
    let err = ref <| Unchecked.defaultof<SocketError>
    Async.StartWithContinuations(
      asyncReceive buffer offset countRemaining err socket,
      onReceiveUntilCompleted buffer offset countRemaining countTotal socket,
      onReceiveFailed socket !err,
      ignore
    )     

  let receiveUntil buffer offset count socket =
    _receiveUntil buffer offset count count socket

  //
  // Accepting an incoming socket
  //
  let asyncAccept (socket:Socket) =
    Async.FromBeginEnd(socket.BeginAccept, socket.EndAccept)

  let onAcceptCompleted socket =
    incoming.Trigger(socket)

  let onAcceptFailed socket (exn:Exception) =
    failure.Trigger(socket, Error.Accept, SocketError.SocketError, exn)

  let accept socket =
    Async.StartWithContinuations(
      asyncAccept socket,
      onAcceptCompleted,
      onAcceptFailed socket,
      ignore
    )

Last,we give an example.

let PORT = 65000

let serverSocket = Socket.createListening PORT

let income = ref null

Socket.Incoming |> Observable.add(fun s ->
  printfn "socket accepted"
  //ignore incoming sockets, except for the first one
  if box !income = null then income := s else printfn "forget it, we have one already"
  //accept more sockets...
  Socket.accept serverSocket
)
Socket.Connected |> Observable.add (fun s ->
  printfn "socket connected [%s:%d]" (Socket.ip s) (Socket.port s)
)
Socket.Disconnected |> Observable.add (fun (s, reason) ->
  printfn "socket purposely disconnected [%s:%d] because %s" (Socket.ip s) (Socket.port s) reason
)
Socket.SentData|> Observable.add (fun (s, n) ->
  printfn "socket [%s:%d] sent %d bytes" (Socket.ip s) (Socket.port s) n
)
Socket.ReceivedData |> Observable.add (fun (s, n) ->
  printfn "socket [%s:%d] received %d bytes" (Socket.ip s) (Socket.port s) n
)

Socket.accept serverSocket

let socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)
socket |> Socket.connect "127.0.0.1" PORT

let socket2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)
socket2 |> Socket.connect "127.0.0.1" PORT

let buffer = Array.create 100 0uy
let buffer2 = Array.create 500 1uy

!income |> Socket.send buffer 0 buffer.Length
//nothing should happen, the event will be triggered when all data will have been read
socket |> Socket.receiveUntil buffer2 0 buffer2.Length
!income |> Socket.send buffer 0 buffer.Length
!income |> Socket.send buffer 0 buffer.Length
!income |> Socket.send buffer 0 buffer.Length
!income |> Socket.send buffer 0 buffer.Length //the ReceivedData is emitted here

do
  Socket.disconnect socket "end of the example" false
  socket2.Close()
  serverSocket.Close()

Comments are closed.