From 07e3ba0098e969ee0c109972aafe21a7b7926302 Mon Sep 17 00:00:00 2001 From: Wyatt Hepler Date: Fri, 2 Jul 2021 00:54:13 -0700 Subject: [PATCH] pw_rpc: Test raw client & bidirectional streaming - Expand RawTestMethodClient and NanopbTestMethodClient to support client and bidirectional streaming methods. Move shared functionality to a common InvocationContext base class. - Add tests for raw client & bidirectional streaming methods. Change-Id: I26f5d0608f6215bc846e69d89b3c9735595a7930 Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/52523 Pigweed-Auto-Submit: Wyatt Hepler Commit-Queue: Auto-Submit Reviewed-by: Alexei Frolov --- pw_rpc/BUILD.bazel | 1 + pw_rpc/BUILD.gn | 1 + pw_rpc/fake_channel_output.cc | 2 + pw_rpc/nanopb/codegen_test.cc | 13 + pw_rpc/nanopb/method_lookup_test.cc | 71 ++++- pw_rpc/nanopb/nanopb_method_test.cc | 4 + .../pw_rpc/nanopb_test_method_context.h | 186 ++++++------ .../pw_rpc/internal/test_method_context.h | 122 ++++++++ pw_rpc/pw_rpc_test_protos/test.proto | 7 +- pw_rpc/raw/codegen_test.cc | 129 +++++++- .../public/pw_rpc/raw/server_reader_writer.h | 10 + .../public/pw_rpc/raw_test_method_context.h | 282 +++++++++++------- 12 files changed, 600 insertions(+), 228 deletions(-) create mode 100644 pw_rpc/public/pw_rpc/internal/test_method_context.h diff --git a/pw_rpc/BUILD.bazel b/pw_rpc/BUILD.bazel index f0b596b49..14677b70b 100644 --- a/pw_rpc/BUILD.bazel +++ b/pw_rpc/BUILD.bazel @@ -119,6 +119,7 @@ pw_cc_library( srcs = ["fake_channel_output.cc"], hdrs = [ "public/pw_rpc/internal/test_method.h", + "public/pw_rpc/internal/test_method_context.h", "pw_rpc_private/fake_channel_output.h", "pw_rpc_private/fake_server_reader_writer.h", "pw_rpc_private/internal_test_utils.h", diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index 42f04c4a4..367a83d24 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn @@ -131,6 +131,7 @@ pw_source_set("synchronized_channel_output") { pw_source_set("test_utils") { public = [ "public/pw_rpc/internal/test_method.h", + "public/pw_rpc/internal/test_method_context.h", "pw_rpc_private/fake_channel_output.h", "pw_rpc_private/fake_server_reader_writer.h", "pw_rpc_private/internal_test_utils.h", diff --git a/pw_rpc/fake_channel_output.cc b/pw_rpc/fake_channel_output.cc index c3ee687ac..1ed1f24df 100644 --- a/pw_rpc/fake_channel_output.cc +++ b/pw_rpc/fake_channel_output.cc @@ -51,6 +51,8 @@ Status FakeChannelOutput::SendAndReleaseBuffer( } done_ = true; break; + case PacketType::SERVER_ERROR: + PW_CRASH("Server error: %s", result.value().status().str()); case PacketType::SERVER_STREAM: ProcessResponse(result.value().payload()); break; diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc index 36f0a4c16..cbc2619f3 100644 --- a/pw_rpc/nanopb/codegen_test.cc +++ b/pw_rpc/nanopb/codegen_test.cc @@ -41,6 +41,19 @@ class TestService final : public generated::TestService { writer.Finish(static_cast(request.status_code)); } + + void TestClientStreamRpc( + ServerContext&, + ServerReader&) { + // TODO(pwbug/428): Test Nanopb client streaming. + } + + void TestBidirectionalStreamRpc( + ServerContext&, + ServerReaderWriter&) { + // TODO(pwbug/428): Test Nanopb bidirectional streaming. + } }; } // namespace test diff --git a/pw_rpc/nanopb/method_lookup_test.cc b/pw_rpc/nanopb/method_lookup_test.cc index 02137c8ff..58a81090d 100644 --- a/pw_rpc/nanopb/method_lookup_test.cc +++ b/pw_rpc/nanopb/method_lookup_test.cc @@ -29,10 +29,23 @@ class MixedService1 : public test::generated::TestService { void TestStreamRpc(ServerContext&, const pw_rpc_test_TestRequest&, ServerWriter&) { - called_streaming_method = true; + called_server_streaming_method = true; } - bool called_streaming_method = false; + void TestClientStreamRpc(ServerContext&, RawServerReader&) { + called_client_streaming_method = true; + } + + void TestBidirectionalStreamRpc( + ServerContext&, + ServerReaderWriter&) { + called_bidirectional_streaming_method = true; + } + + bool called_server_streaming_method = false; + bool called_client_streaming_method = false; + bool called_bidirectional_streaming_method = false; }; class MixedService2 : public test::generated::TestService { @@ -44,37 +57,71 @@ class MixedService2 : public test::generated::TestService { } void TestStreamRpc(ServerContext&, ConstByteSpan, RawServerWriter&) { - called_streaming_method = true; + called_server_streaming_method = true; } - bool called_streaming_method = false; + void TestClientStreamRpc( + ServerContext&, + ServerReader&) { + called_client_streaming_method = true; + } + + void TestBidirectionalStreamRpc(ServerContext&, RawServerReaderWriter&) { + called_bidirectional_streaming_method = true; + } + + bool called_server_streaming_method = false; + bool called_client_streaming_method = false; + bool called_bidirectional_streaming_method = false; }; -TEST(MixedService1, CallRawMethod) { +TEST(MixedService1, CallRawMethod_Unary) { PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestRpc) context; StatusWithSize sws = context.call({}); EXPECT_TRUE(sws.ok()); EXPECT_EQ(123u, sws.size()); } -TEST(MixedService1, CallNanopbMethod) { +TEST(MixedService1, CallNanopbMethod_ServerStreaming) { PW_NANOPB_TEST_METHOD_CONTEXT(MixedService1, TestStreamRpc) context; - ASSERT_FALSE(context.service().called_streaming_method); + ASSERT_FALSE(context.service().called_server_streaming_method); context.call({}); - EXPECT_TRUE(context.service().called_streaming_method); + EXPECT_TRUE(context.service().called_server_streaming_method); } -TEST(MixedService2, CallNanopbMethod) { +TEST(MixedService1, CallRawMethod_ClientStreaming) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestClientStreamRpc) context; + ASSERT_FALSE(context.service().called_client_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_client_streaming_method); +} + +TEST(MixedService1, CallNanopbMethod_BidirectionalStreaming) { + // TODO(pwbug/428): Test Nanopb bidirectional streaming when supported. +} + +TEST(MixedService2, CallNanopbMethod_Unary) { PW_NANOPB_TEST_METHOD_CONTEXT(MixedService2, TestRpc) context; Status status = context.call({}); EXPECT_EQ(Status::Unauthenticated(), status); } -TEST(MixedService2, CallRawMethod) { +TEST(MixedService2, CallRawMethod_ServerStreaming) { PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestStreamRpc) context; - ASSERT_FALSE(context.service().called_streaming_method); + ASSERT_FALSE(context.service().called_server_streaming_method); context.call({}); - EXPECT_TRUE(context.service().called_streaming_method); + EXPECT_TRUE(context.service().called_server_streaming_method); +} + +TEST(MixedService2, CallNanopbMethod_ClientStreaming) { + // TODO(pwbug/428): Test Nanopb client streaming when supported. +} + +TEST(MixedService2, CallRawMethod_BidirectionalStreaming) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestBidirectionalStreamRpc) context; + ASSERT_FALSE(context.service().called_bidirectional_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_bidirectional_streaming_method); } } // namespace diff --git a/pw_rpc/nanopb/nanopb_method_test.cc b/pw_rpc/nanopb/nanopb_method_test.cc index a2c22f2bb..8964ac417 100644 --- a/pw_rpc/nanopb/nanopb_method_test.cc +++ b/pw_rpc/nanopb/nanopb_method_test.cc @@ -304,5 +304,9 @@ TEST(NanopbMethod, EXPECT_EQ(Status::Internal(), last_writer.Write({.value = 1})); // Too big } +// TODO(pwbug/428): Test NanopbServerReader / NanopbServerReaderWriter. In +// particular, test that the client stream callback correctly decodes +// incoming messages with Nanopb. + } // namespace } // namespace pw::rpc::internal diff --git a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h index 721689a6f..0214b6a27 100644 --- a/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h +++ b/pw_rpc/nanopb/public/pw_rpc/nanopb_test_method_context.h @@ -20,12 +20,10 @@ #include "pw_bytes/span.h" #include "pw_containers/vector.h" #include "pw_preprocessor/arguments.h" -#include "pw_rpc/channel.h" #include "pw_rpc/internal/hash.h" #include "pw_rpc/internal/method_lookup.h" #include "pw_rpc/internal/nanopb_method.h" -#include "pw_rpc/internal/packet.h" -#include "pw_rpc/internal/server.h" +#include "pw_rpc/internal/test_method_context.h" #include "pw_rpc_private/fake_channel_output.h" namespace pw::rpc { @@ -65,7 +63,7 @@ namespace pw::rpc { // // PW_NANOPB_TEST_METHOD_CONTEXT takes two optional arguments: // -// size_t kMaxResponse: maximum responses to store; ignored unless streaming +// size_t kMaxResponses: maximum responses to store; ignored unless streaming // size_t kOutputSizeBytes: buffer size; must be large enough for a packet // // Example: @@ -81,7 +79,7 @@ namespace pw::rpc { template class NanopbTestMethodContext; @@ -89,98 +87,101 @@ class NanopbTestMethodContext; namespace internal::test::nanopb { // A ChannelOutput implementation that stores the outgoing payloads and status. -template +template class MessageOutput final : public FakeChannelOutput { public: - MessageOutput(const internal::NanopbMethod& kMethod, - Vector& responses, - ByteSpan packet_buffer, - bool server_streaming) - : FakeChannelOutput(packet_buffer, server_streaming), - method_(kMethod), - responses_(responses) {} + MessageOutput(const internal::NanopbMethod& kMethod, bool server_streaming) + : FakeChannelOutput(packet_buffer_, server_streaming), method_(kMethod) {} - private: - void AppendResponse(ConstByteSpan response) override { + const Vector& responses() const { return responses_; } + + Response& AllocateResponse() { // If we run out of space, the back message is always the most recent. responses_.emplace_back(); responses_.back() = {}; - PW_ASSERT(method_.serde().DecodeResponse(response, &responses_.back())); + return responses_.back(); + } + + private: + void AppendResponse(ConstByteSpan response) override { + Response& response_struct = AllocateResponse(); + PW_ASSERT(method_.serde().DecodeResponse(response, &response_struct)); } void ClearResponses() override { responses_.clear(); } const internal::NanopbMethod& method_; - Vector& responses_; + Vector responses_; + std::array packet_buffer_; }; // Collects everything needed to invoke a particular RPC. template -struct InvocationContext { +struct NanopbInvocationContext : public InvocationContext { + public: using Request = internal::Request; using Response = internal::Response; + // Returns the responses that have been recorded. The maximum number of + // responses is responses().max_size(). responses().back() is always the most + // recent response, even if total_responses() > responses().max_size(). + const Vector& responses() const { return output_.responses(); } + + // Gives access to the RPC's response. + const Response& response() const { + PW_ASSERT(!responses().empty()); + return responses().back(); + } + + protected: template - InvocationContext(Args&&... args) - : output(MethodLookup::GetNanopbMethod(), - responses, - buffer, - MethodTraits::kServerStreaming), - channel(Channel::Create<123>(&output)), - server(std::span(&channel, 1)), - service(std::forward(args)...), - call(static_cast(server), - static_cast(channel), - service, - MethodLookup::GetNanopbMethod()) {} + NanopbInvocationContext(Args&&... args) + : InvocationContext( + MethodLookup::GetNanopbMethod(), + output_, + std::forward(args)...), + output_(MethodLookup::GetNanopbMethod(), + MethodTraits::kServerStreaming) {} - MessageOutput output; + MessageOutput& output() { + return output_; + } - rpc::Channel channel; - rpc::Server server; - Service service; - Vector responses; - std::array buffer = {}; - - internal::ServerCall call; + private: + MessageOutput output_; }; -// Method invocation context for a unary RPC. Returns the status in call() and -// provides the response through the response() method. +// Method invocation context for a unary RPC. Returns the status in +// server_call() and provides the response through the response() method. template -class UnaryContext { - private: - InvocationContext ctx_; +class UnaryContext : public NanopbInvocationContext { + using Base = + NanopbInvocationContext; public: - using Request = typename decltype(ctx_)::Request; - using Response = typename decltype(ctx_)::Response; + using Request = typename Base::Request; + using Response = typename Base::Response; template - UnaryContext(Args&&... args) : ctx_(std::forward(args)...) {} - - Service& service() { return ctx_.service; } + UnaryContext(Args&&... args) : Base(std::forward(args)...) {} // Invokes the RPC with the provided request. Returns the status. Status call(const Request& request) { - ctx_.output.clear(); - ctx_.responses.emplace_back(); - ctx_.responses.back() = {}; + Base::output().clear(); + Response& response = Base::output().AllocateResponse(); return CallMethodImplFunction( - ctx_.call, request, ctx_.responses.back()); - } - - // Gives access to the RPC's response. - const Response& response() const { - PW_ASSERT(ctx_.responses.size() > 0u); - return ctx_.responses.back(); + Base::server_call(), request, response); } }; @@ -188,52 +189,40 @@ class UnaryContext { template -class ServerStreamingContext { +class ServerStreamingContext : public NanopbInvocationContext { private: - InvocationContext - ctx_; + using Base = NanopbInvocationContext; public: - using Request = typename decltype(ctx_)::Request; - using Response = typename decltype(ctx_)::Response; + using Request = typename Base::Request; + using Response = typename Base::Response; template - ServerStreamingContext(Args&&... args) : ctx_(std::forward(args)...) {} - - Service& service() { return ctx_.service; } + ServerStreamingContext(Args&&... args) : Base(std::forward(args)...) {} // Invokes the RPC with the provided request. void call(const Request& request) { - ctx_.output.clear(); - NanopbServerWriter writer(ctx_.call); - return CallMethodImplFunction(ctx_.call, request, writer); + Base::output().clear(); + NanopbServerWriter writer(Base::server_call()); + return CallMethodImplFunction( + Base::server_call(), request, writer); } // Returns a server writer which writes responses into the context's buffer. // This should not be called alongside call(); use one or the other. NanopbServerWriter writer() { - ctx_.output.clear(); - return NanopbServerWriter(ctx_.call); - } - - // Returns the responses that have been recorded. The maximum number of - // responses is responses().max_size(). responses().back() is always the most - // recent response, even if total_responses() > responses().max_size(). - const Vector& responses() const { return ctx_.responses; } - - // The total number of responses sent, which may be larger than - // responses.max_size(). - size_t total_responses() const { return ctx_.output.total_responses(); } - - // True if the stream has terminated. - bool done() const { return ctx_.output.done(); } - - // The status of the stream. Only valid if done() is true. - Status status() const { - PW_ASSERT(done()); - return ctx_.output.last_status(); + Base::output().clear(); + return NanopbServerWriter(Base::server_call()); } }; @@ -242,7 +231,7 @@ class ServerStreamingContext { template using Context = std::tuple_element_t< static_cast(internal::MethodTraits::kType), @@ -250,7 +239,7 @@ using Context = std::tuple_element_t< ServerStreamingContext // TODO(hepler): Support client and bidi streaming >>; @@ -260,11 +249,14 @@ using Context = std::tuple_element_t< template class NanopbTestMethodContext - : public internal::test::nanopb:: - Context { + : public internal::test::nanopb::Context { public: // Forwards constructor arguments to the service class. template @@ -272,7 +264,7 @@ class NanopbTestMethodContext : internal::test::nanopb::Context( std::forward(service_args)...) {} }; diff --git a/pw_rpc/public/pw_rpc/internal/test_method_context.h b/pw_rpc/public/pw_rpc/internal/test_method_context.h new file mode 100644 index 000000000..b741abae2 --- /dev/null +++ b/pw_rpc/public/pw_rpc/internal/test_method_context.h @@ -0,0 +1,122 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include +#include + +#include "pw_assert/assert.h" +#include "pw_rpc/channel.h" +#include "pw_rpc/internal/packet.h" +#include "pw_rpc/internal/server.h" +#include "pw_rpc_private/fake_channel_output.h" + +namespace pw::rpc::internal::test { + +// Collects everything needed to invoke a particular RPC. +template +class InvocationContext { + public: + Service& service() { return service_; } + + // The total number of responses sent, which may be larger than + // responses.max_size(). + size_t total_responses() const { return output_.total_responses(); } + + // True if the RPC has completed. + bool done() const { return output_.done(); } + + // The status of the stream. Only valid if done() is true. + Status status() const { + PW_ASSERT(done()); + return output_.last_status(); + } + + void SendClientError(Status error) { + std::byte packet[kNoPayloadPacketSizeBytes]; + server_.ProcessPacket(Packet(PacketType::CLIENT_ERROR, + channel_.id(), + service_.id(), + kMethodId, + {}, + error) + .Encode(packet) + .value(), + output_); + } + + void SendCancel() { + std::byte packet[kNoPayloadPacketSizeBytes]; + server_.ProcessPacket( + Packet(PacketType::CANCEL, channel_.id(), service_.id(), kMethodId) + .Encode(packet) + .value(), + output_); + } + + protected: + template + void SendClientStream(ConstByteSpan payload) { + std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize]; + server_.ProcessPacket(Packet(PacketType::CLIENT_STREAM, + channel_.id(), + service_.id(), + kMethodId, + payload) + .Encode(packet) + .value(), + output_); + } + + void SendClientStreamEnd() { + std::byte packet[kNoPayloadPacketSizeBytes]; + server_.ProcessPacket(Packet(PacketType::CLIENT_STREAM_END, + channel_.id(), + service_.id(), + kMethodId) + .Encode(packet) + .value(), + output_); + } + + template + InvocationContext(const Method& method, + FakeChannelOutput& output, + Args&&... service_args) + : output_(output), + channel_(Channel::Create<123>(&output_)), + server_(std::span(&channel_, 1)), + service_(std::forward(service_args)...), + server_call_(static_cast(server_), + static_cast(channel_), + service_, + method) { + server_.RegisterService(service_); + } + + internal::ServerCall& server_call() { return server_call_; } + + private: + static constexpr size_t kNoPayloadPacketSizeBytes = + 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ + + 2 /* status */; + + FakeChannelOutput& output_; + rpc::Channel channel_; + rpc::Server server_; + Service service_; + internal::ServerCall server_call_; +}; + +} // namespace pw::rpc::internal::test diff --git a/pw_rpc/pw_rpc_test_protos/test.proto b/pw_rpc/pw_rpc_test_protos/test.proto index 20ab3a2c0..bd31fb9a1 100644 --- a/pw_rpc/pw_rpc_test_protos/test.proto +++ b/pw_rpc/pw_rpc_test_protos/test.proto @@ -32,6 +32,9 @@ message TestStreamResponse { message Empty {} service TestService { - rpc TestRpc(TestRequest) returns (TestResponse) {} - rpc TestStreamRpc(TestRequest) returns (stream TestStreamResponse) {} + rpc TestRpc(TestRequest) returns (TestResponse); + rpc TestStreamRpc(TestRequest) returns (stream TestStreamResponse); + rpc TestClientStreamRpc(stream TestRequest) returns (TestStreamResponse); + rpc TestBidirectionalStreamRpc(stream TestRequest) + returns (stream TestStreamResponse); } diff --git a/pw_rpc/raw/codegen_test.cc b/pw_rpc/raw/codegen_test.cc index 520ae4eea..635139e63 100644 --- a/pw_rpc/raw/codegen_test.cc +++ b/pw_rpc/raw/codegen_test.cc @@ -20,6 +20,33 @@ #include "pw_rpc_test_protos/test.raw_rpc.pb.h" namespace pw::rpc { +namespace { + +Vector EncodeRequest(int integer, Status status) { + Vector buffer(64); + test::TestRequest::RamEncoder test_request(buffer); + + test_request.WriteInteger(integer); + test_request.WriteStatusCode(status.code()); + + EXPECT_EQ(OkStatus(), test_request.status()); + buffer.resize(test_request.size()); + return buffer; +} + +Vector EncodeResponse(int number) { + Vector buffer(64); + test::TestStreamResponse::RamEncoder test_response(buffer); + + test_response.WriteNumber(number); + + EXPECT_EQ(OkStatus(), test_response.status()); + buffer.resize(test_response.size()); + return buffer; +} + +} // namespace + namespace test { class TestService final : public generated::TestService { @@ -59,7 +86,49 @@ class TestService final : public generated::TestService { writer.Finish(status); } + void TestClientStreamRpc(ServerContext&, RawServerReader& reader) { + last_reader_ = std::move(reader); + + last_reader_.set_on_next([this](ConstByteSpan payload) { + last_reader_.Finish(EncodeResponse(ReadInteger(payload)), + Status::Unauthenticated()); + }); + } + + void TestBidirectionalStreamRpc(ServerContext&, + RawServerReaderWriter& reader_writer) { + last_reader_writer_ = std::move(reader_writer); + + last_reader_writer_.set_on_next([this](ConstByteSpan payload) { + last_reader_writer_.Write(EncodeResponse(ReadInteger(payload))); + last_reader_writer_.Finish(Status::NotFound()); + }); + } + + protected: + RawServerReader last_reader_; + RawServerReaderWriter last_reader_writer_; + private: + static uint32_t ReadInteger(ConstByteSpan request) { + uint32_t integer = 0; + + protobuf::Decoder decoder(request); + while (decoder.Next().ok()) { + switch (static_cast(decoder.FieldNumber())) { + case TestRequest::Fields::INTEGER: + EXPECT_EQ(OkStatus(), decoder.ReadUint32(&integer)); + break; + case TestRequest::Fields::STATUS_CODE: + break; + default: + ADD_FAILURE(); + } + } + + return integer; + } + static bool DecodeRequest(ConstByteSpan request, int64_t& integer, Status& status) { @@ -104,13 +173,7 @@ TEST(RawCodegen, CompilesProperly) { TEST(RawCodegen, Server_InvokeUnaryRpc) { PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestRpc) context; - std::byte buffer[64]; - protobuf::NestedEncoder encoder(buffer); - test::TestRequest::Encoder test_request(&encoder); - test_request.WriteInteger(123); - test_request.WriteStatusCode(OkStatus().code()); - - auto sws = context.call(encoder.Encode().value()); + auto sws = context.call(EncodeRequest(123, OkStatus())); EXPECT_EQ(OkStatus(), sws.status()); protobuf::Decoder decoder(context.response()); @@ -130,13 +193,7 @@ TEST(RawCodegen, Server_InvokeUnaryRpc) { TEST(RawCodegen, Server_InvokeServerStreamingRpc) { PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc) context; - std::byte buffer[64]; - protobuf::NestedEncoder encoder(buffer); - test::TestRequest::Encoder test_request(&encoder); - test_request.WriteInteger(5); - test_request.WriteStatusCode(Status::Unauthenticated().code()); - - context.call(encoder.Encode().value()); + context.call(EncodeRequest(5, Status::Unauthenticated())); EXPECT_TRUE(context.done()); EXPECT_EQ(Status::Unauthenticated(), context.status()); EXPECT_EQ(context.total_responses(), 5u); @@ -158,5 +215,49 @@ TEST(RawCodegen, Server_InvokeServerStreamingRpc) { } } +int32_t ReadResponseNumber(ConstByteSpan data) { + int32_t value = -1; + protobuf::Decoder decoder(data); + while (decoder.Next().ok()) { + switch ( + static_cast(decoder.FieldNumber())) { + case test::TestStreamResponse::Fields::NUMBER: { + decoder.ReadInt32(&value); + break; + } + default: + ADD_FAILURE(); + break; + } + } + + return value; +} + +TEST(RawCodegen, Server_InvokeClientStreamingRpc) { + PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestClientStreamRpc) ctx; + + ctx.call(); + ctx.SendClientStream(EncodeRequest(123, OkStatus())); + + ASSERT_TRUE(ctx.done()); + EXPECT_EQ(Status::Unauthenticated(), ctx.status()); + EXPECT_EQ(ctx.total_responses(), 1u); + EXPECT_EQ(ReadResponseNumber(ctx.responses().back()), 123); +} + +TEST(RawCodegen, Server_InvokeBidirectionalStreamingRpc) { + PW_RAW_TEST_METHOD_CONTEXT(test::TestService, TestBidirectionalStreamRpc) + ctx; + + ctx.call(); + ctx.SendClientStream(EncodeRequest(456, OkStatus())); + + ASSERT_TRUE(ctx.done()); + EXPECT_EQ(Status::NotFound(), ctx.status()); + ASSERT_EQ(ctx.total_responses(), 1u); + EXPECT_EQ(ReadResponseNumber(ctx.responses().back()), 456); +} + } // namespace } // namespace pw::rpc diff --git a/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h b/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h index 7edcd3331..c9552cd0b 100644 --- a/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h +++ b/pw_rpc/raw/public/pw_rpc/raw/server_reader_writer.h @@ -30,6 +30,10 @@ namespace test::raw { template class ServerStreamingContext; +template +class ClientStreamingContext; +template +class BidirectionalStreamingContext; } // namespace test::raw } // namespace internal @@ -81,6 +85,9 @@ class RawServerReaderWriter : private internal::Responder { private: friend class internal::RawMethod; + + template + friend class internal::test::raw::BidirectionalStreamingContext; }; // The RawServerReader is used to receive messages and send a response in a @@ -107,6 +114,9 @@ class RawServerReader : private RawServerReaderWriter { private: friend class internal::RawMethod; // Needed for conversions from ReaderWriter + template + friend class internal::test::raw::ClientStreamingContext; + using RawServerReaderWriter::HasClientStream; RawServerReader(internal::ServerCall& call) diff --git a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h index 0302ac95f..6e0d66f50 100644 --- a/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h +++ b/pw_rpc/raw/public/pw_rpc/raw_test_method_context.h @@ -23,6 +23,7 @@ #include "pw_rpc/internal/packet.h" #include "pw_rpc/internal/raw_method.h" #include "pw_rpc/internal/server.h" +#include "pw_rpc/internal/test_method_context.h" #include "pw_rpc_private/fake_channel_output.h" namespace pw::rpc { @@ -65,7 +66,7 @@ namespace pw::rpc { // // PW_RAW_TEST_METHOD_CONTEXT takes two optional arguments: // -// size_t kMaxResponse: maximum responses to store; ignored unless streaming +// size_t kMaxResponses: maximum responses to store; ignored unless streaming // size_t kOutputSizeBytes: buffer size; must be large enough for a packet // // Example: @@ -81,7 +82,7 @@ namespace pw::rpc { template class RawTestMethodContext; @@ -89,69 +90,78 @@ class RawTestMethodContext; namespace internal::test::raw { // A ChannelOutput implementation that stores the outgoing payloads and status. -template +template class MessageOutput final : public FakeChannelOutput { public: - using ResponseBuffer = std::array; + MessageOutput(bool server_streaming) + : FakeChannelOutput(packet_buffer_, server_streaming) {} - MessageOutput(Vector& responses, - Vector& buffers, - ByteSpan packet_buffer, - bool server_streaming) - : FakeChannelOutput(packet_buffer, server_streaming), - responses_(responses), - buffers_(buffers) {} + const Vector& responses() const { return responses_; } + + // Allocates a response buffer and returns a reference to the response span + // for it. + ByteSpan& AllocateResponse() { + // If we run out of space, the back message is always the most recent. + response_buffers_.emplace_back(); + response_buffers_.back() = {}; + + responses_.emplace_back(); + responses_.back() = {response_buffers_.back().data(), + response_buffers_.back().size()}; + return responses_.back(); + } private: void AppendResponse(ConstByteSpan response) override { - // If we run out of space, the back message is always the most recent. - buffers_.emplace_back(); - buffers_.back() = {}; - std::memcpy(&buffers_.back(), response.data(), response.size()); - responses_.emplace_back(); - responses_.back() = {buffers_.back().data(), response.size()}; + ByteSpan& response_span = AllocateResponse(); + PW_ASSERT(response.size() <= response_span.size()); + + std::memcpy(response_span.data(), response.data(), response.size()); + response_span = response_span.first(response.size()); } void ClearResponses() override { responses_.clear(); - buffers_.clear(); + response_buffers_.clear(); } - Vector& responses_; - Vector& buffers_; + std::array packet_buffer_; + Vector responses_; + Vector, kMaxResponses> response_buffers_; }; // Collects everything needed to invoke a particular RPC. template -struct InvocationContext { +class RawInvocationContext : public InvocationContext { + public: + // Returns the responses that have been recorded. The maximum number of + // responses is responses().max_size(). responses().back() is always the most + // recent response, even if total_responses() > responses().max_size(). + const Vector& responses() const { return output_.responses(); } + + // Gives access to the RPC's most recent response. + ConstByteSpan response() const { + PW_ASSERT(!responses().empty()); + return responses().back(); + } + + protected: template - InvocationContext(Args&&... args) - : output(responses, - buffers, - packet_buffer, - MethodTraits::kServerStreaming), - channel(Channel::Create<123>(&output)), - server(std::span(&channel, 1)), - service(std::forward(args)...), - call(static_cast(server), - static_cast(channel), - service, - MethodLookup::GetRawMethod()) {} + RawInvocationContext(Args&&... args) + : InvocationContext( + MethodLookup::GetRawMethod(), + output_, + std::forward(args)...), + output_(MethodTraits::kServerStreaming) {} - using ResponseBuffer = std::array; + MessageOutput& output() { return output_; } - MessageOutput output; - rpc::Channel channel; - rpc::Server server; - Service service; - Vector responses; - Vector buffers; - std::array packet_buffer = {}; - internal::ServerCall call; + private: + MessageOutput output_; }; // Method invocation context for a unary RPC. Returns the status in call() and @@ -160,87 +170,142 @@ template -class UnaryContext { - private: - using Context = - InvocationContext; - Context ctx_; +class UnaryContext + : public RawInvocationContext { + using Base = + RawInvocationContext; public: template - UnaryContext(Args&&... args) : ctx_(std::forward(args)...) {} - - Service& service() { return ctx_.service; } + UnaryContext(Args&&... args) : Base(std::forward(args)...) {} // Invokes the RPC with the provided request. Returns RPC's StatusWithSize. StatusWithSize call(ConstByteSpan request) { - ctx_.output.clear(); - ctx_.buffers.emplace_back(); - ctx_.buffers.back() = {}; - ctx_.responses.emplace_back(); - auto& response = ctx_.responses.back(); - response = {ctx_.buffers.back().data(), ctx_.buffers.back().size()}; - auto sws = CallMethodImplFunction(ctx_.call, request, response); + Base::output().clear(); + ByteSpan& response = Base::output().AllocateResponse(); + auto sws = + CallMethodImplFunction(Base::server_call(), request, response); response = response.first(sws.size()); return sws; } - - // Gives access to the RPC's response. - ConstByteSpan response() const { - PW_ASSERT(ctx_.responses.size() > 0u); - return ctx_.responses.back(); - } }; // Method invocation context for a server streaming RPC. template -class ServerStreamingContext { - private: - using Context = - InvocationContext; - Context ctx_; +class ServerStreamingContext : public RawInvocationContext { + using Base = RawInvocationContext; public: template - ServerStreamingContext(Args&&... args) : ctx_(std::forward(args)...) {} - - Service& service() { return ctx_.service; } + ServerStreamingContext(Args&&... args) : Base(std::forward(args)...) {} // Invokes the RPC with the provided request. void call(ConstByteSpan request) { - ctx_.output.clear(); - RawServerWriter server_writer(ctx_.call); - return CallMethodImplFunction(ctx_.call, request, server_writer); + Base::output().clear(); + RawServerWriter writer(Base::server_call()); + return CallMethodImplFunction( + Base::server_call(), request, writer); } // Returns a server writer which writes responses into the context's buffer. // This should not be called alongside call(); use one or the other. RawServerWriter writer() { - ctx_.output.clear(); - return RawServerWriter(ctx_.call); + Base::output().clear(); + return RawServerWriter(Base::server_call()); + } +}; + +// Method invocation context for a client streaming RPC. +template +class ClientStreamingContext : public RawInvocationContext { + using Base = RawInvocationContext; + + public: + template + ClientStreamingContext(Args&&... args) : Base(std::forward(args)...) {} + + // Invokes the RPC. + void call() { + Base::output().clear(); + RawServerReader reader_writer(Base::server_call()); + return CallMethodImplFunction(Base::server_call(), reader_writer); } - // Returns the responses that have been recorded. The maximum number of - // responses is responses().max_size(). responses().back() is always the most - // recent response, even if total_responses() > responses().max_size(). - const Vector& responses() const { return ctx_.responses; } - - // The total number of responses sent, which may be larger than - // responses.max_size(). - size_t total_responses() const { return ctx_.output.total_responses(); } - - // True if the stream has terminated. - bool done() const { return ctx_.output.done(); } - - // The status of the stream. Only valid if done() is true. - Status status() const { - PW_ASSERT(done()); - return ctx_.output.last_status(); + // Returns a reader/writer which writes responses into the context's buffer. + // This should not be called alongside call(); use one or the other. + RawServerReader reader() { + Base::output().clear(); + return RawServerReader(Base::server_call()); } + + // Allow sending client streaming packets. + using Base::SendClientStream; + using Base::SendClientStreamEnd; +}; + +// Method invocation context for a bidirectional streaming RPC. +template +class BidirectionalStreamingContext : public RawInvocationContext { + using Base = RawInvocationContext; + + public: + template + BidirectionalStreamingContext(Args&&... args) + : Base(std::forward(args)...) {} + + // Invokes the RPC. + void call() { + Base::output().clear(); + RawServerReaderWriter reader_writer(Base::server_call()); + return CallMethodImplFunction(Base::server_call(), reader_writer); + } + + // Returns a reader/writer which writes responses into the context's buffer. + // This should not be called alongside call(); use one or the other. + RawServerReaderWriter reader_writer() { + Base::output().clear(); + return RawServerReaderWriter(Base::server_call()); + } + + // Allow sending client streaming packets. + using Base::SendClientStream; + using Base::SendClientStreamEnd; }; // Alias to select the type of the context object to use based on which type of @@ -248,7 +313,7 @@ class ServerStreamingContext { template using Context = std::tuple_element_t< static_cast(MethodTraits::kType), @@ -256,21 +321,32 @@ using Context = std::tuple_element_t< ServerStreamingContext - // TODO(hepler): Support client and bidi streaming - >>; + kMaxResponses, + kOutputSize>, + ClientStreamingContext, + BidirectionalStreamingContext>>; } // namespace internal::test::raw template class RawTestMethodContext - : public internal::test::raw:: - Context { + : public internal::test::raw::Context { public: // Forwards constructor arguments to the service class. template @@ -278,7 +354,7 @@ class RawTestMethodContext : internal::test::raw::Context( std::forward(service_args)...) {} };