Source code for nrel.hive.initialization.sample_requests

import random
from typing import Tuple, List

from nrel.hive.model.request import Request
from nrel.hive.model.roadnetwork.osm.osm_roadnetwork import OSMRoadNetwork
from nrel.hive.model.sim_time import SimTime
from nrel.hive.runner import Environment
from nrel.hive.state.simulation_state.simulation_state import SimulationState


[docs]def default_request_sampler( count: int, simulation_state: SimulationState, environment: Environment, allow_pooling: bool = False, random_seed: int = 0, ) -> Tuple[Request, ...]: """ samples `count` requests uniformly across time and space :param count: the number of requests to sample :param simulation_state: the simulation state :param environment: the environment :param random_seed: the random seed used for the random selections :return: a tuple of the sampled requests """ if not isinstance(simulation_state.road_network, OSMRoadNetwork): raise NotImplementedError("request sampling is only implemented for the OSMRoadNetwork") if simulation_state.road_network.link_helper is None: raise Exception("Expected link helper on OSMRoadNetwork but found None") random.seed(random_seed) requests: List[Request] = [] possible_timesteps = list( range( int(environment.config.sim.start_time), int(environment.config.sim.end_time), environment.config.sim.timestep_duration_seconds, ) ) possible_links = list(simulation_state.road_network.link_helper.links.values()) id_counter = 0 while len(requests) < count: random_source_link = random.choice(possible_links) random_destination_link = random.choice(possible_links) if random_source_link.start == random_destination_link.end: # skip if the request starts and ends at the same location continue request = Request.build( request_id="r" + str(id_counter), origin=random_source_link.start, destination=random_destination_link.end, road_network=simulation_state.road_network, departure_time=SimTime(random.choice(possible_timesteps)), passengers=random.choice([1, 2, 3, 4]), allows_pooling=allow_pooling, ) requests.append(request) id_counter += 1 sorted_reqeusts = sorted(requests, key=lambda r: (r.departure_time, r.id)) return tuple(sorted_reqeusts)