Skip to content

Commit abbd2fc

Browse files
committed
allow transports to see and manipulate client and server contexts.
1 parent ac29717 commit abbd2fc

File tree

20 files changed

+269
-152
lines changed

20 files changed

+269
-152
lines changed

example-service/src/client.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
use clap::Parser;
88
use service::{WorldClient, init_tracing};
99
use std::{net::SocketAddr, time::Duration};
10-
use tarpc::{client, context, tokio_serde::formats::Json};
10+
use futures::{future, SinkExt};
11+
use tarpc::{client, tokio_serde::formats::Json};
1112
use tokio::time::sleep;
1213
use tracing::Instrument;
1314
use tarpc::context::ClientContext;
@@ -30,9 +31,11 @@ async fn main() -> anyhow::Result<()> {
3031
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
3132
transport.config_mut().max_frame_length(usize::MAX);
3233

34+
let transport = transport.await?.with(|msg: tarpc::ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
35+
3336
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
3437
// config and any Transport as input.
35-
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
38+
let client = WorldClient::new(client::Config::default(), transport).spawn();
3639

3740
let hello = async move {
3841
let mut context = ClientContext::current();

example-service/src/server.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ use std::{
1515
net::{IpAddr, Ipv6Addr, SocketAddr},
1616
time::Duration,
1717
};
18-
use tarpc::{
19-
context,
20-
server::{self, Channel, incoming::Incoming},
21-
tokio_serde::formats::Json,
22-
};
18+
use tarpc::{context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, ClientMessage};
2319
use tokio::time;
20+
use tarpc::context::{ServerContext, SharedContext};
2421

2522
#[derive(Parser)]
2623
struct Flags {
@@ -62,13 +59,14 @@ async fn main() -> anyhow::Result<()> {
6259
listener
6360
// Ignore accept errors.
6461
.filter_map(|r| future::ready(r.ok()))
62+
.map(|t| t.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx))))
6563
.map(server::BaseChannel::with_defaults)
6664
// Limit channels to 1 per IP.
67-
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
65+
.max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip())
6866
// serve is generated by the service attribute. It takes as input any type implementing
6967
// the generated World trait.
7068
.map(|channel| {
71-
let server = HelloServer(channel.transport().peer_addr().unwrap());
69+
let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap());
7270
channel.execute(server.serve()).for_each(spawn)
7371
})
7472
// Max 10 channels.

plugins/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ proc-macro = true
3030
[dev-dependencies]
3131
assert-type-eq = "0.1.0"
3232
futures = "0.3"
33+
futures-util = "0.3.31"
3334
serde = { version = "1.0", features = ["derive"] }
3435
tarpc = { path = "../tarpc", features = ["serde1"] }

plugins/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
376376
///
377377
/// ```no_run
378378
/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext};
379+
/// use futures_util::{TryStreamExt, sink::SinkExt};
379380
///
380381
/// #[service]
381382
/// pub trait Calculator {
@@ -394,6 +395,13 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
394395
/// // This could be any transport.
395396
/// let (client_side, server_side) = transport::channel::unbounded();
396397
///
398+
/// let client_side = client_side.with(|msg: tarpc::ClientMessage<tarpc::context::ClientContext, _>| async move {
399+
/// Ok(msg.map_context(|ctx| ctx.shared_context))
400+
/// });
401+
/// let server_side = server_side.map_ok(|msg: tarpc::ClientMessage<tarpc::context::SharedContext, _>|
402+
/// msg.map_context(tarpc::context::ServerContext::new)
403+
/// );
404+
///
397405
/// // A client can be made like so:
398406
/// let client = CalculatorClient::new(client::Config::default(), client_side);
399407
///
@@ -738,7 +746,7 @@ impl ServiceGenerator<'_> {
738746
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
739747
>
740748
where
741-
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
749+
T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<#response_ident>>
742750
{
743751
let new_client = ::tarpc::client::new(config, transport);
744752
::tarpc::client::NewClient {

tarpc/examples/compression.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*};
99
use serde::{Deserialize, Serialize};
1010
use serde_bytes::ByteBuf;
1111
use std::{io, io::Read, io::Write};
12-
use tarpc::{
13-
client, context,
14-
serde_transport::tcp,
15-
server::{BaseChannel, Channel},
16-
tokio_serde::formats::Bincode,
17-
};
12+
use tarpc::{client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, ClientMessage};
13+
use tarpc::context::{ClientContext, ServerContext, SharedContext};
1814

1915
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
2016
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
@@ -120,17 +116,22 @@ async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
120116
#[tokio::main]
121117
async fn main() -> anyhow::Result<()> {
122118
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
119+
123120
let addr = incoming.local_addr();
124121
tokio::spawn(async move {
125122
let transport = incoming.next().await.unwrap().unwrap();
126-
BaseChannel::with_defaults(add_compression(transport))
123+
let transport = add_compression(transport);
124+
let transport = transport.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx)));
125+
BaseChannel::with_defaults(transport)
127126
.execute(HelloServer.serve())
128127
.for_each(spawn)
129128
.await;
130129
});
131130

