You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

635 lines
25 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Diagnostics.CodeAnalysis;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Runtime.Serialization;
  9. using System.Security.Cryptography;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. using Ionic.Zip;
  13. using IPA.Loader;
  14. using IPA.Loader.Features;
  15. using IPA.Utilities;
  16. using Newtonsoft.Json;
  17. using SemVer;
  18. using UnityEngine;
  19. using UnityEngine.Networking;
  20. using static IPA.Loader.PluginManager;
  21. using Logger = IPA.Logging.Logger;
  22. using Version = SemVer.Version;
  23. namespace IPA.Updating.BeatMods
  24. {
  25. [SuppressMessage("ReSharper", "ClassNeverInstantiated.Global")]
  26. internal class Updater : MonoBehaviour
  27. {
  28. public static Updater Instance;
  29. public void Awake()
  30. {
  31. try
  32. {
  33. if (Instance != null)
  34. Destroy(this);
  35. else
  36. {
  37. Instance = this;
  38. CheckForUpdates();
  39. }
  40. }
  41. catch (Exception e)
  42. {
  43. Logger.updater.Error(e);
  44. }
  45. }
  46. private void CheckForUpdates()
  47. {
  48. StartCoroutine(CheckForUpdatesCoroutine());
  49. }
  50. private class DependencyObject
  51. {
  52. public string Name { get; set; }
  53. public Version Version { get; set; }
  54. public Version ResolvedVersion { get; set; }
  55. public Range Requirement { get; set; }
  56. public Range Conflicts { get; set; } // a range of versions that are not allowed to be downloaded
  57. public bool Resolved { get; set; }
  58. public bool Has { get; set; }
  59. public HashSet<string> Consumers { get; set; } = new HashSet<string>();
  60. public bool MetaRequestFailed { get; set; }
  61. public PluginLoader.PluginInfo LocalPluginMeta { get; set; }
  62. public override string ToString()
  63. {
  64. return $"{Name}@{Version}{(Resolved ? $" -> {ResolvedVersion}" : "")} - ({Requirement} ! {Conflicts}) {(Has ? " Already have" : "")}";
  65. }
  66. }
  67. private readonly Dictionary<string, string> requestCache = new Dictionary<string, string>();
  68. private IEnumerator GetBeatModsEndpoint(string url, Ref<string> result)
  69. {
  70. if (requestCache.TryGetValue(url, out string value))
  71. {
  72. result.Value = value;
  73. }
  74. else
  75. {
  76. using (var request = UnityWebRequest.Get(ApiEndpoint.ApiBase + url))
  77. {
  78. yield return request.SendWebRequest();
  79. if (request.isNetworkError)
  80. {
  81. result.Error = new NetworkException($"Network error while trying to download: {request.error}");
  82. yield break;
  83. }
  84. if (request.isHttpError)
  85. {
  86. if (request.responseCode == 404)
  87. {
  88. result.Error = new NetworkException("Not found");
  89. yield break;
  90. }
  91. result.Error = new NetworkException($"Server returned error {request.error} while getting data");
  92. yield break;
  93. }
  94. result.Value = request.downloadHandler.text;
  95. requestCache[url] = result.Value;
  96. }
  97. }
  98. }
  99. private readonly Dictionary<string, ApiEndpoint.Mod> modCache = new Dictionary<string, ApiEndpoint.Mod>();
  100. private IEnumerator GetModInfo(string modName, string ver, Ref<ApiEndpoint.Mod> result)
  101. {
  102. var uri = string.Format(ApiEndpoint.GetModInfoEndpoint, Uri.EscapeUriString(modName), Uri.EscapeUriString(ver));
  103. if (modCache.TryGetValue(uri, out ApiEndpoint.Mod value))
  104. {
  105. result.Value = value;
  106. }
  107. else
  108. {
  109. Ref<string> reqResult = new Ref<string>("");
  110. yield return GetBeatModsEndpoint(uri, reqResult);
  111. try
  112. {
  113. result.Value = JsonConvert.DeserializeObject<List<ApiEndpoint.Mod>>(reqResult.Value).First();
  114. modCache[uri] = result.Value;
  115. }
  116. catch (Exception e)
  117. {
  118. result.Error = new Exception("Error decoding response", e);
  119. }
  120. }
  121. }
  122. private readonly Dictionary<string, List<ApiEndpoint.Mod>> modVersionsCache = new Dictionary<string, List<ApiEndpoint.Mod>>();
  123. private IEnumerator GetModVersionsMatching(string modName, Range range, Ref<List<ApiEndpoint.Mod>> result)
  124. {
  125. var uri = string.Format(ApiEndpoint.GetModsByName, Uri.EscapeUriString(modName));
  126. if (modVersionsCache.TryGetValue(uri, out List<ApiEndpoint.Mod> value))
  127. {
  128. result.Value = value;
  129. }
  130. else
  131. {
  132. Ref<string> reqResult = new Ref<string>("");
  133. yield return GetBeatModsEndpoint(uri, reqResult);
  134. try
  135. {
  136. result.Value = JsonConvert.DeserializeObject<List<ApiEndpoint.Mod>>(reqResult.Value)
  137. .Where(m => range.IsSatisfied(m.Version)).ToList();
  138. modVersionsCache[uri] = result.Value;
  139. }
  140. catch (Exception e)
  141. {
  142. result.Error = new Exception("Error decoding response", e);
  143. }
  144. }
  145. }
  146. private IEnumerator CheckForUpdatesCoroutine()
  147. {
  148. var depList = new Ref<List<DependencyObject>>(new List<DependencyObject>());
  149. foreach (var plugin in BSMetas
  150. .Where(m => m.Metadata.Features.FirstOrDefault(f => f is NoUpdateFeature) != null))
  151. { // initialize with data to resolve (1.1)
  152. if (plugin.Metadata.Id != null)
  153. { // updatable
  154. var msinfo = plugin.Metadata;
  155. depList.Value.Add(new DependencyObject {
  156. Name = msinfo.Id,
  157. Version = msinfo.Version,
  158. Requirement = new Range($">{msinfo.Version}"),
  159. LocalPluginMeta = plugin
  160. });
  161. }
  162. }
  163. foreach (var meta in PluginLoader.ignoredPlugins.Where(m => m.Id != null)
  164. .Where(m => m.Features.FirstOrDefault(f => f is NoUpdateFeature) != null))
  165. {
  166. depList.Value.Add(new DependencyObject
  167. {
  168. Name = meta.Id,
  169. Version = meta.Version,
  170. Requirement = new Range($">{meta.Version}"),
  171. LocalPluginMeta = new PluginLoader.PluginInfo
  172. {
  173. Metadata = meta, Plugin = null
  174. }
  175. });
  176. }
  177. foreach (var dep in depList.Value)
  178. Logger.updater.Debug($"Phantom Dependency: {dep}");
  179. yield return DependencyResolveFirstPass(depList);
  180. foreach (var dep in depList.Value)
  181. Logger.updater.Debug($"Dependency: {dep}");
  182. yield return DependencyResolveSecondPass(depList);
  183. foreach (var dep in depList.Value)
  184. Logger.updater.Debug($"Dependency: {dep}");
  185. DependendyResolveFinalPass(depList);
  186. }
  187. private IEnumerator DependencyResolveFirstPass(Ref<List<DependencyObject>> list)
  188. {
  189. for (int i = 0; i < list.Value.Count; i++)
  190. { // Grab dependencies (1.2)
  191. var dep = list.Value[i];
  192. var mod = new Ref<ApiEndpoint.Mod>(null);
  193. yield return GetModInfo(dep.Name, "", mod);
  194. try { mod.Verify(); }
  195. catch (Exception e)
  196. {
  197. Logger.updater.Error($"Error getting info for {dep.Name}");
  198. Logger.updater.Error(e);
  199. dep.MetaRequestFailed = true;
  200. continue;
  201. }
  202. list.Value.AddRange(mod.Value.Dependencies.Select(m => new DependencyObject
  203. {
  204. Name = m.Name,
  205. Requirement = new Range($"^{m.Version}"),
  206. Consumers = new HashSet<string> { dep.Name }
  207. }));
  208. // currently no conflicts exist in BeatMods
  209. //list.Value.AddRange(mod.Value.Links.Dependencies.Select(d => new DependencyObject { Name = d.Name, Requirement = d.VersionRange, Consumers = new HashSet<string> { dep.Name } }));
  210. //list.Value.AddRange(mod.Value.Links.Conflicts.Select(d => new DependencyObject { Name = d.Name, Conflicts = d.VersionRange, Consumers = new HashSet<string> { dep.Name } }));
  211. }
  212. var depNames = new HashSet<string>();
  213. var final = new List<DependencyObject>();
  214. foreach (var dep in list.Value)
  215. { // agregate ranges and the like (1.3)
  216. if (!depNames.Contains(dep.Name))
  217. { // should add it
  218. depNames.Add(dep.Name);
  219. final.Add(dep);
  220. }
  221. else
  222. {
  223. var toMod = final.First(d => d.Name == dep.Name);
  224. if (dep.Requirement != null)
  225. {
  226. toMod.Requirement = toMod.Requirement.Intersect(dep.Requirement);
  227. foreach (var consume in dep.Consumers)
  228. toMod.Consumers.Add(consume);
  229. }
  230. else if (dep.Conflicts != null)
  231. {
  232. toMod.Conflicts = toMod.Conflicts == null
  233. ? dep.Conflicts
  234. : new Range($"{toMod.Conflicts} || {dep.Conflicts}");
  235. }
  236. }
  237. }
  238. list.Value = final;
  239. }
  240. private IEnumerator DependencyResolveSecondPass(Ref<List<DependencyObject>> list)
  241. {
  242. foreach(var dep in list.Value)
  243. {
  244. dep.Has = dep.Version != null; // dep.Version is only not null if its already installed
  245. if (dep.MetaRequestFailed)
  246. {
  247. Logger.updater.Warn($"{dep.Name} info request failed, not trying again");
  248. continue;
  249. }
  250. var modsMatching = new Ref<List<ApiEndpoint.Mod>>(null);
  251. yield return GetModVersionsMatching(dep.Name, dep.Requirement, modsMatching);
  252. try { modsMatching.Verify(); }
  253. catch (Exception e)
  254. {
  255. Logger.updater.Error($"Error getting mod list for {dep.Name}");
  256. Logger.updater.Error(e);
  257. dep.MetaRequestFailed = true;
  258. continue;
  259. }
  260. var ver = modsMatching.Value
  261. .Where(nullCheck => nullCheck != null) // entry is not null
  262. //.Where(versionCheck => versionCheck.GameVersion.Version == BeatSaber.GameVersion) // game version matches
  263. .Where(approvalCheck => approvalCheck.Status == ApiEndpoint.Mod.ApprovedStatus) // version approved
  264. .Where(conflictsCheck => dep.Conflicts == null || !dep.Conflicts.IsSatisfied(conflictsCheck.Version)) // not a conflicting version
  265. .Select(mod => mod.Version).Max(); // (2.1) get the max version
  266. // ReSharper disable once AssignmentInConditionalExpression
  267. if (dep.Resolved = ver != null) dep.ResolvedVersion = ver; // (2.2)
  268. dep.Has = dep.Version == dep.ResolvedVersion && dep.Resolved; // dep.Version is only not null if its already installed
  269. }
  270. }
  271. private void DependendyResolveFinalPass(Ref<List<DependencyObject>> list)
  272. { // also starts download of mods
  273. var toDl = new List<DependencyObject>();
  274. foreach (var dep in list.Value)
  275. { // figure out which ones need to be downloaded (3.1)
  276. if (dep.Resolved)
  277. {
  278. Logger.updater.Debug($"Resolved: {dep}");
  279. if (!dep.Has)
  280. {
  281. Logger.updater.Debug($"To Download: {dep}");
  282. toDl.Add(dep);
  283. }
  284. }
  285. else if (!dep.Has)
  286. {
  287. Logger.updater.Warn($"Could not resolve dependency {dep}");
  288. }
  289. }
  290. Logger.updater.Debug($"To Download {string.Join(", ", toDl.Select(d => $"{d.Name}@{d.ResolvedVersion}"))}");
  291. foreach (var item in toDl)
  292. StartCoroutine(UpdateModCoroutine(item));
  293. }
  294. private IEnumerator UpdateModCoroutine(DependencyObject item)
  295. { // (3.2)
  296. Logger.updater.Debug($"Release: {BeatSaber.ReleaseType}");
  297. var mod = new Ref<ApiEndpoint.Mod>(null);
  298. yield return GetModInfo(item.Name, item.ResolvedVersion.ToString(), mod);
  299. try { mod.Verify(); }
  300. catch (Exception e)
  301. {
  302. Logger.updater.Error($"Error occurred while trying to get information for {item}");
  303. Logger.updater.Error(e);
  304. yield break;
  305. }
  306. /*
  307. ApiEndpoint.Mod.DownloadsObject platformFile;
  308. if (BeatSaber.ReleaseType == BeatSaber.Release.Steam || mod.Value.Files.Oculus == null)
  309. platformFile = mod.Value.Files.Steam;
  310. else
  311. platformFile = mod.Value.Files.Oculus;*/
  312. var releaseName = BeatSaber.ReleaseType == BeatSaber.Release.Steam
  313. ? ApiEndpoint.Mod.DownloadsObject.TypeSteam : ApiEndpoint.Mod.DownloadsObject.TypeOculus;
  314. var platformFile = mod.Value.Downloads.First(f => f.Type == ApiEndpoint.Mod.DownloadsObject.TypeUniversal || f.Type == releaseName);
  315. string url = ApiEndpoint.BeatModBase + platformFile.Path;
  316. Logger.updater.Debug($"URL = {url}");
  317. const int maxTries = 3;
  318. int tries = maxTries;
  319. while (tries > 0)
  320. {
  321. if (tries-- != maxTries)
  322. Logger.updater.Debug("Re-trying download...");
  323. using (var stream = new MemoryStream())
  324. using (var request = UnityWebRequest.Get(url))
  325. using (var taskTokenSource = new CancellationTokenSource())
  326. {
  327. var dlh = new StreamDownloadHandler(stream);
  328. request.downloadHandler = dlh;
  329. Logger.updater.Debug("Sending request");
  330. //Logger.updater.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  331. yield return request.SendWebRequest();
  332. Logger.updater.Debug("Download finished");
  333. if (request.isNetworkError)
  334. {
  335. Logger.updater.Error("Network error while trying to update mod");
  336. Logger.updater.Error(request.error);
  337. taskTokenSource.Cancel();
  338. continue;
  339. }
  340. if (request.isHttpError)
  341. {
  342. Logger.updater.Error("Server returned an error code while trying to update mod");
  343. Logger.updater.Error(request.error);
  344. taskTokenSource.Cancel();
  345. continue;
  346. }
  347. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  348. var downloadTask = Task.Run(() =>
  349. { // use slightly more multi threaded approach than co-routines
  350. // ReSharper disable once AccessToDisposedClosure
  351. ExtractPluginAsync(stream, item, platformFile);
  352. }, taskTokenSource.Token);
  353. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  354. yield return null; // pause co-routine until task is done
  355. if (downloadTask.IsFaulted)
  356. {
  357. if (downloadTask.Exception != null && downloadTask.Exception.InnerExceptions.Any(e => e is BeatmodsInterceptException))
  358. { // any exception is an intercept exception
  359. Logger.updater.Error($"Modsaber did not return expected data for {item.Name}");
  360. }
  361. Logger.updater.Error($"Error downloading mod {item.Name}");
  362. Logger.updater.Error(downloadTask.Exception);
  363. continue;
  364. }
  365. break;
  366. }
  367. }
  368. if (tries == 0)
  369. Logger.updater.Warn($"Plugin download failed {maxTries} times, not re-trying");
  370. else
  371. Logger.updater.Debug("Download complete");
  372. }
  373. internal class StreamDownloadHandler : DownloadHandlerScript
  374. {
  375. public MemoryStream Stream { get; set; }
  376. public StreamDownloadHandler(MemoryStream stream)
  377. {
  378. Stream = stream;
  379. }
  380. protected override void ReceiveContentLength(int contentLength)
  381. {
  382. Stream.Capacity = contentLength;
  383. Logger.updater.Debug($"Got content length: {contentLength}");
  384. }
  385. protected override void CompleteContent()
  386. {
  387. Logger.updater.Debug("Download complete");
  388. }
  389. protected override bool ReceiveData(byte[] rData, int dataLength)
  390. {
  391. if (rData == null || rData.Length < 1)
  392. {
  393. Logger.updater.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  394. return false;
  395. }
  396. Stream.Write(rData, 0, dataLength);
  397. return true;
  398. }
  399. protected override byte[] GetData() { return null; }
  400. protected override float GetProgress()
  401. {
  402. return 0f;
  403. }
  404. public override string ToString()
  405. {
  406. return $"{base.ToString()} ({Stream})";
  407. }
  408. }
  409. private void ExtractPluginAsync(MemoryStream stream, DependencyObject item, ApiEndpoint.Mod.DownloadsObject fileInfo)
  410. { // (3.3)
  411. Logger.updater.Debug($"Extracting ZIP file for {item.Name}");
  412. /*var data = stream.GetBuffer();
  413. SHA1 sha = new SHA1CryptoServiceProvider();
  414. var hash = sha.ComputeHash(data);
  415. if (!Utils.UnsafeCompare(hash, fileInfo.Hash))
  416. throw new Exception("The hash for the file doesn't match what is defined");*/
  417. var targetDir = Path.Combine(BeatSaber.InstallPath, "IPA", Path.GetRandomFileName() + "_Pending");
  418. Directory.CreateDirectory(targetDir);
  419. var eventualOutput = Path.Combine(BeatSaber.InstallPath, "IPA", "Pending");
  420. if (!Directory.Exists(eventualOutput))
  421. Directory.CreateDirectory(eventualOutput);
  422. try
  423. {
  424. bool shouldDeleteOldFile = !(item.LocalPluginMeta?.Metadata.IsSelf).Unwrap();
  425. using (var zipFile = ZipFile.Read(stream))
  426. {
  427. Logger.updater.Debug("Streams opened");
  428. foreach (var entry in zipFile)
  429. {
  430. if (entry.IsDirectory)
  431. {
  432. Logger.updater.Debug($"Creating directory {entry.FileName}");
  433. Directory.CreateDirectory(Path.Combine(targetDir, entry.FileName));
  434. }
  435. else
  436. {
  437. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  438. {
  439. entry.Extract(ostream);
  440. ostream.Seek(0, SeekOrigin.Begin);
  441. var md5 = new MD5CryptoServiceProvider();
  442. var fileHash = md5.ComputeHash(ostream);
  443. try
  444. {
  445. if (!Utils.UnsafeCompare(fileHash, fileInfo.Hashes.Where(h => h.File == entry.FileName).Select(h => h.Hash).First()))
  446. throw new Exception("The hash for the file doesn't match what is defined");
  447. }
  448. catch (KeyNotFoundException)
  449. {
  450. throw new BeatmodsInterceptException("BeatMods did not send the hashes for the zip's content!");
  451. }
  452. ostream.Seek(0, SeekOrigin.Begin);
  453. FileInfo targetFile = new FileInfo(Path.Combine(targetDir, entry.FileName));
  454. Directory.CreateDirectory(targetFile.DirectoryName ?? throw new InvalidOperationException());
  455. if (item.LocalPluginMeta != null &&
  456. Utils.GetRelativePath(targetFile.FullName, targetDir) == Utils.GetRelativePath(item.LocalPluginMeta?.Metadata.File.FullName, BeatSaber.InstallPath))
  457. shouldDeleteOldFile = false; // overwriting old file, no need to delete
  458. /*if (targetFile.Exists)
  459. backup.Add(targetFile);
  460. else
  461. newFiles.Add(targetFile);*/
  462. Logger.updater.Debug($"Extracting file {targetFile.FullName}");
  463. targetFile.Delete();
  464. using (var fstream = targetFile.Create())
  465. ostream.CopyTo(fstream);
  466. }
  467. }
  468. }
  469. }
  470. if (shouldDeleteOldFile && item.LocalPluginMeta != null)
  471. File.AppendAllLines(Path.Combine(targetDir, SpecialDeletionsFile), new[] { Utils.GetRelativePath(item.LocalPluginMeta?.Metadata.File.FullName, BeatSaber.InstallPath) });
  472. }
  473. catch (Exception)
  474. { // something failed; restore
  475. /*foreach (var file in newFiles)
  476. file.Delete();
  477. backup.Restore();
  478. backup.Delete();*/
  479. Directory.Delete(targetDir, true); // delete extraction site
  480. throw;
  481. }
  482. if ((item.LocalPluginMeta?.Metadata.IsSelf).Unwrap())
  483. { // currently updating self, so copy to working dir and update
  484. Utils.CopyAll(new DirectoryInfo(targetDir), new DirectoryInfo(BeatSaber.InstallPath));
  485. var deleteFile = Path.Combine(BeatSaber.InstallPath, SpecialDeletionsFile);
  486. if (File.Exists(deleteFile)) File.Delete(deleteFile);
  487. Process.Start(new ProcessStartInfo
  488. {
  489. // will never actually be null
  490. FileName = item.LocalPluginMeta?.Metadata.File.FullName ?? throw new InvalidOperationException(),
  491. Arguments = $"-nw={Process.GetCurrentProcess().Id}",
  492. UseShellExecute = false
  493. });
  494. }
  495. else
  496. Utils.CopyAll(new DirectoryInfo(targetDir), new DirectoryInfo(eventualOutput), SpecialDeletionsFile);
  497. Directory.Delete(targetDir, true); // delete extraction site
  498. Logger.updater.Debug("Extractor exited");
  499. }
  500. internal const string SpecialDeletionsFile = "$$delete";
  501. }
  502. [Serializable]
  503. internal class NetworkException : Exception
  504. {
  505. public NetworkException()
  506. {
  507. }
  508. public NetworkException(string message) : base(message)
  509. {
  510. }
  511. public NetworkException(string message, Exception innerException) : base(message, innerException)
  512. {
  513. }
  514. protected NetworkException(SerializationInfo info, StreamingContext context) : base(info, context)
  515. {
  516. }
  517. }
  518. [Serializable]
  519. internal class BeatmodsInterceptException : Exception
  520. {
  521. public BeatmodsInterceptException()
  522. {
  523. }
  524. public BeatmodsInterceptException(string message) : base(message)
  525. {
  526. }
  527. public BeatmodsInterceptException(string message, Exception innerException) : base(message, innerException)
  528. {
  529. }
  530. protected BeatmodsInterceptException(SerializationInfo info, StreamingContext context) : base(info, context)
  531. {
  532. }
  533. }
  534. }