Code:
from typing import List
from dataclasses import dataclass


# a)
@dataclass
class Point:
    x: float
    y: float

    def __repr__(self) -> str:
        return f"({self.x}, {self.y})"


data_points: List[Point] = [
    Point(1, 7),
    Point(3, 7),
    Point(4, 5),
    Point(5, 10),
    Point(6, 8),
    Point(9, 10),
    Point(8, 7),
    Point(10, 10),
]

# b)
n = len(data_points)
sumprod = 0.0
sumx = 0.0
sumy = 0.0
sumx2 = 0.0
for p in data_points:
    sumprod += p.x * p.y
    sumx += p.x
    sumy += p.y
    sumx2 += p.x * p.x

m_star = (sumprod - sumx * sumy / n) / (sumx2 - sumx * sumx / n)
b_star = (sumy - m_star * sumx) / n


# c)
@dataclass
class Line:
    m: float
    b: float


best_fit = Line(m_star, b_star)


# d)
def mean_squared_error(line: Line, points: List[Point]) -> float:
    sumsqerr = 0.0
    for p in points:
        err = p.y - (line.m * p.x + line.b)
        sumsqerr += err * err
    return sumsqerr / len(points)


# e)
mse = mean_squared_error(best_fit, data_points)
found_better_mse = False
for mdiff10 in range(-2, 3, 1):
    mdiff = mdiff10 / 10
    if mdiff == 0:
        continue
    for bdiff10 in range(-2, 3, 1):
        bdiff = bdiff10 / 10
        if bdiff == 0:
            continue
        new_line = Line(m=best_fit.m + mdiff, b=best_fit.b + bdiff)
        new_mse = mean_squared_error(new_line, data_points)
        if new_mse < mse:
            found_better_mse = True

if not found_better_mse:
    print(f"No smaller MSE found")
Modifié le: mercredi, 14 décembre 2022, 12:58