pw_rpc: Test raw client & bidirectional streaming

- Expand RawTestMethodClient and NanopbTestMethodClient to support
  client and bidirectional streaming methods. Move shared functionality
  to a common InvocationContext base class.
- Add tests for raw client & bidirectional streaming methods.

Change-Id: I26f5d0608f6215bc846e69d89b3c9735595a7930
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/52523
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
This commit is contained in:
Wyatt Hepler 2021-07-02 00:54:13 -07:00 committed by CQ Bot Account
parent d0bda2ad49
commit 07e3ba0098
12 changed files with 600 additions and 228 deletions

View File

@ -119,6 +119,7 @@ pw_cc_library(
srcs = ["fake_channel_output.cc"],
hdrs = [
"public/pw_rpc/internal/test_method.h",
"public/pw_rpc/internal/test_method_context.h",
"pw_rpc_private/fake_channel_output.h",
"pw_rpc_private/fake_server_reader_writer.h",
"pw_rpc_private/internal_test_utils.h",

View File

@ -131,6 +131,7 @@ pw_source_set("synchronized_channel_output") {
pw_source_set("test_utils") {
public = [
"public/pw_rpc/internal/test_method.h",
"public/pw_rpc/internal/test_method_context.h",
"pw_rpc_private/fake_channel_output.h",
"pw_rpc_private/fake_server_reader_writer.h",
"pw_rpc_private/internal_test_utils.h",

View File

@ -51,6 +51,8 @@ Status FakeChannelOutput::SendAndReleaseBuffer(
}
done_ = true;
break;
case PacketType::SERVER_ERROR:
PW_CRASH("Server error: %s", result.value().status().str());
case PacketType::SERVER_STREAM:
ProcessResponse(result.value().payload());
break;

View File

@ -41,6 +41,19 @@ class TestService final : public generated::TestService<TestService> {
writer.Finish(static_cast<Status::Code>(request.status_code));
}
void TestClientStreamRpc(
ServerContext&,
ServerReader<pw_rpc_test_TestRequest, pw_rpc_test_TestStreamResponse>&) {
// TODO(pwbug/428): Test Nanopb client streaming.
}
void TestBidirectionalStreamRpc(
ServerContext&,
ServerReaderWriter<pw_rpc_test_TestRequest,
pw_rpc_test_TestStreamResponse>&) {
// TODO(pwbug/428): Test Nanopb bidirectional streaming.
}
};
} // namespace test

View File

@ -29,10 +29,23 @@ class MixedService1 : public test::generated::TestService<MixedService1> {
void TestStreamRpc(ServerContext&,
const pw_rpc_test_TestRequest&,
ServerWriter<pw_rpc_test_TestStreamResponse>&) {
called_streaming_method = true;
called_server_streaming_method = true;
}
bool called_streaming_method = false;
void TestClientStreamRpc(ServerContext&, RawServerReader&) {
called_client_streaming_method = true;
}
void TestBidirectionalStreamRpc(
ServerContext&,
ServerReaderWriter<pw_rpc_test_TestRequest,
pw_rpc_test_TestStreamResponse>&) {
called_bidirectional_streaming_method = true;
}
bool called_server_streaming_method = false;
bool called_client_streaming_method = false;
bool called_bidirectional_streaming_method = false;
};
class MixedService2 : public test::generated::TestService<MixedService2> {
@ -44,37 +57,71 @@ class MixedService2 : public test::generated::TestService<MixedService2> {
}
void TestStreamRpc(ServerContext&, ConstByteSpan, RawServerWriter&) {
called_streaming_method = true;
called_server_streaming_method = true;
}
bool called_streaming_method = false;
void TestClientStreamRpc(
ServerContext&,
ServerReader<pw_rpc_test_TestRequest, pw_rpc_test_TestStreamResponse>&) {
called_client_streaming_method = true;
}
void TestBidirectionalStreamRpc(ServerContext&, RawServerReaderWriter&) {
called_bidirectional_streaming_method = true;
}
bool called_server_streaming_method = false;
bool called_client_streaming_method = false;
bool called_bidirectional_streaming_method = false;
};
TEST(MixedService1, CallRawMethod) {
TEST(MixedService1, CallRawMethod_Unary) {
PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestRpc) context;
StatusWithSize sws = context.call({});
EXPECT_TRUE(sws.ok());
EXPECT_EQ(123u, sws.size());
}
TEST(MixedService1, CallNanopbMethod) {
TEST(MixedService1, CallNanopbMethod_ServerStreaming) {
PW_NANOPB_TEST_METHOD_CONTEXT(MixedService1, TestStreamRpc) context;
ASSERT_FALSE(context.service().called_streaming_method);
ASSERT_FALSE(context.service().called_server_streaming_method);
context.call({});
EXPECT_TRUE(context.service().called_streaming_method);
EXPECT_TRUE(context.service().called_server_streaming_method);
}
TEST(MixedService2, CallNanopbMethod) {
TEST(MixedService1, CallRawMethod_ClientStreaming) {
PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestClientStreamRpc) context;
ASSERT_FALSE(context.service().called_client_streaming_method);
context.call();
EXPECT_TRUE(context.service().called_client_streaming_method);
}
TEST(MixedService1, CallNanopbMethod_BidirectionalStreaming) {
// TODO(pwbug/428): Test Nanopb bidirectional streaming when supported.
}
TEST(MixedService2, CallNanopbMethod_Unary) {
PW_NANOPB_TEST_METHOD_CONTEXT(MixedService2, TestRpc) context;
Status status = context.call({});
EXPECT_EQ(Status::Unauthenticated(), status);
}
TEST(MixedService2, CallRawMethod) {
TEST(MixedService2, CallRawMethod_ServerStreaming) {
PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestStreamRpc) context;
ASSERT_FALSE(context.service().called_streaming_method);
ASSERT_FALSE(context.service().called_server_streaming_method);
context.call({});
EXPECT_TRUE(context.service().called_streaming_method);
EXPECT_TRUE(context.service().called_server_streaming_method);
}
TEST(MixedService2, CallNanopbMethod_ClientStreaming) {
// TODO(pwbug/428): Test Nanopb client streaming when supported.
}
TEST(MixedService2, CallRawMethod_BidirectionalStreaming) {
PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestBidirectionalStreamRpc) context;
ASSERT_FALSE(context.service().called_bidirectional_streaming_method);
context.call();
EXPECT_TRUE(context.service().called_bidirectional_streaming_method);
}
} // namespace

View File

@ -304,5 +304,9 @@ TEST(NanopbMethod,
EXPECT_EQ(Status::Internal(), last_writer.Write({.value = 1})); // Too big
}
// TODO(pwbug/428): Test NanopbServerReader / NanopbServerReaderWriter. In
// particular, test that the client stream callback correctly decodes
// incoming messages with Nanopb.
} // namespace
} // namespace pw::rpc::internal

