1+ from typing import Any , Dict , Callable , TypeVar , Generic , Optional , Awaitable
2+ from firebase_functions .firestore_fn import Event , Change , DocumentSnapshot
3+ from .common import ChangeType , State , now , get_change_type , FirestoreField , safe_get
4+
5+ T = TypeVar ('T' )
6+ TOutput = TypeVar ('TOutput' , bound = Dict [str , Any ])
7+
8+ class ProcessConfig :
9+ def __init__ (
10+ self ,
11+ input_field : str ,
12+ process_fn : Callable [[Any , Event ], Awaitable [Dict [str , FirestoreField ]]],
13+ error_fn : Callable [[Any ], str ],
14+ status_field : Optional [str ] = None ,
15+ order_field : Optional [str ] = None
16+ ):
17+ self .input_field = input_field
18+ self .process_fn = process_fn
19+ self .error_fn = error_fn
20+ self .status_field = status_field
21+ self .order_field = order_field
22+
23+ class FirestoreOnWriteProcessor (Generic [T , TOutput ]):
24+ def __init__ (self , options : ProcessConfig ):
25+ self .input_field = options .input_field or 'prompt'
26+ self .order_field = options .order_field or 'createTime'
27+ self .status_field = options .status_field or 'status'
28+ self .process_fn = options .process_fn
29+ self .process_updates = True
30+ self .error_fn = options .error_fn
31+
32+ def should_process (self , change : Change ) -> bool :
33+ change_type = get_change_type (change )
34+ if change_type == ChangeType .DELETE :
35+ return False
36+
37+ # Extract status if it exists
38+ status = safe_get (change .after , self .status_field )
39+ state : State = status .get ("state" ) if isinstance (status , dict ) else None
40+ new_value = safe_get (change .after , self .input_field )
41+ old_value = safe_get (change .before , self .input_field )
42+
43+ has_changed = (
44+ change_type == ChangeType .CREATE or
45+ (self .process_updates and
46+ change_type == ChangeType .UPDATE and
47+ old_value != new_value )
48+ )
49+
50+ if (
51+ not new_value or
52+ state in [State .PROCESSING .value , State .COMPLETED .value , State .ERROR .value ] or
53+ not has_changed or
54+ not isinstance (new_value , str )
55+ ):
56+ return False
57+
58+ return True
59+
60+ async def write_start_event (self , event : Event [Change [DocumentSnapshot ]]) -> None :
61+ create_time = event .data .after .create_time
62+ update_time = now ()
63+
64+ status = {
65+ "state" : State .PROCESSING .value ,
66+ "startTime" : update_time ,
67+ "updateTime" : update_time
68+ }
69+
70+ start_data = safe_get (event .data .after , self .order_field )
71+ # Prepare update data
72+ if start_data :
73+ update_data = {self .status_field : status }
74+ else :
75+ update_data = {
76+ self .order_field : create_time ,
77+ self .status_field : status
78+ }
79+
80+ event .data .after .reference .update (update_data )
81+
82+ async def write_completion_event (self , change : Change , output : Dict [str , Any ]) -> None :
83+ update_time = now ()
84+
85+ # In Firebase Python, we need to use dot notation as strings
86+ update_data = dict (output ) # Create a copy to avoid modifying the original
87+ update_data [f"{ self .status_field } .state" ] = State .COMPLETED .value
88+ update_data [f"{ self .status_field } .updateTime" ] = update_time
89+ update_data [f"{ self .status_field } .completeTime" ] = update_time
90+
91+ change .after .reference .update (update_data )
92+
93+ async def write_error_event (self , change : Change , e : Any ) -> None :
94+ event_timestamp = now ()
95+ error_message = self .error_fn (e )
96+
97+ change .after .reference .update ({
98+ self .status_field : {
99+ "state" : State .ERROR .value ,
100+ "updateTime" : event_timestamp ,
101+ "error" : error_message
102+ }
103+ })
104+
105+ async def run (self , event : Event ) -> None :
106+ if not event :
107+ print ("No event data" )
108+ return
109+
110+ if not event .data :
111+ print ("No document event.data" )
112+ return
113+
114+ if not self .should_process (event .data ):
115+ return
116+
117+ try :
118+ await self .write_start_event (event )
119+
120+ input_data = safe_get (event .data .after , self .input_field )
121+ output = await self .process_fn (input_data , event )
122+
123+ await self .write_completion_event (event .data , output )
124+ except Exception as e :
125+ print (f"Message processing error: { e } " )
126+ await self .write_error_event (event .data , e )
0 commit comments