From 9ec0118815179e3f25809d204017ada355f8148e Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Sun, 11 Aug 2024 10:59:35 +0800 Subject: [PATCH 1/8] feat: instruction scheduling by hardware pipelining --- backend/transform/src/instr_schedule.rs | 583 +++++++++++++++++ backend/transform/src/instrdag.rs | 440 +++++++++++++ backend/transform/src/lib.rs | 227 ++++++- backend/transform/src/transformer.rs | 9 +- out.txt | 730 ++++++++++++++++++++++ test | 85 +++ utils/instruction/src/riscv/riscvinstr.rs | 7 +- 7 files changed, 2073 insertions(+), 8 deletions(-) create mode 100644 backend/transform/src/instr_schedule.rs create mode 100644 backend/transform/src/instrdag.rs create mode 100644 out.txt create mode 100644 test diff --git a/backend/transform/src/instr_schedule.rs b/backend/transform/src/instr_schedule.rs new file mode 100644 index 00000000..3f4d3497 --- /dev/null +++ b/backend/transform/src/instr_schedule.rs @@ -0,0 +1,583 @@ +use std::{ + cell::RefCell, + cmp::{max, min}, + collections::{HashMap, VecDeque}, + fmt::Display, + rc::Rc, +}; + +use crate::{ + instrdag::{postprocess_call, InstrDag, InstrNode}, + Liveliness, RiscvInstr, +}; +use instruction::{ + riscv::{ + prelude::RiscvInstrTrait, + reg::RiscvReg::A0, + value::RiscvTemp::{self, PhysReg}, + }, + RiscvInstrSet, +}; +use utils::{ + SysycError, ADD_ALLOCATABLES, BFS_STATE_THRESHOLD, HARDWARE_PIPELINE_PARAM, + LIVE_THROUGH, NEAR_END, REDUCE_LIVE, REDUCE_SUB, SOFTWARE_PIPELINE_PARAM, + SUM_MIN_RATIO, +}; + +type Node = Rc>; +#[derive(Clone, PartialEq, Eq, Copy, Debug)] +enum AluKind { + Mem, + Normal, + Branch, + Float, + MulDiv, +} +#[derive(Clone, Copy, Debug)] +pub struct Alu { + kind: AluKind, + complete_cycle: usize, + is_fdiv: bool, +} +impl Alu { + fn new(kind: AluKind) -> Self { + Self { + kind, + complete_cycle: 0, // 开区间 + is_fdiv: false, + } + } +} +fn get_alukind(instr: &RiscvInstr) -> AluKind { + let v = instr.get_rtn_array(); + // println!("get_alukind: {} {:?}",instr,v); + if v[0] != 0 { + AluKind::Mem + } else if v[1] != 0 { + AluKind::Branch + } else if v[2] != 0 { + AluKind::MulDiv + } else if v[3] != 0 { + AluKind::Float + } else { + AluKind::Normal + } +} +// 当前惩罚策略:在指令为 instrs 的情况下,在运行每一条指令期间活跃的最大寄存器数目 +// 接受参数:dag:初始图,instrs:当前的指令序列,基本块内 SSA +// 实现硬件流水线的时候,要多返回一个 flight_time_increment +fn punishment( + dag: &InstrDag, + state: &State, + instr_id: usize, + my_reads: Vec, + my_writes: Vec, +) -> (i32, usize, usize, Alu) { + let instr = state.instrs.last().unwrap(); + let mut score = 0; + // 软件流水线的惩罚 + score += + (dag.nodes[instr_id].borrow().to_end as i32) * SOFTWARE_PIPELINE_PARAM; + for i in my_reads.iter() { + if state.liveliness_map.get(i).unwrap().use_num == 1 + && !state.liveliness_map.get(i).unwrap().is_liveout + { + score -= 1; + } + } + for i in my_writes.iter() { + if !state.liveliness_map.get(i).unwrap().is_livein { + score += 1; + } + } + // 判断选择这条指令之后,有多少节点可以变成可调度节点 + let new_allocatables = dag.nodes[instr_id] + .borrow() + .succ + .iter() + .filter(|x| state.indegs[&x.borrow().id] == 1) + .count(); + let alloc_score = -(new_allocatables as i32) * ADD_ALLOCATABLES; + // 判断使得寄存器生命周期尽快结束的惩罚,一方面可以判断 read/write 的寄存器的尽快结束之和,另一方面可以判断 read/write 的寄存器最小离结束的次数,这一段 read 和 write 都是加,是没问题的 + // 思考 live_through 这个参数定义了没用,该怎么用上 + let mut sum_uses: usize = my_reads + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_liveout { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .sum(); + let mut min_uses: usize = my_reads + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_liveout { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .min() + .unwrap_or(0); + sum_uses += my_writes + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_livein { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .sum::(); + min_uses = min( + my_writes + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_livein { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .min() + .unwrap_or(0), + min_uses, + ); + let mut end_live_score = (sum_uses as i32) * SUM_MIN_RATIO; + end_live_score += min_uses as i32; + // 判断对后继的影响 + let mut succ_sum = 0; + let mut succ_min = 0; + for i in dag.nodes[instr_id].borrow().succ.iter() { + let mut my_succ_reads = Vec::new(); + if i.borrow().instr.is_call() { + my_succ_reads = dag.call_reads[state.call_ids.len()].clone(); + } else { + my_succ_reads = i.borrow().instr.get_riscv_read().clone(); + } + succ_sum += my_succ_reads + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_liveout { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .sum::(); + succ_min = min( + my_succ_reads + .iter() + .map(|x| state.liveliness_map.get(x).unwrap().use_num) + .min() + .unwrap_or(0), + succ_min, + ); + // 对 write 寄存器的情况考虑如上 + let mut my_succ_writes = Vec::new(); + if i.borrow().instr.is_call() { + my_succ_writes = if let Some(tmp) = dag.call_writes[state.call_ids.len()] + { + vec![tmp] + } else { + Vec::new() + }; + } else { + my_succ_writes = i.borrow().instr.get_riscv_write().clone(); + } + succ_sum += my_succ_writes + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_livein { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .sum::(); + succ_min = min( + my_succ_writes + .iter() + .map(|x| { + if state.liveliness_map.get(x).unwrap().is_livein { + state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH + } else { + state.liveliness_map.get(x).unwrap().use_num + } + }) + .min() + .unwrap_or(0), + succ_min, + ); + } + let mut succ_score = (succ_sum as i32) * SUM_MIN_RATIO; + // 算硬件流水线的惩罚 + let mut flight_time_incre = 1; + let ready_time = state.flight_time + flight_time_incre; + let mut flight_idx = 0; + let mut flight_unit = Alu::new(AluKind::Normal); + let old_max = state.alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); + // 增量,认为第一条指令在时刻1发射 + if get_alukind(instr) != AluKind::Normal { + for (idx, alu) in state.alus.iter().enumerate() { + if get_alukind(instr) == alu.kind { + if alu.complete_cycle > ready_time { + // wait + flight_time_incre = alu.complete_cycle - ready_time + 1; + } + flight_idx = idx; + flight_unit = Alu::new(alu.kind); + if instr.is_fdiv() { + flight_unit.is_fdiv = true; + } + flight_unit.complete_cycle = state.flight_time + + flight_time_incre + + instr.get_rtn_array()[4] as usize; + if instr.is_fdiv() && alu.is_fdiv { + flight_unit.complete_cycle += utils::FDIV_WAIT; + } + break; + } + } + } else { + // 从 alus[4],alus[5] 拿出 complete_time 更小的来考虑 + flight_idx = (if state.alus[4].complete_cycle < state.alus[5].complete_cycle + { + 4 + } else { + 5 + }); + flight_unit = Alu::new(state.alus[flight_idx].kind); + if state.alus[flight_idx].complete_cycle > ready_time { + flight_time_incre = + state.alus[flight_idx].complete_cycle - ready_time + 1; + } + flight_unit.complete_cycle = + state.flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; + } + let time_incre = max(flight_unit.complete_cycle, old_max) - old_max; + // println!("------------"); + // println!(" in punishment calculation:"); + // for i in state.instrs.iter(){ + // println!("{}",i); + // } + // println!("alu status:"); + // for (idx,i) in state.alus.iter().enumerate(){ + // if idx==flight_idx{ + // println!("{:?}",flight_unit); + // }else{ + // println!("{:?}",i); + // } + // } + // println!("time_incre: {} flight_time_incre: {} flight_idx: {}",time_incre,flight_time_incre,flight_idx); + // println!("------------------"); + succ_score += succ_min as i32; + score = score * REDUCE_LIVE + + alloc_score * ADD_ALLOCATABLES + + end_live_score * NEAR_END + + succ_score * REDUCE_SUB + + time_incre as i32 * HARDWARE_PIPELINE_PARAM; + //println!("punishment: {} flight_time_incre: {} flight_idx: {} flight_unit: {:?}",score,flight_time_incre,flight_idx,flight_unit); + (score, flight_time_incre, flight_idx, flight_unit) +} +#[derive(Clone)] +struct State { + instrs: RiscvInstrSet, + score: i32, + indegs: HashMap, // 把节点的 id 映射到入度 + liveliness_map: HashMap, + call_ids: Vec, + alus: [Alu; 6], + flight_time: usize, +} +impl Display for State { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "State: \n")?; + for i in self.instrs.iter() { + write!(f, "{}\n", i)?; + } + write!(f, "alus: \n")?; + for i in self.alus.iter() { + write!(f, "{:?} ", i)?; + } + write!( + f, + "score: {} flight_time: {}\n", + self.score, self.flight_time + )?; + Ok(()) + } +} +pub fn get_punishment_by_instrs(instr: &Vec>) -> i32 { + // 算出原始的 score + // 按照上面的方法算硬件流水线 + let mut alus = [ + Alu::new(AluKind::Mem), + Alu::new(AluKind::Branch), + Alu::new(AluKind::MulDiv), + Alu::new(AluKind::Float), + Alu::new(AluKind::Normal), + Alu::new(AluKind::Normal), + ]; + let mut flight_time = 0; + for instr in instr.iter() { + let mut flight_time_incre = 1; + let ready_time = flight_time + flight_time_incre; + let old_max = alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); + if get_alukind(instr) != AluKind::Normal { + for (idx, alu) in alus.iter_mut().enumerate() { + if get_alukind(instr) == alu.kind { + if alu.complete_cycle > ready_time { + flight_time_incre = alu.complete_cycle - ready_time + 1; + } + if instr.is_fdiv() { + alu.is_fdiv = true; + } + alu.complete_cycle = + flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; + if instr.is_fdiv() && alu.is_fdiv { + alu.complete_cycle += utils::FDIV_WAIT; + } + break; + } + } + } else { + let flight_idx = (if alus[4].complete_cycle < alus[5].complete_cycle { + 4 + } else { + 5 + }); + if alus[flight_idx].complete_cycle > ready_time { + flight_time_incre = alus[flight_idx].complete_cycle - ready_time + 1; + } + alus[flight_idx].complete_cycle = + flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; + } + flight_time += flight_time_incre; + } + let t = alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); + t as i32 * HARDWARE_PIPELINE_PARAM +} +// 咱想想怎么设计:改动: +// 1. 先不去 clone state,对于每个可以分配的 instruction 把 instr 先 push 再 pop 最后把 pop_front 得到的 State 再 push 回去 +// 2. 每一步的计算保留以下4个参数:total_punishment,state_idx,node_id,my_reads 最后根据 total_punishment 排序并且把前 BFS_STATE_THRESHOLD 给 push 进去 +pub fn instr_schedule_by_dag( + dag: InstrDag, + liveliness_map: HashMap, +) -> Result { + // println!("{}",dag); + // 计算原始 punishment + let original_instrs: Vec<_> = + dag.nodes.iter().rev().map(|x| x.borrow().instr.clone()).collect(); + let original_punishment = get_punishment_by_instrs(&original_instrs); + let mut states = VecDeque::new(); + // calculate indegs + let mut indegs = HashMap::new(); + for node in dag.nodes.iter() { + indegs.insert(node.borrow().id, node.borrow().in_deg); + } + states.push_back(State { + instrs: Vec::new(), + score: 0, + indegs: indegs.clone(), + liveliness_map, + call_ids: Vec::new(), + alus: [ + Alu::new(AluKind::Mem), + Alu::new(AluKind::Branch), + Alu::new(AluKind::MulDiv), + Alu::new(AluKind::Float), + Alu::new(AluKind::Normal), + Alu::new(AluKind::Normal), + ], + flight_time: 0, + }); + let depth = dag.nodes.len(); // bfs 深度已知,是所需要调度的指令总数 + for _i in 0..depth { + let real_cnt = states.len(); + let mut keeps = Vec::new(); + for j in 0..real_cnt { + let mut state = states.pop_front().unwrap(); + let allocatables: Vec<_> = state + .indegs + .iter() + .filter(|(_k, v)| **v == 0) + .map(|(k, _)| *k) + .collect(); + // println!("allocatables: {:?} _i: {:?} _j: {:?} ", allocatables,_i,_j); + // println!("state instrs:"); + // for i in state.instrs.iter() { + // println!("{}", i); + // } + for i in allocatables.iter() { + //let mut new_state = state.clone(); + state.instrs.push(dag.nodes[*i].borrow().instr.clone()); + // get riscv reads and writes + let mut my_reads = Vec::new(); + let mut my_writes = Vec::new(); + if dag.nodes[*i].borrow().instr.is_call() { + //check state's call_id length + my_reads = dag.call_reads[state.call_ids.len()].clone(); + my_writes = if let Some(tmp) = dag.call_writes[state.call_ids.len()] { + vec![tmp] + } else { + Vec::new() + }; + } else { + my_reads = dag.nodes[*i].borrow().instr.get_riscv_read().clone(); + my_writes = dag.nodes[*i].borrow().instr.get_riscv_write().clone(); + } + let (punish, flight_time_incre, flight_idx, flight_unit) = + punishment(&dag, &state, *i, my_reads.clone(), my_writes.clone()); + let score = state.score + punish; + keeps.push((j, *i, score, flight_time_incre, flight_idx, flight_unit)); + state.instrs.pop(); + } + states.push_back(state); + } + // debug print keeps + if keeps.len() > BFS_STATE_THRESHOLD { + keeps.sort_by(|a, b| a.2.cmp(&b.2)); + // println!("keeps: "); + // for entry in keeps.iter() { + // println!("{:?} {}", entry,dag.nodes[entry.1].borrow().instr); + // } + // println!("======= end keeps ======"); + keeps.truncate(BFS_STATE_THRESHOLD); + } + for i in 0..real_cnt { + // iterate the keeps + let mut cnts: Vec<_> = + keeps.iter().filter(|x| x.0 == i).map(|x| *x).collect(); + if cnts.len() == 0 { + states.pop_front(); + } else if cnts.len() == 1 { + let mut state = states.pop_front().unwrap(); + state.instrs.push(dag.nodes[cnts[0].1].borrow().instr.clone()); + if dag.nodes[cnts[0].1].borrow().instr.is_call() { + state.call_ids.push(cnts[0].1); + } + // calc my_reads + let mut my_reads = Vec::new(); + if state.instrs.last().unwrap().is_call() { + my_reads = dag.call_reads[state.call_ids.len() - 1].clone(); + } else { + my_reads = + dag.nodes[cnts[0].1].borrow().instr.get_riscv_read().clone(); + } + // decl the use in new_state's liveliness_map + for i in my_reads.iter() { + state.liveliness_map.get_mut(i).unwrap().use_num -= 1; + } + state.indegs.remove(&cnts[0].1); + for succ in dag.nodes[cnts[0].1].borrow().succ.iter() { + let mut new_indeg = state.indegs.clone(); + new_indeg.insert( + succ.borrow().id, + new_indeg.get(&succ.borrow().id).unwrap() - 1, + ); + state.indegs = new_indeg; + } + state.flight_time += cnts[0].3; + state.alus[cnts[0].4] = cnts[0].5; + state.score = cnts[0].2; + states.push_back(state); + } else { + let mut state = states.pop_front().unwrap(); + for j in 0..cnts.len() - 1 { + let mut new_state = state.clone(); + new_state.instrs.push(dag.nodes[cnts[j].1].borrow().instr.clone()); + if dag.nodes[cnts[j].1].borrow().instr.is_call() { + new_state.call_ids.push(cnts[j].1); + } + // calc my_reads + let mut my_reads = Vec::new(); + if new_state.instrs.last().unwrap().is_call() { + my_reads = dag.call_reads[new_state.call_ids.len() - 1].clone(); + } else { + my_reads = + dag.nodes[cnts[j].1].borrow().instr.get_riscv_read().clone(); + } + // decl the use in new_state's liveliness_map + for i in my_reads.iter() { + new_state.liveliness_map.get_mut(i).unwrap().use_num -= 1; + } + new_state.indegs.remove(&cnts[j].1); + for succ in dag.nodes[cnts[j].1].borrow().succ.iter() { + let mut new_indeg = new_state.indegs.clone(); + new_indeg.insert( + succ.borrow().id, + new_indeg.get(&succ.borrow().id).unwrap() - 1, + ); + new_state.indegs = new_indeg; + } + new_state.flight_time += cnts[j].3; + new_state.alus[cnts[j].4] = cnts[j].5; + new_state.score = cnts[j].2; + states.push_back(new_state); + } + // 最后一次不 clone 了 + state + .instrs + .push(dag.nodes[cnts[cnts.len() - 1].1].borrow().instr.clone()); + if dag.nodes[cnts[cnts.len() - 1].1].borrow().instr.is_call() { + state.call_ids.push(cnts[cnts.len() - 1].1); + } + // calc my_reads + let mut my_reads = Vec::new(); + if state.instrs.last().unwrap().is_call() { + my_reads = dag.call_reads[state.call_ids.len() - 1].clone(); + } else { + my_reads = dag.nodes[cnts[cnts.len() - 1].1] + .borrow() + .instr + .get_riscv_read() + .clone(); + } + // decl the use in new_state's liveliness_map + for i in my_reads.iter() { + state.liveliness_map.get_mut(i).unwrap().use_num -= 1; + } + state.indegs.remove(&cnts[cnts.len() - 1].1); + for succ in dag.nodes[cnts[cnts.len() - 1].1].borrow().succ.iter() { + let mut new_indeg = state.indegs.clone(); + new_indeg.insert( + succ.borrow().id, + new_indeg.get(&succ.borrow().id).unwrap() - 1, + ); + state.indegs = new_indeg; + } + state.flight_time += cnts[cnts.len() - 1].3; + state.alus[cnts[cnts.len() - 1].4] = cnts[cnts.len() - 1].5; + state.score = cnts[cnts.len() - 1].2; + states.push_back(state); + } + } + } + // for i in states.iter() { + // println!("final state instructions:"); + // for j in i.instrs.iter() { + // println!("{}", j); + // } + // } + // state 排序 + states.make_contiguous().sort_by(|a, b| a.score.cmp(&b.score)); + let mut final_state = states.pop_front().unwrap(); + // println!("final state instructions:"); + // for i in final_state.instrs.iter() { + // println!("{}", i); + // } + if final_state.score >= original_punishment { + final_state.instrs = original_instrs; + } else { + // println!("original punishment: {} final punishment: {}",original_punishment,final_state.score); + } + Ok(postprocess_call( + final_state.instrs, + &mut dag.call_related.clone(), // 是我call的顺序可能会调换,post_process 的时候和原本push进去的顺序不一致 + dag.branch.clone(), + &mut final_state.call_ids.clone(), + )) +} diff --git a/backend/transform/src/instrdag.rs b/backend/transform/src/instrdag.rs new file mode 100644 index 00000000..7177804f --- /dev/null +++ b/backend/transform/src/instrdag.rs @@ -0,0 +1,440 @@ +use std::{ + cell::RefCell, + cmp::max, + collections::{HashMap, HashSet}, + rc::Rc, +}; + +use instruction::riscv::{ + reg::RiscvReg::{A0, SP}, + riscvinstr::RiscvInstrTrait, + value::RiscvTemp, + RiscvInstr, +}; +use rrvm::RiscvNode; +use std::fmt; +use utils::SysycError; + +type Node = Rc>; +#[derive(Clone)] +pub struct InstrNode { + pub id: usize, + pub in_deg: usize, + pub instr: RiscvInstr, + pub succ: Vec, + pub last_use: usize, + pub pred: Vec, + pub to_end: usize, +} +impl InstrNode { + pub fn new(instr: &RiscvInstr, id: usize) -> Self { + Self { + id, + in_deg: 0, + instr: instr.clone(), + succ: Vec::new(), + last_use: 0, + pred: Vec::new(), + to_end: 0, + } + } +} + +#[derive(Clone)] +pub struct InstrDag { + pub nodes: Vec, + pub call_related: HashMap>>, + pub branch: Option>, + pub call_writes: Vec>, + pub call_reads: Vec>, +} +fn preprocess_call( + node: &RiscvNode, + call_related: &mut Vec>>, // 换成一个 hashmap 用建完图之后的 node id 来索引 + call_write: &mut Vec>, + call_reads: &mut Vec>, +) -> Vec> { + let mut instrs = Vec::new(); + let mut save_instr = false; + let mut my_call_related = Vec::new(); + let mut is_last_restore = false; + let mut push_this = false; + for (idx, i) in node.borrow().instrs.iter().enumerate() { + if push_this { + push_this = false; + my_call_related.push(i.clone()); + call_write.push(Some(i.get_riscv_write()[0])); + call_related.push(my_call_related); + my_call_related = Vec::new(); + continue; + } + if is_last_restore { + is_last_restore = false; + if i.get_riscv_read().len() == 1 { + if let RiscvTemp::PhysReg(A0) = i.get_riscv_read()[0] { + my_call_related.push(i.clone()); + //call_write.push(Some(i.get_riscv_write()[0])); + push_this = true; + continue; + } else { + call_write.push(None); + } + } else { + call_write.push(None); + } + call_related.push(my_call_related); + my_call_related = Vec::new(); + } + if i.is_save() { + save_instr = true; + my_call_related.push(i.clone()); + } else if i.is_restore() { + save_instr = false; + my_call_related.push(i.clone()); + is_last_restore = true; + if idx == node.borrow().instrs.len() - 1 { + call_related.push(my_call_related); + call_write.push(None); + break; + } + } else if i.is_call() { + instrs.push(i.clone()); + my_call_related.push(i.clone()); + } else if save_instr { + my_call_related.push(i.clone()); + } else { + instrs.push(i.clone()); + } + } + // process call writes and call reads + for call_instrs in call_related.iter() { + // 获取所有 instr 中的riscv_reads 的并集 + let mut riscv_reads = HashSet::new(); + // 先把 SP 扔进 riscv_reads + riscv_reads.insert(RiscvTemp::PhysReg(SP)); + for instr in call_instrs.iter() { + riscv_reads.extend(instr.get_riscv_read().iter().cloned()); + } + // 在 riscv_read 中删除 call 指令前传 param 的时候写的寄存器 + for instr in call_instrs.iter() { + if instr.is_call() { + break; + } + for i in instr.get_riscv_write().iter() { + riscv_reads.remove(i); + } + } + call_reads.push(riscv_reads.iter().cloned().collect()); + } + instrs +} +pub fn postprocess_call( + instrs: Vec>, + call_related: &mut HashMap>>, + branch_related: Option>, + call_idxs: &mut Vec, +) -> Vec> { + let mut my_instrs = Vec::new(); + for i in instrs { + if i.is_call() { + my_instrs.append( + &mut call_related.get(&call_idxs.pop().unwrap()).unwrap().clone(), + ); + } else { + my_instrs.push(i); + } + } + if let Some(instr) = branch_related { + my_instrs.push(instr); + } + // debug print + // println!("postprocess call instrs:"); + // for i in my_instrs.iter() { + // println!("{}", i); + // } + // println!("---------------postprocess call instrs end---------------------"); + my_instrs +} +impl InstrDag { + pub fn new(node: &RiscvNode) -> Result { + let mut nodes: Vec = Vec::new(); + let mut defs: HashMap>> = HashMap::new(); + let mut uses: HashMap>>> = + HashMap::new(); + let mut last_call: Option = None; + let mut last_loads: Vec = Vec::new(); + let mut call_related = Vec::new(); + let mut last_uses = HashMap::new(); + let mut last_branch: Option> = None; + let mut call_write = Vec::new(); + let mut call_reads = Vec::new(); + let mut li_ret = None; + let mut call_related_map = HashMap::new(); + let mut call_instrs: Vec>> = Vec::new(); + let mut my_call_write = None; + let mut ret_call_writes = Vec::new(); + let mut ret_call_reads = Vec::new(); + // preprocessing call related: 把 call 前后的 从 save 到 restore 的若干条指令保存在 call_related 里面,然后加入到 is_filtered_idx 之后遍历instrs 的时候遇到就直接continue + // println!("original instrs :"); + // for i in node.borrow().instrs.iter() { + // println!("{}", i); + // } + let mut processed_instrs = preprocess_call( + node, + &mut call_related, + &mut call_write, + &mut call_reads, + ); + ret_call_writes = call_write.clone(); + ret_call_reads = call_reads.clone(); + if processed_instrs.len() > 0 { + let last_instr = processed_instrs.last().unwrap(); + if last_instr.is_branch() { + last_branch = Some(last_instr.clone()); + let _ = processed_instrs.pop(); + } + } + // println!("call read temps: {:?}", call_reads); + // println!("call related instructions:"); + // for i in call_related.iter() { + // for j in i.iter() { + // println!("{}", j); + // }println!("----"); + // } + // for i in call_related.iter(){ + // for j in i.iter(){ + // if j.is_call(){ + // println!("get riscv read: {:?}",j.get_riscv_read()); + // println!("get riscv write: {:?}",j.get_riscv_write()); + // println!("call write: {:?}",call_write); + // println!("-----------"); + // } + // } + // } + // 传参 call 回去 param read 会需要记录 + for i in call_related.iter() { + let mut riscv_writes = HashSet::new(); + let mut riscv_reads = HashSet::new(); + for j in i.iter() { + riscv_writes.extend(j.get_riscv_write().iter().cloned()); + riscv_reads.extend(j.get_riscv_read().iter().cloned()); + } + // println!("for total call related instructions: riscvreads {:?}",riscv_reads); + // println!("for total call related instructions: riscvwrites {:?}",riscv_writes); + // println!("------------"); + } + // println!("processed_instrs len: {}",processed_instrs.len()); + // for i in processed_instrs.iter() { + // println!("{}",i); + // } + for (idx, instr) in processed_instrs.iter().rev().enumerate() { + // println!("instr id:{} {}",instr, idx); + // println!("instr read: {:?}",instr.get_riscv_read()); + // println!("instr write: {:?}",instr.get_riscv_write()); + let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); + if idx == 0 { + if instr.get_riscv_write().len() == 1 + && instr.get_riscv_write()[0] == RiscvTemp::PhysReg(A0) + { + li_ret = Some(node.clone()); + } + } + let mut instr_node_succ = Vec::new(); + let instructions_write = instr.get_riscv_write().clone(); + if instr.is_call() == false { + for instr_write in instructions_write { + instr_node_succ.extend( + uses.get(&instr_write).unwrap_or(&Vec::new()).iter().cloned(), + ); + // println!("in instr {} write extending..",node.borrow().id); + // for i in uses.get(&instr_write).unwrap_or(&Vec::>>::new()).iter().map(|z| z.borrow().id).collect::>() { + // println!("intr write extending to id: {}", i); + // } + // 同时 extend predecessors + for i in uses.get(&instr_write).unwrap_or(&Vec::new()).iter() { + i.borrow_mut().pred.push(node.clone()); + } + uses.remove(&instr_write); + } + } else { + let tmp = call_write.pop().unwrap(); + my_call_write = tmp.clone(); + if let Some(tmp) = tmp { + instr_node_succ + .extend(uses.get(&tmp).unwrap_or(&Vec::new()).iter().cloned()); + uses.get(&tmp).unwrap_or(&Vec::new()).iter().for_each(|x| { + x.borrow_mut().pred.push(node.clone()); + }); + uses.remove(&tmp); + } + } + let instr_read = instr.get_riscv_read().clone(); + if instr.is_call() == false { + for instr_read_temp in instr_read.iter() { + if let Some(def_instr) = defs.get(instr_read_temp) { + instr_node_succ.push(def_instr.clone()); + def_instr.borrow_mut().pred.push(node.clone()); + // println!("in instr def extending {}->{}",node.borrow().id,def_instr.borrow().id); + } + uses.entry(*instr_read_temp).or_default().push(node.clone()); + if !last_uses.contains_key(instr_read_temp) { + last_uses.insert(*instr_read_temp, idx); + } + } + } else { + let tmp = call_reads.pop().unwrap(); + for instr_read_temp in tmp.iter() { + if let Some(def_instr) = defs.get(instr_read_temp) { + instr_node_succ.push(def_instr.clone()); + def_instr.borrow_mut().pred.push(node.clone()); + } + uses.entry(*instr_read_temp).or_default().push(node.clone()); + if !last_uses.contains_key(instr_read_temp) { + last_uses.insert(*instr_read_temp, idx); + } + } + } + // init defs + if instr.is_call() == false { + let instructions_write = instr.get_riscv_write().clone(); + for instr_write in instructions_write.iter() { + defs.insert(*instr_write, node.clone()); + } + } else { + if let Some(tmp) = my_call_write { + defs.insert(tmp, node.clone()); + } + } + // 处理 load call store 指令的依赖关系 + if instr.is_call() { + // 先考虑一下那个最后一条 mov other reg a0 + instr_node_succ.extend(last_loads.iter().cloned()); + last_loads.iter().for_each(|x| { + x.borrow_mut().pred.push(node.clone()); + }); + // println!("in is_call {} extending loads {:?}",node.borrow().id,last_loads.iter().map(|x| x.borrow().id).collect::>()); + last_loads.clear(); + last_call = Some(node.clone()); + if let Some(ret_node) = li_ret.clone() { + instr_node_succ.push(ret_node.clone()); + ret_node.borrow_mut().pred.push(node.clone()); + } + for i in call_instrs.iter() { + instr_node_succ.push(i.clone()); + i.borrow_mut().pred.push(node.clone()); + } + call_instrs.push(node.clone()); + // for i in nodes.iter() { + // instr_node_succ.push(i.clone()); + // } + } else if instr.is_load().unwrap_or(false) { + if let Some(last_call) = last_call.clone() { + // println!("in is_load {} extending last_call {}",node.borrow().id,last_call.borrow().id); + instr_node_succ.push(last_call.clone()); + last_call.borrow_mut().pred.push(node.clone()); + } + last_loads.push(node.clone()); + } else if instr.is_store().unwrap_or(false) { + instr_node_succ.extend(last_loads.iter().cloned()); + last_loads.iter().for_each(|x| { + x.borrow_mut().pred.push(node.clone()); + }); + last_loads.clear(); + last_call = Some(node.clone()); + for i in call_instrs.iter() { + instr_node_succ.push(i.clone()); + i.borrow_mut().pred.push(node.clone()); + } + call_instrs.push(node.clone()); + } + node.borrow_mut().succ = instr_node_succ; + nodes.push(node); + } + for node in nodes.iter() { + // println!("node id: {}", node.borrow().id); + // println!("node successors: {:?}", node.borrow().succ.iter().map(|s| s.borrow().id).collect::>()); + // println!("---------"); + for succ in node.borrow().succ.iter() { + succ.borrow_mut().in_deg += 1; + } + } + for (index, instr) in nodes.iter_mut().enumerate().rev() { + instr.borrow_mut().last_use += + last_uses.iter().filter(|x| *x.1 == index).count(); + } + // construct hashmap,key is the id of the nodes that are call, values are the call instructions + for (idx, instrs) in nodes.iter().enumerate().rev() { + if instrs.borrow().instr.is_call() { + call_related_map.insert(idx, call_related.pop().unwrap()); + } + } + Ok(Self { + nodes, + call_related: call_related_map, + branch: last_branch, + call_reads: ret_call_reads, + call_writes: ret_call_writes, + }) + } + pub fn assign_nodes(&mut self) { + // 先备份一遍所有 node 的 indegs + let indegs = + self.nodes.iter().map(|x| x.borrow().in_deg).collect::>(); + // 开始遍历 + let mut stack_ = Vec::new(); + for i in self.nodes.iter() { + if i.borrow().succ.len() == 0 { + stack_.push(i.clone()); + // get latency + let siz = i.borrow().instr.get_rtn_array()[4] as usize; + i.borrow_mut().to_end = siz; + } + } + while stack_.len() > 0 { + let node = stack_.pop().unwrap(); + for i in node.borrow().pred.iter() { + let new_end = max( + i.borrow().to_end, + node.borrow().to_end + i.borrow().instr.get_rtn_array()[4] as usize, + ); + i.borrow_mut().to_end = new_end; + i.borrow_mut().in_deg -= 1; + if i.borrow().in_deg == 0 { + stack_.push(i.clone()); + } + } + } + // 对每个点恢复 in_deg + for (i, j) in self.nodes.iter().zip(indegs.iter()) { + i.borrow_mut().in_deg = *j; + } + } +} +impl fmt::Display for InstrDag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for node in &self.nodes { + let instr_node = node.borrow(); + writeln!(f, "Node ID: {}", instr_node.id)?; + writeln!(f, "In-degree: {}", instr_node.in_deg)?; + writeln!(f, "Instruction: {}", instr_node.instr)?; + writeln!( + f, + "Successors: {:?}", + instr_node.succ.iter().map(|x| x.borrow().id).collect::>() + )?; + // print successor's in degrees + writeln!( + f, + "Successors' In-degree: {:?}", + instr_node + .succ + .iter() + .map(|x| x.borrow().in_deg) + .collect::>() + )?; + writeln!(f, "Last Use: {}", instr_node.last_use)?; + writeln!(f, "---------------------------")?; + } + Ok(()) + } +} diff --git a/backend/transform/src/lib.rs b/backend/transform/src/lib.rs index ca5a59a8..665ed29c 100644 --- a/backend/transform/src/lib.rs +++ b/backend/transform/src/lib.rs @@ -1,12 +1,31 @@ -use std::{cell::RefCell, collections::HashMap, rc::Rc}; +use std::{ + cell::RefCell, + collections::{BTreeMap, HashMap, HashSet}, + io::{self, Write}, + rc::Rc, +}; +<<<<<<< HEAD +======= +use instr_schedule::instr_schedule_by_dag; +use instrdag::InstrDag; +>>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) use instruction::{riscv::prelude::*, temp::TempManager}; use llvm::Value; +use utils::{SysycError::RiscvGenError}; use rrvm::prelude::*; -use transformer::to_riscv; -use utils::{errors::Result, SysycError::RiscvGenError}; +use transformer::{to_riscv, to_rt_type}; +use utils::{ + errors::Result, BLOCKSIZE_THRESHOLD, DEPENDENCY_EXPLORE_DEPTH, + SCHEDULE_THRESHOLD, +}; +<<<<<<< HEAD +======= +pub mod instr_schedule; +pub mod instrdag; +>>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) pub mod remove_phi; pub mod transformer; @@ -15,15 +34,104 @@ pub fn get_functions( funcs: Vec, ) -> Result<()> { for func in funcs { - program.funcs.push(convert_func(func, &mut program.temp_mgr)?); + let converted_func = convert_func(func, &mut program.temp_mgr)?; + println!("--- before instr schedule: ---"); + for i in converted_func.0.cfg.blocks.iter() { + for j in i.borrow().instrs.iter() { + println!("{}", j); + } + println!("------------block end-------------"); + // println!( + // "jump instruction: {}", + // i.borrow().jump_instr.as_ref().unwrap() + // ); + } + println!("---end---"); + io::stdout().flush().unwrap(); + let func = instr_schedule( + converted_func.0, + converted_func.1, + converted_func.2, + &mut program.temp_mgr, + )?; + println!("--------"); + for i in func.cfg.blocks.iter() { + for j in i.borrow().instrs.iter() { + println!("{}", j); + } + println!("------------block end-------------"); + } + println!("--------"); + program.funcs.push(func); } Ok(()) } +pub fn instr_schedule( + func: RiscvFunc, + live_ins: Vec>, + live_outs: Vec>, + mgr: &mut TempManager, +) -> Result { + func.cfg.clear_data_flow(); + func.cfg.analysis(); + let mut new_blocks = Vec::new(); + for (idx, node) in func.cfg.blocks.iter().enumerate() { + let nodes = + instr_schedule_block(node, &live_ins[idx], &live_outs[idx], mgr)?; + new_blocks.extend(nodes); + } + Ok(RiscvFunc { + total: mgr.total, + spills: 0, + cfg: RiscvCFG { blocks: new_blocks }, + name: func.name, + params: func.params, + ret_type: func.ret_type, + }) +} +pub fn instr_schedule_block( + riscv_node: &RiscvNode, + live_ins: &HashSet, + live_outs: &HashSet, + mgr: &mut TempManager, +) -> Result> { + if riscv_node.borrow().instrs.len() >= SCHEDULE_THRESHOLD { + return Ok(vec![riscv_node.clone()]); + } + let prev = riscv_node + .borrow() + .prev + .iter() + .map(|v| v.borrow().id) + .collect::>(); + let succ = riscv_node + .borrow() + .succ + .iter() + .map(|v| v.borrow().id) + .collect::>(); + // 判断 prev 和 succ 是否有交集 + if prev.intersection(&succ).count() > 0 + && riscv_node.borrow().instrs.len() <= BLOCKSIZE_THRESHOLD + { + // filter call (instrs 中不能有 call 指令) + if riscv_node.borrow().instrs.iter().any(|instr| instr.is_call()) { + transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) + .map(|v| vec![v]) + } else { + transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) + .map(|v| vec![v]) + } + } else { + transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) + .map(|v| vec![v]) + } +} pub fn convert_func( func: LlvmFunc, mgr: &mut TempManager, -) -> Result { +) -> Result<(RiscvFunc, Vec>, Vec>)> { let mut nodes = Vec::new(); let mut edge = Vec::new(); let mut table = HashMap::new(); @@ -74,6 +182,7 @@ pub fn convert_func( for (u, v) in edge { force_link_node(table.get(&u).unwrap(), table.get(&v).unwrap()) } +<<<<<<< HEAD Ok(RiscvFunc { total: mgr.total, spills: 0, @@ -82,8 +191,116 @@ pub fn convert_func( params: func.params, ret_type: func.ret_type, }) +======= + Ok(( + RiscvFunc { + total: mgr.total, + spills: 0, + cfg: RiscvCFG { blocks: nodes }, + name: func.name, + params: func.params, + ret_type: func.ret_type, + }, + live_ins, + live_outs, + )) +>>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) } +fn transform_basic_block_by_pipelining( + node: &RiscvNode, + live_in: &HashSet, + live_out: &HashSet, + _mgr: &mut TempManager, +) -> Result { + let mut instr_dag = InstrDag::new(node)?; + let liveliness_map = get_liveliness_map(&instr_dag, live_in, live_out); + instr_dag.assign_nodes(); + node.borrow_mut().instrs = instr_schedule_by_dag(instr_dag, liveliness_map)?; + Ok(node.clone()) +} +#[derive(Clone, Debug)] +pub struct Liveliness { + is_livein: bool, + is_liveout: bool, + use_num: usize, +} +fn get_liveliness_map( + node: &InstrDag, + live_in: &HashSet, + live_out: &HashSet, +) -> HashMap { + let mut map = HashMap::new(); + let mut call_reads = node.call_reads.clone(); + call_reads.reverse(); + let mut call_writes = node.call_writes.clone(); + call_writes.reverse(); + // 它这里要求是正序遍历,所以遍历次序是和 node 的顺序反的,需要 iter.rev(),同样,call_reads,call_writes 也要reverse再pop + for instrnode in node.nodes.iter().rev() { + let instr = &instrnode.borrow().instr; + if !instr.is_call() { + for tmp in instr.get_riscv_read().iter() { + map + .entry(*tmp) + .or_insert(Liveliness { + is_livein: false, + is_liveout: false, + use_num: 0, + }) + .use_num += 1; + } + for tmp in instr.get_riscv_write().iter() { + map.entry(*tmp).or_insert(Liveliness { + is_livein: false, + is_liveout: false, + use_num: 0, + }); + } + } else { + let call_read = call_reads.pop().unwrap(); + for tmp in call_read.iter() { + map + .entry(*tmp) + .or_insert(Liveliness { + is_livein: false, + is_liveout: false, + use_num: 0, + }) + .use_num += 1; + } + let call_write = call_writes.pop().unwrap(); + for tmp in call_write.iter() { + map.entry(*tmp).or_insert(Liveliness { + is_livein: false, + is_liveout: false, + use_num: 0, + }); + } + } + } + // do live_in + for tmp in live_in.iter() { + map + .entry(*tmp) + .or_insert(Liveliness { + is_livein: true, + is_liveout: false, + use_num: 0, + }) + .is_livein = true; + } + for tmp in live_out.iter() { + map + .entry(*tmp) + .or_insert(Liveliness { + is_livein: false, + is_liveout: true, + use_num: 0, + }) + .is_liveout = true; + } + map +} fn transform_basicblock( node: &LlvmNode, mgr: &mut TempManager, diff --git a/backend/transform/src/transformer.rs b/backend/transform/src/transformer.rs index da2a8dec..847500ec 100644 --- a/backend/transform/src/transformer.rs +++ b/backend/transform/src/transformer.rs @@ -1,4 +1,8 @@ -use instruction::{riscv::convert::*, temp::TempManager, RiscvInstrSet}; +use instruction::{ + riscv::{convert::*, RiscvInstr}, + temp::TempManager, + RiscvInstrSet, +}; use llvm::{LlvmInstr, LlvmInstrVariant}; use utils::errors::Result; @@ -23,3 +27,6 @@ pub fn to_riscv( }?; Ok(riscv_instr) } +pub fn to_rt_type(instr: &RiscvInstr) -> [i32; 5] { + instr.get_rtn_array() +} diff --git a/out.txt b/out.txt new file mode 100644 index 00000000..cc2b845a --- /dev/null +++ b/out.txt @@ -0,0 +1,730 @@ +original punishment: 15 final punishment: 14 +original punishment: 12 final punishment: 8 +original punishment: 9 final punishment: 7 +original punishment: 12 final punishment: 8 +original punishment: 9 final punishment: 7 +original punishment: 12 final punishment: 8 +original punishment: 9 final punishment: 7 +original punishment: 12 final punishment: 11 +original punishment: 19 final punishment: 17 +original punishment: 19 final punishment: 17 +original punishment: 30 final punishment: 29 +original punishment: 19 final punishment: 17 +original punishment: 12 final punishment: 8 +original punishment: 9 final punishment: 7 + .file "./project-eval/testcases/performance/gameoflife-p61glidergun.sy" + .option nopic + .attribute unaligned_access, 0 + .attribute stack_align, 16 + .text + .global sheet1 + .section .sbss, "aw", @nobits + .align 2 + .type sheet1, @object + .size sheet1, 1000000 +sheet1: + .zero 1000000 + .global sheet2 + .align 2 + .type sheet2, @object + .size sheet2, 1000000 +sheet2: + .zero 1000000 + .global width + .align 2 + .type width, @object + .size width, 4 +width: + .zero 4 + .global height + .align 2 + .type height, @object + .size height, 4 +height: + .zero 4 + .global steps + .align 2 + .type steps, @object + .size steps, 4 +steps: + .zero 4 + .global active + .section .sdata, "aw" + .align 2 + .type active, @object + .size active, 4 +active: + .word 1 + .text + .global main + .align 1 + .type read_map, @function +read_map: + addi sp, sp, -48 + sd s2, 0(sp) + sd ra, 8(sp) + sd s4, 16(sp) + sd s5, 24(sp) + sd s1, 32(sp) + sd s3, 40(sp) + la s5, width + call getint + sw a0, 0(s5) + la s4, height + li s1, 1 + call getint + sw a0, 0(s4) + la s3, steps + call getint + sw a0, 0(s3) + call getch + j L_1 + L_2: + li a0, 1 + j L_3 + L_4: + call getch + xori a0, a0, 35 + seqz a0, a0 + bne a0, x0, L_5 + la a0, sheet1 + add s5, a0, s3 + add a0, s5, s2 + sw x0, 0(a0) + j L_6 + L_5: + la a0, sheet1 + add s5, a0, s3 + add s2, s5, s2 + li s5, 1 + sw s5, 0(s2) + L_6: + mv a0, s1 + L_3: + la s5, width + addiw s1, a0, 1 + slliw s2, a0, 2 + lw s5, 0(s5) + slt s5, s5, a0 + seqz s5, s5 + bne s5, x0, L_4 + call getch + mv s1, s4 + L_1: + li s5, 2000 + la s2, height + addiw s4, s1, 1 + mulw s3, s1, s5 + lw s2, 0(s2) + slt s1, s2, s1 + seqz s1, s1 + bne s1, x0, L_2 + ld s2, 0(sp) + ld ra, 8(sp) + ld s4, 16(sp) + ld s5, 24(sp) + ld s1, 32(sp) + ld s3, 40(sp) + addi sp, sp, 48 + ret + .size read_map, .-read_map + .align 1 + .type put_map, @function +put_map: + addi sp, sp, -48 + sd s3, 0(sp) + sd s1, 8(sp) + sd s4, 16(sp) + sd ra, 24(sp) + sd s5, 32(sp) + sd s2, 40(sp) + li a0, 1 + j L_7 + L_8: + li s4, 1 + j L_9 + L_10: + la a0, sheet1 + add s4, a0, s5 + add s1, s4, s1 + lw a0, 0(s1) + xori a0, a0, 1 + seqz a0, a0 + bne a0, x0, L_11 + li a0, 46 + call putch + j L_12 + L_11: + li a0, 35 + call putch + L_12: + mv s4, s2 + L_9: + la a0, width + slliw s1, s4, 2 + addiw s2, s4, 1 + lw a0, 0(a0) + slt s4, a0, s4 + seqz s4, s4 + bne s4, x0, L_10 + li a0, 10 + call putch + mv a0, s3 + L_7: + la s4, height + addiw s3, a0, 1 + li s5, 2000 + lw s2, 0(s4) + mulw s5, a0, s5 + slt s2, s2, a0 + seqz s2, s2 + bne s2, x0, L_8 + ld s3, 0(sp) + ld s1, 8(sp) + ld s4, 16(sp) + ld ra, 24(sp) + ld s5, 32(sp) + ld s2, 40(sp) + addi sp, sp, 48 + ret + .size put_map, .-put_map + .align 1 + .type swap12, @function +swap12: + addi sp, sp, -48 + sd s2, 0(sp) + sd s6, 8(sp) + sd s4, 16(sp) + sd s1, 24(sp) + sd s5, 32(sp) + sd s3, 40(sp) + li s2, 1 + j L_13 + L_14: + li s1, 1 + j L_15 + L_16: + la s6, sheet2 + mv s1, s2 + add s6, s6, s3 + la s2, sheet1 + add s2, s2, s3 + add s6, s6, s5 + lw s6, 0(s6) + add s5, s2, s5 + sw s6, 0(s5) + L_15: + la s6, width + addiw s2, s1, 1 + slliw s5, s1, 2 + lw s6, 0(s6) + slt s1, s6, s1 + seqz s1, s1 + bne s1, x0, L_16 + mv s2, s4 + L_13: + li s3, 2000 + la s6, height + addiw s4, s2, 1 + mulw s3, s2, s3 + lw s5, 0(s6) + slt s2, s5, s2 + seqz s2, s2 + bne s2, x0, L_14 + ld s2, 0(sp) + ld s6, 8(sp) + ld s4, 16(sp) + ld s1, 24(sp) + ld s5, 32(sp) + ld s3, 40(sp) + addi sp, sp, 48 + ret + .size swap12, .-swap12 + .align 1 + .type step, @function +step: + addi sp, sp, -96 + sd s6, 0(sp) + sd s7, 8(sp) + sd s1, 16(sp) + sd s4, 24(sp) + sd s8, 32(sp) + sd s9, 40(sp) + sd s10, 48(sp) + sd s11, 56(sp) + sd s2, 64(sp) + sd s3, 72(sp) + sd s5, 80(sp) + li s4, 1 + j L_17 + L_18: + li a5, 1 + j L_19 + L_20: + lw s2, 0(s4) + lw s1, 0(s6) + addw s2, s2, s1 + lw s1, 0(s7) + addw s2, s2, s1 + lw s1, 0(s8) + addw s2, s2, s1 + lw s1, 0(s9) + addw s2, s2, s1 + lw s1, 0(s10) + addw s2, s2, s1 + lw s1, 0(s11) + addw s2, s2, s1 + lw s1, 0(a2) + addw s3, s2, s1 + lw s1, 0(a3) + xori s1, s1, 1 + seqz s1, s1 + xori s2, s3, 2 + seqz s2, s2 + bne s1, x0, L_21 + li s2, 0 + j L_22 + L_21: + L_22: + bne s2, x0, L_23 + xori s1, s3, 3 + seqz s1, s1 + bne s1, x0, L_24 + sw x0, 0(a4) + j L_25 + L_23: + li s1, 1 + sw s1, 0(a4) + L_26: + L_19: + la s1, width + lw s1, 0(s1) + slt s5, s1, a5 + seqz s5, s5 + addiw s1, a5, -1 + slliw s1, s1, 2 + add s4, t1, s1 + slliw s3, a5, 2 + add s6, t1, s3 + addiw a5, a5, 1 + slliw s2, a5, 2 + add s7, t1, s2 + add s8, a7, s1 + add s9, a7, s2 + add s10, t0, s1 + add s11, t0, s3 + add a2, t0, s2 + add a3, a7, s3 + add a4, a6, s3 + bne s5, x0, L_20 + mv s4, t2 + L_17: + addiw t2, s4, 1 + li s5, 2000 + li s2, 2000 + addiw s1, s4, -1 + li s3, 2000 + mulw s1, s1, s2 + add t1, a0, s1 + la s2, height + mulw s1, t2, s3 + add t0, a0, s1 + lw s1, 0(s2) + slt s2, s1, s4 + seqz s2, s2 + mulw s1, s4, s5 + add a7, a0, s1 + add a6, a1, s1 + bne s2, x0, L_18 + ld s6, 0(sp) + ld s7, 8(sp) + ld s1, 16(sp) + ld s4, 24(sp) + ld s8, 32(sp) + ld s9, 40(sp) + ld s10, 48(sp) + ld s11, 56(sp) + ld s2, 64(sp) + ld s3, 72(sp) + ld s5, 80(sp) + addi sp, sp, 96 + ret + L_24: + li s1, 1 + sw s1, 0(a4) + L_25: + j L_26 + .size step, .-step + .align 1 + .type main, @function +main: + addi sp, sp, -96 + sd s9, 0(sp) + sd s1, 8(sp) + sd s4, 16(sp) + sd s7, 24(sp) + sd ra, 32(sp) + sd s8, 40(sp) + sd s5, 48(sp) + sd s3, 56(sp) + sd s6, 64(sp) + sd s11, 72(sp) + sd s10, 80(sp) + sd s2, 88(sp) + call read_map + li a0, 95 + call _sysy_starttime + j L_27 + L_28: + la s1, active + lw s1, 0(s1) + xori s1, s1, 1 + seqz s1, s1 + bne s1, x0, L_29 + la t0, sheet2 + la t2, sheet1 + li s4, 1 + j L_30 + L_29: + la t0, sheet1 + la t2, sheet2 + li s5, 1 + j L_31 + L_32: + li s3, 1 + j L_33 + L_34: + li a3, 1 + j L_35 + L_36: + lw s5, 0(s4) + lw s1, 0(s1) + xori s11, s1, 1 + seqz s11, s11 + lw s3, 0(s6) + lw s4, 0(s7) + lw s2, 0(s8) + lw s1, 0(s9) + addw s1, s1, s2 + addw s2, s1, s3 + lw s1, 0(s10) + addw s1, s2, s1 + addw s2, s1, s5 + lw s1, 0(a0) + addw s1, s2, s1 + addw s2, s1, s4 + lw s1, 0(a1) + addw s2, s2, s1 + xori s1, s2, 2 + seqz s1, s1 + bne s11, x0, L_37 + li s1, 0 + j L_38 + L_37: + L_38: + bne s1, x0, L_39 + xori s1, s2, 3 + seqz s1, s1 + bne s1, x0, L_40 + sw x0, 0(a2) + j L_41 + L_39: + li s1, 1 + sw s1, 0(a2) + L_42: + L_35: + la s1, width + lw s1, 0(s1) + slt s5, s1, a3 + seqz s5, s5 + addiw s1, a3, -1 + slliw s1, s1, 2 + add s9, a4, s1 + slliw s3, a3, 2 + add s8, a4, s3 + addiw a3, a3, 1 + slliw s2, a3, 2 + add s6, a4, s2 + add s10, a6, s1 + add s4, a6, s2 + add a0, a5, s1 + add s7, a5, s3 + add a1, a5, s2 + add s1, a6, s3 + add a2, a7, s3 + bne s5, x0, L_36 + mv s5, t1 + L_31: + addiw t1, s5, 1 + addiw s3, s5, -1 + li s4, 2000 + li s1, 2000 + li s2, 2000 + mulw s1, s5, s1 + add a7, t2, s1 + add a6, t0, s1 + mulw s1, t1, s2 + add a5, t0, s1 + la s2, height + mulw s1, s3, s4 + add a4, t0, s1 + lw s1, 0(s2) + slt s1, s1, s5 + seqz s1, s1 + bne s1, x0, L_34 + la s2, active + li s1, 2 + sw s1, 0(s2) + L_43: + la s2, steps + la s1, steps + lw s1, 0(s1) + addiw s1, s1, -1 + sw s1, 0(s2) + L_27: + la s1, steps + lw s1, 0(s1) + slt s1, x0, s1 + bne s1, x0, L_28 + li a0, 106 + call _sysy_stoptime + la s1, active + lw s1, 0(s1) + xori s1, s1, 2 + seqz s1, s1 + bne s1, x0, L_32 + j L_44 + L_40: + li s1, 1 + sw s1, 0(a2) + L_41: + j L_42 + L_45: + li a3, 1 + j L_46 + L_47: + lw s2, 0(s4) + lw s1, 0(s6) + addw s2, s2, s1 + lw s1, 0(s7) + addw s2, s2, s1 + lw s1, 0(s8) + addw s2, s2, s1 + lw s1, 0(s9) + addw s2, s2, s1 + lw s1, 0(s10) + addw s2, s2, s1 + lw s1, 0(s11) + addw s2, s2, s1 + lw s1, 0(a0) + addw s3, s2, s1 + lw s1, 0(a1) + xori s2, s1, 1 + seqz s2, s2 + xori s1, s3, 2 + seqz s1, s1 + bne s2, x0, L_48 + li s1, 0 + j L_49 + L_48: + L_49: + bne s1, x0, L_50 + xori s1, s3, 3 + seqz s1, s1 + bne s1, x0, L_51 + sw x0, 0(a2) + j L_52 + L_50: + li s1, 1 + sw s1, 0(a2) + L_53: + L_46: + la s1, width + lw s1, 0(s1) + slt s5, s1, a3 + seqz s5, s5 + addiw s1, a3, -1 + slliw s1, s1, 2 + add s4, a7, s1 + slliw s3, a3, 2 + add s6, a7, s3 + addiw a3, a3, 1 + slliw s2, a3, 2 + add s7, a7, s2 + add s8, a5, s1 + add s9, a5, s2 + add s10, a4, s1 + add s11, a4, s3 + add a0, a4, s2 + add a1, a5, s3 + add a2, a6, s3 + bne s5, x0, L_47 + mv s4, t1 + L_30: + li s5, 2000 + li s2, 2000 + li s3, 2000 + addiw t1, s4, 1 + addiw s1, s4, -1 + mulw s1, s1, s2 + add a7, t0, s1 + la s2, height + mulw s1, s4, s3 + add a6, t2, s1 + add a5, t0, s1 + lw s1, 0(s2) + slt s2, s1, s4 + mulw s1, t1, s5 + seqz s2, s2 + add a4, t0, s1 + bne s2, x0, L_45 + la s2, active + li s1, 1 + sw s1, 0(s2) + j L_43 + L_51: + li s1, 1 + sw s1, 0(a2) + L_52: + j L_53 + L_54: + li s2, 1 + j L_55 + L_56: + la s1, sheet1 + add s1, s1, s5 + add s2, s1, s3 + la s1, sheet2 + add s1, s1, s5 + add s1, s1, s3 + lw s1, 0(s1) + sw s1, 0(s2) + mv s2, s4 + L_55: + la s1, width + addiw s4, s2, 1 + slliw s3, s2, 2 + lw s1, 0(s1) + slt s1, s1, s2 + seqz s1, s1 + bne s1, x0, L_56 + mv s3, s6 + L_33: + li s1, 2000 + la s2, height + addiw s6, s3, 1 + mulw s5, s3, s1 + lw s1, 0(s2) + slt s1, s1, s3 + seqz s1, s1 + bne s1, x0, L_54 + L_44: + call put_map + li a0, 0 + ld s9, 0(sp) + ld s1, 8(sp) + ld s4, 16(sp) + ld s7, 24(sp) + ld ra, 32(sp) + ld s8, 40(sp) + ld s5, 48(sp) + ld s3, 56(sp) + ld s6, 64(sp) + ld s11, 72(sp) + ld s10, 80(sp) + ld s2, 88(sp) + addi sp, sp, 96 + ret + .size main, .-main + + +.text +.global __create_threads +.global __join_threads + + SYS_clone = 220 + CLONE_VM = 256 + SIGCHLD = 17 + __create_threads: + addi a0, a0, -1 + ble a0, zero, .ret_0 + mv a6, a0 + li a5, 0 + mv a1, sp + li a2, 0 + li a3, 0 + li a4, 0 + .L0_builtin: + li a0, (CLONE_VM | SIGCHLD) + li a7, SYS_clone + ecall + bne a0, zero, .ret_i + addi a5, a5, 1 + blt a5, a6, .L0_builtin + .ret_n: + mv a0, a6 + j .L1_builtin + .ret_0: + mv a0, zero + j .L1_builtin + .ret_i: + mv a0, a5 + .L1_builtin: + jr ra + + SYS_waitid = 95 + SYS_exit = 93 + P_ALL = 0 + WEXITED = 4 + __join_threads: + mv a4, a0 + addi a5, a1, -1 + beq a4, a5, .L2_builtin + li a0, P_ALL + li a1, 0 + li a2, 0 + li a3, WEXITED + li a7, SYS_waitid + ecall + .L2_builtin: + beq a4, zero, .L3_builtin + li a0, 0 + li a7, SYS_exit + ecall + .L3_builtin: + jr ra + + + __fill_zero_words: + ble a1, zero, .L8_builtin + addi a1, a1, -1 + slliw a1, a1, 2 + add a2, a1, a0 # 最后一次4字节 + addi a3, a2, -1 + andi a3, a3, -8 # 最后一次8字节 + andi a4, a0, 7 + beq a4, x0, .L4_builtin + + sw x0, 0(a0) + addi a0, a0, 4 + + .L4_builtin: + bgtu a0, a3, .L7_builtin + + .L5_builtin: + sd x0, 0(a0) + addi a0, a0, 8 + ble a0, a3, .L5_builtin + + .L7_builtin: + bgtu a0, a2, .L8_builtin # 如果不够最后一次4字节 + sw x0, 0(a0) + addi a0, a0, 4 + + .L8_builtin: + jr ra + + + + + .ident "SYSYC: (made by RRVM) 1.0.0" diff --git a/test b/test new file mode 100644 index 00000000..7e4c3cb2 --- /dev/null +++ b/test @@ -0,0 +1,85 @@ + .file "./project-eval/testcases/performance/recursive_call_3.sy" + .option nopic + .attribute unaligned_access, 0 + .attribute stack_align, 16 + .text + .text + .global main + .align 1 + .type func_calc_coef, @function +func_calc_coef: + addi sp, sp, -16 + sd s1, 0(sp) + sd ra, 8(sp) + mv s1, a0 + li a0, 0 + blt a1, a0, L_1 + addiw a1, a1, -1 + mv a0, s1 + call func_calc_coef + flw fa3, 0(s1) + li a0, 4 + add a1, s1, a0 + flw fa2, 0(a1) + li a0, 1065353216 + addi a0, a0, 0 + fmv.w.x fa1, a0 + fadd.s fa0, fa1, fa3 + fmul.s fa1, fa3, fa0 + fmul.s fa3, fa3, fa2 + fadd.s fa3, fa2, fa3 + fsub.s fa1, fa0, fa1 + fsub.s fa2, fa2, fa3 + fsw fa1, 0(s1) + fsw fa2, 0(a1) + j L_2 + L_1: + li a1, 4 + add a0, s1, a1 + fmv.w.x fa0, x0 + fsw fa0, 0(a0) + fsw fa0, 0(s1) + L_2: + ld s1, 0(sp) + ld ra, 8(sp) + addi sp, sp, 16 + ret + .size func_calc_coef, .-func_calc_coef + .align 1 + .type main, @function +main: + addi sp, sp, -16 + sd ra, 0(sp) + sd s1, 8(sp) + call getint + mv a1, a0 + addi sp, sp, -16 + mv s1, sp + mv a0, s1 + call func_calc_coef + flw fa1, 0(s1) + li a1, 1065361408 + addi a0, a1, 197 + fmv.w.x fa2, a0 + fmul.s fa1, fa1, fa2 + li a1, 4 + add a1, s1, a1 + flw fa0, 0(a1) + fadd.s fa0, fa1, fa0 + fsub.s fa1, fa0, fa2 + fmv.w.x fa2, x0 + feq.s s1, fa1, fa2 + bne s1, x0, L_3 + j L_4 + L_3: + li a0, 112 + call putch + L_4: + addi sp, sp, 16 + mv a0, x0 + ld ra, 0(sp) + ld s1, 8(sp) + addi sp, sp, 16 + ret + .size main, .-main + .ident "SYSYC: (made by RRVM) 1.0.0" diff --git a/utils/instruction/src/riscv/riscvinstr.rs b/utils/instruction/src/riscv/riscvinstr.rs index cad4af3a..a0a1ef11 100644 --- a/utils/instruction/src/riscv/riscvinstr.rs +++ b/utils/instruction/src/riscv/riscvinstr.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, fmt::Display}; +use std::{ + collections::HashMap, + fmt::Display, + ops::{Add, Sub}, +}; use sysyc_derive::UseTemp; use utils::{mapper::LabelMapper, InstrTrait, Label, UseTemp, RTN}; @@ -38,7 +42,6 @@ impl Clone for RiscvInstr { self.clone_box() } } - pub trait RiscvInstrTrait: Display + UseTemp + CloneRiscvInstr + RTN { From 0ea48ff841693f17b619f171300bc8ea3dbf5562 Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Sun, 11 Aug 2024 16:30:03 +0800 Subject: [PATCH 2/8] chore: fix warnings --- backend/transform/src/instr_schedule.rs | 115 ++-- backend/transform/src/instrdag.rs | 57 +- backend/transform/src/lib.rs | 77 +-- out.txt | 730 ------------------------ 4 files changed, 51 insertions(+), 928 deletions(-) delete mode 100644 out.txt diff --git a/backend/transform/src/instr_schedule.rs b/backend/transform/src/instr_schedule.rs index 3f4d3497..c0ec4e13 100644 --- a/backend/transform/src/instr_schedule.rs +++ b/backend/transform/src/instr_schedule.rs @@ -1,20 +1,17 @@ use std::{ - cell::RefCell, cmp::{max, min}, collections::{HashMap, VecDeque}, fmt::Display, - rc::Rc, }; use crate::{ - instrdag::{postprocess_call, InstrDag, InstrNode}, + instrdag::{postprocess_call, InstrDag}, Liveliness, RiscvInstr, }; use instruction::{ riscv::{ prelude::RiscvInstrTrait, - reg::RiscvReg::A0, - value::RiscvTemp::{self, PhysReg}, + value::RiscvTemp::{self}, }, RiscvInstrSet, }; @@ -24,7 +21,6 @@ use utils::{ SUM_MIN_RATIO, }; -type Node = Rc>; #[derive(Clone, PartialEq, Eq, Copy, Debug)] enum AluKind { Mem, @@ -151,12 +147,12 @@ fn punishment( let mut succ_sum = 0; let mut succ_min = 0; for i in dag.nodes[instr_id].borrow().succ.iter() { - let mut my_succ_reads = Vec::new(); + let my_succ_reads={ if i.borrow().instr.is_call() { - my_succ_reads = dag.call_reads[state.call_ids.len()].clone(); + dag.call_reads[state.call_ids.len()].clone() } else { - my_succ_reads = i.borrow().instr.get_riscv_read().clone(); - } + i.borrow().instr.get_riscv_read().clone() + }}; succ_sum += my_succ_reads .iter() .map(|x| { @@ -176,17 +172,17 @@ fn punishment( succ_min, ); // 对 write 寄存器的情况考虑如上 - let mut my_succ_writes = Vec::new(); + let my_succ_writes = { if i.borrow().instr.is_call() { - my_succ_writes = if let Some(tmp) = dag.call_writes[state.call_ids.len()] + if let Some(tmp) = dag.call_writes[state.call_ids.len()] { vec![tmp] } else { Vec::new() - }; + } } else { - my_succ_writes = i.borrow().instr.get_riscv_write().clone(); - } + i.borrow().instr.get_riscv_write().clone() + }}; succ_sum += my_succ_writes .iter() .map(|x| { @@ -243,12 +239,12 @@ fn punishment( } } else { // 从 alus[4],alus[5] 拿出 complete_time 更小的来考虑 - flight_idx = (if state.alus[4].complete_cycle < state.alus[5].complete_cycle + flight_idx = if state.alus[4].complete_cycle < state.alus[5].complete_cycle { 4 } else { 5 - }); + }; flight_unit = Alu::new(state.alus[flight_idx].kind); if state.alus[flight_idx].complete_cycle > ready_time { flight_time_incre = @@ -258,28 +254,14 @@ fn punishment( state.flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; } let time_incre = max(flight_unit.complete_cycle, old_max) - old_max; - // println!("------------"); - // println!(" in punishment calculation:"); - // for i in state.instrs.iter(){ - // println!("{}",i); - // } - // println!("alu status:"); - // for (idx,i) in state.alus.iter().enumerate(){ - // if idx==flight_idx{ - // println!("{:?}",flight_unit); - // }else{ - // println!("{:?}",i); - // } - // } - // println!("time_incre: {} flight_time_incre: {} flight_idx: {}",time_incre,flight_time_incre,flight_idx); - // println!("------------------"); + succ_score += succ_min as i32; score = score * REDUCE_LIVE + alloc_score * ADD_ALLOCATABLES + end_live_score * NEAR_END + succ_score * REDUCE_SUB + time_incre as i32 * HARDWARE_PIPELINE_PARAM; - //println!("punishment: {} flight_time_incre: {} flight_idx: {} flight_unit: {:?}",score,flight_time_incre,flight_idx,flight_unit); + (score, flight_time_incre, flight_idx, flight_unit) } #[derive(Clone)] @@ -325,9 +307,8 @@ pub fn get_punishment_by_instrs(instr: &Vec>) -> i32 { for instr in instr.iter() { let mut flight_time_incre = 1; let ready_time = flight_time + flight_time_incre; - let old_max = alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); if get_alukind(instr) != AluKind::Normal { - for (idx, alu) in alus.iter_mut().enumerate() { + for alu in alus.iter_mut() { if get_alukind(instr) == alu.kind { if alu.complete_cycle > ready_time { flight_time_incre = alu.complete_cycle - ready_time + 1; @@ -344,11 +325,11 @@ pub fn get_punishment_by_instrs(instr: &Vec>) -> i32 { } } } else { - let flight_idx = (if alus[4].complete_cycle < alus[5].complete_cycle { + let flight_idx = if alus[4].complete_cycle < alus[5].complete_cycle { 4 } else { 5 - }); + }; if alus[flight_idx].complete_cycle > ready_time { flight_time_incre = alus[flight_idx].complete_cycle - ready_time + 1; } @@ -360,14 +341,12 @@ pub fn get_punishment_by_instrs(instr: &Vec>) -> i32 { let t = alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); t as i32 * HARDWARE_PIPELINE_PARAM } -// 咱想想怎么设计:改动: // 1. 先不去 clone state,对于每个可以分配的 instruction 把 instr 先 push 再 pop 最后把 pop_front 得到的 State 再 push 回去 // 2. 每一步的计算保留以下4个参数:total_punishment,state_idx,node_id,my_reads 最后根据 total_punishment 排序并且把前 BFS_STATE_THRESHOLD 给 push 进去 pub fn instr_schedule_by_dag( dag: InstrDag, liveliness_map: HashMap, ) -> Result { - // println!("{}",dag); // 计算原始 punishment let original_instrs: Vec<_> = dag.nodes.iter().rev().map(|x| x.borrow().instr.clone()).collect(); @@ -406,17 +385,11 @@ pub fn instr_schedule_by_dag( .filter(|(_k, v)| **v == 0) .map(|(k, _)| *k) .collect(); - // println!("allocatables: {:?} _i: {:?} _j: {:?} ", allocatables,_i,_j); - // println!("state instrs:"); - // for i in state.instrs.iter() { - // println!("{}", i); - // } for i in allocatables.iter() { - //let mut new_state = state.clone(); state.instrs.push(dag.nodes[*i].borrow().instr.clone()); // get riscv reads and writes - let mut my_reads = Vec::new(); - let mut my_writes = Vec::new(); + let my_reads; + let my_writes; if dag.nodes[*i].borrow().instr.is_call() { //check state's call_id length my_reads = dag.call_reads[state.call_ids.len()].clone(); @@ -437,19 +410,13 @@ pub fn instr_schedule_by_dag( } states.push_back(state); } - // debug print keeps if keeps.len() > BFS_STATE_THRESHOLD { keeps.sort_by(|a, b| a.2.cmp(&b.2)); - // println!("keeps: "); - // for entry in keeps.iter() { - // println!("{:?} {}", entry,dag.nodes[entry.1].borrow().instr); - // } - // println!("======= end keeps ======"); keeps.truncate(BFS_STATE_THRESHOLD); } for i in 0..real_cnt { // iterate the keeps - let mut cnts: Vec<_> = + let cnts: Vec<_> = keeps.iter().filter(|x| x.0 == i).map(|x| *x).collect(); if cnts.len() == 0 { states.pop_front(); @@ -460,13 +427,12 @@ pub fn instr_schedule_by_dag( state.call_ids.push(cnts[0].1); } // calc my_reads - let mut my_reads = Vec::new(); + let my_reads = { if state.instrs.last().unwrap().is_call() { - my_reads = dag.call_reads[state.call_ids.len() - 1].clone(); + dag.call_reads[state.call_ids.len() - 1].clone() } else { - my_reads = - dag.nodes[cnts[0].1].borrow().instr.get_riscv_read().clone(); - } + dag.nodes[cnts[0].1].borrow().instr.get_riscv_read().clone() + }}; // decl the use in new_state's liveliness_map for i in my_reads.iter() { state.liveliness_map.get_mut(i).unwrap().use_num -= 1; @@ -493,13 +459,12 @@ pub fn instr_schedule_by_dag( new_state.call_ids.push(cnts[j].1); } // calc my_reads - let mut my_reads = Vec::new(); + let my_reads = { if new_state.instrs.last().unwrap().is_call() { - my_reads = dag.call_reads[new_state.call_ids.len() - 1].clone(); + dag.call_reads[new_state.call_ids.len() - 1].clone() } else { - my_reads = - dag.nodes[cnts[j].1].borrow().instr.get_riscv_read().clone(); - } + dag.nodes[cnts[j].1].borrow().instr.get_riscv_read().clone() + }}; // decl the use in new_state's liveliness_map for i in my_reads.iter() { new_state.liveliness_map.get_mut(i).unwrap().use_num -= 1; @@ -526,16 +491,16 @@ pub fn instr_schedule_by_dag( state.call_ids.push(cnts[cnts.len() - 1].1); } // calc my_reads - let mut my_reads = Vec::new(); + let my_reads={ if state.instrs.last().unwrap().is_call() { - my_reads = dag.call_reads[state.call_ids.len() - 1].clone(); + dag.call_reads[state.call_ids.len() - 1].clone() } else { - my_reads = dag.nodes[cnts[cnts.len() - 1].1] + dag.nodes[cnts[cnts.len() - 1].1] .borrow() .instr .get_riscv_read() - .clone(); - } + .clone() + }}; // decl the use in new_state's liveliness_map for i in my_reads.iter() { state.liveliness_map.get_mut(i).unwrap().use_num -= 1; @@ -556,23 +521,11 @@ pub fn instr_schedule_by_dag( } } } - // for i in states.iter() { - // println!("final state instructions:"); - // for j in i.instrs.iter() { - // println!("{}", j); - // } - // } // state 排序 states.make_contiguous().sort_by(|a, b| a.score.cmp(&b.score)); let mut final_state = states.pop_front().unwrap(); - // println!("final state instructions:"); - // for i in final_state.instrs.iter() { - // println!("{}", i); - // } if final_state.score >= original_punishment { final_state.instrs = original_instrs; - } else { - // println!("original punishment: {} final punishment: {}",original_punishment,final_state.score); } Ok(postprocess_call( final_state.instrs, diff --git a/backend/transform/src/instrdag.rs b/backend/transform/src/instrdag.rs index 7177804f..94441984 100644 --- a/backend/transform/src/instrdag.rs +++ b/backend/transform/src/instrdag.rs @@ -73,7 +73,6 @@ fn preprocess_call( if i.get_riscv_read().len() == 1 { if let RiscvTemp::PhysReg(A0) = i.get_riscv_read()[0] { my_call_related.push(i.clone()); - //call_write.push(Some(i.get_riscv_write()[0])); push_this = true; continue; } else { @@ -147,14 +146,9 @@ pub fn postprocess_call( if let Some(instr) = branch_related { my_instrs.push(instr); } - // debug print - // println!("postprocess call instrs:"); - // for i in my_instrs.iter() { - // println!("{}", i); - // } - // println!("---------------postprocess call instrs end---------------------"); my_instrs } + impl InstrDag { pub fn new(node: &RiscvNode) -> Result { let mut nodes: Vec = Vec::new(); @@ -172,21 +166,15 @@ impl InstrDag { let mut call_related_map = HashMap::new(); let mut call_instrs: Vec>> = Vec::new(); let mut my_call_write = None; - let mut ret_call_writes = Vec::new(); - let mut ret_call_reads = Vec::new(); // preprocessing call related: 把 call 前后的 从 save 到 restore 的若干条指令保存在 call_related 里面,然后加入到 is_filtered_idx 之后遍历instrs 的时候遇到就直接continue - // println!("original instrs :"); - // for i in node.borrow().instrs.iter() { - // println!("{}", i); - // } let mut processed_instrs = preprocess_call( node, &mut call_related, &mut call_write, &mut call_reads, ); - ret_call_writes = call_write.clone(); - ret_call_reads = call_reads.clone(); + let ret_call_writes = call_write.clone(); + let ret_call_reads = call_reads.clone(); if processed_instrs.len() > 0 { let last_instr = processed_instrs.last().unwrap(); if last_instr.is_branch() { @@ -194,23 +182,6 @@ impl InstrDag { let _ = processed_instrs.pop(); } } - // println!("call read temps: {:?}", call_reads); - // println!("call related instructions:"); - // for i in call_related.iter() { - // for j in i.iter() { - // println!("{}", j); - // }println!("----"); - // } - // for i in call_related.iter(){ - // for j in i.iter(){ - // if j.is_call(){ - // println!("get riscv read: {:?}",j.get_riscv_read()); - // println!("get riscv write: {:?}",j.get_riscv_write()); - // println!("call write: {:?}",call_write); - // println!("-----------"); - // } - // } - // } // 传参 call 回去 param read 会需要记录 for i in call_related.iter() { let mut riscv_writes = HashSet::new(); @@ -219,18 +190,8 @@ impl InstrDag { riscv_writes.extend(j.get_riscv_write().iter().cloned()); riscv_reads.extend(j.get_riscv_read().iter().cloned()); } - // println!("for total call related instructions: riscvreads {:?}",riscv_reads); - // println!("for total call related instructions: riscvwrites {:?}",riscv_writes); - // println!("------------"); } - // println!("processed_instrs len: {}",processed_instrs.len()); - // for i in processed_instrs.iter() { - // println!("{}",i); - // } for (idx, instr) in processed_instrs.iter().rev().enumerate() { - // println!("instr id:{} {}",instr, idx); - // println!("instr read: {:?}",instr.get_riscv_read()); - // println!("instr write: {:?}",instr.get_riscv_write()); let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); if idx == 0 { if instr.get_riscv_write().len() == 1 @@ -246,10 +207,6 @@ impl InstrDag { instr_node_succ.extend( uses.get(&instr_write).unwrap_or(&Vec::new()).iter().cloned(), ); - // println!("in instr {} write extending..",node.borrow().id); - // for i in uses.get(&instr_write).unwrap_or(&Vec::>>::new()).iter().map(|z| z.borrow().id).collect::>() { - // println!("intr write extending to id: {}", i); - // } // 同时 extend predecessors for i in uses.get(&instr_write).unwrap_or(&Vec::new()).iter() { i.borrow_mut().pred.push(node.clone()); @@ -324,12 +281,9 @@ impl InstrDag { i.borrow_mut().pred.push(node.clone()); } call_instrs.push(node.clone()); - // for i in nodes.iter() { - // instr_node_succ.push(i.clone()); - // } + } else if instr.is_load().unwrap_or(false) { if let Some(last_call) = last_call.clone() { - // println!("in is_load {} extending last_call {}",node.borrow().id,last_call.borrow().id); instr_node_succ.push(last_call.clone()); last_call.borrow_mut().pred.push(node.clone()); } @@ -351,9 +305,6 @@ impl InstrDag { nodes.push(node); } for node in nodes.iter() { - // println!("node id: {}", node.borrow().id); - // println!("node successors: {:?}", node.borrow().succ.iter().map(|s| s.borrow().id).collect::>()); - // println!("---------"); for succ in node.borrow().succ.iter() { succ.borrow_mut().in_deg += 1; } diff --git a/backend/transform/src/lib.rs b/backend/transform/src/lib.rs index 665ed29c..4c813d1e 100644 --- a/backend/transform/src/lib.rs +++ b/backend/transform/src/lib.rs @@ -1,31 +1,21 @@ use std::{ cell::RefCell, - collections::{BTreeMap, HashMap, HashSet}, + collections::{HashMap, HashSet}, io::{self, Write}, rc::Rc, }; -<<<<<<< HEAD -======= -use instr_schedule::instr_schedule_by_dag; -use instrdag::InstrDag; ->>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) use instruction::{riscv::prelude::*, temp::TempManager}; use llvm::Value; -use utils::{SysycError::RiscvGenError}; +use utils::SysycError::RiscvGenError; use rrvm::prelude::*; -use transformer::{to_riscv, to_rt_type}; +use transformer::to_riscv; use utils::{ - errors::Result, BLOCKSIZE_THRESHOLD, DEPENDENCY_EXPLORE_DEPTH, + errors::Result, SCHEDULE_THRESHOLD, }; -<<<<<<< HEAD -======= -pub mod instr_schedule; -pub mod instrdag; ->>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) pub mod remove_phi; pub mod transformer; @@ -36,7 +26,7 @@ pub fn get_functions( for func in funcs { let converted_func = convert_func(func, &mut program.temp_mgr)?; println!("--- before instr schedule: ---"); - for i in converted_func.0.cfg.blocks.iter() { + for i in converted_func.cfg.blocks.iter() { for j in i.borrow().instrs.iter() { println!("{}", j); } @@ -49,9 +39,7 @@ pub fn get_functions( println!("---end---"); io::stdout().flush().unwrap(); let func = instr_schedule( - converted_func.0, - converted_func.1, - converted_func.2, + converted_func, &mut program.temp_mgr, )?; println!("--------"); @@ -69,12 +57,12 @@ pub fn get_functions( pub fn instr_schedule( func: RiscvFunc, - live_ins: Vec>, - live_outs: Vec>, mgr: &mut TempManager, ) -> Result { func.cfg.clear_data_flow(); func.cfg.analysis(); + let live_ins=[]; + let live_outs=[]; let mut new_blocks = Vec::new(); for (idx, node) in func.cfg.blocks.iter().enumerate() { let nodes = @@ -99,39 +87,13 @@ pub fn instr_schedule_block( if riscv_node.borrow().instrs.len() >= SCHEDULE_THRESHOLD { return Ok(vec![riscv_node.clone()]); } - let prev = riscv_node - .borrow() - .prev - .iter() - .map(|v| v.borrow().id) - .collect::>(); - let succ = riscv_node - .borrow() - .succ - .iter() - .map(|v| v.borrow().id) - .collect::>(); - // 判断 prev 和 succ 是否有交集 - if prev.intersection(&succ).count() > 0 - && riscv_node.borrow().instrs.len() <= BLOCKSIZE_THRESHOLD - { - // filter call (instrs 中不能有 call 指令) - if riscv_node.borrow().instrs.iter().any(|instr| instr.is_call()) { - transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) + transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) .map(|v| vec![v]) - } else { - transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) - .map(|v| vec![v]) - } - } else { - transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) - .map(|v| vec![v]) - } } pub fn convert_func( func: LlvmFunc, mgr: &mut TempManager, -) -> Result<(RiscvFunc, Vec>, Vec>)> { +) -> Result { let mut nodes = Vec::new(); let mut edge = Vec::new(); let mut table = HashMap::new(); @@ -182,7 +144,7 @@ pub fn convert_func( for (u, v) in edge { force_link_node(table.get(&u).unwrap(), table.get(&v).unwrap()) } -<<<<<<< HEAD + Ok(RiscvFunc { total: mgr.total, spills: 0, @@ -191,21 +153,8 @@ pub fn convert_func( params: func.params, ret_type: func.ret_type, }) -======= - Ok(( - RiscvFunc { - total: mgr.total, - spills: 0, - cfg: RiscvCFG { blocks: nodes }, - name: func.name, - params: func.params, - ret_type: func.ret_type, - }, - live_ins, - live_outs, - )) ->>>>>>> f67bb86 (feat: instruction scheduling by hardware pipelining) -} + } + fn transform_basic_block_by_pipelining( node: &RiscvNode, diff --git a/out.txt b/out.txt deleted file mode 100644 index cc2b845a..00000000 --- a/out.txt +++ /dev/null @@ -1,730 +0,0 @@ -original punishment: 15 final punishment: 14 -original punishment: 12 final punishment: 8 -original punishment: 9 final punishment: 7 -original punishment: 12 final punishment: 8 -original punishment: 9 final punishment: 7 -original punishment: 12 final punishment: 8 -original punishment: 9 final punishment: 7 -original punishment: 12 final punishment: 11 -original punishment: 19 final punishment: 17 -original punishment: 19 final punishment: 17 -original punishment: 30 final punishment: 29 -original punishment: 19 final punishment: 17 -original punishment: 12 final punishment: 8 -original punishment: 9 final punishment: 7 - .file "./project-eval/testcases/performance/gameoflife-p61glidergun.sy" - .option nopic - .attribute unaligned_access, 0 - .attribute stack_align, 16 - .text - .global sheet1 - .section .sbss, "aw", @nobits - .align 2 - .type sheet1, @object - .size sheet1, 1000000 -sheet1: - .zero 1000000 - .global sheet2 - .align 2 - .type sheet2, @object - .size sheet2, 1000000 -sheet2: - .zero 1000000 - .global width - .align 2 - .type width, @object - .size width, 4 -width: - .zero 4 - .global height - .align 2 - .type height, @object - .size height, 4 -height: - .zero 4 - .global steps - .align 2 - .type steps, @object - .size steps, 4 -steps: - .zero 4 - .global active - .section .sdata, "aw" - .align 2 - .type active, @object - .size active, 4 -active: - .word 1 - .text - .global main - .align 1 - .type read_map, @function -read_map: - addi sp, sp, -48 - sd s2, 0(sp) - sd ra, 8(sp) - sd s4, 16(sp) - sd s5, 24(sp) - sd s1, 32(sp) - sd s3, 40(sp) - la s5, width - call getint - sw a0, 0(s5) - la s4, height - li s1, 1 - call getint - sw a0, 0(s4) - la s3, steps - call getint - sw a0, 0(s3) - call getch - j L_1 - L_2: - li a0, 1 - j L_3 - L_4: - call getch - xori a0, a0, 35 - seqz a0, a0 - bne a0, x0, L_5 - la a0, sheet1 - add s5, a0, s3 - add a0, s5, s2 - sw x0, 0(a0) - j L_6 - L_5: - la a0, sheet1 - add s5, a0, s3 - add s2, s5, s2 - li s5, 1 - sw s5, 0(s2) - L_6: - mv a0, s1 - L_3: - la s5, width - addiw s1, a0, 1 - slliw s2, a0, 2 - lw s5, 0(s5) - slt s5, s5, a0 - seqz s5, s5 - bne s5, x0, L_4 - call getch - mv s1, s4 - L_1: - li s5, 2000 - la s2, height - addiw s4, s1, 1 - mulw s3, s1, s5 - lw s2, 0(s2) - slt s1, s2, s1 - seqz s1, s1 - bne s1, x0, L_2 - ld s2, 0(sp) - ld ra, 8(sp) - ld s4, 16(sp) - ld s5, 24(sp) - ld s1, 32(sp) - ld s3, 40(sp) - addi sp, sp, 48 - ret - .size read_map, .-read_map - .align 1 - .type put_map, @function -put_map: - addi sp, sp, -48 - sd s3, 0(sp) - sd s1, 8(sp) - sd s4, 16(sp) - sd ra, 24(sp) - sd s5, 32(sp) - sd s2, 40(sp) - li a0, 1 - j L_7 - L_8: - li s4, 1 - j L_9 - L_10: - la a0, sheet1 - add s4, a0, s5 - add s1, s4, s1 - lw a0, 0(s1) - xori a0, a0, 1 - seqz a0, a0 - bne a0, x0, L_11 - li a0, 46 - call putch - j L_12 - L_11: - li a0, 35 - call putch - L_12: - mv s4, s2 - L_9: - la a0, width - slliw s1, s4, 2 - addiw s2, s4, 1 - lw a0, 0(a0) - slt s4, a0, s4 - seqz s4, s4 - bne s4, x0, L_10 - li a0, 10 - call putch - mv a0, s3 - L_7: - la s4, height - addiw s3, a0, 1 - li s5, 2000 - lw s2, 0(s4) - mulw s5, a0, s5 - slt s2, s2, a0 - seqz s2, s2 - bne s2, x0, L_8 - ld s3, 0(sp) - ld s1, 8(sp) - ld s4, 16(sp) - ld ra, 24(sp) - ld s5, 32(sp) - ld s2, 40(sp) - addi sp, sp, 48 - ret - .size put_map, .-put_map - .align 1 - .type swap12, @function -swap12: - addi sp, sp, -48 - sd s2, 0(sp) - sd s6, 8(sp) - sd s4, 16(sp) - sd s1, 24(sp) - sd s5, 32(sp) - sd s3, 40(sp) - li s2, 1 - j L_13 - L_14: - li s1, 1 - j L_15 - L_16: - la s6, sheet2 - mv s1, s2 - add s6, s6, s3 - la s2, sheet1 - add s2, s2, s3 - add s6, s6, s5 - lw s6, 0(s6) - add s5, s2, s5 - sw s6, 0(s5) - L_15: - la s6, width - addiw s2, s1, 1 - slliw s5, s1, 2 - lw s6, 0(s6) - slt s1, s6, s1 - seqz s1, s1 - bne s1, x0, L_16 - mv s2, s4 - L_13: - li s3, 2000 - la s6, height - addiw s4, s2, 1 - mulw s3, s2, s3 - lw s5, 0(s6) - slt s2, s5, s2 - seqz s2, s2 - bne s2, x0, L_14 - ld s2, 0(sp) - ld s6, 8(sp) - ld s4, 16(sp) - ld s1, 24(sp) - ld s5, 32(sp) - ld s3, 40(sp) - addi sp, sp, 48 - ret - .size swap12, .-swap12 - .align 1 - .type step, @function -step: - addi sp, sp, -96 - sd s6, 0(sp) - sd s7, 8(sp) - sd s1, 16(sp) - sd s4, 24(sp) - sd s8, 32(sp) - sd s9, 40(sp) - sd s10, 48(sp) - sd s11, 56(sp) - sd s2, 64(sp) - sd s3, 72(sp) - sd s5, 80(sp) - li s4, 1 - j L_17 - L_18: - li a5, 1 - j L_19 - L_20: - lw s2, 0(s4) - lw s1, 0(s6) - addw s2, s2, s1 - lw s1, 0(s7) - addw s2, s2, s1 - lw s1, 0(s8) - addw s2, s2, s1 - lw s1, 0(s9) - addw s2, s2, s1 - lw s1, 0(s10) - addw s2, s2, s1 - lw s1, 0(s11) - addw s2, s2, s1 - lw s1, 0(a2) - addw s3, s2, s1 - lw s1, 0(a3) - xori s1, s1, 1 - seqz s1, s1 - xori s2, s3, 2 - seqz s2, s2 - bne s1, x0, L_21 - li s2, 0 - j L_22 - L_21: - L_22: - bne s2, x0, L_23 - xori s1, s3, 3 - seqz s1, s1 - bne s1, x0, L_24 - sw x0, 0(a4) - j L_25 - L_23: - li s1, 1 - sw s1, 0(a4) - L_26: - L_19: - la s1, width - lw s1, 0(s1) - slt s5, s1, a5 - seqz s5, s5 - addiw s1, a5, -1 - slliw s1, s1, 2 - add s4, t1, s1 - slliw s3, a5, 2 - add s6, t1, s3 - addiw a5, a5, 1 - slliw s2, a5, 2 - add s7, t1, s2 - add s8, a7, s1 - add s9, a7, s2 - add s10, t0, s1 - add s11, t0, s3 - add a2, t0, s2 - add a3, a7, s3 - add a4, a6, s3 - bne s5, x0, L_20 - mv s4, t2 - L_17: - addiw t2, s4, 1 - li s5, 2000 - li s2, 2000 - addiw s1, s4, -1 - li s3, 2000 - mulw s1, s1, s2 - add t1, a0, s1 - la s2, height - mulw s1, t2, s3 - add t0, a0, s1 - lw s1, 0(s2) - slt s2, s1, s4 - seqz s2, s2 - mulw s1, s4, s5 - add a7, a0, s1 - add a6, a1, s1 - bne s2, x0, L_18 - ld s6, 0(sp) - ld s7, 8(sp) - ld s1, 16(sp) - ld s4, 24(sp) - ld s8, 32(sp) - ld s9, 40(sp) - ld s10, 48(sp) - ld s11, 56(sp) - ld s2, 64(sp) - ld s3, 72(sp) - ld s5, 80(sp) - addi sp, sp, 96 - ret - L_24: - li s1, 1 - sw s1, 0(a4) - L_25: - j L_26 - .size step, .-step - .align 1 - .type main, @function -main: - addi sp, sp, -96 - sd s9, 0(sp) - sd s1, 8(sp) - sd s4, 16(sp) - sd s7, 24(sp) - sd ra, 32(sp) - sd s8, 40(sp) - sd s5, 48(sp) - sd s3, 56(sp) - sd s6, 64(sp) - sd s11, 72(sp) - sd s10, 80(sp) - sd s2, 88(sp) - call read_map - li a0, 95 - call _sysy_starttime - j L_27 - L_28: - la s1, active - lw s1, 0(s1) - xori s1, s1, 1 - seqz s1, s1 - bne s1, x0, L_29 - la t0, sheet2 - la t2, sheet1 - li s4, 1 - j L_30 - L_29: - la t0, sheet1 - la t2, sheet2 - li s5, 1 - j L_31 - L_32: - li s3, 1 - j L_33 - L_34: - li a3, 1 - j L_35 - L_36: - lw s5, 0(s4) - lw s1, 0(s1) - xori s11, s1, 1 - seqz s11, s11 - lw s3, 0(s6) - lw s4, 0(s7) - lw s2, 0(s8) - lw s1, 0(s9) - addw s1, s1, s2 - addw s2, s1, s3 - lw s1, 0(s10) - addw s1, s2, s1 - addw s2, s1, s5 - lw s1, 0(a0) - addw s1, s2, s1 - addw s2, s1, s4 - lw s1, 0(a1) - addw s2, s2, s1 - xori s1, s2, 2 - seqz s1, s1 - bne s11, x0, L_37 - li s1, 0 - j L_38 - L_37: - L_38: - bne s1, x0, L_39 - xori s1, s2, 3 - seqz s1, s1 - bne s1, x0, L_40 - sw x0, 0(a2) - j L_41 - L_39: - li s1, 1 - sw s1, 0(a2) - L_42: - L_35: - la s1, width - lw s1, 0(s1) - slt s5, s1, a3 - seqz s5, s5 - addiw s1, a3, -1 - slliw s1, s1, 2 - add s9, a4, s1 - slliw s3, a3, 2 - add s8, a4, s3 - addiw a3, a3, 1 - slliw s2, a3, 2 - add s6, a4, s2 - add s10, a6, s1 - add s4, a6, s2 - add a0, a5, s1 - add s7, a5, s3 - add a1, a5, s2 - add s1, a6, s3 - add a2, a7, s3 - bne s5, x0, L_36 - mv s5, t1 - L_31: - addiw t1, s5, 1 - addiw s3, s5, -1 - li s4, 2000 - li s1, 2000 - li s2, 2000 - mulw s1, s5, s1 - add a7, t2, s1 - add a6, t0, s1 - mulw s1, t1, s2 - add a5, t0, s1 - la s2, height - mulw s1, s3, s4 - add a4, t0, s1 - lw s1, 0(s2) - slt s1, s1, s5 - seqz s1, s1 - bne s1, x0, L_34 - la s2, active - li s1, 2 - sw s1, 0(s2) - L_43: - la s2, steps - la s1, steps - lw s1, 0(s1) - addiw s1, s1, -1 - sw s1, 0(s2) - L_27: - la s1, steps - lw s1, 0(s1) - slt s1, x0, s1 - bne s1, x0, L_28 - li a0, 106 - call _sysy_stoptime - la s1, active - lw s1, 0(s1) - xori s1, s1, 2 - seqz s1, s1 - bne s1, x0, L_32 - j L_44 - L_40: - li s1, 1 - sw s1, 0(a2) - L_41: - j L_42 - L_45: - li a3, 1 - j L_46 - L_47: - lw s2, 0(s4) - lw s1, 0(s6) - addw s2, s2, s1 - lw s1, 0(s7) - addw s2, s2, s1 - lw s1, 0(s8) - addw s2, s2, s1 - lw s1, 0(s9) - addw s2, s2, s1 - lw s1, 0(s10) - addw s2, s2, s1 - lw s1, 0(s11) - addw s2, s2, s1 - lw s1, 0(a0) - addw s3, s2, s1 - lw s1, 0(a1) - xori s2, s1, 1 - seqz s2, s2 - xori s1, s3, 2 - seqz s1, s1 - bne s2, x0, L_48 - li s1, 0 - j L_49 - L_48: - L_49: - bne s1, x0, L_50 - xori s1, s3, 3 - seqz s1, s1 - bne s1, x0, L_51 - sw x0, 0(a2) - j L_52 - L_50: - li s1, 1 - sw s1, 0(a2) - L_53: - L_46: - la s1, width - lw s1, 0(s1) - slt s5, s1, a3 - seqz s5, s5 - addiw s1, a3, -1 - slliw s1, s1, 2 - add s4, a7, s1 - slliw s3, a3, 2 - add s6, a7, s3 - addiw a3, a3, 1 - slliw s2, a3, 2 - add s7, a7, s2 - add s8, a5, s1 - add s9, a5, s2 - add s10, a4, s1 - add s11, a4, s3 - add a0, a4, s2 - add a1, a5, s3 - add a2, a6, s3 - bne s5, x0, L_47 - mv s4, t1 - L_30: - li s5, 2000 - li s2, 2000 - li s3, 2000 - addiw t1, s4, 1 - addiw s1, s4, -1 - mulw s1, s1, s2 - add a7, t0, s1 - la s2, height - mulw s1, s4, s3 - add a6, t2, s1 - add a5, t0, s1 - lw s1, 0(s2) - slt s2, s1, s4 - mulw s1, t1, s5 - seqz s2, s2 - add a4, t0, s1 - bne s2, x0, L_45 - la s2, active - li s1, 1 - sw s1, 0(s2) - j L_43 - L_51: - li s1, 1 - sw s1, 0(a2) - L_52: - j L_53 - L_54: - li s2, 1 - j L_55 - L_56: - la s1, sheet1 - add s1, s1, s5 - add s2, s1, s3 - la s1, sheet2 - add s1, s1, s5 - add s1, s1, s3 - lw s1, 0(s1) - sw s1, 0(s2) - mv s2, s4 - L_55: - la s1, width - addiw s4, s2, 1 - slliw s3, s2, 2 - lw s1, 0(s1) - slt s1, s1, s2 - seqz s1, s1 - bne s1, x0, L_56 - mv s3, s6 - L_33: - li s1, 2000 - la s2, height - addiw s6, s3, 1 - mulw s5, s3, s1 - lw s1, 0(s2) - slt s1, s1, s3 - seqz s1, s1 - bne s1, x0, L_54 - L_44: - call put_map - li a0, 0 - ld s9, 0(sp) - ld s1, 8(sp) - ld s4, 16(sp) - ld s7, 24(sp) - ld ra, 32(sp) - ld s8, 40(sp) - ld s5, 48(sp) - ld s3, 56(sp) - ld s6, 64(sp) - ld s11, 72(sp) - ld s10, 80(sp) - ld s2, 88(sp) - addi sp, sp, 96 - ret - .size main, .-main - - -.text -.global __create_threads -.global __join_threads - - SYS_clone = 220 - CLONE_VM = 256 - SIGCHLD = 17 - __create_threads: - addi a0, a0, -1 - ble a0, zero, .ret_0 - mv a6, a0 - li a5, 0 - mv a1, sp - li a2, 0 - li a3, 0 - li a4, 0 - .L0_builtin: - li a0, (CLONE_VM | SIGCHLD) - li a7, SYS_clone - ecall - bne a0, zero, .ret_i - addi a5, a5, 1 - blt a5, a6, .L0_builtin - .ret_n: - mv a0, a6 - j .L1_builtin - .ret_0: - mv a0, zero - j .L1_builtin - .ret_i: - mv a0, a5 - .L1_builtin: - jr ra - - SYS_waitid = 95 - SYS_exit = 93 - P_ALL = 0 - WEXITED = 4 - __join_threads: - mv a4, a0 - addi a5, a1, -1 - beq a4, a5, .L2_builtin - li a0, P_ALL - li a1, 0 - li a2, 0 - li a3, WEXITED - li a7, SYS_waitid - ecall - .L2_builtin: - beq a4, zero, .L3_builtin - li a0, 0 - li a7, SYS_exit - ecall - .L3_builtin: - jr ra - - - __fill_zero_words: - ble a1, zero, .L8_builtin - addi a1, a1, -1 - slliw a1, a1, 2 - add a2, a1, a0 # 最后一次4字节 - addi a3, a2, -1 - andi a3, a3, -8 # 最后一次8字节 - andi a4, a0, 7 - beq a4, x0, .L4_builtin - - sw x0, 0(a0) - addi a0, a0, 4 - - .L4_builtin: - bgtu a0, a3, .L7_builtin - - .L5_builtin: - sd x0, 0(a0) - addi a0, a0, 8 - ble a0, a3, .L5_builtin - - .L7_builtin: - bgtu a0, a2, .L8_builtin # 如果不够最后一次4字节 - sw x0, 0(a0) - addi a0, a0, 4 - - .L8_builtin: - jr ra - - - - - .ident "SYSYC: (made by RRVM) 1.0.0" From 60287e93a1cffa9dc9147d3ecc543b0e93ce356b Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Tue, 13 Aug 2024 11:22:14 +0800 Subject: [PATCH 3/8] refa: refactor instruction scheduling --- .../instruction_scheduling/instr_schedule.rs | 3 + .../src/instruction_scheduling/instrdag.rs | 61 ++ backend/transform/src/instr_schedule.rs | 536 ------------------ backend/transform/src/instrdag.rs | 391 ------------- backend/transform/src/lib.rs | 174 +----- backend/transform/src/transformer.rs | 9 +- 6 files changed, 69 insertions(+), 1105 deletions(-) delete mode 100644 backend/transform/src/instr_schedule.rs delete mode 100644 backend/transform/src/instrdag.rs diff --git a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs index 721a2cc4..4538ff7d 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs @@ -549,6 +549,9 @@ pub fn instr_schedule_by_dag( &mut dag.call_related.clone(), // 是我call的顺序可能会调换,post_process 的时候和原本push进去的顺序不一致 dag.branch.clone(), &mut final_state.call_ids.clone(), +<<<<<<< HEAD dag.li_ret.clone(), +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) )) } diff --git a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs index 1cc0c006..1ea55e1b 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs @@ -6,7 +6,11 @@ use std::{ }; use instruction::riscv::{ +<<<<<<< HEAD reg::RiscvReg::{Fa0, A0, SP}, +======= + reg::RiscvReg::{A0, SP}, +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) riscvinstr::RiscvInstrTrait, value::RiscvTemp, RiscvInstr, @@ -47,14 +51,20 @@ pub struct InstrDag { pub branch: Option>, pub call_writes: Vec>, pub call_reads: Vec>, +<<<<<<< HEAD pub li_ret: Option>, +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) } fn preprocess_call( node: &RiscvNode, call_related: &mut Vec>>, // 换成一个 hashmap 用建完图之后的 node id 来索引 call_write: &mut Vec>, call_reads: &mut Vec>, +<<<<<<< HEAD li_ret: &mut Option>, +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ) -> Vec> { let mut instrs = Vec::new(); let mut save_instr = false; @@ -77,10 +87,13 @@ fn preprocess_call( my_call_related.push(i.clone()); push_this = true; continue; +<<<<<<< HEAD } else if let RiscvTemp::PhysReg(Fa0) = i.get_riscv_read()[0] { my_call_related.push(i.clone()); push_this = true; continue; +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) } else { call_write.push(None); } @@ -131,6 +144,7 @@ fn preprocess_call( } call_reads.push(riscv_reads.iter().cloned().collect()); } +<<<<<<< HEAD // 判断最后一条 if !instrs.is_empty() { let last_instr = instrs.pop().unwrap(); @@ -152,6 +166,8 @@ fn preprocess_call( instrs.push(last_instr); } } +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) instrs } pub fn postprocess_call( @@ -159,7 +175,10 @@ pub fn postprocess_call( call_related: &mut HashMap>>, branch_related: Option>, call_idxs: &mut Vec, +<<<<<<< HEAD li_ret: Option>, +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ) -> Vec> { let mut my_instrs = Vec::new(); for i in instrs { @@ -171,9 +190,12 @@ pub fn postprocess_call( my_instrs.push(i); } } +<<<<<<< HEAD if let Some(instr) = li_ret { my_instrs.push(instr); } +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) if let Some(instr) = branch_related { my_instrs.push(instr); } @@ -203,7 +225,10 @@ impl InstrDag { &mut call_related, &mut call_write, &mut call_reads, +<<<<<<< HEAD &mut li_ret, +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ); let ret_call_writes = call_write.clone(); let ret_call_reads = call_reads.clone(); @@ -225,6 +250,24 @@ impl InstrDag { } for (idx, instr) in processed_instrs.iter().rev().enumerate() { let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); +<<<<<<< HEAD +======= + if idx == 0 + && instr.get_riscv_write().len() == 1 + && (instr.get_riscv_write()[0] == RiscvTemp::PhysReg(A0) + || if let RiscvTemp::VirtReg(t) = &instr.get_riscv_write()[0] { + if let Some(pre) = t.pre_color { + pre == A0 + } else { + false + } + } else { + false + }) { + li_ret = Some(node.clone()); + } + +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) let mut instr_node_succ = Vec::new(); let instructions_write = instr.get_riscv_write().clone(); if !instr.is_call() { @@ -295,18 +338,33 @@ impl InstrDag { }); last_loads.clear(); last_call = Some(node.clone()); +<<<<<<< HEAD +======= + if let Some(ret_node) = li_ret.clone() { + instr_node_succ.push(ret_node.clone()); + ret_node.borrow_mut().pred.push(node.clone()); + } +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) for i in call_instrs.iter() { instr_node_succ.push(i.clone()); i.borrow_mut().pred.push(node.clone()); } call_instrs.push(node.clone()); +<<<<<<< HEAD } else if instr.is_load() { +======= + } else if instr.is_load().unwrap_or(false) { +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) if let Some(last_call) = last_call.clone() { instr_node_succ.push(last_call.clone()); last_call.borrow_mut().pred.push(node.clone()); } last_loads.push(node.clone()); +<<<<<<< HEAD } else if instr.is_store() { +======= + } else if instr.is_store().unwrap_or(false) { +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) instr_node_succ.extend(last_loads.iter().cloned()); last_loads.iter().for_each(|x| { x.borrow_mut().pred.push(node.clone()); @@ -343,7 +401,10 @@ impl InstrDag { branch: last_branch, call_reads: ret_call_reads, call_writes: ret_call_writes, +<<<<<<< HEAD li_ret, +======= +>>>>>>> 15ca5b3 (refa: refactor instruction scheduling) }) } pub fn assign_nodes(&mut self) { diff --git a/backend/transform/src/instr_schedule.rs b/backend/transform/src/instr_schedule.rs deleted file mode 100644 index c0ec4e13..00000000 --- a/backend/transform/src/instr_schedule.rs +++ /dev/null @@ -1,536 +0,0 @@ -use std::{ - cmp::{max, min}, - collections::{HashMap, VecDeque}, - fmt::Display, -}; - -use crate::{ - instrdag::{postprocess_call, InstrDag}, - Liveliness, RiscvInstr, -}; -use instruction::{ - riscv::{ - prelude::RiscvInstrTrait, - value::RiscvTemp::{self}, - }, - RiscvInstrSet, -}; -use utils::{ - SysycError, ADD_ALLOCATABLES, BFS_STATE_THRESHOLD, HARDWARE_PIPELINE_PARAM, - LIVE_THROUGH, NEAR_END, REDUCE_LIVE, REDUCE_SUB, SOFTWARE_PIPELINE_PARAM, - SUM_MIN_RATIO, -}; - -#[derive(Clone, PartialEq, Eq, Copy, Debug)] -enum AluKind { - Mem, - Normal, - Branch, - Float, - MulDiv, -} -#[derive(Clone, Copy, Debug)] -pub struct Alu { - kind: AluKind, - complete_cycle: usize, - is_fdiv: bool, -} -impl Alu { - fn new(kind: AluKind) -> Self { - Self { - kind, - complete_cycle: 0, // 开区间 - is_fdiv: false, - } - } -} -fn get_alukind(instr: &RiscvInstr) -> AluKind { - let v = instr.get_rtn_array(); - // println!("get_alukind: {} {:?}",instr,v); - if v[0] != 0 { - AluKind::Mem - } else if v[1] != 0 { - AluKind::Branch - } else if v[2] != 0 { - AluKind::MulDiv - } else if v[3] != 0 { - AluKind::Float - } else { - AluKind::Normal - } -} -// 当前惩罚策略:在指令为 instrs 的情况下,在运行每一条指令期间活跃的最大寄存器数目 -// 接受参数:dag:初始图,instrs:当前的指令序列,基本块内 SSA -// 实现硬件流水线的时候,要多返回一个 flight_time_increment -fn punishment( - dag: &InstrDag, - state: &State, - instr_id: usize, - my_reads: Vec, - my_writes: Vec, -) -> (i32, usize, usize, Alu) { - let instr = state.instrs.last().unwrap(); - let mut score = 0; - // 软件流水线的惩罚 - score += - (dag.nodes[instr_id].borrow().to_end as i32) * SOFTWARE_PIPELINE_PARAM; - for i in my_reads.iter() { - if state.liveliness_map.get(i).unwrap().use_num == 1 - && !state.liveliness_map.get(i).unwrap().is_liveout - { - score -= 1; - } - } - for i in my_writes.iter() { - if !state.liveliness_map.get(i).unwrap().is_livein { - score += 1; - } - } - // 判断选择这条指令之后,有多少节点可以变成可调度节点 - let new_allocatables = dag.nodes[instr_id] - .borrow() - .succ - .iter() - .filter(|x| state.indegs[&x.borrow().id] == 1) - .count(); - let alloc_score = -(new_allocatables as i32) * ADD_ALLOCATABLES; - // 判断使得寄存器生命周期尽快结束的惩罚,一方面可以判断 read/write 的寄存器的尽快结束之和,另一方面可以判断 read/write 的寄存器最小离结束的次数,这一段 read 和 write 都是加,是没问题的 - // 思考 live_through 这个参数定义了没用,该怎么用上 - let mut sum_uses: usize = my_reads - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_liveout { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .sum(); - let mut min_uses: usize = my_reads - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_liveout { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .min() - .unwrap_or(0); - sum_uses += my_writes - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_livein { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .sum::(); - min_uses = min( - my_writes - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_livein { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .min() - .unwrap_or(0), - min_uses, - ); - let mut end_live_score = (sum_uses as i32) * SUM_MIN_RATIO; - end_live_score += min_uses as i32; - // 判断对后继的影响 - let mut succ_sum = 0; - let mut succ_min = 0; - for i in dag.nodes[instr_id].borrow().succ.iter() { - let my_succ_reads={ - if i.borrow().instr.is_call() { - dag.call_reads[state.call_ids.len()].clone() - } else { - i.borrow().instr.get_riscv_read().clone() - }}; - succ_sum += my_succ_reads - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_liveout { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .sum::(); - succ_min = min( - my_succ_reads - .iter() - .map(|x| state.liveliness_map.get(x).unwrap().use_num) - .min() - .unwrap_or(0), - succ_min, - ); - // 对 write 寄存器的情况考虑如上 - let my_succ_writes = { - if i.borrow().instr.is_call() { - if let Some(tmp) = dag.call_writes[state.call_ids.len()] - { - vec![tmp] - } else { - Vec::new() - } - } else { - i.borrow().instr.get_riscv_write().clone() - }}; - succ_sum += my_succ_writes - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_livein { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .sum::(); - succ_min = min( - my_succ_writes - .iter() - .map(|x| { - if state.liveliness_map.get(x).unwrap().is_livein { - state.liveliness_map.get(x).unwrap().use_num + LIVE_THROUGH - } else { - state.liveliness_map.get(x).unwrap().use_num - } - }) - .min() - .unwrap_or(0), - succ_min, - ); - } - let mut succ_score = (succ_sum as i32) * SUM_MIN_RATIO; - // 算硬件流水线的惩罚 - let mut flight_time_incre = 1; - let ready_time = state.flight_time + flight_time_incre; - let mut flight_idx = 0; - let mut flight_unit = Alu::new(AluKind::Normal); - let old_max = state.alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); - // 增量,认为第一条指令在时刻1发射 - if get_alukind(instr) != AluKind::Normal { - for (idx, alu) in state.alus.iter().enumerate() { - if get_alukind(instr) == alu.kind { - if alu.complete_cycle > ready_time { - // wait - flight_time_incre = alu.complete_cycle - ready_time + 1; - } - flight_idx = idx; - flight_unit = Alu::new(alu.kind); - if instr.is_fdiv() { - flight_unit.is_fdiv = true; - } - flight_unit.complete_cycle = state.flight_time - + flight_time_incre - + instr.get_rtn_array()[4] as usize; - if instr.is_fdiv() && alu.is_fdiv { - flight_unit.complete_cycle += utils::FDIV_WAIT; - } - break; - } - } - } else { - // 从 alus[4],alus[5] 拿出 complete_time 更小的来考虑 - flight_idx = if state.alus[4].complete_cycle < state.alus[5].complete_cycle - { - 4 - } else { - 5 - }; - flight_unit = Alu::new(state.alus[flight_idx].kind); - if state.alus[flight_idx].complete_cycle > ready_time { - flight_time_incre = - state.alus[flight_idx].complete_cycle - ready_time + 1; - } - flight_unit.complete_cycle = - state.flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; - } - let time_incre = max(flight_unit.complete_cycle, old_max) - old_max; - - succ_score += succ_min as i32; - score = score * REDUCE_LIVE - + alloc_score * ADD_ALLOCATABLES - + end_live_score * NEAR_END - + succ_score * REDUCE_SUB - + time_incre as i32 * HARDWARE_PIPELINE_PARAM; - - (score, flight_time_incre, flight_idx, flight_unit) -} -#[derive(Clone)] -struct State { - instrs: RiscvInstrSet, - score: i32, - indegs: HashMap, // 把节点的 id 映射到入度 - liveliness_map: HashMap, - call_ids: Vec, - alus: [Alu; 6], - flight_time: usize, -} -impl Display for State { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "State: \n")?; - for i in self.instrs.iter() { - write!(f, "{}\n", i)?; - } - write!(f, "alus: \n")?; - for i in self.alus.iter() { - write!(f, "{:?} ", i)?; - } - write!( - f, - "score: {} flight_time: {}\n", - self.score, self.flight_time - )?; - Ok(()) - } -} -pub fn get_punishment_by_instrs(instr: &Vec>) -> i32 { - // 算出原始的 score - // 按照上面的方法算硬件流水线 - let mut alus = [ - Alu::new(AluKind::Mem), - Alu::new(AluKind::Branch), - Alu::new(AluKind::MulDiv), - Alu::new(AluKind::Float), - Alu::new(AluKind::Normal), - Alu::new(AluKind::Normal), - ]; - let mut flight_time = 0; - for instr in instr.iter() { - let mut flight_time_incre = 1; - let ready_time = flight_time + flight_time_incre; - if get_alukind(instr) != AluKind::Normal { - for alu in alus.iter_mut() { - if get_alukind(instr) == alu.kind { - if alu.complete_cycle > ready_time { - flight_time_incre = alu.complete_cycle - ready_time + 1; - } - if instr.is_fdiv() { - alu.is_fdiv = true; - } - alu.complete_cycle = - flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; - if instr.is_fdiv() && alu.is_fdiv { - alu.complete_cycle += utils::FDIV_WAIT; - } - break; - } - } - } else { - let flight_idx = if alus[4].complete_cycle < alus[5].complete_cycle { - 4 - } else { - 5 - }; - if alus[flight_idx].complete_cycle > ready_time { - flight_time_incre = alus[flight_idx].complete_cycle - ready_time + 1; - } - alus[flight_idx].complete_cycle = - flight_time + flight_time_incre + instr.get_rtn_array()[4] as usize; - } - flight_time += flight_time_incre; - } - let t = alus.iter().map(|x| x.complete_cycle).max().unwrap_or(0); - t as i32 * HARDWARE_PIPELINE_PARAM -} -// 1. 先不去 clone state,对于每个可以分配的 instruction 把 instr 先 push 再 pop 最后把 pop_front 得到的 State 再 push 回去 -// 2. 每一步的计算保留以下4个参数:total_punishment,state_idx,node_id,my_reads 最后根据 total_punishment 排序并且把前 BFS_STATE_THRESHOLD 给 push 进去 -pub fn instr_schedule_by_dag( - dag: InstrDag, - liveliness_map: HashMap, -) -> Result { - // 计算原始 punishment - let original_instrs: Vec<_> = - dag.nodes.iter().rev().map(|x| x.borrow().instr.clone()).collect(); - let original_punishment = get_punishment_by_instrs(&original_instrs); - let mut states = VecDeque::new(); - // calculate indegs - let mut indegs = HashMap::new(); - for node in dag.nodes.iter() { - indegs.insert(node.borrow().id, node.borrow().in_deg); - } - states.push_back(State { - instrs: Vec::new(), - score: 0, - indegs: indegs.clone(), - liveliness_map, - call_ids: Vec::new(), - alus: [ - Alu::new(AluKind::Mem), - Alu::new(AluKind::Branch), - Alu::new(AluKind::MulDiv), - Alu::new(AluKind::Float), - Alu::new(AluKind::Normal), - Alu::new(AluKind::Normal), - ], - flight_time: 0, - }); - let depth = dag.nodes.len(); // bfs 深度已知,是所需要调度的指令总数 - for _i in 0..depth { - let real_cnt = states.len(); - let mut keeps = Vec::new(); - for j in 0..real_cnt { - let mut state = states.pop_front().unwrap(); - let allocatables: Vec<_> = state - .indegs - .iter() - .filter(|(_k, v)| **v == 0) - .map(|(k, _)| *k) - .collect(); - for i in allocatables.iter() { - state.instrs.push(dag.nodes[*i].borrow().instr.clone()); - // get riscv reads and writes - let my_reads; - let my_writes; - if dag.nodes[*i].borrow().instr.is_call() { - //check state's call_id length - my_reads = dag.call_reads[state.call_ids.len()].clone(); - my_writes = if let Some(tmp) = dag.call_writes[state.call_ids.len()] { - vec![tmp] - } else { - Vec::new() - }; - } else { - my_reads = dag.nodes[*i].borrow().instr.get_riscv_read().clone(); - my_writes = dag.nodes[*i].borrow().instr.get_riscv_write().clone(); - } - let (punish, flight_time_incre, flight_idx, flight_unit) = - punishment(&dag, &state, *i, my_reads.clone(), my_writes.clone()); - let score = state.score + punish; - keeps.push((j, *i, score, flight_time_incre, flight_idx, flight_unit)); - state.instrs.pop(); - } - states.push_back(state); - } - if keeps.len() > BFS_STATE_THRESHOLD { - keeps.sort_by(|a, b| a.2.cmp(&b.2)); - keeps.truncate(BFS_STATE_THRESHOLD); - } - for i in 0..real_cnt { - // iterate the keeps - let cnts: Vec<_> = - keeps.iter().filter(|x| x.0 == i).map(|x| *x).collect(); - if cnts.len() == 0 { - states.pop_front(); - } else if cnts.len() == 1 { - let mut state = states.pop_front().unwrap(); - state.instrs.push(dag.nodes[cnts[0].1].borrow().instr.clone()); - if dag.nodes[cnts[0].1].borrow().instr.is_call() { - state.call_ids.push(cnts[0].1); - } - // calc my_reads - let my_reads = { - if state.instrs.last().unwrap().is_call() { - dag.call_reads[state.call_ids.len() - 1].clone() - } else { - dag.nodes[cnts[0].1].borrow().instr.get_riscv_read().clone() - }}; - // decl the use in new_state's liveliness_map - for i in my_reads.iter() { - state.liveliness_map.get_mut(i).unwrap().use_num -= 1; - } - state.indegs.remove(&cnts[0].1); - for succ in dag.nodes[cnts[0].1].borrow().succ.iter() { - let mut new_indeg = state.indegs.clone(); - new_indeg.insert( - succ.borrow().id, - new_indeg.get(&succ.borrow().id).unwrap() - 1, - ); - state.indegs = new_indeg; - } - state.flight_time += cnts[0].3; - state.alus[cnts[0].4] = cnts[0].5; - state.score = cnts[0].2; - states.push_back(state); - } else { - let mut state = states.pop_front().unwrap(); - for j in 0..cnts.len() - 1 { - let mut new_state = state.clone(); - new_state.instrs.push(dag.nodes[cnts[j].1].borrow().instr.clone()); - if dag.nodes[cnts[j].1].borrow().instr.is_call() { - new_state.call_ids.push(cnts[j].1); - } - // calc my_reads - let my_reads = { - if new_state.instrs.last().unwrap().is_call() { - dag.call_reads[new_state.call_ids.len() - 1].clone() - } else { - dag.nodes[cnts[j].1].borrow().instr.get_riscv_read().clone() - }}; - // decl the use in new_state's liveliness_map - for i in my_reads.iter() { - new_state.liveliness_map.get_mut(i).unwrap().use_num -= 1; - } - new_state.indegs.remove(&cnts[j].1); - for succ in dag.nodes[cnts[j].1].borrow().succ.iter() { - let mut new_indeg = new_state.indegs.clone(); - new_indeg.insert( - succ.borrow().id, - new_indeg.get(&succ.borrow().id).unwrap() - 1, - ); - new_state.indegs = new_indeg; - } - new_state.flight_time += cnts[j].3; - new_state.alus[cnts[j].4] = cnts[j].5; - new_state.score = cnts[j].2; - states.push_back(new_state); - } - // 最后一次不 clone 了 - state - .instrs - .push(dag.nodes[cnts[cnts.len() - 1].1].borrow().instr.clone()); - if dag.nodes[cnts[cnts.len() - 1].1].borrow().instr.is_call() { - state.call_ids.push(cnts[cnts.len() - 1].1); - } - // calc my_reads - let my_reads={ - if state.instrs.last().unwrap().is_call() { - dag.call_reads[state.call_ids.len() - 1].clone() - } else { - dag.nodes[cnts[cnts.len() - 1].1] - .borrow() - .instr - .get_riscv_read() - .clone() - }}; - // decl the use in new_state's liveliness_map - for i in my_reads.iter() { - state.liveliness_map.get_mut(i).unwrap().use_num -= 1; - } - state.indegs.remove(&cnts[cnts.len() - 1].1); - for succ in dag.nodes[cnts[cnts.len() - 1].1].borrow().succ.iter() { - let mut new_indeg = state.indegs.clone(); - new_indeg.insert( - succ.borrow().id, - new_indeg.get(&succ.borrow().id).unwrap() - 1, - ); - state.indegs = new_indeg; - } - state.flight_time += cnts[cnts.len() - 1].3; - state.alus[cnts[cnts.len() - 1].4] = cnts[cnts.len() - 1].5; - state.score = cnts[cnts.len() - 1].2; - states.push_back(state); - } - } - } - // state 排序 - states.make_contiguous().sort_by(|a, b| a.score.cmp(&b.score)); - let mut final_state = states.pop_front().unwrap(); - if final_state.score >= original_punishment { - final_state.instrs = original_instrs; - } - Ok(postprocess_call( - final_state.instrs, - &mut dag.call_related.clone(), // 是我call的顺序可能会调换,post_process 的时候和原本push进去的顺序不一致 - dag.branch.clone(), - &mut final_state.call_ids.clone(), - )) -} diff --git a/backend/transform/src/instrdag.rs b/backend/transform/src/instrdag.rs deleted file mode 100644 index 94441984..00000000 --- a/backend/transform/src/instrdag.rs +++ /dev/null @@ -1,391 +0,0 @@ -use std::{ - cell::RefCell, - cmp::max, - collections::{HashMap, HashSet}, - rc::Rc, -}; - -use instruction::riscv::{ - reg::RiscvReg::{A0, SP}, - riscvinstr::RiscvInstrTrait, - value::RiscvTemp, - RiscvInstr, -}; -use rrvm::RiscvNode; -use std::fmt; -use utils::SysycError; - -type Node = Rc>; -#[derive(Clone)] -pub struct InstrNode { - pub id: usize, - pub in_deg: usize, - pub instr: RiscvInstr, - pub succ: Vec, - pub last_use: usize, - pub pred: Vec, - pub to_end: usize, -} -impl InstrNode { - pub fn new(instr: &RiscvInstr, id: usize) -> Self { - Self { - id, - in_deg: 0, - instr: instr.clone(), - succ: Vec::new(), - last_use: 0, - pred: Vec::new(), - to_end: 0, - } - } -} - -#[derive(Clone)] -pub struct InstrDag { - pub nodes: Vec, - pub call_related: HashMap>>, - pub branch: Option>, - pub call_writes: Vec>, - pub call_reads: Vec>, -} -fn preprocess_call( - node: &RiscvNode, - call_related: &mut Vec>>, // 换成一个 hashmap 用建完图之后的 node id 来索引 - call_write: &mut Vec>, - call_reads: &mut Vec>, -) -> Vec> { - let mut instrs = Vec::new(); - let mut save_instr = false; - let mut my_call_related = Vec::new(); - let mut is_last_restore = false; - let mut push_this = false; - for (idx, i) in node.borrow().instrs.iter().enumerate() { - if push_this { - push_this = false; - my_call_related.push(i.clone()); - call_write.push(Some(i.get_riscv_write()[0])); - call_related.push(my_call_related); - my_call_related = Vec::new(); - continue; - } - if is_last_restore { - is_last_restore = false; - if i.get_riscv_read().len() == 1 { - if let RiscvTemp::PhysReg(A0) = i.get_riscv_read()[0] { - my_call_related.push(i.clone()); - push_this = true; - continue; - } else { - call_write.push(None); - } - } else { - call_write.push(None); - } - call_related.push(my_call_related); - my_call_related = Vec::new(); - } - if i.is_save() { - save_instr = true; - my_call_related.push(i.clone()); - } else if i.is_restore() { - save_instr = false; - my_call_related.push(i.clone()); - is_last_restore = true; - if idx == node.borrow().instrs.len() - 1 { - call_related.push(my_call_related); - call_write.push(None); - break; - } - } else if i.is_call() { - instrs.push(i.clone()); - my_call_related.push(i.clone()); - } else if save_instr { - my_call_related.push(i.clone()); - } else { - instrs.push(i.clone()); - } - } - // process call writes and call reads - for call_instrs in call_related.iter() { - // 获取所有 instr 中的riscv_reads 的并集 - let mut riscv_reads = HashSet::new(); - // 先把 SP 扔进 riscv_reads - riscv_reads.insert(RiscvTemp::PhysReg(SP)); - for instr in call_instrs.iter() { - riscv_reads.extend(instr.get_riscv_read().iter().cloned()); - } - // 在 riscv_read 中删除 call 指令前传 param 的时候写的寄存器 - for instr in call_instrs.iter() { - if instr.is_call() { - break; - } - for i in instr.get_riscv_write().iter() { - riscv_reads.remove(i); - } - } - call_reads.push(riscv_reads.iter().cloned().collect()); - } - instrs -} -pub fn postprocess_call( - instrs: Vec>, - call_related: &mut HashMap>>, - branch_related: Option>, - call_idxs: &mut Vec, -) -> Vec> { - let mut my_instrs = Vec::new(); - for i in instrs { - if i.is_call() { - my_instrs.append( - &mut call_related.get(&call_idxs.pop().unwrap()).unwrap().clone(), - ); - } else { - my_instrs.push(i); - } - } - if let Some(instr) = branch_related { - my_instrs.push(instr); - } - my_instrs -} - -impl InstrDag { - pub fn new(node: &RiscvNode) -> Result { - let mut nodes: Vec = Vec::new(); - let mut defs: HashMap>> = HashMap::new(); - let mut uses: HashMap>>> = - HashMap::new(); - let mut last_call: Option = None; - let mut last_loads: Vec = Vec::new(); - let mut call_related = Vec::new(); - let mut last_uses = HashMap::new(); - let mut last_branch: Option> = None; - let mut call_write = Vec::new(); - let mut call_reads = Vec::new(); - let mut li_ret = None; - let mut call_related_map = HashMap::new(); - let mut call_instrs: Vec>> = Vec::new(); - let mut my_call_write = None; - // preprocessing call related: 把 call 前后的 从 save 到 restore 的若干条指令保存在 call_related 里面,然后加入到 is_filtered_idx 之后遍历instrs 的时候遇到就直接continue - let mut processed_instrs = preprocess_call( - node, - &mut call_related, - &mut call_write, - &mut call_reads, - ); - let ret_call_writes = call_write.clone(); - let ret_call_reads = call_reads.clone(); - if processed_instrs.len() > 0 { - let last_instr = processed_instrs.last().unwrap(); - if last_instr.is_branch() { - last_branch = Some(last_instr.clone()); - let _ = processed_instrs.pop(); - } - } - // 传参 call 回去 param read 会需要记录 - for i in call_related.iter() { - let mut riscv_writes = HashSet::new(); - let mut riscv_reads = HashSet::new(); - for j in i.iter() { - riscv_writes.extend(j.get_riscv_write().iter().cloned()); - riscv_reads.extend(j.get_riscv_read().iter().cloned()); - } - } - for (idx, instr) in processed_instrs.iter().rev().enumerate() { - let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); - if idx == 0 { - if instr.get_riscv_write().len() == 1 - && instr.get_riscv_write()[0] == RiscvTemp::PhysReg(A0) - { - li_ret = Some(node.clone()); - } - } - let mut instr_node_succ = Vec::new(); - let instructions_write = instr.get_riscv_write().clone(); - if instr.is_call() == false { - for instr_write in instructions_write { - instr_node_succ.extend( - uses.get(&instr_write).unwrap_or(&Vec::new()).iter().cloned(), - ); - // 同时 extend predecessors - for i in uses.get(&instr_write).unwrap_or(&Vec::new()).iter() { - i.borrow_mut().pred.push(node.clone()); - } - uses.remove(&instr_write); - } - } else { - let tmp = call_write.pop().unwrap(); - my_call_write = tmp.clone(); - if let Some(tmp) = tmp { - instr_node_succ - .extend(uses.get(&tmp).unwrap_or(&Vec::new()).iter().cloned()); - uses.get(&tmp).unwrap_or(&Vec::new()).iter().for_each(|x| { - x.borrow_mut().pred.push(node.clone()); - }); - uses.remove(&tmp); - } - } - let instr_read = instr.get_riscv_read().clone(); - if instr.is_call() == false { - for instr_read_temp in instr_read.iter() { - if let Some(def_instr) = defs.get(instr_read_temp) { - instr_node_succ.push(def_instr.clone()); - def_instr.borrow_mut().pred.push(node.clone()); - // println!("in instr def extending {}->{}",node.borrow().id,def_instr.borrow().id); - } - uses.entry(*instr_read_temp).or_default().push(node.clone()); - if !last_uses.contains_key(instr_read_temp) { - last_uses.insert(*instr_read_temp, idx); - } - } - } else { - let tmp = call_reads.pop().unwrap(); - for instr_read_temp in tmp.iter() { - if let Some(def_instr) = defs.get(instr_read_temp) { - instr_node_succ.push(def_instr.clone()); - def_instr.borrow_mut().pred.push(node.clone()); - } - uses.entry(*instr_read_temp).or_default().push(node.clone()); - if !last_uses.contains_key(instr_read_temp) { - last_uses.insert(*instr_read_temp, idx); - } - } - } - // init defs - if instr.is_call() == false { - let instructions_write = instr.get_riscv_write().clone(); - for instr_write in instructions_write.iter() { - defs.insert(*instr_write, node.clone()); - } - } else { - if let Some(tmp) = my_call_write { - defs.insert(tmp, node.clone()); - } - } - // 处理 load call store 指令的依赖关系 - if instr.is_call() { - // 先考虑一下那个最后一条 mov other reg a0 - instr_node_succ.extend(last_loads.iter().cloned()); - last_loads.iter().for_each(|x| { - x.borrow_mut().pred.push(node.clone()); - }); - // println!("in is_call {} extending loads {:?}",node.borrow().id,last_loads.iter().map(|x| x.borrow().id).collect::>()); - last_loads.clear(); - last_call = Some(node.clone()); - if let Some(ret_node) = li_ret.clone() { - instr_node_succ.push(ret_node.clone()); - ret_node.borrow_mut().pred.push(node.clone()); - } - for i in call_instrs.iter() { - instr_node_succ.push(i.clone()); - i.borrow_mut().pred.push(node.clone()); - } - call_instrs.push(node.clone()); - - } else if instr.is_load().unwrap_or(false) { - if let Some(last_call) = last_call.clone() { - instr_node_succ.push(last_call.clone()); - last_call.borrow_mut().pred.push(node.clone()); - } - last_loads.push(node.clone()); - } else if instr.is_store().unwrap_or(false) { - instr_node_succ.extend(last_loads.iter().cloned()); - last_loads.iter().for_each(|x| { - x.borrow_mut().pred.push(node.clone()); - }); - last_loads.clear(); - last_call = Some(node.clone()); - for i in call_instrs.iter() { - instr_node_succ.push(i.clone()); - i.borrow_mut().pred.push(node.clone()); - } - call_instrs.push(node.clone()); - } - node.borrow_mut().succ = instr_node_succ; - nodes.push(node); - } - for node in nodes.iter() { - for succ in node.borrow().succ.iter() { - succ.borrow_mut().in_deg += 1; - } - } - for (index, instr) in nodes.iter_mut().enumerate().rev() { - instr.borrow_mut().last_use += - last_uses.iter().filter(|x| *x.1 == index).count(); - } - // construct hashmap,key is the id of the nodes that are call, values are the call instructions - for (idx, instrs) in nodes.iter().enumerate().rev() { - if instrs.borrow().instr.is_call() { - call_related_map.insert(idx, call_related.pop().unwrap()); - } - } - Ok(Self { - nodes, - call_related: call_related_map, - branch: last_branch, - call_reads: ret_call_reads, - call_writes: ret_call_writes, - }) - } - pub fn assign_nodes(&mut self) { - // 先备份一遍所有 node 的 indegs - let indegs = - self.nodes.iter().map(|x| x.borrow().in_deg).collect::>(); - // 开始遍历 - let mut stack_ = Vec::new(); - for i in self.nodes.iter() { - if i.borrow().succ.len() == 0 { - stack_.push(i.clone()); - // get latency - let siz = i.borrow().instr.get_rtn_array()[4] as usize; - i.borrow_mut().to_end = siz; - } - } - while stack_.len() > 0 { - let node = stack_.pop().unwrap(); - for i in node.borrow().pred.iter() { - let new_end = max( - i.borrow().to_end, - node.borrow().to_end + i.borrow().instr.get_rtn_array()[4] as usize, - ); - i.borrow_mut().to_end = new_end; - i.borrow_mut().in_deg -= 1; - if i.borrow().in_deg == 0 { - stack_.push(i.clone()); - } - } - } - // 对每个点恢复 in_deg - for (i, j) in self.nodes.iter().zip(indegs.iter()) { - i.borrow_mut().in_deg = *j; - } - } -} -impl fmt::Display for InstrDag { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for node in &self.nodes { - let instr_node = node.borrow(); - writeln!(f, "Node ID: {}", instr_node.id)?; - writeln!(f, "In-degree: {}", instr_node.in_deg)?; - writeln!(f, "Instruction: {}", instr_node.instr)?; - writeln!( - f, - "Successors: {:?}", - instr_node.succ.iter().map(|x| x.borrow().id).collect::>() - )?; - // print successor's in degrees - writeln!( - f, - "Successors' In-degree: {:?}", - instr_node - .succ - .iter() - .map(|x| x.borrow().in_deg) - .collect::>() - )?; - writeln!(f, "Last Use: {}", instr_node.last_use)?; - writeln!(f, "---------------------------")?; - } - Ok(()) - } -} diff --git a/backend/transform/src/lib.rs b/backend/transform/src/lib.rs index 4c813d1e..ca5a59a8 100644 --- a/backend/transform/src/lib.rs +++ b/backend/transform/src/lib.rs @@ -1,20 +1,11 @@ -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, - io::{self, Write}, - rc::Rc, -}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use instruction::{riscv::prelude::*, temp::TempManager}; use llvm::Value; -use utils::SysycError::RiscvGenError; use rrvm::prelude::*; use transformer::to_riscv; -use utils::{ - errors::Result, - SCHEDULE_THRESHOLD, -}; +use utils::{errors::Result, SysycError::RiscvGenError}; pub mod remove_phi; pub mod transformer; @@ -24,72 +15,11 @@ pub fn get_functions( funcs: Vec, ) -> Result<()> { for func in funcs { - let converted_func = convert_func(func, &mut program.temp_mgr)?; - println!("--- before instr schedule: ---"); - for i in converted_func.cfg.blocks.iter() { - for j in i.borrow().instrs.iter() { - println!("{}", j); - } - println!("------------block end-------------"); - // println!( - // "jump instruction: {}", - // i.borrow().jump_instr.as_ref().unwrap() - // ); - } - println!("---end---"); - io::stdout().flush().unwrap(); - let func = instr_schedule( - converted_func, - &mut program.temp_mgr, - )?; - println!("--------"); - for i in func.cfg.blocks.iter() { - for j in i.borrow().instrs.iter() { - println!("{}", j); - } - println!("------------block end-------------"); - } - println!("--------"); - program.funcs.push(func); + program.funcs.push(convert_func(func, &mut program.temp_mgr)?); } Ok(()) } -pub fn instr_schedule( - func: RiscvFunc, - mgr: &mut TempManager, -) -> Result { - func.cfg.clear_data_flow(); - func.cfg.analysis(); - let live_ins=[]; - let live_outs=[]; - let mut new_blocks = Vec::new(); - for (idx, node) in func.cfg.blocks.iter().enumerate() { - let nodes = - instr_schedule_block(node, &live_ins[idx], &live_outs[idx], mgr)?; - new_blocks.extend(nodes); - } - Ok(RiscvFunc { - total: mgr.total, - spills: 0, - cfg: RiscvCFG { blocks: new_blocks }, - name: func.name, - params: func.params, - ret_type: func.ret_type, - }) -} -pub fn instr_schedule_block( - riscv_node: &RiscvNode, - live_ins: &HashSet, - live_outs: &HashSet, - mgr: &mut TempManager, -) -> Result> { - if riscv_node.borrow().instrs.len() >= SCHEDULE_THRESHOLD { - return Ok(vec![riscv_node.clone()]); - } - transform_basic_block_by_pipelining(riscv_node, live_ins, live_outs, mgr) - .map(|v| vec![v]) -} pub fn convert_func( func: LlvmFunc, mgr: &mut TempManager, @@ -144,7 +74,6 @@ pub fn convert_func( for (u, v) in edge { force_link_node(table.get(&u).unwrap(), table.get(&v).unwrap()) } - Ok(RiscvFunc { total: mgr.total, spills: 0, @@ -153,103 +82,8 @@ pub fn convert_func( params: func.params, ret_type: func.ret_type, }) - } - - -fn transform_basic_block_by_pipelining( - node: &RiscvNode, - live_in: &HashSet, - live_out: &HashSet, - _mgr: &mut TempManager, -) -> Result { - let mut instr_dag = InstrDag::new(node)?; - let liveliness_map = get_liveliness_map(&instr_dag, live_in, live_out); - instr_dag.assign_nodes(); - node.borrow_mut().instrs = instr_schedule_by_dag(instr_dag, liveliness_map)?; - Ok(node.clone()) -} -#[derive(Clone, Debug)] -pub struct Liveliness { - is_livein: bool, - is_liveout: bool, - use_num: usize, -} -fn get_liveliness_map( - node: &InstrDag, - live_in: &HashSet, - live_out: &HashSet, -) -> HashMap { - let mut map = HashMap::new(); - let mut call_reads = node.call_reads.clone(); - call_reads.reverse(); - let mut call_writes = node.call_writes.clone(); - call_writes.reverse(); - // 它这里要求是正序遍历,所以遍历次序是和 node 的顺序反的,需要 iter.rev(),同样,call_reads,call_writes 也要reverse再pop - for instrnode in node.nodes.iter().rev() { - let instr = &instrnode.borrow().instr; - if !instr.is_call() { - for tmp in instr.get_riscv_read().iter() { - map - .entry(*tmp) - .or_insert(Liveliness { - is_livein: false, - is_liveout: false, - use_num: 0, - }) - .use_num += 1; - } - for tmp in instr.get_riscv_write().iter() { - map.entry(*tmp).or_insert(Liveliness { - is_livein: false, - is_liveout: false, - use_num: 0, - }); - } - } else { - let call_read = call_reads.pop().unwrap(); - for tmp in call_read.iter() { - map - .entry(*tmp) - .or_insert(Liveliness { - is_livein: false, - is_liveout: false, - use_num: 0, - }) - .use_num += 1; - } - let call_write = call_writes.pop().unwrap(); - for tmp in call_write.iter() { - map.entry(*tmp).or_insert(Liveliness { - is_livein: false, - is_liveout: false, - use_num: 0, - }); - } - } - } - // do live_in - for tmp in live_in.iter() { - map - .entry(*tmp) - .or_insert(Liveliness { - is_livein: true, - is_liveout: false, - use_num: 0, - }) - .is_livein = true; - } - for tmp in live_out.iter() { - map - .entry(*tmp) - .or_insert(Liveliness { - is_livein: false, - is_liveout: true, - use_num: 0, - }) - .is_liveout = true; - } - map } + fn transform_basicblock( node: &LlvmNode, mgr: &mut TempManager, diff --git a/backend/transform/src/transformer.rs b/backend/transform/src/transformer.rs index 847500ec..da2a8dec 100644 --- a/backend/transform/src/transformer.rs +++ b/backend/transform/src/transformer.rs @@ -1,8 +1,4 @@ -use instruction::{ - riscv::{convert::*, RiscvInstr}, - temp::TempManager, - RiscvInstrSet, -}; +use instruction::{riscv::convert::*, temp::TempManager, RiscvInstrSet}; use llvm::{LlvmInstr, LlvmInstrVariant}; use utils::errors::Result; @@ -27,6 +23,3 @@ pub fn to_riscv( }?; Ok(riscv_instr) } -pub fn to_rt_type(instr: &RiscvInstr) -> [i32; 5] { - instr.get_rtn_array() -} From 79f7ae555f71f2f4bf3007cd4b7a674a2a9a06e8 Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Thu, 15 Aug 2024 18:11:05 +0800 Subject: [PATCH 4/8] fix: bug fix on instr-scheduling --- .../instruction_scheduling/instr_schedule.rs | 3 -- .../src/instruction_scheduling/instrdag.rs | 43 +------------------ 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs index 4538ff7d..721a2cc4 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs @@ -549,9 +549,6 @@ pub fn instr_schedule_by_dag( &mut dag.call_related.clone(), // 是我call的顺序可能会调换,post_process 的时候和原本push进去的顺序不一致 dag.branch.clone(), &mut final_state.call_ids.clone(), -<<<<<<< HEAD dag.li_ret.clone(), -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) )) } diff --git a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs index 1ea55e1b..e54c2efd 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs @@ -6,11 +6,7 @@ use std::{ }; use instruction::riscv::{ -<<<<<<< HEAD reg::RiscvReg::{Fa0, A0, SP}, -======= - reg::RiscvReg::{A0, SP}, ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) riscvinstr::RiscvInstrTrait, value::RiscvTemp, RiscvInstr, @@ -51,20 +47,14 @@ pub struct InstrDag { pub branch: Option>, pub call_writes: Vec>, pub call_reads: Vec>, -<<<<<<< HEAD pub li_ret: Option>, -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) } fn preprocess_call( node: &RiscvNode, call_related: &mut Vec>>, // 换成一个 hashmap 用建完图之后的 node id 来索引 call_write: &mut Vec>, call_reads: &mut Vec>, -<<<<<<< HEAD li_ret: &mut Option>, -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ) -> Vec> { let mut instrs = Vec::new(); let mut save_instr = false; @@ -87,13 +77,10 @@ fn preprocess_call( my_call_related.push(i.clone()); push_this = true; continue; -<<<<<<< HEAD } else if let RiscvTemp::PhysReg(Fa0) = i.get_riscv_read()[0] { my_call_related.push(i.clone()); push_this = true; continue; -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) } else { call_write.push(None); } @@ -144,7 +131,6 @@ fn preprocess_call( } call_reads.push(riscv_reads.iter().cloned().collect()); } -<<<<<<< HEAD // 判断最后一条 if !instrs.is_empty() { let last_instr = instrs.pop().unwrap(); @@ -155,6 +141,7 @@ fn preprocess_call( || if let RiscvTemp::VirtReg(t) = &last_instr.get_riscv_write()[0] { if let Some(pre) = t.pre_color { (pre == A0) || (pre == Fa0) + } else { false } @@ -166,8 +153,6 @@ fn preprocess_call( instrs.push(last_instr); } } -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) instrs } pub fn postprocess_call( @@ -175,10 +160,7 @@ pub fn postprocess_call( call_related: &mut HashMap>>, branch_related: Option>, call_idxs: &mut Vec, -<<<<<<< HEAD li_ret: Option>, -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ) -> Vec> { let mut my_instrs = Vec::new(); for i in instrs { @@ -190,12 +172,9 @@ pub fn postprocess_call( my_instrs.push(i); } } -<<<<<<< HEAD if let Some(instr) = li_ret { my_instrs.push(instr); } -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) if let Some(instr) = branch_related { my_instrs.push(instr); } @@ -225,10 +204,7 @@ impl InstrDag { &mut call_related, &mut call_write, &mut call_reads, -<<<<<<< HEAD &mut li_ret, -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) ); let ret_call_writes = call_write.clone(); let ret_call_reads = call_reads.clone(); @@ -250,8 +226,6 @@ impl InstrDag { } for (idx, instr) in processed_instrs.iter().rev().enumerate() { let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); -<<<<<<< HEAD -======= if idx == 0 && instr.get_riscv_write().len() == 1 && (instr.get_riscv_write()[0] == RiscvTemp::PhysReg(A0) @@ -267,7 +241,6 @@ impl InstrDag { li_ret = Some(node.clone()); } ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) let mut instr_node_succ = Vec::new(); let instructions_write = instr.get_riscv_write().clone(); if !instr.is_call() { @@ -338,33 +311,22 @@ impl InstrDag { }); last_loads.clear(); last_call = Some(node.clone()); -<<<<<<< HEAD -======= if let Some(ret_node) = li_ret.clone() { instr_node_succ.push(ret_node.clone()); ret_node.borrow_mut().pred.push(node.clone()); } ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) for i in call_instrs.iter() { instr_node_succ.push(i.clone()); i.borrow_mut().pred.push(node.clone()); } call_instrs.push(node.clone()); -<<<<<<< HEAD } else if instr.is_load() { -======= - } else if instr.is_load().unwrap_or(false) { ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) if let Some(last_call) = last_call.clone() { instr_node_succ.push(last_call.clone()); last_call.borrow_mut().pred.push(node.clone()); } last_loads.push(node.clone()); -<<<<<<< HEAD } else if instr.is_store() { -======= - } else if instr.is_store().unwrap_or(false) { ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) instr_node_succ.extend(last_loads.iter().cloned()); last_loads.iter().for_each(|x| { x.borrow_mut().pred.push(node.clone()); @@ -401,10 +363,7 @@ impl InstrDag { branch: last_branch, call_reads: ret_call_reads, call_writes: ret_call_writes, -<<<<<<< HEAD li_ret, -======= ->>>>>>> 15ca5b3 (refa: refactor instruction scheduling) }) } pub fn assign_nodes(&mut self) { From d04e36fe100ccd0b7b6d4d4f7af76a1423ca0c9b Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Thu, 15 Aug 2024 20:33:22 +0800 Subject: [PATCH 5/8] fix: bug fix on float return vals --- backend/pre_optimizer/src/instruction_scheduling/instrdag.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs index e54c2efd..bb1d5c83 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs @@ -141,7 +141,10 @@ fn preprocess_call( || if let RiscvTemp::VirtReg(t) = &last_instr.get_riscv_write()[0] { if let Some(pre) = t.pre_color { (pre == A0) || (pre == Fa0) +<<<<<<< HEAD +======= +>>>>>>> d5782f4 (fix: bug fix on float return vals) } else { false } From e7332ea72d11743c7a83bc23e9a5fb0f775fbde1 Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Sun, 18 Aug 2024 03:59:21 +0800 Subject: [PATCH 6/8] style: fix some design problems --- utils/instruction/src/riscv/riscvinstr.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/utils/instruction/src/riscv/riscvinstr.rs b/utils/instruction/src/riscv/riscvinstr.rs index a0a1ef11..97c3665f 100644 --- a/utils/instruction/src/riscv/riscvinstr.rs +++ b/utils/instruction/src/riscv/riscvinstr.rs @@ -1,8 +1,4 @@ -use std::{ - collections::HashMap, - fmt::Display, - ops::{Add, Sub}, -}; +use std::{collections::HashMap, fmt::Display}; use sysyc_derive::UseTemp; use utils::{mapper::LabelMapper, InstrTrait, Label, UseTemp, RTN}; From d6c8f58c8424e3c77164bedda140b5eb99d43633 Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Sun, 18 Aug 2024 20:29:42 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat:=20bug=20fix=20on=20instruction=20sche?= =?UTF-8?q?duling=20debug=20mode=20(=E7=8E=B0=E5=9C=A8=E7=B3=BB=E6=95=B0?= =?UTF-8?q?=E6=98=AF=E5=AF=84=E5=AD=98=E5=99=A8=E9=98=B2=E6=AD=A2=E6=BA=A2?= =?UTF-8?q?=E5=87=BA=E7=9A=84=E7=B3=BB=E6=95=B0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../instruction_scheduling/instr_schedule.rs | 2 +- .../src/instruction_scheduling/instrdag.rs | 22 ++++++++++++++----- utils/src/constants.rs | 10 ++++----- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs index 721a2cc4..f4507eb8 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instr_schedule.rs @@ -513,7 +513,7 @@ pub fn instr_schedule_by_dag( // calc my_reads let my_reads = { if state.instrs.last().unwrap().is_call() { - dag.call_reads.last().unwrap().clone() + dag.call_reads[state.call_ids.len() - 1].clone() } else { dag.nodes[*instr_idx].borrow().instr.get_riscv_read().clone() } diff --git a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs index bb1d5c83..9eaecda0 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs @@ -25,6 +25,7 @@ pub struct InstrNode { pub last_use: usize, pub pred: Vec, pub to_end: usize, + pub out_deg: usize, } impl InstrNode { pub fn new(instr: &RiscvInstr, id: usize) -> Self { @@ -36,6 +37,7 @@ impl InstrNode { last_use: 0, pred: Vec::new(), to_end: 0, + out_deg: 0, } } } @@ -201,6 +203,10 @@ impl InstrDag { let mut call_related_map = HashMap::new(); let mut call_instrs: Vec>> = Vec::new(); let mut my_call_write = None; + // eprintln!("instrs before preprocess_call"); + // for i in node.borrow().instrs.iter() { + // eprintln!("{}", i); + // } // preprocessing call related: 把 call 前后的 从 save 到 restore 的若干条指令保存在 call_related 里面,然后加入到 is_filtered_idx 之后遍历instrs 的时候遇到就直接continue let mut processed_instrs = preprocess_call( node, @@ -346,9 +352,11 @@ impl InstrDag { nodes.push(node); } for node in nodes.iter() { + let len = node.borrow().succ.len(); for succ in node.borrow().succ.iter() { succ.borrow_mut().in_deg += 1; } + node.borrow_mut().out_deg += len; } for (index, instr) in nodes.iter_mut().enumerate().rev() { instr.borrow_mut().last_use += @@ -370,10 +378,12 @@ impl InstrDag { }) } pub fn assign_nodes(&mut self) { + // 这个函数是软流水函数 实际不参与运算 // 先备份一遍所有 node 的 indegs - let indegs = - self.nodes.iter().map(|x| x.borrow().in_deg).collect::>(); + let out_degs = + self.nodes.iter().map(|x| x.borrow().out_deg).collect::>(); // 开始遍历 + let mut stack_ = Vec::new(); for i in self.nodes.iter() { if i.borrow().succ.is_empty() { @@ -390,15 +400,15 @@ impl InstrDag { node.borrow().to_end + i.borrow().instr.get_rtn_array()[4] as usize, ); i.borrow_mut().to_end = new_end; - i.borrow_mut().in_deg -= 1; - if i.borrow().in_deg == 0 { + i.borrow_mut().out_deg -= 1; + if i.borrow().out_deg == 0 { stack_.push(i.clone()); } } } // 对每个点恢复 in_deg - for (i, j) in self.nodes.iter().zip(indegs.iter()) { - i.borrow_mut().in_deg = *j; + for (i, j) in self.nodes.iter().zip(out_degs.iter()) { + i.borrow_mut().out_deg = *j; } } } diff --git a/utils/src/constants.rs b/utils/src/constants.rs index d694b4ed..2aeb0583 100644 --- a/utils/src/constants.rs +++ b/utils/src/constants.rs @@ -32,13 +32,13 @@ pub static DEPENDENCY_EXPLORE_DEPTH: i32 = 10; // software pipelining 过程中 pub static BLOCKSIZE_THRESHOLD: usize = 100; // software pipelining 判断如果基本本块大小超了 BLOCKSIZE_THRESHOLD 后就不进行针对基本本块的优化 pub static BFS_STATE_THRESHOLD: usize = 9; // 在 instr_scheduling 中,每轮 bfs 所保留的状态的阈值 // for instruction scheduling: register punishment -pub static ADD_ALLOCATABLES: i32 = 0; -pub static NEAR_END: i32 = 0; // 寄存器生命周期更快结束的指令优先 -pub static REDUCE_SUB: i32 = 0; // 后继中的节点对应指令,寄存器生命周期更快结束的指令优先 -pub static REDUCE_LIVE: i32 = 0; +pub static ADD_ALLOCATABLES: i32 = 20; +pub static NEAR_END: i32 = 10; // 寄存器生命周期更快结束的指令优先 +pub static REDUCE_SUB: i32 = 4; // 后继中的节点对应指令,寄存器生命周期更快结束的指令优先 +pub static REDUCE_LIVE: i32 = 2; pub static LIVE_THROUGH: usize = 30; pub static SUM_MIN_RATIO: i32 = 1; pub static SCHEDULE_THRESHOLD: usize = 15000; pub static SOFTWARE_PIPELINE_PARAM: i32 = 0; // 拓扑排序后软流水的权重 -pub static HARDWARE_PIPELINE_PARAM: i32 = 1; // 拓扑排序后硬件流水的权重 +pub static HARDWARE_PIPELINE_PARAM: i32 = 0; // 拓扑排序后硬件流水的权重 pub static FDIV_WAIT: usize = 20; // fdiv 的 repeat rate From b87c3a95747e4232cd64eb889da629816534525d Mon Sep 17 00:00:00 2001 From: Rosayxy Date: Sun, 18 Aug 2024 23:27:50 +0800 Subject: [PATCH 8/8] feat: adjust constants to both hardware pipelining and registers --- .../src/instruction_scheduling/instrdag.rs | 26 ------------------- utils/src/constants.rs | 14 +++++----- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs index 9eaecda0..6dea6969 100644 --- a/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs +++ b/backend/pre_optimizer/src/instruction_scheduling/instrdag.rs @@ -143,10 +143,6 @@ fn preprocess_call( || if let RiscvTemp::VirtReg(t) = &last_instr.get_riscv_write()[0] { if let Some(pre) = t.pre_color { (pre == A0) || (pre == Fa0) -<<<<<<< HEAD - -======= ->>>>>>> d5782f4 (fix: bug fix on float return vals) } else { false } @@ -203,10 +199,6 @@ impl InstrDag { let mut call_related_map = HashMap::new(); let mut call_instrs: Vec>> = Vec::new(); let mut my_call_write = None; - // eprintln!("instrs before preprocess_call"); - // for i in node.borrow().instrs.iter() { - // eprintln!("{}", i); - // } // preprocessing call related: 把 call 前后的 从 save 到 restore 的若干条指令保存在 call_related 里面,然后加入到 is_filtered_idx 之后遍历instrs 的时候遇到就直接continue let mut processed_instrs = preprocess_call( node, @@ -235,20 +227,6 @@ impl InstrDag { } for (idx, instr) in processed_instrs.iter().rev().enumerate() { let node = Rc::new(RefCell::new(InstrNode::new(instr, idx))); - if idx == 0 - && instr.get_riscv_write().len() == 1 - && (instr.get_riscv_write()[0] == RiscvTemp::PhysReg(A0) - || if let RiscvTemp::VirtReg(t) = &instr.get_riscv_write()[0] { - if let Some(pre) = t.pre_color { - pre == A0 - } else { - false - } - } else { - false - }) { - li_ret = Some(node.clone()); - } let mut instr_node_succ = Vec::new(); let instructions_write = instr.get_riscv_write().clone(); @@ -320,10 +298,6 @@ impl InstrDag { }); last_loads.clear(); last_call = Some(node.clone()); - if let Some(ret_node) = li_ret.clone() { - instr_node_succ.push(ret_node.clone()); - ret_node.borrow_mut().pred.push(node.clone()); - } for i in call_instrs.iter() { instr_node_succ.push(i.clone()); i.borrow_mut().pred.push(node.clone()); diff --git a/utils/src/constants.rs b/utils/src/constants.rs index 2aeb0583..fc67c258 100644 --- a/utils/src/constants.rs +++ b/utils/src/constants.rs @@ -27,18 +27,18 @@ pub static VEC_EXTERN: [&str; 17] = [ pub static VEC_MACRO: [&str; 2] = ["starttime", "stoptime"]; pub const MAX_PHI_NUM: usize = 10; -pub static EXTEND_TIMES: i32 = 4; // software pipelining 循环展开的次数 -pub static DEPENDENCY_EXPLORE_DEPTH: i32 = 10; // software pipelining 过程中,对于数组的依赖,所枚举到的深度 +pub static EXTEND_TIMES: i32 = 0; // software pipelining 循环展开的次数 +pub static DEPENDENCY_EXPLORE_DEPTH: i32 = 0; // software pipelining 过程中,对于数组的依赖,所枚举到的深度 pub static BLOCKSIZE_THRESHOLD: usize = 100; // software pipelining 判断如果基本本块大小超了 BLOCKSIZE_THRESHOLD 后就不进行针对基本本块的优化 pub static BFS_STATE_THRESHOLD: usize = 9; // 在 instr_scheduling 中,每轮 bfs 所保留的状态的阈值 // for instruction scheduling: register punishment -pub static ADD_ALLOCATABLES: i32 = 20; -pub static NEAR_END: i32 = 10; // 寄存器生命周期更快结束的指令优先 -pub static REDUCE_SUB: i32 = 4; // 后继中的节点对应指令,寄存器生命周期更快结束的指令优先 -pub static REDUCE_LIVE: i32 = 2; +pub static ADD_ALLOCATABLES: i32 = 0; +pub static NEAR_END: i32 = 0; // 寄存器生命周期更快结束的指令优先 +pub static REDUCE_SUB: i32 = 0; // 后继中的节点对应指令,寄存器生命周期更快结束的指令优先 +pub static REDUCE_LIVE: i32 = 0; pub static LIVE_THROUGH: usize = 30; pub static SUM_MIN_RATIO: i32 = 1; pub static SCHEDULE_THRESHOLD: usize = 15000; pub static SOFTWARE_PIPELINE_PARAM: i32 = 0; // 拓扑排序后软流水的权重 -pub static HARDWARE_PIPELINE_PARAM: i32 = 0; // 拓扑排序后硬件流水的权重 +pub static HARDWARE_PIPELINE_PARAM: i32 = 1; // 拓扑排序后硬件流水的权重 pub static FDIV_WAIT: usize = 20; // fdiv 的 repeat rate