@@ -2,7 +2,7 @@ use async_trait::async_trait;
22use serde:: { Serialize , de:: DeserializeOwned } ;
33use serde_json:: Value as JsonValue ;
44use std:: borrow:: Cow ;
5- use std:: marker :: PhantomData ;
5+ use std:: sync :: Arc ;
66
77use crate :: context:: TaskContext ;
88use crate :: error:: { TaskError , TaskResult } ;
@@ -25,11 +25,11 @@ use crate::error::{TaskError, TaskResult};
2525///
2626/// #[async_trait]
2727/// impl Task<()> for SendEmailTask {
28- /// fn name() -> Cow<'static, str> { Cow::Borrowed("send-email") }
28+ /// fn name(&self ) -> Cow<'static, str> { Cow::Borrowed("send-email") }
2929/// type Params = SendEmailParams;
3030/// type Output = SendEmailResult;
3131///
32- /// async fn run(params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
32+ /// async fn run(&self, params: Self::Params, mut ctx: TaskContext, _state: ()) -> TaskResult<Self::Output> {
3333/// let result = ctx.step("send", || async {
3434/// email_service::send(¶ms.to, ¶ms.subject, ¶ms.body).await
3535/// }).await?;
@@ -48,11 +48,11 @@ use crate::error::{TaskError, TaskResult};
4848///
4949/// #[async_trait]
5050/// impl Task<AppState> for FetchUrlTask {
51- /// fn name() -> Cow<'static, str> { Cow::Borrowed("fetch-url") }
51+ /// fn name(&self ) -> Cow<'static, str> { Cow::Borrowed("fetch-url") }
5252/// type Params = String;
5353/// type Output = String;
5454///
55- /// async fn run(url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
55+ /// async fn run(&self, url: Self::Params, mut ctx: TaskContext, state: AppState) -> TaskResult<Self::Output> {
5656/// let body = ctx.step("fetch", || async {
5757/// state.http_client.get(&url).send().await
5858/// .map_err(|e| anyhow::anyhow!("HTTP error: {}", e))?
7070{
7171 /// Task name as stored in the database.
7272 /// Should be unique across your application.
73- fn name ( ) -> Cow < ' static , str > ;
73+ fn name ( & self ) -> Cow < ' static , str > ;
7474
7575 /// Parameter type (must be JSON-serializable)
7676 type Params : Serialize + DeserializeOwned + Send ;
9494 /// The `state` parameter provides access to application-level resources
9595 /// like HTTP clients, database pools, etc.
9696 async fn run (
97+ & self ,
9798 params : Self :: Params ,
9899 ctx : TaskContext < State > ,
99100 state : State ,
@@ -119,14 +120,27 @@ where
119120 ) -> Result < JsonValue , TaskError > ;
120121}
121122
123+ /// Wrapper that implements [`ErasedTask`] for any [`Task`] type.
124+ ///
125+ /// This allows storing heterogeneous tasks in a registry while preserving
126+ /// their ability to execute.
127+ pub struct TaskWrapper < T > ( pub Arc < T > ) ;
128+
129+ impl < T > TaskWrapper < T > {
130+ /// Create a new TaskWrapper from a task instance.
131+ pub fn new ( task : T ) -> Self {
132+ Self ( Arc :: new ( task) )
133+ }
134+ }
135+
122136#[ async_trait]
123- impl < T , State > ErasedTask < State > for PhantomData < T >
137+ impl < T , State > ErasedTask < State > for TaskWrapper < T >
124138where
125139 T : Task < State > ,
126140 State : Clone + Send + Sync + ' static ,
127141{
128142 fn name ( & self ) -> Cow < ' static , str > {
129- T :: name ( )
143+ self . 0 . name ( )
130144 }
131145
132146 fn validate_params ( & self , params : JsonValue ) -> Result < ( ) , TaskError > {
@@ -142,11 +156,11 @@ where
142156 state : State ,
143157 ) -> Result < JsonValue , TaskError > {
144158 let typed_params: T :: Params = serde_json:: from_value ( params) ?;
145- let result = T :: run ( typed_params, ctx, state) . await ?;
159+ let result = self . 0 . run ( typed_params, ctx, state) . await ?;
146160 Ok ( serde_json:: to_value ( & result) ?)
147161 }
148162}
149163
150164/// Type alias for the task registry
151165pub type TaskRegistry < State > =
152- std:: collections:: HashMap < Cow < ' static , str > , & ' static dyn ErasedTask < State > > ;
166+ std:: collections:: HashMap < Cow < ' static , str > , Arc < dyn ErasedTask < State > > > ;
0 commit comments