2 Komitmen e3e8625bbb ... d9823b24d3

Pembuat SHA1 Pesan Tanggal
  Lars d9823b24d3 Support for multiple images 2 tahun lalu
  Lars c012ace4ff Imports and requirements.txt simplified 2 tahun lalu
3 mengubah file dengan 109 tambahan dan 59 penghapusan
  1. 1 6
      requirements.txt
  2. 13 20
      src/predictor.py
  3. 95 33
      src/server.py

+ 1 - 6
requirements.txt

@@ -5,9 +5,7 @@ filelock==3.13.1
 fsspec==2023.10.0
 gunicorn==21.2.0
 idna==3.4
-imageio==2.31.6
 Jinja2==3.1.2
-lazy_loader==0.3
 MarkupSafe==2.1.3
 mpmath==1.3.0
 networkx==3.2.1
@@ -25,12 +23,9 @@ nvidia-nccl-cu12==2.18.1
 nvidia-nvjitlink-cu12==12.3.52
 nvidia-nvtx-cu12==12.1.105
 packaging==23.2
-Pillow==10.0.1
+Pillow==10.1.0
 requests==2.31.0
-scikit-image==0.22.0
-scipy==1.11.3
 sympy==1.12
-tifffile==2023.9.26
 torch==2.1.0
 torchvision==0.16.0
 triton==2.1.0

+ 13 - 20
src/predictor.py

@@ -1,14 +1,10 @@
 
 
-import numpy as np
-
 import torch
 import torchvision.transforms as transforms
 import torch.nn.init
 
 from PIL import Image
-from skimage import transform, util
-from typing import Tuple
 
 keep_prob = 0.9 
 n_classes = 8  # multi hot encoded
@@ -60,15 +56,17 @@ class CNN(torch.nn.Module):
 # ----
 
 class Prediction():
-    def __init__(self, model_name:str, growth:int, cells:int) -> str:
-        self.model = model_name
+    def __init__(self, growth:int, cells:int):
         self.growth = growth
         self.cells = cells
 
-    def __str__(self):
-        growth_level = {1:"early", 2: "mature", 3: "overgrown"}.get(self.growth, "?")
-        cells = {0: "0", 1: "1", 2: "2", 3: "3 or more"}.get(self.cells, "?")
-        return f"Growth level: {growth_level}. Cells: {cells}. Model: {self.model}"
+    def __str__(self) -> str:
+        growth_level = {1: "early growth", 2: "mature growth", 3: "overgrown"}.get(self.growth, "?")
+        cells = {0: "0 cells", 1: "1 cell", 2: "2 cells", 3: "3 or more cells"}.get(self.cells, "?")
+        return f"{growth_level}, {cells}"
+    
+    def getDict(self) -> dict:
+        return {"growth" : self.growth, "cells" : self.cells, "text" : str(self)}
 
 class Predictor():
     def __init__(self, pth_filename:str):
@@ -83,16 +81,11 @@ class Predictor():
         self.cnn.eval()
 
     def predict(self, input_img:Image) -> Prediction:
-        im_gray = np.array(input_img)
-
-        # resizing to 44 x 44 pixels. Model is trained on 44 x 44 pixel images
-        resize_im = (44, 44)
-        im_gray = util.img_as_ubyte(transform.resize(im_gray, resize_im, order = 1, anti_aliasing = False))
-
-        # convert uint8 to some PIL image format, values range 0..1, float32 (float64 vil give error), kan sikkert også gøres på andre måder
-        image = Image.fromarray(im_gray)  # np.shape(image) = (44,44)
+        
+        # Resize to 44 x 44 pixels, and convert to grayscale (if not already). Model is trained on 44 x 44 pixel images
+        img = input_img.resize((44, 44)).convert("L")
 
-        im_tensor = transforms.ToTensor()(image).unsqueeze(0)  # unsqueeze(0) giver [1,44,44] -> [1,1,44,44]
+        im_tensor = transforms.ToTensor()(img).unsqueeze(0)  # unsqueeze(0) giver [1,44,44] -> [1,1,44,44]
 
         # print(np.shape(im_tensor)) # = [1,1,44,44] det er det format som modellen tager som input.  [n_batch, channel, imx, imy]
         # Hvis batch, n af flere billeder på en gang skal det pakkes i formatet [n,1,44,44]
@@ -107,4 +100,4 @@ class Predictor():
         output_classes = list(torch.argmax(torch.reshape(prediction[0].detach(), MHE), dim = -1).numpy())
         # .detach() bruges for at undgå "grad_fn=<SigmoidBackward0>" dvs undgå at den aktivt tracker gradienten for hver operation...
 
-        return Prediction(self.modelName, int(output_classes[0]), int(output_classes[1]))
+        return Prediction(int(output_classes[0]), int(output_classes[1]))

+ 95 - 33
src/server.py

