# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Compiler Pass Ordering Environment Client.""" from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from .models import CompilerOptAction, CompilerOptObservation class CompilerOptEnv( EnvClient[CompilerOptAction, CompilerOptObservation, State] ): """ Client for the Compiler Pass Ordering Environment. Maintains a persistent WebSocket connection to the environment server. Each client instance has its own dedicated environment session. Example (sync): >>> with CompilerOptEnv(base_url="http://localhost:8000").sync() as env: ... obs = env.reset() ... result = env.step(CompilerOptAction(pass_id=13, task_id=1)) ... print(result.observation.improvement_pct) Example (async): >>> async with CompilerOptEnv(base_url="http://localhost:8000") as env: ... obs = await env.reset() ... result = await env.step(CompilerOptAction(pass_id=13, task_id=1)) """ def _step_payload(self, action: CompilerOptAction) -> Dict: return { "pass_id": action.pass_id, "task_id": action.task_id, } def _parse_result(self, payload: Dict) -> StepResult[CompilerOptObservation]: obs_data = payload.get("observation", {}) observation = CompilerOptObservation( estimated_cost = obs_data.get("estimated_cost", 0.0), baseline_cost = obs_data.get("baseline_cost", 0.0), num_instructions = obs_data.get("num_instructions", 0), num_loops = obs_data.get("num_loops", 0), num_branches = obs_data.get("num_branches", 0), num_functions = obs_data.get("num_functions", 0), loop_depth = obs_data.get("loop_depth", 0), program_type = obs_data.get("program_type", ""), passes_applied = obs_data.get("passes_applied", []), passes_available = obs_data.get("passes_available", []), step_count = obs_data.get("step_count", 0), max_steps = obs_data.get("max_steps", 10), synergy_state = obs_data.get("synergy_state", [1.0] * 15), task_id = obs_data.get("task_id", 3), task_description = obs_data.get("task_description", ""), done = payload.get("done", False), reward = payload.get("reward", 0.0), improvement_pct = obs_data.get("improvement_pct", 0.0), last_pass_name = obs_data.get("last_pass_name"), grader_score = obs_data.get("grader_score"), ) return StepResult( observation = observation, reward = payload.get("reward"), done = payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: return State( episode_id = payload.get("episode_id"), step_count = payload.get("step_count", 0), )