Explorar el Código

Multiple prediction and probabilities

Lars hace 2 años
padre
commit
dc0f655cae
Se han modificado 3 ficheros con 68 adiciones y 9 borrados
  1. 65 8
      src/predictor.py
  2. 1 0
      src/server.py
  3. 2 1
      src/test.py

+ 65 - 8
src/predictor.py

@@ -5,6 +5,7 @@ import torchvision.transforms as transforms
 import torch.nn.init
 
 from PIL import Image
+from typing import List
 
 keep_prob = 0.9 
 n_classes = 8  # multi hot encoded
@@ -55,18 +56,52 @@ class CNN(torch.nn.Module):
 
 # ----
 
+GrowthDict = {
+            1: "early growth", 
+            2: "mature growth", 
+            3: "overgrown"}
+
+ClonesDict = {
+            0: "0 clones", 
+            1: "1 clone", 
+            2: "2 clones", 
+            3: "3 or more clones"}
+
 class Prediction():
-    def __init__(self, growth:int, cells:int):
+    def __init__(self, growth:int, clones:int):
         self.growth = growth
-        self.cells = cells
+        self.clones = clones
 
     def __str__(self) -> str:
-        growth_level = {1: "early growth", 2: "mature growth", 3: "overgrown"}.get(self.growth, "?")
-        cells = {0: "0 cells", 1: "1 cell", 2: "2 cells", 3: "3 or more cells"}.get(self.cells, "?")
-        return f"{growth_level}, {cells}"
+        growth_level = {
+            1: "early growth", 
+            2: "mature growth", 
+            3: "overgrown"}.get(self.growth, "?")
+        clones = {
+            0: "0 clones", 
+            1: "1 clone", 
+            2: "2 clones", 
+            3: "3 or more clones"}.get(self.clones, "?")
+        return f"{growth_level}, {clones}"
     
     def getDict(self) -> dict:
-        return {"growth" : self.growth, "cells" : self.cells, "text" : str(self)}
+        return {"growth" : self.growth, "clones" : self.clones, "text" : str(self)}
+
+
+class MultiPrediction():
+    def __init__(self, growth_classes : List[int], growth_prop : List[float], clone_classes : List[int], clone_prop : List[float]):
+        self.pred = {
+           "growth_classes" : growth_classes,
+           "growth_prop" : growth_prop,
+           "clone_classes" : clone_classes,
+           "clone_prop" : clone_prop,
+           "text" : (', '.join(['%s (%0.1f%%)' % (GrowthDict[growth_classes[i]], growth_prop[i] * 100.0) for i in range(len(growth_classes))]) +
+                    "\n" + ', '.join(['%s (%0.1f%%)' % (ClonesDict[clone_classes[i]], clone_prop[i] * 100.0) for i in range(len(clone_classes))]))
+        }
+
+    def getDict(self) -> dict:
+        return self.pred
+
 
 class Predictor():
     def __init__(self, pth_filename:str):
@@ -97,7 +132,29 @@ class Predictor():
         # hvilket indeholder de 8 neuroner. Skal decodes til [A, B], ud fra [A1, A2, A3, A4, B1, B2, B3, B4]
         # bruge argmax og reshape til at få final [A,B] 
 
-        output_classes = list(torch.argmax(torch.reshape(prediction[0].detach(), MHE), dim = -1).numpy())
+        #output_classes = list(torch.argmax(torch.reshape(prediction[0].detach(), MHE), dim = -1).numpy())
         # .detach() bruges for at undgå "grad_fn=<SigmoidBackward0>" dvs undgå at den aktivt tracker gradienten for hver operation...
+        
+        k_highest = 2 # find de 2 højest scored classes og probabilities 
+        (probabilities, output_classes) = torch.topk(torch.reshape(prediction[0].detach(), MHE), k_highest)
+
+        def toIntList(iterable) -> List[int]:
+            return [int(x) for x in iterable]
+
+        def toFloatList(iterable) -> List[float]:
+            return [float(x) for x in iterable]
+
+        growth_classes = toIntList(output_classes[0,:].numpy()) # rangeret efter højeste probability, feks [2 1]
+        print(growth_classes)
+
+        growth_probabilities = toFloatList(probabilities[0,:].numpy()) # de relaterede sandsynligheder, feks [9.9999847e+01 5.7163057e-03], 
+        print(growth_probabilities)
+
+        n_cells_classes = toIntList(output_classes[1,:].numpy()) # rangeret efter højeste probability, feks [1 2]
+        print(n_cells_classes)
+
+        n_cells_probabilities = toFloatList(probabilities[1,:].numpy()) # de relaterede sandsynligheder, feks [1.0000000e+02 2.8741499e-02]
+        print(n_cells_probabilities)
+
+        return MultiPrediction(growth_classes, growth_probabilities, n_cells_classes, n_cells_probabilities)
 
-        return Prediction(int(output_classes[0]), int(output_classes[1]))

+ 1 - 0
src/server.py

@@ -123,6 +123,7 @@ function uploadImages() {
   }
   td {
     font-size: 0.9em;
+    white-space: pre-line;
   }
   img {
     width: 100px; /* Adjust the size of the image if necessary */

+ 2 - 1
src/test.py

@@ -1,13 +1,14 @@
 #!/usr/bin/env python
 
 import sys
+import json
 
 def main(model_filename, img_filename):
     from predictor import Predictor
     from PIL import Image
     p = Predictor(model_filename)
     result = p.predict(Image.open(img_filename))
-    print(str(result))
+    print(json.dumps(result.getDict()))
 
 def usage():
     print("""Usage: