QueueChannel.h 5.04 KB
//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H

#include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h"
#include "llvm/Support/Error.h"

#include <atomic>
#include <condition_variable>
#include <queue>

namespace llvm {

class QueueChannelError : public ErrorInfo<QueueChannelError> {
public:
  static char ID;
};

class QueueChannelClosedError
    : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
public:
  static char ID;
  std::error_code convertToErrorCode() const override {
    return inconvertibleErrorCode();
  }

  void log(raw_ostream &OS) const override {
    OS << "Queue closed";
  }
};

class Queue : public std::queue<char> {
public:
  using ErrorInjector = std::function<Error()>;

  Queue()
    : ReadError([]() { return Error::success(); }),
      WriteError([]() { return Error::success(); }) {}

  Queue(const Queue&) = delete;
  Queue& operator=(const Queue&) = delete;
  Queue(Queue&&) = delete;
  Queue& operator=(Queue&&) = delete;

  std::mutex &getMutex() { return M; }
  std::condition_variable &getCondVar() { return CV; }
  Error checkReadError() { return ReadError(); }
  Error checkWriteError() { return WriteError(); }
  void setReadError(ErrorInjector NewReadError) {
    {
      std::lock_guard<std::mutex> Lock(M);
      ReadError = std::move(NewReadError);
    }
    CV.notify_one();
  }
  void setWriteError(ErrorInjector NewWriteError) {
    std::lock_guard<std::mutex> Lock(M);
    WriteError = std::move(NewWriteError);
  }
private:
  std::mutex M;
  std::condition_variable CV;
  std::function<Error()> ReadError, WriteError;
};

class QueueChannel : public orc::rpc::RawByteChannel {
public:
  QueueChannel(std::shared_ptr<Queue> InQueue,
               std::shared_ptr<Queue> OutQueue)
      : InQueue(InQueue), OutQueue(OutQueue) {}

  QueueChannel(const QueueChannel&) = delete;
  QueueChannel& operator=(const QueueChannel&) = delete;
  QueueChannel(QueueChannel&&) = delete;
  QueueChannel& operator=(QueueChannel&&) = delete;

  template <typename FunctionIdT, typename SequenceIdT>
  Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
    ++InFlightOutgoingMessages;
    return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
  }

  Error endSendMessage() {
    --InFlightOutgoingMessages;
    ++CompletedOutgoingMessages;
    return orc::rpc::RawByteChannel::endSendMessage();
  }

  template <typename FunctionIdT, typename SequenceNumberT>
  Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
    ++InFlightIncomingMessages;
    return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
  }

  Error endReceiveMessage() {
    --InFlightIncomingMessages;
    ++CompletedIncomingMessages;
    return orc::rpc::RawByteChannel::endReceiveMessage();
  }

  Error readBytes(char *Dst, unsigned Size) override {
    std::unique_lock<std::mutex> Lock(InQueue->getMutex());
    while (Size) {
      {
        Error Err = InQueue->checkReadError();
        while (!Err && InQueue->empty()) {
          InQueue->getCondVar().wait(Lock);
          Err = InQueue->checkReadError();
        }
        if (Err)
          return Err;
      }
      *Dst++ = InQueue->front();
      --Size;
      ++NumRead;
      InQueue->pop();
    }
    return Error::success();
  }

  Error appendBytes(const char *Src, unsigned Size) override {
    std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
    while (Size--) {
      if (Error Err = OutQueue->checkWriteError())
        return Err;
      OutQueue->push(*Src++);
      ++NumWritten;
    }
    OutQueue->getCondVar().notify_one();
    return Error::success();
  }

  Error send() override {
    ++SendCalls;
    return Error::success();
  }

  void close() {
    auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
    InQueue->setReadError(ChannelClosed);
    InQueue->setWriteError(ChannelClosed);
    OutQueue->setReadError(ChannelClosed);
    OutQueue->setWriteError(ChannelClosed);
  }

  uint64_t NumWritten = 0;
  uint64_t NumRead = 0;
  std::atomic<size_t> InFlightIncomingMessages{0};
  std::atomic<size_t> CompletedIncomingMessages{0};
  std::atomic<size_t> InFlightOutgoingMessages{0};
  std::atomic<size_t> CompletedOutgoingMessages{0};
  std::atomic<size_t> SendCalls{0};

private:

  std::shared_ptr<Queue> InQueue;
  std::shared_ptr<Queue> OutQueue;
};

inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
createPairedQueueChannels() {
  auto Q1 = std::make_shared<Queue>();
  auto Q2 = std::make_shared<Queue>();
  auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
  auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
  return std::make_pair(std::move(C1), std::move(C2));
}

}

#endif