You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

533 lines
20 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. using IPA.Utilities;
  2. using IPA.Loader;
  3. using Ionic.Zip;
  4. using Newtonsoft.Json;
  5. using System;
  6. using System.Collections;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using System.IO;
  10. using System.Linq;
  11. using System.Security.Cryptography;
  12. using System.Text;
  13. using System.Text.RegularExpressions;
  14. using System.Threading;
  15. using System.Threading.Tasks;
  16. using UnityEngine;
  17. using UnityEngine.Networking;
  18. using SemVer;
  19. using Logger = IPA.Logging.Logger;
  20. using Version = SemVer.Version;
  21. using IPA.Updating.Backup;
  22. using System.Runtime.Serialization;
  23. using System.Reflection;
  24. using static IPA.Loader.PluginManager;
  25. namespace IPA.Updating.ModsaberML
  26. {
  27. class Updater : MonoBehaviour
  28. {
  29. public static Updater instance;
  30. public void Awake()
  31. {
  32. try
  33. {
  34. if (instance != null)
  35. Destroy(this);
  36. else
  37. {
  38. instance = this;
  39. CheckForUpdates();
  40. }
  41. }
  42. catch (Exception e)
  43. {
  44. Logger.updater.Error(e);
  45. }
  46. }
  47. private void CheckForUpdates()
  48. {
  49. StartCoroutine(CheckForUpdatesCoroutine());
  50. }
  51. private class DependencyObject
  52. {
  53. public string Name { get; set; }
  54. public Version Version { get; set; } = null;
  55. public Version ResolvedVersion { get; set; } = null;
  56. public Range Requirement { get; set; } = null;
  57. public bool Resolved { get; set; } = false;
  58. public bool Has { get; set; } = false;
  59. public HashSet<string> Consumers { get; set; } = new HashSet<string>();
  60. public bool MetaRequestFailed { get; set; } = false;
  61. public BSPluginMeta LocalPluginMeta { get; set; } = null;
  62. public override string ToString()
  63. {
  64. return $"{Name}@{Version}{(Resolved ? $" -> {ResolvedVersion}" : "")} - ({Requirement}) {(Has ? $" Already have" : "")}";
  65. }
  66. }
  67. private Dictionary<string, ApiEndpoint.Mod> requestCache = new Dictionary<string, ApiEndpoint.Mod>();
  68. private IEnumerator DownloadModInfo(string name, string ver, Ref<ApiEndpoint.Mod> result)
  69. {
  70. var uri = ApiEndpoint.ApiBase + string.Format(ApiEndpoint.GetApprovedEndpoint, name, ver);
  71. if (requestCache.TryGetValue(uri, out ApiEndpoint.Mod value))
  72. {
  73. result.Value = value;
  74. yield break;
  75. }
  76. else
  77. {
  78. using (var request = UnityWebRequest.Get(uri))
  79. {
  80. yield return request.SendWebRequest();
  81. if (request.isNetworkError)
  82. {
  83. result.Error = new NetworkException($"Network error while trying to download: {request.error}");
  84. yield break;
  85. }
  86. if (request.isHttpError)
  87. {
  88. if (request.responseCode == 404)
  89. {
  90. result.Error = new NetworkException("Not found");
  91. yield break;
  92. }
  93. result.Error = new NetworkException($"Server returned error {request.error} while getting data");
  94. yield break;
  95. }
  96. try
  97. {
  98. result.Value = JsonConvert.DeserializeObject<ApiEndpoint.Mod>(request.downloadHandler.text);
  99. requestCache[uri] = result.Value;
  100. }
  101. catch (Exception e)
  102. {
  103. result.Error = new Exception("Error decoding response", e);
  104. yield break;
  105. }
  106. }
  107. }
  108. }
  109. private IEnumerator CheckForUpdatesCoroutine()
  110. {
  111. var depList = new Ref<List<DependencyObject>>(new List<DependencyObject>());
  112. foreach (var plugin in BSMetas)
  113. { // initialize with data to resolve (1.1)
  114. if (plugin.ModsaberInfo != null)
  115. { // updatable
  116. var msinfo = plugin.ModsaberInfo;
  117. depList.Value.Add(new DependencyObject {
  118. Name = msinfo.InternalName,
  119. Version = new Version(msinfo.CurrentVersion),
  120. Requirement = new Range($">={msinfo.CurrentVersion}"),
  121. LocalPluginMeta = plugin
  122. });
  123. }
  124. }
  125. foreach (var dep in depList.Value)
  126. Logger.updater.Debug($"Phantom Dependency: {dep.ToString()}");
  127. yield return DependencyResolveFirstPass(depList);
  128. foreach (var dep in depList.Value)
  129. Logger.updater.Debug($"Dependency: {dep.ToString()}");
  130. yield return DependencyResolveSecondPass(depList);
  131. foreach (var dep in depList.Value)
  132. Logger.updater.Debug($"Dependency: {dep.ToString()}");
  133. DependendyResolveFinalPass(depList);
  134. }
  135. private IEnumerator DependencyResolveFirstPass(Ref<List<DependencyObject>> list)
  136. {
  137. for (int i = 0; i < list.Value.Count; i++)
  138. { // Grab dependencies (1.2)
  139. var dep = list.Value[i];
  140. var mod = new Ref<ApiEndpoint.Mod>(null);
  141. #region TEMPORARY get latest // SHOULD BE GREATEST OF VERSION
  142. yield return DownloadModInfo(dep.Name, "", mod);
  143. #endregion
  144. try { mod.Verify(); }
  145. catch (Exception e)
  146. {
  147. Logger.updater.Error($"Error getting info for {dep.Name}");
  148. Logger.updater.Error(e);
  149. dep.MetaRequestFailed = true;
  150. continue;
  151. }
  152. list.Value.AddRange(mod.Value.Dependencies.Select(d => new DependencyObject { Name = d.Name, Requirement = d.VersionRange, Consumers = new HashSet<string>() { dep.Name } }));
  153. }
  154. var depNames = new HashSet<string>();
  155. var final = new List<DependencyObject>();
  156. foreach (var dep in list.Value)
  157. { // agregate ranges and the like (1.3)
  158. if (!depNames.Contains(dep.Name))
  159. { // should add it
  160. depNames.Add(dep.Name);
  161. final.Add(dep);
  162. }
  163. else
  164. {
  165. var toMod = final.Where(d => d.Name == dep.Name).First();
  166. toMod.Requirement = toMod.Requirement.Intersect(dep.Requirement);
  167. foreach (var consume in dep.Consumers)
  168. toMod.Consumers.Add(consume);
  169. }
  170. }
  171. list.Value = final;
  172. }
  173. private IEnumerator DependencyResolveSecondPass(Ref<List<DependencyObject>> list)
  174. {
  175. IEnumerator GetGameVersionMap(string modname, Ref<Dictionary<Version,Version>> map)
  176. { // gets map of mod version -> game version (2.0)
  177. map.Value = new Dictionary<Version, Version>();
  178. var mod = new Ref<ApiEndpoint.Mod>(null);
  179. yield return DownloadModInfo(modname, "", mod);
  180. try { mod.Verify(); }
  181. catch (Exception)
  182. {
  183. map.Value = null;
  184. map.Error = new Exception($"Error getting info for {modname}", mod.Error);
  185. yield break;
  186. }
  187. map.Value.Add(mod.Value.Version, mod.Value.GameVersion);
  188. foreach (var ver in mod.Value.OldVersions)
  189. {
  190. yield return DownloadModInfo(modname, ver.ToString(), mod);
  191. try { mod.Verify(); }
  192. catch (Exception e)
  193. {
  194. Logger.updater.Error($"Error getting info for {modname}v{ver}");
  195. Logger.updater.Error(e);
  196. continue;
  197. }
  198. map.Value.Add(mod.Value.Version, mod.Value.GameVersion);
  199. }
  200. }
  201. foreach(var dep in list.Value)
  202. {
  203. dep.Has = dep.Version != null;// dep.Version is only not null if its already installed
  204. if (dep.MetaRequestFailed)
  205. {
  206. Logger.updater.Warn($"{dep.Name} info request failed, not trying again");
  207. continue;
  208. }
  209. var dict = new Ref<Dictionary<Version, Version>>(null);
  210. yield return GetGameVersionMap(dep.Name, dict);
  211. try { dict.Verify(); }
  212. catch (Exception e)
  213. {
  214. Logger.updater.Error($"Error getting map for {dep.Name}");
  215. Logger.updater.Error(e);
  216. continue;
  217. }
  218. var ver = dep.Requirement.MaxSatisfying(dict.Value.Where(kvp => kvp.Value == BeatSaber.GameVersion).Select(kvp => kvp.Key)); // (2.1)
  219. if (dep.Resolved = ver != null) dep.ResolvedVersion = ver; // (2.2)
  220. dep.Has = dep.Version == dep.ResolvedVersion && dep.Resolved; // dep.Version is only not null if its already installed
  221. }
  222. }
  223. private void DependendyResolveFinalPass(Ref<List<DependencyObject>> list)
  224. { // also starts download of mods
  225. var toDl = new List<DependencyObject>();
  226. foreach (var dep in list.Value)
  227. { // figure out which ones need to be downloaded (3.1)
  228. if (dep.Resolved)
  229. {
  230. Logger.updater.Debug($"Resolved: {dep.ToString()}");
  231. if (!dep.Has)
  232. {
  233. Logger.updater.Debug($"To Download: {dep.ToString()}");
  234. toDl.Add(dep);
  235. }
  236. }
  237. else if (!dep.Has)
  238. {
  239. Logger.updater.Warn($"Could not resolve dependency {dep}");
  240. }
  241. }
  242. Logger.updater.Debug($"To Download {string.Join(", ", toDl.Select(d => $"{d.Name}@{d.ResolvedVersion}"))}");
  243. string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName() + Path.GetRandomFileName());
  244. Directory.CreateDirectory(tempDirectory);
  245. Logger.updater.Debug($"Temp directory: {tempDirectory}");
  246. foreach (var item in toDl)
  247. StartCoroutine(UpdateModCoroutine(item, tempDirectory));
  248. }
  249. private IEnumerator UpdateModCoroutine(DependencyObject item, string tempDirectory)
  250. { // (3.2)
  251. Logger.updater.Debug($"Release: {BeatSaber.ReleaseType}");
  252. var mod = new Ref<ApiEndpoint.Mod>(null);
  253. yield return DownloadModInfo(item.Name, item.ResolvedVersion.ToString(), mod);
  254. try { mod.Verify(); }
  255. catch (Exception e)
  256. {
  257. Logger.updater.Error($"Error occurred while trying to get information for {item}");
  258. Logger.updater.Error(e);
  259. yield break;
  260. }
  261. ApiEndpoint.Mod.PlatformFile platformFile;
  262. if (BeatSaber.ReleaseType == BeatSaber.Release.Steam || mod.Value.Files.Oculus == null)
  263. platformFile = mod.Value.Files.Steam;
  264. else
  265. platformFile = mod.Value.Files.Oculus;
  266. string url = platformFile.DownloadPath;
  267. Logger.updater.Debug($"URL = {url}");
  268. const int MaxTries = 3;
  269. int maxTries = MaxTries;
  270. while (maxTries > 0)
  271. {
  272. if (maxTries-- != MaxTries)
  273. Logger.updater.Debug($"Re-trying download...");
  274. using (var stream = new MemoryStream())
  275. using (var request = UnityWebRequest.Get(url))
  276. using (var taskTokenSource = new CancellationTokenSource())
  277. {
  278. var dlh = new StreamDownloadHandler(stream);
  279. request.downloadHandler = dlh;
  280. Logger.updater.Debug("Sending request");
  281. //Logger.updater.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  282. yield return request.SendWebRequest();
  283. Logger.updater.Debug("Download finished");
  284. if (request.isNetworkError)
  285. {
  286. Logger.updater.Error("Network error while trying to update mod");
  287. Logger.updater.Error(request.error);
  288. taskTokenSource.Cancel();
  289. continue;
  290. }
  291. if (request.isHttpError)
  292. {
  293. Logger.updater.Error($"Server returned an error code while trying to update mod");
  294. Logger.updater.Error(request.error);
  295. taskTokenSource.Cancel();
  296. continue;
  297. }
  298. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  299. var downloadTask = Task.Run(() =>
  300. { // use slightly more multithreaded approach than coroutines
  301. ExtractPluginAsync(stream, item, platformFile, tempDirectory);
  302. }, taskTokenSource.Token);
  303. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  304. yield return null; // pause coroutine until task is done
  305. if (downloadTask.IsFaulted)
  306. {
  307. Logger.updater.Error($"Error downloading mod {item.Name}");
  308. Logger.updater.Error(downloadTask.Exception);
  309. continue;
  310. }
  311. break;
  312. }
  313. }
  314. if (maxTries == 0)
  315. Logger.updater.Warn($"Plugin download failed {MaxTries} times, not re-trying");
  316. else
  317. Logger.updater.Debug("Download complete");
  318. }
  319. internal class StreamDownloadHandler : DownloadHandlerScript
  320. {
  321. public MemoryStream Stream { get; set; }
  322. public StreamDownloadHandler(MemoryStream stream) : base()
  323. {
  324. Stream = stream;
  325. }
  326. protected override void ReceiveContentLength(int contentLength)
  327. {
  328. Stream.Capacity = contentLength;
  329. Logger.updater.Debug($"Got content length: {contentLength}");
  330. }
  331. protected override void CompleteContent()
  332. {
  333. Logger.updater.Debug("Download complete");
  334. }
  335. protected override bool ReceiveData(byte[] data, int dataLength)
  336. {
  337. if (data == null || data.Length < 1)
  338. {
  339. Logger.updater.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  340. return false;
  341. }
  342. Stream.Write(data, 0, dataLength);
  343. return true;
  344. }
  345. protected override byte[] GetData() { return null; }
  346. protected override float GetProgress()
  347. {
  348. return 0f;
  349. }
  350. public override string ToString()
  351. {
  352. return $"{base.ToString()} ({Stream?.ToString()})";
  353. }
  354. }
  355. private void ExtractPluginAsync(MemoryStream stream, DependencyObject item, ApiEndpoint.Mod.PlatformFile fileInfo, string tempDirectory)
  356. { // (3.3)
  357. Logger.updater.Debug($"Extracting ZIP file for {item.Name}");
  358. var data = stream.GetBuffer();
  359. SHA1 sha = new SHA1CryptoServiceProvider();
  360. var hash = sha.ComputeHash(data);
  361. if (!LoneFunctions.UnsafeCompare(hash, fileInfo.Hash))
  362. throw new Exception("The hash for the file doesn't match what is defined");
  363. var newFiles = new List<FileInfo>();
  364. var backup = new BackupUnit(tempDirectory, $"backup-{item.Name}");
  365. try
  366. {
  367. bool shouldDeleteOldFile = true;
  368. using (var zipFile = ZipFile.Read(stream))
  369. {
  370. Logger.updater.Debug("Streams opened");
  371. foreach (var entry in zipFile)
  372. {
  373. if (entry.IsDirectory)
  374. {
  375. Logger.updater.Debug($"Creating directory {entry.FileName}");
  376. Directory.CreateDirectory(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  377. }
  378. else
  379. {
  380. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  381. {
  382. entry.Extract(ostream);
  383. ostream.Seek(0, SeekOrigin.Begin);
  384. sha = new SHA1CryptoServiceProvider();
  385. var fileHash = sha.ComputeHash(ostream);
  386. if (!LoneFunctions.UnsafeCompare(fileHash, fileInfo.FileHashes[entry.FileName]))
  387. throw new Exception("The hash for the file doesn't match what is defined");
  388. ostream.Seek(0, SeekOrigin.Begin);
  389. FileInfo targetFile = new FileInfo(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  390. Directory.CreateDirectory(targetFile.DirectoryName);
  391. if (targetFile.FullName == item.LocalPluginMeta?.Filename)
  392. shouldDeleteOldFile = false; // overwriting old file, no need to delete
  393. if (targetFile.Exists)
  394. backup.Add(targetFile);
  395. else
  396. newFiles.Add(targetFile);
  397. Logger.updater.Debug($"Extracting file {targetFile.FullName}");
  398. targetFile.Delete();
  399. var fstream = targetFile.Create();
  400. ostream.CopyTo(fstream);
  401. }
  402. }
  403. }
  404. }
  405. if (item.LocalPluginMeta?.Plugin is SelfPlugin)
  406. { // currently updating self
  407. Process.Start(new ProcessStartInfo
  408. {
  409. FileName = item.LocalPluginMeta.Filename,
  410. Arguments = $"-nw={Process.GetCurrentProcess().Id}",
  411. UseShellExecute = false
  412. });
  413. }
  414. else if (shouldDeleteOldFile && item.LocalPluginMeta != null)
  415. File.Delete(item.LocalPluginMeta.Filename);
  416. }
  417. catch (Exception)
  418. { // something failed; restore
  419. foreach (var file in newFiles)
  420. file.Delete();
  421. backup.Restore();
  422. backup.Delete();
  423. throw;
  424. }
  425. backup.Delete();
  426. Logger.updater.Debug("Extractor exited");
  427. }
  428. }
  429. [Serializable]
  430. internal class NetworkException : Exception
  431. {
  432. public NetworkException()
  433. {
  434. }
  435. public NetworkException(string message) : base(message)
  436. {
  437. }
  438. public NetworkException(string message, Exception innerException) : base(message, innerException)
  439. {
  440. }
  441. protected NetworkException(SerializationInfo info, StreamingContext context) : base(info, context)
  442. {
  443. }
  444. }
  445. }