@@ -45,10 +45,14 @@ html = """<!DOCTYPE html>
 <meta charset="UTF-8">
 <title>Cell Growth Classifier</title>
 <script>
-function uploadImage() {
+function uploadImages() {
     var formData = new FormData();
-    var imageFile = document.getElementById("imageInput").files[0];
-    formData.append("image", imageFile);
+    var imageFiles = document.getElementById("imageInput").files;
+    var imageFilesArray = [];
+    for (var i = 0; i < imageFiles.length; i++) {
+        formData.append("images", imageFiles[i]);
+        imageFilesArray.push({name: imageFiles[i].name, file: imageFiles[i]});
+    }
 
     fetch('/upload', {
         method: 'POST',
@@ -56,36 +60,94 @@ function uploadImage() {
     })
     .then(response => response.json())
     .then(data => {
-        document.getElementById("results").textContent = `${data.results}`;
-        var image = document.createElement("img");
-        image.src = URL.createObjectURL(imageFile);
-        image.onload = function() {
-            URL.revokeObjectURL(image.src) // Free up memory
-        }
-        document.getElementById("imageDisplay").innerHTML = '';
-        document.getElementById("imageDisplay").appendChild(image);
+        var results = data.results;
+        var sortedKeys = Object.keys(results).sort(); // Sort the filenames
+        var table = document.getElementById("resultsTable");
+        table.innerHTML = ""; // Clear the table first
+
+        // Create table header
+        var header = table.createTHead();
+        var headerRow = header.insertRow(0);
+        headerRow.insertCell(0).textContent = "Image";
+        headerRow.insertCell(1).textContent = "Filename";
+        headerRow.insertCell(2).textContent = "Analysis";
+
+        // Create table body
+        var tbody = table.createTBody();
+
+        // Insert a row in the table for each sorted image result
+        sortedKeys.forEach((imageName) => {
+            var imageFile = imageFilesArray.find(file => file.name === imageName);
+            var row = tbody.insertRow(-1);
+            var cellImage = row.insertCell(0);
+            var image = document.createElement("img");
+
+            if (imageFile) {
+                image.src = URL.createObjectURL(imageFile.file);
+                image.onload = function() {
+                    URL.revokeObjectURL(this.src) // Free up memory
+                }
+            } else {
+                image.alt = "Image not found";
+            }
+            image.style.width = '100px'; // Set the image size
+            cellImage.appendChild(image);
+
+            row.insertCell(1).textContent = imageName;
+            row.insertCell(2).textContent = results[imageName].text;
+        });
+
+        // Set the model used
+        document.getElementById("modelUsed").textContent = "Model used: " + data.model;
     })
     .catch(error => {
         console.error('Error:', error);
-        document.getElementById("results").textContent = "An error occurred while uploading the image.";
+        document.getElementById("results").textContent = "An error occurred while uploading the images.";
     });
 }
 </script>
+<style>
+  table {
+    border-collapse: collapse;
+    margin-top: 20px;
+  }
+  thead, td {
+    border: 1px solid #ddd;
+    padding: 8px;
+    text-align: left;
+  }
+  thead {
+    background-color: #f2f2f2;
+    color: #333;
+    font-weight: bold;
+  }
+  td {
+    font-size: 0.9em;
+  }
+  img {
+    width: 100px; /* Adjust the size of the image if necessary */
+    height: auto;
+  }
+  #modelUsed {
+    margin-top: 20px;
+  }
+</style>
+
 </head>
 <body>
 
-<h2>Cell growth classifier demo</h2>
-<input type="file" id="imageInput" accept="image/*"> <br><br>
-<button onclick="uploadImage()">Upload</button>
-
-<h3>Image:</h3>
-<div id="imageDisplay"></div>
+<h2>Cell Growth Classifier</h2>
+<input type="file" id="imageInput" accept="image/*" multiple>
+<button onclick="uploadImages()">Analyze</button>
 
-<h3>Results:</h3>
-<div id="results"></div>
+<p id="modelUsed"></p>
+<table id="resultsTable">
+    
+</table>
 
 </body>
 </html>
+
 """
 
 # ---- 
@@ -107,19 +169,19 @@ def startServer(index_html:str, predictor:Predictor, host_port:str):
 
         @app.route('/upload', method='POST')
         def upload_image():
-            upload = request.files.get('image')
-            if upload is None:
-                return "No image uploaded"
-            
-            # Read the image file in bytes and open it with Pillow
-            image_bytes = io.BytesIO(upload.file.read())
-            try:
-                with Image.open(image_bytes) as img:
-                    pred:Prediction = app.predictor.predict(img)
-                    return {"results": str(pred)}
-            except IOError:
-                response.status = 400
-                return "Invalid image file"
+            uploads = request.files.getall('images')
+            results = {}
+            for upload in uploads:
+                # Read the image file in bytes and open it with Pillow
+                image_bytes = io.BytesIO(upload.file.read())
+                try:
+                    with Image.open(image_bytes) as img:
+                        pred:Prediction = app.predictor.predict(img)
+                        results[upload.filename] = pred.getDict()
+                except IOError:
+                    response.status = 400
+                    return "Invalid image file"
+            return {"model": app.predictor.modelName, "results": results}
   
         # Store predictor in the bottle app object
         app.predictor = predictor