Last active
December 3, 2024 16:00
-
-
Save MaxOhn/625af10011f6d7e13a171b08ccf959ff to your computer and use it in GitHub Desktop.
Benchmarking osu!lazer against rosu-pp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
fs::read_to_string, | |
hint::black_box, | |
path::Path, | |
time::{Duration, Instant}, | |
}; | |
use rosu_pp::{ | |
any::{DifficultyAttributes, PerformanceAttributes}, | |
Beatmap, Difficulty, Performance, | |
}; | |
fn main() { | |
// 100 random map ids | |
let map_ids_string = read_to_string("/path/to/map_ids.csv").unwrap(); | |
let map_ids: Vec<u32> = map_ids_string | |
.split(',') | |
.map(str::parse) | |
.map(Result::unwrap) | |
.collect(); | |
print!("Decoding maps: "); | |
let paths: Vec<_> = map_ids | |
.iter() | |
.map(|id| format!("/path/to/{id}.osu")) | |
.collect(); | |
bench(decode_map, &paths); | |
print!("Calculating difficulties: "); | |
let maps: Vec<_> = paths.iter().map(decode_map).collect(); | |
bench(calculate_difficulty, &maps); | |
print!("Calculating performances: "); | |
let attrs: Vec<_> = maps.iter().map(calculate_difficulty).collect(); | |
bench(calculate_performance, &attrs); | |
} | |
fn decode_map(path: impl AsRef<Path>) -> Beatmap { | |
Beatmap::from_path(path).unwrap() | |
} | |
fn calculate_difficulty(map: &Beatmap) -> DifficultyAttributes { | |
Difficulty::new().calculate(map) | |
} | |
fn calculate_performance(diff_attrs: &DifficultyAttributes) -> PerformanceAttributes { | |
Performance::from(diff_attrs.clone()).calculate() | |
} | |
fn bench<'a, I, O>(f: impl Fn(&'a I) -> O, inputs: &'a [I]) { | |
const CHUNK_SIZE: usize = 10; | |
// Warmup | |
for _ in 0..CHUNK_SIZE { | |
for input in inputs { | |
let _ = black_box(f(input)); | |
} | |
} | |
let mut chunk_times = Vec::new(); | |
let start = Instant::now(); | |
while start.elapsed().as_secs() < 10 { | |
let curr = Instant::now(); | |
for _ in 0..CHUNK_SIZE { | |
for input in inputs { | |
let _ = black_box(f(input)); | |
} | |
} | |
let elapsed = curr.elapsed(); | |
chunk_times.push(elapsed / CHUNK_SIZE as u32); | |
} | |
chunk_times.sort_unstable(); | |
let mean = chunk_times.iter().copied().sum::<Duration>() / chunk_times.len() as u32; | |
let median = chunk_times[chunk_times.len() / 2]; | |
println!("Median: {median:.2?} | Mean: {mean:.2?}"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using osu.Game.Beatmaps; | |
using osu.Game.Beatmaps.Formats; | |
using osu.Game.Rulesets; | |
using osu.Game.Rulesets.Catch; | |
using osu.Game.Rulesets.Difficulty; | |
using osu.Game.Rulesets.Mania; | |
using osu.Game.Rulesets.Osu; | |
using osu.Game.Rulesets.Taiko; | |
using osu.Game.Scoring; | |
using System.Diagnostics; | |
namespace Bench | |
{ | |
public class Program | |
{ | |
static void Main() | |
{ | |
Decoder.RegisterDependencies(new AssemblyRulesetStore()); | |
// 100 random map ids | |
string mapIdsString = File.ReadAllText("/path/to/map_ids.csv"); | |
var mapIds = mapIdsString.Split(',').Select(int.Parse).ToList(); | |
Console.Write($"Decoding maps: "); | |
var paths = mapIds.Select(mapId => $"/path/to/{mapId}.osu").ToList(); | |
Bench(DecodeMap, paths); | |
Console.Write($"Calculating difficulties: "); | |
var maps = paths.Select(DecodeMap).ToList(); | |
Bench(CalculateDifficulty, maps); | |
Console.Write($"Calculating performances: "); | |
var attrs = maps.Select(map => (CreateRuleset(map.Beatmap.BeatmapInfo.Ruleset), CalculateDifficulty(map))).ToList(); | |
Bench(CalculatePerformance, attrs); | |
} | |
static FlatWorkingBeatmap DecodeMap(string path) => new FlatWorkingBeatmap(path); | |
static DifficultyAttributes CalculateDifficulty(IWorkingBeatmap map) | |
=> CreateRuleset(map.Beatmap.BeatmapInfo.Ruleset).CreateDifficultyCalculator(map).Calculate(); | |
static PerformanceAttributes? CalculatePerformance((Ruleset ruleset, DifficultyAttributes diffAttrs) input) | |
=> input.ruleset.CreatePerformanceCalculator()?.Calculate(new ScoreInfo(), input.diffAttrs); | |
static void Bench<I, O>(Func<I, O> f, List<I> inputs) | |
{ | |
const int chunk_size = 10; | |
// Warmup | |
for (int i = 0; i < chunk_size; i++) | |
{ | |
foreach (I input in inputs) | |
{ | |
var _ = f(input); | |
} | |
} | |
var chunkTimes = new List<TimeSpan>(); | |
var start = Stopwatch.StartNew(); | |
while (start.Elapsed.Seconds < 10) | |
{ | |
var curr = Stopwatch.StartNew(); | |
for (int i = 0; i < chunk_size;i++) | |
{ | |
foreach (I input in inputs) | |
{ | |
var _ = f(input); | |
} | |
} | |
var elapsed = curr.Elapsed; | |
chunkTimes.Add(elapsed / chunk_size); | |
} | |
chunkTimes.Sort(); | |
var mean = new TimeSpan(chunkTimes.Aggregate(0L, (sum, next) => sum + next.Ticks) / chunkTimes.Count); | |
var median = chunkTimes[chunkTimes.Count / 2]; | |
Console.WriteLine($"Median: {FormatTimeSpan(mean)} | Mean: {FormatTimeSpan(median)}"); | |
} | |
static Ruleset CreateRuleset(RulesetInfo info) | |
{ | |
switch (info.OnlineID) | |
{ | |
case 0: | |
return new OsuRuleset(); | |
case 1: | |
return new TaikoRuleset(); | |
case 2: | |
return new CatchRuleset(); | |
case 3: | |
return new ManiaRuleset(); | |
default: | |
throw new InvalidDataException($"invalid ruleset id {info.OnlineID}"); | |
} | |
} | |
static string FormatTimeSpan(TimeSpan time) | |
{ | |
const long ns_per_tick = 100L; | |
const long ticks_per_milli = 1000000L / ns_per_tick; | |
const long ticks_per_micro = 1000L / ns_per_tick; | |
if (time.Seconds > 0) | |
return $"{time:ss\\.ff}s"; | |
else if (time.Ticks >= ticks_per_milli) | |
return $"{time.Ticks / ticks_per_milli}.{(time.Ticks % ticks_per_milli).ToString().PadRight(2, '0')[..2]}ms"; | |
else if (time.Ticks >= ticks_per_micro) | |
return $"{time.Ticks / ticks_per_micro}.{(time.Ticks % ticks_per_micro).ToString().PadRight(2, '0')[..2]}µs"; | |
else | |
return $"{time.Ticks * ns_per_tick}ns"; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment