-
Notifications
You must be signed in to change notification settings - Fork 52
fix(configurator): make env_params first-class to fix the trajectory cache key #901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
95c192b
0dabd4e
d59572e
ed42f9d
9412c7b
af3b6a8
dead658
69dd7d4
7fec7f5
8f91664
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -97,6 +97,7 @@ class TestRun: | |||||
| reports: Set[Type[ReportGenerationStrategy]] = field(default_factory=set) | ||||||
| extra_srun_args: str | None = None | ||||||
| num_nodes_explicit: bool = False | ||||||
| current_env_params: dict[str, Any] = field(default_factory=dict) | ||||||
|
|
||||||
| def __hash__(self) -> int: | ||||||
| return hash(self.name + self.test.name + str(self.iterations) + str(self.current_iteration)) | ||||||
|
|
@@ -156,7 +157,9 @@ def param_space(self) -> dict[str, Any]: | |||||
| **{ | ||||||
| key: value | ||||||
| for key, value in cmd_args_dict.items() | ||||||
| if isinstance(value, list) and not self.test.is_dse_excluded_arg(key) | ||||||
| if isinstance(value, list) | ||||||
| and not self.test.is_dse_excluded_arg(key) | ||||||
| and not self.test.is_env_sampled(key) | ||||||
| }, | ||||||
| **{f"extra_env_vars.{key}": value for key, value in extra_env_vars_dict.items() if isinstance(value, list)}, | ||||||
| } | ||||||
|
|
@@ -184,27 +187,40 @@ def all_combinations(self) -> list[dict[str, Any]]: | |||||
|
|
||||||
| return all_combinations | ||||||
|
|
||||||
| def apply_params_set(self, action: dict[str, Any]) -> "TestRun": | ||||||
| def apply_params_set(self, action: dict[str, Any], env_params: dict[str, Any] | None = None) -> "TestRun": | ||||||
| tdef = self.test.model_copy(deep=True) | ||||||
| for key, value in action.items(): | ||||||
|
|
||||||
| def _apply(key: str, value: Any) -> None: | ||||||
| if key.startswith("extra_env_vars."): | ||||||
| tdef.extra_env_vars[key[len("extra_env_vars.") :]] = value | ||||||
| return | ||||||
| attrs = key.split(".") | ||||||
| obj = tdef.cmd_args | ||||||
| for attr in attrs[:-1]: | ||||||
| obj = obj[attr] if isinstance(obj, dict) else getattr(obj, attr) | ||||||
| if isinstance(obj, dict): | ||||||
| obj[attrs[-1]] = value | ||||||
| else: | ||||||
| attrs = key.split(".") | ||||||
| obj = tdef.cmd_args | ||||||
| for attr in attrs[:-1]: | ||||||
| obj = obj[attr] if isinstance(obj, dict) else getattr(obj, attr) | ||||||
| if isinstance(obj, dict): | ||||||
| obj[attrs[-1]] = value | ||||||
| else: | ||||||
| setattr(obj, attrs[-1], value) | ||||||
| setattr(obj, attrs[-1], value) | ||||||
|
|
||||||
| # RNG runs in the env before this call; applying only concrete values keeps this deterministic. | ||||||
| for key, value in action.items(): | ||||||
| _apply(key, value) | ||||||
| for key, value in (env_params or {}).items(): | ||||||
| _apply(key, value) | ||||||
|
|
||||||
| type(tdef)(**tdef.model_dump()) # trigger validation | ||||||
| # env_params is validated at parse time; after the overlay its target cmd_args fields hold | ||||||
| # concrete scalar draws, so re-validating it here would reject weighted specs. Drop it for | ||||||
| # this validation-only pass, which exists to validate the applied action values. | ||||||
| validation_args = tdef.model_dump() | ||||||
| validation_args.pop("env_params", None) | ||||||
| type(tdef)(**validation_args) # trigger validation | ||||||
|
|
||||||
| new_tr = copy.deepcopy(self) | ||||||
| new_tr.test = tdef | ||||||
| if "NUM_NODES" in action: | ||||||
| new_tr.num_nodes = action["NUM_NODES"] | ||||||
|
rutayan-nv marked this conversation as resolved.
|
||||||
| new_tr.current_env_params = dict(env_params or {}) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||
| return new_tr | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| from .base_agent import RewardOverrides | ||||||||||||||||||
| from .base_gym import BaseGym | ||||||||||||||||||
| from .env_params import EnvParams, EnvParamsSink | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclasses.dataclass(frozen=True) | ||||||||||||||||||
|
|
@@ -36,6 +37,7 @@ class TrajectoryEntry: | |||||||||||||||||
| action: dict[str, Any] | ||||||||||||||||||
| reward: float | ||||||||||||||||||
| observation: list | ||||||||||||||||||
| env_params: dict[str, Any] = dataclasses.field(default_factory=dict) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class CloudAIGymEnv(BaseGym): | ||||||||||||||||||
|
|
@@ -61,8 +63,15 @@ def __init__(self, test_run: TestRun, runner: BaseRunner, rewards: RewardOverrid | |||||||||||||||||
| self.max_steps = test_run.test.agent_steps | ||||||||||||||||||
| self.reward_function = Registry().get_reward_function(test_run.test.agent_reward_function) | ||||||||||||||||||
| self.trajectory: dict[int, list[TrajectoryEntry]] = {} | ||||||||||||||||||
| self.params: EnvParams | None = EnvParams.from_test(test_run.test) | ||||||||||||||||||
| self.env_params_sink = EnvParamsSink() | ||||||||||||||||||
| super().__init__() | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def env_params_record_path(self) -> Path: | ||||||||||||||||||
| """``env.csv`` lives alongside ``trajectory.csv`` so a plain ``merge`` joins them.""" | ||||||||||||||||||
| return self.iteration_dir / "env.csv" | ||||||||||||||||||
|
|
||||||||||||||||||
| def define_action_space(self) -> Dict[str, list[Any]]: | ||||||||||||||||||
| return self.test_run.param_space | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -119,7 +128,9 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: | |||||||||||||||||
| - info (dict): Additional info for debugging. | ||||||||||||||||||
| """ | ||||||||||||||||||
| self.test_run.increment_step() | ||||||||||||||||||
| self.test_run = self.test_run.apply_params_set(action) | ||||||||||||||||||
| # RNG lives in the env: sample here, then apply action + sample so the run and cache key see them. | ||||||||||||||||||
| sampled_env_params = self.params.sample(self.test_run.step) if self.params else {} | ||||||||||||||||||
| self.test_run = self.test_run.apply_params_set(action, env_params=sampled_env_params) | ||||||||||||||||||
|
|
||||||||||||||||||
| cached_result = self.get_cached_trajectory_result(action) | ||||||||||||||||||
| if cached_result is not None: | ||||||||||||||||||
|
|
@@ -134,6 +145,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: | |||||||||||||||||
| action=action, | ||||||||||||||||||
| reward=cached_result.reward, | ||||||||||||||||||
| observation=cached_result.observation, | ||||||||||||||||||
| env_params=dict(self.test_run.current_env_params), | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit and I see this (to my understanding) redundant
Suggested change
|
||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
| return cached_result.observation, cached_result.reward, False, {} | ||||||||||||||||||
|
|
@@ -162,6 +174,9 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: | |||||||||||||||||
| self.test_run.step = new_tr.step | ||||||||||||||||||
| self.test_run.output_path = new_tr.output_path | ||||||||||||||||||
|
|
||||||||||||||||||
| # The test_run rebuild above drops the sample; restore it so the entry, cache key, and env.csv match. | ||||||||||||||||||
| self.test_run.current_env_params = new_tr.current_env_params | ||||||||||||||||||
|
Comment on lines
+177
to
+178
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the same line in |
||||||||||||||||||
|
|
||||||||||||||||||
| observation = self.get_observation(action) | ||||||||||||||||||
| reward = self.compute_reward(observation) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -171,6 +186,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: | |||||||||||||||||
| action=action, | ||||||||||||||||||
| reward=reward, | ||||||||||||||||||
| observation=observation, | ||||||||||||||||||
| env_params=dict(self.test_run.current_env_params), | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -230,7 +246,14 @@ def get_observation(self, action: Any) -> list: | |||||||||||||||||
| return observation | ||||||||||||||||||
|
|
||||||||||||||||||
| def write_trajectory(self, entry: TrajectoryEntry): | ||||||||||||||||||
| """Append the trajectory to the CSV file and to the local attribute.""" | ||||||||||||||||||
| """ | ||||||||||||||||||
| Append the entry to the in-memory cache and trajectory.csv (plus env.csv when declared). | ||||||||||||||||||
|
|
||||||||||||||||||
| ``trajectory.csv`` and the ``env.csv`` projection are sunk from the same | ||||||||||||||||||
| ``TrajectoryEntry`` here, so a trial that never produces an entry (e.g. a | ||||||||||||||||||
| constraint failure returns before this call) lands in neither file and the | ||||||||||||||||||
| two stay 1:1 step-aligned. | ||||||||||||||||||
| """ | ||||||||||||||||||
| self.current_trajectory.append(entry) | ||||||||||||||||||
|
|
||||||||||||||||||
| file_exists = self.trajectory_file_path.exists() | ||||||||||||||||||
|
|
@@ -243,17 +266,36 @@ def write_trajectory(self, entry: TrajectoryEntry): | |||||||||||||||||
| writer.writerow(["step", "action", "reward", "observation"]) | ||||||||||||||||||
| writer.writerow([entry.step, entry.action, entry.reward, entry.observation]) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.env_params_sink.write(self.env_params_record_path, entry.step, entry.env_params) | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def iteration_dir(self) -> Path: | ||||||||||||||||||
| """Per-iteration output dir; trajectory.csv and env.csv both live here, step-aligned.""" | ||||||||||||||||||
| return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def trajectory_file_path(self) -> Path: | ||||||||||||||||||
| return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" / "trajectory.csv" | ||||||||||||||||||
| return self.iteration_dir / "trajectory.csv" | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def current_trajectory(self) -> list[TrajectoryEntry]: | ||||||||||||||||||
| return self.trajectory.setdefault(self.test_run.current_iteration, []) | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_cached_trajectory_result(self, action: Any) -> TrajectoryEntry | None: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Return a cached entry only when the full trial identity matches. | ||||||||||||||||||
|
|
||||||||||||||||||
| Trial identity is ``(action, env_params)``: env-randomized parameters | ||||||||||||||||||
| change the workload's behaviour, so a trial repeating the same action | ||||||||||||||||||
| under a different ``env_params`` sample must miss and re-run. Empty | ||||||||||||||||||
| env_params on both sides is the back-compat path for workloads that | ||||||||||||||||||
| do not declare any ``[env_params.*]`` block. | ||||||||||||||||||
| """ | ||||||||||||||||||
| current_env_params = self.test_run.current_env_params | ||||||||||||||||||
| for entry in self.current_trajectory: | ||||||||||||||||||
| if self._values_match_exact(entry.action, action): | ||||||||||||||||||
| if not self._values_match_exact(entry.action, action): | ||||||||||||||||||
| continue | ||||||||||||||||||
| if self._values_match_exact(entry.env_params, current_env_params): | ||||||||||||||||||
| return entry | ||||||||||||||||||
|
Comment on lines
+296
to
299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit again, but I'd like to push a bit for readibility. we sacrifice a few ms for validation but I believe it's just mush easier to read
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd suggest reverting the changes with the
_applyfunction (local functions are a troublesome thing during debugging) and rather replace the previous rootwith
it makes the delta of the PR quite shorter + no inner function