sraft

simple raft implementation
git clone https://git.parazyd.org/sraft
Log | Files | Refs | README

commit db0dad77964a9efc1f9301f33aa185a4857dc216
Author: parazyd <parazyd@dyne.org>
Date:   Thu, 31 Mar 2022 02:00:00 +0200

Add code

Diffstat:
A.gitignore | 2++
ACargo.toml | 20++++++++++++++++++++
AREADME.md | 14++++++++++++++
Aexamples/peer.rs | 55+++++++++++++++++++++++++++++++++++++++++++++++++++++++
Arustfmt.toml | 9+++++++++
Asrc/lib.rs | 463+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Asrc/method.rs | 40++++++++++++++++++++++++++++++++++++++++
7 files changed, 603 insertions(+), 0 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "sraft" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-channel = "1.6.1" +async-std = {version = "1.11.0", features = ["attributes"]} +borsh = "0.9.3" +futures = "0.3.21" +lazy_static = "1.4.0" +log = "0.4.16" +rand = "0.8.5" + +[dev-dependencies] +async-executor = "1.4.1" +clap = {version = "3.1.6", features = ["derive"]} +easy-parallel = "3.2.0" +simplelog = "0.12.0-alpha1" +smol = "1.2.5" diff --git a/README.md b/README.md @@ -0,0 +1,14 @@ +sraft +===== + +Simple Raft consensus implementation. + +``` +$ cargo build --example peer +$ ./target/debug/examples/peer -p 127.0.0.1:13002 -p 127.0.0.1:13003 -i 1 -l 127.0.0.1:13001 +$ ./target/debug/examples/peer -p 127.0.0.1:13001 -p 127.0.0.1:13001 -i 2 -l 127.0.0.1:13002 +$ ./target/debug/examples/peer -p 127.0.0.1:13001 -p 127.0.0.1:13002 -i 3 -l 127.0.0.1:13003 +``` + +Try stopping and starting certain nodes to see how new leaders are +elected. diff --git a/examples/peer.rs b/examples/peer.rs @@ -0,0 +1,55 @@ +use std::net::SocketAddr; + +use async_executor::Executor; +use async_std::sync::Arc; +use clap::Parser; +use easy_parallel::Parallel; +use simplelog::{ColorChoice, Config, LevelFilter, TermLogger, TerminalMode}; + +use sraft::{Raft, RaftRpc}; + +#[derive(Parser)] +struct Args { + #[clap(long, short)] + peer: Vec<SocketAddr>, + + #[clap(long, short)] + id: u64, + + #[clap(long, short)] + listen: SocketAddr, +} + +#[async_std::main] +async fn main() { + let args = Args::parse(); + + TermLogger::init(LevelFilter::Debug, Config::default(), TerminalMode::Mixed, ColorChoice::Auto) + .unwrap(); + + let mut raft = Raft::new(args.id); + for (k, v) in args.peer.iter().enumerate() { + raft.peers.insert(k as u64, *v); + } + + let raft_rpc = RaftRpc(args.listen); + + let ex = Arc::new(Executor::new()); + let (_signal, shutdown) = async_channel::unbounded::<()>(); + + Parallel::new() + .each(0..4, |_| smol::future::block_on(ex.run(shutdown.recv()))) + // + .add(|| { + smol::future::block_on(async move { + raft_rpc.start().await; + }); + Ok(()) + }) + // + .finish(|| { + smol::future::block_on(async move { + raft.start().await; + }) + }); +} diff --git a/rustfmt.toml b/rustfmt.toml @@ -0,0 +1,9 @@ +reorder_imports = true +imports_granularity = "Crate" +use_small_heuristics = "Max" +comment_width = 100 +wrap_comments = false +binop_separator = "Back" +trailing_comma = "Vertical" +trailing_semicolon = false +use_field_init_shorthand = true diff --git a/src/lib.rs b/src/lib.rs @@ -0,0 +1,463 @@ +use std::{collections::HashMap, io, net::SocketAddr, time::Duration}; + +use async_channel::{Receiver, Sender}; +use async_std::{ + io::{ReadExt, WriteExt}, + net::{TcpListener, TcpStream}, + stream::StreamExt, + sync::Mutex, + task, +}; +use borsh::{BorshDeserialize, BorshSerialize}; +use futures::{select, FutureExt}; +use lazy_static::lazy_static; +use log::{debug, error}; +use rand::Rng; + +mod method; +use crate::method::{HeartbeatArgs, HeartbeatReply, RaftMethod, VoteArgs, VoteReply}; + +#[derive(BorshSerialize, BorshDeserialize, Clone, Debug)] +pub struct LogEntry { + log_term: u64, + log_index: u64, + log_data: Vec<u8>, +} + +pub struct LogStore(pub Vec<LogEntry>); + +impl LogStore { + fn get_last_index(&self) -> u64 { + let rlen = self.0.len(); + if rlen == 0 { + return 0 + } + + self.0[rlen - 1].log_index + } +} + +lazy_static! { + pub static ref LOG_STORE: Mutex<LogStore> = Mutex::new(LogStore(vec![])); + // This is used for heartbeats + pub static ref HEARTBEAT_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded(); + // This is used to let our node know when it has become a leader + pub static ref TOLEADER_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded(); + + pub static ref STATE: Mutex<State> = Mutex::new(State::new()); +} + +#[derive(Default)] +pub struct State { + pub current_term: u64, + pub voted_for: u64, + pub vote_count: u64, + + pub commit_index: u64, + pub _last_applied: u64, + + pub next_index: Vec<u64>, + pub match_index: Vec<u64>, +} + +impl State { + pub fn new() -> Self { + Self { + current_term: 0, + voted_for: 0, + vote_count: 0, + commit_index: 0, + _last_applied: 0, + next_index: vec![], + match_index: vec![], + } + } +} + +pub enum Role { + Follower, + Candidate, + Leader, +} + +pub struct Raft { + pub peers: HashMap<u64, SocketAddr>, + node_id: u64, + role: Role, +} + +impl Raft { + pub fn new(node_id: u64) -> Self { + Self { peers: Default::default(), node_id, role: Role::Follower } + } + + pub async fn start(&mut self) { + debug!("Raft::start()"); + self.role = Role::Follower; + + let mut state = STATE.lock().await; + state.current_term = 0; + state.voted_for = 0; + drop(state); + + let mut rng = rand::thread_rng(); + + loop { + let delay = Duration::from_millis(rng.gen_range(0..200) + 300); + + match self.role { + Role::Follower => { + select! { + _ = HEARTBEAT_CHAN.1.recv().fuse() => { + debug!("[FOLLOWER] Raft::start(): follower_{} got heartbeat", self.node_id); + } + _ = task::sleep(delay).fuse() => { + debug!("[FOLLOWER] Raft::start(): follower_{} timeout", self.node_id); + self.role = Role::Candidate; + } + } + } + + Role::Candidate => { + debug!("[CANDIDATE] Raft::start(): peer_{} is now a candidate", self.node_id); + let mut state = STATE.lock().await; + state.current_term += 1; + state.voted_for = self.node_id; + state.vote_count = 1; + drop(state); + + // TODO: In background + debug!("[CANDIDATE] Raft::start(): broadcasting request_vote"); + self.broadcast_request_vote().await; + + select! { + _ = task::sleep(delay).fuse() => { + debug!("[CANDIDATE] Raft::start(): Timeout as candidate, becoming a follower"); + self.role = Role::Follower; + } + _ = TOLEADER_CHAN.1.recv().fuse() => { + debug!("[CANDIDATE] Raft::start(): We are now the leader"); + self.role = Role::Leader; + + let mut state = STATE.lock().await; + state.next_index = vec![1_u64; self.peers.len()]; + state.match_index = vec![0_u64; self.peers.len()]; + drop(state); + + // TODO: In background + let t = task::spawn(async { + let mut i = 0; + loop { + debug!("[CANDIDATE] Raft::start(): Appending data in bg loop"); + i += 1; + let state = STATE.lock().await; + let logentry = LogEntry { + log_term: state.current_term, + log_index: i, + log_data: format!("user send: {}", i).as_bytes().to_vec(), + }; + drop(state); + + debug!("[CANDIDATE] Raft::start(): Acquiring logstore lock in bg loop"); + let mut logstore = LOG_STORE.lock().await; + logstore.0.push(logentry); + drop(logstore); + debug!("[CANDIDATE] Raft::start(): Dropped logstore lock in bg loop"); + task::sleep(Duration::from_secs(3)).await; + } + }); + } + } + } + + Role::Leader => { + debug!("[LEADER] Raft::start(): Broadcasting heartbeat as leader"); + self.broadcast_heartbeat().await; + task::sleep(Duration::from_millis(100)).await; + } + } + } + } + + async fn broadcast_request_vote(&mut self) { + debug!("Raft::broadcast_request_vote()"); + let state = STATE.lock().await; + let args = VoteArgs { term: state.current_term, candidate_id: self.node_id }; + drop(state); + + // TODO: Do this concurrently + for i in self.peers.clone() { + debug!("Raft::broadcast_request_vote(): Sending req to peer {}", i.1); + match self.send_request_vote(i.0, args.clone()).await { + Ok(v) => debug!("Raft::broadcast_request_vote(): Got reply: {:?}", v), + Err(e) => { + error!("Raft::broadcast_request_vote(): Failed vote to peer {}, ({})", i.1, e); + continue + } + }; + } + } + + async fn send_request_vote( + &mut self, + node_id: u64, + args: VoteArgs, + ) -> Result<VoteReply, io::Error> { + debug!("Raft::send_request_vote()"); + let addr = self.peers[&node_id]; + + let method = RaftMethod::Vote(args); + let payload = method.try_to_vec().unwrap(); + + debug!("Raft::send_request_vote(): Connecting to peer_{}", node_id); + let mut stream = TcpStream::connect(addr).await?; + debug!("Raft::send_request_vote(): Writing to stream"); + stream.write_all(&payload).await?; + debug!("Raft::send_request_vote(): Wrote to stream"); + + debug!("Raft::send_request_vote(): Reading from stream"); + let mut buf = vec![0_u8; 4096]; + stream.read(&mut buf).await?; + debug!("Raft::send_request_vote(): Read from stream"); + + let reply = try_from_slice_unchecked::<VoteReply>(&buf)?; + let mut state = STATE.lock().await; + if reply.term > state.current_term { + debug!("Raft::send_request_vote(): reply.term > state.current_term"); + state.current_term = reply.term; + state.voted_for = 0; + drop(state); + self.role = Role::Follower; + return Ok(reply) + } + drop(state); + + if reply.vote_granted { + debug!("Raft::send_request_vote(): reply.vote_granted == true"); + let mut state = STATE.lock().await; + state.vote_count += 1; + drop(state); + } + + let state = STATE.lock().await; + if state.vote_count >= (self.peers.len() / 2 + 1).try_into().unwrap() { + debug!("Raft::send_request_vote(): Elected for leader"); + TOLEADER_CHAN.0.send(true).await.unwrap(); + } + drop(state); + + Ok(reply) + } + + async fn broadcast_heartbeat(&mut self) { + debug!("[LEADER] Raft::broadcast_heartbeat()"); + + for i in self.peers.clone() { + let state = STATE.lock().await; + let mut args = HeartbeatArgs { + term: state.current_term, + leader_id: self.node_id, + prev_log_index: 0, + prev_log_term: 0, + entries: vec![], + leader_commit: state.commit_index, + }; + + let prev_log_index = state.next_index[i.0 as usize] - 1; + drop(state); + + debug!("[LEADER] Raft::broadcast_heartbeat(): Acquiring lock on LOG_STORE"); + let logstore = LOG_STORE.lock().await; + if logstore.get_last_index() > prev_log_index { + args.prev_log_index = prev_log_index; + args.prev_log_term = logstore.0[prev_log_index as usize].log_term; + args.entries = logstore.0[prev_log_index as usize..].to_vec(); + drop(logstore); + debug!("[LEADER] Raft::broadcast_heartbeat(): Dropped lock on LOG_STORE"); + debug!("[LEADER] Raft::broadcast_heartbeat(): Send entries: {:?}", args.entries); + } + + // TODO: Run in background + match self.send_heartbeat(i.0, args).await { + Ok(v) => debug!("[LEADER] Raft::broadcast_heartbeat(): Got reply: {:?}", v), + Err(e) => { + error!( + "[LEADER] Raft::broadcast_heartbeat(): Failed heartbeat to peer_{} ({})", + i.0, e + ); + continue + } + }; + } + } + + async fn send_heartbeat( + &mut self, + node_id: u64, + args: HeartbeatArgs, + ) -> Result<HeartbeatReply, io::Error> { + debug!("Raft::send_heartbeat({}, {:?}", node_id, args); + let addr = self.peers[&node_id]; + + let method = RaftMethod::Heartbeat(args); + let payload = method.try_to_vec()?; + + debug!("Raft::send_heartbeat(): Connecting to peer_{}", node_id); + let mut stream = TcpStream::connect(addr).await?; + debug!("Raft::send_heartbeat(): Writing to stream"); + stream.write_all(&payload).await?; + debug!("Raft::send_heartbeat(): Wrote to stream"); + + debug!("Raft::send_heartbeat(): Reading from stream"); + let mut buf = vec![0_u8; 4096]; + stream.read(&mut buf).await?; + debug!("Raft::send_heartbeat(): Read from stream"); + + let reply = try_from_slice_unchecked::<HeartbeatReply>(&buf)?; + + let mut state = STATE.lock().await; + if reply.success { + debug!("Raft::send_heartbeat(): Got success reply"); + if reply.next_index > 0 { + state.next_index[node_id as usize] = reply.next_index; + state.match_index[node_id as usize] = reply.next_index - 1; + } + } else if reply.term > state.current_term { + debug!("Raft::send_heartbeat(): reply.term > state.current_term"); + state.current_term = reply.term; + state.voted_for = 0; + self.role = Role::Follower; + } + drop(state); + + Ok(reply) + } +} + +pub struct RaftRpc(pub SocketAddr); + +impl RaftRpc { + pub async fn start(&self) { + debug!("RaftRpc::start()"); + + debug!("RaftRpc::start(): Binding to {}", self.0); + let listener = TcpListener::bind(self.0).await.unwrap(); + let mut incoming = listener.incoming(); + + while let Some(stream) = incoming.next().await { + debug!("RaftRpc::start(): Got RPC request"); + let stream = stream.unwrap(); + let (reader, writer) = &mut (&stream, &stream); + + debug!("RaftRpc::start(): Reading from reader..."); + let mut buf = vec![0_u8; 4096]; + reader.read(&mut buf).await.unwrap(); + debug!("RaftRpc::start(): Read from reader"); + + match try_from_slice_unchecked::<RaftMethod>(&buf).unwrap() { + RaftMethod::Vote(args) => { + debug!("RaftRpc::start(): Got RaftMethod::Vote"); + let reply = self.request_vote(args).await; + let payload = reply.try_to_vec().unwrap(); + + debug!("RaftRpc::start(): Vote: Writing to writer..."); + writer.write_all(&payload).await.unwrap(); + debug!("RaftRpc::start(): Vote: Wrote to writer"); + } + + RaftMethod::Heartbeat(args) => { + debug!("RaftRpc::start(): Got RaftMethod::Heartbeat"); + let reply = self.heartbeat(args).await; + let payload = reply.try_to_vec().unwrap(); + + debug!("RaftRpc::start(): Heartbeat: Writing to writer..."); + writer.write_all(&payload).await.unwrap(); + debug!("RaftRpc::start(): Heartbeat: Wrote to writer"); + } + } + } + } + + async fn request_vote(&self, args: VoteArgs) -> VoteReply { + debug!("RaftRpc::request_vote()"); + let mut reply = VoteReply { term: 0, vote_granted: false }; + + debug!("RaftRpc::request_vote(): Acquiring state lock"); + let mut state = STATE.lock().await; + debug!("RaftRpc::request_vote(): Got lock"); + + if args.term < state.current_term { + reply.term = state.current_term; + drop(state); + reply.vote_granted = false; + return reply + } + + if state.voted_for == 0 { + state.current_term = args.term; + state.voted_for = args.candidate_id; + drop(state); + reply.term = args.term; + reply.vote_granted = true; + return reply + } + + drop(state); + reply + } + + async fn heartbeat(&self, args: HeartbeatArgs) -> HeartbeatReply { + debug!("RaftRpc::heartbeat()"); + let mut reply = HeartbeatReply { success: false, term: 0, next_index: 0 }; + + debug!("RaftRpc::heartbeat(): Acquiring state lock"); + let state = STATE.lock().await; + debug!("RaftRpc::heartbeat(): Got state lock"); + let current_term = state.current_term; + drop(state); + debug!("RaftRpc::heartbeat(): Dropped state lock"); + + if args.term < current_term { + reply.success = false; + reply.term = current_term; + return reply + } + + debug!("RaftRpc::heartbeat(): Sending to channel"); + HEARTBEAT_CHAN.0.send(true).await.unwrap(); + debug!("RaftRpc::heartbeat(): Sent to channel"); + + if args.entries.is_empty() { + reply.success = true; + reply.term = current_term; + return reply + } + + debug!("RaftRpc::heartbeat(): Acquiring logstore lock"); + let mut logstore = LOG_STORE.lock().await; + debug!("RaftRpc::heartbeat(): Got logstore lock"); + if args.prev_log_index > logstore.get_last_index() { + reply.success = false; + reply.term = current_term; + reply.next_index = logstore.get_last_index() + 1; + drop(logstore); + return reply + } + + logstore.0.extend_from_slice(&args.entries); + reply.next_index = logstore.get_last_index() + 1; + drop(logstore); + debug!("RaftRpc::heartbeat(): Dropped logstore lock"); + + reply.success = true; + reply.term = current_term; + + reply + } +} + +fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, io::Error> { + let mut data_mut = data; + let result = T::deserialize(&mut data_mut)?; + Ok(result) +} diff --git a/src/method.rs b/src/method.rs @@ -0,0 +1,40 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +use crate::LogEntry; + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +pub enum RaftMethod { + Vote(VoteArgs), + Heartbeat(HeartbeatArgs), +} + +#[derive(BorshSerialize, BorshDeserialize, Clone, Debug)] +pub struct VoteArgs { + pub term: u64, + pub candidate_id: u64, +} + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +pub struct VoteReply { + pub term: u64, + pub vote_granted: bool, +} + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +pub struct HeartbeatArgs { + pub term: u64, + pub leader_id: u64, + + pub prev_log_index: u64, + pub prev_log_term: u64, + + pub entries: Vec<LogEntry>, + pub leader_commit: u64, +} + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +pub struct HeartbeatReply { + pub success: bool, + pub term: u64, + pub next_index: u64, +}