44
55import asyncio
66from datetime import datetime
7+ import os
78from os .path import isfile
9+ import tempfile
810from typing import Any , ForwardRef , TypedDict , TypeVar , get_type_hints
911
1012import aiofiles
@@ -77,7 +79,7 @@ class StateDict(TypedDict, total=False):
7779T = TypeVar ("T" , bound = "StateDict" )
7880
7981
80- def _get_defaults (cls : type [T ]) -> dict [ str , Any ] :
82+ def _get_defaults (cls : type [T ]) -> T :
8183 """Generate a default dict based on typed dict
8284
8385 This function recursively creates a nested dictionary structure that mirrors
@@ -117,22 +119,71 @@ def _get_defaults(cls: type[T]) -> dict[str, Any]:
117119 return new_dict # type: ignore[return-value]
118120
119121
120- async def get_state (state_yaml : str ) -> StateDict :
122+ def _write_state (state_yaml : str , state : dict [str , Any ] | StateDict ) -> None :
123+ "Write state atomically to a temp file, this prevents reading a file being written to"
124+
125+ dir_name = os .path .dirname (os .path .abspath (state_yaml ))
126+ content = yaml .dump (state )
127+ with tempfile .NamedTemporaryFile (
128+ mode = "w" , dir = dir_name , delete = False , suffix = ".tmp"
129+ ) as tmp :
130+ tmp .write (content )
131+ tmp_path = tmp .name
132+ os .chmod (tmp_path , 0o644 )
133+ os .replace (tmp_path , state_yaml )
134+
135+
136+ async def get_state (state_yaml : str , _already_locked : bool = False ) -> StateDict :
121137 """Read in state yaml.
122138
123139 :param state_yaml: filename where to read the state
124140 :type state_yaml: ``str``
141+ :param _already_locked: Whether the lock is already held by the caller (e.g. set_state).
142+ Prevents deadlock when corruption recovery needs to write defaults.
143+ :type _already_locked: ``bool``
125144 :rtype: ``StateDict``
126145 """
127146 if not isfile (
128147 state_yaml
129148 ): # noqa: PTH113 - isfile is fine and simpler in this case.
130- return _get_defaults (StateDict ) # type: ignore
131- async with aiofiles .open (state_yaml , mode = "r" ) as yaml_file :
132- LOG .debug ("Loading state from yaml" )
133- content = await yaml_file .read ()
134- state_yaml_payload : StateDict = yaml .safe_load (content )
135- return state_yaml_payload
149+ return _get_defaults (StateDict )
150+
151+ try :
152+ async with aiofiles .open (state_yaml , mode = "r" ) as yaml_file :
153+ LOG .debug ("Loading state from yaml" )
154+ content = await yaml_file .read ()
155+
156+ state_yaml_payload : StateDict | None = yaml .safe_load (content )
157+
158+ # Handle corrupted/empty YAML files
159+ if state_yaml_payload is None or not isinstance (state_yaml_payload , dict ):
160+ LOG .warning (
161+ "State file %s is corrupted or empty, reinitializing with defaults" ,
162+ state_yaml ,
163+ )
164+ defaults = _get_defaults (StateDict )
165+ if _already_locked :
166+ _write_state (state_yaml , defaults )
167+ else :
168+ async with lock :
169+ _write_state (state_yaml , defaults )
170+ return defaults
171+
172+ return state_yaml_payload
173+
174+ except yaml .YAMLError as e :
175+ LOG .error (
176+ "Failed to parse state file %s: %s. Reinitializing with defaults." ,
177+ state_yaml ,
178+ e ,
179+ )
180+ defaults = _get_defaults (StateDict )
181+ if _already_locked :
182+ _write_state (state_yaml , defaults )
183+ else :
184+ async with lock :
185+ _write_state (state_yaml , defaults )
186+ return defaults
136187
137188
138189async def set_state (
@@ -143,6 +194,7 @@ async def set_state(
143194 ),
144195) -> None :
145196 """Save state yaml.
197+
146198 :param state_yaml: filename where to read the state
147199 :type state_yaml: ``str``
148200 :param key: Key name
@@ -152,14 +204,11 @@ async def set_state(
152204 :rtype: ``StateDict``
153205 """
154206 async with lock : # note ic-dev21: on lock le fichier pour être sûr de finir la job
155- current_state = await get_state (state_yaml ) or {}
207+ current_state = await get_state (state_yaml , _already_locked = True ) or {}
156208 merged_state : dict [str , Any ] = {key : {** current_state .get (key , {}), ** state }} # type: ignore[dict-item]
157209 new_state : dict [str , Any ] = {** current_state , ** merged_state }
158- async with aiofiles .open (state_yaml , mode = "w" ) as yaml_file :
159- LOG .debug ("Saving state to yaml file" )
160- # TODO: Use asyncio.get_running_loop() and run_in_executor to write
161- # to the file in a non blocking manner. Currently, the file writes
162- # are properly async but the yaml dump is done synchronously on the
163- # main event loop.
164- content = yaml .dump (new_state )
165- await yaml_file .write (content )
210+ LOG .debug ("Saving state to yaml file" )
211+ # TODO: Use asyncio.get_running_loop() and run_in_executor to write
212+ # to the file in a non blocking manner. Currently, yaml.dump is
213+ # synchronous on the main event loop.
214+ _write_state (state_yaml , new_state )
0 commit comments