Proposition de corrigé
Code:
from typing import List
from dataclasses import dataclass
# a)
@dataclass
class Point:
x: float
y: float
def __repr__(self) -> str:
return f"({self.x}, {self.y})"
data_points: List[Point] = [
Point(1, 7),
Point(3, 7),
Point(4, 5),
Point(5, 10),
Point(6, 8),
Point(9, 10),
Point(8, 7),
Point(10, 10),
]
# b)
n = len(data_points)
sumprod = 0.0
sumx = 0.0
sumy = 0.0
sumx2 = 0.0
for p in data_points:
sumprod += p.x * p.y
sumx += p.x
sumy += p.y
sumx2 += p.x * p.x
m_star = (sumprod - sumx * sumy / n) / (sumx2 - sumx * sumx / n)
b_star = (sumy - m_star * sumx) / n
# c)
@dataclass
class Line:
m: float
b: float
best_fit = Line(m_star, b_star)
# d)
def mean_squared_error(line: Line, points: List[Point]) -> float:
sumsqerr = 0.0
for p in points:
err = p.y - (line.m * p.x + line.b)
sumsqerr += err * err
return sumsqerr / len(points)
# e)
mse = mean_squared_error(best_fit, data_points)
found_better_mse = False
for mdiff10 in range(-2, 3, 1):
mdiff = mdiff10 / 10
if mdiff == 0:
continue
for bdiff10 in range(-2, 3, 1):
bdiff = bdiff10 / 10
if bdiff == 0:
continue
new_line = Line(m=best_fit.m + mdiff, b=best_fit.b + bdiff)
new_mse = mean_squared_error(new_line, data_points)
if new_mse < mse:
found_better_mse = True
if not found_better_mse:
print(f"No smaller MSE found")
from dataclasses import dataclass
# a)
@dataclass
class Point:
x: float
y: float
def __repr__(self) -> str:
return f"({self.x}, {self.y})"
data_points: List[Point] = [
Point(1, 7),
Point(3, 7),
Point(4, 5),
Point(5, 10),
Point(6, 8),
Point(9, 10),
Point(8, 7),
Point(10, 10),
]
# b)
n = len(data_points)
sumprod = 0.0
sumx = 0.0
sumy = 0.0
sumx2 = 0.0
for p in data_points:
sumprod += p.x * p.y
sumx += p.x
sumy += p.y
sumx2 += p.x * p.x
m_star = (sumprod - sumx * sumy / n) / (sumx2 - sumx * sumx / n)
b_star = (sumy - m_star * sumx) / n
# c)
@dataclass
class Line:
m: float
b: float
best_fit = Line(m_star, b_star)
# d)
def mean_squared_error(line: Line, points: List[Point]) -> float:
sumsqerr = 0.0
for p in points:
err = p.y - (line.m * p.x + line.b)
sumsqerr += err * err
return sumsqerr / len(points)
# e)
mse = mean_squared_error(best_fit, data_points)
found_better_mse = False
for mdiff10 in range(-2, 3, 1):
mdiff = mdiff10 / 10
if mdiff == 0:
continue
for bdiff10 in range(-2, 3, 1):
bdiff = bdiff10 / 10
if bdiff == 0:
continue
new_line = Line(m=best_fit.m + mdiff, b=best_fit.b + bdiff)
new_mse = mean_squared_error(new_line, data_points)
if new_mse < mse:
found_better_mse = True
if not found_better_mse:
print(f"No smaller MSE found")
Modifié le: mercredi, 14 décembre 2022, 12:58