132131
let transport = tcp::connect(addr, Bincode::default).await?;
133-
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
132+
let transport = add_compression(transport);
133+
let transport = transport.with(|msg: ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
134+
let client = WorldClient::new(client::Config::default(), transport).spawn();
134135

135136
println!(
136137
"{}",

tarpc/examples/custom_transport.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
// https://opensource.org/licenses/MIT.
66

77
use futures::prelude::*;
8-
use tarpc::context::{ClientContext, ServerContext};
9-
use tarpc::serde_transport as transport;
8+
use tarpc::context::{ClientContext, ServerContext, SharedContext};
9+
use tarpc::{serde_transport as transport, ClientMessage};
1010
use tarpc::server::{BaseChannel, Channel};
1111
use tarpc::tokio_serde::formats::Bincode;
1212
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
@@ -23,7 +23,6 @@ struct Service;
2323
impl PingService for Service {
2424
async fn ping(self, _: &mut ServerContext) {}
2525
}
26-
2726
#[tokio::main]
2827
async fn main() -> anyhow::Result<()> {
2928
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
@@ -40,6 +39,7 @@ async fn main() -> anyhow::Result<()> {
4039
let (conn, _addr) = listener.accept().await.unwrap();
4140
let framed = codec_builder.new_framed(conn);
4241
let transport = transport::new(framed, Bincode::default());
42+
let transport = transport.map_ok(|c: ClientMessage<SharedContext, _>| c.map_context(ServerContext::new));
4343

4444
let fut = BaseChannel::with_defaults(transport)
4545
.execute(Service.serve())
@@ -50,6 +50,7 @@ async fn main() -> anyhow::Result<()> {
5050

5151
let conn = UnixStream::connect(bind_addr).await?;
5252
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
53+
let transport = transport.with(|msg: ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
5354
PingServiceClient::new(Default::default(), transport)
5455
.spawn()
5556
.ping(&mut ClientContext::current())

tarpc/examples/pubsub.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,11 @@ use std::{
4848
sync::{Arc, Mutex, RwLock},
4949
};
5050
use subscriber::Subscriber as _;
51-
use tarpc::{
52-
client, context,
53-
serde_transport::tcp,
54-
server::{self, Channel},
55-
tokio_serde::formats::Json,
56-
};
51+
use tarpc::{client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, ClientMessage};
5752
use tokio::net::ToSocketAddrs;
5853
use tracing::info;
5954
use tracing_subscriber::prelude::*;
55+
use tarpc::context::{ServerContext, SharedContext};
6056

6157
pub mod subscriber {
6258
#[tarpc::service]
@@ -104,6 +100,7 @@ impl Subscriber {
104100
) -> anyhow::Result<SubscriberHandle> {
105101
let publisher = tcp::connect(publisher_addr, Json::default).await?;
106102
let local_addr = publisher.local_addr()?;
103+
let publisher = publisher.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx)));
107104
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
108105
let subscriber = Subscriber { local_addr, topics };
109106
// The first request is for the topics being subscribed to.
@@ -164,6 +161,8 @@ impl Publisher {
164161
let publisher = connecting_publishers.next().await.unwrap().unwrap();
165162
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
166163

164+
let publisher = publisher.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx)));
165+
167166
server::BaseChannel::with_defaults(publisher)
168167
.execute(self.serve())
169168
.for_each(spawn)
@@ -183,6 +182,7 @@ impl Publisher {
183182
tokio::spawn(async move {
184183
while let Some(conn) = connecting_subscribers.next().await {
185184
let subscriber_addr = conn.peer_addr().unwrap();
185+
let conn = conn.with(|msg: tarpc::ClientMessage<tarpc::context::ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
186186

187187
let tarpc::client::NewClient {
188188
client: subscriber,
@@ -341,7 +341,7 @@ async fn main() -> anyhow::Result<()> {
341341

342342
let publisher = publisher::PublisherClient::new(
343343
client::Config::default(),
344-
tcp::connect(addrs.publisher, Json::default).await?,
344+
tcp::connect(addrs.publisher, Json::default).await?.with(|msg: tarpc::ClientMessage<tarpc::context::ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)))
345345
)
346346
.spawn();
347347

tarpc/examples/readme.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
// https://opensource.org/licenses/MIT.
66

77
use futures::prelude::*;
8-
use tarpc::{
9-
client, context,
10-
server::{self, Channel},
11-
};
8+
use tarpc::{client, context, server::{self, Channel}, transport, ClientMessage};
9+
use tarpc::context::{ClientContext, ServerContext, SharedContext};
1210

1311
/// This is the service definition. It looks a lot like a trait definition.
1412
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -34,7 +32,10 @@ async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
3432

3533
#[tokio::main]
3634
async fn main() -> anyhow::Result<()> {
37-
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
35+
let (client_transport, server_transport) = transport::channel::unbounded_mapped(
36+
|msg: ClientMessage<ClientContext, _>| msg.map_context(|ctx| ctx.shared_context),
37+
|msg: ClientMessage<SharedContext, _>| msg.map_context(ServerContext::new),
38+
);
3839

3940
let server = server::BaseChannel::with_defaults(server_transport);
4041
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));

tarpc/examples/tls_over_tcp.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ use tokio_rustls::rustls::{
1717
server::{WebPkiClientVerifier, danger::ClientCertVerifier},
1818
};
1919
use tokio_rustls::{TlsAcceptor, TlsConnector};
20-
21-
use tarpc::context::{ClientContext, ServerContext};
20+
use tarpc::context::{ClientContext, ServerContext, SharedContext};
2221
use tarpc::serde_transport as transport;
2322
use tarpc::server::{BaseChannel, Channel};
2423
use tarpc::tokio_serde::formats::Bincode;
@@ -115,6 +114,7 @@ async fn main() -> anyhow::Result<()> {
115114
let framed = codec_builder.new_framed(tls_stream);
116115

117116
let transport = transport::new(framed, Bincode::default());
117+
let transport = transport.map_ok(|c: tarpc::ClientMessage<SharedContext, _>| c.map_context(|ctx| ServerContext::new(ctx)));
118118

119119
let fut = BaseChannel::with_defaults(transport)
120120
.execute(Service.serve())
@@ -144,6 +144,7 @@ async fn main() -> anyhow::Result<()> {
144144
let stream = connector.connect(domain, stream).await?;
145145

146146
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
147+
let transport = transport.with(|msg: tarpc::ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
147148
let answer = PingServiceClient::new(Default::default(), transport)
148149
.spawn()
149150
.ping(&mut ClientContext::current())

tarpc/examples/tracing.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use tarpc::{
3535
};
3636
use tokio::net::TcpStream;
3737
use tracing_subscriber::prelude::*;
38+
use tarpc::context::{ClientContext, ServerContext, SharedContext};
3839

3940
pub mod add {
4041
#[tarpc::service]
@@ -124,7 +125,7 @@ where
124125
}
125126

126127
fn make_stub<Req, Resp, const N: usize>(
127-
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
128+
backends: [impl Transport<ClientMessage<ClientContext, Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
128129
) -> retry::Retry<
129130
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
130131
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
@@ -173,23 +174,28 @@ async fn main() -> anyhow::Result<()> {
173174
.serving(AddServer.serve());
174175
let add_server = add_listener1
175176
.chain(add_listener2)
177+
.map(|t| t.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx))))
176178
.map(BaseChannel::with_defaults);
177179
tokio::spawn(spawn_incoming(add_server.execute(server)));
178180

181+
let map_context = |msg: ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context));
182+
179183
let add_client = add::AddClient::from(make_stub([
180-
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
181-
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
184+
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?.with(map_context),
185+
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?.with(map_context),
182186
]));
183187

184188
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
185189
.await?
186-
.filter_map(|r| future::ready(r.ok()));
187-
let addr = double_listener.get_ref().local_addr();
190+
.filter_map(|r| future::ready(r.ok()))
191+
.map(|t| t.map_ok(|msg: ClientMessage<SharedContext, _>| msg.map_context(|ctx| ServerContext::new(ctx))));
192+
let addr = double_listener.get_ref().get_ref().local_addr();
188193
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
189194
let server = DoubleServer { add_client }.serve();
190195
tokio::spawn(spawn_incoming(double_server.execute(server)));
191196

192197
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
198+
let to_double_server = to_double_server.with(|msg: ClientMessage<ClientContext, _>| future::ok(msg.map_context(|ctx| ctx.shared_context)));
193199
let double_client =
194200
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
195201

0 commit comments

Comments
 (0)