server.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python
  2. import gunicorn.app.base
  3. import sys
  4. from predictor import Predictor, Prediction
  5. class CustomUnicornApp(gunicorn.app.base.BaseApplication):
  6. """
  7. This gunicorn app class provides create and exit callbacks for workers,
  8. and runs gunicorn with a single worker and multiple gthreads
  9. """
  10. def __init__(self, create_app_callback, exit_app_callback, host_port):
  11. self._configBind = host_port
  12. self._createAppCallback = create_app_callback
  13. self._exitAppCallback = exit_app_callback
  14. super().__init__()
  15. @staticmethod
  16. def exitWorker(arbiter, worker):
  17. # worker.app provides us with a reference to "self", and we can call the
  18. # exit callback with the object created by the createAppCallback:
  19. self = worker.app
  20. self._exitAppCallback(self._createdApp)
  21. def load_config(self):
  22. self.cfg.set("bind", self._configBind)
  23. self.cfg.set("worker_class", "gthread")
  24. self.cfg.set("workers", 1)
  25. self.cfg.set("threads", 4)
  26. self.cfg.set("worker_exit", CustomUnicornApp.exitWorker)
  27. # Try to uncomment and make 10 requests, to test correct restart of worker:
  28. # self.cfg.set("max_requests", 10)
  29. def load(self):
  30. # This function is invoked when a worker is booted
  31. self._createdApp = self._createAppCallback()
  32. return self._createdApp
  33. # --- index.html contents
  34. html = """<!DOCTYPE html>
  35. <html lang="en">
  36. <head>
  37. <meta charset="UTF-8">
  38. <title>Cell Growth Classifier</title>
  39. <script>
  40. function uploadImages() {
  41. var formData = new FormData();
  42. var imageFiles = document.getElementById("imageInput").files;
  43. var imageFilesArray = [];
  44. for (var i = 0; i < imageFiles.length; i++) {
  45. formData.append("images", imageFiles[i]);
  46. imageFilesArray.push({name: imageFiles[i].name, file: imageFiles[i]});
  47. }
  48. fetch('/upload', {
  49. method: 'POST',
  50. body: formData
  51. })
  52. .then(response => response.json())
  53. .then(data => {
  54. var results = data.results;
  55. var sortedKeys = Object.keys(results).sort(); // Sort the filenames
  56. var table = document.getElementById("resultsTable");
  57. table.innerHTML = ""; // Clear the table first
  58. // Create table header
  59. var header = table.createTHead();
  60. var headerRow = header.insertRow(0);
  61. headerRow.insertCell(0).textContent = "Image";
  62. headerRow.insertCell(1).textContent = "Filename";
  63. headerRow.insertCell(2).textContent = "Analysis";
  64. // Create table body
  65. var tbody = table.createTBody();
  66. // Insert a row in the table for each sorted image result
  67. sortedKeys.forEach((imageName) => {
  68. var imageFile = imageFilesArray.find(file => file.name === imageName);
  69. var row = tbody.insertRow(-1);
  70. var cellImage = row.insertCell(0);
  71. var image = document.createElement("img");
  72. if (imageFile) {
  73. image.src = URL.createObjectURL(imageFile.file);
  74. image.onload = function() {
  75. URL.revokeObjectURL(this.src) // Free up memory
  76. }
  77. } else {
  78. image.alt = "Image not found";
  79. }
  80. image.style.width = '100px'; // Set the image size
  81. cellImage.appendChild(image);
  82. row.insertCell(1).textContent = imageName;
  83. row.insertCell(2).textContent = results[imageName].text;
  84. });
  85. // Set the model used
  86. document.getElementById("modelUsed").textContent = "Model used: " + data.model;
  87. })
  88. .catch(error => {
  89. console.error('Error:', error);
  90. document.getElementById("results").textContent = "An error occurred while uploading the images.";
  91. });
  92. }
  93. </script>
  94. <style>
  95. table {
  96. border-collapse: collapse;
  97. margin-top: 20px;
  98. }
  99. thead, td {
  100. border: 1px solid #ddd;
  101. padding: 8px;
  102. text-align: left;
  103. }
  104. thead {
  105. background-color: #f2f2f2;
  106. color: #333;
  107. font-weight: bold;
  108. }
  109. td {
  110. font-size: 0.9em;
  111. }
  112. img {
  113. width: 100px; /* Adjust the size of the image if necessary */
  114. height: auto;
  115. }
  116. #modelUsed {
  117. margin-top: 20px;
  118. }
  119. </style>
  120. </head>
  121. <body>
  122. <h2>Cell Growth Classifier</h2>
  123. <input type="file" id="imageInput" accept="image/*" multiple>
  124. <button onclick="uploadImages()">Analyze</button>
  125. <p id="modelUsed"></p>
  126. <table id="resultsTable">
  127. </table>
  128. </body>
  129. </html>
  130. """
  131. # ----
  132. from PIL import Image
  133. from bottle import Bottle, request, response
  134. import threading
  135. import io
  136. def startServer(index_html:str, predictor:Predictor, host_port:str):
  137. def create():
  138. app = Bottle()
  139. lock = threading.Lock()
  140. @app.route("/")
  141. def getIndex():
  142. # Serve static content, no lock protection necessary:
  143. return index_html
  144. @app.route('/upload', method='POST')
  145. def upload_image():
  146. uploads = request.files.getall('images')
  147. results = {}
  148. for upload in uploads:
  149. # Read the image file in bytes and open it with Pillow
  150. image_bytes = io.BytesIO(upload.file.read())
  151. try:
  152. with Image.open(image_bytes) as img:
  153. pred:Prediction = app.predictor.predict(img)
  154. results[upload.filename] = pred.getDict()
  155. except IOError:
  156. response.status = 400
  157. return "Invalid image file"
  158. return {"model": app.predictor.modelName, "results": results}
  159. # Store predictor in the bottle app object
  160. app.predictor = predictor
  161. return app
  162. def exit(app):
  163. # Get the service through the app object and save state
  164. pass
  165. CustomUnicornApp(create, exit, host_port).run()
  166. def usage():
  167. print("""Usage:
  168. server.py modelfile host_port
  169. Example:
  170. ./server.py cells_2.pth localhost:8001
  171. """)
  172. if __name__ == "__main__":
  173. if len(sys.argv) != 3:
  174. usage()
  175. else:
  176. p = Predictor(sys.argv[1])
  177. startServer(html, p, sys.argv[2])