Skip to content

Commit 851d60d

Browse files
authored
changed interface to use instance methods rather than static ones (#58)
* changed interface to use instance mthods rather than static ones * changed interface to use instance mthods rather than static ones
1 parent 45734a1 commit 851d60d

8 files changed

Lines changed: 246 additions & 84 deletions

File tree

benches/common/tasks.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,23 @@ use std::borrow::Cow;
77
// ============================================================================
88

99
#[allow(dead_code)]
10+
#[derive(Default)]
1011
pub struct NoOpTask;
1112

1213
#[async_trait]
1314
impl Task<()> for NoOpTask {
14-
fn name() -> Cow<'static, str> {
15+
fn name(&self) -> Cow<'static, str> {
1516
Cow::Borrowed("bench-noop")
1617
}
1718
type Params = ();
1819
type Output = ();
1920

20-
async fn run(_params: Self::Params, _ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
21+
async fn run(
22+
&self,
23+
_params: Self::Params,
24+
_ctx: TaskContext,
25+
_state: (),
26+
) -> TaskResult<Self::Output> {
2127
Ok(())
2228
}
2329
}
@@ -27,6 +33,7 @@ impl Task<()> for NoOpTask {
2733
// ============================================================================
2834

2935
#[allow(dead_code)]
36+
#[derive(Default)]
3037
pub struct QuickTask;
3138

3239
#[allow(dead_code)]
@@ -37,13 +44,18 @@ pub struct QuickParams {
3744

3845
#[async_trait]
3946
impl Task<()> for QuickTask {
40-
fn name() -> Cow<'static, str> {
47+
fn name(&self) -> Cow<'static, str> {
4148
Cow::Borrowed("bench-quick")
4249
}
4350
type Params = QuickParams;
4451
type Output = u32;
4552

46-
async fn run(params: Self::Params, _ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
53+
async fn run(
54+
&self,
55+
params: Self::Params,
56+
_ctx: TaskContext,
57+
_state: (),
58+
) -> TaskResult<Self::Output> {
4759
Ok(params.task_num)
4860
}
4961
}
@@ -53,6 +65,7 @@ impl Task<()> for QuickTask {
5365
// ============================================================================
5466

5567
#[allow(dead_code)]
68+
#[derive(Default)]
5669
pub struct MultiStepBenchTask;
5770

5871
#[allow(dead_code)]
@@ -63,13 +76,14 @@ pub struct MultiStepParams {
6376

6477
#[async_trait]
6578
impl Task<()> for MultiStepBenchTask {
66-
fn name() -> Cow<'static, str> {
79+
fn name(&self) -> Cow<'static, str> {
6780
Cow::Borrowed("bench-multi-step")
6881
}
6982
type Params = MultiStepParams;
7083
type Output = u32;
7184

7285
async fn run(
86+
&self,
7387
params: Self::Params,
7488
mut ctx: TaskContext,
7589
_state: (),
@@ -88,6 +102,7 @@ impl Task<()> for MultiStepBenchTask {
88102
// ============================================================================
89103

90104
#[allow(dead_code)]
105+
#[derive(Default)]
91106
pub struct LargePayloadBenchTask;
92107

93108
#[allow(dead_code)]
@@ -98,13 +113,14 @@ pub struct LargePayloadParams {
98113

99114
#[async_trait]
100115
impl Task<()> for LargePayloadBenchTask {
101-
fn name() -> Cow<'static, str> {
116+
fn name(&self) -> Cow<'static, str> {
102117
Cow::Borrowed("bench-large-payload")
103118
}
104119
type Params = LargePayloadParams;
105120
type Output = usize;
106121

107122
async fn run(
123+
&self,
108124
params: Self::Params,
109125
mut ctx: TaskContext,
110126
_state: (),

src/client.rs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ use serde::Serialize;
22
use serde_json::Value as JsonValue;
33
use sqlx::{Executor, PgPool, Postgres};
44
use std::collections::HashMap;
5-
use std::marker::PhantomData;
65
use std::sync::Arc;
76
use std::time::Duration;
87
use tokio::sync::RwLock;
98
use uuid::Uuid;
109

1110
use crate::error::{DurableError, DurableResult};
12-
use crate::task::{Task, TaskRegistry};
11+
use crate::task::{Task, TaskRegistry, TaskWrapper};
1312
use crate::types::{
1413
CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow,
1514
WorkerOptions,
@@ -305,31 +304,45 @@ where
305304
/// Register a task type. Required before spawning or processing.
306305
///
307306
/// Returns an error if a task with the same name is already registered.
308-
pub async fn register<T: Task<State>>(&self) -> DurableResult<&Self> {
307+
pub async fn register<T: Task<State> + Default>(&self) -> DurableResult<&Self> {
308+
self.register_instance(T::default()).await
309+
}
310+
311+
/// Register a task instance. Required before spawning or processing.
312+
///
313+
/// Use this when you need to register a task with runtime-determined metadata
314+
/// (e.g., a TypeScript tool loaded from a config file).
315+
///
316+
/// Returns an error if a task with the same name is already registered.
317+
pub async fn register_instance<T: Task<State>>(&self, task: T) -> DurableResult<&Self> {
309318
let mut registry = self.registry.write().await;
310-
let name = T::name();
311-
if registry.contains_key(&name) {
319+
let name = task.name();
320+
if registry.contains_key(name.as_ref()) {
312321
return Err(DurableError::TaskAlreadyRegistered {
313322
task_name: name.to_string(),
314323
});
315324
}
316-
registry.insert(name, &PhantomData::<T>);
325+
registry.insert(name, Arc::new(TaskWrapper::new(task)));
317326
Ok(self)
318327
}
319328

320329
/// Spawn a task (type-safe version)
321-
pub async fn spawn<T: Task<State>>(&self, params: T::Params) -> DurableResult<SpawnResult> {
330+
pub async fn spawn<T: Task<State> + Default>(
331+
&self,
332+
params: T::Params,
333+
) -> DurableResult<SpawnResult> {
322334
self.spawn_with_options::<T>(params, SpawnOptions::default())
323335
.await
324336
}
325337

326338
/// Spawn a task with options (type-safe version)
327-
pub async fn spawn_with_options<T: Task<State>>(
339+
pub async fn spawn_with_options<T: Task<State> + Default>(
328340
&self,
329341
params: T::Params,
330342
options: SpawnOptions,
331343
) -> DurableResult<SpawnResult> {
332-
self.spawn_by_name(&T::name(), serde_json::to_value(&params)?, options)
344+
let task = T::default();
345+
self.spawn_by_name(&task.name(), serde_json::to_value(&params)?, options)
333346
.await
334347
}
335348

@@ -370,7 +383,7 @@ where
370383
params: T::Params,
371384
) -> DurableResult<SpawnResult>
372385
where
373-
T: Task<State>,
386+
T: Task<State> + Default,
374387
E: Executor<'e, Database = Postgres>,
375388
{
376389
self.spawn_with_options_with::<T, E>(executor, params, SpawnOptions::default())
@@ -385,13 +398,13 @@ where
385398
options: SpawnOptions,
386399
) -> DurableResult<SpawnResult>
387400
where
388-
T: Task<State>,
401+
T: Task<State> + Default,
389402
E: Executor<'e, Database = Postgres>,
390403
{
391-
// Type-safe spawn uses T::name() which is already registered
404+
let task = T::default();
392405
self.spawn_by_name_internal(
393406
executor,
394-
&T::name(),
407+
&task.name(),
395408
serde_json::to_value(&params)?,
396409
options,
397410
)

src/context.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ where
575575
tracing::instrument(
576576
name = "durable.task.spawn",
577577
skip(self, params, options),
578-
fields(task_id = %self.task_id, subtask_name = %T::name())
578+
fields(task_id = %self.task_id)
579579
)
580580
)]
581581
pub async fn spawn<T>(
@@ -585,10 +585,11 @@ where
585585
options: crate::SpawnOptions,
586586
) -> TaskResult<TaskHandle<T::Output>>
587587
where
588-
T: Task<State>,
588+
T: Task<State> + Default,
589589
{
590+
let task = T::default();
590591
let params_json = serde_json::to_value(&params)?;
591-
self.spawn_by_name(name, &T::name(), params_json, options)
592+
self.spawn_by_name(name, &task.name(), params_json, options)
592593
.await
593594
}
594595

src/lib.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
//! #[derive(Serialize, Deserialize)]
2020
//! struct MyOutput { result: i32 }
2121
//!
22+
//! #[derive(Default)]
2223
//! struct MyTask;
2324
//!
2425
//! #[async_trait]
2526
//! impl Task<()> for MyTask {
26-
//! fn name() -> Cow<'static, str> { Cow::Borrowed("my-task") }
27+
//! fn name(&self) -> Cow<'static, str> { Cow::Borrowed("my-task") }
2728
//! type Params = MyParams;
2829
//! type Output = MyOutput;
2930
//!
30-
//! async fn run(params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
31+
//! async fn run(&self, params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
3132
//! let doubled = ctx.step("double", || async {
3233
//! Ok(params.value * 2)
3334
//! }).await?;
@@ -61,15 +62,16 @@
6162
//! http_client: reqwest::Client,
6263
//! }
6364
//!
65+
//! #[derive(Default)]
6466
//! struct FetchTask;
6567
//!
6668
//! #[async_trait]
6769
//! impl Task<AppState> for FetchTask {
68-
//! fn name() -> Cow<'static, str> { Cow::Borrowed("fetch") }
70+
//! fn name(&self) -> Cow<'static, str> { Cow::Borrowed("fetch") }
6971
//! type Params = String;
7072
//! type Output = String;
7173
//!
72-
//! async fn run(url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
74+
//! async fn run(&self, url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
7375
//! ctx.step("fetch", || async {
7476
//! state.http_client.get(&url).send().await?.text().await
7577
//! .map_err(|e| anyhow::anyhow!(e))
@@ -105,7 +107,7 @@ mod worker;
105107
pub use client::{Durable, DurableBuilder};
106108
pub use context::TaskContext;
107109
pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult};
108-
pub use task::Task;
110+
pub use task::{ErasedTask, Task, TaskWrapper};
109111
pub use types::{
110112
CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult,
111113
TaskHandle, WorkerOptions,

src/task.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use async_trait::async_trait;
22
use serde::{Serialize, de::DeserializeOwned};
33
use serde_json::Value as JsonValue;
44
use std::borrow::Cow;
5-
use std::marker::PhantomData;
5+
use std::sync::Arc;
66

77
use crate::context::TaskContext;
88
use crate::error::{TaskError, TaskResult};
@@ -25,11 +25,11 @@ use crate::error::{TaskError, TaskResult};
2525
///
2626
/// #[async_trait]
2727
/// impl Task<()> for SendEmailTask {
28-
/// fn name() -> Cow<'static, str> { Cow::Borrowed("send-email") }
28+
/// fn name(&self) -> Cow<'static, str> { Cow::Borrowed("send-email") }
2929
/// type Params = SendEmailParams;
3030
/// type Output = SendEmailResult;
3131
///
32-
/// async fn run(params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
32+
/// async fn run(&self, params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
3333
/// let result = ctx.step("send", || async {
3434
/// email_service::send(&params.to, &params.subject, &params.body).await
3535
/// }).await?;
@@ -48,11 +48,11 @@ use crate::error::{TaskError, TaskResult};
4848
///
4949
/// #[async_trait]
5050
/// impl Task<AppState> for FetchUrlTask {
51-
/// fn name() -> Cow<'static, str> { Cow::Borrowed("fetch-url") }
51+
/// fn name(&self) -> Cow<'static, str> { Cow::Borrowed("fetch-url") }
5252
/// type Params = String;
5353
/// type Output = String;
5454
///
55-
/// async fn run(url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
55+
/// async fn run(&self, url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
5656
/// let body = ctx.step("fetch", || async {
5757
/// state.http_client.get(&url).send().await
5858
/// .map_err(|e| anyhow::anyhow!("HTTP error: {}", e))?
@@ -70,7 +70,7 @@ where
7070
{
7171
/// Task name as stored in the database.
7272
/// Should be unique across your application.
73-
fn name() -> Cow<'static, str>;
73+
fn name(&self) -> Cow<'static, str>;
7474

7575
/// Parameter type (must be JSON-serializable)
7676
type Params: Serialize + DeserializeOwned + Send;
@@ -94,6 +94,7 @@ where
9494
/// The `state` parameter provides access to application-level resources
9595
/// like HTTP clients, database pools, etc.
9696
async fn run(
97+
&self,
9798
params: Self::Params,
9899
ctx: TaskContext<State>,
99100
state: State,
@@ -119,14 +120,27 @@ where
119120
) -> Result<JsonValue, TaskError>;
120121
}
121122

123+
/// Wrapper that implements [`ErasedTask`] for any [`Task`] type.
124+
///
125+
/// This allows storing heterogeneous tasks in a registry while preserving
126+
/// their ability to execute.
127+
pub struct TaskWrapper<T>(pub Arc<T>);
128+
129+
impl<T> TaskWrapper<T> {
130+
/// Create a new TaskWrapper from a task instance.
131+
pub fn new(task: T) -> Self {
132+
Self(Arc::new(task))
133+
}
134+
}
135+
122136
#[async_trait]
123-
impl<T, State> ErasedTask<State> for PhantomData<T>
137+
impl<T, State> ErasedTask<State> for TaskWrapper<T>
124138
where
125139
T: Task<State>,
126140
State: Clone + Send + Sync + 'static,
127141
{
128142
fn name(&self) -> Cow<'static, str> {
129-
T::name()
143+
self.0.name()
130144
}
131145

132146
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> {
@@ -142,11 +156,11 @@ where
142156
state: State,
143157
) -> Result<JsonValue, TaskError> {
144158
let typed_params: T::Params = serde_json::from_value(params)?;
145-
let result = T::run(typed_params, ctx, state).await?;
159+
let result = self.0.run(typed_params, ctx, state).await?;
146160
Ok(serde_json::to_value(&result)?)
147161
}
148162
}
149163

150164
/// Type alias for the task registry
151165
pub type TaskRegistry<State> =
152-
std::collections::HashMap<Cow<'static, str>, &'static dyn ErasedTask<State>>;
166+
std::collections::HashMap<Cow<'static, str>, Arc<dyn ErasedTask<State>>>;

src/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ impl Worker {
364364
// Look up handler
365365
let registry = registry.read().await;
366366
let handler = match registry.get(task.task_name.as_str()) {
367-
Some(h) => *h,
367+
Some(h) => h.clone(),
368368
None => {
369369
tracing::error!("Unknown task: {}", task.task_name);
370370
Self::fail_run(

0 commit comments

Comments
 (0)