View File

@ -20,12 +20,10 @@
#include "pw_bytes/span.h"
#include "pw_containers/vector.h"
#include "pw_preprocessor/arguments.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/hash.h"
#include "pw_rpc/internal/method_lookup.h"
#include "pw_rpc/internal/nanopb_method.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/server.h"
#include "pw_rpc/internal/test_method_context.h"
#include "pw_rpc_private/fake_channel_output.h"
namespace pw::rpc {
@ -65,7 +63,7 @@ namespace pw::rpc {
//
// PW_NANOPB_TEST_METHOD_CONTEXT takes two optional arguments:
//
// size_t kMaxResponse: maximum responses to store; ignored unless streaming
// size_t kMaxResponses: maximum responses to store; ignored unless streaming
// size_t kOutputSizeBytes: buffer size; must be large enough for a packet
//
// Example:
@ -81,7 +79,7 @@ namespace pw::rpc {
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse = 4,
size_t kMaxResponses = 4,
size_t kOutputSizeBytes = 128>
class NanopbTestMethodContext;
@ -89,98 +87,101 @@ class NanopbTestMethodContext;
namespace internal::test::nanopb {
// A ChannelOutput implementation that stores the outgoing payloads and status.
template <typename Response>
template <typename Response, size_t kMaxResponses, size_t kOutputSize>
class MessageOutput final : public FakeChannelOutput {
public:
MessageOutput(const internal::NanopbMethod& kMethod,
Vector<Response>& responses,
ByteSpan packet_buffer,
bool server_streaming)
: FakeChannelOutput(packet_buffer, server_streaming),
method_(kMethod),
responses_(responses) {}
MessageOutput(const internal::NanopbMethod& kMethod, bool server_streaming)
: FakeChannelOutput(packet_buffer_, server_streaming), method_(kMethod) {}
private:
void AppendResponse(ConstByteSpan response) override {
const Vector<Response>& responses() const { return responses_; }
Response& AllocateResponse() {
// If we run out of space, the back message is always the most recent.
responses_.emplace_back();
responses_.back() = {};
PW_ASSERT(method_.serde().DecodeResponse(response, &responses_.back()));
return responses_.back();
}
private:
void AppendResponse(ConstByteSpan response) override {
Response& response_struct = AllocateResponse();
PW_ASSERT(method_.serde().DecodeResponse(response, &response_struct));
}
void ClearResponses() override { responses_.clear(); }
const internal::NanopbMethod& method_;
Vector<Response>& responses_;
Vector<Response, kMaxResponses> responses_;
std::array<std::byte, kOutputSize> packet_buffer_;
};
// Collects everything needed to invoke a particular RPC.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
struct InvocationContext {
struct NanopbInvocationContext : public InvocationContext<Service, kMethodId> {
public:
using Request = internal::Request<kMethod>;
using Response = internal::Response<kMethod>;
// Returns the responses that have been recorded. The maximum number of
// responses is responses().max_size(). responses().back() is always the most
// recent response, even if total_responses() > responses().max_size().
const Vector<Response>& responses() const { return output_.responses(); }
// Gives access to the RPC's response.
const Response& response() const {
PW_ASSERT(!responses().empty());
return responses().back();
}
protected:
template <typename... Args>
InvocationContext(Args&&... args)
: output(MethodLookup::GetNanopbMethod<Service, kMethodId>(),
responses,
buffer,
MethodTraits<decltype(kMethod)>::kServerStreaming),
channel(Channel::Create<123>(&output)),
server(std::span(&channel, 1)),
service(std::forward<Args>(args)...),
call(static_cast<internal::Server&>(server),
static_cast<internal::Channel&>(channel),
service,
MethodLookup::GetNanopbMethod<Service, kMethodId>()) {}
NanopbInvocationContext(Args&&... args)
: InvocationContext<Service, kMethodId>(
MethodLookup::GetNanopbMethod<Service, kMethodId>(),
output_,
std::forward<Args>(args)...),
output_(MethodLookup::GetNanopbMethod<Service, kMethodId>(),
MethodTraits<decltype(kMethod)>::kServerStreaming) {}
MessageOutput<Response> output;
MessageOutput<Response, kMaxResponses, kOutputSize>& output() {
return output_;
}
rpc::Channel channel;
rpc::Server server;
Service service;
Vector<Response, kMaxResponse> responses;
std::array<std::byte, kOutputSize> buffer = {};
internal::ServerCall call;
private:
MessageOutput<Response, kMaxResponses, kOutputSize> output_;
};
// Method invocation context for a unary RPC. Returns the status in call() and
// provides the response through the response() method.
// Method invocation context for a unary RPC. Returns the status in
// server_call() and provides the response through the response() method.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kOutputSize>
class UnaryContext {
private:
InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize> ctx_;
class UnaryContext : public NanopbInvocationContext<Service,
kMethod,
kMethodId,
1,
kOutputSize> {
using Base =
NanopbInvocationContext<Service, kMethod, kMethodId, 1, kOutputSize>;
public:
using Request = typename decltype(ctx_)::Request;
using Response = typename decltype(ctx_)::Response;
using Request = typename Base::Request;
using Response = typename Base::Response;
template <typename... Args>
UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
Service& service() { return ctx_.service; }
UnaryContext(Args&&... args) : Base(std::forward<Args>(args)...) {}
// Invokes the RPC with the provided request. Returns the status.
Status call(const Request& request) {
ctx_.output.clear();
ctx_.responses.emplace_back();
ctx_.responses.back() = {};
Base::output().clear();
Response& response = Base::output().AllocateResponse();
return CallMethodImplFunction<kMethod>(
ctx_.call, request, ctx_.responses.back());
}
// Gives access to the RPC's response.
const Response& response() const {
PW_ASSERT(ctx_.responses.size() > 0u);
return ctx_.responses.back();
Base::server_call(), request, response);
}
};
@ -188,52 +189,40 @@ class UnaryContext {
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
class ServerStreamingContext {
class ServerStreamingContext : public NanopbInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize> {
private:
InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize>
ctx_;
using Base = NanopbInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>;
public:
using Request = typename decltype(ctx_)::Request;
using Response = typename decltype(ctx_)::Response;
using Request = typename Base::Request;
using Response = typename Base::Response;
template <typename... Args>
ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
Service& service() { return ctx_.service; }
ServerStreamingContext(Args&&... args) : Base(std::forward<Args>(args)...) {}
// Invokes the RPC with the provided request.
void call(const Request& request) {
ctx_.output.clear();
NanopbServerWriter<Response> writer(ctx_.call);
return CallMethodImplFunction<kMethod>(ctx_.call, request, writer);
Base::output().clear();
NanopbServerWriter<Response> writer(Base::server_call());
return CallMethodImplFunction<kMethod>(
Base::server_call(), request, writer);
}
// Returns a server writer which writes responses into the context's buffer.
// This should not be called alongside call(); use one or the other.
NanopbServerWriter<Response> writer() {
ctx_.output.clear();
return NanopbServerWriter<Response>(ctx_.call);
}
// Returns the responses that have been recorded. The maximum number of
// responses is responses().max_size(). responses().back() is always the most
// recent response, even if total_responses() > responses().max_size().
const Vector<Response>& responses() const { return ctx_.responses; }
// The total number of responses sent, which may be larger than
// responses.max_size().
size_t total_responses() const { return ctx_.output.total_responses(); }
// True if the stream has terminated.
bool done() const { return ctx_.output.done(); }
// The status of the stream. Only valid if done() is true.
Status status() const {
PW_ASSERT(done());
return ctx_.output.last_status();
Base::output().clear();
return NanopbServerWriter<Response>(Base::server_call());
}
};
@ -242,7 +231,7 @@ class ServerStreamingContext {
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
using Context = std::tuple_element_t<
static_cast<size_t>(internal::MethodTraits<decltype(kMethod)>::kType),
@ -250,7 +239,7 @@ using Context = std::tuple_element_t<
ServerStreamingContext<Service,
kMethod,
kMethodId,
kMaxResponse,
kMaxResponses,
kOutputSize>
// TODO(hepler): Support client and bidi streaming
>>;
@ -260,11 +249,14 @@ using Context = std::tuple_element_t<
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSizeBytes>
class NanopbTestMethodContext
: public internal::test::nanopb::
Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> {
: public internal::test::nanopb::Context<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSizeBytes> {
public:
// Forwards constructor arguments to the service class.
template <typename... ServiceArgs>
@ -272,7 +264,7 @@ class NanopbTestMethodContext
: internal::test::nanopb::Context<Service,
kMethod,
kMethodId,
kMaxResponse,
kMaxResponses,
kOutputSizeBytes>(
std::forward<ServiceArgs>(service_args)...) {}
};

View File

@ -0,0 +1,122 @@
// Copyright 2021 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#pragma once
#include <array>
#include <cstddef>
#include "pw_assert/assert.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/server.h"
#include "pw_rpc_private/fake_channel_output.h"
namespace pw::rpc::internal::test {
// Collects everything needed to invoke a particular RPC.
template <typename Service, uint32_t kMethodId>
class InvocationContext {
public:
Service& service() { return service_; }
// The total number of responses sent, which may be larger than
// responses.max_size().
size_t total_responses() const { return output_.total_responses(); }
// True if the RPC has completed.
bool done() const { return output_.done(); }
// The status of the stream. Only valid if done() is true.
Status status() const {
PW_ASSERT(done());
return output_.last_status();
}
void SendClientError(Status error) {
std::byte packet[kNoPayloadPacketSizeBytes];
server_.ProcessPacket(Packet(PacketType::CLIENT_ERROR,
channel_.id(),
service_.id(),
kMethodId,
{},
error)
.Encode(packet)
.value(),
output_);
}
void SendCancel() {
std::byte packet[kNoPayloadPacketSizeBytes];
server_.ProcessPacket(
Packet(PacketType::CANCEL, channel_.id(), service_.id(), kMethodId)
.Encode(packet)
.value(),
output_);
}
protected:
template <size_t kMaxPayloadSize = 32>
void SendClientStream(ConstByteSpan payload) {
std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize];
server_.ProcessPacket(Packet(PacketType::CLIENT_STREAM,
channel_.id(),
service_.id(),
kMethodId,
payload)
.Encode(packet)
.value(),
output_);
}
void SendClientStreamEnd() {
std::byte packet[kNoPayloadPacketSizeBytes];
server_.ProcessPacket(Packet(PacketType::CLIENT_STREAM_END,
channel_.id(),
service_.id(),
kMethodId)
.Encode(packet)
.value(),
output_);
}
template <typename... Args>
InvocationContext(const Method& method,
FakeChannelOutput& output,
Args&&... service_args)
: output_(output),
channel_(Channel::Create<123>(&output_)),
server_(std::span(&channel_, 1)),
service_(std::forward<Args>(service_args)...),
server_call_(static_cast<internal::Server&>(server_),
static_cast<internal::Channel&>(channel_),
service_,
method) {
server_.RegisterService(service_);
}
internal::ServerCall& server_call() { return server_call_; }
private:
static constexpr size_t kNoPayloadPacketSizeBytes =
2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ +
2 /* status */;
FakeChannelOutput& output_;
rpc::Channel channel_;
rpc::Server server_;
Service service_;
internal::ServerCall server_call_;
};
} // namespace pw::rpc::internal::test

View File

@ -32,6 +32,9 @@ message TestStreamResponse {
message Empty {}
service TestService {
rpc TestRpc(TestRequest) returns (TestResponse) {}
rpc TestStreamRpc(TestRequest) returns (stream TestStreamResponse) {}
rpc TestRpc(TestRequest) returns (TestResponse);
rpc TestStreamRpc(TestRequest) returns (stream TestStreamResponse);
rpc TestClientStreamRpc(stream TestRequest) returns (TestStreamResponse);
rpc TestBidirectionalStreamRpc(stream TestRequest)
returns (stream TestStreamResponse);
}

View File

@ -20,6 +20,33 @@
#include "pw_rpc_test_protos/test.raw_rpc.pb.h"
namespace pw::rpc {
namespace {
Vector<std::byte, 64> EncodeRequest(int integer, Status status) {
Vector<std::byte, 64> buffer(64);
test::TestRequest::RamEncoder test_request(buffer);
test_request.WriteInteger(integer);
test_request.WriteStatusCode(status.code());
EXPECT_EQ(OkStatus(), test_request.status());
buffer.resize(test_request.size());
return buffer;
}
Vector<std::byte, 64> EncodeResponse(int number) {
Vector<std::byte, 64> buffer(64);
test::TestStreamResponse::RamEncoder test_response(buffer);
test_response.WriteNumber(number);
EXPECT_EQ(OkStatus(), test_response.status());
buffer.resize(test_response.size());
return buffer;
}
} // namespace
namespace test {
class TestService final : public generated::TestService<TestService> {
@ -59,7 +86,49 @@ class TestService final : public generated::TestService<TestService> {
writer.Finish(status);
}
void TestClientStreamRpc(ServerContext&, RawServerReader& reader) {
last_reader_ = std::move(reader);
last_reader_.set_on_next([this](ConstByteSpan payload) {
last_reader_.Finish(EncodeResponse(ReadInteger(payload)),
Status::Unauthenticated());
});
}
void TestBidirectionalStreamRpc(ServerContext&,
RawServerReaderWriter& reader_writer) {
last_reader_writer_ = std::move(reader_writer);
last_reader_writer_.set_on_next([this](ConstByteSpan payload) {
last_reader_writer_.Write(EncodeResponse(ReadInteger(payload)));
last_reader_writer_.Finish(Status::NotFound());
});
}
protected:
RawServerReader last_reader_;
RawServerReaderWriter last_reader_writer_;
private:
static uint32_t ReadInteger(ConstByteSpan request) {
uint32_t integer = 0;
protobuf::Decoder decoder(request);
while (decoder.Next().ok()) {
switch (static_cast<TestRequest::Fields>(decoder.FieldNumber())) {
case TestRequest::Fields::INTEGER:
EXPECT_EQ(OkStatus(), decoder.ReadUint32(&integer));
break;
case TestRequest::Fields::STATUS_CODE:
break;
default:
ADD_FAILURE();
}
}
return integer;
}
static bool DecodeRequest(ConstByteSpan request,
int64_t& integer,
Status& status) {
@ -104,13 +173,7 @@ TEST(RawCodegen, CompilesProperly) {
TEST(RawCodegen, Server_InvokeUnaryRpc) {
PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestRpc) context;
std::byte buffer[64];
protobuf::NestedEncoder encoder(buffer);
test::TestRequest::Encoder test_request(&encoder);
test_request.WriteInteger(123);
test_request.WriteStatusCode(OkStatus().code());
auto sws = context.call(encoder.Encode().value());
auto sws = context.call(EncodeRequest(123, OkStatus()));
EXPECT_EQ(OkStatus(), sws.status());
protobuf::Decoder decoder(context.response());
@ -130,13 +193,7 @@ TEST(RawCodegen, Server_InvokeUnaryRpc) {
TEST(RawCodegen, Server_InvokeServerStreamingRpc) {
PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc) context;
std::byte buffer[64];
protobuf::NestedEncoder encoder(buffer);
test::TestRequest::Encoder test_request(&encoder);
test_request.WriteInteger(5);
test_request.WriteStatusCode(Status::Unauthenticated().code());
context.call(encoder.Encode().value());
context.call(EncodeRequest(5, Status::Unauthenticated()));
EXPECT_TRUE(context.done());
EXPECT_EQ(Status::Unauthenticated(), context.status());
EXPECT_EQ(context.total_responses(), 5u);
@ -158,5 +215,49 @@ TEST(RawCodegen, Server_InvokeServerStreamingRpc) {
}
}
int32_t ReadResponseNumber(ConstByteSpan data) {
int32_t value = -1;
protobuf::Decoder decoder(data);
while (decoder.Next().ok()) {
switch (
static_cast<test::TestStreamResponse::Fields>(decoder.FieldNumber())) {
case test::TestStreamResponse::Fields::NUMBER: {
decoder.ReadInt32(&value);
break;
}
default:
ADD_FAILURE();
break;
}
}
return value;
}
TEST(RawCodegen, Server_InvokeClientStreamingRpc) {
PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestClientStreamRpc) ctx;
ctx.call();
ctx.SendClientStream(EncodeRequest(123, OkStatus()));
ASSERT_TRUE(ctx.done());
EXPECT_EQ(Status::Unauthenticated(), ctx.status());
EXPECT_EQ(ctx.total_responses(), 1u);
EXPECT_EQ(ReadResponseNumber(ctx.responses().back()), 123);
}
TEST(RawCodegen, Server_InvokeBidirectionalStreamingRpc) {
PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestBidirectionalStreamRpc)
ctx;
ctx.call();
ctx.SendClientStream(EncodeRequest(456, OkStatus()));
ASSERT_TRUE(ctx.done());
EXPECT_EQ(Status::NotFound(), ctx.status());
ASSERT_EQ(ctx.total_responses(), 1u);
EXPECT_EQ(ReadResponseNumber(ctx.responses().back()), 456);
}
} // namespace
} // namespace pw::rpc

View File

@ -30,6 +30,10 @@ namespace test::raw {
template <typename, auto, uint32_t, size_t, size_t>
class ServerStreamingContext;
template <typename, auto, uint32_t, size_t, size_t>
class ClientStreamingContext;
template <typename, auto, uint32_t, size_t, size_t>
class BidirectionalStreamingContext;
} // namespace test::raw
} // namespace internal
@ -81,6 +85,9 @@ class RawServerReaderWriter : private internal::Responder {
private:
friend class internal::RawMethod;
template <typename, auto, uint32_t, size_t, size_t>
friend class internal::test::raw::BidirectionalStreamingContext;
};
// The RawServerReader is used to receive messages and send a response in a
@ -107,6 +114,9 @@ class RawServerReader : private RawServerReaderWriter {
private:
friend class internal::RawMethod; // Needed for conversions from ReaderWriter
template <typename, auto, uint32_t, size_t, size_t>
friend class internal::test::raw::ClientStreamingContext;
using RawServerReaderWriter::HasClientStream;
RawServerReader(internal::ServerCall& call)

View File

@ -23,6 +23,7 @@
#include "pw_rpc/internal/packet.h"
#include "pw_rpc/internal/raw_method.h"
#include "pw_rpc/internal/server.h"
#include "pw_rpc/internal/test_method_context.h"
#include "pw_rpc_private/fake_channel_output.h"
namespace pw::rpc {
@ -65,7 +66,7 @@ namespace pw::rpc {
//
// PW_RAW_TEST_METHOD_CONTEXT takes two optional arguments:
//
// size_t kMaxResponse: maximum responses to store; ignored unless streaming
// size_t kMaxResponses: maximum responses to store; ignored unless streaming
// size_t kOutputSizeBytes: buffer size; must be large enough for a packet
//
// Example:
@ -81,7 +82,7 @@ namespace pw::rpc {
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse = 4,
size_t kMaxResponses = 4,
size_t kOutputSizeBytes = 128>
class RawTestMethodContext;
@ -89,69 +90,78 @@ class RawTestMethodContext;
namespace internal::test::raw {
// A ChannelOutput implementation that stores the outgoing payloads and status.
template <size_t kOutputSize>
template <size_t kOutputSize, size_t kMaxResponses>
class MessageOutput final : public FakeChannelOutput {
public:
using ResponseBuffer = std::array<std::byte, kOutputSize>;
MessageOutput(bool server_streaming)
: FakeChannelOutput(packet_buffer_, server_streaming) {}
MessageOutput(Vector<ByteSpan>& responses,
Vector<ResponseBuffer>& buffers,
ByteSpan packet_buffer,
bool server_streaming)
: FakeChannelOutput(packet_buffer, server_streaming),
responses_(responses),
buffers_(buffers) {}
const Vector<ByteSpan>& responses() const { return responses_; }
// Allocates a response buffer and returns a reference to the response span
// for it.
ByteSpan& AllocateResponse() {
// If we run out of space, the back message is always the most recent.
response_buffers_.emplace_back();
response_buffers_.back() = {};
responses_.emplace_back();
responses_.back() = {response_buffers_.back().data(),
response_buffers_.back().size()};
return responses_.back();
}
private:
void AppendResponse(ConstByteSpan response) override {
// If we run out of space, the back message is always the most recent.
buffers_.emplace_back();
buffers_.back() = {};
std::memcpy(&buffers_.back(), response.data(), response.size());
responses_.emplace_back();
responses_.back() = {buffers_.back().data(), response.size()};
ByteSpan& response_span = AllocateResponse();
PW_ASSERT(response.size() <= response_span.size());
std::memcpy(response_span.data(), response.data(), response.size());
response_span = response_span.first(response.size());
}
void ClearResponses() override {
responses_.clear();
buffers_.clear();
response_buffers_.clear();
}
Vector<ByteSpan>& responses_;
Vector<ResponseBuffer>& buffers_;
std::array<std::byte, kOutputSize> packet_buffer_;
Vector<ByteSpan, kMaxResponses> responses_;
Vector<std::array<std::byte, kOutputSize>, kMaxResponses> response_buffers_;
};
// Collects everything needed to invoke a particular RPC.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
struct InvocationContext {
class RawInvocationContext : public InvocationContext<Service, kMethodId> {
public:
// Returns the responses that have been recorded. The maximum number of
// responses is responses().max_size(). responses().back() is always the most
// recent response, even if total_responses() > responses().max_size().
const Vector<ByteSpan>& responses() const { return output_.responses(); }
// Gives access to the RPC's most recent response.
ConstByteSpan response() const {
PW_ASSERT(!responses().empty());
return responses().back();
}
protected:
template <typename... Args>
InvocationContext(Args&&... args)
: output(responses,
buffers,
packet_buffer,
MethodTraits<decltype(kMethod)>::kServerStreaming),
channel(Channel::Create<123>(&output)),
server(std::span(&channel, 1)),
service(std::forward<Args>(args)...),
call(static_cast<internal::Server&>(server),
static_cast<internal::Channel&>(channel),
service,
MethodLookup::GetRawMethod<Service, kMethodId>()) {}
RawInvocationContext(Args&&... args)
: InvocationContext<Service, kMethodId>(
MethodLookup::GetRawMethod<Service, kMethodId>(),
output_,
std::forward<Args>(args)...),
output_(MethodTraits<decltype(kMethod)>::kServerStreaming) {}
using ResponseBuffer = std::array<std::byte, kOutputSize>;
MessageOutput<kOutputSize, kMaxResponses>& output() { return output_; }
MessageOutput<kOutputSize> output;
rpc::Channel channel;
rpc::Server server;
Service service;
Vector<ByteSpan, kMaxResponse> responses;
Vector<ResponseBuffer, kMaxResponse> buffers;
std::array<std::byte, kOutputSize> packet_buffer = {};
internal::ServerCall call;
private:
MessageOutput<kOutputSize, kMaxResponses> output_;
};
// Method invocation context for a unary RPC. Returns the status in call() and
@ -160,87 +170,142 @@ template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kOutputSize>
class UnaryContext {
private:
using Context =
InvocationContext<Service, kMethod, kMethodId, 1, kOutputSize>;
Context ctx_;
class UnaryContext
: public RawInvocationContext<Service, kMethod, kMethodId, 1, kOutputSize> {
using Base =
RawInvocationContext<Service, kMethod, kMethodId, 1, kOutputSize>;
public:
template <typename... Args>
UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
Service& service() { return ctx_.service; }
UnaryContext(Args&&... args) : Base(std::forward<Args>(args)...) {}
// Invokes the RPC with the provided request. Returns RPC's StatusWithSize.
StatusWithSize call(ConstByteSpan request) {
ctx_.output.clear();
ctx_.buffers.emplace_back();
ctx_.buffers.back() = {};
ctx_.responses.emplace_back();
auto& response = ctx_.responses.back();
response = {ctx_.buffers.back().data(), ctx_.buffers.back().size()};
auto sws = CallMethodImplFunction<kMethod>(ctx_.call, request, response);
Base::output().clear();
ByteSpan& response = Base::output().AllocateResponse();
auto sws =
CallMethodImplFunction<kMethod>(Base::server_call(), request, response);
response = response.first(sws.size());
return sws;
}
// Gives access to the RPC's response.
ConstByteSpan response() const {
PW_ASSERT(ctx_.responses.size() > 0u);
return ctx_.responses.back();
}
};
// Method invocation context for a server streaming RPC.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
class ServerStreamingContext {
private:
using Context =
InvocationContext<Service, kMethod, kMethodId, kMaxResponse, kOutputSize>;
Context ctx_;
class ServerStreamingContext : public RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize> {
using Base = RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>;
public:
template <typename... Args>
ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
Service& service() { return ctx_.service; }
ServerStreamingContext(Args&&... args) : Base(std::forward<Args>(args)...) {}
// Invokes the RPC with the provided request.
void call(ConstByteSpan request) {
ctx_.output.clear();
RawServerWriter server_writer(ctx_.call);
return CallMethodImplFunction<kMethod>(ctx_.call, request, server_writer);
Base::output().clear();
RawServerWriter writer(Base::server_call());
return CallMethodImplFunction<kMethod>(
Base::server_call(), request, writer);
}
// Returns a server writer which writes responses into the context's buffer.
// This should not be called alongside call(); use one or the other.
RawServerWriter writer() {
ctx_.output.clear();
return RawServerWriter(ctx_.call);
Base::output().clear();
return RawServerWriter(Base::server_call());
}
};
// Method invocation context for a client streaming RPC.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponses,
size_t kOutputSize>
class ClientStreamingContext : public RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize> {
using Base = RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>;
public:
template <typename... Args>
ClientStreamingContext(Args&&... args) : Base(std::forward<Args>(args)...) {}
// Invokes the RPC.
void call() {
Base::output().clear();
RawServerReader reader_writer(Base::server_call());
return CallMethodImplFunction<kMethod>(Base::server_call(), reader_writer);
}
// Returns the responses that have been recorded. The maximum number of
// responses is responses().max_size(). responses().back() is always the most
// recent response, even if total_responses() > responses().max_size().
const Vector<ByteSpan>& responses() const { return ctx_.responses; }
// The total number of responses sent, which may be larger than
// responses.max_size().
size_t total_responses() const { return ctx_.output.total_responses(); }
// True if the stream has terminated.
bool done() const { return ctx_.output.done(); }
// The status of the stream. Only valid if done() is true.
Status status() const {
PW_ASSERT(done());
return ctx_.output.last_status();
// Returns a reader/writer which writes responses into the context's buffer.
// This should not be called alongside call(); use one or the other.
RawServerReader reader() {
Base::output().clear();
return RawServerReader(Base::server_call());
}
// Allow sending client streaming packets.
using Base::SendClientStream;
using Base::SendClientStreamEnd;
};
// Method invocation context for a bidirectional streaming RPC.
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponses,
size_t kOutputSize>
class BidirectionalStreamingContext : public RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize> {
using Base = RawInvocationContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>;
public:
template <typename... Args>
BidirectionalStreamingContext(Args&&... args)
: Base(std::forward<Args>(args)...) {}
// Invokes the RPC.
void call() {
Base::output().clear();
RawServerReaderWriter reader_writer(Base::server_call());
return CallMethodImplFunction<kMethod>(Base::server_call(), reader_writer);
}
// Returns a reader/writer which writes responses into the context's buffer.
// This should not be called alongside call(); use one or the other.
RawServerReaderWriter reader_writer() {
Base::output().clear();
return RawServerReaderWriter(Base::server_call());
}
// Allow sending client streaming packets.
using Base::SendClientStream;
using Base::SendClientStreamEnd;
};
// Alias to select the type of the context object to use based on which type of
@ -248,7 +313,7 @@ class ServerStreamingContext {
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSize>
using Context = std::tuple_element_t<
static_cast<size_t>(MethodTraits<decltype(kMethod)>::kType),
@ -256,21 +321,32 @@ using Context = std::tuple_element_t<
ServerStreamingContext<Service,
kMethod,
kMethodId,
kMaxResponse,
kOutputSize>
// TODO(hepler): Support client and bidi streaming
>>;
kMaxResponses,
kOutputSize>,
ClientStreamingContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>,
BidirectionalStreamingContext<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSize>>>;
} // namespace internal::test::raw
template <typename Service,
auto kMethod,
uint32_t kMethodId,
size_t kMaxResponse,
size_t kMaxResponses,
size_t kOutputSizeBytes>
class RawTestMethodContext
: public internal::test::raw::
Context<Service, kMethod, kMethodId, kMaxResponse, kOutputSizeBytes> {
: public internal::test::raw::Context<Service,
kMethod,
kMethodId,
kMaxResponses,
kOutputSizeBytes> {
public:
// Forwards constructor arguments to the service class.
template <typename... ServiceArgs>
@ -278,7 +354,7 @@ class RawTestMethodContext
: internal::test::raw::Context<Service,
kMethod,
kMethodId,
kMaxResponse,
kMaxResponses,
kOutputSizeBytes>(
std::forward<ServiceArgs>(service_args)...) {}
};