File size: 933 Bytes
872aa5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gym
import torch

from abc import ABC, abstractmethod
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from torch.utils.tensorboard.writer import SummaryWriter
from typing import List, Optional, TypeVar

from shared.callbacks.callback import Callback
from shared.policy.policy import Policy
from shared.stats import EpisodesStats

AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")

class Algorithm(ABC):
    @abstractmethod
    def __init__(
        self,
        policy: Policy,
        env: VecEnv,
        device: torch.device,
        tb_writer: SummaryWriter,
        **kwargs,
    ) -> None:
        super().__init__()
        self.policy = policy
        self.env = env
        self.device = device
        self.tb_writer = tb_writer

    @abstractmethod
    def learn(
        self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
    ) -> AlgorithmSelf:
        ...