Unverified Commit 7fb7e177 authored by Burak's avatar Burak Committed by GitHub
Browse files

Simplify Python Pokemon service (#1773)



* Simplify Python Pokemon service

* Use `threading.Lock` instead of `multiprocessing.Lock`

Co-authored-by: default avatarMatteo Bigoi <1781140+crisidev@users.noreply.github.com>

* Explain why we need to pass `force=True`

* Add comment about synchronization of `Context` class

Co-authored-by: default avatarMatteo Bigoi <1781140+crisidev@users.noreply.github.com>
parent 89a6e589
Loading
Loading
Loading
Loading
+19 −15
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
import itertools
import logging
import random
import threading
from threading import Lock
from dataclasses import dataclass
from typing import List, Optional

@@ -30,22 +30,18 @@ from libpokemon_service_server_sdk.types import ByteStream
# fast logging handler, Tracingandler based on Rust tracing crate.
logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()])


# A slightly more atomic counter using a threading lock.
class FastWriteCounter:
class SafeCounter:
    def __init__(self):
        self._number_of_read = 0
        self._counter = itertools.count()
        self._read_lock = threading.Lock()
        self._val = 0
        self._lock = Lock()

    def increment(self):
        next(self._counter)
        with self._lock:
            self._val += 1

    def value(self):
        with self._read_lock:
            value = next(self._counter) - self._number_of_read
            self._number_of_read += 1
        return value
        with self._lock:
            return self._val


###########################################################
@@ -65,7 +61,12 @@ class FastWriteCounter:
#   * def operation(input: OperationInput, state: State) -> OperationOutput
#   * async def operation(input: OperationInput, state: State) -> OperationOutput
#
# NOTE: protection of the data inside the context class is up to the developer
# Synchronization:
#   Instance of `Context` class will be cloned for every worker and all state kept in `Context`
#   will be specific to that process. There is no protection provided by default, 
#   it is up to you to have synchronization between processes. 
#   If you really want to share state between different processes you need to use `multiprocessing` primitives: 
#   https://docs.python.org/3/library/multiprocessing.html#sharing-state-between-processes
@dataclass
class Context:
    # In our case it simulates an in-memory database containing the description of Pikachu in multiple
@@ -90,7 +91,7 @@ class Context:
            ),
        ]
    }
    _calls_count = FastWriteCounter()
    _calls_count = SafeCounter()
    _radio_database = [
        "https://ia800107.us.archive.org/33/items/299SoundEffectCollection/102%20Palette%20Town%20Theme.mp3",
        "https://ia600408.us.archive.org/29/items/PocketMonstersGreenBetaLavenderTownMusicwwwFlvtoCom/Pocket%20Monsters%20Green%20Beta-%20Lavender%20Town%20Music-%5Bwww_flvto_com%5D.mp3",
@@ -228,4 +229,7 @@ async def stream_pokemon_radio(_: StreamPokemonRadioInput, context: Context):
###########################################################
# Run the server.
###########################################################
def main():
    app.run(workers=1)

main()
+7 −1
Original line number Diff line number Diff line
@@ -426,7 +426,13 @@ event_loop.add_signal_handler(signal.SIGINT,
        // Forcing the multiprocessing start method to fork is a workaround for it.
        // https://github.com/pytest-dev/pytest-flask/issues/104#issuecomment-577908228
        #[cfg(target_os = "macos")]
        mp.call_method1("set_start_method", ("fork",))?;
        mp.call_method(
            "set_start_method",
            ("fork",),
            // We need to pass `force=True` to prevent `context has already been set` exception,
            // see https://github.com/pytorch/pytorch/issues/3492
            Some(vec![("force", true)].into_py_dict(py)),
        )?;

        let address = address.unwrap_or_else(|| String::from("127.0.0.1"));
        let port = port.unwrap_or(13734);