sraft

simple raft implementation
git clone https://git.parazyd.org/sraft
Log | Files | Refs | README

lib.rs (16370B)


      1 use std::{collections::HashMap, io, net::SocketAddr, time::Duration};
      2 
      3 use async_channel::{Receiver, Sender};
      4 use async_std::{
      5     io::{ReadExt, WriteExt},
      6     net::{TcpListener, TcpStream},
      7     stream::StreamExt,
      8     sync::Mutex,
      9     task,
     10 };
     11 use borsh::{BorshDeserialize, BorshSerialize};
     12 use futures::{select, FutureExt};
     13 use lazy_static::lazy_static;
     14 use log::{debug, error};
     15 use rand::Rng;
     16 
     17 mod method;
     18 use crate::method::{HeartbeatArgs, HeartbeatReply, RaftMethod, VoteArgs, VoteReply};
     19 
     20 #[derive(BorshSerialize, BorshDeserialize, Clone, Debug)]
     21 pub struct LogEntry {
     22     log_term: u64,
     23     log_index: u64,
     24     log_data: Vec<u8>,
     25 }
     26 
     27 pub struct LogStore(pub Vec<LogEntry>);
     28 
     29 impl LogStore {
     30     fn get_last_index(&self) -> u64 {
     31         let rlen = self.0.len();
     32         if rlen == 0 {
     33             return 0
     34         }
     35 
     36         self.0[rlen - 1].log_index
     37     }
     38 }
     39 
     40 lazy_static! {
     41     pub static ref LOG_STORE: Mutex<LogStore> = Mutex::new(LogStore(vec![]));
     42     // This is used for heartbeats
     43     pub static ref HEARTBEAT_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
     44     // This is used to let our node know when it has become a leader
     45     pub static ref TOLEADER_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
     46 
     47     pub static ref STATE: Mutex<State> = Mutex::new(State::new());
     48 }
     49 
     50 #[derive(Default)]
     51 pub struct State {
     52     pub current_term: u64,
     53     pub voted_for: u64,
     54     pub vote_count: u64,
     55 
     56     pub commit_index: u64,
     57     pub _last_applied: u64,
     58 
     59     pub next_index: Vec<u64>,
     60     pub match_index: Vec<u64>,
     61 }
     62 
     63 impl State {
     64     pub fn new() -> Self {
     65         Self {
     66             current_term: 0,
     67             voted_for: 0,
     68             vote_count: 0,
     69             commit_index: 0,
     70             _last_applied: 0,
     71             next_index: vec![],
     72             match_index: vec![],
     73         }
     74     }
     75 }
     76 
     77 pub enum Role {
     78     Follower,
     79     Candidate,
     80     Leader,
     81 }
     82 
     83 pub struct Raft {
     84     pub peers: HashMap<u64, SocketAddr>,
     85     node_id: u64,
     86     role: Role,
     87 }
     88 
     89 impl Raft {
     90     pub fn new(node_id: u64) -> Self {
     91         Self { peers: Default::default(), node_id, role: Role::Follower }
     92     }
     93 
     94     pub async fn start(&mut self) {
     95         debug!("Raft::start()");
     96         self.role = Role::Follower;
     97 
     98         let mut state = STATE.lock().await;
     99         state.current_term = 0;
    100         state.voted_for = 0;
    101         drop(state);
    102 
    103         let mut rng = rand::thread_rng();
    104 
    105         loop {
    106             let delay = Duration::from_millis(rng.gen_range(0..200) + 300);
    107 
    108             match self.role {
    109                 Role::Follower => {
    110                     select! {
    111                         _ = HEARTBEAT_CHAN.1.recv().fuse() => {
    112                             debug!("[FOLLOWER] Raft::start(): follower_{} got heartbeat", self.node_id);
    113                         }
    114                         _ = task::sleep(delay).fuse() => {
    115                             debug!("[FOLLOWER] Raft::start(): follower_{} timeout", self.node_id);
    116                             self.role = Role::Candidate;
    117                         }
    118                     }
    119                 }
    120 
    121                 Role::Candidate => {
    122                     debug!("[CANDIDATE] Raft::start(): peer_{} is now a candidate", self.node_id);
    123                     let mut state = STATE.lock().await;
    124                     state.current_term += 1;
    125                     state.voted_for = self.node_id;
    126                     state.vote_count = 1;
    127                     drop(state);
    128 
    129                     // TODO: In background
    130                     debug!("[CANDIDATE] Raft::start(): broadcasting request_vote");
    131                     self.broadcast_request_vote().await;
    132 
    133                     select! {
    134                         _ = task::sleep(delay).fuse() => {
    135                             debug!("[CANDIDATE] Raft::start(): Timeout as candidate, becoming a follower");
    136                             self.role = Role::Follower;
    137                         }
    138                         _ = TOLEADER_CHAN.1.recv().fuse() => {
    139                             debug!("[CANDIDATE] Raft::start(): We are now the leader");
    140                             self.role = Role::Leader;
    141 
    142                             let mut state = STATE.lock().await;
    143                             state.next_index = vec![1_u64; self.peers.len()];
    144                             state.match_index = vec![0_u64; self.peers.len()];
    145                             drop(state);
    146 
    147                             // TODO: In background
    148                             let t = task::spawn(async {
    149                                 let mut i = 0;
    150                                 loop {
    151                                     debug!("[CANDIDATE] Raft::start(): Appending data in bg loop");
    152                                     i += 1;
    153                                     let state = STATE.lock().await;
    154                                     let logentry = LogEntry {
    155                                         log_term: state.current_term,
    156                                         log_index: i,
    157                                         log_data: format!("user send: {}", i).as_bytes().to_vec(),
    158                                     };
    159                                     drop(state);
    160 
    161                                     debug!("[CANDIDATE] Raft::start(): Acquiring logstore lock in bg loop");
    162                                     let mut logstore = LOG_STORE.lock().await;
    163                                     logstore.0.push(logentry);
    164                                     drop(logstore);
    165                                     debug!("[CANDIDATE] Raft::start(): Dropped logstore lock in bg loop");
    166                                     task::sleep(Duration::from_secs(3)).await;
    167                                 }
    168                             });
    169                         }
    170                     }
    171                 }
    172 
    173                 Role::Leader => {
    174                     debug!("[LEADER] Raft::start(): Broadcasting heartbeat as leader");
    175                     self.broadcast_heartbeat().await;
    176                     task::sleep(Duration::from_millis(100)).await;
    177                 }
    178             }
    179         }
    180     }
    181 
    182     async fn broadcast_request_vote(&mut self) {
    183         debug!("Raft::broadcast_request_vote()");
    184         let state = STATE.lock().await;
    185         let args = VoteArgs { term: state.current_term, candidate_id: self.node_id };
    186         drop(state);
    187 
    188         // TODO: Do this concurrently
    189         for i in self.peers.clone() {
    190             debug!("Raft::broadcast_request_vote(): Sending req to peer {}", i.1);
    191             match self.send_request_vote(i.0, args.clone()).await {
    192                 Ok(v) => debug!("Raft::broadcast_request_vote(): Got reply: {:?}", v),
    193                 Err(e) => {
    194                     error!("Raft::broadcast_request_vote(): Failed vote to peer {}, ({})", i.1, e);
    195                     continue
    196                 }
    197             };
    198         }
    199     }
    200 
    201     async fn send_request_vote(
    202         &mut self,
    203         node_id: u64,
    204         args: VoteArgs,
    205     ) -> Result<VoteReply, io::Error> {
    206         debug!("Raft::send_request_vote()");
    207         let addr = self.peers[&node_id];
    208 
    209         let method = RaftMethod::Vote(args);
    210         let payload = method.try_to_vec().unwrap();
    211 
    212         debug!("Raft::send_request_vote(): Connecting to peer_{}", node_id);
    213         let mut stream = TcpStream::connect(addr).await?;
    214         debug!("Raft::send_request_vote(): Writing to stream");
    215         stream.write_all(&payload).await?;
    216         debug!("Raft::send_request_vote(): Wrote to stream");
    217 
    218         debug!("Raft::send_request_vote(): Reading from stream");
    219         let mut buf = vec![0_u8; 4096];
    220         stream.read(&mut buf).await?;
    221         debug!("Raft::send_request_vote(): Read from stream");
    222 
    223         let reply = try_from_slice_unchecked::<VoteReply>(&buf)?;
    224         let mut state = STATE.lock().await;
    225         if reply.term > state.current_term {
    226             debug!("Raft::send_request_vote(): reply.term > state.current_term");
    227             state.current_term = reply.term;
    228             state.voted_for = 0;
    229             drop(state);
    230             self.role = Role::Follower;
    231             return Ok(reply)
    232         }
    233         drop(state);
    234 
    235         if reply.vote_granted {
    236             debug!("Raft::send_request_vote(): reply.vote_granted == true");
    237             let mut state = STATE.lock().await;
    238             state.vote_count += 1;
    239             drop(state);
    240         }
    241 
    242         let state = STATE.lock().await;
    243         if state.vote_count >= (self.peers.len() / 2 + 1).try_into().unwrap() {
    244             debug!("Raft::send_request_vote(): Elected for leader");
    245             TOLEADER_CHAN.0.send(true).await.unwrap();
    246         }
    247         drop(state);
    248 
    249         Ok(reply)
    250     }
    251 
    252     async fn broadcast_heartbeat(&mut self) {
    253         debug!("[LEADER] Raft::broadcast_heartbeat()");
    254 
    255         for i in self.peers.clone() {
    256             let state = STATE.lock().await;
    257             let mut args = HeartbeatArgs {
    258                 term: state.current_term,
    259                 leader_id: self.node_id,
    260                 prev_log_index: 0,
    261                 prev_log_term: 0,
    262                 entries: vec![],
    263                 leader_commit: state.commit_index,
    264             };
    265 
    266             let prev_log_index = state.next_index[i.0 as usize] - 1;
    267             drop(state);
    268 
    269             debug!("[LEADER] Raft::broadcast_heartbeat(): Acquiring lock on LOG_STORE");
    270             let logstore = LOG_STORE.lock().await;
    271             if logstore.get_last_index() > prev_log_index {
    272                 args.prev_log_index = prev_log_index;
    273                 args.prev_log_term = logstore.0[prev_log_index as usize].log_term;
    274                 args.entries = logstore.0[prev_log_index as usize..].to_vec();
    275                 drop(logstore);
    276                 debug!("[LEADER] Raft::broadcast_heartbeat(): Dropped lock on LOG_STORE");
    277                 debug!("[LEADER] Raft::broadcast_heartbeat(): Send entries: {:?}", args.entries);
    278             }
    279 
    280             // TODO: Run in background
    281             match self.send_heartbeat(i.0, args).await {
    282                 Ok(v) => debug!("[LEADER] Raft::broadcast_heartbeat(): Got reply: {:?}", v),
    283                 Err(e) => {
    284                     error!(
    285                         "[LEADER] Raft::broadcast_heartbeat(): Failed heartbeat to peer_{} ({})",
    286                         i.0, e
    287                     );
    288                     continue
    289                 }
    290             };
    291         }
    292     }
    293 
    294     async fn send_heartbeat(
    295         &mut self,
    296         node_id: u64,
    297         args: HeartbeatArgs,
    298     ) -> Result<HeartbeatReply, io::Error> {
    299         debug!("Raft::send_heartbeat({}, {:?}", node_id, args);
    300         let addr = self.peers[&node_id];
    301 
    302         let method = RaftMethod::Heartbeat(args);
    303         let payload = method.try_to_vec()?;
    304 
    305         debug!("Raft::send_heartbeat(): Connecting to peer_{}", node_id);
    306         let mut stream = TcpStream::connect(addr).await?;
    307         debug!("Raft::send_heartbeat(): Writing to stream");
    308         stream.write_all(&payload).await?;
    309         debug!("Raft::send_heartbeat(): Wrote to stream");
    310 
    311         debug!("Raft::send_heartbeat(): Reading from stream");
    312         let mut buf = vec![0_u8; 4096];
    313         stream.read(&mut buf).await?;
    314         debug!("Raft::send_heartbeat(): Read from stream");
    315 
    316         let reply = try_from_slice_unchecked::<HeartbeatReply>(&buf)?;
    317 
    318         let mut state = STATE.lock().await;
    319         if reply.success {
    320             debug!("Raft::send_heartbeat(): Got success reply");
    321             if reply.next_index > 0 {
    322                 state.next_index[node_id as usize] = reply.next_index;
    323                 state.match_index[node_id as usize] = reply.next_index - 1;
    324             }
    325         } else if reply.term > state.current_term {
    326             debug!("Raft::send_heartbeat(): reply.term > state.current_term");
    327             state.current_term = reply.term;
    328             state.voted_for = 0;
    329             self.role = Role::Follower;
    330         }
    331         drop(state);
    332 
    333         Ok(reply)
    334     }
    335 }
    336 
    337 pub struct RaftRpc(pub SocketAddr);
    338 
    339 impl RaftRpc {
    340     pub async fn start(&self) {
    341         debug!("RaftRpc::start()");
    342 
    343         debug!("RaftRpc::start(): Binding to {}", self.0);
    344         let listener = TcpListener::bind(self.0).await.unwrap();
    345         let mut incoming = listener.incoming();
    346 
    347         while let Some(stream) = incoming.next().await {
    348             debug!("RaftRpc::start(): Got RPC request");
    349             let stream = stream.unwrap();
    350             let (reader, writer) = &mut (&stream, &stream);
    351 
    352             debug!("RaftRpc::start(): Reading from reader...");
    353             let mut buf = vec![0_u8; 4096];
    354             reader.read(&mut buf).await.unwrap();
    355             debug!("RaftRpc::start(): Read from reader");
    356 
    357             match try_from_slice_unchecked::<RaftMethod>(&buf).unwrap() {
    358                 RaftMethod::Vote(args) => {
    359                     debug!("RaftRpc::start(): Got RaftMethod::Vote");
    360                     let reply = self.request_vote(args).await;
    361                     let payload = reply.try_to_vec().unwrap();
    362 
    363                     debug!("RaftRpc::start(): Vote: Writing to writer...");
    364                     writer.write_all(&payload).await.unwrap();
    365                     debug!("RaftRpc::start(): Vote: Wrote to writer");
    366                 }
    367 
    368                 RaftMethod::Heartbeat(args) => {
    369                     debug!("RaftRpc::start(): Got RaftMethod::Heartbeat");
    370                     let reply = self.heartbeat(args).await;
    371                     let payload = reply.try_to_vec().unwrap();
    372 
    373                     debug!("RaftRpc::start(): Heartbeat: Writing to writer...");
    374                     writer.write_all(&payload).await.unwrap();
    375                     debug!("RaftRpc::start(): Heartbeat: Wrote to writer");
    376                 }
    377             }
    378         }
    379     }
    380 
    381     async fn request_vote(&self, args: VoteArgs) -> VoteReply {
    382         debug!("RaftRpc::request_vote()");
    383         let mut reply = VoteReply { term: 0, vote_granted: false };
    384 
    385         debug!("RaftRpc::request_vote(): Acquiring state lock");
    386         let mut state = STATE.lock().await;
    387         debug!("RaftRpc::request_vote(): Got lock");
    388 
    389         if args.term < state.current_term {
    390             reply.term = state.current_term;
    391             drop(state);
    392             reply.vote_granted = false;
    393             return reply
    394         }
    395 
    396         if state.voted_for == 0 {
    397             state.current_term = args.term;
    398             state.voted_for = args.candidate_id;
    399             drop(state);
    400             reply.term = args.term;
    401             reply.vote_granted = true;
    402             return reply
    403         }
    404 
    405         drop(state);
    406         reply
    407     }
    408 
    409     async fn heartbeat(&self, args: HeartbeatArgs) -> HeartbeatReply {
    410         debug!("RaftRpc::heartbeat()");
    411         let mut reply = HeartbeatReply { success: false, term: 0, next_index: 0 };
    412 
    413         debug!("RaftRpc::heartbeat(): Acquiring state lock");
    414         let state = STATE.lock().await;
    415         debug!("RaftRpc::heartbeat(): Got state lock");
    416         let current_term = state.current_term;
    417         drop(state);
    418         debug!("RaftRpc::heartbeat(): Dropped state lock");
    419 
    420         if args.term < current_term {
    421             reply.success = false;
    422             reply.term = current_term;
    423             return reply
    424         }
    425 
    426         debug!("RaftRpc::heartbeat(): Sending to channel");
    427         HEARTBEAT_CHAN.0.send(true).await.unwrap();
    428         debug!("RaftRpc::heartbeat(): Sent to channel");
    429 
    430         if args.entries.is_empty() {
    431             reply.success = true;
    432             reply.term = current_term;
    433             return reply
    434         }
    435 
    436         debug!("RaftRpc::heartbeat(): Acquiring logstore lock");
    437         let mut logstore = LOG_STORE.lock().await;
    438         debug!("RaftRpc::heartbeat(): Got logstore lock");
    439         if args.prev_log_index > logstore.get_last_index() {
    440             reply.success = false;
    441             reply.term = current_term;
    442             reply.next_index = logstore.get_last_index() + 1;
    443             drop(logstore);
    444             return reply
    445         }
    446 
    447         logstore.0.extend_from_slice(&args.entries);
    448         reply.next_index = logstore.get_last_index() + 1;
    449         drop(logstore);
    450         debug!("RaftRpc::heartbeat(): Dropped logstore lock");
    451 
    452         reply.success = true;
    453         reply.term = current_term;
    454 
    455         reply
    456     }
    457 }
    458 
    459 fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, io::Error> {
    460     let mut data_mut = data;
    461     let result = T::deserialize(&mut data_mut)?;
    462     Ok(result)
    463 }