You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

534 lines
20 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. using IPA.Utilities;
  2. using IPA.Loader;
  3. using Ionic.Zip;
  4. using Newtonsoft.Json;
  5. using System;
  6. using System.Collections;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using System.IO;
  10. using System.Linq;
  11. using System.Security.Cryptography;
  12. using System.Text;
  13. using System.Text.RegularExpressions;
  14. using System.Threading;
  15. using System.Threading.Tasks;
  16. using UnityEngine;
  17. using UnityEngine.Networking;
  18. using SemVer;
  19. using Logger = IPA.Logging.Logger;
  20. using Version = SemVer.Version;
  21. using IPA.Updating.Backup;
  22. using System.Runtime.Serialization;
  23. using System.Reflection;
  24. using static IPA.Loader.PluginManager;
  25. namespace IPA.Updating.ModsaberML
  26. {
  27. class Updater : MonoBehaviour
  28. {
  29. public static Updater instance;
  30. public void Awake()
  31. {
  32. try
  33. {
  34. if (instance != null)
  35. Destroy(this);
  36. else
  37. {
  38. instance = this;
  39. CheckForUpdates();
  40. }
  41. }
  42. catch (Exception e)
  43. {
  44. Logger.updater.Error(e);
  45. }
  46. }
  47. private void CheckForUpdates()
  48. {
  49. StartCoroutine(CheckForUpdatesCoroutine());
  50. }
  51. private class DependencyObject
  52. {
  53. public string Name { get; set; }
  54. public Version Version { get; set; } = null;
  55. public Version ResolvedVersion { get; set; } = null;
  56. public Range Requirement { get; set; } = null;
  57. public bool Resolved { get; set; } = false;
  58. public bool Has { get; set; } = false;
  59. public HashSet<string> Consumers { get; set; } = new HashSet<string>();
  60. public bool MetaRequestFailed { get; set; } = false;
  61. public BSPluginMeta LocalPluginMeta { get; set; } = null;
  62. public override string ToString()
  63. {
  64. return $"{Name}@{Version}{(Resolved ? $" -> {ResolvedVersion}" : "")} - ({Requirement}) {(Has ? $" Already have" : "")}";
  65. }
  66. }
  67. private Dictionary<string, ApiEndpoint.Mod> requestCache = new Dictionary<string, ApiEndpoint.Mod>();
  68. private IEnumerator DownloadModInfo(string name, string ver, Ref<ApiEndpoint.Mod> result)
  69. {
  70. var uri = ApiEndpoint.ApiBase + string.Format(ApiEndpoint.GetApprovedEndpoint, name, ver);
  71. if (requestCache.TryGetValue(uri, out ApiEndpoint.Mod value))
  72. {
  73. result.Value = value;
  74. yield break;
  75. }
  76. else
  77. {
  78. using (var request = UnityWebRequest.Get(uri))
  79. {
  80. yield return request.SendWebRequest();
  81. if (request.isNetworkError)
  82. {
  83. result.Error = new NetworkException($"Network error while trying to download: {request.error}");
  84. yield break;
  85. }
  86. if (request.isHttpError)
  87. {
  88. if (request.responseCode == 404)
  89. {
  90. result.Error = new NetworkException("Not found");
  91. yield break;
  92. }
  93. result.Error = new NetworkException($"Server returned error {request.error} while getting data");
  94. yield break;
  95. }
  96. try
  97. {
  98. result.Value = JsonConvert.DeserializeObject<ApiEndpoint.Mod>(request.downloadHandler.text);
  99. requestCache[uri] = result.Value;
  100. }
  101. catch (Exception e)
  102. {
  103. result.Error = new Exception("Error decoding response", e);
  104. yield break;
  105. }
  106. }
  107. }
  108. }
  109. private IEnumerator CheckForUpdatesCoroutine()
  110. {
  111. var depList = new Ref<List<DependencyObject>>(new List<DependencyObject>());
  112. foreach (var plugin in BSMetas)
  113. { // initialize with data to resolve (1.1)
  114. if (plugin.ModsaberInfo != null)
  115. { // updatable
  116. var msinfo = plugin.ModsaberInfo;
  117. depList.Value.Add(new DependencyObject {
  118. Name = msinfo.InternalName,
  119. Version = new Version(msinfo.CurrentVersion),
  120. Requirement = new Range($">={msinfo.CurrentVersion}"),
  121. LocalPluginMeta = plugin
  122. });
  123. }
  124. }
  125. foreach (var dep in depList.Value)
  126. Logger.updater.Debug($"Phantom Dependency: {dep.ToString()}");
  127. yield return DependencyResolveFirstPass(depList);
  128. foreach (var dep in depList.Value)
  129. Logger.updater.Debug($"Dependency: {dep.ToString()}");
  130. yield return DependencyResolveSecondPass(depList);
  131. foreach (var dep in depList.Value)
  132. Logger.updater.Debug($"Dependency: {dep.ToString()}");
  133. DependendyResolveFinalPass(depList);
  134. }
  135. private IEnumerator DependencyResolveFirstPass(Ref<List<DependencyObject>> list)
  136. {
  137. for (int i = 0; i < list.Value.Count; i++)
  138. { // Grab dependencies (1.2)
  139. var dep = list.Value[i];
  140. var mod = new Ref<ApiEndpoint.Mod>(null);
  141. #region TEMPORARY get latest // SHOULD BE GREATEST OF VERSION
  142. yield return DownloadModInfo(dep.Name, "", mod);
  143. #endregion
  144. try { mod.Verify(); }
  145. catch (Exception e)
  146. {
  147. Logger.updater.Error($"Error getting info for {dep.Name}");
  148. Logger.updater.Error(e);
  149. dep.MetaRequestFailed = true;
  150. continue;
  151. }
  152. list.Value.AddRange(mod.Value.Dependencies.Select(d => new DependencyObject { Name = d.Name, Requirement = d.VersionRange, Consumers = new HashSet<string>() { dep.Name } }));
  153. }
  154. var depNames = new HashSet<string>();
  155. var final = new List<DependencyObject>();
  156. foreach (var dep in list.Value)
  157. { // agregate ranges and the like (1.3)
  158. if (!depNames.Contains(dep.Name))
  159. { // should add it
  160. depNames.Add(dep.Name);
  161. final.Add(dep);
  162. }
  163. else
  164. {
  165. var toMod = final.Where(d => d.Name == dep.Name).First();
  166. toMod.Requirement = toMod.Requirement.Intersect(dep.Requirement);
  167. foreach (var consume in dep.Consumers)
  168. toMod.Consumers.Add(consume);
  169. }
  170. }
  171. list.Value = final;
  172. }
  173. private IEnumerator DependencyResolveSecondPass(Ref<List<DependencyObject>> list)
  174. {
  175. IEnumerator GetGameVersionMap(string modname, Ref<Dictionary<Version,Version>> map)
  176. { // gets map of mod version -> game version (2.0)
  177. map.Value = new Dictionary<Version, Version>();
  178. var mod = new Ref<ApiEndpoint.Mod>(null);
  179. yield return DownloadModInfo(modname, "", mod);
  180. try { mod.Verify(); }
  181. catch (Exception)
  182. {
  183. map.Value = null;
  184. map.Error = new Exception($"Error getting info for {modname}", mod.Error);
  185. yield break;
  186. }
  187. map.Value.Add(mod.Value.Version, mod.Value.GameVersion);
  188. foreach (var ver in mod.Value.OldVersions)
  189. {
  190. yield return DownloadModInfo(modname, ver.ToString(), mod);
  191. try { mod.Verify(); }
  192. catch (Exception e)
  193. {
  194. Logger.updater.Error($"Error getting info for {modname}v{ver}");
  195. Logger.updater.Error(e);
  196. continue;
  197. }
  198. map.Value.Add(mod.Value.Version, mod.Value.GameVersion);
  199. }
  200. }
  201. foreach(var dep in list.Value)
  202. {
  203. dep.Has = dep.Version != null;// dep.Version is only not null if its already installed
  204. if (dep.MetaRequestFailed)
  205. {
  206. Logger.updater.Warn($"{dep.Name} info request failed, not trying again");
  207. continue;
  208. }
  209. var dict = new Ref<Dictionary<Version, Version>>(null);
  210. yield return GetGameVersionMap(dep.Name, dict);
  211. try { dict.Verify(); }
  212. catch (Exception e)
  213. {
  214. Logger.updater.Error($"Error getting map for {dep.Name}");
  215. Logger.updater.Error(e);
  216. dep.MetaRequestFailed = true;
  217. continue;
  218. }
  219. var ver = dep.Requirement.MaxSatisfying(dict.Value.Where(kvp => kvp.Value == BeatSaber.GameVersion).Select(kvp => kvp.Key)); // (2.1)
  220. if (dep.Resolved = ver != null) dep.ResolvedVersion = ver; // (2.2)
  221. dep.Has = dep.Version == dep.ResolvedVersion && dep.Resolved; // dep.Version is only not null if its already installed
  222. }
  223. }
  224. private void DependendyResolveFinalPass(Ref<List<DependencyObject>> list)
  225. { // also starts download of mods
  226. var toDl = new List<DependencyObject>();
  227. foreach (var dep in list.Value)
  228. { // figure out which ones need to be downloaded (3.1)
  229. if (dep.Resolved)
  230. {
  231. Logger.updater.Debug($"Resolved: {dep.ToString()}");
  232. if (!dep.Has)
  233. {
  234. Logger.updater.Debug($"To Download: {dep.ToString()}");
  235. toDl.Add(dep);
  236. }
  237. }
  238. else if (!dep.Has)
  239. {
  240. Logger.updater.Warn($"Could not resolve dependency {dep}");
  241. }
  242. }
  243. Logger.updater.Debug($"To Download {string.Join(", ", toDl.Select(d => $"{d.Name}@{d.ResolvedVersion}"))}");
  244. string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName() + Path.GetRandomFileName());
  245. Directory.CreateDirectory(tempDirectory);
  246. Logger.updater.Debug($"Temp directory: {tempDirectory}");
  247. foreach (var item in toDl)
  248. StartCoroutine(UpdateModCoroutine(item, tempDirectory));
  249. }
  250. private IEnumerator UpdateModCoroutine(DependencyObject item, string tempDirectory)
  251. { // (3.2)
  252. Logger.updater.Debug($"Release: {BeatSaber.ReleaseType}");
  253. var mod = new Ref<ApiEndpoint.Mod>(null);
  254. yield return DownloadModInfo(item.Name, item.ResolvedVersion.ToString(), mod);
  255. try { mod.Verify(); }
  256. catch (Exception e)
  257. {
  258. Logger.updater.Error($"Error occurred while trying to get information for {item}");
  259. Logger.updater.Error(e);
  260. yield break;
  261. }
  262. ApiEndpoint.Mod.PlatformFile platformFile;
  263. if (BeatSaber.ReleaseType == BeatSaber.Release.Steam || mod.Value.Files.Oculus == null)
  264. platformFile = mod.Value.Files.Steam;
  265. else
  266. platformFile = mod.Value.Files.Oculus;
  267. string url = platformFile.DownloadPath;
  268. Logger.updater.Debug($"URL = {url}");
  269. const int MaxTries = 3;
  270. int maxTries = MaxTries;
  271. while (maxTries > 0)
  272. {
  273. if (maxTries-- != MaxTries)
  274. Logger.updater.Debug($"Re-trying download...");
  275. using (var stream = new MemoryStream())
  276. using (var request = UnityWebRequest.Get(url))
  277. using (var taskTokenSource = new CancellationTokenSource())
  278. {
  279. var dlh = new StreamDownloadHandler(stream);
  280. request.downloadHandler = dlh;
  281. Logger.updater.Debug("Sending request");
  282. //Logger.updater.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  283. yield return request.SendWebRequest();
  284. Logger.updater.Debug("Download finished");
  285. if (request.isNetworkError)
  286. {
  287. Logger.updater.Error("Network error while trying to update mod");
  288. Logger.updater.Error(request.error);
  289. taskTokenSource.Cancel();
  290. continue;
  291. }
  292. if (request.isHttpError)
  293. {
  294. Logger.updater.Error($"Server returned an error code while trying to update mod");
  295. Logger.updater.Error(request.error);
  296. taskTokenSource.Cancel();
  297. continue;
  298. }
  299. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  300. var downloadTask = Task.Run(() =>
  301. { // use slightly more multithreaded approach than coroutines
  302. ExtractPluginAsync(stream, item, platformFile, tempDirectory);
  303. }, taskTokenSource.Token);
  304. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  305. yield return null; // pause coroutine until task is done
  306. if (downloadTask.IsFaulted)
  307. {
  308. Logger.updater.Error($"Error downloading mod {item.Name}");
  309. Logger.updater.Error(downloadTask.Exception);
  310. continue;
  311. }
  312. break;
  313. }
  314. }
  315. if (maxTries == 0)
  316. Logger.updater.Warn($"Plugin download failed {MaxTries} times, not re-trying");
  317. else
  318. Logger.updater.Debug("Download complete");
  319. }
  320. internal class StreamDownloadHandler : DownloadHandlerScript
  321. {
  322. public MemoryStream Stream { get; set; }
  323. public StreamDownloadHandler(MemoryStream stream) : base()
  324. {
  325. Stream = stream;
  326. }
  327. protected override void ReceiveContentLength(int contentLength)
  328. {
  329. Stream.Capacity = contentLength;
  330. Logger.updater.Debug($"Got content length: {contentLength}");
  331. }
  332. protected override void CompleteContent()
  333. {
  334. Logger.updater.Debug("Download complete");
  335. }
  336. protected override bool ReceiveData(byte[] data, int dataLength)
  337. {
  338. if (data == null || data.Length < 1)
  339. {
  340. Logger.updater.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  341. return false;
  342. }
  343. Stream.Write(data, 0, dataLength);
  344. return true;
  345. }
  346. protected override byte[] GetData() { return null; }
  347. protected override float GetProgress()
  348. {
  349. return 0f;
  350. }
  351. public override string ToString()
  352. {
  353. return $"{base.ToString()} ({Stream?.ToString()})";
  354. }
  355. }
  356. private void ExtractPluginAsync(MemoryStream stream, DependencyObject item, ApiEndpoint.Mod.PlatformFile fileInfo, string tempDirectory)
  357. { // (3.3)
  358. Logger.updater.Debug($"Extracting ZIP file for {item.Name}");
  359. var data = stream.GetBuffer();
  360. SHA1 sha = new SHA1CryptoServiceProvider();
  361. var hash = sha.ComputeHash(data);
  362. if (!LoneFunctions.UnsafeCompare(hash, fileInfo.Hash))
  363. throw new Exception("The hash for the file doesn't match what is defined");
  364. var newFiles = new List<FileInfo>();
  365. var backup = new BackupUnit(tempDirectory, $"backup-{item.Name}");
  366. try
  367. {
  368. bool shouldDeleteOldFile = true;
  369. using (var zipFile = ZipFile.Read(stream))
  370. {
  371. Logger.updater.Debug("Streams opened");
  372. foreach (var entry in zipFile)
  373. {
  374. if (entry.IsDirectory)
  375. {
  376. Logger.updater.Debug($"Creating directory {entry.FileName}");
  377. Directory.CreateDirectory(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  378. }
  379. else
  380. {
  381. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  382. {
  383. entry.Extract(ostream);
  384. ostream.Seek(0, SeekOrigin.Begin);
  385. sha = new SHA1CryptoServiceProvider();
  386. var fileHash = sha.ComputeHash(ostream);
  387. if (!LoneFunctions.UnsafeCompare(fileHash, fileInfo.FileHashes[entry.FileName]))
  388. throw new Exception("The hash for the file doesn't match what is defined");
  389. ostream.Seek(0, SeekOrigin.Begin);
  390. FileInfo targetFile = new FileInfo(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  391. Directory.CreateDirectory(targetFile.DirectoryName);
  392. if (targetFile.FullName == item.LocalPluginMeta?.Filename)
  393. shouldDeleteOldFile = false; // overwriting old file, no need to delete
  394. if (targetFile.Exists)
  395. backup.Add(targetFile);
  396. else
  397. newFiles.Add(targetFile);
  398. Logger.updater.Debug($"Extracting file {targetFile.FullName}");
  399. targetFile.Delete();
  400. var fstream = targetFile.Create();
  401. ostream.CopyTo(fstream);
  402. }
  403. }
  404. }
  405. }
  406. if (item.LocalPluginMeta?.Plugin is SelfPlugin)
  407. { // currently updating self
  408. Process.Start(new ProcessStartInfo
  409. {
  410. FileName = item.LocalPluginMeta.Filename,
  411. Arguments = $"-nw={Process.GetCurrentProcess().Id}",
  412. UseShellExecute = false
  413. });
  414. }
  415. else if (shouldDeleteOldFile && item.LocalPluginMeta != null)
  416. File.Delete(item.LocalPluginMeta.Filename);
  417. }
  418. catch (Exception)
  419. { // something failed; restore
  420. foreach (var file in newFiles)
  421. file.Delete();
  422. backup.Restore();
  423. backup.Delete();
  424. throw;
  425. }
  426. backup.Delete();
  427. Logger.updater.Debug("Extractor exited");
  428. }
  429. }
  430. [Serializable]
  431. internal class NetworkException : Exception
  432. {
  433. public NetworkException()
  434. {
  435. }
  436. public NetworkException(string message) : base(message)
  437. {
  438. }
  439. public NetworkException(string message, Exception innerException) : base(message, innerException)
  440. {
  441. }
  442. protected NetworkException(SerializationInfo info, StreamingContext context) : base(info, context)
  443. {
  444. }
  445. }
  446. }