You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

631 lines
25 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Diagnostics.CodeAnalysis;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Runtime.Serialization;
  9. using System.Security.Cryptography;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. using Ionic.Zip;
  13. using IPA.Loader;
  14. using IPA.Utilities;
  15. using Newtonsoft.Json;
  16. using SemVer;
  17. using UnityEngine;
  18. using UnityEngine.Networking;
  19. using static IPA.Loader.PluginManager;
  20. using Logger = IPA.Logging.Logger;
  21. using Version = SemVer.Version;
  22. namespace IPA.Updating.BeatMods
  23. {
  24. [SuppressMessage("ReSharper", "ClassNeverInstantiated.Global")]
  25. internal class Updater : MonoBehaviour
  26. {
  27. public static Updater Instance;
  28. public void Awake()
  29. {
  30. try
  31. {
  32. if (Instance != null)
  33. Destroy(this);
  34. else
  35. {
  36. Instance = this;
  37. CheckForUpdates();
  38. }
  39. }
  40. catch (Exception e)
  41. {
  42. Logger.updater.Error(e);
  43. }
  44. }
  45. private void CheckForUpdates()
  46. {
  47. StartCoroutine(CheckForUpdatesCoroutine());
  48. }
  49. private class DependencyObject
  50. {
  51. public string Name { get; set; }
  52. public Version Version { get; set; }
  53. public Version ResolvedVersion { get; set; }
  54. public Range Requirement { get; set; }
  55. public Range Conflicts { get; set; } // a range of versions that are not allowed to be downloaded
  56. public bool Resolved { get; set; }
  57. public bool Has { get; set; }
  58. public HashSet<string> Consumers { get; set; } = new HashSet<string>();
  59. public bool MetaRequestFailed { get; set; }
  60. public PluginLoader.PluginInfo LocalPluginMeta { get; set; }
  61. public override string ToString()
  62. {
  63. return $"{Name}@{Version}{(Resolved ? $" -> {ResolvedVersion}" : "")} - ({Requirement} ! {Conflicts}) {(Has ? " Already have" : "")}";
  64. }
  65. }
  66. private readonly Dictionary<string, string> requestCache = new Dictionary<string, string>();
  67. private IEnumerator GetBeatModsEndpoint(string url, Ref<string> result)
  68. {
  69. if (requestCache.TryGetValue(url, out string value))
  70. {
  71. result.Value = value;
  72. }
  73. else
  74. {
  75. using (var request = UnityWebRequest.Get(ApiEndpoint.ApiBase + url))
  76. {
  77. yield return request.SendWebRequest();
  78. if (request.isNetworkError)
  79. {
  80. result.Error = new NetworkException($"Network error while trying to download: {request.error}");
  81. yield break;
  82. }
  83. if (request.isHttpError)
  84. {
  85. if (request.responseCode == 404)
  86. {
  87. result.Error = new NetworkException("Not found");
  88. yield break;
  89. }
  90. result.Error = new NetworkException($"Server returned error {request.error} while getting data");
  91. yield break;
  92. }
  93. result.Value = request.downloadHandler.text;
  94. requestCache[url] = result.Value;
  95. }
  96. }
  97. }
  98. private readonly Dictionary<string, ApiEndpoint.Mod> modCache = new Dictionary<string, ApiEndpoint.Mod>();
  99. private IEnumerator GetModInfo(string modName, string ver, Ref<ApiEndpoint.Mod> result)
  100. {
  101. var uri = string.Format(ApiEndpoint.GetModInfoEndpoint, Uri.EscapeUriString(modName), Uri.EscapeUriString(ver));
  102. if (modCache.TryGetValue(uri, out ApiEndpoint.Mod value))
  103. {
  104. result.Value = value;
  105. }
  106. else
  107. {
  108. Ref<string> reqResult = new Ref<string>("");
  109. yield return GetBeatModsEndpoint(uri, reqResult);
  110. try
  111. {
  112. result.Value = JsonConvert.DeserializeObject<List<ApiEndpoint.Mod>>(reqResult.Value).First();
  113. modCache[uri] = result.Value;
  114. }
  115. catch (Exception e)
  116. {
  117. result.Error = new Exception("Error decoding response", e);
  118. }
  119. }
  120. }
  121. private readonly Dictionary<string, List<ApiEndpoint.Mod>> modVersionsCache = new Dictionary<string, List<ApiEndpoint.Mod>>();
  122. private IEnumerator GetModVersionsMatching(string modName, Range range, Ref<List<ApiEndpoint.Mod>> result)
  123. {
  124. var uri = string.Format(ApiEndpoint.GetModsByName, Uri.EscapeUriString(modName));
  125. if (modVersionsCache.TryGetValue(uri, out List<ApiEndpoint.Mod> value))
  126. {
  127. result.Value = value;
  128. }
  129. else
  130. {
  131. Ref<string> reqResult = new Ref<string>("");
  132. yield return GetBeatModsEndpoint(uri, reqResult);
  133. try
  134. {
  135. result.Value = JsonConvert.DeserializeObject<List<ApiEndpoint.Mod>>(reqResult.Value)
  136. .Where(m => range.IsSatisfied(m.Version)).ToList();
  137. modVersionsCache[uri] = result.Value;
  138. }
  139. catch (Exception e)
  140. {
  141. result.Error = new Exception("Error decoding response", e);
  142. }
  143. }
  144. }
  145. private IEnumerator CheckForUpdatesCoroutine()
  146. {
  147. var depList = new Ref<List<DependencyObject>>(new List<DependencyObject>());
  148. foreach (var plugin in BSMetas)
  149. { // initialize with data to resolve (1.1)
  150. if (plugin.Metadata.Id != null)
  151. { // updatable
  152. var msinfo = plugin.Metadata;
  153. depList.Value.Add(new DependencyObject {
  154. Name = msinfo.Id,
  155. Version = msinfo.Version,
  156. Requirement = new Range($">={msinfo.Version}"),
  157. LocalPluginMeta = plugin
  158. });
  159. }
  160. }
  161. foreach (var meta in PluginLoader.ignoredPlugins.Where(m => m.Id != null))
  162. {
  163. depList.Value.Add(new DependencyObject
  164. {
  165. Name = meta.Id,
  166. Version = meta.Version,
  167. Requirement = new Range($">={meta.Version}"),
  168. LocalPluginMeta = new PluginLoader.PluginInfo
  169. {
  170. Metadata = meta, Plugin = null
  171. }
  172. });
  173. }
  174. foreach (var dep in depList.Value)
  175. Logger.updater.Debug($"Phantom Dependency: {dep}");
  176. yield return DependencyResolveFirstPass(depList);
  177. foreach (var dep in depList.Value)
  178. Logger.updater.Debug($"Dependency: {dep}");
  179. yield return DependencyResolveSecondPass(depList);
  180. foreach (var dep in depList.Value)
  181. Logger.updater.Debug($"Dependency: {dep}");
  182. DependendyResolveFinalPass(depList);
  183. }
  184. private IEnumerator DependencyResolveFirstPass(Ref<List<DependencyObject>> list)
  185. {
  186. for (int i = 0; i < list.Value.Count; i++)
  187. { // Grab dependencies (1.2)
  188. var dep = list.Value[i];
  189. var mod = new Ref<ApiEndpoint.Mod>(null);
  190. yield return GetModInfo(dep.Name, "", mod);
  191. try { mod.Verify(); }
  192. catch (Exception e)
  193. {
  194. Logger.updater.Error($"Error getting info for {dep.Name}");
  195. Logger.updater.Error(e);
  196. dep.MetaRequestFailed = true;
  197. continue;
  198. }
  199. list.Value.AddRange(mod.Value.Dependencies.Select(m => new DependencyObject
  200. {
  201. Name = m.Name,
  202. Requirement = new Range($">={m.Version}"),
  203. Consumers = new HashSet<string> { dep.Name }
  204. }));
  205. // currently no conflicts exist in BeatMods
  206. //list.Value.AddRange(mod.Value.Links.Dependencies.Select(d => new DependencyObject { Name = d.Name, Requirement = d.VersionRange, Consumers = new HashSet<string> { dep.Name } }));
  207. //list.Value.AddRange(mod.Value.Links.Conflicts.Select(d => new DependencyObject { Name = d.Name, Conflicts = d.VersionRange, Consumers = new HashSet<string> { dep.Name } }));
  208. }
  209. var depNames = new HashSet<string>();
  210. var final = new List<DependencyObject>();
  211. foreach (var dep in list.Value)
  212. { // agregate ranges and the like (1.3)
  213. if (!depNames.Contains(dep.Name))
  214. { // should add it
  215. depNames.Add(dep.Name);
  216. final.Add(dep);
  217. }
  218. else
  219. {
  220. var toMod = final.First(d => d.Name == dep.Name);
  221. if (dep.Requirement != null)
  222. {
  223. toMod.Requirement = toMod.Requirement.Intersect(dep.Requirement);
  224. foreach (var consume in dep.Consumers)
  225. toMod.Consumers.Add(consume);
  226. }
  227. else if (dep.Conflicts != null)
  228. {
  229. toMod.Conflicts = toMod.Conflicts == null
  230. ? dep.Conflicts
  231. : new Range($"{toMod.Conflicts} || {dep.Conflicts}");
  232. }
  233. }
  234. }
  235. list.Value = final;
  236. }
  237. private IEnumerator DependencyResolveSecondPass(Ref<List<DependencyObject>> list)
  238. {
  239. foreach(var dep in list.Value)
  240. {
  241. dep.Has = dep.Version != null; // dep.Version is only not null if its already installed
  242. if (dep.MetaRequestFailed)
  243. {
  244. Logger.updater.Warn($"{dep.Name} info request failed, not trying again");
  245. continue;
  246. }
  247. var modsMatching = new Ref<List<ApiEndpoint.Mod>>(null);
  248. yield return GetModVersionsMatching(dep.Name, dep.Requirement, modsMatching);
  249. try { modsMatching.Verify(); }
  250. catch (Exception e)
  251. {
  252. Logger.updater.Error($"Error getting mod list for {dep.Name}");
  253. Logger.updater.Error(e);
  254. dep.MetaRequestFailed = true;
  255. continue;
  256. }
  257. var ver = modsMatching.Value
  258. .Where(nullCheck => nullCheck != null) // entry is not null
  259. //.Where(versionCheck => versionCheck.GameVersion.Version == BeatSaber.GameVersion) // game version matches
  260. .Where(approvalCheck => approvalCheck.Status == ApiEndpoint.Mod.ApprovedStatus) // version approved
  261. .Where(conflictsCheck => dep.Conflicts == null || !dep.Conflicts.IsSatisfied(conflictsCheck.Version)) // not a conflicting version
  262. .Select(mod => mod.Version).Max(); // (2.1) get the max version
  263. // ReSharper disable once AssignmentInConditionalExpression
  264. if (dep.Resolved = ver != null) dep.ResolvedVersion = ver; // (2.2)
  265. dep.Has = dep.Version == dep.ResolvedVersion && dep.Resolved; // dep.Version is only not null if its already installed
  266. }
  267. }
  268. private void DependendyResolveFinalPass(Ref<List<DependencyObject>> list)
  269. { // also starts download of mods
  270. var toDl = new List<DependencyObject>();
  271. foreach (var dep in list.Value)
  272. { // figure out which ones need to be downloaded (3.1)
  273. if (dep.Resolved)
  274. {
  275. Logger.updater.Debug($"Resolved: {dep}");
  276. if (!dep.Has)
  277. {
  278. Logger.updater.Debug($"To Download: {dep}");
  279. toDl.Add(dep);
  280. }
  281. }
  282. else if (!dep.Has)
  283. {
  284. Logger.updater.Warn($"Could not resolve dependency {dep}");
  285. }
  286. }
  287. Logger.updater.Debug($"To Download {string.Join(", ", toDl.Select(d => $"{d.Name}@{d.ResolvedVersion}"))}");
  288. foreach (var item in toDl)
  289. StartCoroutine(UpdateModCoroutine(item));
  290. }
  291. private IEnumerator UpdateModCoroutine(DependencyObject item)
  292. { // (3.2)
  293. Logger.updater.Debug($"Release: {BeatSaber.ReleaseType}");
  294. var mod = new Ref<ApiEndpoint.Mod>(null);
  295. yield return GetModInfo(item.Name, item.ResolvedVersion.ToString(), mod);
  296. try { mod.Verify(); }
  297. catch (Exception e)
  298. {
  299. Logger.updater.Error($"Error occurred while trying to get information for {item}");
  300. Logger.updater.Error(e);
  301. yield break;
  302. }
  303. /*
  304. ApiEndpoint.Mod.DownloadsObject platformFile;
  305. if (BeatSaber.ReleaseType == BeatSaber.Release.Steam || mod.Value.Files.Oculus == null)
  306. platformFile = mod.Value.Files.Steam;
  307. else
  308. platformFile = mod.Value.Files.Oculus;*/
  309. var releaseName = BeatSaber.ReleaseType == BeatSaber.Release.Steam
  310. ? ApiEndpoint.Mod.DownloadsObject.TypeSteam : ApiEndpoint.Mod.DownloadsObject.TypeOculus;
  311. var platformFile = mod.Value.Downloads.First(f => f.Type == ApiEndpoint.Mod.DownloadsObject.TypeUniversal || f.Type == releaseName);
  312. string url = ApiEndpoint.BeatModBase + platformFile.Path;
  313. Logger.updater.Debug($"URL = {url}");
  314. const int maxTries = 3;
  315. int tries = maxTries;
  316. while (tries > 0)
  317. {
  318. if (tries-- != maxTries)
  319. Logger.updater.Debug("Re-trying download...");
  320. using (var stream = new MemoryStream())
  321. using (var request = UnityWebRequest.Get(url))
  322. using (var taskTokenSource = new CancellationTokenSource())
  323. {
  324. var dlh = new StreamDownloadHandler(stream);
  325. request.downloadHandler = dlh;
  326. Logger.updater.Debug("Sending request");
  327. //Logger.updater.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  328. yield return request.SendWebRequest();
  329. Logger.updater.Debug("Download finished");
  330. if (request.isNetworkError)
  331. {
  332. Logger.updater.Error("Network error while trying to update mod");
  333. Logger.updater.Error(request.error);
  334. taskTokenSource.Cancel();
  335. continue;
  336. }
  337. if (request.isHttpError)
  338. {
  339. Logger.updater.Error("Server returned an error code while trying to update mod");
  340. Logger.updater.Error(request.error);
  341. taskTokenSource.Cancel();
  342. continue;
  343. }
  344. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  345. var downloadTask = Task.Run(() =>
  346. { // use slightly more multi threaded approach than co-routines
  347. // ReSharper disable once AccessToDisposedClosure
  348. ExtractPluginAsync(stream, item, platformFile);
  349. }, taskTokenSource.Token);
  350. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  351. yield return null; // pause co-routine until task is done
  352. if (downloadTask.IsFaulted)
  353. {
  354. if (downloadTask.Exception != null && downloadTask.Exception.InnerExceptions.Any(e => e is BeatmodsInterceptException))
  355. { // any exception is an intercept exception
  356. Logger.updater.Error($"Modsaber did not return expected data for {item.Name}");
  357. }
  358. Logger.updater.Error($"Error downloading mod {item.Name}");
  359. Logger.updater.Error(downloadTask.Exception);
  360. continue;
  361. }
  362. break;
  363. }
  364. }
  365. if (tries == 0)
  366. Logger.updater.Warn($"Plugin download failed {maxTries} times, not re-trying");
  367. else
  368. Logger.updater.Debug("Download complete");
  369. }
  370. internal class StreamDownloadHandler : DownloadHandlerScript
  371. {
  372. public MemoryStream Stream { get; set; }
  373. public StreamDownloadHandler(MemoryStream stream)
  374. {
  375. Stream = stream;
  376. }
  377. protected override void ReceiveContentLength(int contentLength)
  378. {
  379. Stream.Capacity = contentLength;
  380. Logger.updater.Debug($"Got content length: {contentLength}");
  381. }
  382. protected override void CompleteContent()
  383. {
  384. Logger.updater.Debug("Download complete");
  385. }
  386. protected override bool ReceiveData(byte[] rData, int dataLength)
  387. {
  388. if (rData == null || rData.Length < 1)
  389. {
  390. Logger.updater.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  391. return false;
  392. }
  393. Stream.Write(rData, 0, dataLength);
  394. return true;
  395. }
  396. protected override byte[] GetData() { return null; }
  397. protected override float GetProgress()
  398. {
  399. return 0f;
  400. }
  401. public override string ToString()
  402. {
  403. return $"{base.ToString()} ({Stream})";
  404. }
  405. }
  406. private void ExtractPluginAsync(MemoryStream stream, DependencyObject item, ApiEndpoint.Mod.DownloadsObject fileInfo)
  407. { // (3.3)
  408. Logger.updater.Debug($"Extracting ZIP file for {item.Name}");
  409. /*var data = stream.GetBuffer();
  410. SHA1 sha = new SHA1CryptoServiceProvider();
  411. var hash = sha.ComputeHash(data);
  412. if (!Utils.UnsafeCompare(hash, fileInfo.Hash))
  413. throw new Exception("The hash for the file doesn't match what is defined");*/
  414. var targetDir = Path.Combine(BeatSaber.InstallPath, "IPA", Path.GetRandomFileName() + "_Pending");
  415. Directory.CreateDirectory(targetDir);
  416. var eventualOutput = Path.Combine(BeatSaber.InstallPath, "IPA", "Pending");
  417. if (!Directory.Exists(eventualOutput))
  418. Directory.CreateDirectory(eventualOutput);
  419. try
  420. {
  421. bool shouldDeleteOldFile = !(item.LocalPluginMeta?.Metadata.IsSelf).Unwrap();
  422. using (var zipFile = ZipFile.Read(stream))
  423. {
  424. Logger.updater.Debug("Streams opened");
  425. foreach (var entry in zipFile)
  426. {
  427. if (entry.IsDirectory)
  428. {
  429. Logger.updater.Debug($"Creating directory {entry.FileName}");
  430. Directory.CreateDirectory(Path.Combine(targetDir, entry.FileName));
  431. }
  432. else
  433. {
  434. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  435. {
  436. entry.Extract(ostream);
  437. ostream.Seek(0, SeekOrigin.Begin);
  438. var md5 = new MD5CryptoServiceProvider();
  439. var fileHash = md5.ComputeHash(ostream);
  440. try
  441. {
  442. if (!Utils.UnsafeCompare(fileHash, fileInfo.Hashes.Where(h => h.File == entry.FileName).Select(h => h.Hash).First()))
  443. throw new Exception("The hash for the file doesn't match what is defined");
  444. }
  445. catch (KeyNotFoundException)
  446. {
  447. throw new BeatmodsInterceptException("BeatMods did not send the hashes for the zip's content!");
  448. }
  449. ostream.Seek(0, SeekOrigin.Begin);
  450. FileInfo targetFile = new FileInfo(Path.Combine(targetDir, entry.FileName));
  451. Directory.CreateDirectory(targetFile.DirectoryName ?? throw new InvalidOperationException());
  452. if (Utils.GetRelativePath(targetFile.FullName, targetDir) == Utils.GetRelativePath(item.LocalPluginMeta?.Metadata.File.FullName, BeatSaber.InstallPath))
  453. shouldDeleteOldFile = false; // overwriting old file, no need to delete
  454. /*if (targetFile.Exists)
  455. backup.Add(targetFile);
  456. else
  457. newFiles.Add(targetFile);*/
  458. Logger.updater.Debug($"Extracting file {targetFile.FullName}");
  459. targetFile.Delete();
  460. using (var fstream = targetFile.Create())
  461. ostream.CopyTo(fstream);
  462. }
  463. }
  464. }
  465. }
  466. if (shouldDeleteOldFile && item.LocalPluginMeta != null)
  467. File.AppendAllLines(Path.Combine(targetDir, SpecialDeletionsFile), new[] { Utils.GetRelativePath(item.LocalPluginMeta?.Metadata.File.FullName, BeatSaber.InstallPath) });
  468. }
  469. catch (Exception)
  470. { // something failed; restore
  471. /*foreach (var file in newFiles)
  472. file.Delete();
  473. backup.Restore();
  474. backup.Delete();*/
  475. Directory.Delete(targetDir, true); // delete extraction site
  476. throw;
  477. }
  478. if ((item.LocalPluginMeta?.Metadata.IsSelf).Unwrap())
  479. { // currently updating self, so copy to working dir and update
  480. Utils.CopyAll(new DirectoryInfo(targetDir), new DirectoryInfo(BeatSaber.InstallPath));
  481. var deleteFile = Path.Combine(BeatSaber.InstallPath, SpecialDeletionsFile);
  482. if (File.Exists(deleteFile)) File.Delete(deleteFile);
  483. Process.Start(new ProcessStartInfo
  484. {
  485. // will never actually be null
  486. FileName = item.LocalPluginMeta?.Metadata.File.FullName ?? throw new InvalidOperationException(),
  487. Arguments = $"-nw={Process.GetCurrentProcess().Id}",
  488. UseShellExecute = false
  489. });
  490. }
  491. else
  492. Utils.CopyAll(new DirectoryInfo(targetDir), new DirectoryInfo(eventualOutput), SpecialDeletionsFile);
  493. Directory.Delete(targetDir, true); // delete extraction site
  494. Logger.updater.Debug("Extractor exited");
  495. }
  496. internal const string SpecialDeletionsFile = "$$delete";
  497. }
  498. [Serializable]
  499. internal class NetworkException : Exception
  500. {
  501. public NetworkException()
  502. {
  503. }
  504. public NetworkException(string message) : base(message)
  505. {
  506. }
  507. public NetworkException(string message, Exception innerException) : base(message, innerException)
  508. {
  509. }
  510. protected NetworkException(SerializationInfo info, StreamingContext context) : base(info, context)
  511. {
  512. }
  513. }
  514. [Serializable]
  515. internal class BeatmodsInterceptException : Exception
  516. {
  517. public BeatmodsInterceptException()
  518. {
  519. }
  520. public BeatmodsInterceptException(string message) : base(message)
  521. {
  522. }
  523. public BeatmodsInterceptException(string message, Exception innerException) : base(message, innerException)
  524. {
  525. }
  526. protected BeatmodsInterceptException(SerializationInfo info, StreamingContext context) : base(info, context)
  527. {
  528. }
  529. }
  530. }