lib.rs (16370B)
1 use std::{collections::HashMap, io, net::SocketAddr, time::Duration}; 2 3 use async_channel::{Receiver, Sender}; 4 use async_std::{ 5 io::{ReadExt, WriteExt}, 6 net::{TcpListener, TcpStream}, 7 stream::StreamExt, 8 sync::Mutex, 9 task, 10 }; 11 use borsh::{BorshDeserialize, BorshSerialize}; 12 use futures::{select, FutureExt}; 13 use lazy_static::lazy_static; 14 use log::{debug, error}; 15 use rand::Rng; 16 17 mod method; 18 use crate::method::{HeartbeatArgs, HeartbeatReply, RaftMethod, VoteArgs, VoteReply}; 19 20 #[derive(BorshSerialize, BorshDeserialize, Clone, Debug)] 21 pub struct LogEntry { 22 log_term: u64, 23 log_index: u64, 24 log_data: Vec<u8>, 25 } 26 27 pub struct LogStore(pub Vec<LogEntry>); 28 29 impl LogStore { 30 fn get_last_index(&self) -> u64 { 31 let rlen = self.0.len(); 32 if rlen == 0 { 33 return 0 34 } 35 36 self.0[rlen - 1].log_index 37 } 38 } 39 40 lazy_static! { 41 pub static ref LOG_STORE: Mutex<LogStore> = Mutex::new(LogStore(vec![])); 42 // This is used for heartbeats 43 pub static ref HEARTBEAT_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded(); 44 // This is used to let our node know when it has become a leader 45 pub static ref TOLEADER_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded(); 46 47 pub static ref STATE: Mutex<State> = Mutex::new(State::new()); 48 } 49 50 #[derive(Default)] 51 pub struct State { 52 pub current_term: u64, 53 pub voted_for: u64, 54 pub vote_count: u64, 55 56 pub commit_index: u64, 57 pub _last_applied: u64, 58 59 pub next_index: Vec<u64>, 60 pub match_index: Vec<u64>, 61 } 62 63 impl State { 64 pub fn new() -> Self { 65 Self { 66 current_term: 0, 67 voted_for: 0, 68 vote_count: 0, 69 commit_index: 0, 70 _last_applied: 0, 71 next_index: vec![], 72 match_index: vec![], 73 } 74 } 75 } 76 77 pub enum Role { 78 Follower, 79 Candidate, 80 Leader, 81 } 82 83 pub struct Raft { 84 pub peers: HashMap<u64, SocketAddr>, 85 node_id: u64, 86 role: Role, 87 } 88 89 impl Raft { 90 pub fn new(node_id: u64) -> Self { 91 Self { peers: Default::default(), node_id, role: Role::Follower } 92 } 93 94 pub async fn start(&mut self) { 95 debug!("Raft::start()"); 96 self.role = Role::Follower; 97 98 let mut state = STATE.lock().await; 99 state.current_term = 0; 100 state.voted_for = 0; 101 drop(state); 102 103 let mut rng = rand::thread_rng(); 104 105 loop { 106 let delay = Duration::from_millis(rng.gen_range(0..200) + 300); 107 108 match self.role { 109 Role::Follower => { 110 select! { 111 _ = HEARTBEAT_CHAN.1.recv().fuse() => { 112 debug!("[FOLLOWER] Raft::start(): follower_{} got heartbeat", self.node_id); 113 } 114 _ = task::sleep(delay).fuse() => { 115 debug!("[FOLLOWER] Raft::start(): follower_{} timeout", self.node_id); 116 self.role = Role::Candidate; 117 } 118 } 119 } 120 121 Role::Candidate => { 122 debug!("[CANDIDATE] Raft::start(): peer_{} is now a candidate", self.node_id); 123 let mut state = STATE.lock().await; 124 state.current_term += 1; 125 state.voted_for = self.node_id; 126 state.vote_count = 1; 127 drop(state); 128 129 // TODO: In background 130 debug!("[CANDIDATE] Raft::start(): broadcasting request_vote"); 131 self.broadcast_request_vote().await; 132 133 select! { 134 _ = task::sleep(delay).fuse() => { 135 debug!("[CANDIDATE] Raft::start(): Timeout as candidate, becoming a follower"); 136 self.role = Role::Follower; 137 } 138 _ = TOLEADER_CHAN.1.recv().fuse() => { 139 debug!("[CANDIDATE] Raft::start(): We are now the leader"); 140 self.role = Role::Leader; 141 142 let mut state = STATE.lock().await; 143 state.next_index = vec![1_u64; self.peers.len()]; 144 state.match_index = vec![0_u64; self.peers.len()]; 145 drop(state); 146 147 // TODO: In background 148 let t = task::spawn(async { 149 let mut i = 0; 150 loop { 151 debug!("[CANDIDATE] Raft::start(): Appending data in bg loop"); 152 i += 1; 153 let state = STATE.lock().await; 154 let logentry = LogEntry { 155 log_term: state.current_term, 156 log_index: i, 157 log_data: format!("user send: {}", i).as_bytes().to_vec(), 158 }; 159 drop(state); 160 161 debug!("[CANDIDATE] Raft::start(): Acquiring logstore lock in bg loop"); 162 let mut logstore = LOG_STORE.lock().await; 163 logstore.0.push(logentry); 164 drop(logstore); 165 debug!("[CANDIDATE] Raft::start(): Dropped logstore lock in bg loop"); 166 task::sleep(Duration::from_secs(3)).await; 167 } 168 }); 169 } 170 } 171 } 172 173 Role::Leader => { 174 debug!("[LEADER] Raft::start(): Broadcasting heartbeat as leader"); 175 self.broadcast_heartbeat().await; 176 task::sleep(Duration::from_millis(100)).await; 177 } 178 } 179 } 180 } 181 182 async fn broadcast_request_vote(&mut self) { 183 debug!("Raft::broadcast_request_vote()"); 184 let state = STATE.lock().await; 185 let args = VoteArgs { term: state.current_term, candidate_id: self.node_id }; 186 drop(state); 187 188 // TODO: Do this concurrently 189 for i in self.peers.clone() { 190 debug!("Raft::broadcast_request_vote(): Sending req to peer {}", i.1); 191 match self.send_request_vote(i.0, args.clone()).await { 192 Ok(v) => debug!("Raft::broadcast_request_vote(): Got reply: {:?}", v), 193 Err(e) => { 194 error!("Raft::broadcast_request_vote(): Failed vote to peer {}, ({})", i.1, e); 195 continue 196 } 197 }; 198 } 199 } 200 201 async fn send_request_vote( 202 &mut self, 203 node_id: u64, 204 args: VoteArgs, 205 ) -> Result<VoteReply, io::Error> { 206 debug!("Raft::send_request_vote()"); 207 let addr = self.peers[&node_id]; 208 209 let method = RaftMethod::Vote(args); 210 let payload = method.try_to_vec().unwrap(); 211 212 debug!("Raft::send_request_vote(): Connecting to peer_{}", node_id); 213 let mut stream = TcpStream::connect(addr).await?; 214 debug!("Raft::send_request_vote(): Writing to stream"); 215 stream.write_all(&payload).await?; 216 debug!("Raft::send_request_vote(): Wrote to stream"); 217 218 debug!("Raft::send_request_vote(): Reading from stream"); 219 let mut buf = vec![0_u8; 4096]; 220 stream.read(&mut buf).await?; 221 debug!("Raft::send_request_vote(): Read from stream"); 222 223 let reply = try_from_slice_unchecked::<VoteReply>(&buf)?; 224 let mut state = STATE.lock().await; 225 if reply.term > state.current_term { 226 debug!("Raft::send_request_vote(): reply.term > state.current_term"); 227 state.current_term = reply.term; 228 state.voted_for = 0; 229 drop(state); 230 self.role = Role::Follower; 231 return Ok(reply) 232 } 233 drop(state); 234 235 if reply.vote_granted { 236 debug!("Raft::send_request_vote(): reply.vote_granted == true"); 237 let mut state = STATE.lock().await; 238 state.vote_count += 1; 239 drop(state); 240 } 241 242 let state = STATE.lock().await; 243 if state.vote_count >= (self.peers.len() / 2 + 1).try_into().unwrap() { 244 debug!("Raft::send_request_vote(): Elected for leader"); 245 TOLEADER_CHAN.0.send(true).await.unwrap(); 246 } 247 drop(state); 248 249 Ok(reply) 250 } 251 252 async fn broadcast_heartbeat(&mut self) { 253 debug!("[LEADER] Raft::broadcast_heartbeat()"); 254 255 for i in self.peers.clone() { 256 let state = STATE.lock().await; 257 let mut args = HeartbeatArgs { 258 term: state.current_term, 259 leader_id: self.node_id, 260 prev_log_index: 0, 261 prev_log_term: 0, 262 entries: vec![], 263 leader_commit: state.commit_index, 264 }; 265 266 let prev_log_index = state.next_index[i.0 as usize] - 1; 267 drop(state); 268 269 debug!("[LEADER] Raft::broadcast_heartbeat(): Acquiring lock on LOG_STORE"); 270 let logstore = LOG_STORE.lock().await; 271 if logstore.get_last_index() > prev_log_index { 272 args.prev_log_index = prev_log_index; 273 args.prev_log_term = logstore.0[prev_log_index as usize].log_term; 274 args.entries = logstore.0[prev_log_index as usize..].to_vec(); 275 drop(logstore); 276 debug!("[LEADER] Raft::broadcast_heartbeat(): Dropped lock on LOG_STORE"); 277 debug!("[LEADER] Raft::broadcast_heartbeat(): Send entries: {:?}", args.entries); 278 } 279 280 // TODO: Run in background 281 match self.send_heartbeat(i.0, args).await { 282 Ok(v) => debug!("[LEADER] Raft::broadcast_heartbeat(): Got reply: {:?}", v), 283 Err(e) => { 284 error!( 285 "[LEADER] Raft::broadcast_heartbeat(): Failed heartbeat to peer_{} ({})", 286 i.0, e 287 ); 288 continue 289 } 290 }; 291 } 292 } 293 294 async fn send_heartbeat( 295 &mut self, 296 node_id: u64, 297 args: HeartbeatArgs, 298 ) -> Result<HeartbeatReply, io::Error> { 299 debug!("Raft::send_heartbeat({}, {:?}", node_id, args); 300 let addr = self.peers[&node_id]; 301 302 let method = RaftMethod::Heartbeat(args); 303 let payload = method.try_to_vec()?; 304 305 debug!("Raft::send_heartbeat(): Connecting to peer_{}", node_id); 306 let mut stream = TcpStream::connect(addr).await?; 307 debug!("Raft::send_heartbeat(): Writing to stream"); 308 stream.write_all(&payload).await?; 309 debug!("Raft::send_heartbeat(): Wrote to stream"); 310 311 debug!("Raft::send_heartbeat(): Reading from stream"); 312 let mut buf = vec![0_u8; 4096]; 313 stream.read(&mut buf).await?; 314 debug!("Raft::send_heartbeat(): Read from stream"); 315 316 let reply = try_from_slice_unchecked::<HeartbeatReply>(&buf)?; 317 318 let mut state = STATE.lock().await; 319 if reply.success { 320 debug!("Raft::send_heartbeat(): Got success reply"); 321 if reply.next_index > 0 { 322 state.next_index[node_id as usize] = reply.next_index; 323 state.match_index[node_id as usize] = reply.next_index - 1; 324 } 325 } else if reply.term > state.current_term { 326 debug!("Raft::send_heartbeat(): reply.term > state.current_term"); 327 state.current_term = reply.term; 328 state.voted_for = 0; 329 self.role = Role::Follower; 330 } 331 drop(state); 332 333 Ok(reply) 334 } 335 } 336 337 pub struct RaftRpc(pub SocketAddr); 338 339 impl RaftRpc { 340 pub async fn start(&self) { 341 debug!("RaftRpc::start()"); 342 343 debug!("RaftRpc::start(): Binding to {}", self.0); 344 let listener = TcpListener::bind(self.0).await.unwrap(); 345 let mut incoming = listener.incoming(); 346 347 while let Some(stream) = incoming.next().await { 348 debug!("RaftRpc::start(): Got RPC request"); 349 let stream = stream.unwrap(); 350 let (reader, writer) = &mut (&stream, &stream); 351 352 debug!("RaftRpc::start(): Reading from reader..."); 353 let mut buf = vec![0_u8; 4096]; 354 reader.read(&mut buf).await.unwrap(); 355 debug!("RaftRpc::start(): Read from reader"); 356 357 match try_from_slice_unchecked::<RaftMethod>(&buf).unwrap() { 358 RaftMethod::Vote(args) => { 359 debug!("RaftRpc::start(): Got RaftMethod::Vote"); 360 let reply = self.request_vote(args).await; 361 let payload = reply.try_to_vec().unwrap(); 362 363 debug!("RaftRpc::start(): Vote: Writing to writer..."); 364 writer.write_all(&payload).await.unwrap(); 365 debug!("RaftRpc::start(): Vote: Wrote to writer"); 366 } 367 368 RaftMethod::Heartbeat(args) => { 369 debug!("RaftRpc::start(): Got RaftMethod::Heartbeat"); 370 let reply = self.heartbeat(args).await; 371 let payload = reply.try_to_vec().unwrap(); 372 373 debug!("RaftRpc::start(): Heartbeat: Writing to writer..."); 374 writer.write_all(&payload).await.unwrap(); 375 debug!("RaftRpc::start(): Heartbeat: Wrote to writer"); 376 } 377 } 378 } 379 } 380 381 async fn request_vote(&self, args: VoteArgs) -> VoteReply { 382 debug!("RaftRpc::request_vote()"); 383 let mut reply = VoteReply { term: 0, vote_granted: false }; 384 385 debug!("RaftRpc::request_vote(): Acquiring state lock"); 386 let mut state = STATE.lock().await; 387 debug!("RaftRpc::request_vote(): Got lock"); 388 389 if args.term < state.current_term { 390 reply.term = state.current_term; 391 drop(state); 392 reply.vote_granted = false; 393 return reply 394 } 395 396 if state.voted_for == 0 { 397 state.current_term = args.term; 398 state.voted_for = args.candidate_id; 399 drop(state); 400 reply.term = args.term; 401 reply.vote_granted = true; 402 return reply 403 } 404 405 drop(state); 406 reply 407 } 408 409 async fn heartbeat(&self, args: HeartbeatArgs) -> HeartbeatReply { 410 debug!("RaftRpc::heartbeat()"); 411 let mut reply = HeartbeatReply { success: false, term: 0, next_index: 0 }; 412 413 debug!("RaftRpc::heartbeat(): Acquiring state lock"); 414 let state = STATE.lock().await; 415 debug!("RaftRpc::heartbeat(): Got state lock"); 416 let current_term = state.current_term; 417 drop(state); 418 debug!("RaftRpc::heartbeat(): Dropped state lock"); 419 420 if args.term < current_term { 421 reply.success = false; 422 reply.term = current_term; 423 return reply 424 } 425 426 debug!("RaftRpc::heartbeat(): Sending to channel"); 427 HEARTBEAT_CHAN.0.send(true).await.unwrap(); 428 debug!("RaftRpc::heartbeat(): Sent to channel"); 429 430 if args.entries.is_empty() { 431 reply.success = true; 432 reply.term = current_term; 433 return reply 434 } 435 436 debug!("RaftRpc::heartbeat(): Acquiring logstore lock"); 437 let mut logstore = LOG_STORE.lock().await; 438 debug!("RaftRpc::heartbeat(): Got logstore lock"); 439 if args.prev_log_index > logstore.get_last_index() { 440 reply.success = false; 441 reply.term = current_term; 442 reply.next_index = logstore.get_last_index() + 1; 443 drop(logstore); 444 return reply 445 } 446 447 logstore.0.extend_from_slice(&args.entries); 448 reply.next_index = logstore.get_last_index() + 1; 449 drop(logstore); 450 debug!("RaftRpc::heartbeat(): Dropped logstore lock"); 451 452 reply.success = true; 453 reply.term = current_term; 454 455 reply 456 } 457 } 458 459 fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, io::Error> { 460 let mut data_mut = data; 461 let result = T::deserialize(&mut data_mut)?; 462 Ok(result) 463 }